Skip to content

Commit

Permalink
Add Closed stream event to OriginEvents
Browse files Browse the repository at this point in the history
Adds an event for when a stream is closed.
  • Loading branch information
allada committed Jan 27, 2025
1 parent 0915e03 commit b0f7016
Show file tree
Hide file tree
Showing 5 changed files with 127 additions and 22 deletions.
6 changes: 6 additions & 0 deletions local-remote-execution/generated-cc/cc/MODULE.bazel
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
###############################################################################
# Bazel now uses Bzlmod by default to manage external dependencies.
# Please consider migrating your external dependencies from WORKSPACE to MODULE.bazel.
#
# For more details, please check https://github.com/bazelbuild/bazel/issues/18958
###############################################################################
Original file line number Diff line number Diff line change
Expand Up @@ -108,9 +108,10 @@ message StreamEvent {
uint64 data_length = 3;
WriteRequestOverride write_request = 4;
google.longrunning.Operation operation = 5;
google.protobuf.Empty closed = 6;
}

reserved 6; // NextId.
reserved 7; // NextId.
}

message Event {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ pub mod response_event {
}
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct StreamEvent {
#[prost(oneof = "stream_event::Event", tags = "1, 2, 3, 4, 5")]
#[prost(oneof = "stream_event::Event", tags = "1, 2, 3, 4, 5, 6")]
pub event: ::core::option::Option<stream_event::Event>,
}
/// Nested message and enum types in `StreamEvent`.
Expand All @@ -218,6 +218,8 @@ pub mod stream_event {
Operation(
super::super::super::super::super::super::google::longrunning::Operation,
),
#[prost(message, tag = "6")]
Closed(()),
}
}
#[derive(Clone, PartialEq, ::prost::Message)]
Expand Down
134 changes: 114 additions & 20 deletions nativelink-util/src/origin_event.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ use std::pin::Pin;
use std::sync::{Arc, OnceLock};

use futures::future::ready;
use futures::task::{Context, Poll};
use futures::{Future, FutureExt, Stream, StreamExt};
use nativelink_proto::build::bazel::remote::execution::v2::{
ActionResult, BatchReadBlobsRequest, BatchReadBlobsResponse, BatchUpdateBlobsRequest,
Expand All @@ -35,13 +36,15 @@ use nativelink_proto::google::bytestream::{
};
use nativelink_proto::google::longrunning::Operation;
use nativelink_proto::google::rpc::Status;
use pin_project_lite::pin_project;
use rand::RngCore;
use tokio::sync::mpsc;
use tokio::sync::mpsc::error::TrySendError;
use tonic::{Response, Status as TonicStatus, Streaming};
use uuid::Uuid;

use crate::make_symbol;
use crate::origin_context::ActiveOriginContext;
use crate::{background_spawn, make_symbol};

const ORIGIN_EVENT_VERSION: u32 = 0;

Expand Down Expand Up @@ -90,6 +93,7 @@ pub fn get_id_for_event(event: &Event) -> [u8; 2] {
Some(stream_event::Event::DataLength(_)) => [0x03, 0x03],
Some(stream_event::Event::WriteRequest(_)) => [0x03, 0x04],
Some(stream_event::Event::Operation(_)) => [0x03, 0x05],
Some(stream_event::Event::Closed(())) => [0x03, 0x06], // Special case when stream has terminated.
},
}
}
Expand Down Expand Up @@ -130,11 +134,17 @@ impl OriginEventCollector {
}
}

