Skip to content

Commit

Permalink
Merge pull request #32 from bspeice/hyper
Browse files Browse the repository at this point in the history
Add a `hyper` client implementation
  • Loading branch information
yoshuawuyts authored Aug 7, 2020
2 parents ea2cff6 + 24f1b20 commit a1043a5
Show file tree
Hide file tree
Showing 3 changed files with 200 additions and 2 deletions.
10 changes: 8 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,22 @@ h1_client = ["async-h1", "async-std", "async-native-tls"]
native_client = ["curl_client", "wasm_client"]
curl_client = ["isahc", "async-std"]
wasm_client = ["js-sys", "web-sys", "wasm-bindgen", "wasm-bindgen-futures"]
hyper_client = ["hyper", "hyper-tls"]

[dependencies]
futures = { version = "0.3.1" }
http-types = "2.3.0"
http-types = { version = "2.3.0", features = ["hyperium_http"] }
log = "0.4.7"

# h1-client
async-h1 = { version = "2.0.0", optional = true }
async-std = { version = "1.6.0", default-features = false, optional = true }
async-native-tls = { version = "0.3.1", optional = true }

# reqwest-client
hyper = { version = "0.13.6", features = ["tcp"], optional = true }
hyper-tls = { version = "0.4.3", optional = true }

# isahc-client
[target.'cfg(not(target_arch = "wasm32"))'.dependencies]
isahc = { version = "0.9", optional = true, default-features = false, features = ["http2"] }
Expand Down Expand Up @@ -63,5 +68,6 @@ features = [

[dev-dependencies]
async-std = { version = "1.6.0", features = ["unstable", "attributes"] }
tide = { version = "0.9.0" }
portpicker = "0.1.0"
tide = { version = "0.9.0" }
tokio = { version = "0.2.21", features = ["macros"] }
188 changes: 188 additions & 0 deletions src/hyper.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
//! http-client implementation for reqwest

use super::{Error, HttpClient, Request, Response};
use http_types::headers::{HeaderName, HeaderValue};
use http_types::StatusCode;
use hyper::body::HttpBody;
use hyper_tls::HttpsConnector;
use std::convert::TryFrom;
use std::str::FromStr;

/// Hyper-based HTTP Client.
#[derive(Debug)]
pub struct HyperClient {}

impl HyperClient {
/// Create a new client.
///
/// There is no specific benefit to reusing instances of this client.
pub fn new() -> Self {
HyperClient {}
}
}

impl HttpClient for HyperClient {
fn send(&self, req: Request) -> futures::future::BoxFuture<'static, Result<Response, Error>> {
Box::pin(async move {
let req = HyperHttpRequest::try_from(req).await?.into_inner();
// UNWRAP: Scheme guaranteed to be "http" or "https" as part of conversion
let scheme = req.uri().scheme_str().unwrap();

let response = match scheme {
"http" => {
let client = hyper::Client::builder().build_http::<hyper::Body>();
client.request(req).await
}
"https" => {
let https = HttpsConnector::new();
let client = hyper::Client::builder().build::<_, hyper::Body>(https);
client.request(req).await
}
_ => unreachable!(),
}?;

let resp = HttpTypesResponse::try_from(response).await?.into_inner();
Ok(resp)
})
}
}

struct HyperHttpRequest {
inner: hyper::Request<hyper::Body>,
}

impl HyperHttpRequest {
async fn try_from(mut value: Request) -> Result<Self, Error> {
// UNWRAP: This unwrap is unjustified in `http-types`, need to check if it's actually safe.
let uri = hyper::Uri::try_from(&format!("{}", value.url())).unwrap();

// `HyperClient` depends on the scheme being either "http" or "https"
match uri.scheme_str() {
Some("http") | Some("https") => (),
_ => return Err(Error::from_str(StatusCode::BadRequest, "invalid scheme")),
};

let mut request = hyper::Request::builder();

// UNWRAP: Default builder is safe
let req_headers = request.headers_mut().unwrap();
for (name, values) in &value {
// UNWRAP: http-types and http have equivalent validation rules
let name = hyper::header::HeaderName::from_str(name.as_str()).unwrap();

for value in values.iter() {
// UNWRAP: http-types and http have equivalent validation rules
let value =
hyper::header::HeaderValue::from_bytes(value.as_str().as_bytes()).unwrap();
req_headers.append(&name, value);
}
}

let body = value.body_bytes().await?;
let body = hyper::Body::from(body);

let request = request
.method(value.method())
.version(value.version().map(|v| v.into()).unwrap_or_default())
.uri(uri)
.body(body)?;

Ok(HyperHttpRequest { inner: request })
}

fn into_inner(self) -> hyper::Request<hyper::Body> {
self.inner
}
}

struct HttpTypesResponse {
inner: Response,
}

impl HttpTypesResponse {
async fn try_from(value: hyper::Response<hyper::Body>) -> Result<Self, Error> {
let (parts, mut body) = value.into_parts();

let body = match body.data().await {
None => None,
Some(Ok(b)) => Some(b),
Some(Err(_)) => {
return Err(Error::from_str(
StatusCode::BadGateway,
"unable to read HTTP response body",
))
}
}
.map(|b| http_types::Body::from_bytes(b.to_vec()))
.unwrap_or(http_types::Body::empty());

let mut res = Response::new(parts.status);
res.set_version(Some(parts.version.into()));

for (name, value) in parts.headers {
let value = value.as_bytes().to_owned();
let value = HeaderValue::from_bytes(value)?;

if let Some(name) = name {
let name = name.as_str();
let name = HeaderName::from_str(name)?;
res.insert_header(name, value);
}
}

res.set_body(body);
Ok(HttpTypesResponse { inner: res })
}

fn into_inner(self) -> Response {
self.inner
}
}

#[cfg(test)]
mod tests {
use crate::{Error, HttpClient};
use http_types::{Method, Request, Url};
use hyper::service::{make_service_fn, service_fn};
use std::time::Duration;
use tokio::sync::oneshot::channel;

use super::HyperClient;

async fn echo(
req: hyper::Request<hyper::Body>,
) -> Result<hyper::Response<hyper::Body>, hyper::Error> {
Ok(hyper::Response::new(req.into_body()))
}

#[tokio::test]
async fn basic_functionality() {
let (send, recv) = channel::<()>();

let recv = async move { recv.await.unwrap_or(()) };

let addr = ([127, 0, 0, 1], portpicker::pick_unused_port().unwrap()).into();
let service = make_service_fn(|_| async { Ok::<_, hyper::Error>(service_fn(echo)) });
let server = hyper::Server::bind(&addr)
.serve(service)
.with_graceful_shutdown(recv);

let client = HyperClient::new();
let url = Url::parse(&format!("http://localhost:{}", addr.port())).unwrap();
let mut req = Request::new(Method::Get, url);
req.set_body("hello");

let client = async move {
tokio::time::delay_for(Duration::from_millis(100)).await;
let mut resp = client.send(req).await?;
send.send(()).unwrap();
assert_eq!(resp.body_string().await?, "hello");

Result::<(), Error>::Ok(())
};

let (client_res, server_res) = tokio::join!(client, server);
assert!(client_res.is_ok());
assert!(server_res.is_ok());
}
}
4 changes: 4 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@ pub mod native;
#[cfg(feature = "h1_client")]
pub mod h1;

#[cfg_attr(feature = "docs", doc(cfg(hyper_client)))]
#[cfg(feature = "hyper_client")]
pub mod hyper;

/// An HTTP Request type with a streaming body.
pub type Request = http_types::Request;

Expand Down

0 comments on commit a1043a5

Please sign in to comment.