Skip to content

Commit

Permalink
Test policy
Browse files Browse the repository at this point in the history
  • Loading branch information
igorski-r7 committed Jan 22, 2025
1 parent 38a4cbe commit d9cb702
Show file tree
Hide file tree
Showing 6 changed files with 49 additions and 38 deletions.
4 changes: 3 additions & 1 deletion plugins/ssh/komand_ssh/connection/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import paramiko
from typing import Dict, Any
from komand_ssh.util.strategies import ConnectUsingPasswordStrategy, ConnectUsingRSAKeyStrategy
from komand_ssh.util.policies import CustomMissingKeyPolicy


class Connection(insightconnect_plugin_runtime.Connection):
Expand All @@ -29,9 +30,10 @@ def connect(self, params={}) -> None:
self.use_key = params.get(Input.USE_KEY, False)
self.key = params.get(Input.KEY, {}).get("secretKey", "")
self.ssh_client = paramiko.SSHClient()
self.ssh_client.set_missing_host_key_policy(CustomMissingKeyPolicy())

def client(self, host: str = None) -> paramiko.SSHClient:
# Update host only if different from host in connection
# Update host only if entered and different from host in connection
if host and host != self.host:
self.host = host

Expand Down
5 changes: 5 additions & 0 deletions plugins/ssh/komand_ssh/util/policies.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from paramiko import MissingHostKeyPolicy, SSHClient, PKey

class CustomMissingKeyPolicy(MissingHostKeyPolicy):
def missing_host_key(self, client: SSHClient, hostname: str, key: PKey) -> None:
client.get_host_keys().add(hostname, key.get_name(), key)
7 changes: 3 additions & 4 deletions plugins/ssh/komand_ssh/util/strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from .constants import DEFAULT_ENCODING, DEFAULT_SLEEP_TIME, DEFAULT_SSH_KEY_ALGORITHMS



class SSHConnectionStrategy(ABC):
def __init__(self, client: SSHClient, logger: Logger) -> None:
self.client = client
Expand Down Expand Up @@ -58,17 +59,15 @@ def _extend_host_keys(

class ConnectUsingPasswordStrategy(SSHConnectionStrategy):
def connect(self, host: str, port: int, username: str, password: str, key: str = None) -> SSHClient:
self.logger.info("Connecting to the SSH server via password...")
self._extend_host_keys(host, port, self.client)
self.logger.info("Connecting to SSH server via password...")
self.client.connect(host, port, username, password)
return self.client


class ConnectUsingRSAKeyStrategy(SSHConnectionStrategy):
def connect(self, host: str, port: int, username: str, password: str, key: str = None) -> SSHClient:
self.logger.info("Connecting to the SSH server via RSA key...")
self.logger.info("Connecting to SSH server via RSA key...")
key = b64decode(key).decode(DEFAULT_ENCODING)
rsa_key = RSAKey.from_private_key(StringIO(key), password=password)
self._extend_host_keys(host, port, self.client)
self.client.connect(host, port, username, password, pkey=rsa_key)
return self.client
File renamed without changes.
33 changes: 13 additions & 20 deletions plugins/ssh/unit_test/test_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,32 +4,25 @@
sys.path.append(os.path.abspath("../"))

from unittest import TestCase
from unittest.mock import patch
from unittest.mock import MagicMock, patch

from komand_ssh.actions.run import Run
from komand_ssh.actions.run.schema import Input, Output

from util import Util

from komand_ssh.actions.run import Run
from komand_ssh.actions.run.schema import Output
STUB_PARAMETERS = {Input.HOST: "example.com", Input.COMMAND: "ls -l"}


class TestRun(TestCase):
def setUp(self):
self.action = Run()
self.params = {
"host": "example.com",
"command": "ls -l",
}
self.action.connection = Util.default_connector()

def mock_execute_command(self):
file1 = open("./ssh/unit_test/results", "r")
return file1, file1, file1

@patch("paramiko.SSHClient.set_missing_host_key_policy", return_value=None)
@patch("paramiko.SSHClient.load_system_host_keys", return_value=None)
self.action = Util.default_connector(Run())

@patch("paramiko.SSHClient.connect", return_value=None)
@patch("paramiko.SSHClient.exec_command", side_effect=mock_execute_command)
def test_run(self, mock_key_policy, mock_host_keys, mock_connect, mock_exec):
@patch("paramiko.SSHClient.exec_command", side_effect=Util.mock_execute_command)
def test_run(self, mock_connect: MagicMock, mock_exec: MagicMock) -> None:
response = self.action.run(STUB_PARAMETERS)
expected = {Output.RESULTS: {"stdout": "/home/vagrant", "stderr": "", "all_output": "/home/vagrant"}}
actual = self.action.run(self.params)
self.assertEqual(actual, expected)
self.assertEqual(response, expected)
mock_connect.assert_called()
mock_exec.assert_called()
38 changes: 25 additions & 13 deletions plugins/ssh/unit_test/util.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,37 @@
import logging
import os
import sys
from typing import TextIO, Tuple

sys.path.append(os.path.abspath("../"))

from pathlib import Path

from insightconnect_plugin_runtime.action import Action
from komand_ssh.connection.connection import Connection
from komand_ssh.connection.schema import Input

STUB_CONNECTION = {
Input.HOST: "0.0.0.0",
Input.PORT: "22",
Input.KEY: {},
Input.USE_KEY: False,
Input.PASSWORD: {"secretKey": "ABC"},
Input.USERNAME: "username",
}


class Util:
@staticmethod
def default_connector():
connection = Connection()
params = {
Input.HOST: "0.0.0.0",
Input.PORT: "22",
Input.KEY: {},
Input.USE_KEY: False,
Input.PASSWORD: {"secretKey": "ABC"},
Input.USERNAME: "username",
}
connection.logger = logging.getLogger("action logger")
connection.parameters = params
return connection
def default_connector(action: Action) -> Action:
default_connection = Connection()
default_connection.logger = logging.getLogger("connection logger")
default_connection.connect(STUB_CONNECTION)
action.connection = default_connection
action.logger = logging.getLogger("action logger")
return action

@staticmethod
def mock_execute_command() -> Tuple[TextIO, TextIO, TextIO]:
file_ = open(Path(__file__).parent / "responses" / "results.txt", "r")
return file_, file_, file_

0 comments on commit d9cb702

Please sign in to comment.