Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove config from task manager and stop killing it #400

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 15 additions & 23 deletions azimuth/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
from azimuth.startup import startup_tasks
from azimuth.task_manager import TaskManager
from azimuth.types import DatasetSplitName, ModuleOptions, SupportedModule
from azimuth.utils.cluster import default_cluster
from azimuth.utils.conversion import JSONResponseIgnoreNan
from azimuth.utils.logs import set_logger_config
from azimuth.utils.validation import assert_not_none
Expand Down Expand Up @@ -147,9 +146,7 @@ def start_app(config_path: Optional[str], load_config_history: bool, debug: bool
if azimuth_config.dataset is None:
raise ValueError("No dataset has been specified in the config.")

local_cluster = default_cluster(large=azimuth_config.large_dask_cluster)

run_startup_tasks(azimuth_config, local_cluster)
run_startup_tasks(azimuth_config)
task_manager = assert_not_none(_task_manager)
task_manager.client.run(set_logger_config, level)

Expand Down Expand Up @@ -321,25 +318,20 @@ def load_dataset_split_managers_from_config(
}


def initialize_managers(azimuth_config: AzimuthConfig, cluster: SpecCluster):
"""Initialize DatasetSplitManagers and TaskManagers.

def initialize_managers_and_config(
azimuth_config: AzimuthConfig, cluster: Optional[SpecCluster] = None
):
"""Initialize DatasetSplitManagers and Config.

Args:
azimuth_config: Configuration
cluster: Dask cluster to use.
azimuth_config: Config
cluster: Dask cluster to use, if different than default.
"""
global _task_manager, _dataset_split_managers, _azimuth_config
_azimuth_config = azimuth_config
if _task_manager is not None:
task_history = _task_manager.current_tasks
else:
task_history = {}

_task_manager = TaskManager(azimuth_config, cluster=cluster)

_task_manager.current_tasks = task_history
if not _task_manager:
_task_manager = TaskManager(cluster, azimuth_config.large_dask_cluster)

_azimuth_config = azimuth_config
_dataset_split_managers = load_dataset_split_managers_from_config(azimuth_config)


Expand All @@ -361,6 +353,7 @@ def run_validation_module(pipeline_index=None):
_, task = task_manager.get_task(
task_name=SupportedModule.Validation,
dataset_split_name=dataset_split,
config=config,
mod_options=ModuleOptions(pipeline_index=pipeline_index),
)
# Will raise exceptions as needed.
Expand All @@ -373,15 +366,14 @@ def run_validation_module(pipeline_index=None):
run_validation_module(pipeline_index)


def run_startup_tasks(azimuth_config: AzimuthConfig, cluster: SpecCluster):
def run_startup_tasks(azimuth_config: AzimuthConfig, cluster: Optional[SpecCluster] = None):
"""Initialize managers, run validation and startup tasks.

Args:
azimuth_config: Config
cluster: Cluster

cluster: Dask cluster to use, if different than default.
"""
initialize_managers(azimuth_config, cluster)
initialize_managers_and_config(azimuth_config, cluster)

task_manager = assert_not_none(get_task_manager())
# Validate that everything is in order **before** the startup tasks.
Expand All @@ -393,5 +385,5 @@ def run_startup_tasks(azimuth_config: AzimuthConfig, cluster: SpecCluster):
azimuth_config.save() # Save only after the validation modules ran successfully

global _startup_tasks, _ready_flag
_startup_tasks = startup_tasks(_dataset_split_managers, task_manager)
_startup_tasks = startup_tasks(_dataset_split_managers, task_manager, azimuth_config)
_ready_flag = Event()
2 changes: 2 additions & 0 deletions azimuth/routers/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@ def get_perturbation_testing_summary(
SupportedModule.PerturbationTestingMerged,
dataset_split_name=DatasetSplitName.all,
task_manager=task_manager,
config=config,
last_update=last_update,
mod_options=ModuleOptions(pipeline_index=pipeline_index),
)[0]
Expand All @@ -180,6 +181,7 @@ def get_perturbation_testing_summary(
SupportedModule.PerturbationTestingSummary,
dataset_split_name=DatasetSplitName.all,
task_manager=task_manager,
config=config,
mod_options=ModuleOptions(pipeline_index=pipeline_index),
)[0]
return PerturbationTestingSummary(
Expand Down
4 changes: 4 additions & 0 deletions azimuth/routers/class_overlap.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def get_class_overlap_plot(
SupportedModule.ClassOverlap,
dataset_split_name=dataset_split_name,
task_manager=task_manager,
config=config,
last_update=-1,
)[0]
class_overlap_plot_response: ClassOverlapPlotResponse = make_sankey_plot(
Expand All @@ -94,6 +95,7 @@ def get_class_overlap_plot(
def get_class_overlap(
dataset_split_name: DatasetSplitName,
task_manager: TaskManager = Depends(get_task_manager),
config: AzimuthConfig = Depends(get_config),
dataset_split_manager: DatasetSplitManager = Depends(get_dataset_split_manager),
dataset_split_managers: Dict[DatasetSplitName, DatasetSplitManager] = Depends(
get_dataset_split_manager_mapping
Expand All @@ -106,6 +108,7 @@ def get_class_overlap(
SupportedModule.ClassOverlap,
dataset_split_name=dataset_split_name,
task_manager=task_manager,
config=config,
last_update=-1,
)[0]
dataset_class_count = class_overlap_result.s_matrix.shape[0]
Expand All @@ -121,6 +124,7 @@ def get_class_overlap(
SupportedModule.ConfusionMatrix,
DatasetSplitName.eval,
task_manager=task_manager,
config=config,
mod_options=ModuleOptions(
pipeline_index=pipeline_index, cf_normalize=False, cf_reorder_classes=False
),
Expand Down
4 changes: 2 additions & 2 deletions azimuth/routers/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from azimuth.app import (
get_config,
get_task_manager,
initialize_managers,
initialize_managers_and_config,
require_editable_config,
run_startup_tasks,
)
Expand Down Expand Up @@ -89,7 +89,7 @@ def patch_config(
except Exception as e:
log.error("Rollback config update due to error", exc_info=e)
new_config = config
initialize_managers(new_config, task_manager.cluster)
initialize_managers_and_config(new_config, task_manager.cluster)
log.info("Config update cancelled.")
if isinstance(e, (AzimuthValidationError, ValidationError)):
raise HTTPException(HTTP_400_BAD_REQUEST, detail=str(e))
Expand Down
4 changes: 3 additions & 1 deletion azimuth/routers/custom_utterances.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,11 +95,13 @@ def get_saliency(
utterances: List[str] = Query([], title="Utterances"),
pipeline_index: int = Depends(require_pipeline_index),
task_manager: TaskManager = Depends(get_task_manager),
config: AzimuthConfig = Depends(get_config),
) -> List[SaliencyResponse]:
task_result: List[SaliencyResponse] = get_custom_task_result(
SupportedMethod.Saliency,
task_manager=task_manager,
custom_query={task_manager.config.columns.text_input: utterances},
config=config,
custom_query={config.columns.text_input: utterances},
mod_options=ModuleOptions(pipeline_index=pipeline_index),
)

Expand Down
5 changes: 4 additions & 1 deletion azimuth/routers/dataset_warnings.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@

from fastapi import APIRouter, Depends

from azimuth.app import get_dataset_split_manager_mapping, get_task_manager
from azimuth.app import get_config, get_dataset_split_manager_mapping, get_task_manager
from azimuth.config import AzimuthConfig
from azimuth.dataset_split_manager import DatasetSplitManager
from azimuth.task_manager import TaskManager
from azimuth.types import DatasetSplitName, SupportedModule
Expand All @@ -25,6 +26,7 @@
)
def get_dataset_warnings(
task_manager: TaskManager = Depends(get_task_manager),
config: AzimuthConfig = Depends(get_config),
dataset_split_managers: Dict[DatasetSplitName, DatasetSplitManager] = Depends(
get_dataset_split_manager_mapping
),
Expand All @@ -35,6 +37,7 @@ def get_dataset_warnings(
dataset_split_name=DatasetSplitName.all,
task_manager=task_manager,
last_update=get_last_update(list(dataset_split_managers.values())),
config=config,
)[0]

return task_result.warning_groups
2 changes: 2 additions & 0 deletions azimuth/routers/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ def get_export_perturbation_testing_summary(
SupportedModule.PerturbationTestingSummary,
DatasetSplitName.all,
task_manager=task_manager,
config=config,
last_update=last_update,
mod_options=ModuleOptions(pipeline_index=pipeline_index),
)[0].all_tests_summary
Expand Down Expand Up @@ -150,6 +151,7 @@ def get_export_perturbed_set(
SupportedModule.PerturbationTesting,
dataset_split_name,
task_manager,
config=config,
mod_options=ModuleOptions(pipeline_index=pipeline_index_not_null),
)

Expand Down
5 changes: 4 additions & 1 deletion azimuth/routers/model_performance/confidence_histogram.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@

from fastapi import APIRouter, Depends, Query

from azimuth.app import get_dataset_split_manager, get_task_manager
from azimuth.app import get_config, get_dataset_split_manager, get_task_manager
from azimuth.config import AzimuthConfig
from azimuth.dataset_split_manager import DatasetSplitManager
from azimuth.task_manager import TaskManager
from azimuth.types import (
Expand Down Expand Up @@ -33,6 +34,7 @@ def get_confidence_histogram(
dataset_split_name: DatasetSplitName,
named_filters: NamedDatasetFilters = Depends(build_named_dataset_filters),
task_manager: TaskManager = Depends(get_task_manager),
config: AzimuthConfig = Depends(get_config),
dataset_split_manager: DatasetSplitManager = Depends(get_dataset_split_manager),
pipeline_index: int = Depends(require_pipeline_index),
without_postprocessing: bool = Query(False, title="Without Postprocessing"),
Expand All @@ -47,6 +49,7 @@ def get_confidence_histogram(
task_name=SupportedModule.ConfidenceHistogram,
dataset_split_name=dataset_split_name,
task_manager=task_manager,
config=config,
mod_options=mod_options,
last_update=dataset_split_manager.last_update,
)[0]
Expand Down
5 changes: 4 additions & 1 deletion azimuth/routers/model_performance/confusion_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@

from fastapi import APIRouter, Depends, Query

from azimuth.app import get_dataset_split_manager, get_task_manager
from azimuth.app import get_config, get_dataset_split_manager, get_task_manager
from azimuth.config import AzimuthConfig
from azimuth.dataset_split_manager import DatasetSplitManager
from azimuth.task_manager import TaskManager
from azimuth.types import (
Expand Down Expand Up @@ -33,6 +34,7 @@ def get_confusion_matrix(
dataset_split_name: DatasetSplitName,
named_filters: NamedDatasetFilters = Depends(build_named_dataset_filters),
task_manager: TaskManager = Depends(get_task_manager),
config: AzimuthConfig = Depends(get_config),
dataset_split_manager: DatasetSplitManager = Depends(get_dataset_split_manager),
pipeline_index: int = Depends(require_pipeline_index),
without_postprocessing: bool = Query(False, title="Without Postprocessing"),
Expand All @@ -51,6 +53,7 @@ def get_confusion_matrix(
SupportedModule.ConfusionMatrix,
dataset_split_name,
task_manager=task_manager,
config=config,
mod_options=mod_options,
last_update=dataset_split_manager.last_update,
)[0]
Expand Down
8 changes: 7 additions & 1 deletion azimuth/routers/model_performance/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@

from fastapi import APIRouter, Depends, Query

from azimuth.app import get_dataset_split_manager, get_task_manager
from azimuth.app import get_config, get_dataset_split_manager, get_task_manager
from azimuth.config import AzimuthConfig
from azimuth.dataset_split_manager import DatasetSplitManager
from azimuth.modules.model_performance.metrics import MetricsModule
from azimuth.task_manager import TaskManager
Expand Down Expand Up @@ -41,6 +42,7 @@ def get_metrics(
dataset_split_name: DatasetSplitName,
named_filters: NamedDatasetFilters = Depends(build_named_dataset_filters),
task_manager: TaskManager = Depends(get_task_manager),
config: AzimuthConfig = Depends(get_config),
dataset_split_manager: DatasetSplitManager = Depends(get_dataset_split_manager),
pipeline_index: int = Depends(require_pipeline_index),
without_postprocessing: bool = Query(False, title="Without Postprocessing"),
Expand All @@ -55,6 +57,7 @@ def get_metrics(
SupportedModule.Metrics,
dataset_split_name,
task_manager,
config=config,
mod_options=mod_options,
last_update=dataset_split_manager.last_update,
)
Expand All @@ -73,6 +76,7 @@ def get_metrics(
def get_metrics_per_filter(
dataset_split_name: DatasetSplitName,
task_manager: TaskManager = Depends(get_task_manager),
config: AzimuthConfig = Depends(get_config),
dataset_split_manager: DatasetSplitManager = Depends(get_dataset_split_manager),
pipeline_index: int = Depends(require_pipeline_index),
) -> MetricsPerFilterAPIResponse:
Expand All @@ -81,6 +85,7 @@ def get_metrics_per_filter(
SupportedModule.MetricsPerFilter,
dataset_split_name,
task_manager,
config=config,
mod_options=mod_options,
last_update=dataset_split_manager.last_update,
)[0]
Expand All @@ -89,6 +94,7 @@ def get_metrics_per_filter(
SupportedModule.Metrics,
dataset_split_name,
task_manager,
config=config,
mod_options=mod_options,
last_update=dataset_split_manager.last_update,
)[0]
Expand Down
7 changes: 6 additions & 1 deletion azimuth/routers/model_performance/outcome_count.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
from fastapi import APIRouter, Depends, HTTPException, Query
from starlette.status import HTTP_400_BAD_REQUEST

from azimuth.app import get_dataset_split_manager, get_task_manager
from azimuth.app import get_config, get_dataset_split_manager, get_task_manager
from azimuth.config import AzimuthConfig
from azimuth.dataset_split_manager import DatasetSplitManager
from azimuth.task_manager import TaskManager
from azimuth.types import (
Expand Down Expand Up @@ -38,6 +39,7 @@
def get_outcome_count_per_threshold(
dataset_split_name: DatasetSplitName,
task_manager: TaskManager = Depends(get_task_manager),
config: AzimuthConfig = Depends(get_config),
dataset_split_manager: DatasetSplitManager = Depends(get_dataset_split_manager),
pipeline_index: int = Depends(require_pipeline_index),
) -> OutcomeCountPerThresholdResponse:
Expand All @@ -50,6 +52,7 @@ def get_outcome_count_per_threshold(
SupportedModule.OutcomeCountPerThreshold,
dataset_split_name,
task_manager,
config=config,
mod_options=mod_options,
last_update=dataset_split_manager.last_update,
)
Expand All @@ -69,6 +72,7 @@ def get_outcome_count_per_filter(
dataset_split_name: DatasetSplitName,
named_filters: NamedDatasetFilters = Depends(build_named_dataset_filters),
task_manager: TaskManager = Depends(get_task_manager),
config: AzimuthConfig = Depends(get_config),
dataset_split_manager: DatasetSplitManager = Depends(get_dataset_split_manager),
pipeline_index: int = Depends(require_pipeline_index),
without_postprocessing: bool = Query(False, title="Without Postprocessing"),
Expand All @@ -83,6 +87,7 @@ def get_outcome_count_per_filter(
SupportedModule.OutcomeCountPerFilter,
dataset_split_name,
task_manager,
config=config,
mod_options=mod_options,
last_update=dataset_split_manager.last_update,
)[0]
Expand Down
5 changes: 4 additions & 1 deletion azimuth/routers/top_words.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@

from fastapi import APIRouter, Depends, Query

from azimuth.app import get_dataset_split_manager, get_task_manager
from azimuth.app import get_config, get_dataset_split_manager, get_task_manager
from azimuth.config import AzimuthConfig
from azimuth.dataset_split_manager import DatasetSplitManager
from azimuth.task_manager import TaskManager
from azimuth.types import (
Expand Down Expand Up @@ -33,6 +34,7 @@ def get_top_words(
dataset_split_name: DatasetSplitName,
named_filters: NamedDatasetFilters = Depends(build_named_dataset_filters),
task_manager: TaskManager = Depends(get_task_manager),
config: AzimuthConfig = Depends(get_config),
dataset_split_manager: DatasetSplitManager = Depends(get_dataset_split_manager),
pipeline_index: int = Depends(require_pipeline_index),
without_postprocessing: bool = Query(False, title="Without Postprocessing"),
Expand All @@ -49,6 +51,7 @@ def get_top_words(
SupportedModule.TopWords,
dataset_split_name,
task_manager,
config=config,
mod_options=mod_options,
last_update=dataset_split_manager.last_update,
)[0]
Expand Down
Loading