Skip to content

Commit

Permalink
add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
snowmead committed Sep 21, 2024
1 parent 1784431 commit 79ecb61
Show file tree
Hide file tree
Showing 8 changed files with 878 additions and 682 deletions.
3 changes: 3 additions & 0 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
{
"rust-analyzer.cargo.features": ["rocksdb"]
}
11 changes: 11 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 4 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,5 +36,9 @@ bounded-integer = { version = "0.5.7", features = ["types", "num-traits02"] }
aquamarine = "0.3.2"
tiktoken-rs = "0.5.8"

[dev-dependencies]
futures = "0.3"
uuid = { version = "1.0", features = ["v4"] }

[features]
rocksdb = ["dep:rocksdb"]
124 changes: 4 additions & 120 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@
#![feature(once_cell_try)]

use std::{
collections::VecDeque,
fmt::{Debug, Display},
marker::PhantomData,
str::FromStr,
Expand All @@ -61,7 +60,7 @@ use num_traits::{
SaturatingSub, ToPrimitive, Unsigned,
};
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use tracing::{error, trace};
use tracing::trace;

pub mod architecture;
pub mod loom;
Expand Down Expand Up @@ -101,6 +100,7 @@ pub type Result<T> = std::result::Result<T, Box<dyn std::error::Error + Send + S
/// format!("{}:{}", self.id, self.sub_id)
/// }
/// }
/// ```
pub trait TapestryId: Debug + Clone + Send + Sync + 'static {
/// Returns the base key.
///
Expand Down Expand Up @@ -256,7 +256,7 @@ pub trait Config: Debug + Sized + Clone + Default + Send + Sync + 'static {
}

