diff --git a/rust/crates/langsmith-tracing-client/src/client/blocking/processor.rs b/rust/crates/langsmith-tracing-client/src/client/blocking/processor.rs index 2063516bf..bb23cd858 100644 --- a/rust/crates/langsmith-tracing-client/src/client/blocking/processor.rs +++ b/rust/crates/langsmith-tracing-client/src/client/blocking/processor.rs @@ -1,3 +1,4 @@ +use std::collections::HashMap; use std::sync::mpsc::{Receiver, Sender}; use std::sync::{mpsc, Arc, Mutex}; use std::time::{Duration, Instant}; @@ -97,6 +98,62 @@ impl RunProcessor { Ok(()) } + // If we have a `QueuedRun::Create` and `QueuedRun::Update` for the same run ID in the batch, + // combine the update data into the create so we can send just one operation instead of two. + fn combine_batch_operations(batch: Vec) -> Vec { + let mut output = Vec::with_capacity(batch.len()); + let mut id_to_index = HashMap::with_capacity(batch.len()); + + for queued_run in batch { + match queued_run { + QueuedRun::Create(ref run_create_extended) => { + // Record the `Create` operation's ID and index, + // in case we need to modify it later. + let RunCreateExtended { run_create, .. } = run_create_extended; + let run_id = run_create.common.id.clone(); + let index = output.len(); + id_to_index.insert(run_id, index); + output.push(queued_run); + } + QueuedRun::Update(run_update_extended) => { + let run_id = run_update_extended.run_update.common.id.as_str(); + if let Some(create_index) = id_to_index.get(run_id) { + // This `run_id` matches a `Create` in this batch. + // Merge the `Update` data into the `Create` and + // drop the separate `Update` operation from the batch. + let RunUpdateExtended { run_update, io, attachments } = run_update_extended; + let QueuedRun::Create(matching_create) = &mut output[*create_index] else { + panic!("index {create_index} did not point to a Create operation in {output:?}"); + }; + debug_assert_eq!( + run_update.common.id, matching_create.run_create.common.id, + "Create operation at index {create_index} did not have expected ID {}: {matching_create:?}", + run_update.common.id, + ); + + matching_create.run_create.common.merge(run_update.common); + matching_create.run_create.end_time = Some(run_update.end_time); + matching_create.io.merge(io); + if let Some(mut _existing_attachments) = + matching_create.attachments.as_mut() + { + unimplemented!("figure out how to merge attachments -- in Python they are a dict but here they are a Vec"); + } else { + matching_create.attachments = attachments; + } + } else { + // No matching `Create` operations for this `Update`, add it as-is. + output.push(QueuedRun::Update(run_update_extended)); + } + } + // Allow other operations to pass through unchanged. + _ => output.push(queued_run), + } + } + + output + } + #[expect(unused_variables)] fn send_batch(&self, batch: Vec) -> Result<(), TracingClientError> { //println!("Handling a batch of {} runs", batch.len()); @@ -104,6 +161,8 @@ impl RunProcessor { let mut json_data = Vec::new(); let mut attachment_parts = Vec::new(); + let batch = Self::combine_batch_operations(batch); + let start_iter = Instant::now(); for queued_run in batch { match queued_run { diff --git a/rust/crates/langsmith-tracing-client/src/client/run.rs b/rust/crates/langsmith-tracing-client/src/client/run.rs index b5308bc71..4750c1028 100644 --- a/rust/crates/langsmith-tracing-client/src/client/run.rs +++ b/rust/crates/langsmith-tracing-client/src/client/run.rs @@ -24,6 +24,18 @@ pub struct RunIO { pub outputs: Option>, } +impl RunIO { + #[inline] + pub(crate) fn merge(&mut self, other: RunIO) { + if other.inputs.is_some() { + self.inputs = other.inputs; + } + if other.outputs.is_some() { + self.outputs = other.outputs; + } + } +} + #[derive(Serialize, Deserialize, PartialEq, Debug)] pub struct RunCommon { pub id: String, @@ -39,6 +51,36 @@ pub struct RunCommon { pub session_name: Option, } +impl RunCommon { + #[inline] + pub(crate) fn merge(&mut self, other: RunCommon) { + if other.parent_run_id.is_some() { + self.parent_run_id = other.parent_run_id; + } + if other.extra.is_some() { + self.extra = other.extra; + } + if other.error.is_some() { + self.error = other.error; + } + if other.serialized.is_some() { + self.serialized = other.serialized; + } + if other.events.is_some() { + self.events = other.events; + } + if other.tags.is_some() { + self.tags = other.tags; + } + if other.session_id.is_some() { + self.session_id = other.session_id; + } + if other.session_name.is_some() { + self.session_name = other.session_name; + } + } +} + #[derive(Serialize, Deserialize, PartialEq, Debug)] pub struct RunCreate { #[serde(flatten)]