async fn publish_origin_event(&self, event: Event, parent_event_id: Option<Uuid>) -> Uuid {
let event_id = Uuid::now_v6(&get_node_id(Some(&event)));
/// Publishes an event to the origin event collector.
async fn publish_origin_event(
&self,
event: Event,
parent_event_id: Option<Uuid>,
maybe_event_id: Option<Uuid>,
) -> Uuid {
let event_id = maybe_event_id.unwrap_or_else(|| Uuid::now_v6(&get_node_id(Some(&event))));
let parent_event_id =
parent_event_id.map_or_else(String::new, |id| id.as_hyphenated().to_string());
// Failing to send this event means that the receiver has been dropped.
// Ignore cases when channel is dropped.
let _ = self
.sender
.send(OriginEvent {
Expand All @@ -148,6 +158,82 @@ impl OriginEventCollector {
.await;
event_id
}

/// Publishes an event to the origin event collector.
/// If the buffer is full, the event will be sent in a background spawn.
/// This is useful for cases where the event is critical and must be sent,
/// but cannot await the send operation.
fn publish_origin_event_now_or_in_spawn(
&self,
event: Event,
parent_event_id: Option<Uuid>,
) {
let event_id = Uuid::now_v6(&get_node_id(Some(&event)));
let parent_event_id =
parent_event_id.map_or_else(String::new, |id| id.as_hyphenated().to_string());
self.sender
.try_send(OriginEvent {
version: ORIGIN_EVENT_VERSION,
event_id: event_id.as_hyphenated().to_string(),
parent_event_id,
bazel_request_metadata: self.bazel_metadata.clone(),
identity: self.identity.clone(),
event: Some(event),
})
.map_or_else(
|e| match e {
TrySendError::Full(event) => {
let sender = self.sender.clone();
background_spawn!("send_end_stream_origin_event", async move {
// Ignore cases when channel is dropped.
let _ = sender.send(event).await;
});
},
// Ignore cases when channel is dropped.
TrySendError::Closed(_) => {},
},
|()| {},
)
}
}

pin_project! {
struct CloseEventStream<S> {
#[pin]
inner: S,
ctx_impl: OriginEventContextImpl,
}

impl <S> PinnedDrop for CloseEventStream<S> {
#[inline]
fn drop(this: Pin<&mut Self>) {
let event = Event {
event: Some(event::Event::Stream(StreamEvent {
event: Some(stream_event::Event::Closed(()))
})),
};
// Try to send the event immediately, if we cannot because
// the buffer is full, do it in a background spawn.
this.ctx_impl.origin_event_collector
.publish_origin_event_now_or_in_spawn(event, Some(this.ctx_impl.parent_event_id));
}
}
}

impl<'a, T, S> Stream for CloseEventStream<S>
where
S: Stream<Item = T> + Send + 'a,
{
type Item = S::Item;

fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let project = self.project();
project.inner.poll_next(cx)
}

fn size_hint(&self) -> (usize, Option<usize>) {
self.inner.size_hint()
}
}

make_symbol!(ORIGIN_EVENT_COLLECTOR, OriginEventCollector);
Expand Down Expand Up @@ -185,7 +271,7 @@ impl OriginEventContext<()> {
let event = source_cb().as_event();
async move {
let parent_event_id = origin_event_collector
.publish_origin_event(event, None)
.publish_origin_event(event, None, None)
.await;
OriginEventContext {
inner: Some(OriginEventContextImpl {
Expand Down Expand Up @@ -220,17 +306,20 @@ impl<U> OriginEventContext<U> {
O: OriginEventSource<U> + Send + 'a,
S: Stream<Item = O> + Send + 'a,
{
if self.inner.is_none() {
let Some(ctx_impl) = self.inner.clone() else {
return Box::pin(stream);
}
};
let ctx = self.clone();
Box::pin(stream.then(move |item| {
let ctx = ctx.clone();
async move {
ctx.emit(|| &item).await;
item
}
}))
Box::pin(CloseEventStream {
inner: stream.then(move |item| {
let ctx = ctx.clone();
async move {
ctx.emit(|| &item).await;
item
}
}),
ctx_impl,
})
}
}

Expand All @@ -246,7 +335,7 @@ pub trait OriginEventSource<Source>: Sized {
fn publish<'a>(&self, ctx: &'a OriginEventContextImpl) -> impl Future<Output = ()> + Send + 'a {
let event = self.as_event();
ctx.origin_event_collector
.publish_origin_event(event, Some(ctx.parent_event_id))
.publish_origin_event(event, Some(ctx.parent_event_id), None)
// We don't need the Uuid here.
.map(|_| ())
}
Expand Down Expand Up @@ -400,10 +489,15 @@ fn to_batch_read_blobs_response_override(val: &BatchReadBlobsResponse) -> event:
}

#[inline]
fn to_empty<T>(_: T) -> event::Event {
fn to_empty_response<T>(_: T) -> event::Event {
get_event_type!(Response, Empty, ())
}

#[inline]
fn to_empty_write_request<T>(_: T) -> event::Event {
get_event_type!(Request, WriteRequest, ())
}

// -- Requests --

impl_as_event! {Request, (), GetCapabilitiesRequest}
Expand All @@ -414,7 +508,7 @@ impl_as_event! {Request, (), BatchReadBlobsRequest}
impl_as_event! {Request, (), BatchUpdateBlobsRequest, BatchUpdateBlobsRequest, to_batch_update_blobs_request_override}
impl_as_event! {Request, (), GetTreeRequest}
impl_as_event! {Request, (), ReadRequest}
impl_as_event! {Request, (), Streaming<WriteRequest>, WriteRequest, to_empty}
impl_as_event! {Request, (), Streaming<WriteRequest>, WriteRequest, to_empty_write_request}
impl_as_event! {Request, (), QueryWriteStatusRequest}
impl_as_event! {Request, (), ExecuteRequest}
impl_as_event! {Request, (), WaitExecutionRequest}
Expand All @@ -425,13 +519,13 @@ impl_as_event! {Response, GetCapabilitiesRequest, ServerCapabilities}
impl_as_event! {Response, GetActionResultRequest, ActionResult}
impl_as_event! {Response, UpdateActionResultRequest, ActionResult}
impl_as_event! {Response, Streaming<WriteRequest>, WriteResponse}
impl_as_event! {Response, ReadRequest, Pin<Box<dyn Stream<Item = Result<ReadResponse, TonicStatus>> + Send + '_>>, Empty, to_empty}
impl_as_event! {Response, ReadRequest, Pin<Box<dyn Stream<Item = Result<ReadResponse, TonicStatus>> + Send + '_>>, Empty, to_empty_response}
impl_as_event! {Response, QueryWriteStatusRequest, QueryWriteStatusResponse}
impl_as_event! {Response, FindMissingBlobsRequest, FindMissingBlobsResponse}
impl_as_event! {Response, BatchUpdateBlobsRequest, BatchUpdateBlobsResponse}
impl_as_event! {Response, BatchReadBlobsRequest, BatchReadBlobsResponse, BatchReadBlobsResponseOverride, to_batch_read_blobs_response_override}
impl_as_event! {Response, GetTreeRequest, Pin<Box<dyn Stream<Item = Result<GetTreeResponse, TonicStatus>> + Send + '_>>, Empty, to_empty}
impl_as_event! {Response, ExecuteRequest, Pin<Box<dyn Stream<Item = Result<Operation, TonicStatus>> + Send + '_>>, Empty, to_empty}
impl_as_event! {Response, GetTreeRequest, Pin<Box<dyn Stream<Item = Result<GetTreeResponse, TonicStatus>> + Send + '_>>, Empty, to_empty_response}
impl_as_event! {Response, ExecuteRequest, Pin<Box<dyn Stream<Item = Result<Operation, TonicStatus>> + Send + '_>>, Empty, to_empty_response}

// -- Streams --

Expand Down
2 changes: 2 additions & 0 deletions nativelink-util/tests/origin_event_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ fn get_id_for_event_test() {
Some(stream_event::Event::DataLength(_)) => [0x03, 0x03],
Some(stream_event::Event::WriteRequest(_)) => [0x03, 0x04],
Some(stream_event::Event::Operation(_)) => [0x03, 0x05],
Some(stream_event::Event::Closed(())) => [0x03, 0x06],
// Don't forget to add new entries to test cases.
}
}
Expand Down Expand Up @@ -161,4 +162,5 @@ fn get_id_for_event_test() {
test_event!(Stream, DataLength);
test_event!(Stream, WriteRequest);
test_event!(Stream, Operation);
test_event!(Stream, Closed);
}

0 comments on commit b0f7016

Please sign in to comment.