Skip to content

Commit

Permalink
Explode kwargs
Browse files Browse the repository at this point in the history
  • Loading branch information
FlorianBracq committed Jan 8, 2025
1 parent 9d5d45f commit c90799f
Showing 1 changed file with 59 additions and 40 deletions.
99 changes: 59 additions & 40 deletions msticpy/data/drivers/azure_kusto_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@
"https://msticpy.readthedocs.io/en/latest/DataProviders/DataProv-Kusto.html"
)

# pylint:disable=too-many-lines

logger: logging.Logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -146,6 +148,13 @@ class AzureKustoDriver(DriverBase):
def __init__(
self: AzureKustoDriver,
connection_str: str | None = None,
*,
debug: bool = False,
data_environment: DataEnvironment = DataEnvironment.Kusto,
strict_query_match: bool = False,
timeout: int = _DEFAULT_TIMEOUT,
proxies: dict[str, str] | None = None,
max_threads: int = 4,
**kwargs,
) -> None:
"""
Expand Down Expand Up @@ -176,10 +185,10 @@ def __init__(
"""
super().__init__(**kwargs)
if kwargs.get("debug", False):
if debug:
logger.setLevel(logging.DEBUG)
self.environment: str = kwargs.get("data_environment", DataEnvironment.Kusto)
self._strict_query_match: bool = kwargs.get("strict_query_match", False)
self.environment: DataEnvironment = data_environment
self._strict_query_match: bool = strict_query_match
self._kusto_settings: dict[str, dict[str, KustoConfig]] = _get_kusto_settings()
self._default_database: str | None = None
self._current_connection: str | None = connection_str
Expand All @@ -188,13 +197,10 @@ def __init__(
self._az_auth_types: list[str] | None = None
self._az_tenant_id: str | None = None
self._def_timeout: int = min(
kwargs.pop("timeout", _DEFAULT_TIMEOUT),
timeout,
_MAX_TIMEOUT,
)
self._def_proxies: dict[str, str] | None = kwargs.get(
"proxies",
get_http_proxies(),
)
self._def_proxies: dict[str, str] | None = proxies or get_http_proxies()

self.add_query_filter("data_environments", "Kusto")
self.set_driver_property(DriverProps.PUBLIC_ATTRS, self._set_public_attribs())
Expand All @@ -203,7 +209,7 @@ def __init__(
self.set_driver_property(DriverProps.SUPPORTS_THREADING, value=True)
self.set_driver_property(
DriverProps.MAX_PARALLEL,
value=kwargs.get("max_threads", 4),
value=max_threads,
)
self._loaded = True

Expand Down Expand Up @@ -276,7 +282,20 @@ def set_database(self: Self, database: str) -> None:
"""Set the default database to `database`."""
self._default_database = database

def connect(self: Self, connection_str: str | None = None, **kwargs) -> None:
def connect(
self: Self,
connection_str: str | None = None,
*,
database: str | None = None,
timeout: int | None = None,
auth_types: str | list[str] | None = None,
mp_az_auth: bool | str | list[str] | None = None,
tenant_id: str | None = None,
mp_az_tenant_id: str | None = None,
cluster: str | None = None,
proxies: dict[str, str] | None = None,
**kwargs,
) -> None:
"""
Connect to data source.
Expand Down Expand Up @@ -328,32 +347,16 @@ def connect(self: Self, connection_str: str | None = None, **kwargs) -> None:
connection_str,
kwargs,
)
self._default_database = kwargs.pop("database", None)
self._def_timeout = min(kwargs.pop("timeout", self._def_timeout), _MAX_TIMEOUT)
az_auth_types: bool | str | list[str] = kwargs.pop(
"auth_types",
kwargs.pop("mp_az_auth", None),
)
self._default_database = database
self._def_timeout = min(timeout or self._def_timeout, _MAX_TIMEOUT)
az_auth_types: bool | str | list[str] | None = auth_types or mp_az_auth
if isinstance(az_auth_types, bool):
self._az_auth_types = None
elif isinstance(az_auth_types, str):
self._az_auth_types = [az_auth_types]
else:
self._az_auth_types = az_auth_types
self._az_tenant_id = kwargs.pop(
"tenant_id",
kwargs.pop("mp_az_tenant_id", None),
)

cluster: str = kwargs.pop("cluster", None)
if connection_str:
self.current_connection = connection_str
if not connection_str and not cluster:
err_msg: str = "Must specify either a connection string or a cluster name"
raise MsticpyParameterError(
err_msg,
parameter=["connection_str", "cluster"],
)
self._az_tenant_id = tenant_id or mp_az_tenant_id

kusto_cs: KustoConnectionStringBuilder | str | None = None
if cluster:
Expand All @@ -367,15 +370,21 @@ def connect(self: Self, connection_str: str | None = None, **kwargs) -> None:
)
kusto_cs = self._get_connection_string_for_cluster(self._current_config)
self.current_connection = cluster
else:
elif connection_str:
logger.info("Using connection string %s", connection_str)
self.current_connection = connection_str
kusto_cs = connection_str
else:
err_msg: str = "Must specify either a connection string or a cluster name"
raise MsticpyParameterError(
err_msg,
parameter=["connection_str", "cluster"],
)
if not kusto_cs:
err_msg = "Kusto connection string required"
raise MsticpyParameterError(err_msg)
self.client = KustoClient(kusto_cs)
proxies: dict[str, str] | None = kwargs.get("proxies", self._def_proxies)
proxies = proxies or self._def_proxies
proxy_url: str | None = proxies.get("https") if proxies else None
if proxy_url:
logger.info(
Expand All @@ -393,6 +402,9 @@ def query(
self: Self,
query: str,
query_source: QuerySource | None = None,
*,
timeout: int | None = None,
database: str | None = None,
**kwargs,
) -> pd.DataFrame | dict[str, Any] | None:
"""
Expand Down Expand Up @@ -423,13 +435,19 @@ def query(
data, result = self.query_with_results(
query,
query_source=query_source,
timeout=timeout,
database=database,
**kwargs,
)
return data if data is not None else result

def query_with_results(
self: Self,
query: str,
*,
query_source: QuerySource | None = None,
timeout: int | None = None,
database: str | None = None,
**kwargs,
) -> tuple[pd.DataFrame | None, dict[str, Any]]:
"""
Expand All @@ -454,9 +472,9 @@ def query_with_results(
and there is no default database.
"""
del kwargs
if not self._connected:
_raise_not_connected_error()
query_source: QuerySource | None = kwargs.pop("query_source", None)

if query_source and not self.query_usable(query_source):
query_spec: dict[str, str] = self._get_cluster_spec_from_query_source(
Expand All @@ -474,16 +492,16 @@ def query_with_results(
help_uri=_HELP_URL,
)

database: str = self._get_query_database_name(
database = self._get_query_database_name(
query_source=query_source,
**kwargs,
database=database,
)
data: pd.DataFrame | None = None
status: dict[str, bool] = {"success": False}
connection_props = ClientRequestProperties()
connection_props.set_option(
ClientRequestProperties.request_timeout_option_name,
timedelta(seconds=kwargs.get("timeout", self._def_timeout)),
timedelta(seconds=timeout or self._def_timeout),
)
if self.client is None:
_raise_not_connected_error()
Expand Down Expand Up @@ -729,7 +747,7 @@ def _get_auth_params_from_config(

def _lookup_cluster_settings(self: Self, cluster: str) -> KustoConfig:
"""Return cluster URI from config if cluster name is passed."""
cluster_key = cluster.casefold().strip()
cluster_key: str = cluster.casefold().strip()
if cluster_key in self._kusto_settings["url"]:
return self._kusto_settings["url"][cluster_key]
if cluster_key in self._kusto_settings["name"]:
Expand Down Expand Up @@ -761,10 +779,11 @@ def _lookup_cluster_settings(self: Self, cluster: str) -> KustoConfig:
def _get_query_database_name(
self: Self,
query_source: QuerySource | None = None,
**kwargs,
*,
database: str | None = None,
) -> str:
"""Get the database name from query source or kwargs."""
if database := kwargs.get("database"):
"""Get the database name from query source."""
if database:
logger.info("Using database %s from parameter.", database)
return database
if query_source:
Expand Down

0 comments on commit c90799f

Please sign in to comment.