From 144f58adaf654a0f98176a5fdb4b5821226f4ff5 Mon Sep 17 00:00:00 2001 From: Kaleb Elwert Date: Thu, 19 Sep 2024 07:44:13 -0700 Subject: [PATCH] Update relevant channel request callbacks to return a bool This is a breaking change, but it tweaks channel request callbacks to return a bool rather than requiring the user to manually call `session.channel_success` or `session.channel_failure`. This has the added advantage of changing the defaults of a number of request callbacks to more-secure defaults (deny), and makes it impossible for a user to miss responding to callbacks which require responses. Note that this does *not* handle sending responses for all requests, only channel requests listed in RFC4254 as having a "want reply" param rather than just "false", even though it may be more correct to respond to malformed requests which have improperly set that byte to "true" even though the RFC specifies "false". --- russh/examples/sftp_server.rs | 21 +++++++------ russh/examples/test.rs | 6 ++-- russh/src/server/encrypted.rs | 55 ++++++++++++++++++++++++++++------- russh/src/server/mod.rs | 24 +++++++-------- russh/src/server/session.rs | 11 ++++--- 5 files changed, 77 insertions(+), 40 deletions(-) diff --git a/russh/examples/sftp_server.rs b/russh/examples/sftp_server.rs index 99ff695..7e60307 100644 --- a/russh/examples/sftp_server.rs +++ b/russh/examples/sftp_server.rs @@ -76,19 +76,22 @@ impl russh::server::Handler for SshSession { channel_id: ChannelId, name: &str, session: &mut Session, - ) -> Result<(), Self::Error> { + ) -> Result { info!("subsystem: {}", name); - if name == "sftp" { - let channel = self.get_channel(channel_id).await; - let sftp = SftpSession::default(); - session.channel_success(channel_id); - russh_sftp::server::run(channel.into_stream(), sftp).await; - } else { - session.channel_failure(channel_id); + if name != "sftp" { + return Ok(false); } - Ok(()) + let channel = self.get_channel(channel_id).await; + let sftp = SftpSession::default(); + session.channel_success(channel_id); + + tokio::spawn(async move { + russh_sftp::server::run(channel.into_stream(), sftp).await; + }); + + Ok(true) } } diff --git a/russh/examples/test.rs b/russh/examples/test.rs index be5effc..58c4fd5 100644 --- a/russh/examples/test.rs +++ b/russh/examples/test.rs @@ -69,14 +69,14 @@ impl server::Handler for Server { &mut self, channel: ChannelId, session: &mut Session, - ) -> Result<(), Self::Error> { - session.request_success(); - Ok(()) + ) -> Result { + Ok(true) } async fn auth_publickey(&mut self, _: &str, _: &key::PublicKey) -> Result { Ok(server::Auth::Accept) } + async fn data( &mut self, _channel: ChannelId, diff --git a/russh/src/server/encrypted.rs b/russh/src/server/encrypted.rs index 9f79061..ea593fa 100644 --- a/russh/src/server/encrypted.rs +++ b/russh/src/server/encrypted.rs @@ -822,7 +822,7 @@ impl Session { debug!("handler.pty_request {:?}", channel_num); #[allow(clippy::indexing_slicing)] // `modes` length checked - handler + let response = handler .pty_request( channel_num, term, @@ -833,7 +833,13 @@ impl Session { &modes[0..i], self, ) - .await + .await?; + if response { + self.channel_success(channel_num); + } else { + self.channel_failure(channel_num); + } + Ok(()) } b"x11-req" => { let single_connection = r.read_byte().map_err(crate::Error::from)? != 0; @@ -855,7 +861,7 @@ impl Session { }); } debug!("handler.x11_request {:?}", channel_num); - handler + let response = handler .x11_request( channel_num, single_connection, @@ -864,7 +870,13 @@ impl Session { x11_screen_number, self, ) - .await + .await?; + if response { + self.channel_success(channel_num); + } else { + self.channel_failure(channel_num); + } + Ok(()) } b"env" => { let env_variable = @@ -883,23 +895,34 @@ impl Session { } debug!("handler.env_request {:?}", channel_num); - handler + let response = handler .env_request(channel_num, env_variable, env_value, self) - .await + .await?; + if response { + self.channel_success(channel_num); + } else { + self.channel_failure(channel_num); + } + Ok(()) } b"shell" => { if let Some(chan) = self.channels.get(&channel_num) { let _ = chan.send(ChannelMsg::RequestShell { want_reply: true }); } debug!("handler.shell_request {:?}", channel_num); - handler.shell_request(channel_num, self).await + let response = handler.shell_request(channel_num, self).await?; + if response { + self.channel_success(channel_num); + } else { + self.channel_failure(channel_num); + } + Ok(()) } b"auth-agent-req@openssh.com" => { if let Some(chan) = self.channels.get(&channel_num) { let _ = chan.send(ChannelMsg::AgentForward { want_reply: true }); } debug!("handler.agent_request {:?}", channel_num); - let response = handler.agent_request(channel_num, self).await?; if response { self.request_success() @@ -917,7 +940,13 @@ impl Session { }); } debug!("handler.exec_request {:?}", channel_num); - handler.exec_request(channel_num, req, self).await + let response = handler.exec_request(channel_num, req, self).await?; + if response { + self.channel_success(channel_num); + } else { + self.channel_failure(channel_num); + } + Ok(()) } b"subsystem" => { let name = @@ -931,7 +960,13 @@ impl Session { }); } debug!("handler.subsystem_request {:?}", channel_num); - handler.subsystem_request(channel_num, name, self).await + let response = handler.subsystem_request(channel_num, name, self).await?; + if response { + self.channel_success(channel_num); + } else { + self.channel_failure(channel_num); + } + Ok(()) } b"window-change" => { let col_width = r.read_u32().map_err(crate::Error::from)?; diff --git a/russh/src/server/mod.rs b/russh/src/server/mod.rs index a5e2150..0656de8 100644 --- a/russh/src/server/mod.rs +++ b/russh/src/server/mod.rs @@ -397,8 +397,8 @@ pub trait Handler: Sized { pix_height: u32, modes: &[(Pty, u32)], session: &mut Session, - ) -> Result<(), Self::Error> { - Ok(()) + ) -> Result { + Ok(false) } /// The client requests an X11 connection. @@ -411,8 +411,8 @@ pub trait Handler: Sized { x11_auth_cookie: &str, x11_screen_number: u32, session: &mut Session, - ) -> Result<(), Self::Error> { - Ok(()) + ) -> Result { + Ok(false) } /// The client wants to set the given environment variable. Check @@ -425,8 +425,8 @@ pub trait Handler: Sized { variable_name: &str, variable_value: &str, session: &mut Session, - ) -> Result<(), Self::Error> { - Ok(()) + ) -> Result { + Ok(false) } /// The client requests a shell. @@ -435,8 +435,8 @@ pub trait Handler: Sized { &mut self, channel: ChannelId, session: &mut Session, - ) -> Result<(), Self::Error> { - Ok(()) + ) -> Result { + Ok(false) } /// The client sends a command to execute, to be passed to a @@ -447,8 +447,8 @@ pub trait Handler: Sized { channel: ChannelId, data: &[u8], session: &mut Session, - ) -> Result<(), Self::Error> { - Ok(()) + ) -> Result { + Ok(false) } /// The client asks to start the subsystem with the given name @@ -459,8 +459,8 @@ pub trait Handler: Sized { channel: ChannelId, name: &str, session: &mut Session, - ) -> Result<(), Self::Error> { - Ok(()) + ) -> Result { + Ok(false) } /// The client's pseudo-terminal window size has changed. diff --git a/russh/src/server/session.rs b/russh/src/server/session.rs index 26d8da6..2943985 100644 --- a/russh/src/server/session.rs +++ b/russh/src/server/session.rs @@ -780,7 +780,8 @@ impl Session { } } - /// Send a "failure" reply to a global request. + /// Send a "failure" reply to a channel request. Always call this function + /// if the request failed (it checks whether the client expects an answer). pub fn channel_failure(&mut self, channel: ChannelId) { if let Some(ref mut enc) = self.common.encrypted { if let Some(channel) = enc.channels.get_mut(&channel) { @@ -825,9 +826,8 @@ impl Session { self.common.byte(channel, msg::CHANNEL_EOF); } - /// Send data to a channel. On session channels, `extended` can be - /// used to encode standard error by passing `Some(1)`, and stdout - /// by passing `None`. + /// Send data to a channel. On session channels, this generally + /// refers to stdout. /// /// The number of bytes added to the "sending pipeline" (to be /// processed by the event loop) is returned. @@ -840,8 +840,7 @@ impl Session { } /// Send data to a channel. On session channels, `extended` can be - /// used to encode standard error by passing `Some(1)`, and stdout - /// by passing `None`. + /// used to encode standard error by passing `1`. /// /// The number of bytes added to the "sending pipeline" (to be /// processed by the event loop) is returned.