From c86a6bcb4acb0f92e731ea2e4c1e4a839248a600 Mon Sep 17 00:00:00 2001 From: Sean McArthur Date: Fri, 11 Oct 2024 11:28:29 -0700 Subject: [PATCH] fix(http1): send 'connection: close' when connection is ending (#3725) This includes conditions where hyper knows the connection will end after the response, such as a request error that ruins the connection, or when graceful shutdown is triggered. Closes #3720 --- benches/pipeline.rs | 2 +- benches/server.rs | 2 +- src/proto/h1/conn.rs | 25 +++++++++++++++++-------- tests/server.rs | 12 ++++++++++-- 4 files changed, 29 insertions(+), 12 deletions(-) diff --git a/benches/pipeline.rs b/benches/pipeline.rs index b79232de9b..d36f054a6c 100644 --- a/benches/pipeline.rs +++ b/benches/pipeline.rs @@ -76,7 +76,7 @@ fn hello_world_16(b: &mut test::Bencher) { tcp.write_all(b"GET / HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\n\r\n") .unwrap(); let mut buf = Vec::new(); - tcp.read_to_end(&mut buf).unwrap() + tcp.read_to_end(&mut buf).unwrap() - "connection: close\r\n".len() } * PIPELINED_REQUESTS; let mut tcp = TcpStream::connect(addr).unwrap(); diff --git a/benches/server.rs b/benches/server.rs index c5424105a8..6e9d3742cf 100644 --- a/benches/server.rs +++ b/benches/server.rs @@ -72,7 +72,7 @@ macro_rules! bench_server { tcp.write_all(b"GET / HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\n\r\n") .unwrap(); let mut buf = Vec::new(); - tcp.read_to_end(&mut buf).unwrap() + tcp.read_to_end(&mut buf).unwrap() - "connection: close\r\n".len() }; let mut tcp = TcpStream::connect(addr).unwrap(); diff --git a/src/proto/h1/conn.rs b/src/proto/h1/conn.rs index 10f4f87b40..8ddf7558e1 100644 --- a/src/proto/h1/conn.rs +++ b/src/proto/h1/conn.rs @@ -21,7 +21,7 @@ use super::{Decoder, Encode, EncodedBuf, Encoder, Http1Transaction, ParseContext use crate::body::DecodedLength; #[cfg(feature = "server")] use crate::common::time::Time; -use crate::headers::connection_keep_alive; +use crate::headers; use crate::proto::{BodyLength, MessageHead}; #[cfg(feature = "server")] use crate::rt::Sleep; @@ -657,7 +657,7 @@ where let outgoing_is_keep_alive = head .headers .get(CONNECTION) - .map_or(false, connection_keep_alive); + .map_or(false, headers::connection_keep_alive); if !outgoing_is_keep_alive { match head.version { @@ -680,12 +680,21 @@ where // If we know the remote speaks an older version, we try to fix up any messages // to work with our older peer. fn enforce_version(&mut self, head: &mut MessageHead) { - if let Version::HTTP_10 = self.state.version { - // Fixes response or connection when keep-alive header is not present - self.fix_keep_alive(head); - // If the remote only knows HTTP/1.0, we should force ourselves - // to do only speak HTTP/1.0 as well. - head.version = Version::HTTP_10; + match self.state.version { + Version::HTTP_10 => { + // Fixes response or connection when keep-alive header is not present + self.fix_keep_alive(head); + // If the remote only knows HTTP/1.0, we should force ourselves + // to do only speak HTTP/1.0 as well. + head.version = Version::HTTP_10; + } + Version::HTTP_11 => { + if let KA::Disabled = self.state.keep_alive.status() { + head.headers + .insert(CONNECTION, HeaderValue::from_static("close")); + } + } + _ => (), } // If the remote speaks HTTP/1.1, then it *should* be fine with // both HTTP/1.0 and HTTP/1.1 from us. So again, we just let diff --git a/tests/server.rs b/tests/server.rs index 5120ad776f..f72cf62702 100644 --- a/tests/server.rs +++ b/tests/server.rs @@ -1140,6 +1140,8 @@ fn pipeline_enabled() { assert_eq!(s(lines.next().unwrap()), "HTTP/1.1 200 OK\r"); assert_eq!(s(lines.next().unwrap()), "content-length: 12\r"); + // close because the last request said to close + assert_eq!(s(lines.next().unwrap()), "connection: close\r"); lines.next().unwrap(); // Date assert_eq!(s(lines.next().unwrap()), "\r"); assert_eq!(s(lines.next().unwrap()), "Hello World"); @@ -1181,7 +1183,7 @@ fn http_11_uri_too_long() { let mut req = connect(server.addr()); req.write_all(request_line.as_bytes()).unwrap(); - let expected = "HTTP/1.1 414 URI Too Long\r\ncontent-length: 0\r\n"; + let expected = "HTTP/1.1 414 URI Too Long\r\nconnection: close\r\ncontent-length: 0\r\n"; let mut buf = [0; 256]; let n = req.read(&mut buf).unwrap(); assert!(n >= expected.len(), "read: {:?} >= {:?}", n, expected.len()); @@ -1208,6 +1210,12 @@ async fn disable_keep_alive_mid_request() { "should receive OK response, but buf: {:?}", buf, ); + let sbuf = s(&buf); + assert!( + sbuf.contains("connection: close\r\n"), + "response should have sent close: {:?}", + sbuf, + ); }); let (socket, _) = listener.accept().await.unwrap(); @@ -2366,7 +2374,7 @@ fn streaming_body() { buf.starts_with(b"HTTP/1.1 200 OK\r\n"), "response is 200 OK" ); - assert_eq!(buf.len(), 100_789, "full streamed body read"); + assert_eq!(buf.len(), 100_808, "full streamed body read"); } #[test]