From caba5e5a650d19883fb6a1568a66798b208418d3 Mon Sep 17 00:00:00 2001 From: mlataza <1750947+mlataza@users.noreply.github.com> Date: Fri, 17 Jan 2025 04:11:20 +0800 Subject: [PATCH] add websocket connection as backup --- config_module/fetch_config.py | 229 ++++---- config_module/host_info.py | 400 ++++++------- iot_hub_module/connection_management.py | 735 ++++++++++++------------ tests/config_module/test_host_info.py | 4 +- 4 files changed, 697 insertions(+), 671 deletions(-) diff --git a/config_module/fetch_config.py b/config_module/fetch_config.py index ac3c8e7..94cde19 100644 --- a/config_module/fetch_config.py +++ b/config_module/fetch_config.py @@ -1,111 +1,118 @@ -""" Module for fetching configuration """ - -from typing import Dict, Any, List, Tuple - -import logging -import asyncio -import httpx -from .host_info import build_host_tags - -# Put Timestamps on logging entries -logging.basicConfig( - level=logging.INFO, - format="%(asctime)s %(levelname)s: %(message)s", - datefmt="%Y-%m-%d %H:%M:%S", -) - -REQUIRED_KEYS = [ - "azure_iot_hub_host", - "device_id", - "shared_access_key", - "rewst_engine_host", - "rewst_org_id", -] - - -async def fetch_configuration( - config_url: str, - secret: str = None, - org_id: str = None, - retry_intervals: Tuple[Tuple[int, int | float]] = ( - (5, 12), - (60, 60), - (300, float("inf")), - ), -) -> Dict[str, Any] | None: - """ - Fetch configuration from configuration link. - - Args: - config_url (str): Configuration url. - secret (str, optional): Client secret with the Rewst platform. Defaults to None. - org_id (str, optional): Organization identifier in Rewst platform. Defaults to None. - retry_intervals (Tuple[Tuple[int, int | float]], optional): List of tuples of intervals - and maximum retries to do in case the response is an error. Defaults to - ((5, 12), (60, 60), (300, float("inf"))). - Returns: - Dict[str, Any]|None: Configuration data if successful, otherwise None. - """ - - host_info = build_host_tags(org_id) - - headers = {} - if secret: - headers["x-rewst-secret"] = secret - - logging.debug(f"Sending host information to {config_url}: {str(host_info)}") - - for interval, max_retries in retry_intervals: - retries = 0 - while retries < max_retries: - retries += 1 - async with httpx.AsyncClient( - timeout=None - ) as client: # Set timeout to None to wait indefinitely - try: - response = await client.post( - config_url, - json=host_info, - headers=headers, - follow_redirects=True, - ) - except httpx.TimeoutException: - logging.warning( - f"Attempt {retries}: Request timed out. Retrying..." - ) - await asyncio.sleep(interval) - continue # Skip the rest of the loop and retry - - except httpx.RequestError as e: - logging.warning( - f"Attempt {retries}: Network error: {e}. Retrying..." - ) - await asyncio.sleep(interval) - continue - - if response.status_code == 303: - logging.info( - "Waiting while Rewst processes Agent Registration..." - ) # Custom message for 303 - elif response.status_code == 200: - data = response.json() - config_data = data.get("configuration") - if config_data and all(key in config_data for key in REQUIRED_KEYS): - return config_data - else: - logging.warning( - f"Attempt {retries}: Missing required keys in configuration data. Retrying..." - ) - elif response.status_code == 400 or response.status_code == 401: - logging.error( - f"Attempt {retries}: Not authorized. Check your config secret." - ) - else: - logging.warning( - f"Attempt {retries}: Received status code {response.status_code}. Retrying..." - ) - - logging.info(f"Attempt {retries}: Waiting {interval}s before retrying...") - await asyncio.sleep(interval) - - logging.info("This process will end when the service is installed.") +""" Module for fetching configuration """ + +from typing import Dict, Any, Tuple + +import logging +import asyncio +import httpx +from .host_info import build_host_tags + +# Put Timestamps on logging entries +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s %(levelname)s: %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", +) + +REQUIRED_KEYS = [ + "azure_iot_hub_host", + "device_id", + "shared_access_key", + "rewst_engine_host", + "rewst_org_id", +] + + +async def fetch_configuration( + config_url: str, + secret: str = None, + org_id: str = None, + retry_intervals: Tuple[Tuple[int, int | float]] = ( + (5, 12), + (60, 60), + (300, float("inf")), + ), +) -> Dict[str, Any] | None: + """ + Fetch configuration from configuration link. + + Args: + config_url (str): Configuration url. + secret (str, optional): Client secret with the Rewst platform. Defaults to None. + org_id (str, optional): Organization identifier in Rewst platform. Defaults to None. + retry_intervals (Tuple[Tuple[int, int | float]], optional): List of tuples of intervals + and maximum retries to do in case the response is an error. Defaults to + ((5, 12), (60, 60), (300, float("inf"))). + Returns: + Dict[str, Any]|None: Configuration data if successful, otherwise None. + """ + + host_info = build_host_tags(org_id) + + headers = {} + if secret: + headers["x-rewst-secret"] = secret + + logging.info("Sending host information to %s: %s", config_url, host_info) + + for interval, max_retries in retry_intervals: + retries = 0 + while retries < max_retries: + retries += 1 + async with httpx.AsyncClient( + timeout=None + ) as client: # Set timeout to None to wait indefinitely + try: + response = await client.post( + config_url, + json=host_info, + headers=headers, + follow_redirects=True, + ) + except httpx.TimeoutException: + logging.warning( + "Attempt %d: Request timed out. Retrying...", retries + ) + await asyncio.sleep(interval) + continue # Skip the rest of the loop and retry + + except httpx.RequestError as e: + logging.warning( + "Attempt %d: Network error: %s. Retrying...", retries, e + ) + await asyncio.sleep(interval) + continue + + if response.status_code == 303: + logging.info( + "Waiting while Rewst processes Agent Registration..." + ) # Custom message for 303 + elif response.status_code == 200: + data = response.json() + config_data = data.get("configuration") + logging.info(config_data) + + if config_data and all(key in config_data for key in REQUIRED_KEYS): + return config_data + + logging.warning( + "Attempt %d: Missing required keys in configuration data. Retrying...", + retries, + ) + elif response.status_code in (400, 401): + logging.error( + "Attempt %d: Not authorized. Check your config secret.", retries + ) + else: + logging.warning( + "Attempt %d: Received status code %d. Retrying...", + retries, + response.status_code, + ) + + logging.info( + "Attempt %d: Waiting %ds before retrying...", retries, interval + ) + await asyncio.sleep(interval) + + logging.info("This process will end when the service is installed.") diff --git a/config_module/host_info.py b/config_module/host_info.py index 7d9f261..d0faff2 100644 --- a/config_module/host_info.py +++ b/config_module/host_info.py @@ -1,197 +1,203 @@ -""" Module to get host info details """ - -from typing import Dict, Any - -import platform -import uuid -import socket -import subprocess -import sys -import logging -import psutil -import __version__ -from config_module.config_io import ( - get_service_executable_path, - get_agent_executable_path, -) - - -def get_mac_address() -> str: - """ - Get MAC Address of the host machine. - - Returns: - str: MAC address in hexadecimal format iwhtout the columns. - """ - # Returns the MAC address of the host without colons - mac_num = hex(uuid.UUID(int=uuid.getnode()).int)[2:] - mac_address = ":".join(mac_num[i : i + 2] for i in range(0, 11, 2)) - return mac_address.replace(":", "") - - -def run_powershell_command(powershell_command: str) -> str | None: - """ - Execute a powershell command and return the output - - Args: - powershell_command (str): Powershell command to execute. - - Returns: - str|None: Execution output if successful, otherwise None. - """ - try: - result = subprocess.run( - ["powershell", "-Command", powershell_command], - capture_output=True, - text=True, - check=True, - ) - return result.stdout.strip() - except subprocess.CalledProcessError as e: - print(f"Error executing PowerShell script: {e}", file=sys.stderr) - return None - - -def is_domain_controller() -> bool: - """ - Checks if the current computer is a domain controller. - - Returns: - bool: True if domain controller, otherwise False. - """ - powershell_command = """ - $domainStatus = (Get-WmiObject Win32_ComputerSystem).DomainRole - if ($domainStatus -eq 4 -or $domainStatus -eq 5) { - return $true - } else { - return $false - } - """ - output = run_powershell_command(powershell_command) - logging.info(f"Is domain controller?: {output}") - return "True" in output - - -def get_ad_domain_name() -> str | None: - """ - Gets the Active Directory domain name if the PC is joined to AD. - - Returns: - str|None: Active directory domain name of the host if it exists, otherwise None. - """ - - powershell_command = """ - $domainInfo = (Get-WmiObject Win32_ComputerSystem).Domain - if ($domainInfo -and $domainInfo -ne 'WORKGROUP') { - return $domainInfo - } else { - return $null - } - """ - ad_domain_name = run_powershell_command(powershell_command) - logging.info(f"AD domain name: {ad_domain_name}") - return ad_domain_name - - -def get_entra_domain() -> str | None: - """ - Get the Entra domain of the host machine. - - Returns: - str|None: Entra domain if it exists, otherwise None. - """ - if platform.system().lower() != "windows": - return None - else: - try: - result = subprocess.run( - ["dsregcmd", "/status"], text=True, capture_output=True - ) - output = result.stdout - for line in output.splitlines(): - if "AzureAdJoined" in line and "YES" in line: - for line in output.splitlines(): - if "DomainName" in line: - domain_name = line.split(":")[1].strip() - return domain_name - except Exception as e: - logging.warning(f"Unexpected issue querying for Entra Domain: {str(e)}") - pass # Handle exception if necessary - return None - - -def is_entra_connect_server() -> bool: - """ - Checks whether the host machine is an Entra connect server. - - Returns: - bool: True if it is an Entra connect server, otherwise False. - """ - if platform.system().lower() != "windows": - return False - else: - potential_service_names = [ - "ADSync", - "Azure AD Sync", - "EntraConnectSync", - "OtherFutureName", - ] - for service_name in potential_service_names: - if is_service_running(service_name): - return True - return False - - -def is_service_running(service_name: str) -> bool: - """ - Checks whether the service is running. - - Args: - service_name (str): Service name to check. - - Returns: - bool: True if the service is running, otherwise False. - """ - for service in ( - psutil.win_service_iter() - if platform.system() == "Windows" - else psutil.process_iter(["name"]) - ): - if service.name().lower() == service_name.lower(): - return True - return False - - -def build_host_tags(org_id: str = None) -> Dict[str, Any]: - """ - Build host tags for the organization. - - Args: - org_id (str, optional): Organization identifier in Rewst platform. Defaults to None. - - Returns: - Dict[str, Any]: Host tags detail. - """ - # Collect host information - ad_domain = get_ad_domain_name() - if ad_domain: - is_dc = is_domain_controller() - else: - is_dc = False - - host_info = { - "agent_version": (__version__.__version__ or None), - "agent_executable_path": get_agent_executable_path(org_id), - "service_executable_path": get_service_executable_path(org_id), - "hostname": socket.gethostname(), - "mac_address": get_mac_address(), - "operating_system": platform.platform(), - "cpu_model": platform.processor(), - "ram_gb": round(psutil.virtual_memory().total / (1024**3), 1), - "ad_domain": ad_domain, - "is_ad_domain_controller": is_dc, - "is_entra_connect_server": is_entra_connect_server(), - "entra_domain": get_entra_domain(), - "org_id": org_id, - } - return host_info +""" Module to get host info details """ + +from typing import Dict, Any + +import platform +import uuid +import socket +import subprocess +import sys +import logging +import psutil +import __version__ +from config_module.config_io import ( + get_service_executable_path, + get_agent_executable_path, +) + + +def get_mac_address() -> str: + """ + Get MAC Address of the host machine. + + Returns: + str: MAC address in hexadecimal format iwhtout the columns. + """ + # Use the psutil for hardware mac address + for _, addrs in psutil.net_if_addrs().items(): + for addr in addrs: + if addr.family == psutil.AF_LINK: + return str(addr.address).lower().replace("-", "") + + # Returns the MAC address of the host without colons + mac_num = hex(uuid.UUID(int=uuid.getnode()).int)[2:] + mac_address = ":".join(mac_num[i : i + 2] for i in range(0, 11, 2)) + return mac_address.replace(":", "") + + +def run_powershell_command(powershell_command: str) -> str | None: + """ + Execute a powershell command and return the output + + Args: + powershell_command (str): Powershell command to execute. + + Returns: + str|None: Execution output if successful, otherwise None. + """ + try: + result = subprocess.run( + ["powershell", "-Command", powershell_command], + capture_output=True, + text=True, + check=True, + ) + return result.stdout.strip() + except subprocess.CalledProcessError as e: + print(f"Error executing PowerShell script: {e}", file=sys.stderr) + return None + + +def is_domain_controller() -> bool: + """ + Checks if the current computer is a domain controller. + + Returns: + bool: True if domain controller, otherwise False. + """ + powershell_command = """ + $domainStatus = (Get-WmiObject Win32_ComputerSystem).DomainRole + if ($domainStatus -eq 4 -or $domainStatus -eq 5) { + return $true + } else { + return $false + } + """ + output = run_powershell_command(powershell_command) + logging.info(f"Is domain controller?: {output}") + return "True" in output + + +def get_ad_domain_name() -> str | None: + """ + Gets the Active Directory domain name if the PC is joined to AD. + + Returns: + str|None: Active directory domain name of the host if it exists, otherwise None. + """ + + powershell_command = """ + $domainInfo = (Get-WmiObject Win32_ComputerSystem).Domain + if ($domainInfo -and $domainInfo -ne 'WORKGROUP') { + return $domainInfo + } else { + return $null + } + """ + ad_domain_name = run_powershell_command(powershell_command) + logging.info(f"AD domain name: {ad_domain_name}") + return ad_domain_name + + +def get_entra_domain() -> str | None: + """ + Get the Entra domain of the host machine. + + Returns: + str|None: Entra domain if it exists, otherwise None. + """ + if platform.system().lower() != "windows": + return None + else: + try: + result = subprocess.run( + ["dsregcmd", "/status"], text=True, capture_output=True + ) + output = result.stdout + for line in output.splitlines(): + if "AzureAdJoined" in line and "YES" in line: + for line in output.splitlines(): + if "DomainName" in line: + domain_name = line.split(":")[1].strip() + return domain_name + except Exception as e: + logging.warning(f"Unexpected issue querying for Entra Domain: {str(e)}") + pass # Handle exception if necessary + return None + + +def is_entra_connect_server() -> bool: + """ + Checks whether the host machine is an Entra connect server. + + Returns: + bool: True if it is an Entra connect server, otherwise False. + """ + if platform.system().lower() != "windows": + return False + else: + potential_service_names = [ + "ADSync", + "Azure AD Sync", + "EntraConnectSync", + "OtherFutureName", + ] + for service_name in potential_service_names: + if is_service_running(service_name): + return True + return False + + +def is_service_running(service_name: str) -> bool: + """ + Checks whether the service is running. + + Args: + service_name (str): Service name to check. + + Returns: + bool: True if the service is running, otherwise False. + """ + for service in ( + psutil.win_service_iter() + if platform.system() == "Windows" + else psutil.process_iter(["name"]) + ): + if service.name().lower() == service_name.lower(): + return True + return False + + +def build_host_tags(org_id: str = None) -> Dict[str, Any]: + """ + Build host tags for the organization. + + Args: + org_id (str, optional): Organization identifier in Rewst platform. Defaults to None. + + Returns: + Dict[str, Any]: Host tags detail. + """ + # Collect host information + ad_domain = get_ad_domain_name() + if ad_domain: + is_dc = is_domain_controller() + else: + is_dc = False + + host_info = { + "agent_version": (__version__.__version__ or None), + "agent_executable_path": get_agent_executable_path(org_id), + "service_executable_path": get_service_executable_path(org_id), + "hostname": socket.gethostname(), + "mac_address": get_mac_address(), + "operating_system": platform.platform(), + "cpu_model": platform.processor(), + "ram_gb": round(psutil.virtual_memory().total / (1024**3), 1), + "ad_domain": ad_domain, + "is_ad_domain_controller": is_dc, + "is_entra_connect_server": is_entra_connect_server(), + "entra_domain": get_entra_domain(), + "org_id": org_id, + } + return host_info diff --git a/iot_hub_module/connection_management.py b/iot_hub_module/connection_management.py index e175392..ad15cc5 100644 --- a/iot_hub_module/connection_management.py +++ b/iot_hub_module/connection_management.py @@ -1,362 +1,373 @@ -""" Module for defining class and functions to manage connections. """ - -from typing import Dict, Any - -import asyncio -import base64 -import json -import os -import subprocess -import logging -import platform -import signal -import tempfile -import httpx - -from azure.iot.device.aio import IoTHubDeviceClient -from azure.iot.device.iothub.models import Message - -from platformdirs import ( - site_config_dir -) -from config_module.config_io import ( - get_config_file_path, - get_agent_executable_path, - get_service_executable_path, - get_service_manager_path -) -from config_module.host_info import build_host_tags -import service_module.service_management - -# Set up logging -logging.basicConfig(level=logging.INFO) - -os_type = platform.system().lower() - - -class ConnectionManager: - """ - Manages the connection between the agent and IoT Hub. - """ - - def __init__(self, config_data: Dict[str, Any]) -> None: - """Construcs a new connection manager instance - - Args: - config_data (Dict[str, Any]): Configuration data of the connection. - """ - self.config_data = config_data - self.connection_string = self.get_connection_string() - self.os_type = platform.system().lower() - self.client = IoTHubDeviceClient.create_from_connection_string( - self.connection_string) - - def get_connection_string(self) -> str: - """ - Get the connection string used to connect to the IoT Hub. - - Returns: - str: Connection string to the IoT Hub. - """ - conn_str = ( - f"HostName={self.config_data['azure_iot_hub_host']};" - f"DeviceId={self.config_data['device_id']};" - f"SharedAccessKey={self.config_data['shared_access_key']}" - ) - return conn_str - - async def connect(self) -> None: - """ - Connect the agent service to the IoT Hub. - """ - try: - await self.client.connect() - except Exception as e: - logging.exception(f"Exception in connection to the IoT Hub: {e}") - - async def disconnect(self) -> None: - """ - Disconnect the agent service from the IoT Hub. - """ - try: - await self.client.disconnect() - except Exception as e: - logging.exception( - f"Exception in disconnecting from the IoT Hub: {e}") - - async def send_message(self, message_data: Dict[str, Any]) -> None: - """ - Send a message to the IoT Hub. - - Args: - message_data (Dict[str, Any]): Message data in JSON format. - """ - message_json = json.dumps(message_data) - await self.client.send_message(message_json) - - async def set_message_handler(self) -> None: - """ - Sets the event handler for income messages from the Iot Hub. - """ - self.client.on_message_received = self.handle_message - - async def execute_commands(self, commands: bytes, post_url: str = None, interpreter_override: str = None) -> Dict[str, str]: - """ - Execute commands on the machine using the specified interpreter and send back result via post_url. - - Args: - commands (str): Base64 encoded list of commands. - post_url (str, optional): Post back URL to send the stdout and stderr results of the commands after execution. Defaults to None. - interpreter_override (str, optional): Interpreter name to use in executing the commands. Defaults to None. - - Returns: - Dict[str, str]: Output message in JSON format sent to the post_url. - """ - interpreter = interpreter_override or self.get_default_interpreter() - logging.info(f"Using interpreter: {interpreter}") - output_message_data = None - - # Write commands to a temporary file - script_suffix = ".ps1" if "powershell" in interpreter.lower() else ".sh" - tmp_dir = None - if os_type == "windows": - config_dir = site_config_dir() - scripts_dir = os.path.join( - config_dir, "\\RewstRemoteAgent\\scripts") - # scripts_dir = "C:\\Scripts" - if not os.path.exists(scripts_dir): - os.makedirs(scripts_dir) - tmp_dir = scripts_dir - with tempfile.NamedTemporaryFile(delete=False, suffix=script_suffix, - mode="w", dir=tmp_dir) as temp_file: - if "powershell" in interpreter.lower(): - # If PowerShell is used, decode the commands - decoded_commands = base64.b64decode( - commands).decode('utf-16-le') - # Ensure TLS 1.2 configuration is set at the beginning of the command - tls_command = "[Net.ServicePointManager]::SecurityProtocol = [Net.SecurityProtocolType]::Tls12" - if tls_command not in decoded_commands: - decoded_commands = tls_command + "\n" + decoded_commands - else: - # For other interpreters, you might want to handle encoding differently - decoded_commands = base64.b64decode(commands).decode('utf-8') - - # logging.info(f"Decoded Commands:\n{decoded_commands}") - temp_file.write(decoded_commands) - temp_file.flush() # Explicitly flush the file buffer - os.fsync(temp_file.fileno()) # Ensures all data is written to disk - temp_file_path = temp_file.name - - logging.info(f"Wrote commands to temp file {temp_file_path}") - - # Construct the shell command to execute the temp file - if "powershell" in interpreter.lower() or "pwsh" in interpreter.lower(): - shell_command = f'{interpreter} -File "{temp_file_path}"' - else: - shell_command = f'{interpreter} "{temp_file_path}"' - - try: - # Execute the command - logging.info(f"Running process via commandline: {shell_command}") - process = subprocess.Popen( - shell_command, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - shell=True, - text=True - ) - stdout, stderr = process.communicate() - exit_code = process.returncode - logging.info(f"Command completed with exit code {exit_code}") - - if exit_code != 0 or stderr: - # Log and print error details - error_message = f"Script execution failed with exit code { - exit_code}. Error: {stderr}" - logging.error(error_message) - print(error_message) # Print to console - output_message_data = { - 'output': stdout, - 'error': error_message - } - else: - output_message_data = { - 'output': stdout, - 'error': '' - } - - except subprocess.CalledProcessError as e: - logging.error( - f"Command '{shell_command}' failed with error code {e.returncode}") - logging.error(f"Error output: {e.output}") - output_message_data = { - 'output': '', - 'error': f"Command failed with error code {e.returncode}: {e.output}" - } - - except Exception as e: - logging.error(f"An unexpected error occurred: {e}") - output_message_data = { - 'output': '', - 'error': f"An unexpected error occurred: {e}" - } - - finally: - # Loop to wait until the file can be deleted - while True: - try: - if os.path.exists(temp_file_path): - os.remove(temp_file_path) - break # If successful, break out of the loop - except PermissionError: - await asyncio.sleep(1) - except Exception as e: - logging.error(f"Error deleting temporary file: {e}") - break # If a different error occurs, break out of the loop - - if post_url and output_message_data: - logging.info("Sending Results to Rewst via httpx.") - async with httpx.AsyncClient() as client: - response = await client.post(post_url, json=output_message_data) - logging.info(f"POST request status: {response.status_code}") - if response.status_code != 200: - logging.error(f"Error response: {response.text}") - - return output_message_data - - async def handle_message(self, message: Message) -> None: - """Handle incoming message event from the IoT Hub. - - Args: - message (Message): Message instance from the IoT Hub. - """ - logging.info("Received IoT Hub message in handle_message.") - try: - message_data = json.loads(message.data) - get_installation_info = message_data.get("get_installation") - commands = message_data.get("commands") - post_id = message_data.get("post_id") - interpreter_override = message_data.get("interpreter_override") - - if post_id: - post_path = post_id.replace(":", "/") - rewst_engine_host = self.config_data["rewst_engine_host"] - post_url = f"https://{ - rewst_engine_host}/webhooks/custom/action/{post_path}" - logging.info(f"Will POST results to {post_url}") - else: - post_url = None - - if commands: - logging.info("Received commands in message") - try: - await self.execute_commands(commands, post_url, interpreter_override) - except Exception as e: - logging.exception(f"Exception running commands: {e}") - - if get_installation_info: - logging.info("Received request for installation paths") - try: - await self.get_installation(post_url) - except Exception as e: - logging.exception( - f"Exception getting installation info: {e}") - except json.JSONDecodeError as e: - logging.error(f"Error decoding message data as JSON: {e}") - except Exception as e: - logging.exception(f"An unexpected error occurred: {e}") - - async def get_installation(self, post_url: str) -> None: - """Send installation data of the service to the Rewst platform. The post_url - is an ephemeral link generated by the Rewst platform. - - Args: - post_url (str): Post back link to send the installation data to. - """ - org_id = self.config_data['rewst_org_id'] - service_executable_path = get_service_executable_path(org_id) - agent_executable_path = get_agent_executable_path(org_id) - service_manager_path = get_service_manager_path(org_id) - config_file_path = get_config_file_path(org_id) - - paths_data = { - "service_executable_path": service_executable_path, - "agent_executable_path": agent_executable_path, - "config_file_path": config_file_path, - "service_manager_path": service_manager_path, - "tags": build_host_tags(org_id) - } - - try: - async with httpx.AsyncClient() as client: - response = await client.post(post_url, json=paths_data) - response.raise_for_status() - - except httpx.RequestError as e: - logging.error(f"Request to {post_url} failed: {e}") - except httpx.HTTPStatusError as e: - logging.error(f"Error response {e.response.status_code} while posting to { - post_url}: {e.response.text}") - except Exception as e: - logging.error(f"An unexpected error occurred while posting to { - post_url}: {e}") - - def get_default_interpreter(self) -> str: - """Get the default interpreter depending on the platform's OS type. - - Returns: - str: Interpreter executable path. - """ - if self.os_type == 'windows': - return 'powershell' - elif self.os_type == 'darwin': - return '/bin/zsh' - else: - return '/bin/bash' - - -async def iot_hub_connection_loop(config_data: Dict[str, Any], stop_event: asyncio.Event = asyncio.Event()) -> None: - """Connect to the IoT Hub and wait for a stop event to close the loop. - - Args: - config_data (Dict[str, Any]): Configuration data of the agent service. - stop_event (asyncio.Event): Stop event instance. - """ - def signal_handler(signum, frame): - logging.info(f"Received signal { - signum}. Initiating graceful shutdown.") - stop_event.set() - - signal.signal(signal.SIGTERM, signal_handler) - signal.signal(signal.SIGINT, signal_handler) - - try: - # Instantiate ConnectionManager - connection_manager = ConnectionManager(config_data) - - # Connect to IoT Hub - logging.info("Connecting to IoT Hub...") - await connection_manager.connect() - - # Update Device Twin reported properties to 'online' - logging.info("Updating device status to online...") - twin_patch = {"connectivity": {"status": "online"}} - await connection_manager.client.patch_twin_reported_properties(twin_patch) - - # Set Message Handler - logging.info("Setting up message handler...") - await connection_manager.set_message_handler() - - # Use an asyncio.Event to exit the loop when the service stops - while not stop_event.is_set(): - await asyncio.sleep(1) - - # Before disconnecting, update Device Twin reported properties to 'offline' - logging.info("Updating device status to offline...") - twin_patch = {"connectivity": {"status": "offline"}} - await connection_manager.client.patch_twin_reported_properties(twin_patch) - - await connection_manager.disconnect() - - except Exception as e: - logging.exception(f"Exception Caught during IoT Hub Loop: {str(e)}") +""" Module for defining class and functions to manage connections. """ + +from typing import Dict, Any + +import asyncio +import base64 +import json +import os +import subprocess +import logging +import platform +import signal +import tempfile +import httpx + +from azure.iot.device.aio import IoTHubDeviceClient +from azure.iot.device.iothub.models import Message +from azure.iot.device.exceptions import ConnectionFailedError, ConnectionDroppedError + +from platformdirs import ( + site_config_dir +) +from config_module.config_io import ( + get_config_file_path, + get_agent_executable_path, + get_service_executable_path, + get_service_manager_path +) +from config_module.host_info import build_host_tags + +# Set up logging +logging.basicConfig(level=logging.INFO) + +os_type = platform.system().lower() + + +class ConnectionManager: + """ + Manages the connection between the agent and IoT Hub. + """ + + def __init__(self, config_data: Dict[str, Any]) -> None: + """Construcs a new connection manager instance + + Args: + config_data (Dict[str, Any]): Configuration data of the connection. + """ + self.config_data = config_data + self.connection_string = self.get_connection_string() + self.os_type = platform.system().lower() + self.client = IoTHubDeviceClient.create_from_connection_string( + self.connection_string) + + def get_connection_string(self) -> str: + """ + Get the connection string used to connect to the IoT Hub. + + Returns: + str: Connection string to the IoT Hub. + """ + conn_str = ( + f"HostName={self.config_data['azure_iot_hub_host']};" + f"DeviceId={self.config_data['device_id']};" + f"SharedAccessKey={self.config_data['shared_access_key']}" + ) + return conn_str + + async def connect_using_websockets(self) -> None: + """ + Modify the connection to use websockets and reconnect + """ + try: + logging.info("Connecting over websockets...") + self.client = IoTHubDeviceClient.create_from_connection_string( + self.connection_string, websockets=True) + await self.client.connect() + except Exception as e: + logging.exception("Exception in connection to the IoT Hub: %s", e) + + async def connect(self) -> None: + """ + Connect the agent service to the IoT Hub. + """ + try: + await self.client.connect() + except (ConnectionFailedError, ConnectionDroppedError): + await self.connect_using_websockets() + except Exception as e: + logging.exception("Exception in connection to the IoT Hub: %s", e) + + async def disconnect(self) -> None: + """ + Disconnect the agent service from the IoT Hub. + """ + try: + await self.client.disconnect() + except Exception as e: + logging.exception( + "Exception in disconnecting from the IoT Hub: %s", e) + + async def send_message(self, message_data: Dict[str, Any]) -> None: + """ + Send a message to the IoT Hub. + + Args: + message_data (Dict[str, Any]): Message data in JSON format. + """ + message_json = json.dumps(message_data) + await self.client.send_message(message_json) + + async def set_message_handler(self) -> None: + """ + Sets the event handler for income messages from the Iot Hub. + """ + self.client.on_message_received = self.handle_message + + async def execute_commands(self, commands: bytes, post_url: str = None, interpreter_override: str = None) -> Dict[str, str]: + """ + Execute commands on the machine using the specified interpreter and send back result via post_url. + + Args: + commands (str): Base64 encoded list of commands. + post_url (str, optional): Post back URL to send the stdout and stderr results of the commands after execution. Defaults to None. + interpreter_override (str, optional): Interpreter name to use in executing the commands. Defaults to None. + + Returns: + Dict[str, str]: Output message in JSON format sent to the post_url. + """ + interpreter = interpreter_override or self.get_default_interpreter() + logging.info("Using interpreter: %s", interpreter) + output_message_data = None + + # Write commands to a temporary file + script_suffix = ".ps1" if "powershell" in interpreter.lower() else ".sh" + tmp_dir = None + if os_type == "windows": + config_dir = site_config_dir() + scripts_dir = os.path.join( + config_dir, "\\RewstRemoteAgent\\scripts") + # scripts_dir = "C:\\Scripts" + if not os.path.exists(scripts_dir): + os.makedirs(scripts_dir) + tmp_dir = scripts_dir + with tempfile.NamedTemporaryFile(delete=False, suffix=script_suffix, + mode="w", dir=tmp_dir) as temp_file: + if "powershell" in interpreter.lower(): + # If PowerShell is used, decode the commands + decoded_commands = base64.b64decode( + commands).decode('utf-16-le') + # Ensure TLS 1.2 configuration is set at the beginning of the command + tls_command = "[Net.ServicePointManager]::SecurityProtocol = [Net.SecurityProtocolType]::Tls12" + if tls_command not in decoded_commands: + decoded_commands = tls_command + "\n" + decoded_commands + else: + # For other interpreters, you might want to handle encoding differently + decoded_commands = base64.b64decode(commands).decode('utf-8') + + # logging.info(f"Decoded Commands:\n{decoded_commands}") + temp_file.write(decoded_commands) + temp_file.flush() # Explicitly flush the file buffer + os.fsync(temp_file.fileno()) # Ensures all data is written to disk + temp_file_path = temp_file.name + + logging.info("Wrote commands to temp file %s", temp_file_path) + + # Construct the shell command to execute the temp file + if "powershell" in interpreter.lower() or "pwsh" in interpreter.lower(): + shell_command = f'{interpreter} -File "{temp_file_path}"' + else: + shell_command = f'{interpreter} "{temp_file_path}"' + + try: + # Execute the command + logging.info("Running process via commandline: %s", shell_command) + process = subprocess.Popen( + shell_command, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + shell=True, + text=True + ) + stdout, stderr = process.communicate() + exit_code = process.returncode + logging.info("Command completed with exit code %d", exit_code) + + if exit_code != 0 or stderr: + # Log and print error details + error_message = f"Script execution failed with exit code { + exit_code}. Error: {stderr}" + logging.error(error_message) + print(error_message) # Print to console + output_message_data = { + 'output': stdout, + 'error': error_message + } + else: + output_message_data = { + 'output': stdout, + 'error': '' + } + + except subprocess.CalledProcessError as e: + logging.error( + "Command '%s' failed with error code %d", shell_command, e.returncode) + logging.error("Error output: %s", e.output) + output_message_data = { + 'output': '', + 'error': f"Command failed with error code {e.returncode}: {e.output}" + } + + except Exception as e: + logging.error("An unexpected error occurred: %s", e) + output_message_data = { + 'output': '', + 'error': f"An unexpected error occurred: {e}" + } + + finally: + # Loop to wait until the file can be deleted + while True: + try: + if os.path.exists(temp_file_path): + os.remove(temp_file_path) + break # If successful, break out of the loop + except PermissionError: + await asyncio.sleep(1) + except Exception as e: + logging.error("Error deleting temporary file: %s", e) + break # If a different error occurs, break out of the loop + + if post_url and output_message_data: + logging.info("Sending Results to Rewst via httpx.") + async with httpx.AsyncClient() as client: + response = await client.post(post_url, json=output_message_data) + logging.info("POST request status: %d", response.status_code) + if response.status_code != 200: + logging.error("Error response: %s", response.text) + + return output_message_data + + async def handle_message(self, message: Message) -> None: + """Handle incoming message event from the IoT Hub. + + Args: + message (Message): Message instance from the IoT Hub. + """ + logging.info("Received IoT Hub message in handle_message.") + try: + message_data = json.loads(message.data) + get_installation_info = message_data.get("get_installation") + commands = message_data.get("commands") + post_id = message_data.get("post_id") + interpreter_override = message_data.get("interpreter_override") + + if post_id: + post_path = post_id.replace(":", "/") + rewst_engine_host = self.config_data["rewst_engine_host"] + post_url = f"https://{ + rewst_engine_host}/webhooks/custom/action/{post_path}" + logging.info("Will POST results to %s", post_url) + else: + post_url = None + + if commands: + logging.info("Received commands in message") + try: + await self.execute_commands(commands, post_url, interpreter_override) + except Exception as e: + logging.exception("Exception running commands: %s", e) + + if get_installation_info: + logging.info("Received request for installation paths") + try: + await self.get_installation(post_url) + except Exception as e: + logging.exception( + "Exception getting installation info: %s", e) + except json.JSONDecodeError as e: + logging.error("Error decoding message data as JSON: %s", e) + except Exception as e: + logging.exception("An unexpected error occurred: %s", e) + + async def get_installation(self, post_url: str) -> None: + """Send installation data of the service to the Rewst platform. The post_url + is an ephemeral link generated by the Rewst platform. + + Args: + post_url (str): Post back link to send the installation data to. + """ + org_id = self.config_data['rewst_org_id'] + service_executable_path = get_service_executable_path(org_id) + agent_executable_path = get_agent_executable_path(org_id) + service_manager_path = get_service_manager_path(org_id) + config_file_path = get_config_file_path(org_id) + + paths_data = { + "service_executable_path": service_executable_path, + "agent_executable_path": agent_executable_path, + "config_file_path": config_file_path, + "service_manager_path": service_manager_path, + "tags": build_host_tags(org_id) + } + + try: + async with httpx.AsyncClient() as client: + response = await client.post(post_url, json=paths_data) + response.raise_for_status() + + except httpx.RequestError as e: + logging.error(f"Request to {post_url} failed: {e}") + except httpx.HTTPStatusError as e: + logging.error(f"Error response {e.response.status_code} while posting to {post_url}: {e.response.text}") + except Exception as e: + logging.error(f"An unexpected error occurred while posting to {post_url}: {e}") + + def get_default_interpreter(self) -> str: + """Get the default interpreter depending on the platform's OS type. + + Returns: + str: Interpreter executable path. + """ + if self.os_type == 'windows': + return 'powershell' + elif self.os_type == 'darwin': + return '/bin/zsh' + else: + return '/bin/bash' + + +async def iot_hub_connection_loop(config_data: Dict[str, Any], stop_event: asyncio.Event = asyncio.Event()) -> None: + """Connect to the IoT Hub and wait for a stop event to close the loop. + + Args: + config_data (Dict[str, Any]): Configuration data of the agent service. + stop_event (asyncio.Event): Stop event instance. + """ + def signal_handler(signum, frame): + logging.info(f"Received signal {signum}. Initiating graceful shutdown.") + stop_event.set() + + signal.signal(signal.SIGTERM, signal_handler) + signal.signal(signal.SIGINT, signal_handler) + + try: + # Instantiate ConnectionManager + connection_manager = ConnectionManager(config_data) + + # Connect to IoT Hub + logging.info("Connecting to IoT Hub...") + await connection_manager.connect() + + # Update Device Twin reported properties to 'online' + logging.info("Updating device status to online...") + twin_patch = {"connectivity": {"status": "online"}} + await connection_manager.client.patch_twin_reported_properties(twin_patch) + + # Set Message Handler + logging.info("Setting up message handler...") + await connection_manager.set_message_handler() + + # Use an asyncio.Event to exit the loop when the service stops + while not stop_event.is_set(): + await asyncio.sleep(1) + + # Before disconnecting, update Device Twin reported properties to 'offline' + logging.info("Updating device status to offline...") + twin_patch = {"connectivity": {"status": "offline"}} + await connection_manager.client.patch_twin_reported_properties(twin_patch) + + await connection_manager.disconnect() + + except Exception as e: + logging.exception(f"Exception Caught during IoT Hub Loop: {str(e)}") diff --git a/tests/config_module/test_host_info.py b/tests/config_module/test_host_info.py index 50aaca7..a783f53 100644 --- a/tests/config_module/test_host_info.py +++ b/tests/config_module/test_host_info.py @@ -17,8 +17,10 @@ class TestHostInfo(unittest.TestCase): """Test class for host_info unit tests""" + @patch("psutil.net_if_addrs") @patch("uuid.getnode") - def test_get_mac_address(self, mock_getnode): + def test_get_mac_address(self, mock_getnode, mock_net): + mock_net.items.return_value = [] # Mocking the UUID return value to a fixed value for testing mock_getnode.return_value = 123456789012345 mac_address = get_mac_address()