Skip to content

Commit

Permalink
Update relevant channel request callbacks to return a bool
Browse files Browse the repository at this point in the history
This is a breaking change, but it tweaks channel request callbacks to
return a bool rather than requiring the user to manually call
`session.channel_success` or `session.channel_failure`.

This has the added advantage of changing the defaults of a number of
request callbacks to more-secure defaults (deny), and makes it
impossible for a user to miss responding to callbacks which require
responses.

Note that this does *not* handle sending responses for all requests,
only channel requests listed in RFC4254 as having a "want reply" param
rather than just "false", even though it may be more correct to respond
to malformed requests which have improperly set that byte to "true" even
though the RFC specifies "false".
  • Loading branch information
belak committed Sep 19, 2024
1 parent 451e74b commit 144f58a
Show file tree
Hide file tree
Showing 5 changed files with 77 additions and 40 deletions.
21 changes: 12 additions & 9 deletions russh/examples/sftp_server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,19 +76,22 @@ impl russh::server::Handler for SshSession {
channel_id: ChannelId,
name: &str,
session: &mut Session,
) -> Result<(), Self::Error> {
) -> Result<bool, Self::Error> {
info!("subsystem: {}", name);

if name == "sftp" {
let channel = self.get_channel(channel_id).await;
let sftp = SftpSession::default();
session.channel_success(channel_id);
russh_sftp::server::run(channel.into_stream(), sftp).await;
} else {
session.channel_failure(channel_id);
if name != "sftp" {
return Ok(false);
}

Ok(())
let channel = self.get_channel(channel_id).await;
let sftp = SftpSession::default();
session.channel_success(channel_id);

tokio::spawn(async move {
russh_sftp::server::run(channel.into_stream(), sftp).await;
});

Ok(true)
}
}

Expand Down
6 changes: 3 additions & 3 deletions russh/examples/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,14 +69,14 @@ impl server::Handler for Server {
&mut self,
channel: ChannelId,
session: &mut Session,
) -> Result<(), Self::Error> {
session.request_success();
Ok(())
) -> Result<bool, Self::Error> {
Ok(true)
}

async fn auth_publickey(&mut self, _: &str, _: &key::PublicKey) -> Result<Auth, Self::Error> {
Ok(server::Auth::Accept)
}

