diff --git a/msticpy/data/core/data_providers.py b/msticpy/data/core/data_providers.py index f9ec3455e..c11f0aa9c 100644 --- a/msticpy/data/core/data_providers.py +++ b/msticpy/data/core/data_providers.py @@ -93,7 +93,7 @@ def __init__( # noqa: MC0001 # pylint: enable=import-outside-toplevel setattr(self.__class__, "_add_pivots", add_data_queries_to_entities) - data_environment, self.environment_name = self._check_environment( + data_environment, self.environment_name = QueryProvider._check_environment( data_environment ) @@ -139,8 +139,9 @@ def __init__( # noqa: MC0001 self._query_time = QueryTime(units="day") logger.info("Initialization complete.") + @classmethod def _check_environment( - self, data_environment + cls, data_environment ) -> Tuple[Union[str, DataEnvironment], str]: """Check environment against known names.""" if isinstance(data_environment, str): @@ -212,42 +213,6 @@ def connect(self, connection_str: Optional[str] = None, **kwargs): logger.info("Adding query pivot functions") self._add_pivots(lambda: self._query_time.timespan) - def exec_query(self, query: str, **kwargs) -> Union[pd.DataFrame, Any]: - """ - Execute simple query string. - - Parameters - ---------- - query : str - [description] - use_connections : Union[str, List[str]] - - Other Parameters - ---------------- - query_options : Dict[str, Any] - Additional options passed to query driver. - kwargs : Dict[str, Any] - Additional options passed to query driver. - - Returns - ------- - Union[pd.DataFrame, Any] - Query results - a DataFrame if successful - or a KqlResult if unsuccessful. - - """ - query_options = kwargs.pop("query_options", {}) or kwargs - query_source = kwargs.pop("query_source", None) - - logger.info("Executing query '%s...'", query[:40]) - logger.debug("Full query: %s", query) - logger.debug("Query options: %s", query_options) - if not self._additional_connections: - return self._query_provider.query( - query, query_source=query_source, **query_options - ) - return self._exec_additional_connections(query, **kwargs) - @property def query_time(self): """Return the default QueryTime control for queries.""" diff --git a/msticpy/data/core/query_provider_connections_mixin.py b/msticpy/data/core/query_provider_connections_mixin.py index e88ce2c92..de00f5563 100644 --- a/msticpy/data/core/query_provider_connections_mixin.py +++ b/msticpy/data/core/query_provider_connections_mixin.py @@ -6,6 +6,7 @@ """Query Provider additional connection methods.""" import asyncio import logging +from abc import abstractmethod from concurrent.futures import ThreadPoolExecutor from datetime import datetime from functools import partial @@ -37,23 +38,61 @@ class QueryProviderProtocol(Protocol): _additional_connections: Dict[str, Any] _query_provider: DriverBase - def exec_query(self, query: str, **kwargs) -> Union[pd.DataFrame, Any]: - """Execute a query against the provider.""" - ... - - # fmt: off @staticmethod + @abstractmethod def _get_query_options( params: Dict[str, Any], kwargs: Dict[str, Any] ) -> Dict[str, Any]: - ... - # fmt: on + """Return any kwargs not already in params.""" # pylint: disable=super-init-not-called class QueryProviderConnectionsMixin(QueryProviderProtocol): """Mixin additional connection handling QueryProvider class.""" + @staticmethod + @abstractmethod + def _get_query_options( + params: Dict[str, Any], kwargs: Dict[str, Any] + ) -> Dict[str, Any]: + """Return any kwargs not already in params.""" + + def exec_query(self, query: str, **kwargs) -> Union[pd.DataFrame, Any]: + """ + Execute simple query string. + + Parameters + ---------- + query : str + [description] + use_connections : Union[str, List[str]] + + Other Parameters + ---------------- + query_options : Dict[str, Any] + Additional options passed to query driver. + kwargs : Dict[str, Any] + Additional options passed to query driver. + + Returns + ------- + Union[pd.DataFrame, Any] + Query results - a DataFrame if successful + or a KqlResult if unsuccessful. + + """ + query_options = kwargs.pop("query_options", {}) or kwargs + query_source = kwargs.pop("query_source", None) + + logger.info("Executing query '%s...'", query[:40]) + logger.debug("Full query: %s", query) + logger.debug("Query options: %s", query_options) + if not self._additional_connections: + return self._query_provider.query( + query, query_source=query_source, **query_options + ) + return self._exec_additional_connections(query, **kwargs) + def add_connection( self, connection_str: Optional[str] = None, @@ -159,8 +198,16 @@ def _exec_additional_connections(self, query, **kwargs) -> pd.DataFrame: if self._query_provider.get_driver_property(DriverProps.SUPPORTS_THREADING): logger.info("Running threaded queries.") event_loop = _get_event_loop() + max_workers: int = self._query_provider.get_driver_property( + DriverProps.MAX_PARALLEL + ) return event_loop.run_until_complete( - self._exec_queries_threaded(query_tasks, progress, retry) + self._exec_queries_threaded( + query_tasks, + progress, + retry, + max_workers, + ) ) # standard synchronous execution @@ -238,8 +285,16 @@ def _exec_split_query( if self._query_provider.get_driver_property(DriverProps.SUPPORTS_THREADING): logger.info("Running threaded queries.") event_loop = _get_event_loop() + max_workers: int = self._query_provider.get_driver_property( + DriverProps.MAX_PARALLEL + ) return event_loop.run_until_complete( - self._exec_queries_threaded(query_tasks, progress, retry) + self._exec_queries_threaded( + query_tasks, + progress, + retry, + max_workers, + ) ) # or revert to standard synchronous execution @@ -285,7 +340,11 @@ def _exec_synchronous_queries( results.append(query_task()) except MsticpyDataQueryError: print(f"Query {con_name} failed.") - return pd.concat(results) + if results: + return pd.concat(results) + + logger.warning("All queries failed.") + return pd.DataFrame() def _create_split_queries( self, @@ -318,29 +377,26 @@ def _create_split_queries( logger.info("Split query into %s chunks", len(split_queries)) return split_queries + @staticmethod async def _exec_queries_threaded( - self, query_tasks: Dict[str, partial], progress: bool = True, retry: bool = False, + max_workers: int = 4, ) -> pd.DataFrame: """Return results of multiple queries run as threaded tasks.""" logger.info("Running threaded queries for %d connections.", len(query_tasks)) event_loop = _get_event_loop() - with ThreadPoolExecutor( - max_workers=self._query_provider.get_driver_property( - DriverProps.MAX_PARALLEL - ) - ) as executor: + with ThreadPoolExecutor(max_workers=max_workers) as executor: # add the additional connections thread_tasks = { query_id: event_loop.run_in_executor(executor, query_func) for query_id, query_func in query_tasks.items() } results: List[pd.DataFrame] = [] - failed_tasks: Dict[str, asyncio.Future] = {} + failed_tasks_ids: List[str] = [] if progress: task_iter = tqdm( asyncio.as_completed(thread_tasks.values()), @@ -360,24 +416,33 @@ async def _exec_queries_threaded( "Query task '%s' failed with exception", query_id, ) - failed_tasks[query_id] = thread_task - - if retry and failed_tasks: - for query_id, thread_task in failed_tasks.items(): - try: - logger.info("Retrying query task '%s'", query_id) - result = await thread_task - results.append(result) - except Exception: # pylint: disable=broad-except - logger.warning( - "Retried query task '%s' failed with exception", - query_id, - exc_info=True, - ) - # Sort the results by the order of the tasks - results = [result for _, result in sorted(zip(thread_tasks, results))] - - return pd.concat(results, ignore_index=True) + # Reusing thread task would result in: + # RuntimeError: cannot reuse already awaited coroutine + # A new task should be queued + failed_tasks_ids.append(query_id) + + # Sort the results by the order of the tasks + results = [result for _, result in sorted(zip(thread_tasks, results))] + + if retry and failed_tasks_ids: + failed_results: pd.DataFrame = ( + await QueryProviderConnectionsMixin._exec_queries_threaded( + { + failed_tasks_id: query_tasks[failed_tasks_id] + for failed_tasks_id in failed_tasks_ids + }, + progress=progress, + retry=False, + max_workers=max_workers, + ) + ) + if not failed_results.empty: + results.append(failed_results) + if results: + return pd.concat(results, ignore_index=True) + + logger.warning("All queries failed.") + return pd.DataFrame() def _get_event_loop() -> asyncio.AbstractEventLoop: diff --git a/msticpy/data/drivers/driver_base.py b/msticpy/data/drivers/driver_base.py index bfb3d79de..9adde3269 100644 --- a/msticpy/data/drivers/driver_base.py +++ b/msticpy/data/drivers/driver_base.py @@ -182,7 +182,6 @@ def connect(self, connection_str: Optional[str] = None, **kwargs): Connect to a data source """ - return None @abc.abstractmethod def query( diff --git a/tests/data/test_async_queries.py b/tests/data/test_async_queries.py index 23d9fc489..18b8703f8 100644 --- a/tests/data/test_async_queries.py +++ b/tests/data/test_async_queries.py @@ -6,11 +6,14 @@ """Test async connections and split queries.""" from datetime import datetime, timedelta, timezone +from unittest.mock import patch import pandas as pd import pytest_check as check +from msticpy.common.exceptions import MsticpyDataQueryError from msticpy.data.core.data_providers import QueryProvider +from msticpy.data.drivers.local_data_driver import LocalDataDriver from msticpy.data.core.query_provider_connections_mixin import _calc_split_ranges from msticpy.data.drivers.driver_base import DriverProps @@ -49,6 +52,31 @@ def test_multiple_connections_sync(): # verify columns/schema is the same. check.equal(list(single_results.columns), list(multi_results.columns)) + # Check that with parameter progress = False, the result is still the same. + multi_results_no_progress = local_prov.WindowsSecurity.list_host_logons( + host_name="DESKTOP-12345", + start=start, + end=end, + progress=False, + ) + check.is_true(multi_results_no_progress.equals(multi_results)) + + # Check that even if only exceptions are returned, the result will be an empty dataframe. + with patch.object( + LocalDataDriver, "query", side_effect=MsticpyDataQueryError + ) as patched_query_exception: + multi_results_exception_raised: pd.DataFrame = ( + local_prov.WindowsSecurity.list_host_logons( + host_name="DESKTOP-12345", + start=start, + end=end, + ) + ) + check.is_true(patched_query_exception.called) + check.equal(patched_query_exception.call_count, len(connections)) + check.is_instance(multi_results_exception_raised, pd.DataFrame) + check.is_true(multi_results_exception_raised.empty) + def test_multiple_connections_threaded(): """Test adding connection instance to provider.""" @@ -86,6 +114,181 @@ def test_multiple_connections_threaded(): # verify columns/schema is the same. check.equal(list(single_results.columns), list(multi_results.columns)) + # Check that with parameter progress = False, the result is still the same. + multi_results_no_progress = local_prov.WindowsSecurity.list_host_logons( + host_name="DESKTOP-12345", + start=start, + end=end, + progress=False, + ) + check.is_true(multi_results_no_progress.equals(multi_results)) + + # Check that even if the query returns an empty dataframe, the result will be ok. + with patch.object( + LocalDataDriver, "query", return_value=pd.DataFrame() + ) as patched_query_empty_df: + multi_results_no_result: pd.DataFrame = ( + local_prov.WindowsSecurity.list_host_logons( + host_name="DESKTOP-12345", + start=start, + end=end, + ) + ) + check.is_true(patched_query_empty_df.called) + check.equal(patched_query_empty_df.call_count, len(connections)) + check.is_instance(multi_results_no_result, pd.DataFrame) + check.is_true(multi_results_no_result.empty) + + # Check that even if only exceptions are returned, the result will be an empty dataframe. + with patch.object( + LocalDataDriver, "query", side_effect=MsticpyDataQueryError + ) as patched_query_exception: + multi_results_exception_raised: pd.DataFrame = ( + local_prov.WindowsSecurity.list_host_logons( + host_name="DESKTOP-12345", + start=start, + end=end, + ) + ) + check.is_true(patched_query_exception.called) + check.equal(patched_query_exception.call_count, len(connections)) + check.is_instance(multi_results_exception_raised, pd.DataFrame) + check.is_true(multi_results_exception_raised.empty) + + # Check if retry parameter works as expected when only exceptions are raised. + with patch.object( + LocalDataDriver, + "query", + side_effect=MsticpyDataQueryError, + ) as patched_query_exception_then_exception: + multi_results_exception_raised_retried: pd.DataFrame = ( + local_prov.WindowsSecurity.list_host_logons( + host_name="DESKTOP-12345", + start=start, + end=end, + retry_on_error=True, + ) + ) + check.is_true(patched_query_exception_then_exception.called) + check.equal( + patched_query_exception_then_exception.call_count, len(connections) * 2 + ) + check.is_instance(multi_results_exception_raised_retried, pd.DataFrame) + check.is_true(multi_results_exception_raised_retried.empty) + + # Check if retry parameter works as expected. + # Exceptions will be raised for the first executions, then returns a dummy dataframe + with patch.object( + LocalDataDriver, + "query", + side_effect=[MsticpyDataQueryError for _ in connections] + + [single_results for _ in connections], + ) as patched_query_exception_then_ok: + multi_results_exception_raised_retried_success: pd.DataFrame = ( + local_prov.WindowsSecurity.list_host_logons( + host_name="DESKTOP-12345", + start=start, + end=end, + retry_on_error=True, + ) + ) + check.is_true(patched_query_exception_then_ok.called) + check.equal(patched_query_exception_then_ok.call_count, len(connections) * 2) + check.is_instance(multi_results_exception_raised_retried_success, pd.DataFrame) + check.is_false(multi_results_exception_raised_retried_success.empty) + check.is_true( + multi_results_exception_raised_retried_success.equals(multi_results) + ) + + # Check if retry parameter works as expected. + # Exceptions will be raised for the one driver for all executions + # but for the other drivers, only during the first execution + with patch.object( + LocalDataDriver, + "query", + side_effect=[MsticpyDataQueryError] + + [MsticpyDataQueryError for _ in range(len(connections) - 1)] + + [MsticpyDataQueryError] + + [single_results for _ in range(len(connections) - 1)], + ) as patched_query_exception_partial_success: + multi_results_exception_raised_retried_partial_success: pd.DataFrame = ( + local_prov.WindowsSecurity.list_host_logons( + host_name="DESKTOP-12345", + start=start, + end=end, + retry_on_error=True, + ) + ) + check.is_true(patched_query_exception_partial_success.called) + check.equal( + patched_query_exception_partial_success.call_count, len(connections) * 2 + ) + check.is_instance( + multi_results_exception_raised_retried_partial_success, pd.DataFrame + ) + check.is_true( + multi_results_exception_raised_retried_partial_success.equals( + pd.concat( + [single_results for _ in range(len(connections) - 1)], + ignore_index=True, + ) + ) + ) + + # Check if retry parameter works as expected. + # Exceptions will be raised for the one driver for all executions + # but for the other drivers, no exceptions will occur + with patch.object( + LocalDataDriver, + "query", + side_effect=[MsticpyDataQueryError] + + [single_results for _ in range(len(connections) - 1)] + + [MsticpyDataQueryError], + ) as patched_query_exception_partial_sucess_v2: + multi_results_exception_raised_retried_partial_success_v2: pd.DataFrame = ( + local_prov.WindowsSecurity.list_host_logons( + host_name="DESKTOP-12345", + start=start, + end=end, + retry_on_error=True, + ) + ) + check.is_true(patched_query_exception_partial_sucess_v2.called) + check.equal( + patched_query_exception_partial_sucess_v2.call_count, len(connections) + 1 + ) + check.is_instance( + multi_results_exception_raised_retried_partial_success_v2, pd.DataFrame + ) + check.is_true( + multi_results_exception_raised_retried_partial_success_v2.equals( + pd.concat( + [single_results for _ in range(len(connections) - 1)], + ignore_index=True, + ) + ) + ) + + # Check if running in ipython yields the same result. + with patch( + "msticpy.data.core.query_provider_connections_mixin.is_ipython", + return_value=True, + ) as patched_is_ipython: + multi_results_exception_raised_retried_success: pd.DataFrame = ( + local_prov.WindowsSecurity.list_host_logons( + host_name="DESKTOP-12345", + start=start, + end=end, + ) + ) + check.is_true(patched_is_ipython.called) + check.equal(patched_is_ipython.call_count, 2) + check.is_instance(multi_results_exception_raised_retried_success, pd.DataFrame) + check.is_false(multi_results_exception_raised_retried_success.empty) + check.is_true( + multi_results_exception_raised_retried_success.equals(multi_results) + ) + def test_split_queries_sync(): """Test queries split into time segments.""" diff --git a/tests/data/test_dataqueries.py b/tests/data/test_dataqueries.py index cb9e3d0ce..a6a43331d 100644 --- a/tests/data/test_dataqueries.py +++ b/tests/data/test_dataqueries.py @@ -6,13 +6,14 @@ """dataprovider query test class.""" import contextlib import io -import unittest import warnings from copy import deepcopy -from datetime import datetime +from datetime import datetime, timezone from functools import partial from pathlib import Path -from typing import Any, Dict, Iterable, Optional, Tuple, Union +from typing import Any, Dict, Iterable, List, Optional, Tuple, Union +from unittest import TestCase +from unittest.mock import patch import pandas as pd import pytest @@ -25,6 +26,7 @@ from msticpy.data.core.query_provider_connections_mixin import _calc_split_ranges from msticpy.data.core.query_source import QuerySource from msticpy.data.drivers.driver_base import DriverBase, DriverProps +from msticpy.data.query_defns import DataEnvironment from ..unit_test_lib import get_test_data_path @@ -91,7 +93,7 @@ def driver_queries(self) -> Iterable[Dict[str, str]]: ] -class TestDataQuery(unittest.TestCase): +class TestDataQuery(TestCase): """Unit test class.""" provider = None @@ -352,33 +354,55 @@ def test_connect_queries_dotted(self): q_src = q_store.get_query("Saved.Searches.test.query3") self.assertEqual(q_src.query, dotted_container_qs[2]["query"]) + def validate_time_ranges( + self, + ranges: List[Tuple[datetime, datetime]], + delta: pd.Timedelta, + ) -> None: + self.assertIsInstance(ranges, list) + + for index, range in enumerate(ranges): + self.assertIsInstance(range, tuple) + start_range = range[0] + self.assertIsInstance(start_range, datetime) + end_range = range[1] + self.assertIsInstance(end_range, datetime) + + self.assertGreater(end_range, start_range) + if index > 0: + previous_end_range = ranges[index - 1][1] + # Ensure the new start range starts after the former end_range + self.assertGreater(start_range, previous_end_range) + # Ensure we don't have more than 1ns between 2 ranges + self.assertEqual(start_range - previous_end_range, pd.Timedelta("1ns")) + # Ensure a given time range is not longer than the configured delta + 10% + self.assertLess(end_range - start_range, delta + (delta / 10)) + def test_split_ranges(self): """Test time range split logic.""" - start = datetime.utcnow() - pd.Timedelta("5h") - end = datetime.utcnow() + pd.Timedelta("5min") + start = datetime.now(tz=timezone.utc) - pd.Timedelta("5h") delta = pd.Timedelta("1h") + # Case where the last range has less than 10% of delta of difference + end = datetime.now(tz=timezone.utc) + delta / 1001 ranges = _calc_split_ranges(start, end, delta) - self.assertEqual(len(ranges), 6) self.assertEqual(ranges[0][0], start) self.assertEqual(ranges[-1][1], end) + self.validate_time_ranges(ranges=ranges, delta=delta) - st_times = [start_tm[0] for start_tm in ranges] - for end_time in (end_tm[1] for end_tm in ranges): - self.assertNotIn(end_time, st_times) - - end = end + pd.Timedelta("20min") + # Case where the last range has more than 10% of delta of difference + end = datetime.now(tz=timezone.utc) + delta / 999 ranges = _calc_split_ranges(start, end, delta) - self.assertEqual(len(ranges), 6) self.assertEqual(ranges[0][0], start) self.assertEqual(ranges[-1][1], end) + self.validate_time_ranges(ranges=ranges, delta=delta) def test_split_queries(self): """Test queries split into time segments.""" la_provider = self.la_provider - start = datetime.utcnow() - pd.Timedelta("5h") - end = datetime.utcnow() + pd.Timedelta("5min") + start = datetime.now(tz=timezone.utc) - pd.Timedelta("5h") + end = datetime.now(tz=timezone.utc) + pd.Timedelta("5min") delta = pd.Timedelta("1h") ranges = _calc_split_ranges(start, end, delta) @@ -409,8 +433,8 @@ def test_split_queries_err(self): self.assertIn("Cannot split a query", mssg.getvalue()) # With invalid split_query_by value it will default to 1D - start = datetime.utcnow() - pd.Timedelta("5D") - end = datetime.utcnow() + pd.Timedelta("5min") + start = datetime.now(tz=timezone.utc) - pd.Timedelta("5D") + end = datetime.now(tz=timezone.utc) + pd.Timedelta("5min") result_queries = la_provider.all_queries.list_alerts( "print", start=start, end=end, split_query_by="Invalid" @@ -418,6 +442,86 @@ def test_split_queries_err(self): queries = result_queries.split("\n\n") self.assertEqual(len(queries), 6) + def test_getattr_invalid_attribute(self) -> None: + """Test method get_attr when attribute is not a supported attribute.""" + with pytest.raises( + AttributeError, match=f"UTDataDriver has no attribute 'test'" + ): + self.provider.test + + def test_schema(self) -> None: + """Test default implementation of property schema.""" + schema: dict = self.provider.schema + self.assertIsInstance(schema, dict) + self.assertFalse(schema) + + def test_service_queries(self) -> None: + """Test default implementation of property service_queries.""" + service_queries: Tuple[dict, str] = self.provider.service_queries + self.assertIsInstance(service_queries, tuple) + self.assertIsInstance(service_queries[0], dict) + self.assertIsInstance(service_queries[1], str) + self.assertFalse(service_queries[0]) + self.assertFalse(service_queries[1]) + + def test_add_query_filter_invalid_parameter(self) -> None: + """Test default implementation of method add_query_filter with invalid name.""" + with pytest.raises(ValueError, match="'name' test must be one of:.*"): + self.provider.add_query_filter(name="test", query_filter="") + + def test_add_query_filter_as_str(self) -> None: + """Test default implementation of method add_query_filter with invalid name.""" + my_filter: str = "test_filter" + filter_name: str = "data_sources" + self.assertNotIn(filter_name, self.provider._query_filter) + self.provider.add_query_filter(name=filter_name, query_filter=my_filter) + self.assertIn(filter_name, self.provider._query_filter) + self.assertIn(my_filter, self.provider._query_filter[filter_name]) + + def test_set_driver_property(self) -> None: + """Test default implementation of method set_driver_property with invalid property.""" + with pytest.raises( + TypeError, + match="Property 'supports_threading' is not the correct type.", + ): + self.provider.set_driver_property( + name=DriverProps.SUPPORTS_THREADING, value=42 + ) + + def test_query_usable(self) -> None: + """Test default implementation of method query_usable.""" + self.assertTrue(self.provider.query_usable(query_source=None)) + + def test_execute_query_provider_not_loaded(self) -> None: + """Test method _execute_query when driver is not loaded.""" + self.la_provider._query_provider._loaded = False + with pytest.raises(ValueError, match="Provider is not loaded."): + self.la_provider._execute_query() + + def test_execute_query_provider_not_connected(self) -> None: + """Test method _execute_query when driver is not connected.""" + self.la_provider._query_provider._connected = False + with pytest.raises(ValueError, match="No connection to a data source."): + self.la_provider._execute_query() + + def test_check_for_time_params_missing_start(self) -> None: + """Test method _check_for_time_params when start is missing.""" + missing: List[str] = ["start"] + params: dict = {} + changes: bool = self.la_provider._check_for_time_params(params, missing) + self.assertTrue(changes) + self.assertIn("start", params) + self.assertFalse(missing) + + def test_check_for_time_params_missing_end(self) -> None: + """Test method _check_for_time_params when end is missing.""" + missing: List[str] = ["end"] + params: dict = {} + changes: bool = self.la_provider._check_for_time_params(params, missing) + self.assertTrue(changes) + self.assertIn("end", params) + self.assertFalse(missing) + _LOCAL_DATA_PATHS = [str(get_test_data_path().joinpath("localdata"))] @@ -581,3 +685,96 @@ def test_query_paths(mode): ): check.is_true(hasattr(qry_prov, data_family)) pkg_config._settings["QueryDefinitions"] = current_settings + + +def test_driver_props_valid_type_invalid_property_name() -> None: + """Test method valid_type when input property is not in the predefined properties.""" + valid: bool = DriverProps.valid_type( + property_name="random_property", + value=0, + ) + check.is_true(valid) + + +def test_driver_queries() -> None: + """Test default implementation of property driver_queries.""" + + class MinimalDriver(DriverBase): + def connect(): + pass + + def query(): + pass + + def query_with_results(): + pass + + driver = MinimalDriver() + driver_queries: List[dict] = driver.driver_queries + check.is_instance(driver_queries, list) + check.equal(len(driver_queries), 1) + check.is_instance(driver_queries[0], dict) + check.is_false(driver_queries[0]) + + +def test_init_invalid_driver() -> None: + """Test QueryProvider method __init__ with invalid driver.""" + data_environment: str = "Kusto" + with patch( + "msticpy.data.drivers.import_driver", return_value=QueryProvider + ) as mocked_import_driver, pytest.raises( + LookupError, + match=f"Could not find suitable data provider for", + ): + QueryProvider( + data_environment=data_environment, + driver=None, + ) + check.is_true(mocked_import_driver.called) + check.equal(mocked_import_driver.call_count, 1) + + +def test_check_environment_unknown_str_env() -> None: + """Test method _check_environment with unknown str environment.""" + invalid_data_environment: str = "invalid_env" + with pytest.raises( + TypeError, match=f"Unknown data environment {invalid_data_environment}" + ): + QueryProvider._check_environment(invalid_data_environment) + + +def test_check_environment_unknown_env_type() -> None: + """Test method _check_environment with invalid environment type.""" + invalid_data_environment: int = 42 + with pytest.raises( + TypeError, + match=f"Unknown data environment type {invalid_data_environment} \({type(invalid_data_environment)}\)", + ): + QueryProvider._check_environment(invalid_data_environment) + + +def test_check_environment_str_custom_provider() -> None: + """Test method _check_environment with a custom provider as a string.""" + data_environment: str = "Custom" + with patch( + "msticpy.data.drivers.CUSTOM_PROVIDERS", + ) as mocked_custom_providers: + # We only need to overwrite the __contains__ method to ensure the check + # value in drivers.CUSTOM_PROVIDERS + # always returns True. + mocked_custom_providers.__contains__.return_value = True + data_env, env_name = QueryProvider._check_environment(data_environment) + check.is_true(mocked_custom_providers.__contains__.called) + check.equal(mocked_custom_providers.__contains__.call_count, 1) + check.is_instance(data_env, str) + check.equal(data_env, data_environment) + check.equal(env_name, data_environment) + + +def test_check_environment_as_data_environment() -> None: + """Test method _check_environment with a DataEnvironment object.""" + data_environment: DataEnvironment = DataEnvironment.MSSentinel + data_env, env_name = QueryProvider._check_environment(data_environment) + check.is_instance(data_env, DataEnvironment) + check.equal(data_env, data_environment) + check.equal(env_name, data_environment.name) diff --git a/tests/data/test_query_source.py b/tests/data/test_query_source.py index 69134aa42..7edd213d9 100644 --- a/tests/data/test_query_source.py +++ b/tests/data/test_query_source.py @@ -7,7 +7,7 @@ import os import unittest import warnings -from datetime import datetime, timedelta +from datetime import datetime, timedelta, timezone from typing import Any, Optional, Tuple, Union import pandas as pd @@ -86,7 +86,7 @@ def setUp(self): def test_date_formatters_datetime(self): """Test date formatting standard date.""" # standard date - test_end = datetime.utcnow() + test_end = datetime.now(tz=timezone.utc) test_start = test_end - timedelta(days=1) check_dt_str = test_start.isoformat(sep="T") + "Z" q_src = self.query_sources["SecurityAlert"]["list_related_alerts"] @@ -95,7 +95,7 @@ def test_date_formatters_datetime(self): def test_date_formatters_datestring(self): """Test date formatting ISO date string.""" - test_end = datetime.utcnow() + test_end = datetime.now(tz=timezone.utc) test_start = test_end - timedelta(days=1) check_dt_str = test_start.isoformat(sep="T") + "Z" start = test_start.isoformat() @@ -109,7 +109,7 @@ def test_date_formatters_datestring(self): def test_date_formatters_off_1day(self): """Test date formatting Offset -1 day.""" - test_end = datetime.utcnow() + test_end = datetime.now(tz=timezone.utc) q_src = self.query_sources["SecurityAlert"]["list_related_alerts"] query = q_src.create_query(start=-1, end=0) check_date = test_end - timedelta(1) @@ -120,7 +120,7 @@ def test_date_formatters_off_1day(self): def test_date_formatters_off_1day_str(self): """Test date formatting Offset -1 day as string.""" - test_dt = datetime.utcnow() + test_dt = datetime.now(tz=timezone.utc) q_src = self.query_sources["SecurityAlert"]["list_related_alerts"] query = q_src.create_query(start="-1d", end=test_dt) check_date = test_dt - timedelta(1) @@ -130,7 +130,7 @@ def test_date_formatters_off_1day_str(self): def test_date_formatters_off_1week_str(self): """Test date formatting Offset -1 week.""" - test_dt = datetime.utcnow() + test_dt = datetime.now(tz=timezone.utc) q_src = self.query_sources["SecurityAlert"]["list_related_alerts"] query = q_src.create_query(start="-1w", end=test_dt) check_date = test_dt - timedelta(7) @@ -140,7 +140,7 @@ def test_date_formatters_off_1week_str(self): def test_date_formatters_off_1wk_rnd_dn(self): """Test date formatting Offset -1 week rounded to day.""" - test_dt = datetime.utcnow() + test_dt = datetime.now(tz=timezone.utc) q_src = self.query_sources["SecurityAlert"]["list_related_alerts"] query = q_src.create_query(start="-1w@d", end=test_dt) check_date = test_dt - timedelta(7) @@ -150,7 +150,7 @@ def test_date_formatters_off_1wk_rnd_dn(self): def test_date_formatters_off_1wk_rnd_up(self): """Test date formatting Offset +1 week rounded to day.""" - test_dt = datetime.utcnow() + test_dt = datetime.now(tz=timezone.utc) q_src = self.query_sources["SecurityAlert"]["list_related_alerts"] query = q_src.create_query(start="1w@d", end=test_dt) check_date = test_dt + timedelta(7 + 1) @@ -160,7 +160,7 @@ def test_date_formatters_off_1wk_rnd_up(self): def test_list_formatter(self): """Test for default list formatting.""" - test_end = datetime.utcnow() + test_end = datetime.now(tz=timezone.utc) test_start = test_end - timedelta(days=1) q_src = self.query_sources["Azure"]["list_azure_activity_for_ip"] ip_address_list = ["192.168.0.1", "192.168.0.2", "192.168.0.3"] @@ -201,7 +201,7 @@ def test_cust_formatters_splunk(): "list": splunk_driver.SplunkDriver._format_list, } - test_end = datetime.utcnow() + test_end = datetime.now(tz=timezone.utc) test_start = test_end - timedelta(days=1) ip_address_list = "192.168.0.1, 192.168.0.2, 192.168.0.3"