diff --git a/Cargo.toml b/Cargo.toml index c90d1eb..6295b66 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -22,10 +22,11 @@ h1_client = ["async-h1", "async-std", "async-native-tls"] native_client = ["curl_client", "wasm_client"] curl_client = ["isahc", "async-std"] wasm_client = ["js-sys", "web-sys", "wasm-bindgen", "wasm-bindgen-futures"] +hyper_client = ["hyper", "hyper-tls"] [dependencies] futures = { version = "0.3.1" } -http-types = "2.3.0" +http-types = { version = "2.3.0", features = ["hyperium_http"] } log = "0.4.7" # h1-client @@ -33,6 +34,10 @@ async-h1 = { version = "2.0.0", optional = true } async-std = { version = "1.6.0", default-features = false, optional = true } async-native-tls = { version = "0.3.1", optional = true } +# reqwest-client +hyper = { version = "0.13.6", features = ["tcp"], optional = true } +hyper-tls = { version = "0.4.3", optional = true } + # isahc-client [target.'cfg(not(target_arch = "wasm32"))'.dependencies] isahc = { version = "0.9", optional = true, default-features = false, features = ["http2"] } @@ -63,5 +68,6 @@ features = [ [dev-dependencies] async-std = { version = "1.6.0", features = ["unstable", "attributes"] } -tide = { version = "0.9.0" } portpicker = "0.1.0" +tide = { version = "0.9.0" } +tokio = { version = "0.2.21", features = ["macros"] } diff --git a/src/hyper.rs b/src/hyper.rs new file mode 100644 index 0000000..7141a15 --- /dev/null +++ b/src/hyper.rs @@ -0,0 +1,188 @@ +//! http-client implementation for reqwest + +use super::{Error, HttpClient, Request, Response}; +use http_types::headers::{HeaderName, HeaderValue}; +use http_types::StatusCode; +use hyper::body::HttpBody; +use hyper_tls::HttpsConnector; +use std::convert::TryFrom; +use std::str::FromStr; + +/// Hyper-based HTTP Client. +#[derive(Debug)] +pub struct HyperClient {} + +impl HyperClient { + /// Create a new client. + /// + /// There is no specific benefit to reusing instances of this client. + pub fn new() -> Self { + HyperClient {} + } +} + +impl HttpClient for HyperClient { + fn send(&self, req: Request) -> futures::future::BoxFuture<'static, Result> { + Box::pin(async move { + let req = HyperHttpRequest::try_from(req).await?.into_inner(); + // UNWRAP: Scheme guaranteed to be "http" or "https" as part of conversion + let scheme = req.uri().scheme_str().unwrap(); + + let response = match scheme { + "http" => { + let client = hyper::Client::builder().build_http::(); + client.request(req).await + } + "https" => { + let https = HttpsConnector::new(); + let client = hyper::Client::builder().build::<_, hyper::Body>(https); + client.request(req).await + } + _ => unreachable!(), + }?; + + let resp = HttpTypesResponse::try_from(response).await?.into_inner(); + Ok(resp) + }) + } +} + +struct HyperHttpRequest { + inner: hyper::Request, +} + +impl HyperHttpRequest { + async fn try_from(mut value: Request) -> Result { + // UNWRAP: This unwrap is unjustified in `http-types`, need to check if it's actually safe. + let uri = hyper::Uri::try_from(&format!("{}", value.url())).unwrap(); + + // `HyperClient` depends on the scheme being either "http" or "https" + match uri.scheme_str() { + Some("http") | Some("https") => (), + _ => return Err(Error::from_str(StatusCode::BadRequest, "invalid scheme")), + }; + + let mut request = hyper::Request::builder(); + + // UNWRAP: Default builder is safe + let req_headers = request.headers_mut().unwrap(); + for (name, values) in &value { + // UNWRAP: http-types and http have equivalent validation rules + let name = hyper::header::HeaderName::from_str(name.as_str()).unwrap(); + + for value in values.iter() { + // UNWRAP: http-types and http have equivalent validation rules + let value = + hyper::header::HeaderValue::from_bytes(value.as_str().as_bytes()).unwrap(); + req_headers.append(&name, value); + } + } + + let body = value.body_bytes().await?; + let body = hyper::Body::from(body); + + let request = request + .method(value.method()) + .version(value.version().map(|v| v.into()).unwrap_or_default()) + .uri(uri) + .body(body)?; + + Ok(HyperHttpRequest { inner: request }) + } + + fn into_inner(self) -> hyper::Request { + self.inner + } +} + +struct HttpTypesResponse { + inner: Response, +} + +impl HttpTypesResponse { + async fn try_from(value: hyper::Response) -> Result { + let (parts, mut body) = value.into_parts(); + + let body = match body.data().await { + None => None, + Some(Ok(b)) => Some(b), + Some(Err(_)) => { + return Err(Error::from_str( + StatusCode::BadGateway, + "unable to read HTTP response body", + )) + } + } + .map(|b| http_types::Body::from_bytes(b.to_vec())) + .unwrap_or(http_types::Body::empty()); + + let mut res = Response::new(parts.status); + res.set_version(Some(parts.version.into())); + + for (name, value) in parts.headers { + let value = value.as_bytes().to_owned(); + let value = HeaderValue::from_bytes(value)?; + + if let Some(name) = name { + let name = name.as_str(); + let name = HeaderName::from_str(name)?; + res.insert_header(name, value); + } + } + + res.set_body(body); + Ok(HttpTypesResponse { inner: res }) + } + + fn into_inner(self) -> Response { + self.inner + } +} + +#[cfg(test)] +mod tests { + use crate::{Error, HttpClient}; + use http_types::{Method, Request, Url}; + use hyper::service::{make_service_fn, service_fn}; + use std::time::Duration; + use tokio::sync::oneshot::channel; + + use super::HyperClient; + + async fn echo( + req: hyper::Request, + ) -> Result, hyper::Error> { + Ok(hyper::Response::new(req.into_body())) + } + + #[tokio::test] + async fn basic_functionality() { + let (send, recv) = channel::<()>(); + + let recv = async move { recv.await.unwrap_or(()) }; + + let addr = ([127, 0, 0, 1], portpicker::pick_unused_port().unwrap()).into(); + let service = make_service_fn(|_| async { Ok::<_, hyper::Error>(service_fn(echo)) }); + let server = hyper::Server::bind(&addr) + .serve(service) + .with_graceful_shutdown(recv); + + let client = HyperClient::new(); + let url = Url::parse(&format!("http://localhost:{}", addr.port())).unwrap(); + let mut req = Request::new(Method::Get, url); + req.set_body("hello"); + + let client = async move { + tokio::time::delay_for(Duration::from_millis(100)).await; + let mut resp = client.send(req).await?; + send.send(()).unwrap(); + assert_eq!(resp.body_string().await?, "hello"); + + Result::<(), Error>::Ok(()) + }; + + let (client_res, server_res) = tokio::join!(client, server); + assert!(client_res.is_ok()); + assert!(server_res.is_ok()); + } +} diff --git a/src/lib.rs b/src/lib.rs index 3e5c8cb..490eef6 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -32,6 +32,10 @@ pub mod native; #[cfg(feature = "h1_client")] pub mod h1; +#[cfg_attr(feature = "docs", doc(cfg(hyper_client)))] +#[cfg(feature = "hyper_client")] +pub mod hyper; + /// An HTTP Request type with a streaming body. pub type Request = http_types::Request;