diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index e6da8491..a2cc0ecf 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,6 +1,6 @@ repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.4.0 + rev: v4.5.0 hooks: - id: end-of-file-fixer - id: requirements-txt-fixer @@ -13,13 +13,13 @@ repos: - id: debug-statements - repo: https://github.com/psf/black - rev: 23.7.0 + rev: 23.9.1 hooks: - id: black language_version: python3 - repo: https://github.com/charliermarsh/ruff-pre-commit - rev: "v0.0.282" + rev: "v0.0.292" hooks: - id: ruff diff --git a/CHANGELOG.md b/CHANGELOG.md index be0dc2dc..06d84471 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,10 @@ +## [Version 1.5.1] - 2023-10-17 + +- `MorphologicalFilterTask` adapted to work on boolean values. +- Added `temporal_subset` method to `EOPatch`, which can be used to extract a subset of an `EOPatch` by filtering out temporal slices. Also added a corresponding `TemporalSubsetTask`. +- `EOExecutor` now has an option to treat `TemporalDimensionWarning` as an exception. +- String representation of `EOPatch` objects was revisited to avoid edge cases where the output would print enormous objects. + ## [Version 1.5.0] - 2023-09-06 The release focuses on making `eo-learn` much simpler to install, reducing the number of dependencies, and improving validation of soundness of `EOPatch` data. diff --git a/docs/source/conf.py b/docs/source/conf.py index bcb1a622..11609690 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -134,7 +134,7 @@ # When Sphinx documents class signature it prioritizes __new__ method over __init__ method. The following hack puts # EOTask.__new__ method to the blacklist so that __init__ method signature will be taken instead. This seems the # cleanest way even though a private object is accessed. -sphinx.ext.autodoc._CLASS_NEW_BLACKLIST.append("{0.__module__}.{0.__qualname__}".format(EOTask.__new__)) # noqa: SLF001 +sphinx.ext.autodoc._CLASS_NEW_BLACKLIST.append(f"{EOTask.__module__}.{EOTask.__new__.__qualname__}") # noqa: SLF001 EXAMPLES_FOLDER = "./examples" diff --git a/eolearn/__init__.py b/eolearn/__init__.py index 61e39397..5c4a15d8 100644 --- a/eolearn/__init__.py +++ b/eolearn/__init__.py @@ -1,5 +1,5 @@ """Main module of the `eolearn` package.""" -__version__ = "1.5.0" +__version__ = "1.5.1" import importlib.util import warnings diff --git a/eolearn/core/__init__.py b/eolearn/core/__init__.py index 168775b6..766dc951 100644 --- a/eolearn/core/__init__.py +++ b/eolearn/core/__init__.py @@ -19,6 +19,7 @@ RemoveFeatureTask, RenameFeatureTask, SaveTask, + TemporalSubsetTask, ZipFeatureTask, ) from .eodata import EOPatch diff --git a/eolearn/core/core_tasks.py b/eolearn/core/core_tasks.py index 21ba4648..31b9f953 100644 --- a/eolearn/core/core_tasks.py +++ b/eolearn/core/core_tasks.py @@ -413,6 +413,30 @@ def execute(self, src_eopatch: EOPatch, dst_eopatch: EOPatch) -> EOPatch: return dst_eopatch +class TemporalSubsetTask(EOTask): + """Extracts a temporal subset of the EOPatch.""" + + def __init__( + self, timestamps: None | list[dt.datetime] | list[int] | Callable[[list[dt.datetime]], Iterable[bool]] = None + ): + """ + :param timestamps: Input for the `temporal_subset` method of EOPatch. Can also be provided in execution + arguments. Value in execution arguments takes precedence. + """ + self.timestamps = timestamps + + def execute( + self, + eopatch: EOPatch, + *, + timestamps: None | list[dt.datetime] | list[int] | Callable[[list[dt.datetime]], Iterable[bool]] = None, + ) -> EOPatch: + timestamps = timestamps if timestamps is not None else self.timestamps + if timestamps is None: + raise ValueError("Value for `timestamps` must be provided on initialization or as an execution argument.") + return eopatch.temporal_subset(timestamps) + + class MapFeatureTask(EOTask): """Applies a function to each feature in input_features of a patch and stores the results in a set of output_features. diff --git a/eolearn/core/eodata.py b/eolearn/core/eodata.py index 5f0dda5a..93660204 100644 --- a/eolearn/core/eodata.py +++ b/eolearn/core/eodata.py @@ -455,11 +455,7 @@ def __repr__(self) -> str: @staticmethod def _repr_value(value: object) -> str: - """Creates a representation string for different types of data. - - :param value: data in any type - :return: representation string - """ + """Creates a representation string for different types of data.""" if isinstance(value, np.ndarray): return f"{EOPatch._repr_value_class(value)}(shape={value.shape}, dtype={value.dtype})" @@ -467,24 +463,28 @@ def _repr_value(value: object) -> str: crs = CRS(value.crs).ogc_string() if value.crs else value.crs return f"{EOPatch._repr_value_class(value)}(columns={list(value)}, length={len(value)}, crs={crs})" + repr_str = str(value) + if len(repr_str) <= MAX_DATA_REPR_LEN: + return repr_str + if isinstance(value, (list, tuple, dict)) and value: - repr_str = str(value) - if len(repr_str) <= MAX_DATA_REPR_LEN: - return repr_str + lb, rb = ("[", "]") if isinstance(value, list) else ("(", ")") if isinstance(value, tuple) else ("{", "}") - l_bracket, r_bracket = ("[", "]") if isinstance(value, list) else ("(", ")") - if isinstance(value, (list, tuple)) and len(value) > 2: - repr_str = f"{l_bracket}{value[0]!r}, ..., {value[-1]!r}{r_bracket}" + if isinstance(value, dict): # generate representation of first element or (key, value) pair + some_key = next(iter(value)) + repr_of_el = f"{EOPatch._repr_value(some_key)}: {EOPatch._repr_value(value[some_key])}" + else: + repr_of_el = EOPatch._repr_value(value[0]) - if len(repr_str) > MAX_DATA_REPR_LEN and isinstance(value, (list, tuple)) and len(value) > 1: - repr_str = f"{l_bracket}{value[0]!r}, ...{r_bracket}" + many_elements_visual = ", ..." if len(value) > 1 else "" # add ellipsis if there are multiple elements + repr_str = f"{lb}{repr_of_el}{many_elements_visual}{rb}" if len(repr_str) > MAX_DATA_REPR_LEN: repr_str = str(type(value)) - return f"{repr_str}, length={len(value)}" + return f"{repr_str}" - return repr(value) + return str(type(value)) @staticmethod def _repr_value_class(value: object) -> str: @@ -726,6 +726,7 @@ def merge( self, *eopatches, features=features, time_dependent_op=time_dependent_op, timeless_op=timeless_op ) + @deprecated_function(EODeprecationWarning, "Please use the method `temporal_subset` instead.") def consolidate_timestamps(self, timestamps: list[dt.datetime]) -> set[dt.datetime]: """Removes all frames from the EOPatch with a date not found in the provided timestamps list. @@ -750,6 +751,45 @@ def consolidate_timestamps(self, timestamps: list[dt.datetime]) -> set[dt.dateti return remove_from_patch + def temporal_subset( + self, timestamps: Iterable[dt.datetime] | Iterable[int] | Callable[[list[dt.datetime]], Iterable[bool]] + ) -> EOPatch: + """Returns an EOPatch that only contains data for the temporal subset corresponding to `timestamps`. + + For array-based data appropriate temporal slices are extracted. For vector data a filtration is performed. + + :param timestamps: Parameter that defines the temporal subset. Can be a collection of timestamps, a + collection of timestamp indices. It is possible to also provide a callable that maps a list of timestamps + to a sequence of booleans, which determine if a given timestamp is included in the subset or not. + """ + timestamp_indices = self._parse_temporal_subset_input(timestamps) + new_timestamps = [ts for i, ts in enumerate(self.get_timestamps()) if i in timestamp_indices] + new_patch = EOPatch(bbox=self.bbox, timestamps=new_timestamps) + + for ftype, fname in self.get_features(): + if ftype.is_timeless() or ftype.is_meta(): + new_patch[ftype, fname] = self[ftype, fname] + elif ftype.is_vector(): + gdf: gpd.GeoDataFrame = self[ftype, fname] + new_patch[ftype, fname] = gdf[gdf[TIMESTAMP_COLUMN].isin(new_timestamps)] + else: + new_patch[ftype, fname] = self[ftype, fname][timestamp_indices] + + return new_patch + + def _parse_temporal_subset_input( + self, timestamps: Iterable[dt.datetime] | Iterable[int] | Callable[[list[dt.datetime]], Iterable[bool]] + ) -> list[int]: + """Parses input into a list of timestamp indices. Also adds implicit support for strings via `parse_time`.""" + if callable(timestamps): + accepted_timestamps = timestamps(self.get_timestamps()) + return [i for i, accepted in enumerate(accepted_timestamps) if accepted] + ts_or_idx = list(timestamps) + if all(isinstance(ts, int) for ts in ts_or_idx): + return ts_or_idx # type: ignore[return-value] + parsed_timestamps = {parse_time(ts, force_datetime=True) for ts in ts_or_idx} # type: ignore[call-overload] + return [i for i, ts in enumerate(self.get_timestamps()) if ts in parsed_timestamps] + def plot( self, feature: Feature, diff --git a/eolearn/core/eoexecution.py b/eolearn/core/eoexecution.py index de224990..6ce1f165 100644 --- a/eolearn/core/eoexecution.py +++ b/eolearn/core/eoexecution.py @@ -27,7 +27,7 @@ from .eonode import EONode from .eoworkflow import EOWorkflow, WorkflowResults -from .exceptions import EORuntimeWarning +from .exceptions import EORuntimeWarning, TemporalDimensionWarning from .utils.fs import get_base_filesystem_and_path, get_full_path, pickle_fs, unpickle_fs from .utils.logging import LogFileFilter from .utils.parallelize import _decide_processing_type, _ProcessingType, parallelize @@ -36,8 +36,7 @@ class _HandlerWithFsFactoryType(Protocol): """Type definition for a callable that accepts a path and a filesystem object""" - def __call__(self, path: str, filesystem: FS, **kwargs: Any) -> Handler: - ... + def __call__(self, path: str, filesystem: FS, **kwargs: Any) -> Handler: ... # pylint: disable=invalid-name @@ -56,6 +55,7 @@ class _ProcessingData: filter_logs_by_thread: bool logs_filter: Filter | None logs_handler_factory: _HandlerFactoryType + raise_on_temporal_mismatch: bool @dataclass(frozen=True) @@ -87,6 +87,7 @@ def __init__( filesystem: FS | None = None, logs_filter: Filter | None = None, logs_handler_factory: _HandlerFactoryType = FileHandler, + raise_on_temporal_mismatch: bool = False, ): """ :param workflow: A prepared instance of EOWorkflow class @@ -109,6 +110,7 @@ def __init__( object. The 2nd option is chosen only if `filesystem` parameter exists in the signature. + :param raise_on_temporal_mismatch: Whether to treat `TemporalDimensionWarning` as an exception. """ self.workflow = workflow self.execution_kwargs = self._parse_and_validate_execution_kwargs(execution_kwargs) @@ -117,6 +119,7 @@ def __init__( self.filesystem, self.logs_folder = self._parse_logs_filesystem(filesystem, logs_folder) self.logs_filter = logs_filter self.logs_handler_factory = logs_handler_factory + self.raise_on_temporal_mismatch = raise_on_temporal_mismatch self.start_time: dt.datetime | None = None self.report_folder: str | None = None @@ -194,6 +197,7 @@ def run(self, workers: int | None = 1, multiprocess: bool = True, **tqdm_kwargs: filter_logs_by_thread=filter_logs_by_thread, logs_filter=self.logs_filter, logs_handler_factory=self.logs_handler_factory, + raise_on_temporal_mismatch=self.raise_on_temporal_mismatch, ) for workflow_kwargs, log_path in zip(self.execution_kwargs, log_paths) ] @@ -264,7 +268,10 @@ def _execute_workflow(cls, data: _ProcessingData) -> WorkflowResults: data.logs_handler_factory, ) - results = data.workflow.execute(data.workflow_kwargs, raise_errors=False) + with warnings.catch_warnings(): + if data.raise_on_temporal_mismatch: + warnings.simplefilter("error", TemporalDimensionWarning) + results = data.workflow.execute(data.workflow_kwargs, raise_errors=False) cls._try_remove_logging(data.log_path, logger, handler) return results diff --git a/eolearn/core/eoworkflow.py b/eolearn/core/eoworkflow.py index 3c7e5fe6..37c58fa3 100644 --- a/eolearn/core/eoworkflow.py +++ b/eolearn/core/eoworkflow.py @@ -307,12 +307,10 @@ def get_nodes(self) -> list[EONode]: return self._nodes[:] @overload - def get_node_with_uid(self, uid: str, fail_if_missing: Literal[True] = ...) -> EONode: - ... + def get_node_with_uid(self, uid: str, fail_if_missing: Literal[True] = ...) -> EONode: ... @overload - def get_node_with_uid(self, uid: str, fail_if_missing: Literal[False] = ...) -> EONode | None: - ... + def get_node_with_uid(self, uid: str, fail_if_missing: Literal[False] = ...) -> EONode | None: ... def get_node_with_uid(self, uid: str, fail_if_missing: bool = False) -> EONode | None: """Returns node with give uid, if it exists in the workflow.""" diff --git a/eolearn/core/extra/ray.py b/eolearn/core/extra/ray.py index ba0d2a1c..de2bcda4 100644 --- a/eolearn/core/extra/ray.py +++ b/eolearn/core/extra/ray.py @@ -61,7 +61,7 @@ def _get_processing_type(*_: Any, **__: Any) -> _ProcessingType: def _ray_workflow_executor(workflow_args: _ProcessingData) -> WorkflowResults: """Called to execute a workflow on a ray worker""" # pylint: disable=protected-access - return RayExecutor._execute_workflow(workflow_args) + return RayExecutor._execute_workflow(workflow_args) # noqa: SLF001 def parallelize_with_ray( diff --git a/eolearn/core/utils/parallelize.py b/eolearn/core/utils/parallelize.py index 216672e4..4fb454e7 100644 --- a/eolearn/core/utils/parallelize.py +++ b/eolearn/core/utils/parallelize.py @@ -49,9 +49,7 @@ def _decide_processing_type(workers: int | None, multiprocess: bool) -> _Process """ if workers == 1: return _ProcessingType.SINGLE_PROCESS - if multiprocess: - return _ProcessingType.MULTIPROCESSING - return _ProcessingType.MULTITHREADING + return _ProcessingType.MULTIPROCESSING if multiprocess else _ProcessingType.MULTITHREADING def parallelize( @@ -74,10 +72,7 @@ def parallelize( :return: A list of function results. """ if not params: - raise ValueError( - "At least 1 list of parameters should be given. Otherwise it is not clear how many times the" - "function has to be executed." - ) + return [] processing_type = _decide_processing_type(workers=workers, multiprocess=multiprocess) if processing_type is _ProcessingType.SINGLE_PROCESS: @@ -105,7 +100,6 @@ def execute_with_mp_lock(function: Callable[..., OutputType], *args: Any, **kwar :param function: A function :param args: Function's positional arguments :param kwargs: Function's keyword arguments - :return: Function's results """ if multiprocessing.current_process().name == "MainProcess" or MULTIPROCESSING_LOCK is None: return function(*args, **kwargs) @@ -165,10 +159,7 @@ def join_futures_iter( """ def _wait_function(remaining_futures: Collection[Future]) -> tuple[Collection[Future], Collection[Future]]: - done, not_done = concurrent.futures.wait( - remaining_futures, timeout=float(update_interval), return_when=FIRST_COMPLETED - ) - return done, not_done + return concurrent.futures.wait(remaining_futures, timeout=float(update_interval), return_when=FIRST_COMPLETED) def _get_result(future: Future) -> Any: return future.result() @@ -184,8 +175,6 @@ def _base_join_futures_iter( ) -> Generator[tuple[int, OutputType], None, None]: """A generalized utility function that resolves futures, monitors progress, and serves as an iterator over results.""" - if not isinstance(futures, list): - raise ValueError(f"Parameters 'futures' should be a list but {type(futures)} was given") remaining_futures: Collection[FutureType] = _make_copy_and_empty_given(futures) id_to_position_map = {id(future): index for index, future in enumerate(remaining_futures)} @@ -195,9 +184,8 @@ def _base_join_futures_iter( done, remaining_futures = wait_function(remaining_futures) for future in done: result = get_result_function(future) - result_position = id_to_position_map[id(future)] pbar.update(1) - yield result_position, result + yield id_to_position_map[id(future)], result def _make_copy_and_empty_given(items: list[T]) -> list[T]: diff --git a/eolearn/features/extra/clustering.py b/eolearn/features/extra/clustering.py index 35172c3c..443ff3e5 100644 --- a/eolearn/features/extra/clustering.py +++ b/eolearn/features/extra/clustering.py @@ -88,7 +88,7 @@ def execute(self, eopatch: EOPatch) -> EOPatch: # All connections to masked pixels are removed if self.mask_name is not None: - mask = eopatch.mask_timeless[self.mask_name].squeeze() + mask = eopatch.mask_timeless[self.mask_name].squeeze(axis=-1) graph_args["mask"] = mask data = data[np.ravel(mask) != 0] diff --git a/eolearn/geometry/morphology.py b/eolearn/geometry/morphology.py index dbf039fe..b4ee86b9 100644 --- a/eolearn/geometry/morphology.py +++ b/eolearn/geometry/morphology.py @@ -48,7 +48,7 @@ def __init__( self.no_data_label = no_data_label def execute(self, eopatch: EOPatch) -> EOPatch: - feature_array = eopatch[(self.mask_type, self.mask_name)].squeeze().copy() + feature_array = eopatch[(self.mask_type, self.mask_name)].squeeze(axis=-1).copy() all_labels = np.unique(feature_array) erode_labels = self.erode_labels if self.erode_labels else all_labels @@ -148,6 +148,10 @@ def __init__( def map_method(self, feature: np.ndarray) -> np.ndarray: """Applies the morphological operation to a raster feature.""" feature = feature.copy() + is_bool = feature.dtype == bool + if is_bool: + feature = feature.astype(np.uint8) + morph_func = partial(cv2.morphologyEx, kernel=self.struct_elem, op=self.morph_operation) if feature.ndim == 3: for channel in range(feature.shape[2]): @@ -158,4 +162,4 @@ def map_method(self, feature: np.ndarray) -> np.ndarray: else: raise ValueError(f"Invalid number of dimensions: {feature.ndim}") - return feature + return feature.astype(bool) if is_bool else feature diff --git a/tests/core/test_core_tasks.py b/tests/core/test_core_tasks.py index c80fa32f..cfa91de8 100644 --- a/tests/core/test_core_tasks.py +++ b/tests/core/test_core_tasks.py @@ -40,6 +40,7 @@ RemoveFeatureTask, RenameFeatureTask, SaveTask, + TemporalSubsetTask, ZipFeatureTask, ) from eolearn.core.core_tasks import ExplodeBandsTask @@ -277,6 +278,27 @@ def test_merge_features(axis: int, features_to_merge: list[Feature], feature: Fe assert_array_equal(patch[feature], expected) +@pytest.mark.parametrize( + "timestamps", + [ + [1, 2, 4], + [datetime(2019, 4, 2), datetime(2019, 7, 2), datetime(2019, 12, 31)], + lambda _: [False, True, True, False, True], + ], +) +def test_temporal_subset_task(patch: EOPatch, timestamps): + """The correctness is tested in the method test, so we focus on testing that parameters are passed correctly.""" + task_init = TemporalSubsetTask(timestamps) + result_init = task_init.execute(patch) + + task_exec = TemporalSubsetTask() + result_exec = task_exec.execute(patch, timestamps=timestamps) + + assert result_init == result_exec + assert len(result_exec.get_timestamps()) == 3 + assert_array_equal(result_exec.data["bands"], patch.data["bands"][[1, 2, 4]]) + + @pytest.mark.parametrize( ("input_features", "output_feature", "zip_function", "kwargs"), [ diff --git a/tests/core/test_eodata.py b/tests/core/test_eodata.py index 68ccae5f..d16c7979 100644 --- a/tests/core/test_eodata.py +++ b/tests/core/test_eodata.py @@ -6,7 +6,7 @@ """ from __future__ import annotations -import datetime +import datetime as dt import warnings from typing import Any @@ -17,6 +17,7 @@ from sentinelhub import CRS, BBox from eolearn.core import EOPatch, FeatureType +from eolearn.core.constants import TIMESTAMP_COLUMN from eolearn.core.eodata_io import FeatureIO from eolearn.core.exceptions import EODeprecationWarning, TemporalDimensionWarning from eolearn.core.types import Feature, FeaturesSpecification @@ -96,7 +97,7 @@ def test_bbox_feature_type(invalid_bbox: Any) -> None: @pytest.mark.parametrize( - "valid_entry", [["2018-01-01", "15.2.1992"], (datetime.datetime(2017, 1, 1, 10, 4, 7), datetime.date(2017, 1, 11))] + "valid_entry", [["2018-01-01", "15.2.1992"], (dt.datetime(2017, 1, 1, 10, 4, 7), dt.date(2017, 1, 11))] ) def test_timestamp_valid_feature_type(valid_entry: Any) -> None: eop = EOPatch(bbox=DUMMY_BBOX, timestamps=valid_entry) @@ -106,9 +107,9 @@ def test_timestamp_valid_feature_type(valid_entry: Any) -> None: @pytest.mark.parametrize( "invalid_timestamps", [ - [datetime.datetime(2017, 1, 1, 10, 4, 7), None, datetime.datetime(2017, 1, 11, 10, 3, 51)], + [dt.datetime(2017, 1, 1, 10, 4, 7), None, dt.datetime(2017, 1, 11, 10, 3, 51)], "something", - datetime.datetime(2017, 1, 1, 10, 4, 7), + dt.datetime(2017, 1, 1, 10, 4, 7), ], ) def test_timestamps_invalid_feature_type(invalid_timestamps: Any) -> None: @@ -398,19 +399,20 @@ def test_get_features(patch: EOPatch, expected_features: list[Feature]) -> None: assert patch.get_features() == expected_features +@pytest.mark.filterwarnings("ignore::eolearn.core.exceptions.EODeprecationWarning") def test_timestamp_consolidation() -> None: # 10 frames timestamps = [ - datetime.datetime(2017, 1, 1, 10, 4, 7), - datetime.datetime(2017, 1, 4, 10, 14, 5), - datetime.datetime(2017, 1, 11, 10, 3, 51), - datetime.datetime(2017, 1, 14, 10, 13, 46), - datetime.datetime(2017, 1, 24, 10, 14, 7), - datetime.datetime(2017, 2, 10, 10, 1, 32), - datetime.datetime(2017, 2, 20, 10, 6, 35), - datetime.datetime(2017, 3, 2, 10, 0, 20), - datetime.datetime(2017, 3, 12, 10, 7, 6), - datetime.datetime(2017, 3, 15, 10, 12, 14), + dt.datetime(2017, 1, 1, 10, 4, 7), + dt.datetime(2017, 1, 4, 10, 14, 5), + dt.datetime(2017, 1, 11, 10, 3, 51), + dt.datetime(2017, 1, 14, 10, 13, 46), + dt.datetime(2017, 1, 24, 10, 14, 7), + dt.datetime(2017, 2, 10, 10, 1, 32), + dt.datetime(2017, 2, 20, 10, 6, 35), + dt.datetime(2017, 3, 2, 10, 0, 20), + dt.datetime(2017, 3, 12, 10, 7, 6), + dt.datetime(2017, 3, 15, 10, 12, 14), ] data = np.random.rand(10, 100, 100, 3) @@ -430,7 +432,7 @@ def test_timestamp_consolidation() -> None: good_timestamps = timestamps.copy() del good_timestamps[0] del good_timestamps[-1] - good_timestamps.append(datetime.datetime(2017, 12, 1)) + good_timestamps.append(dt.datetime(2017, 12, 1)) removed_frames = eop.consolidate_timestamps(good_timestamps) @@ -444,20 +446,48 @@ def test_timestamp_consolidation() -> None: assert np.array_equal(mask_timeless, eop.mask_timeless["MASK_TIMELESS"]) -def test_timestamps_deprecation(): - eop = EOPatch(bbox=DUMMY_BBOX, timestamps=[datetime.datetime(1234, 5, 6)]) - - with pytest.warns(EODeprecationWarning): - assert eop.timestamp == [datetime.datetime(1234, 5, 6)] - - with pytest.warns(EODeprecationWarning): - eop.timestamp = [datetime.datetime(4321, 5, 6)] - - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", category=EODeprecationWarning) - # so the warnings get ignored in pytest summary - assert eop.timestamp == [datetime.datetime(4321, 5, 6)] - assert eop.timestamp == eop.timestamps +@pytest.mark.parametrize( + "method_input", + [ + ["2017-04-08", "2017-09-17"], + [1, 2], + lambda dates: (dt.datetime(2017, 4, 1) < x < dt.datetime(2017, 10, 10) for x in dates), + ], +) +def test_temporal_subset(method_input): + eop = generate_eopatch( + { + FeatureType.DATA: ["data1", "data2"], + FeatureType.MASK_TIMELESS: ["mask_timeless"], + FeatureType.SCALAR_TIMELESS: ["scalar_timeless"], + FeatureType.MASK: ["mask"], + }, + timestamps=[ + dt.datetime(2017, 1, 5), + dt.datetime(2017, 4, 8), + dt.datetime(2017, 9, 17), + dt.datetime(2018, 1, 5), + dt.datetime(2018, 12, 1), + ], + ) + vector_data = GeoDataFrame( + {TIMESTAMP_COLUMN: eop.get_timestamps()}, geometry=[eop.bbox.geometry.buffer(i) for i in range(5)], crs=32633 + ) + eop.vector["vector"] = vector_data + subset_timestamps = eop.timestamps[1:3] + + subset_eop = eop.temporal_subset(method_input) + assert subset_eop.timestamps == subset_timestamps + for feature in eop.get_features(): + if feature[0].is_timeless(): + assert_feature_data_equal(eop[feature], subset_eop[feature]) + elif feature[0].is_array(): + assert_feature_data_equal(eop[feature][1:3, ...], subset_eop[feature]) + + assert_feature_data_equal( + subset_eop.vector["vector"], + vector_data[1:3], + ) def test_bbox_none_deprecation(): diff --git a/tests/core/test_eodata_io.py b/tests/core/test_eodata_io.py index e9692cd9..bfb56bfe 100644 --- a/tests/core/test_eodata_io.py +++ b/tests/core/test_eodata_io.py @@ -623,8 +623,7 @@ def test_partial_temporal_saving_into_existing(eopatch: EOPatch, temporal_select io_kwargs = dict(path="patch-folder", filesystem=temp_fs, overwrite_permission="OVERWRITE_FEATURES") eopatch.save(**io_kwargs, use_zarr=True) - partial_patch = eopatch.copy(deep=True) - partial_patch.consolidate_timestamps(np.array(partial_patch.timestamps)[temporal_selection or ...]) + partial_patch = eopatch.copy(deep=True).temporal_subset(np.array(eopatch.timestamps)[temporal_selection or ...]) partial_patch.data["data"] = np.full_like(partial_patch.data["data"], 2) partial_patch.save(**io_kwargs, use_zarr=True, temporal_selection=temporal_selection) @@ -668,8 +667,7 @@ def test_partial_temporal_saving_infer(eopatch: EOPatch): io_kwargs = dict(path="patch-folder", filesystem=temp_fs, overwrite_permission="OVERWRITE_FEATURES") eopatch.save(**io_kwargs, use_zarr=True) - partial_patch = eopatch.copy(deep=True) - partial_patch.consolidate_timestamps(eopatch.timestamps[1:2] + eopatch.timestamps[3:5]) + partial_patch = eopatch.copy(deep=True).temporal_subset([1, 3, 4]) partial_patch.data["data"] = np.full_like(partial_patch.data["data"], 2) partial_patch.save(**io_kwargs, use_zarr=True, temporal_selection="infer") diff --git a/tests/core/test_eodata_merge.py b/tests/core/test_eodata_merge.py index 131e8706..4ce7bd3d 100644 --- a/tests/core/test_eodata_merge.py +++ b/tests/core/test_eodata_merge.py @@ -206,8 +206,7 @@ def test_lazy_loading(test_eopatch_path): def test_temporally_independent_merge(test_eopatch_path): full_patch = EOPatch.load(test_eopatch_path) - part1, part2 = full_patch.copy(deep=True), full_patch.copy(deep=True) - part1.consolidate_timestamps(full_patch.get_timestamps()[:10]) - part2.consolidate_timestamps(full_patch.get_timestamps()[10:]) + part1 = full_patch.copy(deep=True).temporal_subset(range(10)) + part2 = full_patch.copy(deep=True).temporal_subset(range(10, len(full_patch.timestamps))) assert full_patch == merge_eopatches(part1, part2, time_dependent_op="concatenate") diff --git a/tests/core/test_eoexecutor.py b/tests/core/test_eoexecutor.py index 06fa594a..48e80675 100644 --- a/tests/core/test_eoexecutor.py +++ b/tests/core/test_eoexecutor.py @@ -19,7 +19,21 @@ import pytest from fs.base import FS -from eolearn.core import EOExecutor, EONode, EOTask, EOWorkflow, OutputTask, WorkflowResults, execute_with_mp_lock +from sentinelhub import CRS, BBox + +from eolearn.core import ( + CreateEOPatchTask, + EOExecutor, + EONode, + EOTask, + EOWorkflow, + FeatureType, + InitializeFeatureTask, + OutputTask, + WorkflowResults, + execute_with_mp_lock, + linearly_connect_tasks, +) from eolearn.core.utils.fs import get_full_path FULL_LOG_LINE_COUNT = 12 @@ -251,3 +265,21 @@ def test_without_lock(num_workers): assert len(lines) == 2 * num_workers assert len(set(lines[:num_workers])) == num_workers, "All processes should start" assert len(set(lines[num_workers:])) == num_workers, "All processes should finish" + + +@pytest.mark.parametrize("multiprocess", [True, False]) +def test_temporal_dim_error(multiprocess): + workflow = EOWorkflow( + linearly_connect_tasks( + CreateEOPatchTask(bbox=BBox((0, 0, 1, 1), CRS.POP_WEB)), + InitializeFeatureTask([FeatureType.DATA, "data"], (2, 5, 5, 1)), + ) + ) + + executor = EOExecutor(workflow, [{}, {}]) + for result in executor.run(workers=2, multiprocess=multiprocess): + assert result.error_node_uid is None + + executor = EOExecutor(workflow, [{}, {}], raise_on_temporal_mismatch=True) + for result in executor.run(workers=2, multiprocess=multiprocess): + assert result.error_node_uid is not None diff --git a/tests/features/conftest.py b/tests/features/conftest.py index 77c533c4..8959c6b1 100644 --- a/tests/features/conftest.py +++ b/tests/features/conftest.py @@ -34,5 +34,4 @@ def small_ndvi_eopatch_fixture(example_eopatch: EOPatch): ndvi = example_eopatch.data["NDVI"][:, :20, :20] ndvi[np.isnan(ndvi)] = 0 example_eopatch.data["NDVI"] = ndvi - example_eopatch.consolidate_timestamps(example_eopatch.get_timestamps()[:10]) - return example_eopatch + return example_eopatch.temporal_subset(range(10)) diff --git a/tests/features/extra/test_clustering.py b/tests/features/extra/test_clustering.py index b143b080..80b3707b 100644 --- a/tests/features/extra/test_clustering.py +++ b/tests/features/extra/test_clustering.py @@ -43,14 +43,14 @@ def test_clustering(example_eopatch): remove_small=10, ).execute(example_eopatch) - clusters = example_eopatch.data_timeless["clusters_small"].squeeze() + clusters = example_eopatch.data_timeless["clusters_small"].squeeze(axis=-1) assert len(np.unique(clusters)) == 22, "Wrong number of clusters." assert np.median(clusters) == 2 assert np.mean(clusters) == pytest.approx(2.19109 if sys.version_info < (3, 9) else 2.201188) - clusters = example_eopatch.data_timeless["clusters_mask"].squeeze() + clusters = example_eopatch.data_timeless["clusters_mask"].squeeze(axis=-1) assert len(np.unique(clusters)) == 8, "Wrong number of clusters." assert np.median(clusters) == 0 diff --git a/tests/geometry/test_morphology.py b/tests/geometry/test_morphology.py index 646e80ec..f4d69d7a 100644 --- a/tests/geometry/test_morphology.py +++ b/tests/geometry/test_morphology.py @@ -24,8 +24,11 @@ def patch_fixture() -> EOPatch: config = PatchGeneratorConfig(max_integer_value=10, raster_shape=(50, 100), depth_range=(3, 4)) patch = generate_eopatch([MASK_FEATURE, MASK_TIMELESS_FEATURE], config=config) - for feat in [MASK_FEATURE, MASK_TIMELESS_FEATURE]: - patch[feat] = patch[feat].astype(np.uint8) + patch[MASK_FEATURE] = patch[MASK_FEATURE].astype(np.uint8) + patch[MASK_TIMELESS_FEATURE] = patch[MASK_TIMELESS_FEATURE] < 1 + patch[MASK_TIMELESS_FEATURE][10:20, 20:32] = 0 + patch[MASK_TIMELESS_FEATURE][30:, 50:] = 1 + return patch @@ -64,22 +67,32 @@ def test_erosion_partial(test_eopatch): MorphologicalOperations.DILATION, None, [6, 34, 172, 768, 2491, 7405, 19212, 44912], - [1, 2, 16, 104, 466, 1490, 3870, 9051], + [4882, 10118], + ), + ( + MorphologicalOperations.EROSION, + MorphologicalStructFactory.get_disk(4), + [54555, 15639, 3859, 770, 153, 19, 5], + [12391, 2609], ), - (MorphologicalOperations.EROSION, MorphologicalStructFactory.get_disk(11), [74957, 42, 1], [14994, 6]), - (MorphologicalOperations.OPENING, MorphologicalStructFactory.get_disk(11), [73899, 1051, 50], [14837, 163]), - (MorphologicalOperations.CLOSING, MorphologicalStructFactory.get_disk(11), [770, 74230], [425, 14575]), ( MorphologicalOperations.OPENING, - MorphologicalStructFactory.get_rectangle(5, 6), - [48468, 24223, 2125, 169, 15], - [10146, 4425, 417, 3, 9], + MorphologicalStructFactory.get_disk(3), + [8850, 13652, 16866, 14632, 11121, 6315, 2670, 761, 133], + [11981, 3019], + ), + (MorphologicalOperations.CLOSING, MorphologicalStructFactory.get_disk(11), [770, 74230], [661, 14339]), + ( + MorphologicalOperations.OPENING, + MorphologicalStructFactory.get_rectangle(3, 3), + [15026, 23899, 20363, 9961, 4328, 1128, 280, 15], + [12000, 3000], ), ( MorphologicalOperations.DILATION, MorphologicalStructFactory.get_rectangle(5, 6), [2, 19, 198, 3929, 70852], - [32, 743, 14225], + [803, 14197], ), ], ) @@ -91,5 +104,6 @@ def test_morphological_filter(patch, morph_operation, struct_element, mask_count assert patch[MASK_FEATURE].shape == (5, 50, 100, 3) assert patch[MASK_TIMELESS_FEATURE].shape == (50, 100, 3) + assert patch[MASK_TIMELESS_FEATURE].dtype == bool assert_array_equal(np.unique(patch[MASK_FEATURE], return_counts=True)[1], mask_counts) assert_array_equal(np.unique(patch[MASK_TIMELESS_FEATURE], return_counts=True)[1], mask_timeless_counts)