async fn data(
&mut self,
_channel: ChannelId,
Expand Down
55 changes: 45 additions & 10 deletions russh/src/server/encrypted.rs
Original file line number Diff line number Diff line change
Expand Up @@ -822,7 +822,7 @@ impl Session {

debug!("handler.pty_request {:?}", channel_num);
#[allow(clippy::indexing_slicing)] // `modes` length checked
handler
let response = handler
.pty_request(
channel_num,
term,
Expand All @@ -833,7 +833,13 @@ impl Session {
&modes[0..i],
self,
)
.await
.await?;
if response {
self.channel_success(channel_num);
} else {
self.channel_failure(channel_num);
}
Ok(())
}
b"x11-req" => {
let single_connection = r.read_byte().map_err(crate::Error::from)? != 0;
Expand All @@ -855,7 +861,7 @@ impl Session {
});
}
debug!("handler.x11_request {:?}", channel_num);
handler
let response = handler
.x11_request(
channel_num,
single_connection,
Expand All @@ -864,7 +870,13 @@ impl Session {
x11_screen_number,
self,
)
.await
.await?;
if response {
self.channel_success(channel_num);
} else {
self.channel_failure(channel_num);
}
Ok(())
}
b"env" => {
let env_variable =
Expand All @@ -883,23 +895,34 @@ impl Session {
}

debug!("handler.env_request {:?}", channel_num);
handler
let response = handler
.env_request(channel_num, env_variable, env_value, self)
.await
.await?;
if response {
self.channel_success(channel_num);
} else {
self.channel_failure(channel_num);
}
Ok(())
}
b"shell" => {
if let Some(chan) = self.channels.get(&channel_num) {
let _ = chan.send(ChannelMsg::RequestShell { want_reply: true });
}
debug!("handler.shell_request {:?}", channel_num);
handler.shell_request(channel_num, self).await
let response = handler.shell_request(channel_num, self).await?;
if response {
self.channel_success(channel_num);
} else {
self.channel_failure(channel_num);
}
Ok(())
}
b"[email protected]" => {
if let Some(chan) = self.channels.get(&channel_num) {
let _ = chan.send(ChannelMsg::AgentForward { want_reply: true });
}
debug!("handler.agent_request {:?}", channel_num);

let response = handler.agent_request(channel_num, self).await?;
if response {
self.request_success()
Expand All @@ -917,7 +940,13 @@ impl Session {
});
}
debug!("handler.exec_request {:?}", channel_num);
handler.exec_request(channel_num, req, self).await
let response = handler.exec_request(channel_num, req, self).await?;
if response {
self.channel_success(channel_num);
} else {
self.channel_failure(channel_num);
}
Ok(())
}
b"subsystem" => {
let name =
Expand All @@ -931,7 +960,13 @@ impl Session {
});
}
debug!("handler.subsystem_request {:?}", channel_num);
handler.subsystem_request(channel_num, name, self).await
let response = handler.subsystem_request(channel_num, name, self).await?;
if response {
self.channel_success(channel_num);
} else {
self.channel_failure(channel_num);
}
Ok(())
}
b"window-change" => {
let col_width = r.read_u32().map_err(crate::Error::from)?;
Expand Down
24 changes: 12 additions & 12 deletions russh/src/server/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -397,8 +397,8 @@ pub trait Handler: Sized {
pix_height: u32,
modes: &[(Pty, u32)],
session: &mut Session,
) -> Result<(), Self::Error> {
Ok(())
) -> Result<bool, Self::Error> {
Ok(false)
}

/// The client requests an X11 connection.
Expand All @@ -411,8 +411,8 @@ pub trait Handler: Sized {
x11_auth_cookie: &str,
x11_screen_number: u32,
session: &mut Session,
) -> Result<(), Self::Error> {
Ok(())
) -> Result<bool, Self::Error> {
Ok(false)
}

/// The client wants to set the given environment variable. Check
Expand All @@ -425,8 +425,8 @@ pub trait Handler: Sized {
variable_name: &str,
variable_value: &str,
session: &mut Session,
) -> Result<(), Self::Error> {
Ok(())
) -> Result<bool, Self::Error> {
Ok(false)
}

/// The client requests a shell.
Expand All @@ -435,8 +435,8 @@ pub trait Handler: Sized {
&mut self,
channel: ChannelId,
session: &mut Session,
) -> Result<(), Self::Error> {
Ok(())
) -> Result<bool, Self::Error> {
Ok(false)
}

/// The client sends a command to execute, to be passed to a
Expand All @@ -447,8 +447,8 @@ pub trait Handler: Sized {
channel: ChannelId,
data: &[u8],
session: &mut Session,
) -> Result<(), Self::Error> {
Ok(())
) -> Result<bool, Self::Error> {
Ok(false)
}

/// The client asks to start the subsystem with the given name
Expand All @@ -459,8 +459,8 @@ pub trait Handler: Sized {
channel: ChannelId,
name: &str,
session: &mut Session,
) -> Result<(), Self::Error> {
Ok(())
) -> Result<bool, Self::Error> {
Ok(false)
}

/// The client's pseudo-terminal window size has changed.
Expand Down
11 changes: 5 additions & 6 deletions russh/src/server/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -780,7 +780,8 @@ impl Session {
}
}

/// Send a "failure" reply to a global request.
/// Send a "failure" reply to a channel request. Always call this function
/// if the request failed (it checks whether the client expects an answer).
pub fn channel_failure(&mut self, channel: ChannelId) {
if let Some(ref mut enc) = self.common.encrypted {
if let Some(channel) = enc.channels.get_mut(&channel) {
Expand Down Expand Up @@ -825,9 +826,8 @@ impl Session {
self.common.byte(channel, msg::CHANNEL_EOF);
}

/// Send data to a channel. On session channels, `extended` can be
/// used to encode standard error by passing `Some(1)`, and stdout
/// by passing `None`.
/// Send data to a channel. On session channels, this generally
/// refers to stdout.
///
/// The number of bytes added to the "sending pipeline" (to be
/// processed by the event loop) is returned.
Expand All @@ -840,8 +840,7 @@ impl Session {
}

/// Send data to a channel. On session channels, `extended` can be
/// used to encode standard error by passing `Some(1)`, and stdout
/// by passing `None`.
/// used to encode standard error by passing `1`.
///
/// The number of bytes added to the "sending pipeline" (to be
/// processed by the event loop) is returned.
Expand Down

0 comments on commit 144f58a

Please sign in to comment.