Skip to content

Commit

Permalink
Remove config from task manager and stop killing it
Browse files Browse the repository at this point in the history
  • Loading branch information
gabegma committed Mar 16, 2023
1 parent 180d3bc commit 14a7cd1
Show file tree
Hide file tree
Showing 21 changed files with 173 additions and 127 deletions.
38 changes: 16 additions & 22 deletions azimuth/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from azimuth.startup import startup_tasks
from azimuth.task_manager import TaskManager
from azimuth.types import DatasetSplitName, ModuleOptions
from azimuth.utils.cluster import default_cluster
from azimuth.utils.conversion import JSONResponseIgnoreNan
from azimuth.utils.logs import set_logger_config
from azimuth.utils.project import load_dataset_split_managers_from_config, save_config
Expand Down Expand Up @@ -101,9 +100,7 @@ def start_app(config_path: Optional[str], load_config_history: bool, debug: bool
if azimuth_config.dataset is None:
raise ValueError("No dataset has been specified in the config.")

local_cluster = default_cluster(large=azimuth_config.large_dask_cluster)

run_startup_tasks(azimuth_config, local_cluster)
run_startup_tasks(azimuth_config)
assert_not_none(_task_manager).client.run(set_logger_config, level)

app = create_app()
Expand Down Expand Up @@ -240,25 +237,23 @@ def create_app() -> FastAPI:
return app


def initialize_managers(azimuth_config: AzimuthConfig, cluster: SpecCluster):
"""Initialize DatasetSplitManagers and TaskManagers.
def initialize_managers_and_config(
azimuth_config: AzimuthConfig, cluster: Optional[SpecCluster] = None
):
"""Initialize DatasetSplitManagers and Config.
Args:
azimuth_config: Configuration
cluster: Dask cluster to use.
azimuth_config: Config
cluster: Dask cluster to use, if different than default.
"""
global _task_manager, _dataset_split_managers, _azimuth_config
_azimuth_config = azimuth_config
if _task_manager is not None:
task_history = _task_manager.current_tasks
if _task_manager:
_task_manager.clear_worker_cache()
_task_manager.restart()
else:
task_history = {}

_task_manager = TaskManager(azimuth_config, cluster=cluster)

_task_manager.current_tasks = task_history
_task_manager = TaskManager(cluster, azimuth_config.large_dask_cluster)

_azimuth_config = azimuth_config
_dataset_split_managers = load_dataset_split_managers_from_config(azimuth_config)


Expand Down Expand Up @@ -295,15 +290,14 @@ def run_validation_module(pipeline_index=None):
task_manager.restart()


def run_startup_tasks(azimuth_config: AzimuthConfig, cluster: SpecCluster):
def run_startup_tasks(azimuth_config: AzimuthConfig, cluster: Optional[SpecCluster] = None):
"""Initialize managers, run validation and startup tasks.
Args:
azimuth_config: Config
cluster: Cluster
cluster: Dask cluster to use, if different than default.
"""
initialize_managers(azimuth_config, cluster)
initialize_managers_and_config(azimuth_config, cluster)

task_manager = assert_not_none(get_task_manager())
# Validate that everything is in order **before** the startup tasks.
Expand All @@ -315,5 +309,5 @@ def run_startup_tasks(azimuth_config: AzimuthConfig, cluster: SpecCluster):
save_config(azimuth_config) # Save only after the validation modules ran successfully

global _startup_tasks, _ready_flag
_startup_tasks = startup_tasks(_dataset_split_managers, task_manager)
_startup_tasks = startup_tasks(_dataset_split_managers, task_manager, azimuth_config)
_ready_flag = Event()
18 changes: 7 additions & 11 deletions azimuth/routers/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,15 +84,12 @@ def get_dataset_info(
get_dataset_split_manager_mapping
),
startup_tasks: Dict[str, Module] = Depends(get_startup_tasks),
task_manager: TaskManager = Depends(get_task_manager),
config: AzimuthConfig = Depends(get_config),
):
eval_dm = dataset_split_managers.get(DatasetSplitName.eval)
training_dm = dataset_split_managers.get(DatasetSplitName.train)
dm = assert_not_none(eval_dm or training_dm)

model_contract = task_manager.config.model_contract

return DatasetInfoResponse(
project_name=config.name,
class_names=dm.get_class_names(),
Expand All @@ -105,19 +102,16 @@ def get_dataset_info(
if training_dm is not None
else [],
startup_tasks={k: v.status() for k, v in startup_tasks.items()},
model_contract=model_contract,
prediction_available=predictions_available(task_manager.config),
perturbation_testing_available=perturbation_testing_available(task_manager.config),
model_contract=config.model_contract,
prediction_available=predictions_available(config),
perturbation_testing_available=perturbation_testing_available(config),
available_dataset_splits=AvailableDatasetSplits(
eval=eval_dm is not None, train=training_dm is not None
),
similarity_available=similarity_available(task_manager.config),
similarity_available=similarity_available(config),
postprocessing_editable=None
if config.pipelines is None
else [
postprocessing_editable(task_manager.config, idx)
for idx in range(len(config.pipelines))
],
else [postprocessing_editable(config, idx) for idx in range(len(config.pipelines))],
)


Expand Down Expand Up @@ -171,6 +165,7 @@ def get_perturbation_testing_summary(
SupportedModule.PerturbationTestingMerged,
dataset_split_name=DatasetSplitName.all,
task_manager=task_manager,
config=config,
last_update=last_update,
mod_options=ModuleOptions(pipeline_index=pipeline_index),
)[0]
Expand All @@ -186,6 +181,7 @@ def get_perturbation_testing_summary(
SupportedModule.PerturbationTestingSummary,
dataset_split_name=DatasetSplitName.all,
task_manager=task_manager,
config=config,
mod_options=ModuleOptions(pipeline_index=pipeline_index),
)[0]
return PerturbationTestingSummary(
Expand Down
4 changes: 4 additions & 0 deletions azimuth/routers/class_overlap.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def get_class_overlap_plot(
SupportedModule.ClassOverlap,
dataset_split_name=dataset_split_name,
task_manager=task_manager,
config=config,
last_update=-1,
)[0]
class_overlap_plot_response: ClassOverlapPlotResponse = make_sankey_plot(
Expand All @@ -94,6 +95,7 @@ def get_class_overlap_plot(
def get_class_overlap(
dataset_split_name: DatasetSplitName,
task_manager: TaskManager = Depends(get_task_manager),
config: AzimuthConfig = Depends(get_config),
dataset_split_manager: DatasetSplitManager = Depends(get_dataset_split_manager),
dataset_split_managers: Dict[DatasetSplitName, DatasetSplitManager] = Depends(
get_dataset_split_manager_mapping
Expand All @@ -106,6 +108,7 @@ def get_class_overlap(
SupportedModule.ClassOverlap,
dataset_split_name=dataset_split_name,
task_manager=task_manager,
config=config,
last_update=-1,
)[0]
dataset_class_count = class_overlap_result.s_matrix.shape[0]
Expand All @@ -121,6 +124,7 @@ def get_class_overlap(
SupportedModule.ConfusionMatrix,
DatasetSplitName.eval,
task_manager=task_manager,
config=config,
mod_options=ModuleOptions(
pipeline_index=pipeline_index, cf_normalize=False, cf_reorder_classes=False
),
Expand Down
6 changes: 3 additions & 3 deletions azimuth/routers/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from azimuth.app import (
get_config,
get_task_manager,
initialize_managers,
initialize_managers_and_config,
require_editable_config,
run_startup_tasks,
)
Expand Down Expand Up @@ -81,11 +81,11 @@ def patch_config(

try:
new_config = update_config(old_config=config, partial_config=partial_config)
run_startup_tasks(new_config, task_manager.cluster)
run_startup_tasks(new_config)
except Exception as e:
log.error("Rollback config update due to error", exc_info=e)
new_config = config
initialize_managers(new_config, task_manager.cluster)
initialize_managers_and_config(new_config)
if isinstance(e, AzimuthValidationError):
raise HTTPException(HTTP_400_BAD_REQUEST, detail=str(e))
elif isinstance(e, ValidationError):
Expand Down
4 changes: 3 additions & 1 deletion azimuth/routers/custom_utterances.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,11 +95,13 @@ def get_saliency(
utterances: List[str] = Query([], title="Utterances"),
pipeline_index: int = Depends(require_pipeline_index),
task_manager: TaskManager = Depends(get_task_manager),
config: AzimuthConfig = Depends(get_config),
) -> List[SaliencyResponse]:
task_result: List[SaliencyResponse] = get_custom_task_result(
SupportedMethod.Saliency,
task_manager=task_manager,
custom_query={task_manager.config.columns.text_input: utterances},
config=config,
custom_query={config.columns.text_input: utterances},
mod_options=ModuleOptions(pipeline_index=pipeline_index),
)

Expand Down
5 changes: 4 additions & 1 deletion azimuth/routers/dataset_warnings.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@

from fastapi import APIRouter, Depends

from azimuth.app import get_dataset_split_manager_mapping, get_task_manager
from azimuth.app import get_config, get_dataset_split_manager_mapping, get_task_manager
from azimuth.config import AzimuthConfig
from azimuth.dataset_split_manager import DatasetSplitManager
from azimuth.task_manager import TaskManager
from azimuth.types import DatasetSplitName, SupportedModule
Expand All @@ -25,6 +26,7 @@
)
def get_dataset_warnings(
task_manager: TaskManager = Depends(get_task_manager),
config: AzimuthConfig = Depends(get_config),
dataset_split_managers: Dict[DatasetSplitName, DatasetSplitManager] = Depends(
get_dataset_split_manager_mapping
),
Expand All @@ -35,6 +37,7 @@ def get_dataset_warnings(
dataset_split_name=DatasetSplitName.all,
task_manager=task_manager,
last_update=get_last_update(list(dataset_split_managers.values())),
config=config
)[0]

return task_result.warning_groups
12 changes: 6 additions & 6 deletions azimuth/routers/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,18 +108,18 @@ def get_export_perturbation_testing_summary(
SupportedModule.PerturbationTestingSummary,
DatasetSplitName.all,
task_manager=task_manager,
config=config,
last_update=last_update,
mod_options=ModuleOptions(pipeline_index=pipeline_index),
)[0].all_tests_summary

cfg = task_manager.config
df = pd.DataFrame.from_records([t.dict() for t in task_result])
df["example"] = df["example"].apply(lambda i: i["perturbedUtterance"])
file_label = time.strftime("%Y%m%d_%H%M%S", time.localtime())

filename = f"azimuth_export_behavioral_testing_summary_{cfg.name}_{file_label}.csv"
filename = f"azimuth_export_behavioral_testing_summary_{config.name}_{file_label}.csv"

path = pjoin(cfg.get_artifact_path(), filename)
path = pjoin(config.get_artifact_path(), filename)

df.to_csv(path, index=False)

Expand All @@ -143,15 +143,15 @@ def get_export_perturbed_set(
) -> FileResponse:
pipeline_index_not_null = assert_not_none(pipeline_index)
file_label = time.strftime("%Y%m%d_%H%M%S", time.localtime())
cfg = task_manager.config

filename = f"azimuth_export_modified_set_{cfg.name}_{dataset_split_name}_{file_label}.json"
path = pjoin(cfg.get_artifact_path(), filename)
filename = f"azimuth_export_modified_set_{config.name}_{dataset_split_name}_{file_label}.json"
path = pjoin(config.get_artifact_path(), filename)

task_result: List[List[PerturbedUtteranceResult]] = get_standard_task_result(
SupportedModule.PerturbationTesting,
dataset_split_name,
task_manager,
config=config,
mod_options=ModuleOptions(pipeline_index=pipeline_index_not_null),
)

Expand Down
5 changes: 4 additions & 1 deletion azimuth/routers/model_performance/confidence_histogram.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@

from fastapi import APIRouter, Depends, Query

from azimuth.app import get_dataset_split_manager, get_task_manager
from azimuth.app import get_config, get_dataset_split_manager, get_task_manager
from azimuth.config import AzimuthConfig
from azimuth.dataset_split_manager import DatasetSplitManager
from azimuth.task_manager import TaskManager
from azimuth.types import (
Expand Down Expand Up @@ -33,6 +34,7 @@ def get_confidence_histogram(
dataset_split_name: DatasetSplitName,
named_filters: NamedDatasetFilters = Depends(build_named_dataset_filters),
task_manager: TaskManager = Depends(get_task_manager),
config: AzimuthConfig = Depends(get_config),
dataset_split_manager: DatasetSplitManager = Depends(get_dataset_split_manager),
pipeline_index: int = Depends(require_pipeline_index),
without_postprocessing: bool = Query(False, title="Without Postprocessing"),
Expand All @@ -47,6 +49,7 @@ def get_confidence_histogram(
task_name=SupportedModule.ConfidenceHistogram,
dataset_split_name=dataset_split_name,
task_manager=task_manager,
config=config,
mod_options=mod_options,
last_update=dataset_split_manager.last_update,
)[0]
Expand Down
5 changes: 4 additions & 1 deletion azimuth/routers/model_performance/confusion_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@

from fastapi import APIRouter, Depends, Query

from azimuth.app import get_dataset_split_manager, get_task_manager
from azimuth.app import get_config, get_dataset_split_manager, get_task_manager
from azimuth.config import AzimuthConfig
from azimuth.dataset_split_manager import DatasetSplitManager
from azimuth.task_manager import TaskManager
from azimuth.types import (
Expand Down Expand Up @@ -33,6 +34,7 @@ def get_confusion_matrix(
dataset_split_name: DatasetSplitName,
named_filters: NamedDatasetFilters = Depends(build_named_dataset_filters),
task_manager: TaskManager = Depends(get_task_manager),
config: AzimuthConfig = Depends(get_config),
dataset_split_manager: DatasetSplitManager = Depends(get_dataset_split_manager),
pipeline_index: int = Depends(require_pipeline_index),
without_postprocessing: bool = Query(False, title="Without Postprocessing"),
Expand All @@ -51,6 +53,7 @@ def get_confusion_matrix(
SupportedModule.ConfusionMatrix,
dataset_split_name,
task_manager=task_manager,
config=config,
mod_options=mod_options,
last_update=dataset_split_manager.last_update,
)[0]
Expand Down
8 changes: 7 additions & 1 deletion azimuth/routers/model_performance/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@

from fastapi import APIRouter, Depends, Query

from azimuth.app import get_dataset_split_manager, get_task_manager
from azimuth.app import get_config, get_dataset_split_manager, get_task_manager
from azimuth.config import AzimuthConfig
from azimuth.dataset_split_manager import DatasetSplitManager
from azimuth.modules.model_performance.metrics import MetricsModule
from azimuth.task_manager import TaskManager
Expand Down Expand Up @@ -41,6 +42,7 @@ def get_metrics(
dataset_split_name: DatasetSplitName,
named_filters: NamedDatasetFilters = Depends(build_named_dataset_filters),
task_manager: TaskManager = Depends(get_task_manager),
config: AzimuthConfig = Depends(get_config),
dataset_split_manager: DatasetSplitManager = Depends(get_dataset_split_manager),
pipeline_index: int = Depends(require_pipeline_index),
without_postprocessing: bool = Query(False, title="Without Postprocessing"),
Expand All @@ -55,6 +57,7 @@ def get_metrics(
SupportedModule.Metrics,
dataset_split_name,
task_manager,
config=config,
mod_options=mod_options,
last_update=dataset_split_manager.last_update,
)
Expand All @@ -73,6 +76,7 @@ def get_metrics(
def get_metrics_per_filter(
dataset_split_name: DatasetSplitName,
task_manager: TaskManager = Depends(get_task_manager),
config: AzimuthConfig = Depends(get_config),
dataset_split_manager: DatasetSplitManager = Depends(get_dataset_split_manager),
pipeline_index: int = Depends(require_pipeline_index),
) -> MetricsPerFilterAPIResponse:
Expand All @@ -81,6 +85,7 @@ def get_metrics_per_filter(
SupportedModule.MetricsPerFilter,
dataset_split_name,
task_manager,
config=config,
mod_options=mod_options,
last_update=dataset_split_manager.last_update,
)[0]
Expand All @@ -89,6 +94,7 @@ def get_metrics_per_filter(
SupportedModule.Metrics,
dataset_split_name,
task_manager,
config=config,
mod_options=mod_options,
last_update=dataset_split_manager.last_update,
)[0]
Expand Down
Loading

0 comments on commit 14a7cd1

Please sign in to comment.