diff --git a/CHANGELOG.md b/CHANGELOG.md index 29800cbd..669385c3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,7 @@ - `MorphologicalFilterTask` adapted to work on boolean values. - Added `temporal_subset` method to `EOPatch`, which can be used to extract a subset of an `EOPatch` by filtering out temporal slices. Also added a corresponding `TemporalSubsetTask`. +- `EOExecutor` now has an option to treat `TemporalDimensionWarning` as an exception. ## [Version 1.5.0] - 2023-09-06 diff --git a/eolearn/core/eoexecution.py b/eolearn/core/eoexecution.py index 16f4a281..6ce1f165 100644 --- a/eolearn/core/eoexecution.py +++ b/eolearn/core/eoexecution.py @@ -27,7 +27,7 @@ from .eonode import EONode from .eoworkflow import EOWorkflow, WorkflowResults -from .exceptions import EORuntimeWarning +from .exceptions import EORuntimeWarning, TemporalDimensionWarning from .utils.fs import get_base_filesystem_and_path, get_full_path, pickle_fs, unpickle_fs from .utils.logging import LogFileFilter from .utils.parallelize import _decide_processing_type, _ProcessingType, parallelize @@ -55,6 +55,7 @@ class _ProcessingData: filter_logs_by_thread: bool logs_filter: Filter | None logs_handler_factory: _HandlerFactoryType + raise_on_temporal_mismatch: bool @dataclass(frozen=True) @@ -86,6 +87,7 @@ def __init__( filesystem: FS | None = None, logs_filter: Filter | None = None, logs_handler_factory: _HandlerFactoryType = FileHandler, + raise_on_temporal_mismatch: bool = False, ): """ :param workflow: A prepared instance of EOWorkflow class @@ -108,6 +110,7 @@ def __init__( object. The 2nd option is chosen only if `filesystem` parameter exists in the signature. + :param raise_on_temporal_mismatch: Whether to treat `TemporalDimensionWarning` as an exception. """ self.workflow = workflow self.execution_kwargs = self._parse_and_validate_execution_kwargs(execution_kwargs) @@ -116,6 +119,7 @@ def __init__( self.filesystem, self.logs_folder = self._parse_logs_filesystem(filesystem, logs_folder) self.logs_filter = logs_filter self.logs_handler_factory = logs_handler_factory + self.raise_on_temporal_mismatch = raise_on_temporal_mismatch self.start_time: dt.datetime | None = None self.report_folder: str | None = None @@ -193,6 +197,7 @@ def run(self, workers: int | None = 1, multiprocess: bool = True, **tqdm_kwargs: filter_logs_by_thread=filter_logs_by_thread, logs_filter=self.logs_filter, logs_handler_factory=self.logs_handler_factory, + raise_on_temporal_mismatch=self.raise_on_temporal_mismatch, ) for workflow_kwargs, log_path in zip(self.execution_kwargs, log_paths) ] @@ -263,7 +268,10 @@ def _execute_workflow(cls, data: _ProcessingData) -> WorkflowResults: data.logs_handler_factory, ) - results = data.workflow.execute(data.workflow_kwargs, raise_errors=False) + with warnings.catch_warnings(): + if data.raise_on_temporal_mismatch: + warnings.simplefilter("error", TemporalDimensionWarning) + results = data.workflow.execute(data.workflow_kwargs, raise_errors=False) cls._try_remove_logging(data.log_path, logger, handler) return results diff --git a/tests/core/test_eoexecutor.py b/tests/core/test_eoexecutor.py index 06fa594a..48e80675 100644 --- a/tests/core/test_eoexecutor.py +++ b/tests/core/test_eoexecutor.py @@ -19,7 +19,21 @@ import pytest from fs.base import FS -from eolearn.core import EOExecutor, EONode, EOTask, EOWorkflow, OutputTask, WorkflowResults, execute_with_mp_lock +from sentinelhub import CRS, BBox + +from eolearn.core import ( + CreateEOPatchTask, + EOExecutor, + EONode, + EOTask, + EOWorkflow, + FeatureType, + InitializeFeatureTask, + OutputTask, + WorkflowResults, + execute_with_mp_lock, + linearly_connect_tasks, +) from eolearn.core.utils.fs import get_full_path FULL_LOG_LINE_COUNT = 12 @@ -251,3 +265,21 @@ def test_without_lock(num_workers): assert len(lines) == 2 * num_workers assert len(set(lines[:num_workers])) == num_workers, "All processes should start" assert len(set(lines[num_workers:])) == num_workers, "All processes should finish" + + +@pytest.mark.parametrize("multiprocess", [True, False]) +def test_temporal_dim_error(multiprocess): + workflow = EOWorkflow( + linearly_connect_tasks( + CreateEOPatchTask(bbox=BBox((0, 0, 1, 1), CRS.POP_WEB)), + InitializeFeatureTask([FeatureType.DATA, "data"], (2, 5, 5, 1)), + ) + ) + + executor = EOExecutor(workflow, [{}, {}]) + for result in executor.run(workers=2, multiprocess=multiprocess): + assert result.error_node_uid is None + + executor = EOExecutor(workflow, [{}, {}], raise_on_temporal_mismatch=True) + for result in executor.run(workers=2, multiprocess=multiprocess): + assert result.error_node_uid is not None