Skip to content

Commit

Permalink
Begin hyper upgrade in lading/
Browse files Browse the repository at this point in the history
This commit begins the hyper upgrade in lading/ itself. This will be a
slow process spread over many commits.

Signed-off-by: Brian L. Troutwine <[email protected]>
  • Loading branch information
blt committed Dec 24, 2024
1 parent bfbdbb2 commit 8622df9
Show file tree
Hide file tree
Showing 5 changed files with 162 additions and 105 deletions.
12 changes: 7 additions & 5 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 5 additions & 3 deletions lading/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,11 @@ flate2 = { version = "1.0.34", default-features = false, features = [
futures = "0.3.31"
fuser = { version = "0.15", optional = true }
heck = { version = "0.5", default-features = false }
http = "0.2"
http-serde = "1.1"
hyper = { workspace = true, features = ["backports", "client", "deprecated", "http1", "http2", "server"] }
http = "1.2"
http-serde = "2.1"
hyper = { version = "1.5", features = ["client", "http1", "http2", "server"] }
http-body-util = "0.1"
hyper-util = "0.1"
is_executable = "1.0.4"
metrics = { workspace = true }
metrics-exporter-prometheus = { workspace = true }
Expand Down
133 changes: 79 additions & 54 deletions lading/src/blackhole/http.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,19 @@
//! `requests_received`: Total requests received
//!
use std::{net::SocketAddr, time::Duration};
use std::{net::SocketAddr, sync::Arc, time::Duration};

use bytes::Bytes;
use http::{header::InvalidHeaderValue, status::InvalidStatusCode, HeaderMap};
use hyper::{
body::HttpBody,
header,
server::conn::{AddrIncoming, AddrStream},
service::{make_service_fn, service_fn},
Body, Request, Response, Server, StatusCode,
use http_body_util::{combinators::BoxBody, BodyExt};
use hyper::{header, service::service_fn, Request, Response, StatusCode};
use hyper_util::{
rt::{TokioExecutor, TokioIo},
server::conn::auto,
};
use metrics::counter;
use serde::{Deserialize, Serialize};
use tower::ServiceBuilder;
use tokio::{sync::Semaphore, task::JoinSet};
use tracing::{debug, error, info};

use super::General;
Expand All @@ -42,6 +42,9 @@ pub enum Error {
/// Failed to deserialize the configuration.
#[error("Failed to deserialize the configuration: {0}")]
Serde(#[from] serde_json::Error),
/// Wrapper for [`std::io::Error`].
#[error("IO error: {0}")]
Io(#[from] std::io::Error),
}

/// Body variant supported by this blackhole.
Expand Down Expand Up @@ -129,18 +132,21 @@ async fn srv(
status: StatusCode,
metric_labels: Vec<(String, String)>,
body_bytes: Vec<u8>,
req: Request<Body>,
req: Request<hyper::body::Incoming>,
headers: HeaderMap,
response_delay: Duration,
) -> Result<Response<Body>, hyper::Error> {
) -> Result<hyper::Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
counter!("requests_received", &metric_labels).increment(1);

// Split into parts
let (parts, body) = req.into_parts();

let bytes = body.collect().await?.to_bytes();
counter!("bytes_received", &metric_labels).increment(bytes.len() as u64);
// Convert the `Body` into `Bytes`
let body: Bytes = body.boxed().collect().await?.to_bytes();

match crate::codec::decode(parts.headers.get(hyper::header::CONTENT_ENCODING), bytes) {
counter!("bytes_received", &metric_labels).increment(body.len() as u64);

match crate::codec::decode(parts.headers.get(hyper::header::CONTENT_ENCODING), body) {
Err(response) => Ok(response),
Ok(body) => {
counter!("decoded_bytes_received", &metric_labels).increment(body.len() as u64);
Expand All @@ -150,7 +156,7 @@ async fn srv(
let mut okay = Response::default();
*okay.status_mut() = status;
*okay.headers_mut() = headers;
*okay.body_mut() = Body::from(body_bytes);
*okay.body_mut() = crate::full(body);
Ok(okay)
}
}
Expand Down Expand Up @@ -234,48 +240,67 @@ impl Http {
/// Function will return an error if the configuration is invalid or if
/// receiving a packet fails.
pub async fn run(self) -> Result<(), Error> {
let service = make_service_fn(|_: &AddrStream| {
let metric_labels = self.metric_labels.clone();
let body_bytes = self.body_bytes.clone();
let headers = self.headers.clone();
async move {
Ok::<_, hyper::Error>(service_fn(move |request| {
debug!("REQUEST: {:?}", request);
srv(
self.status,
metric_labels.clone(),
body_bytes.clone(),
request,
headers.clone(),
self.response_delay,
)
}))
}
});
let svc = ServiceBuilder::new()
.load_shed()
.concurrency_limit(self.concurrency_limit)
.timeout(Duration::from_secs(1))
.service(service);

let addr = AddrIncoming::bind(&self.httpd_addr)
.map(|mut addr| {
addr.set_keepalive(Some(Duration::from_secs(60)));
addr
})
.map_err(Error::Hyper)?;

let server = Server::builder(addr).serve(svc);
tokio::select! {
res = server => {
error!("server shutdown unexpectedly");
res.map_err(Error::Hyper)
}
() = self.shutdown.recv() => {
info!("shutdown signal received");
Ok(())
let listener = tokio::net::TcpListener::bind(self.httpd_addr).await?;
let sem = Arc::new(Semaphore::new(self.concurrency_limit));
let mut join_set = JoinSet::new();
loop {
tokio::select! {
_ = self.shutdown.recv() => {
info!("shutdown signal received");
break;
}

incoming = listener.accept() => {
let (stream, addr) = match incoming {
Ok((s,a)) => (s,a),
Err(e) => {
error!("accept error: {e}");
continue;
}
};

let metric_labels = self.metric_labels.clone();
let body_bytes = self.body_bytes.clone();
let headers = self.headers.clone();
let status = self.status;
let response_delay = self.response_delay;

join_set.spawn(async move {
debug!("Accepted connection from {addr}");
let permit = match sem.acquire_owned().await {
Ok(p) => p,
Err(e) => {
error!("Semaphore closed: {e}");
return;
}
};
let serve_future = auto::Builder::new(TokioExecutor::new())
.serve_connection(
TokioIo::new(stream),
service_fn(move |req: Request<hyper::body::Incoming>| {
debug!("REQUEST: {:?}", req);
srv(
status,
metric_labels.clone(),
body_bytes.clone(),
req,
headers.clone(),
response_delay,
)
})
);

if let Err(e) = serve_future.await {
error!("Error serving {addr}: {e}");
}
drop(permit);
});
}
}
}
drop(listener);
while join_set.join_next().await.is_some() {}
Ok(())
}
}

Expand Down
103 changes: 60 additions & 43 deletions lading/src/blackhole/splunk_hec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,21 +12,20 @@ use std::{
atomic::{AtomicU64, Ordering},
Arc,
},
time::Duration,
};

use hyper::{
body::HttpBody,
header,
server::conn::{AddrIncoming, AddrStream},
service::{make_service_fn, service_fn},
Body, Method, Request, Response, Server, StatusCode,
use bytes::Bytes;
use http_body_util::{combinators::BoxBody, BodyExt};
use hyper::{header, service::service_fn, Method, Request, Response, StatusCode};
use hyper_util::{
rt::{TokioExecutor, TokioIo},
server::conn::auto,
};
use metrics::counter;
use rustc_hash::FxHashMap;
use serde::{Deserialize, Serialize};
use tower::ServiceBuilder;
use tracing::{error, info};
use tokio::task::JoinSet;
use tracing::{debug, error, info};

use super::General;

Expand All @@ -45,6 +44,9 @@ pub enum Error {
/// Deserialization Error
#[error(transparent)]
SerdeJson(#[from] serde_json::Error),
/// Wrapper for [`std::io::Error`].
#[error("IO error: {0}")]
Io(#[from] std::io::Error),
}

#[derive(Debug, Deserialize, Serialize, Clone, Copy, PartialEq, Eq)]
Expand Down Expand Up @@ -89,13 +91,13 @@ struct HecResponse {
}

async fn srv(
req: Request<Body>,
req: Request<hyper::body::Incoming>,
labels: Arc<Vec<(String, String)>>,
) -> Result<Response<Body>, Error> {
) -> Result<hyper::Response<BoxBody<Bytes, hyper::Error>>, Error> {
counter!("requests_received", &*labels).increment(1);

let (parts, body) = req.into_parts();
let bytes = body.collect().await?.to_bytes();
let bytes = body.boxed().collect().await?.to_bytes();
counter!("bytes_received", &*labels).increment(bytes.len() as u64);

match crate::codec::decode(parts.headers.get(hyper::header::CONTENT_ENCODING), bytes) {
Expand Down Expand Up @@ -125,15 +127,15 @@ async fn srv(
code: 0,
ack_id,
})?;
*okay.body_mut() = Body::from(body_bytes);
*okay.body_mut() = crate::full(body_bytes);
}
// Path for querying indexer acknowledgements
(Method::POST, "/services/collector/ack") => {
match serde_json::from_slice::<HecAckRequest>(&body) {
Ok(ack_request) => {
let body_bytes =
serde_json::to_vec(&HecAckResponse::from(ack_request))?;
*okay.body_mut() = Body::from(body_bytes);
*okay.body_mut() = crate::full(body_bytes);
}
Err(_) => {
*okay.status_mut() = StatusCode::BAD_REQUEST;
Expand Down Expand Up @@ -190,38 +192,53 @@ impl SplunkHec {
///
/// None known.
pub async fn run(self) -> Result<(), Error> {
let listener = tokio::net::TcpListener::bind(&self.httpd_addr).await?;
let sem = Arc::new(tokio::sync::Semaphore::new(self.concurrency_limit));
let mut join_set = JoinSet::new();
let labels = Arc::new(self.metric_labels.clone());
let service = make_service_fn(|_: &AddrStream| {
let labels = labels.clone();
async move {
Ok::<_, hyper::Error>(service_fn(move |req| {

loop {
tokio::select! {
_ = self.shutdown.recv() => {
info!("shutdown signal received");
break;
}
incoming = listener.accept() => {
let (stream, addr) = match incoming {
Ok((s,a)) => (s,a),
Err(e) => {
error!("accept error: {e}");
continue;
}
};

let labels = Arc::clone(&labels);
srv(req, labels)
}))
}
});
let svc = ServiceBuilder::new()
.load_shed()
.concurrency_limit(self.concurrency_limit)
.timeout(Duration::from_secs(1))
.service(service);

let addr = AddrIncoming::bind(&self.httpd_addr)
.map(|mut addr| {
addr.set_keepalive(Some(Duration::from_secs(60)));
addr
})
.map_err(Error::Hyper)?;
let server = Server::builder(addr).serve(svc);
tokio::select! {
res = server => {
error!("server shutdown unexpectedly");
res.map_err(Error::Hyper)
}
() = self.shutdown.recv() => {
info!("shutdown signal received");
Ok(())
join_set.spawn(async move {
debug!("Accepted connection from {addr}");
let permit = match sem.acquire_owned().await {
Ok(p) => p,
Err(e) => {
error!("Semaphore closed: {e}");
return;
}
};
let serve_future = auto::Builder::new(TokioExecutor::new())
.serve_connection(TokioIo::new(stream), service_fn(move |req| {
let labels = Arc::clone(&labels);
srv(req, labels)
}));

if let Err(e) = serve_future.await {
error!("Error serving: {e}");
}
drop(permit);
});
}
}
}

drop(listener);
while join_set.join_next().await.is_some() {}
Ok(())
}
}
Loading

0 comments on commit 8622df9

Please sign in to comment.