diff --git a/sssd_test_framework/topology_controllers.py b/sssd_test_framework/topology_controllers.py index a0c1829..911037f 100644 --- a/sssd_test_framework/topology_controllers.py +++ b/sssd_test_framework/topology_controllers.py @@ -1,9 +1,6 @@ from __future__ import annotations -from functools import partial, wraps -from typing import Any - -from pytest_mh import MultihostBackupHost, TopologyController +from pytest_mh import BackupTopologyController from pytest_mh.conn import ProcessResult from .config import SSSDMultihostConfig @@ -24,29 +21,7 @@ ] -def restore_vanilla_on_error(method): - """ - Restore or hosts to its original state if an exception occurs - during method execution. - - :param method: Method to decorate. - :type method: _type_ - :return: _description_ - :rtype: _type_ - """ - - @wraps(method) - def wrapper(self: BackupTopologyController, *args, **kwargs): - try: - return self._invoke_with_args(partial(method, self)) - except Exception: - self.restore_vanilla() - raise - - return wrapper - - -class BackupTopologyController(TopologyController[SSSDMultihostConfig]): +class ProvisionedBackupTopologyController(BackupTopologyController[SSSDMultihostConfig]): """ Provide basic restore functionality for topologies. """ @@ -54,92 +29,48 @@ class BackupTopologyController(TopologyController[SSSDMultihostConfig]): def __init__(self) -> None: super().__init__() - self.backup_data: dict[MultihostBackupHost, Any | None] = {} self.provisioned: bool = False def init(self, *args, **kwargs): super().init(*args, **kwargs) self.provisioned = self.name in self.multihost.provisioned_topologies - def restore(self, hosts: dict[MultihostBackupHost, Any | None]) -> None: - errors = [] - for host, backup_data in hosts.items(): - if not isinstance(host, MultihostBackupHost): - continue - - try: - host.restore(backup_data) - except Exception as e: - errors.append(e) - - if errors: - raise ExceptionGroup("Some hosts failed to restore to original state", errors) - - def restore_vanilla(self) -> None: - restore_data: dict[MultihostBackupHost, Any | None] = {} - - for host in self.hosts: - if not isinstance(host, MultihostBackupHost): - continue - - restore_data[host] = host.backup_data - - self.restore(restore_data) - def topology_teardown(self) -> None: if self.provisioned: return - try: - for host, backup_data in self.backup_data.items(): - if not isinstance(host, MultihostBackupHost): - continue - - host.remove_backup(backup_data) - except Exception: - # This is not that important, we can just ignore - pass - - self.restore_vanilla() + super().topology_teardown() def teardown(self) -> None: if self.provisioned: self.restore_vanilla() return - self.restore(self.backup_data) + super().teardown() -class ClientTopologyController(BackupTopologyController): +class ClientTopologyController(ProvisionedBackupTopologyController): """ Client Topology Controller. """ - def topology_teardown(self) -> None: - pass - - def teardown(self) -> None: - self.restore_vanilla() + pass -class LDAPTopologyController(BackupTopologyController): +class LDAPTopologyController(ProvisionedBackupTopologyController): """ LDAP Topology Controller. """ - def topology_teardown(self) -> None: - pass - - def teardown(self) -> None: - self.restore_vanilla() + pass -class IPATopologyController(BackupTopologyController): +class IPATopologyController(ProvisionedBackupTopologyController): """ IPA Topology Controller. """ - @restore_vanilla_on_error + @BackupTopologyController.restore_vanilla_on_error def topology_setup(self, client: ClientHost, ipa: IPAHost, nfs: NFSHost) -> None: if self.provisioned: self.logger.info(f"Topology '{self.name}' is already provisioned") @@ -159,17 +90,15 @@ def topology_setup(self, client: ClientHost, ipa: IPAHost, nfs: NFSHost) -> None client.conn.exec(["realm", "join", ipa.domain], input=ipa.adminpw) # Backup so we can restore to this state after each test - self.backup_data[ipa] = ipa.backup() - self.backup_data[client] = client.backup() - self.backup_data[nfs] = nfs.backup() + super().topology_setup() -class ADTopologyController(BackupTopologyController): +class ADTopologyController(ProvisionedBackupTopologyController): """ AD Topology Controller. """ - @restore_vanilla_on_error + @BackupTopologyController.restore_vanilla_on_error def topology_setup(self, client: ClientHost, provider: ADHost | SambaHost, nfs: NFSHost) -> None: if self.provisioned: self.logger.info(f"Topology '{self.name}' is already provisioned") @@ -185,9 +114,7 @@ def topology_setup(self, client: ClientHost, provider: ADHost | SambaHost, nfs: client.conn.exec(["realm", "join", provider.domain], input=provider.adminpw) # Backup so we can restore to this state after each test - self.backup_data[provider] = provider.backup() - self.backup_data[client] = client.backup() - self.backup_data[nfs] = nfs.backup() + super().topology_setup() class SambaTopologyController(ADTopologyController): @@ -198,12 +125,12 @@ class SambaTopologyController(ADTopologyController): pass -class IPATrustADTopologyController(BackupTopologyController): +class IPATrustADTopologyController(ProvisionedBackupTopologyController): """ IPA trust AD Topology Controller. """ - @restore_vanilla_on_error + @BackupTopologyController.restore_vanilla_on_error def topology_setup(self, client: ClientHost, ipa: IPAHost, trusted: ADHost | SambaHost) -> None: if self.provisioned: self.logger.info(f"Topology '{self.name}' is already provisioned") @@ -230,9 +157,7 @@ def topology_setup(self, client: ClientHost, ipa: IPAHost, trusted: ADHost | Sam client.conn.exec(["realm", "join", ipa.domain], input=ipa.adminpw) # Backup so we can restore to this state after each test - self.backup_data[ipa] = ipa.backup() - self.backup_data[trusted] = trusted.backup() - self.backup_data[client] = client.backup() + super().topology_setup() # If this command is run on freshly started containers, it is possible the IPA is not yet # fully ready to create the trust. It takes a while for it to start working.