/// Context message that represent a single message in a [`TapestryFragment`] instance.
#[derive(Clone, Debug, Serialize, Deserialize)]
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
pub struct ContextMessage<T: Config> {
pub role: WrapperRole,
pub content: String,
Expand Down Expand Up @@ -285,7 +285,7 @@ impl<T: Config> ContextMessage<T> {
/// The total number of `context_tokens` is tracked when [`Loom::weave`] is executed and if it
/// exceeds the maximum number of tokens allowed for the current GPT [`Config::PromptModel`], then a
/// summary is generated and a new [`TapestryFragment`] instance is created.
#[derive(Debug, Serialize, Deserialize, Default, Clone)]
#[derive(Debug, Serialize, Deserialize, Default, PartialEq, Clone)]
pub struct TapestryFragment<T: Config> {
/// Total number of _GPT tokens_ in the `context_messages`.
pub context_tokens: <T::PromptModel as Llm<T>>::Tokens,
Expand Down Expand Up @@ -346,119 +346,3 @@ impl<T: Config> TapestryFragment<T> {
Ok(())
}
}

/// The machine that drives all of the core methods that should be used across any service
/// that needs to prompt LLM and receive a response.
///
/// This is implemented over the [`Config`] trait.
#[async_trait]
pub trait Loom<T: Config> {
/// Prompt LLM Weaver for a response for [`TapestryId`].
///
/// Prompts LLM with the current [`TapestryFragment`] instance and the new `msgs`.
///
/// A summary will be generated of the current [`TapestryFragment`] instance if the total number
/// of tokens in the `context_messages` exceeds the maximum number of tokens allowed for the
/// current [`Config::PromptModel`] or custom max tokens. This threshold is affected by the
/// [`Config::TOKEN_THRESHOLD_PERCENTILE`].
///
/// # Parameters
///
/// - `prompt_llm_config`: The [`Config::PromptModel`] to use for prompting LLM.
/// - `summary_llm_config`: The [`Config::SummaryModel`] to use for generating summaries.
/// - `tapestry_id`: The [`TapestryId`] to use for storing the [`TapestryFragment`] instance.
/// - `instructions`: The instruction message to be used for the current [`TapestryFragment`]
/// instance.
/// - `msgs`: The messages to prompt the LLM with.
async fn weave<TID: TapestryId>(
&self,
prompt_llm_config: LlmConfig<T, T::PromptModel>,
summary_llm_config: LlmConfig<T, T::SummaryModel>,
tapestry_id: TID,
instructions: String,
mut msgs: Vec<ContextMessage<T>>,
) -> Result<(<<T as Config>::PromptModel as Llm<T>>::Response, u64, bool)>;
/// Generates the summary of the current [`TapestryFragment`] instance.
///
/// Returns the summary message as a string.
async fn generate_summary(
summary_model_config: LlmConfig<T, T::SummaryModel>,
tapestry_fragment: &TapestryFragment<T>,
summary_max_tokens: SummaryModelTokens<T>,
) -> Result<String>;
/// Helper method to build a [`ContextMessage`]
fn build_context_message(
role: WrapperRole,
content: String,
account_id: Option<String>,
) -> ContextMessage<T>;
fn count_tokens_in_messages(
msgs: impl Iterator<Item = &ContextMessage<T>>,
) -> <T::PromptModel as Llm<T>>::Tokens;
}

/// A helper struct to manage the prompt messages in a deque while keeping track of the tokens
/// added or removed.
struct VecPromptMsgsDeque<T: Config, L: Llm<T>> {
tokens: <L as Llm<T>>::Tokens,
inner: VecDeque<<L as Llm<T>>::Request>,
}

impl<T: Config, L: Llm<T>> VecPromptMsgsDeque<T, L> {
fn new() -> Self {
Self { tokens: L::Tokens::from_u8(0).unwrap(), inner: VecDeque::new() }
}

fn with_capacity(capacity: usize) -> Self {
Self { tokens: L::Tokens::from_u8(0).unwrap(), inner: VecDeque::with_capacity(capacity) }
}

fn push_front(&mut self, msg_reqs: L::Request) {
let tokens = L::count_tokens(&msg_reqs.to_string()).unwrap_or_default();
self.tokens = self.tokens.saturating_add(&tokens);
self.inner.push_front(msg_reqs);
}

fn push_back(&mut self, msg_reqs: L::Request) {
let tokens = L::count_tokens(&msg_reqs.to_string()).unwrap_or_default();
self.tokens = self.tokens.saturating_add(&tokens);
self.inner.push_back(msg_reqs);
}

fn append(&mut self, msg_reqs: &mut VecDeque<L::Request>) {
msg_reqs.iter().for_each(|msg_req| {
let msg_tokens = L::count_tokens(&msg_req.to_string()).unwrap_or_default();
self.tokens = self.tokens.saturating_add(&msg_tokens);
});
self.inner.append(msg_reqs);
}

fn truncate(&mut self, len: usize) {
let mut tokens = L::Tokens::from_u8(0).unwrap();
for msg_req in self.inner.iter().take(len) {
let msg_tokens = L::count_tokens(&msg_req.to_string()).unwrap_or_default();
tokens = tokens.saturating_add(&msg_tokens);
}
self.inner.truncate(len);
self.tokens = tokens;
}

fn extend(&mut self, msg_reqs: Vec<L::Request>) {
let mut tokens = L::Tokens::from_u8(0).unwrap();
for msg_req in &msg_reqs {
let msg_tokens = L::count_tokens(&msg_req.to_string()).unwrap_or_default();
tokens = tokens.saturating_add(&msg_tokens);
}
self.inner.extend(msg_reqs);
match self.tokens.checked_add(&tokens) {
Some(v) => self.tokens = v,
None => {
error!("Token overflow");
},
}
}

fn into_vec(self) -> Vec<L::Request> {
self.inner.into()
}
}
23 changes: 19 additions & 4 deletions src/loom.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,18 @@ use tracing::{debug, error, instrument, trace};

use crate::{
types::{
LoomError, PromptModelRequest, PromptModelTokens, SummaryModelTokens, WeaveError,
WrapperRole, ASSISTANT_ROLE, SYSTEM_ROLE,
LoomError, PromptModelRequest, PromptModelTokens, SummaryModelTokens, VecPromptMsgsDeque,
WeaveError, WrapperRole, ASSISTANT_ROLE, SYSTEM_ROLE,
},
Config, ContextMessage, Llm, LlmConfig, TapestryChestHandler, TapestryFragment, TapestryId,
VecPromptMsgsDeque,
};

/// The machine that drives all of the core methods that should be used across any service
/// that needs to prompt LLM and receive a response.
///
/// This is implemented over the [`Config`] trait.
pub struct Loom<T: Config> {
chest: T::Chest,
pub chest: T::Chest,
_phantom: PhantomData<T>,
}

Expand All @@ -28,6 +27,22 @@ impl<T: Config> Loom<T> {
}

/// Prompt LLM Weaver for a response for [`TapestryId`].
///
/// Prompts LLM with the current [`TapestryFragment`] instance and the new `msgs`.
///
/// A summary will be generated of the current [`TapestryFragment`] instance if the total number
/// of tokens in the `context_messages` exceeds the maximum number of tokens allowed for the
/// current [`Config::PromptModel`] or custom max tokens. This threshold is affected by the
/// [`Config::TOKEN_THRESHOLD_PERCENTILE`].
///
/// # Parameters
///
/// - `prompt_llm_config`: The [`Config::PromptModel`] to use for prompting LLM.
/// - `summary_llm_config`: The [`Config::SummaryModel`] to use for generating summaries.
/// - `tapestry_id`: The [`TapestryId`] to use for storing the [`TapestryFragment`] instance.
/// - `instructions`: The instruction message to be used for the current [`TapestryFragment`]
/// instance.
/// - `msgs`: The messages to prompt the LLM with.
#[instrument(skip(self, instructions, msgs))]
pub async fn weave<TID: TapestryId>(
&self,
Expand Down
Loading

0 comments on commit 79ecb61

Please sign in to comment.