Skip to content

Commit

Permalink
Add some basic supporting elements
Browse files Browse the repository at this point in the history
Added in timeouts for establishing connections. Will need to look into
ways to implement timeouts for command executions.
Added SSHResult and AuthenticationError classes to the hussh module.
  • Loading branch information
JacobCallahan committed Mar 26, 2024
1 parent ac3383a commit 6216c32
Show file tree
Hide file tree
Showing 5 changed files with 64 additions and 44 deletions.
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "hussh"
version = "0.1.0"
version = "0.1.1"
edition = "2021"

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
Expand Down
86 changes: 55 additions & 31 deletions src/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,14 +54,21 @@
//! ```
//!
//! Note: The `read` method sends an EOF to the shell, so you won't be able to send more commands after calling `read`. If you want to send more commands, you would need to create a new `InteractiveShell` instance.
use pyo3::create_exception;
use pyo3::exceptions::PyTimeoutError;
use pyo3::prelude::*;
use ssh2::{Channel, Session};
use std::io::prelude::*;
use std::io::{BufReader, BufWriter, Read, Write};
use std::net::TcpStream;
use std::path::Path;
// use ssh2::FileStat;

const MAX_BUFF_SIZE: usize = 65536;
create_exception!(
connection,
AuthenticationError,
pyo3::exceptions::PyException
);

