From 5657e940b96c1a1d7aa0420892ba094a32b28999 Mon Sep 17 00:00:00 2001 From: kozistr Date: Mon, 14 Oct 2024 20:43:15 +0900 Subject: [PATCH] feature: middleware for setting otel context --- router/src/http/server.rs | 96 +++++++++++++++++++++++++++++---------- router/src/logging.rs | 55 ++++++++++++++++++++++ 2 files changed, 128 insertions(+), 23 deletions(-) diff --git a/router/src/http/server.rs b/router/src/http/server.rs index e2623c38..0ddf2c27 100644 --- a/router/src/http/server.rs +++ b/router/src/http/server.rs @@ -10,7 +10,7 @@ use crate::http::types::{ VertexResponse, }; use crate::{ - shutdown, ClassifierModel, EmbeddingModel, ErrorResponse, ErrorType, Info, ModelType, + logging, shutdown, ClassifierModel, EmbeddingModel, ErrorResponse, ErrorType, Info, ModelType, ResponseMetadata, }; use ::http::HeaderMap; @@ -38,6 +38,7 @@ use text_embeddings_core::TextEmbeddingsError; use tokio::sync::OwnedSemaphorePermit; use tower_http::cors::{AllowOrigin, CorsLayer}; use tracing::instrument; +use tracing_opentelemetry::OpenTelemetrySpanExt; use utoipa::OpenApi; use utoipa_swagger_ui::SwaggerUi; @@ -100,9 +101,14 @@ example = json ! ({"error": "Batch size error", "error_type": "validation"})), async fn predict( infer: Extension, info: Extension, + Extension(context): Extension>, Json(req): Json, ) -> Result<(HeaderMap, Json), (StatusCode, Json)> { let span = tracing::Span::current(); + if let Some(context) = context { + span.set_parent(context); + } + let start_time = Instant::now(); // Closure for predict @@ -296,9 +302,14 @@ example = json ! ({"error": "Batch size error", "error_type": "validation"})), async fn rerank( infer: Extension, info: Extension, + Extension(context): Extension>, Json(req): Json, ) -> Result<(HeaderMap, Json), (StatusCode, Json)> { let span = tracing::Span::current(); + if let Some(context) = context { + span.set_parent(context); + } + let start_time = Instant::now(); if req.texts.is_empty() { @@ -482,6 +493,7 @@ example = json ! ({"error": "Batch size error", "error_type": "validation"})), async fn similarity( infer: Extension, info: Extension, + context: Extension>, Json(req): Json, ) -> Result<(HeaderMap, Json), (StatusCode, Json)> { if req.inputs.sentences.is_empty() { @@ -528,7 +540,7 @@ async fn similarity( }; // Get embeddings - let (header_map, embed_response) = embed(infer, info, Json(embed_req)).await?; + let (header_map, embed_response) = embed(infer, info, context, Json(embed_req)).await?; let embeddings = embed_response.0 .0; // Compute cosine @@ -564,9 +576,14 @@ example = json ! ({"error": "Batch size error", "error_type": "validation"})), async fn embed( infer: Extension, info: Extension, + Extension(context): Extension>, Json(req): Json, ) -> Result<(HeaderMap, Json), (StatusCode, Json)> { let span = tracing::Span::current(); + if let Some(context) = context { + span.set_parent(context); + } + let start_time = Instant::now(); let truncate = req.truncate.unwrap_or(info.auto_truncate); @@ -733,9 +750,14 @@ example = json ! ({"error": "Batch size error", "error_type": "validation"})), async fn embed_sparse( infer: Extension, info: Extension, + Extension(context): Extension>, Json(req): Json, ) -> Result<(HeaderMap, Json), (StatusCode, Json)> { let span = tracing::Span::current(); + if let Some(context) = context { + span.set_parent(context); + } + let start_time = Instant::now(); let sparsify = |values: Vec| { @@ -911,9 +933,14 @@ example = json ! ({"error": "Batch size error", "error_type": "validation"})), async fn embed_all( infer: Extension, info: Extension, + Extension(context): Extension>, Json(req): Json, ) -> Result<(HeaderMap, Json), (StatusCode, Json)> { let span = tracing::Span::current(); + if let Some(context) = context { + span.set_parent(context); + } + let start_time = Instant::now(); let truncate = req.truncate.unwrap_or(info.auto_truncate); @@ -1078,6 +1105,7 @@ example = json ! ({"message": "Batch size error", "type": "validation"})), async fn openai_embed( infer: Extension, info: Extension, + Extension(context): Extension>, Json(req): Json, ) -> Result<(HeaderMap, Json), (StatusCode, Json)> { @@ -1097,6 +1125,10 @@ async fn openai_embed( }; let span = tracing::Span::current(); + if let Some(context) = context { + span.set_parent(context); + } + let start_time = Instant::now(); let truncate = info.auto_truncate; @@ -1462,36 +1494,47 @@ example = json ! ({"error": "Batch size error", "error_type": "validation"})), async fn vertex_compatibility( infer: Extension, info: Extension, + context: Extension>, Json(req): Json, ) -> Result, (StatusCode, Json)> { - let embed_future = move |infer: Extension, info: Extension, req: EmbedRequest| async move { - let result = embed(infer, info, Json(req)).await?; + let embed_future = move |infer: Extension, + info: Extension, + context: Extension>, + req: EmbedRequest| async move { + let result = embed(infer, info, context, Json(req)).await?; Ok(VertexPrediction::Embed(result.1 .0)) }; - let embed_sparse_future = - move |infer: Extension, info: Extension, req: EmbedSparseRequest| async move { - let result = embed_sparse(infer, info, Json(req)).await?; - Ok(VertexPrediction::EmbedSparse(result.1 .0)) - }; - let predict_future = - move |infer: Extension, info: Extension, req: PredictRequest| async move { - let result = predict(infer, info, Json(req)).await?; - Ok(VertexPrediction::Predict(result.1 .0)) - }; - let rerank_future = - move |infer: Extension, info: Extension, req: RerankRequest| async move { - let result = rerank(infer, info, Json(req)).await?; - Ok(VertexPrediction::Rerank(result.1 .0)) - }; + let embed_sparse_future = move |infer: Extension, + info: Extension, + context: Extension>, + req: EmbedSparseRequest| async move { + let result = embed_sparse(infer, info, context, Json(req)).await?; + Ok(VertexPrediction::EmbedSparse(result.1 .0)) + }; + let predict_future = move |infer: Extension, + info: Extension, + context: Extension>, + req: PredictRequest| async move { + let result = predict(infer, info, context, Json(req)).await?; + Ok(VertexPrediction::Predict(result.1 .0)) + }; + let rerank_future = move |infer: Extension, + info: Extension, + context: Extension>, + req: RerankRequest| async move { + let result = rerank(infer, info, context, Json(req)).await?; + Ok(VertexPrediction::Rerank(result.1 .0)) + }; let mut futures = Vec::with_capacity(req.instances.len()); for instance in req.instances { let local_infer = infer.clone(); let local_info = info.clone(); + let local_context = context.clone(); // Rerank is the only payload that can me matched safely if let Ok(instance) = serde_json::from_value::(instance.clone()) { - futures.push(rerank_future(local_infer, local_info, instance).boxed()); + futures.push(rerank_future(local_infer, local_info, local_context, instance).boxed()); continue; } @@ -1499,17 +1542,23 @@ async fn vertex_compatibility( ModelType::Classifier(_) | ModelType::Reranker(_) => { let instance = serde_json::from_value::(instance) .map_err(ErrorResponse::from)?; - futures.push(predict_future(local_infer, local_info, instance).boxed()); + futures + .push(predict_future(local_infer, local_info, local_context, instance).boxed()); } ModelType::Embedding(_) => { if infer.is_splade() { let instance = serde_json::from_value::(instance) .map_err(ErrorResponse::from)?; - futures.push(embed_sparse_future(local_infer, local_info, instance).boxed()); + futures.push( + embed_sparse_future(local_infer, local_info, local_context, instance) + .boxed(), + ); } else { let instance = serde_json::from_value::(instance) .map_err(ErrorResponse::from)?; - futures.push(embed_future(local_infer, local_info, instance).boxed()); + futures.push( + embed_future(local_infer, local_info, local_context, instance).boxed(), + ); } } } @@ -1777,6 +1826,7 @@ pub async fn run( .layer(Extension(info)) .layer(Extension(prom_handle.clone())) .layer(OtelAxumLayer::default()) + .layer(axum::middleware::from_fn(logging::trace_context_middleware)) .layer(DefaultBodyLimit::max(payload_limit)) .layer(cors_layer); diff --git a/router/src/logging.rs b/router/src/logging.rs index 7a8fe810..45b29dce 100644 --- a/router/src/logging.rs +++ b/router/src/logging.rs @@ -1,3 +1,6 @@ +use axum::{extract::Request, middleware::Next, response::Response}; +use opentelemetry::trace::{SpanContext, SpanId, TraceContextExt, TraceFlags, TraceId}; +use opentelemetry::Context; use opentelemetry::{global, KeyValue}; use opentelemetry_otlp::WithExportConfig; use opentelemetry_sdk::propagation::TraceContextPropagator; @@ -7,6 +10,58 @@ use tracing_subscriber::layer::SubscriberExt; use tracing_subscriber::util::SubscriberInitExt; use tracing_subscriber::{EnvFilter, Layer}; +struct TraceParent { + #[allow(dead_code)] + version: u8, + trace_id: TraceId, + parent_id: SpanId, + trace_flags: TraceFlags, +} + +fn parse_traceparent(header_value: &str) -> Option { + let parts: Vec<&str> = header_value.split('-').collect(); + if parts.len() != 4 { + return None; + } + + let version = u8::from_str_radix(parts[0], 16).ok()?; + if version == 0xff { + return None; + } + + let trace_id = TraceId::from_hex(parts[1]).ok()?; + let parent_id = SpanId::from_hex(parts[2]).ok()?; + let trace_flags = u8::from_str_radix(parts[3], 16).ok()?; + + Some(TraceParent { + version, + trace_id, + parent_id, + trace_flags: TraceFlags::new(trace_flags), + }) +} + +pub async fn trace_context_middleware(mut request: Request, next: Next) -> Response { + let context = request + .headers() + .get("traceparent") + .and_then(|v| v.to_str().ok()) + .and_then(parse_traceparent) + .map(|traceparent| { + Context::new().with_remote_span_context(SpanContext::new( + traceparent.trace_id, + traceparent.parent_id, + traceparent.trace_flags, + true, + Default::default(), + )) + }); + + request.extensions_mut().insert(context); + + next.run(request).await +} + /// Init logging using env variables LOG_LEVEL and LOG_FORMAT: /// - otlp_endpoint is an optional URL to an Open Telemetry collector /// - LOG_LEVEL may be TRACE, DEBUG, INFO, WARN or ERROR (default to INFO)