From c90799fc3c6b2b5456886bb1e16cfadd9cc8cd5f Mon Sep 17 00:00:00 2001 From: Florian BRACQ Date: Wed, 8 Jan 2025 12:10:09 +0000 Subject: [PATCH] Explode kwargs --- msticpy/data/drivers/azure_kusto_driver.py | 99 +++++++++++++--------- 1 file changed, 59 insertions(+), 40 deletions(-) diff --git a/msticpy/data/drivers/azure_kusto_driver.py b/msticpy/data/drivers/azure_kusto_driver.py index 8cca0bfd..8ee54020 100644 --- a/msticpy/data/drivers/azure_kusto_driver.py +++ b/msticpy/data/drivers/azure_kusto_driver.py @@ -66,6 +66,8 @@ "https://msticpy.readthedocs.io/en/latest/DataProviders/DataProv-Kusto.html" ) +# pylint:disable=too-many-lines + logger: logging.Logger = logging.getLogger(__name__) @@ -146,6 +148,13 @@ class AzureKustoDriver(DriverBase): def __init__( self: AzureKustoDriver, connection_str: str | None = None, + *, + debug: bool = False, + data_environment: DataEnvironment = DataEnvironment.Kusto, + strict_query_match: bool = False, + timeout: int = _DEFAULT_TIMEOUT, + proxies: dict[str, str] | None = None, + max_threads: int = 4, **kwargs, ) -> None: """ @@ -176,10 +185,10 @@ def __init__( """ super().__init__(**kwargs) - if kwargs.get("debug", False): + if debug: logger.setLevel(logging.DEBUG) - self.environment: str = kwargs.get("data_environment", DataEnvironment.Kusto) - self._strict_query_match: bool = kwargs.get("strict_query_match", False) + self.environment: DataEnvironment = data_environment + self._strict_query_match: bool = strict_query_match self._kusto_settings: dict[str, dict[str, KustoConfig]] = _get_kusto_settings() self._default_database: str | None = None self._current_connection: str | None = connection_str @@ -188,13 +197,10 @@ def __init__( self._az_auth_types: list[str] | None = None self._az_tenant_id: str | None = None self._def_timeout: int = min( - kwargs.pop("timeout", _DEFAULT_TIMEOUT), + timeout, _MAX_TIMEOUT, ) - self._def_proxies: dict[str, str] | None = kwargs.get( - "proxies", - get_http_proxies(), - ) + self._def_proxies: dict[str, str] | None = proxies or get_http_proxies() self.add_query_filter("data_environments", "Kusto") self.set_driver_property(DriverProps.PUBLIC_ATTRS, self._set_public_attribs()) @@ -203,7 +209,7 @@ def __init__( self.set_driver_property(DriverProps.SUPPORTS_THREADING, value=True) self.set_driver_property( DriverProps.MAX_PARALLEL, - value=kwargs.get("max_threads", 4), + value=max_threads, ) self._loaded = True @@ -276,7 +282,20 @@ def set_database(self: Self, database: str) -> None: """Set the default database to `database`.""" self._default_database = database - def connect(self: Self, connection_str: str | None = None, **kwargs) -> None: + def connect( + self: Self, + connection_str: str | None = None, + *, + database: str | None = None, + timeout: int | None = None, + auth_types: str | list[str] | None = None, + mp_az_auth: bool | str | list[str] | None = None, + tenant_id: str | None = None, + mp_az_tenant_id: str | None = None, + cluster: str | None = None, + proxies: dict[str, str] | None = None, + **kwargs, + ) -> None: """ Connect to data source. @@ -328,32 +347,16 @@ def connect(self: Self, connection_str: str | None = None, **kwargs) -> None: connection_str, kwargs, ) - self._default_database = kwargs.pop("database", None) - self._def_timeout = min(kwargs.pop("timeout", self._def_timeout), _MAX_TIMEOUT) - az_auth_types: bool | str | list[str] = kwargs.pop( - "auth_types", - kwargs.pop("mp_az_auth", None), - ) + self._default_database = database + self._def_timeout = min(timeout or self._def_timeout, _MAX_TIMEOUT) + az_auth_types: bool | str | list[str] | None = auth_types or mp_az_auth if isinstance(az_auth_types, bool): self._az_auth_types = None elif isinstance(az_auth_types, str): self._az_auth_types = [az_auth_types] else: self._az_auth_types = az_auth_types - self._az_tenant_id = kwargs.pop( - "tenant_id", - kwargs.pop("mp_az_tenant_id", None), - ) - - cluster: str = kwargs.pop("cluster", None) - if connection_str: - self.current_connection = connection_str - if not connection_str and not cluster: - err_msg: str = "Must specify either a connection string or a cluster name" - raise MsticpyParameterError( - err_msg, - parameter=["connection_str", "cluster"], - ) + self._az_tenant_id = tenant_id or mp_az_tenant_id kusto_cs: KustoConnectionStringBuilder | str | None = None if cluster: @@ -367,15 +370,21 @@ def connect(self: Self, connection_str: str | None = None, **kwargs) -> None: ) kusto_cs = self._get_connection_string_for_cluster(self._current_config) self.current_connection = cluster - else: + elif connection_str: logger.info("Using connection string %s", connection_str) self.current_connection = connection_str kusto_cs = connection_str + else: + err_msg: str = "Must specify either a connection string or a cluster name" + raise MsticpyParameterError( + err_msg, + parameter=["connection_str", "cluster"], + ) if not kusto_cs: err_msg = "Kusto connection string required" raise MsticpyParameterError(err_msg) self.client = KustoClient(kusto_cs) - proxies: dict[str, str] | None = kwargs.get("proxies", self._def_proxies) + proxies = proxies or self._def_proxies proxy_url: str | None = proxies.get("https") if proxies else None if proxy_url: logger.info( @@ -393,6 +402,9 @@ def query( self: Self, query: str, query_source: QuerySource | None = None, + *, + timeout: int | None = None, + database: str | None = None, **kwargs, ) -> pd.DataFrame | dict[str, Any] | None: """ @@ -423,6 +435,8 @@ def query( data, result = self.query_with_results( query, query_source=query_source, + timeout=timeout, + database=database, **kwargs, ) return data if data is not None else result @@ -430,6 +444,10 @@ def query( def query_with_results( self: Self, query: str, + *, + query_source: QuerySource | None = None, + timeout: int | None = None, + database: str | None = None, **kwargs, ) -> tuple[pd.DataFrame | None, dict[str, Any]]: """ @@ -454,9 +472,9 @@ def query_with_results( and there is no default database. """ + del kwargs if not self._connected: _raise_not_connected_error() - query_source: QuerySource | None = kwargs.pop("query_source", None) if query_source and not self.query_usable(query_source): query_spec: dict[str, str] = self._get_cluster_spec_from_query_source( @@ -474,16 +492,16 @@ def query_with_results( help_uri=_HELP_URL, ) - database: str = self._get_query_database_name( + database = self._get_query_database_name( query_source=query_source, - **kwargs, + database=database, ) data: pd.DataFrame | None = None status: dict[str, bool] = {"success": False} connection_props = ClientRequestProperties() connection_props.set_option( ClientRequestProperties.request_timeout_option_name, - timedelta(seconds=kwargs.get("timeout", self._def_timeout)), + timedelta(seconds=timeout or self._def_timeout), ) if self.client is None: _raise_not_connected_error() @@ -729,7 +747,7 @@ def _get_auth_params_from_config( def _lookup_cluster_settings(self: Self, cluster: str) -> KustoConfig: """Return cluster URI from config if cluster name is passed.""" - cluster_key = cluster.casefold().strip() + cluster_key: str = cluster.casefold().strip() if cluster_key in self._kusto_settings["url"]: return self._kusto_settings["url"][cluster_key] if cluster_key in self._kusto_settings["name"]: @@ -761,10 +779,11 @@ def _lookup_cluster_settings(self: Self, cluster: str) -> KustoConfig: def _get_query_database_name( self: Self, query_source: QuerySource | None = None, - **kwargs, + *, + database: str | None = None, ) -> str: - """Get the database name from query source or kwargs.""" - if database := kwargs.get("database"): + """Get the database name from query source.""" + if database: logger.info("Using database %s from parameter.", database) return database if query_source: