Skip to content

Commit

Permalink
Remove config from task manager and stop killing it
Browse files Browse the repository at this point in the history
  • Loading branch information
gabegma committed Apr 9, 2023
1 parent 4095a73 commit 931f444
Show file tree
Hide file tree
Showing 21 changed files with 154 additions and 102 deletions.
37 changes: 15 additions & 22 deletions azimuth/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
from azimuth.startup import startup_tasks
from azimuth.task_manager import TaskManager
from azimuth.types import DatasetSplitName, ModuleOptions
from azimuth.utils.cluster import default_cluster
from azimuth.utils.conversion import JSONResponseIgnoreNan
from azimuth.utils.logs import set_logger_config
from azimuth.utils.validation import assert_not_none
Expand Down Expand Up @@ -148,9 +147,7 @@ def start_app(config_path: Optional[str], load_config_history: bool, debug: bool
if azimuth_config.dataset is None:
raise ValueError("No dataset has been specified in the config.")

local_cluster = default_cluster(large=azimuth_config.large_dask_cluster)

run_startup_tasks(azimuth_config, local_cluster)
run_startup_tasks(azimuth_config)
task_manager = assert_not_none(_task_manager)
task_manager.client.run(set_logger_config, level)

Expand Down Expand Up @@ -322,25 +319,22 @@ def load_dataset_split_managers_from_config(
}


def initialize_managers(azimuth_config: AzimuthConfig, cluster: SpecCluster):
"""Initialize DatasetSplitManagers and TaskManagers.
def initialize_managers_and_config(
azimuth_config: AzimuthConfig, cluster: Optional[SpecCluster] = None
):
"""Initialize DatasetSplitManagers and Config.
Args:
azimuth_config: Configuration
cluster: Dask cluster to use.
azimuth_config: Config
cluster: Dask cluster to use, if different than default.
"""
global _task_manager, _dataset_split_managers, _azimuth_config
_azimuth_config = azimuth_config
if _task_manager is not None:
task_history = _task_manager.current_tasks
if _task_manager:
_task_manager.restart()
else:
task_history = {}

_task_manager = TaskManager(azimuth_config, cluster=cluster)

_task_manager.current_tasks = task_history
_task_manager = TaskManager(cluster, azimuth_config.large_dask_cluster)

_azimuth_config = azimuth_config
_dataset_split_managers = load_dataset_split_managers_from_config(azimuth_config)


Expand Down Expand Up @@ -375,15 +369,14 @@ def run_validation_module(pipeline_index=None):
run_validation_module(pipeline_index)


def run_startup_tasks(azimuth_config: AzimuthConfig, cluster: SpecCluster):
def run_startup_tasks(azimuth_config: AzimuthConfig, cluster: Optional[SpecCluster] = None):
"""Initialize managers, run validation and startup tasks.
Args:
azimuth_config: Config
cluster: Cluster
cluster: Dask cluster to use, if different than default.
"""
initialize_managers(azimuth_config, cluster)
initialize_managers_and_config(azimuth_config, cluster)

task_manager = assert_not_none(get_task_manager())
# Validate that everything is in order **before** the startup tasks.
Expand All @@ -396,5 +389,5 @@ def run_startup_tasks(azimuth_config: AzimuthConfig, cluster: SpecCluster):
azimuth_config.save() # Save only after the validation modules ran successfully

global _startup_tasks, _ready_flag
_startup_tasks = startup_tasks(_dataset_split_managers, task_manager)
_startup_tasks = startup_tasks(_dataset_split_managers, task_manager, azimuth_config)
_ready_flag = Event()
2 changes: 2 additions & 0 deletions azimuth/routers/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@ def get_perturbation_testing_summary(
SupportedModule.PerturbationTestingMerged,
dataset_split_name=DatasetSplitName.all,
task_manager=task_manager,
config=config,
last_update=last_update,
mod_options=ModuleOptions(pipeline_index=pipeline_index),
)[0]
Expand All @@ -180,6 +181,7 @@ def get_perturbation_testing_summary(
SupportedModule.PerturbationTestingSummary,
dataset_split_name=DatasetSplitName.all,
task_manager=task_manager,
config=config,
mod_options=ModuleOptions(pipeline_index=pipeline_index),
)[0]
return PerturbationTestingSummary(
Expand Down
4 changes: 4 additions & 0 deletions azimuth/routers/class_overlap.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def get_class_overlap_plot(
SupportedModule.ClassOverlap,
dataset_split_name=dataset_split_name,
task_manager=task_manager,
config=config,
last_update=-1,
)[0]
class_overlap_plot_response: ClassOverlapPlotResponse = make_sankey_plot(
Expand All @@ -94,6 +95,7 @@ def get_class_overlap_plot(
def get_class_overlap(
dataset_split_name: DatasetSplitName,
task_manager: TaskManager = Depends(get_task_manager),
config: AzimuthConfig = Depends(get_config),
dataset_split_manager: DatasetSplitManager = Depends(get_dataset_split_manager),
dataset_split_managers: Dict[DatasetSplitName, DatasetSplitManager] = Depends(
get_dataset_split_manager_mapping
Expand All @@ -106,6 +108,7 @@ def get_class_overlap(
SupportedModule.ClassOverlap,
dataset_split_name=dataset_split_name,
task_manager=task_manager,
config=config,
last_update=-1,
)[0]
dataset_class_count = class_overlap_result.s_matrix.shape[0]
Expand All @@ -121,6 +124,7 @@ def get_class_overlap(
SupportedModule.ConfusionMatrix,
DatasetSplitName.eval,
task_manager=task_manager,
config=config,
mod_options=ModuleOptions(
pipeline_index=pipeline_index, cf_normalize=False, cf_reorder_classes=False
),
Expand Down
4 changes: 2 additions & 2 deletions azimuth/routers/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from azimuth.app import (
get_config,
get_task_manager,
initialize_managers,
initialize_managers_and_config,
require_editable_config,
run_startup_tasks,
)
Expand Down Expand Up @@ -89,7 +89,7 @@ def patch_config(
except Exception as e:
log.error("Rollback config update due to error", exc_info=e)
new_config = config
initialize_managers(new_config, task_manager.cluster)
initialize_managers_and_config(new_config, task_manager.cluster)
log.info("Config update cancelled.")
if isinstance(e, (AzimuthValidationError, ValidationError)):
raise HTTPException(HTTP_400_BAD_REQUEST, detail=str(e))
Expand Down
4 changes: 3 additions & 1 deletion azimuth/routers/custom_utterances.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,11 +95,13 @@ def get_saliency(
utterances: List[str] = Query([], title="Utterances"),
pipeline_index: int = Depends(require_pipeline_index),
task_manager: TaskManager = Depends(get_task_manager),
config: AzimuthConfig = Depends(get_config),
) -> List[SaliencyResponse]:
task_result: List[SaliencyResponse] = get_custom_task_result(
SupportedMethod.Saliency,
task_manager=task_manager,
custom_query={task_manager.config.columns.text_input: utterances},
config=config,
custom_query={config.columns.text_input: utterances},
mod_options=ModuleOptions(pipeline_index=pipeline_index),
)

Expand Down
5 changes: 4 additions & 1 deletion azimuth/routers/dataset_warnings.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@

from fastapi import APIRouter, Depends

from azimuth.app import get_dataset_split_manager_mapping, get_task_manager
from azimuth.app import get_config, get_dataset_split_manager_mapping, get_task_manager
from azimuth.config import AzimuthConfig
from azimuth.dataset_split_manager import DatasetSplitManager
from azimuth.task_manager import TaskManager
from azimuth.types import DatasetSplitName, SupportedModule
Expand All @@ -25,6 +26,7 @@
)
def get_dataset_warnings(
task_manager: TaskManager = Depends(get_task_manager),
config: AzimuthConfig = Depends(get_config),
dataset_split_managers: Dict[DatasetSplitName, DatasetSplitManager] = Depends(
get_dataset_split_manager_mapping
),
Expand All @@ -35,6 +37,7 @@ def get_dataset_warnings(
dataset_split_name=DatasetSplitName.all,
task_manager=task_manager,
last_update=get_last_update(list(dataset_split_managers.values())),
config=config
)[0]

return task_result.warning_groups
2 changes: 2 additions & 0 deletions azimuth/routers/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ def get_export_perturbation_testing_summary(
SupportedModule.PerturbationTestingSummary,
DatasetSplitName.all,
task_manager=task_manager,
config=config,
last_update=last_update,
mod_options=ModuleOptions(pipeline_index=pipeline_index),
)[0].all_tests_summary
Expand Down Expand Up @@ -150,6 +151,7 @@ def get_export_perturbed_set(
SupportedModule.PerturbationTesting,
dataset_split_name,
task_manager,
config=config,
mod_options=ModuleOptions(pipeline_index=pipeline_index_not_null),
)

Expand Down
5 changes: 4 additions & 1 deletion azimuth/routers/model_performance/confidence_histogram.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@

from fastapi import APIRouter, Depends, Query

from azimuth.app import get_dataset_split_manager, get_task_manager
from azimuth.app import get_config, get_dataset_split_manager, get_task_manager
from azimuth.config import AzimuthConfig
from azimuth.dataset_split_manager import DatasetSplitManager
from azimuth.task_manager import TaskManager
from azimuth.types import (
Expand Down Expand Up @@ -33,6 +34,7 @@ def get_confidence_histogram(
dataset_split_name: DatasetSplitName,
named_filters: NamedDatasetFilters = Depends(build_named_dataset_filters),
task_manager: TaskManager = Depends(get_task_manager),
config: AzimuthConfig = Depends(get_config),
dataset_split_manager: DatasetSplitManager = Depends(get_dataset_split_manager),
pipeline_index: int = Depends(require_pipeline_index),
without_postprocessing: bool = Query(False, title="Without Postprocessing"),
Expand All @@ -47,6 +49,7 @@ def get_confidence_histogram(
task_name=SupportedModule.ConfidenceHistogram,
dataset_split_name=dataset_split_name,
task_manager=task_manager,
config=config,
mod_options=mod_options,
last_update=dataset_split_manager.last_update,
)[0]
Expand Down
5 changes: 4 additions & 1 deletion azimuth/routers/model_performance/confusion_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@

from fastapi import APIRouter, Depends, Query

from azimuth.app import get_dataset_split_manager, get_task_manager
from azimuth.app import get_config, get_dataset_split_manager, get_task_manager
from azimuth.config import AzimuthConfig
from azimuth.dataset_split_manager import DatasetSplitManager
from azimuth.task_manager import TaskManager
from azimuth.types import (
Expand Down Expand Up @@ -33,6 +34,7 @@ def get_confusion_matrix(
dataset_split_name: DatasetSplitName,
named_filters: NamedDatasetFilters = Depends(build_named_dataset_filters),
task_manager: TaskManager = Depends(get_task_manager),
config: AzimuthConfig = Depends(get_config),
dataset_split_manager: DatasetSplitManager = Depends(get_dataset_split_manager),
pipeline_index: int = Depends(require_pipeline_index),
without_postprocessing: bool = Query(False, title="Without Postprocessing"),
Expand All @@ -51,6 +53,7 @@ def get_confusion_matrix(
SupportedModule.ConfusionMatrix,
dataset_split_name,
task_manager=task_manager,
config=config,
mod_options=mod_options,
last_update=dataset_split_manager.last_update,
)[0]
Expand Down
8 changes: 7 additions & 1 deletion azimuth/routers/model_performance/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@

from fastapi import APIRouter, Depends, Query

from azimuth.app import get_dataset_split_manager, get_task_manager
from azimuth.app import get_config, get_dataset_split_manager, get_task_manager
from azimuth.config import AzimuthConfig
from azimuth.dataset_split_manager import DatasetSplitManager
from azimuth.modules.model_performance.metrics import MetricsModule
from azimuth.task_manager import TaskManager
Expand Down Expand Up @@ -41,6 +42,7 @@ def get_metrics(
dataset_split_name: DatasetSplitName,
named_filters: NamedDatasetFilters = Depends(build_named_dataset_filters),
task_manager: TaskManager = Depends(get_task_manager),
config: AzimuthConfig = Depends(get_config),
dataset_split_manager: DatasetSplitManager = Depends(get_dataset_split_manager),
pipeline_index: int = Depends(require_pipeline_index),
without_postprocessing: bool = Query(False, title="Without Postprocessing"),
Expand All @@ -55,6 +57,7 @@ def get_metrics(
SupportedModule.Metrics,
dataset_split_name,
task_manager,
config=config,
mod_options=mod_options,
last_update=dataset_split_manager.last_update,
)
Expand All @@ -73,6 +76,7 @@ def get_metrics(
def get_metrics_per_filter(
dataset_split_name: DatasetSplitName,
task_manager: TaskManager = Depends(get_task_manager),
config: AzimuthConfig = Depends(get_config),
dataset_split_manager: DatasetSplitManager = Depends(get_dataset_split_manager),
pipeline_index: int = Depends(require_pipeline_index),
) -> MetricsPerFilterAPIResponse:
Expand All @@ -81,6 +85,7 @@ def get_metrics_per_filter(
SupportedModule.MetricsPerFilter,
dataset_split_name,
task_manager,
config=config,
mod_options=mod_options,
last_update=dataset_split_manager.last_update,
)[0]
Expand All @@ -89,6 +94,7 @@ def get_metrics_per_filter(
SupportedModule.Metrics,
dataset_split_name,
task_manager,
config=config,
mod_options=mod_options,
last_update=dataset_split_manager.last_update,
)[0]
Expand Down
7 changes: 6 additions & 1 deletion azimuth/routers/model_performance/outcome_count.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
from fastapi import APIRouter, Depends, HTTPException, Query
from starlette.status import HTTP_400_BAD_REQUEST

from azimuth.app import get_dataset_split_manager, get_task_manager
from azimuth.app import get_config, get_dataset_split_manager, get_task_manager
from azimuth.config import AzimuthConfig
from azimuth.dataset_split_manager import DatasetSplitManager
from azimuth.task_manager import TaskManager
from azimuth.types import (
Expand Down Expand Up @@ -38,6 +39,7 @@
def get_outcome_count_per_threshold(
dataset_split_name: DatasetSplitName,
task_manager: TaskManager = Depends(get_task_manager),
config: AzimuthConfig = Depends(get_config),
dataset_split_manager: DatasetSplitManager = Depends(get_dataset_split_manager),
pipeline_index: int = Depends(require_pipeline_index),
) -> OutcomeCountPerThresholdResponse:
Expand All @@ -50,6 +52,7 @@ def get_outcome_count_per_threshold(
SupportedModule.OutcomeCountPerThreshold,
dataset_split_name,
task_manager,
config=config,
mod_options=mod_options,
last_update=dataset_split_manager.last_update,
)
Expand All @@ -69,6 +72,7 @@ def get_outcome_count_per_filter(
dataset_split_name: DatasetSplitName,
named_filters: NamedDatasetFilters = Depends(build_named_dataset_filters),
task_manager: TaskManager = Depends(get_task_manager),
config: AzimuthConfig = Depends(get_config),
dataset_split_manager: DatasetSplitManager = Depends(get_dataset_split_manager),
pipeline_index: int = Depends(require_pipeline_index),
without_postprocessing: bool = Query(False, title="Without Postprocessing"),
Expand All @@ -83,6 +87,7 @@ def get_outcome_count_per_filter(
SupportedModule.OutcomeCountPerFilter,
dataset_split_name,
task_manager,
config=config,
mod_options=mod_options,
last_update=dataset_split_manager.last_update,
)[0]
Expand Down
5 changes: 4 additions & 1 deletion azimuth/routers/top_words.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@

from fastapi import APIRouter, Depends, Query

from azimuth.app import get_dataset_split_manager, get_task_manager
from azimuth.app import get_config, get_dataset_split_manager, get_task_manager
from azimuth.config import AzimuthConfig
from azimuth.dataset_split_manager import DatasetSplitManager
from azimuth.task_manager import TaskManager
from azimuth.types import (
Expand Down Expand Up @@ -33,6 +34,7 @@ def get_top_words(
dataset_split_name: DatasetSplitName,
named_filters: NamedDatasetFilters = Depends(build_named_dataset_filters),
task_manager: TaskManager = Depends(get_task_manager),
config: AzimuthConfig = Depends(get_config),
dataset_split_manager: DatasetSplitManager = Depends(get_dataset_split_manager),
pipeline_index: int = Depends(require_pipeline_index),
without_postprocessing: bool = Query(False, title="Without Postprocessing"),
Expand All @@ -49,6 +51,7 @@ def get_top_words(
SupportedModule.TopWords,
dataset_split_name,
task_manager,
config=config,
mod_options=mod_options,
last_update=dataset_split_manager.last_update,
)[0]
Expand Down
Loading

0 comments on commit 931f444

Please sign in to comment.