Skip to content

Commit

Permalink
[Experimental] Move identity & origin event middleware config
Browse files Browse the repository at this point in the history
Restructure the config slightly for better management of http
specific configs allowing them to be configured per http service.
  • Loading branch information
allada committed Dec 13, 2024
1 parent f280e71 commit 5f751dc
Show file tree
Hide file tree
Showing 6 changed files with 71 additions and 106 deletions.
17 changes: 5 additions & 12 deletions nativelink-config/src/cas_server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -209,13 +209,6 @@ pub struct BepConfig {
/// The store name referenced in the `stores` map in the main config.
#[serde(deserialize_with = "convert_string_with_shellexpand")]
pub store: StoreRefName,

/// The config related to identifying the client.
/// The value of this header will be used to identify the caller and
/// will be added to the `BuildEvent::identity` field of the message.
/// Default: {see `IdentityHeaderSpec`}
#[serde(default)]
pub experimental_identity_header: IdentityHeaderSpec,
}

#[derive(Deserialize, Clone, Debug, Default)]
Expand Down Expand Up @@ -250,11 +243,6 @@ pub struct OriginEventsSpec {
/// Default: 65536 (zero defaults to this)
#[serde(default, deserialize_with = "convert_numeric_with_shellexpand")]
pub max_event_queue_size: usize,

/// The config related to identifying the client.
/// Default: {see `IdentityHeaderSpec`}
#[serde(default)]
pub identity_header: IdentityHeaderSpec,
}

#[derive(Deserialize, Debug)]
Expand Down Expand Up @@ -452,6 +440,11 @@ pub struct ServerConfig {

/// Services to attach to server.
pub services: Option<ServicesConfig>,

/// The config related to identifying the client.
/// Default: {see `IdentityHeaderSpec`}
#[serde(default)]
pub experimental_identity_header: IdentityHeaderSpec,
}

