diff --git a/russh/examples/sftp_server.rs b/russh/examples/sftp_server.rs index 99ff695..7e60307 100644 --- a/russh/examples/sftp_server.rs +++ b/russh/examples/sftp_server.rs @@ -76,19 +76,22 @@ impl russh::server::Handler for SshSession { channel_id: ChannelId, name: &str, session: &mut Session, - ) -> Result<(), Self::Error> { + ) -> Result { info!("subsystem: {}", name); - if name == "sftp" { - let channel = self.get_channel(channel_id).await; - let sftp = SftpSession::default(); - session.channel_success(channel_id); - russh_sftp::server::run(channel.into_stream(), sftp).await; - } else { - session.channel_failure(channel_id); + if name != "sftp" { + return Ok(false); } - Ok(()) + let channel = self.get_channel(channel_id).await; + let sftp = SftpSession::default(); + session.channel_success(channel_id); + + tokio::spawn(async move { + russh_sftp::server::run(channel.into_stream(), sftp).await; + }); + + Ok(true) } } diff --git a/russh/examples/test.rs b/russh/examples/test.rs index be5effc..58c4fd5 100644 --- a/russh/examples/test.rs +++ b/russh/examples/test.rs @@ -69,14 +69,14 @@ impl server::Handler for Server { &mut self, channel: ChannelId, session: &mut Session, - ) -> Result<(), Self::Error> { - session.request_success(); - Ok(()) + ) -> Result { + Ok(true) } async fn auth_publickey(&mut self, _: &str, _: &key::PublicKey) -> Result { Ok(server::Auth::Accept) } + async fn data( &mut self, _channel: ChannelId, diff --git a/russh/src/server/encrypted.rs b/russh/src/server/encrypted.rs index 9f79061..ea593fa 100644 --- a/russh/src/server/encrypted.rs +++ b/russh/src/server/encrypted.rs @@ -822,7 +822,7 @@ impl Session { debug!("handler.pty_request {:?}", channel_num); #[allow(clippy::indexing_slicing)] // `modes` length checked - handler + let response = handler .pty_request( channel_num, term, @@ -833,7 +833,13 @@ impl Session { &modes[0..i], self, ) - .await + .await?; + if response { + self.channel_success(channel_num); + } else { + self.channel_failure(channel_num); + } + Ok(()) } b"x11-req" => { let single_connection = r.read_byte().map_err(crate::Error::from)? != 0; @@ -855,7 +861,7 @@ impl Session { }); } debug!("handler.x11_request {:?}", channel_num); - handler + let response = handler .x11_request( channel_num, single_connection, @@ -864,7 +870,13 @@ impl Session { x11_screen_number, self, ) - .await + .await?; + if response { + self.channel_success(channel_num); + } else { + self.channel_failure(channel_num); + } + Ok(()) } b"env" => { let env_variable = @@ -883,23 +895,34 @@ impl Session { } debug!("handler.env_request {:?}", channel_num); - handler + let response = handler .env_request(channel_num, env_variable, env_value, self) - .await + .await?; + if response { + self.channel_success(channel_num); + } else { + self.channel_failure(channel_num); + } + Ok(()) } b"shell" => { if let Some(chan) = self.channels.get(&channel_num) { let _ = chan.send(ChannelMsg::RequestShell { want_reply: true }); } debug!("handler.shell_request {:?}", channel_num); - handler.shell_request(channel_num, self).await + let response = handler.shell_request(channel_num, self).await?; + if response { + self.channel_success(channel_num); + } else { + self.channel_failure(channel_num); + } + Ok(()) } b"auth-agent-req@openssh.com" => { if let Some(chan) = self.channels.get(&channel_num) { let _ = chan.send(ChannelMsg::AgentForward { want_reply: true }); } debug!("handler.agent_request {:?}", channel_num); - let response = handler.agent_request(channel_num, self).await?; if response { self.request_success() @@ -917,7 +940,13 @@ impl Session { }); } debug!("handler.exec_request {:?}", channel_num); - handler.exec_request(channel_num, req, self).await + let response = handler.exec_request(channel_num, req, self).await?; + if response { + self.channel_success(channel_num); + } else { + self.channel_failure(channel_num); + } + Ok(()) } b"subsystem" => { let name = @@ -931,7 +960,13 @@ impl Session { }); } debug!("handler.subsystem_request {:?}", channel_num); - handler.subsystem_request(channel_num, name, self).await + let response = handler.subsystem_request(channel_num, name, self).await?; + if response { + self.channel_success(channel_num); + } else { + self.channel_failure(channel_num); + } + Ok(()) } b"window-change" => { let col_width = r.read_u32().map_err(crate::Error::from)?; diff --git a/russh/src/server/mod.rs b/russh/src/server/mod.rs index a5e2150..0656de8 100644 --- a/russh/src/server/mod.rs +++ b/russh/src/server/mod.rs @@ -397,8 +397,8 @@ pub trait Handler: Sized { pix_height: u32, modes: &[(Pty, u32)], session: &mut Session, - ) -> Result<(), Self::Error> { - Ok(()) + ) -> Result { + Ok(false) } /// The client requests an X11 connection. @@ -411,8 +411,8 @@ pub trait Handler: Sized { x11_auth_cookie: &str, x11_screen_number: u32, session: &mut Session, - ) -> Result<(), Self::Error> { - Ok(()) + ) -> Result { + Ok(false) } /// The client wants to set the given environment variable. Check @@ -425,8 +425,8 @@ pub trait Handler: Sized { variable_name: &str, variable_value: &str, session: &mut Session, - ) -> Result<(), Self::Error> { - Ok(()) + ) -> Result { + Ok(false) } /// The client requests a shell. @@ -435,8 +435,8 @@ pub trait Handler: Sized { &mut self, channel: ChannelId, session: &mut Session, - ) -> Result<(), Self::Error> { - Ok(()) + ) -> Result { + Ok(false) } /// The client sends a command to execute, to be passed to a @@ -447,8 +447,8 @@ pub trait Handler: Sized { channel: ChannelId, data: &[u8], session: &mut Session, - ) -> Result<(), Self::Error> { - Ok(()) + ) -> Result { + Ok(false) } /// The client asks to start the subsystem with the given name @@ -459,8 +459,8 @@ pub trait Handler: Sized { channel: ChannelId, name: &str, session: &mut Session, - ) -> Result<(), Self::Error> { - Ok(()) + ) -> Result { + Ok(false) } /// The client's pseudo-terminal window size has changed. diff --git a/russh/src/server/session.rs b/russh/src/server/session.rs index 26d8da6..2943985 100644 --- a/russh/src/server/session.rs +++ b/russh/src/server/session.rs @@ -780,7 +780,8 @@ impl Session { } } - /// Send a "failure" reply to a global request. + /// Send a "failure" reply to a channel request. Always call this function + /// if the request failed (it checks whether the client expects an answer). pub fn channel_failure(&mut self, channel: ChannelId) { if let Some(ref mut enc) = self.common.encrypted { if let Some(channel) = enc.channels.get_mut(&channel) { @@ -825,9 +826,8 @@ impl Session { self.common.byte(channel, msg::CHANNEL_EOF); } - /// Send data to a channel. On session channels, `extended` can be - /// used to encode standard error by passing `Some(1)`, and stdout - /// by passing `None`. + /// Send data to a channel. On session channels, this generally + /// refers to stdout. /// /// The number of bytes added to the "sending pipeline" (to be /// processed by the event loop) is returned. @@ -840,8 +840,7 @@ impl Session { } /// Send data to a channel. On session channels, `extended` can be - /// used to encode standard error by passing `Some(1)`, and stdout - /// by passing `None`. + /// used to encode standard error by passing `1`. /// /// The number of bytes added to the "sending pipeline" (to be /// processed by the event loop) is returned.