From 14a7cd1218e8699cc5eb58dd65acf4e0643f5bd5 Mon Sep 17 00:00:00 2001 From: "gabrielle.gm" Date: Fri, 27 Jan 2023 10:36:17 -0500 Subject: [PATCH] Remove config from task manager and stop killing it --- azimuth/app.py | 38 ++++----- azimuth/routers/app.py | 18 ++-- azimuth/routers/class_overlap.py | 4 + azimuth/routers/config.py | 6 +- azimuth/routers/custom_utterances.py | 4 +- azimuth/routers/dataset_warnings.py | 5 +- azimuth/routers/export.py | 12 +-- .../model_performance/confidence_histogram.py | 5 +- .../model_performance/confusion_matrix.py | 5 +- azimuth/routers/model_performance/metrics.py | 8 +- .../model_performance/outcome_count.py | 7 +- azimuth/routers/top_words.py | 5 +- azimuth/routers/utterances.py | 16 ++-- azimuth/startup.py | 19 +++-- azimuth/task_manager.py | 19 +++-- azimuth/utils/routers.py | 10 ++- tests/README.md | 3 +- tests/conftest.py | 4 +- tests/test_routers/conftest.py | 10 +-- tests/test_startup.py | 18 ++-- tests/test_task_manager.py | 84 +++++++++++-------- 21 files changed, 173 insertions(+), 127 deletions(-) diff --git a/azimuth/app.py b/azimuth/app.py index 63f689a6..7892ef37 100644 --- a/azimuth/app.py +++ b/azimuth/app.py @@ -18,7 +18,6 @@ from azimuth.startup import startup_tasks from azimuth.task_manager import TaskManager from azimuth.types import DatasetSplitName, ModuleOptions -from azimuth.utils.cluster import default_cluster from azimuth.utils.conversion import JSONResponseIgnoreNan from azimuth.utils.logs import set_logger_config from azimuth.utils.project import load_dataset_split_managers_from_config, save_config @@ -101,9 +100,7 @@ def start_app(config_path: Optional[str], load_config_history: bool, debug: bool if azimuth_config.dataset is None: raise ValueError("No dataset has been specified in the config.") - local_cluster = default_cluster(large=azimuth_config.large_dask_cluster) - - run_startup_tasks(azimuth_config, local_cluster) + run_startup_tasks(azimuth_config) assert_not_none(_task_manager).client.run(set_logger_config, level) app = create_app() @@ -240,25 +237,23 @@ def create_app() -> FastAPI: return app -def initialize_managers(azimuth_config: AzimuthConfig, cluster: SpecCluster): - """Initialize DatasetSplitManagers and TaskManagers. - +def initialize_managers_and_config( + azimuth_config: AzimuthConfig, cluster: Optional[SpecCluster] = None +): + """Initialize DatasetSplitManagers and Config. Args: - azimuth_config: Configuration - cluster: Dask cluster to use. + azimuth_config: Config + cluster: Dask cluster to use, if different than default. """ global _task_manager, _dataset_split_managers, _azimuth_config - _azimuth_config = azimuth_config - if _task_manager is not None: - task_history = _task_manager.current_tasks + if _task_manager: + _task_manager.clear_worker_cache() + _task_manager.restart() else: - task_history = {} - - _task_manager = TaskManager(azimuth_config, cluster=cluster) - - _task_manager.current_tasks = task_history + _task_manager = TaskManager(cluster, azimuth_config.large_dask_cluster) + _azimuth_config = azimuth_config _dataset_split_managers = load_dataset_split_managers_from_config(azimuth_config) @@ -295,15 +290,14 @@ def run_validation_module(pipeline_index=None): task_manager.restart() -def run_startup_tasks(azimuth_config: AzimuthConfig, cluster: SpecCluster): +def run_startup_tasks(azimuth_config: AzimuthConfig, cluster: Optional[SpecCluster] = None): """Initialize managers, run validation and startup tasks. Args: azimuth_config: Config - cluster: Cluster - + cluster: Dask cluster to use, if different than default. """ - initialize_managers(azimuth_config, cluster) + initialize_managers_and_config(azimuth_config, cluster) task_manager = assert_not_none(get_task_manager()) # Validate that everything is in order **before** the startup tasks. @@ -315,5 +309,5 @@ def run_startup_tasks(azimuth_config: AzimuthConfig, cluster: SpecCluster): save_config(azimuth_config) # Save only after the validation modules ran successfully global _startup_tasks, _ready_flag - _startup_tasks = startup_tasks(_dataset_split_managers, task_manager) + _startup_tasks = startup_tasks(_dataset_split_managers, task_manager, azimuth_config) _ready_flag = Event() diff --git a/azimuth/routers/app.py b/azimuth/routers/app.py index 9a3977ad..70fa2561 100644 --- a/azimuth/routers/app.py +++ b/azimuth/routers/app.py @@ -84,15 +84,12 @@ def get_dataset_info( get_dataset_split_manager_mapping ), startup_tasks: Dict[str, Module] = Depends(get_startup_tasks), - task_manager: TaskManager = Depends(get_task_manager), config: AzimuthConfig = Depends(get_config), ): eval_dm = dataset_split_managers.get(DatasetSplitName.eval) training_dm = dataset_split_managers.get(DatasetSplitName.train) dm = assert_not_none(eval_dm or training_dm) - model_contract = task_manager.config.model_contract - return DatasetInfoResponse( project_name=config.name, class_names=dm.get_class_names(), @@ -105,19 +102,16 @@ def get_dataset_info( if training_dm is not None else [], startup_tasks={k: v.status() for k, v in startup_tasks.items()}, - model_contract=model_contract, - prediction_available=predictions_available(task_manager.config), - perturbation_testing_available=perturbation_testing_available(task_manager.config), + model_contract=config.model_contract, + prediction_available=predictions_available(config), + perturbation_testing_available=perturbation_testing_available(config), available_dataset_splits=AvailableDatasetSplits( eval=eval_dm is not None, train=training_dm is not None ), - similarity_available=similarity_available(task_manager.config), + similarity_available=similarity_available(config), postprocessing_editable=None if config.pipelines is None - else [ - postprocessing_editable(task_manager.config, idx) - for idx in range(len(config.pipelines)) - ], + else [postprocessing_editable(config, idx) for idx in range(len(config.pipelines))], ) @@ -171,6 +165,7 @@ def get_perturbation_testing_summary( SupportedModule.PerturbationTestingMerged, dataset_split_name=DatasetSplitName.all, task_manager=task_manager, + config=config, last_update=last_update, mod_options=ModuleOptions(pipeline_index=pipeline_index), )[0] @@ -186,6 +181,7 @@ def get_perturbation_testing_summary( SupportedModule.PerturbationTestingSummary, dataset_split_name=DatasetSplitName.all, task_manager=task_manager, + config=config, mod_options=ModuleOptions(pipeline_index=pipeline_index), )[0] return PerturbationTestingSummary( diff --git a/azimuth/routers/class_overlap.py b/azimuth/routers/class_overlap.py index cefcb8af..eda94c4a 100644 --- a/azimuth/routers/class_overlap.py +++ b/azimuth/routers/class_overlap.py @@ -72,6 +72,7 @@ def get_class_overlap_plot( SupportedModule.ClassOverlap, dataset_split_name=dataset_split_name, task_manager=task_manager, + config=config, last_update=-1, )[0] class_overlap_plot_response: ClassOverlapPlotResponse = make_sankey_plot( @@ -94,6 +95,7 @@ def get_class_overlap_plot( def get_class_overlap( dataset_split_name: DatasetSplitName, task_manager: TaskManager = Depends(get_task_manager), + config: AzimuthConfig = Depends(get_config), dataset_split_manager: DatasetSplitManager = Depends(get_dataset_split_manager), dataset_split_managers: Dict[DatasetSplitName, DatasetSplitManager] = Depends( get_dataset_split_manager_mapping @@ -106,6 +108,7 @@ def get_class_overlap( SupportedModule.ClassOverlap, dataset_split_name=dataset_split_name, task_manager=task_manager, + config=config, last_update=-1, )[0] dataset_class_count = class_overlap_result.s_matrix.shape[0] @@ -121,6 +124,7 @@ def get_class_overlap( SupportedModule.ConfusionMatrix, DatasetSplitName.eval, task_manager=task_manager, + config=config, mod_options=ModuleOptions( pipeline_index=pipeline_index, cf_normalize=False, cf_reorder_classes=False ), diff --git a/azimuth/routers/config.py b/azimuth/routers/config.py index 1dd7302c..ba8ab872 100644 --- a/azimuth/routers/config.py +++ b/azimuth/routers/config.py @@ -11,7 +11,7 @@ from azimuth.app import ( get_config, get_task_manager, - initialize_managers, + initialize_managers_and_config, require_editable_config, run_startup_tasks, ) @@ -81,11 +81,11 @@ def patch_config( try: new_config = update_config(old_config=config, partial_config=partial_config) - run_startup_tasks(new_config, task_manager.cluster) + run_startup_tasks(new_config) except Exception as e: log.error("Rollback config update due to error", exc_info=e) new_config = config - initialize_managers(new_config, task_manager.cluster) + initialize_managers_and_config(new_config) if isinstance(e, AzimuthValidationError): raise HTTPException(HTTP_400_BAD_REQUEST, detail=str(e)) elif isinstance(e, ValidationError): diff --git a/azimuth/routers/custom_utterances.py b/azimuth/routers/custom_utterances.py index 3da6f4d8..f61eb51a 100644 --- a/azimuth/routers/custom_utterances.py +++ b/azimuth/routers/custom_utterances.py @@ -95,11 +95,13 @@ def get_saliency( utterances: List[str] = Query([], title="Utterances"), pipeline_index: int = Depends(require_pipeline_index), task_manager: TaskManager = Depends(get_task_manager), + config: AzimuthConfig = Depends(get_config), ) -> List[SaliencyResponse]: task_result: List[SaliencyResponse] = get_custom_task_result( SupportedMethod.Saliency, task_manager=task_manager, - custom_query={task_manager.config.columns.text_input: utterances}, + config=config, + custom_query={config.columns.text_input: utterances}, mod_options=ModuleOptions(pipeline_index=pipeline_index), ) diff --git a/azimuth/routers/dataset_warnings.py b/azimuth/routers/dataset_warnings.py index 8ba6dc6c..da872399 100644 --- a/azimuth/routers/dataset_warnings.py +++ b/azimuth/routers/dataset_warnings.py @@ -6,7 +6,8 @@ from fastapi import APIRouter, Depends -from azimuth.app import get_dataset_split_manager_mapping, get_task_manager +from azimuth.app import get_config, get_dataset_split_manager_mapping, get_task_manager +from azimuth.config import AzimuthConfig from azimuth.dataset_split_manager import DatasetSplitManager from azimuth.task_manager import TaskManager from azimuth.types import DatasetSplitName, SupportedModule @@ -25,6 +26,7 @@ ) def get_dataset_warnings( task_manager: TaskManager = Depends(get_task_manager), + config: AzimuthConfig = Depends(get_config), dataset_split_managers: Dict[DatasetSplitName, DatasetSplitManager] = Depends( get_dataset_split_manager_mapping ), @@ -35,6 +37,7 @@ def get_dataset_warnings( dataset_split_name=DatasetSplitName.all, task_manager=task_manager, last_update=get_last_update(list(dataset_split_managers.values())), + config=config )[0] return task_result.warning_groups diff --git a/azimuth/routers/export.py b/azimuth/routers/export.py index 86dac95d..2767c0f4 100644 --- a/azimuth/routers/export.py +++ b/azimuth/routers/export.py @@ -108,18 +108,18 @@ def get_export_perturbation_testing_summary( SupportedModule.PerturbationTestingSummary, DatasetSplitName.all, task_manager=task_manager, + config=config, last_update=last_update, mod_options=ModuleOptions(pipeline_index=pipeline_index), )[0].all_tests_summary - cfg = task_manager.config df = pd.DataFrame.from_records([t.dict() for t in task_result]) df["example"] = df["example"].apply(lambda i: i["perturbedUtterance"]) file_label = time.strftime("%Y%m%d_%H%M%S", time.localtime()) - filename = f"azimuth_export_behavioral_testing_summary_{cfg.name}_{file_label}.csv" + filename = f"azimuth_export_behavioral_testing_summary_{config.name}_{file_label}.csv" - path = pjoin(cfg.get_artifact_path(), filename) + path = pjoin(config.get_artifact_path(), filename) df.to_csv(path, index=False) @@ -143,15 +143,15 @@ def get_export_perturbed_set( ) -> FileResponse: pipeline_index_not_null = assert_not_none(pipeline_index) file_label = time.strftime("%Y%m%d_%H%M%S", time.localtime()) - cfg = task_manager.config - filename = f"azimuth_export_modified_set_{cfg.name}_{dataset_split_name}_{file_label}.json" - path = pjoin(cfg.get_artifact_path(), filename) + filename = f"azimuth_export_modified_set_{config.name}_{dataset_split_name}_{file_label}.json" + path = pjoin(config.get_artifact_path(), filename) task_result: List[List[PerturbedUtteranceResult]] = get_standard_task_result( SupportedModule.PerturbationTesting, dataset_split_name, task_manager, + config=config, mod_options=ModuleOptions(pipeline_index=pipeline_index_not_null), ) diff --git a/azimuth/routers/model_performance/confidence_histogram.py b/azimuth/routers/model_performance/confidence_histogram.py index e8b561b3..459b435f 100644 --- a/azimuth/routers/model_performance/confidence_histogram.py +++ b/azimuth/routers/model_performance/confidence_histogram.py @@ -4,7 +4,8 @@ from fastapi import APIRouter, Depends, Query -from azimuth.app import get_dataset_split_manager, get_task_manager +from azimuth.app import get_config, get_dataset_split_manager, get_task_manager +from azimuth.config import AzimuthConfig from azimuth.dataset_split_manager import DatasetSplitManager from azimuth.task_manager import TaskManager from azimuth.types import ( @@ -33,6 +34,7 @@ def get_confidence_histogram( dataset_split_name: DatasetSplitName, named_filters: NamedDatasetFilters = Depends(build_named_dataset_filters), task_manager: TaskManager = Depends(get_task_manager), + config: AzimuthConfig = Depends(get_config), dataset_split_manager: DatasetSplitManager = Depends(get_dataset_split_manager), pipeline_index: int = Depends(require_pipeline_index), without_postprocessing: bool = Query(False, title="Without Postprocessing"), @@ -47,6 +49,7 @@ def get_confidence_histogram( task_name=SupportedModule.ConfidenceHistogram, dataset_split_name=dataset_split_name, task_manager=task_manager, + config=config, mod_options=mod_options, last_update=dataset_split_manager.last_update, )[0] diff --git a/azimuth/routers/model_performance/confusion_matrix.py b/azimuth/routers/model_performance/confusion_matrix.py index 76424e79..f23d67e3 100644 --- a/azimuth/routers/model_performance/confusion_matrix.py +++ b/azimuth/routers/model_performance/confusion_matrix.py @@ -4,7 +4,8 @@ from fastapi import APIRouter, Depends, Query -from azimuth.app import get_dataset_split_manager, get_task_manager +from azimuth.app import get_config, get_dataset_split_manager, get_task_manager +from azimuth.config import AzimuthConfig from azimuth.dataset_split_manager import DatasetSplitManager from azimuth.task_manager import TaskManager from azimuth.types import ( @@ -33,6 +34,7 @@ def get_confusion_matrix( dataset_split_name: DatasetSplitName, named_filters: NamedDatasetFilters = Depends(build_named_dataset_filters), task_manager: TaskManager = Depends(get_task_manager), + config: AzimuthConfig = Depends(get_config), dataset_split_manager: DatasetSplitManager = Depends(get_dataset_split_manager), pipeline_index: int = Depends(require_pipeline_index), without_postprocessing: bool = Query(False, title="Without Postprocessing"), @@ -51,6 +53,7 @@ def get_confusion_matrix( SupportedModule.ConfusionMatrix, dataset_split_name, task_manager=task_manager, + config=config, mod_options=mod_options, last_update=dataset_split_manager.last_update, )[0] diff --git a/azimuth/routers/model_performance/metrics.py b/azimuth/routers/model_performance/metrics.py index c712a636..7ccb7de9 100644 --- a/azimuth/routers/model_performance/metrics.py +++ b/azimuth/routers/model_performance/metrics.py @@ -5,7 +5,8 @@ from fastapi import APIRouter, Depends, Query -from azimuth.app import get_dataset_split_manager, get_task_manager +from azimuth.app import get_config, get_dataset_split_manager, get_task_manager +from azimuth.config import AzimuthConfig from azimuth.dataset_split_manager import DatasetSplitManager from azimuth.modules.model_performance.metrics import MetricsModule from azimuth.task_manager import TaskManager @@ -41,6 +42,7 @@ def get_metrics( dataset_split_name: DatasetSplitName, named_filters: NamedDatasetFilters = Depends(build_named_dataset_filters), task_manager: TaskManager = Depends(get_task_manager), + config: AzimuthConfig = Depends(get_config), dataset_split_manager: DatasetSplitManager = Depends(get_dataset_split_manager), pipeline_index: int = Depends(require_pipeline_index), without_postprocessing: bool = Query(False, title="Without Postprocessing"), @@ -55,6 +57,7 @@ def get_metrics( SupportedModule.Metrics, dataset_split_name, task_manager, + config=config, mod_options=mod_options, last_update=dataset_split_manager.last_update, ) @@ -73,6 +76,7 @@ def get_metrics( def get_metrics_per_filter( dataset_split_name: DatasetSplitName, task_manager: TaskManager = Depends(get_task_manager), + config: AzimuthConfig = Depends(get_config), dataset_split_manager: DatasetSplitManager = Depends(get_dataset_split_manager), pipeline_index: int = Depends(require_pipeline_index), ) -> MetricsPerFilterAPIResponse: @@ -81,6 +85,7 @@ def get_metrics_per_filter( SupportedModule.MetricsPerFilter, dataset_split_name, task_manager, + config=config, mod_options=mod_options, last_update=dataset_split_manager.last_update, )[0] @@ -89,6 +94,7 @@ def get_metrics_per_filter( SupportedModule.Metrics, dataset_split_name, task_manager, + config=config, mod_options=mod_options, last_update=dataset_split_manager.last_update, )[0] diff --git a/azimuth/routers/model_performance/outcome_count.py b/azimuth/routers/model_performance/outcome_count.py index a490416e..aaf5ce58 100644 --- a/azimuth/routers/model_performance/outcome_count.py +++ b/azimuth/routers/model_performance/outcome_count.py @@ -7,7 +7,8 @@ from fastapi import APIRouter, Depends, HTTPException, Query from starlette.status import HTTP_400_BAD_REQUEST -from azimuth.app import get_dataset_split_manager, get_task_manager +from azimuth.app import get_config, get_dataset_split_manager, get_task_manager +from azimuth.config import AzimuthConfig from azimuth.dataset_split_manager import DatasetSplitManager from azimuth.task_manager import TaskManager from azimuth.types import ( @@ -38,6 +39,7 @@ def get_outcome_count_per_threshold( dataset_split_name: DatasetSplitName, task_manager: TaskManager = Depends(get_task_manager), + config: AzimuthConfig = Depends(get_config), dataset_split_manager: DatasetSplitManager = Depends(get_dataset_split_manager), pipeline_index: int = Depends(require_pipeline_index), ) -> OutcomeCountPerThresholdResponse: @@ -50,6 +52,7 @@ def get_outcome_count_per_threshold( SupportedModule.OutcomeCountPerThreshold, dataset_split_name, task_manager, + config=config, mod_options=mod_options, last_update=dataset_split_manager.last_update, ) @@ -69,6 +72,7 @@ def get_outcome_count_per_filter( dataset_split_name: DatasetSplitName, named_filters: NamedDatasetFilters = Depends(build_named_dataset_filters), task_manager: TaskManager = Depends(get_task_manager), + config: AzimuthConfig = Depends(get_config), dataset_split_manager: DatasetSplitManager = Depends(get_dataset_split_manager), pipeline_index: int = Depends(require_pipeline_index), without_postprocessing: bool = Query(False, title="Without Postprocessing"), @@ -83,6 +87,7 @@ def get_outcome_count_per_filter( SupportedModule.OutcomeCountPerFilter, dataset_split_name, task_manager, + config=config, mod_options=mod_options, last_update=dataset_split_manager.last_update, )[0] diff --git a/azimuth/routers/top_words.py b/azimuth/routers/top_words.py index 0162998f..2e9ab08a 100644 --- a/azimuth/routers/top_words.py +++ b/azimuth/routers/top_words.py @@ -4,7 +4,8 @@ from fastapi import APIRouter, Depends, Query -from azimuth.app import get_dataset_split_manager, get_task_manager +from azimuth.app import get_config, get_dataset_split_manager, get_task_manager +from azimuth.config import AzimuthConfig from azimuth.dataset_split_manager import DatasetSplitManager from azimuth.task_manager import TaskManager from azimuth.types import ( @@ -33,6 +34,7 @@ def get_top_words( dataset_split_name: DatasetSplitName, named_filters: NamedDatasetFilters = Depends(build_named_dataset_filters), task_manager: TaskManager = Depends(get_task_manager), + config: AzimuthConfig = Depends(get_config), dataset_split_manager: DatasetSplitManager = Depends(get_dataset_split_manager), pipeline_index: int = Depends(require_pipeline_index), without_postprocessing: bool = Query(False, title="Without Postprocessing"), @@ -49,6 +51,7 @@ def get_top_words( SupportedModule.TopWords, dataset_split_name, task_manager, + config=config, mod_options=mod_options, last_update=dataset_split_manager.last_update, )[0] diff --git a/azimuth/routers/utterances.py b/azimuth/routers/utterances.py index 8abc3ff4..30ddad28 100644 --- a/azimuth/routers/utterances.py +++ b/azimuth/routers/utterances.py @@ -34,13 +34,8 @@ SimilarUtterance, SimilarUtterancesResponse, ) -from azimuth.types.tag import ( - ALL_DATA_ACTIONS, - DATASET_SMART_TAG_FAMILIES, - PIPELINE_SMART_TAG_FAMILIES, - SMART_TAGS_FAMILY_MAPPING, - DataAction, -) +from azimuth.types.tag import (ALL_DATA_ACTIONS, DATASET_SMART_TAG_FAMILIES, DataAction, + PIPELINE_SMART_TAG_FAMILIES, SMART_TAGS_FAMILY_MAPPING) from azimuth.types.utterance import ( GetUtterancesResponse, ModelPrediction, @@ -101,7 +96,7 @@ def get_utterances( if predictions_available(config) and pipeline_index is not None: threshold = ( assert_not_none(config.pipelines)[pipeline_index].threshold - if postprocessing_known(task_manager.config, pipeline_index) + if postprocessing_known(config, pipeline_index) else None ) table_key = PredictionTableKey.from_pipeline_index(pipeline_index, config) @@ -164,6 +159,7 @@ def get_utterances( SupportedMethod.Saliency, dataset_split_name, task_manager, + config=config, last_update=dataset_split_manager.last_update, mod_options=ModuleOptions( pipeline_index=pipeline_index, indices=ds[DatasetColumn.row_idx] @@ -276,6 +272,7 @@ def get_perturbed_utterances( dataset_split_name: DatasetSplitName, index: int, task_manager: TaskManager = Depends(get_task_manager), + config: AzimuthConfig = Depends(get_config), dataset_split_manager: DatasetSplitManager = Depends(get_dataset_split_manager), pipeline_index: int = Depends(require_pipeline_index), ) -> List[PerturbedUtteranceWithClassNames]: @@ -286,13 +283,14 @@ def get_perturbed_utterances( For endpoints that support per index request, we will not return a list of result. """ - if not perturbation_testing_available(task_manager.config): + if not perturbation_testing_available(config): return [] response: List[PerturbedUtteranceResult] = get_standard_task_result( SupportedModule.PerturbationTesting, dataset_split_name, task_manager, + config=config, last_update=dataset_split_manager.last_update, mod_options=ModuleOptions(pipeline_index=pipeline_index, indices=[index]), )[0] diff --git a/azimuth/startup.py b/azimuth/startup.py index a8c3fadc..7319e20e 100644 --- a/azimuth/startup.py +++ b/azimuth/startup.py @@ -159,6 +159,7 @@ def on_end(fut: Future, module: DaskModule, dm: DatasetSplitManager, task_manage def make_startup_tasks( dataset_split_managers: Dict[DatasetSplitName, Optional[DatasetSplitManager]], task_manager: TaskManager, + config: AzimuthConfig, supported_module: SupportedTask, mod_options: Dict, dependencies: List[DaskModule], @@ -171,6 +172,7 @@ def make_startup_tasks( Args: dataset_split_managers: loaded dataset_split_managers. task_manager: Initialized task managers. + config: Config. supported_module: A Module to instantiate. mod_options: Special kwargs for the Module. dependencies: Which modules to include as dependency. @@ -195,6 +197,7 @@ def make_startup_tasks( _, maybe_task = task_manager.get_task( task_name=supported_module, dataset_split_name=dataset_split_name, + config=config, dependencies=dependencies, mod_options=ModuleOptions(pipeline_index=pipeline_index, **mod_options), ) @@ -218,35 +221,36 @@ def get_modules(module_objects: Dict[str, DaskModule], deps_name: List[str]): def startup_tasks( dataset_split_managers: Dict[DatasetSplitName, Optional[DatasetSplitManager]], task_manager: TaskManager, + config: AzimuthConfig, ) -> Dict[str, DaskModule]: """Create and launch all startup tasks. Args: dataset_split_managers: Dataset Managers. task_manager: Task Manager. + config: Config. Returns: Modules with their names. """ - config = task_manager.config start_up_tasks = [ Startup("syntax_tags", SupportedModule.SyntaxTagging), ] if predictions_available(config): start_up_tasks += BASE_PREDICTION_TASKS - if perturbation_testing_available(task_manager.config): + if perturbation_testing_available(config): start_up_tasks += PERTURBATION_TESTING_TASKS - if task_manager.config.uncertainty.iterations > 1: + if config.uncertainty.iterations > 1: start_up_tasks += BMA_PREDICTION_TASKS if config.pipelines is not None and len(config.pipelines) > 1: start_up_tasks += PIPELINE_COMPARISON_TASKS # TODO We only check pipeline_index=0, but we should check all pipelines. - if postprocessing_editable(task_manager.config, 0): + if postprocessing_editable(config, 0): start_up_tasks += POSTPROCESSING_TASKS - if saliency_available(task_manager.config): + if saliency_available(config): start_up_tasks += SALIENCY_TASKS - if similarity_available(task_manager.config): + if similarity_available(config): start_up_tasks += SIMILARITY_TASKS mods = start_tasks_for_dms(config, dataset_split_managers, task_manager, start_up_tasks) @@ -296,7 +300,8 @@ def start_tasks_for_dms( for k, v in make_startup_tasks( dataset_split_managers, task_manager, - startup.module, + config=config, + supported_module=startup.module, mod_options=startup.mod_options, dependencies=dep_mods, pipeline_index=pipeline_index, diff --git a/azimuth/task_manager.py b/azimuth/task_manager.py index 28f9b1d2..823f1909 100644 --- a/azimuth/task_manager.py +++ b/azimuth/task_manager.py @@ -28,14 +28,12 @@ class TaskManager: """The Task Manager responsibility is to start tasks and scale the Cluster as needed. Args: - config: The application configuration. - cluster: Dask cluster to use, we will spawn a default one if not provided. - + cluster: Dask cluster to use, if different than default. + large_dask_cluster: Specify a large default dask cluster. """ - def __init__(self, config: AzimuthConfig, cluster: Optional[SpecCluster] = None): - self.config = config - self.cluster = cluster or default_cluster(large=config.large_dask_cluster) + def __init__(self, cluster: Optional[SpecCluster] = None, large_dask_cluster: bool = False): + self.cluster = cluster or default_cluster(large=large_dask_cluster) self.client = Client(cluster) self.tasks: Dict[str, type] = {} self.current_tasks: Dict[str, DaskModule] = {} @@ -102,6 +100,7 @@ def get_task( self, task_name: SupportedTask, dataset_split_name: DatasetSplitName, + config: AzimuthConfig, mod_options: Optional[ModuleOptions] = None, last_update: int = -1, dependencies: Optional[List[DaskModule]] = None, @@ -113,6 +112,7 @@ def get_task( Args: task_name: Name of the task. dataset_split_name: Which dataset split to use. + config: Config. mod_options: Options for the module. last_update: Last known update of the dataset_split. dependencies: Which Modules should complete before this one. @@ -139,7 +139,7 @@ def get_task( task: DaskModule = task_cls( dataset_split_name=dataset_split_name, - config=self.config.copy(deep=True), + config=config.copy(deep=True), mod_options=mod_options, ) # Check if this task already exist. @@ -164,6 +164,7 @@ def get_custom_task( self, task_name: SupportedTask, custom_query: Dict[str, Any], + config: AzimuthConfig, mod_options: Optional[ModuleOptions] = None, ) -> Tuple[str, Optional[DaskModule]]: """Get the task `name` run on a custom query. @@ -173,6 +174,7 @@ def get_custom_task( Args: task_name: Name of the task. custom_query: Query fed to the Module. + config: Config. mod_options: Options for the module. Returns: @@ -193,7 +195,7 @@ def get_custom_task( # We found the task, we can instantiate it. task: DaskModule = task_cls( dataset_split_name=DatasetSplitName.eval, # Placeholder - config=self.config.copy(deep=True), + config=config.copy(deep=True), mod_options=mod_options, ) # Check if this task already exist. @@ -215,7 +217,6 @@ def status(self): cluster = (self.cluster.workers,) return { "cluster": cluster, - "config": self.config.dict(), **self.get_all_tasks_status(task=None), } diff --git a/azimuth/utils/routers.py b/azimuth/utils/routers.py index 9626f442..0922735f 100644 --- a/azimuth/utils/routers.py +++ b/azimuth/utils/routers.py @@ -96,6 +96,7 @@ def get_standard_task_result( task_name: SupportedTask, dataset_split_name: DatasetSplitName, task_manager: TaskManager, + config: AzimuthConfig, mod_options: Optional[ModuleOptions] = None, last_update: int = -1, ): @@ -105,6 +106,7 @@ def get_standard_task_result( task_name: The task name e.g. ConfidenceHistogram dataset_split_name: e.g. DatasetSplitName.eval task_manager: The task manager + config: Config mod_options: Module options to pass to the task launcher last_update: The last known update for this dataset_split, to know if we need to recompute. @@ -118,6 +120,7 @@ def get_standard_task_result( _, task = task_manager.get_task( task_name=task_name, dataset_split_name=dataset_split_name, + config=config, mod_options=mod_options, last_update=last_update, ) @@ -135,6 +138,7 @@ def get_standard_task_result( def get_custom_task_result( task_name: SupportedTask, task_manager: TaskManager, + config: AzimuthConfig, custom_query: Dict[str, Any], mod_options: Optional[ModuleOptions] = None, ): @@ -143,6 +147,7 @@ def get_custom_task_result( Args: task_name: The task name e.g. ConfidenceHistogram task_manager: The task manager + config: Config custom_query: A dictionary specifying the custom query mod_options: Options to pass to the task launcher @@ -153,7 +158,10 @@ def get_custom_task_result( HTTPException when the requested module doesn't exist. """ _, task = task_manager.get_custom_task( - task_name=task_name, custom_query=custom_query, mod_options=mod_options + task_name=task_name, + config=config, + custom_query=custom_query, + mod_options=mod_options, ) if not task: raise HTTPException(status_code=HTTP_404_NOT_FOUND, detail="Aggregation not found") diff --git a/tests/README.md b/tests/README.md index cb4ba694..6e75edd5 100644 --- a/tests/README.md +++ b/tests/README.md @@ -72,5 +72,4 @@ Some tests need to test other model contracts than `hf_text_classification`, or ### Task Manager and Dask Client * If your test needs a Dask Client, you can add the fixture `dask_client`. -* If your test needs a TaskManager (most don't), `tiny_text_task_manager` is set up to work with - the `tiny_text_config`. Use both for the tests. +* If your test needs a TaskManager (most don't), you can use `task_manager`. diff --git a/tests/conftest.py b/tests/conftest.py index d1b81004..11a59551 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -134,8 +134,8 @@ def tiny_text_config_no_postprocessor(tiny_text_config) -> AzimuthConfig: @pytest.fixture -def tiny_text_task_manager(tiny_text_config, dask_client): - task_manager = TaskManager(tiny_text_config, cluster=dask_client.cluster) +def task_manager(dask_client): + task_manager = TaskManager(cluster=dask_client.cluster) yield task_manager diff --git a/tests/test_routers/conftest.py b/tests/test_routers/conftest.py index f3a82f17..2aa4efda 100644 --- a/tests/test_routers/conftest.py +++ b/tests/test_routers/conftest.py @@ -8,7 +8,7 @@ from fastapi import FastAPI from starlette.testclient import TestClient -import azimuth.app as me_app +import azimuth.app as azimuth_app from azimuth.app import get_ready_flag from azimuth.config import AzimuthConfig from tests.utils import DATASET_CFG, SIMPLE_PERTURBATION_TESTING_CONFIG @@ -24,7 +24,7 @@ def is_set(self): def create_test_app(config) -> FastAPI: json.dump(config.dict(by_alias=True), open("/tmp/config.json", "w")) - return me_app.start_app("/tmp/config.json", load_config_history=False, debug=False) + return azimuth_app.start_app("/tmp/config.json", load_config_history=False, debug=False) FAST_TEST_CFG = { @@ -44,7 +44,7 @@ def wait_for_startup_after(app): while resp.json()["startupTasksReady"] is not True: time.sleep(1) resp = client.get("/status") - task_manager = me_app.get_task_manager() + task_manager = azimuth_app.get_task_manager() while task_manager.is_locked: time.sleep(1) @@ -70,7 +70,7 @@ def app() -> FastAPI: while resp.json()["startupTasksReady"] is not True: time.sleep(1) resp = client.get("/status") - task_manager = me_app.get_task_manager() + task_manager = azimuth_app.get_task_manager() while task_manager.is_locked: time.sleep(1) yield _app @@ -79,7 +79,7 @@ def app() -> FastAPI: @pytest.fixture(scope="function") def app_not_started(app) -> FastAPI: - startup_tasks = me_app.get_startup_tasks() + startup_tasks = azimuth_app.get_startup_tasks() class ModuleThatWillNeverEnd: def status(self): diff --git a/tests/test_startup.py b/tests/test_startup.py index 9f7f01d3..8e2412d2 100644 --- a/tests/test_startup.py +++ b/tests/test_startup.py @@ -20,12 +20,12 @@ from tests.utils import get_table_key, get_tiny_text_config_one_ds_name -def test_startup_task(tiny_text_config, tiny_text_task_manager): +def test_startup_task(tiny_text_config, task_manager): dms = load_dataset_split_managers_from_config(tiny_text_config) - mods = startup_tasks(dms, tiny_text_task_manager) + mods = startup_tasks(dms, task_manager, tiny_text_config) one_mod = mods["syntax_tags_eval"] # We lock the task manager - assert tiny_text_task_manager.is_locked + assert task_manager.is_locked assert not one_mod.done() assert all("train" in k or "eval" in k or "all" in k for k in mods.keys()) assert all( @@ -34,11 +34,11 @@ def test_startup_task(tiny_text_config, tiny_text_task_manager): assert len(mods) == 19 -def test_startup_task_fast(tiny_text_config, tiny_text_task_manager): +def test_startup_task_fast(tiny_text_config, task_manager): tiny_text_config.behavioral_testing = None tiny_text_config.similarity = None dms = load_dataset_split_managers_from_config(tiny_text_config) - mods = startup_tasks(dms, tiny_text_task_manager) + mods = startup_tasks(dms, task_manager, tiny_text_config) assert not any( mod.task_name in (SupportedModule.PerturbationTesting, SupportedModule.NeighborsTagging) @@ -69,11 +69,11 @@ def test_on_end(tiny_text_config): task_manager.clear_worker_cache.assert_called_once() -def test_startup_task_one_ds(tiny_text_config_one_ds, tiny_text_task_manager): +def test_startup_task_one_ds(tiny_text_config_one_ds, task_manager): dms = load_dataset_split_managers_from_config(tiny_text_config_one_ds) assert DatasetSplitName.eval in dms and DatasetSplitName.train in dms - mods = startup_tasks(dms, tiny_text_task_manager) + mods = startup_tasks(dms, task_manager, tiny_text_config_one_ds) ds_name, other_ds_name = get_tiny_text_config_one_ds_name(tiny_text_config_one_ds) assert all( (DatasetSplitName.all in k or ds_name in k) and other_ds_name not in k for k in mods.keys() @@ -84,10 +84,10 @@ def test_startup_task_one_ds(tiny_text_config_one_ds, tiny_text_task_manager): @pytest.mark.parametrize("iterations", [20, 1]) -def test_startup_task_bma(tiny_text_config, tiny_text_task_manager, iterations): +def test_startup_task_bma(tiny_text_config, task_manager, iterations): tiny_text_config.uncertainty.iterations = iterations dms = load_dataset_split_managers_from_config(tiny_text_config) - mods = startup_tasks(dms, tiny_text_task_manager) + mods = startup_tasks(dms, task_manager, tiny_text_config) # Check that we have BMA at some point if iterations > 1 if iterations == 1: diff --git a/tests/test_task_manager.py b/tests/test_task_manager.py index 6207b726..bc87d65f 100644 --- a/tests/test_task_manager.py +++ b/tests/test_task_manager.py @@ -17,11 +17,12 @@ ) -def test_get_all_task(tiny_text_task_manager): +def test_get_all_task(task_manager, tiny_text_config): # We can find a task - key, mod = tiny_text_task_manager.get_task( + key, mod = task_manager.get_task( SupportedMethod.Inputs, dataset_split_name=DatasetSplitName.eval, + config=tiny_text_config, mod_options=ModuleOptions(pipeline_index=0, indices=[0]), ) assert mod is not None @@ -29,28 +30,29 @@ def test_get_all_task(tiny_text_task_manager): mod.result() # The task history is logged - tasks = tiny_text_task_manager.get_all_tasks_status("Inputs") + tasks = task_manager.get_all_tasks_status("Inputs") assert len(tasks) == 1 # We can make a new task! - key, mod = tiny_text_task_manager.get_task( + key, mod = task_manager.get_task( SupportedModule.SyntaxTagging, dataset_split_name=DatasetSplitName.eval, + config=tiny_text_config, mod_options=ModuleOptions(indices=[0]), ) assert mod is not None # The length stays the same - assert len(tiny_text_task_manager.get_all_tasks_status("Inputs")) == 1 - assert len(tiny_text_task_manager.get_all_tasks_status("SyntaxTagging")) == 1 - assert len(tiny_text_task_manager.get_all_tasks_status(task=None)) == 2 + assert len(task_manager.get_all_tasks_status("Inputs")) == 1 + assert len(task_manager.get_all_tasks_status("SyntaxTagging")) == 1 + assert len(task_manager.get_all_tasks_status(task=None)) == 2 # We can get info on the cluster - info = tiny_text_task_manager.status() + info = task_manager.status() assert "cluster" in info # If we request something that does not exists, it is None - key, mod = tiny_text_task_manager.get_task("allo", dataset_split_name=DatasetSplitName.eval) + key, mod = task_manager.get_task("allo", DatasetSplitName.eval, tiny_text_config) assert mod is None @@ -64,11 +66,12 @@ def get_module_data(simple_text_config): def test_clearing_cache(tiny_text_config): - task_manager = TaskManager(tiny_text_config) + task_manager = TaskManager() key, mod = task_manager.get_task( SupportedMethod.Predictions, dataset_split_name=DatasetSplitName.eval, + config=tiny_text_config, mod_options=ModuleOptions(pipeline_index=0, indices=[0, 1]), ) assert mod is not None @@ -86,7 +89,7 @@ def test_clearing_cache(tiny_text_config): assert not any([m and d for m, d in cached.values()]) -def test_expired_task(tiny_text_task_manager, tiny_text_config): +def test_expired_task(task_manager, tiny_text_config): class ExpirableModule(FilterableModule[SyntaxConfig]): def compute(self, batch): return ["ExpirableModule"] @@ -97,87 +100,99 @@ def compute(self, batch): current_update = time.time() - tiny_text_task_manager.register_task("ExpirableModule", ExpirableModule) - tiny_text_task_manager.register_task("NotExpirableModule", NotExpirableModule) + task_manager.register_task("ExpirableModule", ExpirableModule) + task_manager.register_task("NotExpirableModule", NotExpirableModule) - _, not_expirable_task = tiny_text_task_manager.get_task( + _, not_expirable_task = task_manager.get_task( "NotExpirableModule", dataset_split_name=DatasetSplitName.eval, + config=tiny_text_config, last_update=current_update, ) # Get an expirable task - _, expirable_task = tiny_text_task_manager.get_task( + _, expirable_task = task_manager.get_task( "ExpirableModule", dataset_split_name=DatasetSplitName.eval, mod_options=ModuleOptions(pipeline_index=0), + config=tiny_text_config, last_update=current_update, ) assert not not_expirable_task.done() and not expirable_task.done() _ = expirable_task.wait(), not_expirable_task.wait() # If we don't change the time, the task are cached - _, not_expirable_task = tiny_text_task_manager.get_task( + _, not_expirable_task = task_manager.get_task( "NotExpirableModule", dataset_split_name=DatasetSplitName.eval, + config=tiny_text_config, last_update=current_update, ) # Get an expirable task - _, expirable_task = tiny_text_task_manager.get_task( + _, expirable_task = task_manager.get_task( "ExpirableModule", dataset_split_name=DatasetSplitName.eval, mod_options=ModuleOptions(pipeline_index=0), + config=tiny_text_config, last_update=current_update, ) assert not_expirable_task.done() and expirable_task.done() # If we update the dataset_split, the expirable task will be recomputed current_update = time.time() - _, not_expirable_task = tiny_text_task_manager.get_task( + _, not_expirable_task = task_manager.get_task( "NotExpirableModule", dataset_split_name=DatasetSplitName.eval, + config=tiny_text_config, last_update=current_update, ) # Get an expirable task - _, expirable_task = tiny_text_task_manager.get_task( + _, expirable_task = task_manager.get_task( "ExpirableModule", dataset_split_name=DatasetSplitName.eval, mod_options=ModuleOptions(pipeline_index=0), + config=tiny_text_config, last_update=current_update, ) assert not_expirable_task.done() and not expirable_task.done() -def test_lock(tiny_text_task_manager): - assert not tiny_text_task_manager.is_locked +def test_lock(task_manager, tiny_text_config): + assert not task_manager.is_locked # Can't unlock a TaskManager that is Unlock with pytest.raises(ValueError): - tiny_text_task_manager.unlock() - tiny_text_task_manager.lock() + task_manager.unlock() + task_manager.lock() # Can't lock a TaskManager that is lock with pytest.raises(ValueError): - tiny_text_task_manager.lock() + task_manager.lock() # Cant schedule task: with pytest.raises(TaskManagerLockedException): - tiny_text_task_manager.get_task( + task_manager.get_task( SupportedMethod.Predictions, dataset_split_name=DatasetSplitName.eval, + config=tiny_text_config, ) - tiny_text_task_manager.unlock() - _ = tiny_text_task_manager.get_task( + task_manager.unlock() + _ = task_manager.get_task( SupportedMethod.Predictions, dataset_split_name=DatasetSplitName.eval, + config=tiny_text_config, mod_options=ModuleOptions(pipeline_index=0, indices=[0, 1]), ) -def test_custom_query(tiny_text_task_manager): - _, pred_task = tiny_text_task_manager.get_custom_task( +def test_custom_query( + task_manager, + tiny_text_config, +): + _, pred_task = task_manager.get_custom_task( SupportedMethod.Predictions, custom_query={ - tiny_text_task_manager.config.columns.text_input: ["hello, this is fred"], - tiny_text_task_manager.config.columns.label: [-1], + tiny_text_config.columns.text_input: ["hello, this is fred"], + tiny_text_config.columns.label: [-1], }, + config=tiny_text_config, mod_options=ModuleOptions(pipeline_index=0), ) assert not pred_task.done() @@ -186,12 +201,13 @@ def test_custom_query(tiny_text_task_manager): # Wait for callbacks pred_task.wait() # Check that we have the same module - _, pred_task2 = tiny_text_task_manager.get_custom_task( + _, pred_task2 = task_manager.get_custom_task( SupportedMethod.Predictions, custom_query={ - tiny_text_task_manager.config.columns.text_input: ["hello, this is fred"], - tiny_text_task_manager.config.columns.label: [-1], + tiny_text_config.columns.text_input: ["hello, this is fred"], + tiny_text_config.columns.label: [-1], }, + config=tiny_text_config, mod_options=ModuleOptions(pipeline_index=0), ) assert pred_task2 is pred_task