#[allow(non_camel_case_types)]
Expand Down
51 changes: 10 additions & 41 deletions nativelink-service/src/bep_server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ use std::pin::Pin;
use bytes::BytesMut;
use futures::stream::unfold;
use futures::Stream;
use nativelink_config::cas_server::IdentityHeaderSpec;
use nativelink_error::{Error, ResultExt};
use nativelink_proto::com::github::trace_machina::nativelink::events::{bep_event, BepEvent};
use nativelink_proto::google::devtools::build::v1::publish_build_event_server::{
Expand All @@ -29,25 +28,25 @@ use nativelink_proto::google::devtools::build::v1::{
PublishLifecycleEventRequest,
};
use nativelink_store::store_manager::StoreManager;
use nativelink_util::origin_context::{ActiveOriginContext, ORIGIN_IDENTITY};
use nativelink_util::store_trait::{Store, StoreDriver, StoreKey, StoreLike};
use prost::Message;
use tonic::metadata::MetadataMap;
use tonic::{Request, Response, Result, Status, Streaming};
use tracing::{instrument, Level};

/// Current version of the BEP event. This might be used in the future if
/// there is a breaking change in the BEP event format.
const BEP_EVENT_VERSION: u32 = 0;

/// Default identity header name.
/// Note: If this is changed, the default value in the [`IdentityHeaderSpec`]
// TODO(allada) This has a mirror in origin_event_middleware.rs.
// We should consolidate these.
const DEFAULT_IDENTITY_HEADER: &str = "x-identity";
fn get_identity() -> Result<Option<String>, Status> {
ActiveOriginContext::get()
.map_or(Ok(None), |ctx| ctx.get_value(&ORIGIN_IDENTITY))
.err_tip(|| "In BepServer")
.map_or_else(|e| Err(e.into()), |v| Ok(v.map(|v| v.as_ref().clone())))
}

pub struct BepServer {
store: Store,
identity_header: IdentityHeaderSpec,
}

impl BepServer {
Expand All @@ -59,41 +58,13 @@ impl BepServer {
.get_store(&config.store)
.err_tip(|| format!("Expected store {} to exist in store manager", &config.store))?;

let mut identity_header = config.experimental_identity_header.clone();
if identity_header.header_name.is_none() {
identity_header.header_name = Some(DEFAULT_IDENTITY_HEADER.to_string());
}

Ok(Self {
store,
identity_header,
})
Ok(Self { store })
}

pub fn into_service(self) -> PublishBuildEventServer<BepServer> {
PublishBuildEventServer::new(self)
}

fn get_identity(&self, request_metadata: &MetadataMap) -> Result<Option<String>, Status> {
let header_name = self
.identity_header
.header_name
.as_deref()
.unwrap_or(DEFAULT_IDENTITY_HEADER);
if header_name.is_empty() {
return Ok(None);
}
let identity = request_metadata
.get(header_name)
.and_then(|header| header.to_str().ok().map(str::to_string));
if identity.is_none() && self.identity_header.required {
return Err(Status::unauthenticated(format!(
"'{header_name}' header is required"
)));
}
Ok(identity)
}

async fn inner_publish_lifecycle_event(
&self,
request: PublishLifecycleEventRequest,
Expand Down Expand Up @@ -243,8 +214,7 @@ impl PublishBuildEvent for BepServer {
&self,
grpc_request: Request<PublishLifecycleEventRequest>,
) -> Result<Response<()>, Status> {
let identity = self.get_identity(grpc_request.metadata())?;
self.inner_publish_lifecycle_event(grpc_request.into_inner(), identity)
self.inner_publish_lifecycle_event(grpc_request.into_inner(), get_identity()?)
.await
.map_err(Error::into)
}
Expand All @@ -260,8 +230,7 @@ impl PublishBuildEvent for BepServer {
&self,
grpc_request: Request<Streaming<PublishBuildToolEventStreamRequest>>,
) -> Result<Response<Self::PublishBuildToolEventStreamStream>, Status> {
let identity = self.get_identity(grpc_request.metadata())?;
self.inner_publish_build_tool_event_stream(grpc_request.into_inner(), identity)
self.inner_publish_build_tool_event_stream(grpc_request.into_inner(), get_identity()?)
.await
.map_err(Error::into)
}
Expand Down
3 changes: 1 addition & 2 deletions nativelink-service/tests/bep_server_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ use std::sync::Arc;

use futures::StreamExt;
use hyper::body::Frame;
use nativelink_config::cas_server::{BepConfig, IdentityHeaderSpec};
use nativelink_config::cas_server::BepConfig;
use nativelink_config::stores::{MemorySpec, StoreSpec};
use nativelink_error::{Error, ResultExt};
use nativelink_macro::nativelink_test;
Expand Down Expand Up @@ -69,7 +69,6 @@ fn make_bep_server(store_manager: &StoreManager) -> Result<BepServer, Error> {
BepServer::new(
&BepConfig {
store: BEP_STORE_NAME.to_string(),
experimental_identity_header: IdentityHeaderSpec::default(),
},
store_manager,
)
Expand Down
4 changes: 4 additions & 0 deletions nativelink-util/src/origin_context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@ macro_rules! make_symbol {
};
}

// Symbol that represents the identity of the origin of a request.
// See: IdentityHeaderSpec for details.
make_symbol!(ORIGIN_IDENTITY, String);

pub struct NLSymbol<T: Send + Sync + 'static> {
pub name: &'static str,
pub _phantom: std::marker::PhantomData<T>,
Expand Down
41 changes: 21 additions & 20 deletions nativelink-util/src/origin_event_middleware.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ use tower::layer::Layer;
use tower::Service;
use tracing::trace_span;

use crate::origin_context::OriginContext;
use crate::origin_context::{ActiveOriginContext, ORIGIN_IDENTITY};
use crate::origin_event::{OriginEventCollector, ORIGIN_EVENT_COLLECTOR};

/// Default identity header name.
Expand All @@ -45,17 +45,17 @@ pub struct OriginRequestMetadata {

#[derive(Clone)]
pub struct OriginEventMiddlewareLayer {
origin_event_tx: mpsc::Sender<OriginEvent>,
maybe_origin_event_tx: Option<mpsc::Sender<OriginEvent>>,
idenity_header_config: Arc<IdentityHeaderSpec>,
}

impl OriginEventMiddlewareLayer {
pub fn new(
origin_event_tx: mpsc::Sender<OriginEvent>,
maybe_origin_event_tx: Option<mpsc::Sender<OriginEvent>>,
idenity_header_config: IdentityHeaderSpec,
) -> Self {
Self {
origin_event_tx,
maybe_origin_event_tx,
idenity_header_config: Arc::new(idenity_header_config),
}
}
Expand All @@ -67,7 +67,7 @@ impl<S> Layer<S> for OriginEventMiddlewareLayer {
fn layer(&self, service: S) -> Self::Service {
OriginEventMiddleware {
inner: service,
origin_event_tx: self.origin_event_tx.clone(),
maybe_origin_event_tx: self.maybe_origin_event_tx.clone(),
idenity_header_config: self.idenity_header_config.clone(),
}
}
Expand All @@ -76,7 +76,7 @@ impl<S> Layer<S> for OriginEventMiddlewareLayer {
#[derive(Clone)]
pub struct OriginEventMiddleware<S> {
inner: S,
origin_event_tx: mpsc::Sender<OriginEvent>,
maybe_origin_event_tx: Option<mpsc::Sender<OriginEvent>>,
idenity_header_config: Arc<IdentityHeaderSpec>,
}

Expand All @@ -101,7 +101,8 @@ where
let clone = self.inner.clone();
let mut inner = std::mem::replace(&mut self.inner, clone);

let (identity, bazel_metadata) = {
let mut context = ActiveOriginContext::fork().unwrap_or_default();
let identity = {
let identity_header = self
.idenity_header_config
.header_name
Expand All @@ -124,24 +125,24 @@ where
.unwrap())
});
}

context.set_value(&ORIGIN_IDENTITY, Arc::new(identity.clone()));
identity
};
if let Some(origin_event_tx) = &self.maybe_origin_event_tx {
let bazel_metadata = req
.headers()
.get("build.bazel.remote.execution.v2.requestmetadata-bin")
.and_then(|header| BASE64_STANDARD_NO_PAD.decode(header.as_bytes()).ok())
.and_then(|data| RequestMetadata::decode(data.as_slice()).ok());
(identity, bazel_metadata)
};

let mut context = OriginContext::new();
context.set_value(
&ORIGIN_EVENT_COLLECTOR,
Arc::new(OriginEventCollector::new(
self.origin_event_tx.clone(),
identity,
bazel_metadata,
)),
);
context.set_value(
&ORIGIN_EVENT_COLLECTOR,
Arc::new(OriginEventCollector::new(
origin_event_tx.clone(),
identity,
bazel_metadata,
)),
);
}

Box::pin(async move {
Arc::new(context)
Expand Down
61 changes: 30 additions & 31 deletions src/bin/nativelink.rs
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,30 @@ async fn inner_main(
schedulers: action_schedulers.clone(),
}));

let maybe_origin_event_tx = cfg
.experimental_origin_events
.as_ref()
.map(|origin_events_cfg| {
let mut max_queued_events = origin_events_cfg.max_event_queue_size;
if max_queued_events == 0 {
max_queued_events = DEFAULT_MAX_QUEUE_EVENTS;
}
let (tx, rx) = mpsc::channel(max_queued_events);
let store_name = origin_events_cfg.publisher.store.as_str();
let store = store_manager.get_store(store_name).err_tip(|| {
format!("Could not get store {store_name} for origin event publisher")
})?;

root_futures.push(Box::pin(
OriginEventPublisher::new(store, rx, shutdown_tx.clone())
.run()
.map(Ok),
));

Ok::<_, Error>(tx)
})
.transpose()?;

for (server_cfg, connected_clients_mux) in servers_and_clients {
let services = server_cfg
.services
Expand Down Expand Up @@ -457,37 +481,12 @@ async fn inner_main(

let health_registry = health_registry_builder.lock().await.build();

let mut svc = Router::new();
let maybe_middleware_layer = cfg
.experimental_origin_events
.as_ref()
.map(|origin_events_cfg| {
let mut max_queued_events = origin_events_cfg.max_event_queue_size;
if max_queued_events == 0 {
max_queued_events = DEFAULT_MAX_QUEUE_EVENTS;
}
let (tx, rx) = mpsc::channel(max_queued_events);
let store_name = origin_events_cfg.publisher.store.as_str();
let store = store_manager.get_store(store_name).err_tip(|| {
format!("Could not get store {store_name} for origin event publisher")
})?;

root_futures.push(Box::pin(
OriginEventPublisher::new(store, rx, shutdown_tx.clone())
.run()
.map(Ok),
));

Ok::<_, Error>((tx, origin_events_cfg))
})
.transpose()?
.map(|(tx, cfg)| OriginEventMiddlewareLayer::new(tx, cfg.identity_header.clone()));
let tonic_axum_rounter = tonic_services.into_service().into_axum_router();
svc = if let Some(middleware) = maybe_middleware_layer {
svc.merge(tonic_axum_rounter.layer(middleware))
} else {
svc.merge(tonic_axum_rounter)
};
let mut svc = Router::new().merge(tonic_services.into_service().into_axum_router().layer(
OriginEventMiddlewareLayer::new(
maybe_origin_event_tx.clone(),
server_cfg.experimental_identity_header.clone(),
),
));

if let Some(health_cfg) = services.health {
let path = if health_cfg.path.is_empty() {
Expand Down

0 comments on commit 5f751dc

Please sign in to comment.