fn read_from_channel(channel: &mut Channel) -> SSHResult {
let mut stdout = String::new();
Expand Down Expand Up @@ -197,6 +204,7 @@ pub struct Connection {
private_key: String,
#[pyo3(get)]
timeout: u32,
sftp_conn: Option<ssh2::Sftp>,
}

#[pymethods]
Expand All @@ -209,18 +217,21 @@ impl Connection {
password: Option<String>,
private_key: Option<String>,
timeout: Option<u32>,
) -> Self {
) -> PyResult<Connection> {
// if port isn't set, use the default ssh port 22
let port = port.unwrap_or(22);
// combine the host and port into a single string
let conn_str = format!("{}:{}", host, port);
let tcp_conn = TcpStream::connect(&conn_str).unwrap();
let tcp_conn = TcpStream::connect(&conn_str)
.map_err(|e| PyErr::new::<PyTimeoutError, _>(format!("{}", e)))?;
let mut session = Session::new().unwrap();
// if a timeout is set, use it
let timeout = timeout.unwrap_or(0);
session.set_timeout(timeout);
session.set_tcp_stream(tcp_conn);
session.handshake().unwrap();
session
.handshake()
.map_err(|e| PyErr::new::<PyTimeoutError, _>(format!("{}", e)))?;
// if username isn't set, try using root
let username = username.unwrap_or("root".to_string());
let password = password.unwrap_or("".to_string());
Expand All @@ -231,41 +242,44 @@ impl Connection {
if password != "" {
session
.userauth_pubkey_file(&username, None, Path::new(&private_key), Some(&password))
.unwrap();
.map_err(|e| PyErr::new::<AuthenticationError, _>(format!("{}", e)))?;
} else {
// otherwise, try using the private key without a passphrase
session
.userauth_pubkey_file(&username, None, Path::new(&private_key), None)
.unwrap();
.map_err(|e| PyErr::new::<AuthenticationError, _>(format!("{}", e)))?;
}
} else if password != "" {
session.userauth_password(&username, &password).unwrap();
session
.userauth_password(&username, &password)
.map_err(|e| PyErr::new::<AuthenticationError, _>(format!("{}", e)))?;
} else {
// if password isn't set, try using the default ssh-agent
if session.userauth_agent(&username).is_err() {
panic!("Failed to authenticate with ssh-agent");
return Err(PyErr::new::<AuthenticationError, _>(
"Failed to authenticate with ssh-agent",
));
}
}
Connection {
Ok(Connection {
session,
port,
host,
username,
password,
private_key,
timeout,
}
sftp_conn: None,
})
}

/// Executes a command over the SSH connection and returns the result.
fn execute(&self, command: String, timeout: Option<u32>) -> SSHResult {
if let Some(timeout) = timeout {
self.session.set_timeout(timeout); // set the timeout to the provided value
}
fn execute(&self, command: String) -> PyResult<SSHResult> {
let mut channel = self.session.channel_session().unwrap();
channel.exec(&command).unwrap();
self.session.set_timeout(self.timeout); // reset the timeout to the default
read_from_channel(&mut channel)
if let Err(e) = channel.exec(&command) {
return Err(PyErr::new::<PyTimeoutError, _>(format!("{}", e)));
}
Ok(read_from_channel(&mut channel))
}

/// Reads a file over SCP and returns the contents.
Expand Down Expand Up @@ -351,41 +365,51 @@ impl Connection {
/// Reads a file over SFTP and returns the contents.
/// If `local_path` is provided, the file is saved to the local system.
/// Otherwise, the contents of the file are returned as a string.
fn sftp_read(&self, remote_path: String, local_path: Option<String>) -> PyResult<String> {
let mut remote_file = self
.session
.sftp()
.unwrap()
.open(Path::new(&remote_path))
.unwrap();
fn sftp_read(&mut self, remote_path: String, local_path: Option<String>) -> PyResult<String> {
if self.sftp_conn.is_none() {
self.sftp_conn = Some(self.session.sftp().unwrap());
}
let mut remote_file = BufReader::new(
self.sftp_conn
.as_ref()
.unwrap()
.open(Path::new(&remote_path))
.unwrap(),
);
match local_path {
Some(local_path) => {
let mut local_file = std::fs::File::create(local_path).unwrap();
let local_file = std::fs::File::create(local_path)?;
let mut writer = BufWriter::new(local_file);
let mut buffer = vec![0; MAX_BUFF_SIZE];
loop {
let len = remote_file.read(&mut buffer).unwrap();
let len = remote_file.read(&mut buffer)?;
if len == 0 {
break;
}
local_file.write_all(&buffer[..len]).unwrap();
writer.write_all(&buffer[..len])?;
}
writer.flush()?;
Ok("Ok".to_string())
}
None => {
let mut contents = String::new();
remote_file.read_to_string(&mut contents).unwrap();
remote_file.read_to_string(&mut contents)?;
Ok(contents)
}
}
}

/// Writes a file over SFTP.
fn sftp_write(&self, local_path: String, remote_path: String) -> PyResult<()> {
fn sftp_write(&mut self, local_path: String, remote_path: String) -> PyResult<()> {
let mut local_file = std::fs::File::open(&local_path).unwrap();
let metadata = local_file.metadata().unwrap();
// If we don't already have an SFTP connection, create one
if self.sftp_conn.is_none() {
self.sftp_conn = Some(self.session.sftp().unwrap());
}
let mut remote_file = self
.session
.sftp()
.sftp_conn
.as_ref()
.unwrap()
.create(Path::new(&remote_path))
.unwrap();
Expand Down
4 changes: 4 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use connection::AuthenticationError;
use pyo3::prelude::*;

mod connection;
Expand All @@ -6,5 +7,8 @@ mod connection;
#[pymodule]
fn hussh(_py: Python, m: &PyModule) -> PyResult<()> {
m.add_class::<connection::Connection>()?; // Add the Connection class
m.add_class::<connection::SSHResult>()?;
// m.add_class::<connection::InteractiveShell>()?;
m.add("AuthenticationError", _py.get_type::<AuthenticationError>())?;
Ok(())
}
14 changes: 3 additions & 11 deletions tests/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import pytest
from pathlib import Path
from hussh import Connection
from hussh import Connection, SSHResult


TEXT_FILE = Path("tests/data/hp.txt").resolve()
Expand Down Expand Up @@ -44,6 +44,7 @@ def test_agent_auth(setup_agent_auth):
def test_basic_command(conn):
"""Test that we can run a basic command."""
result = conn.execute("echo hello")
assert isinstance(result, SSHResult)
assert result.status == 0
assert result.stdout == "hello\n"

Expand Down Expand Up @@ -147,16 +148,7 @@ def test_shell_context(conn):
assert sh.exit_result.status != 0


@pytest.mark.skip("Skipping until exceptions are implemented.")
def test_connection_timeout():
"""Test that we can trigger a timeout on connect."""
with pytest.raises(TimeoutError):
Connection(host="localhost", port=8022, password="toor", timeout=2000)


@pytest.mark.skip("Skipping until exceptions are implemented.")
def test_execute_timeout():
"""Test that we can trigger a timeout on execute."""
conn = Connection(host="localhost", port=8022, password="toor")
with pytest.raises(TimeoutError):
conn.execute("sleep 3", timeout=2000)
Connection(host="localhost", port=8022, password="toor", timeout=10)

0 comments on commit 6216c32

Please sign in to comment.