diff --git a/.flake8 b/.flake8 deleted file mode 100644 index a364178eb..000000000 --- a/.flake8 +++ /dev/null @@ -1,10 +0,0 @@ -[flake8] -# B028 is ignored because !r flags cannot be used in python < 3.8 -ignore = E203, W503, C408, B028 -exclude = .git, __pycache__, build, dist -max-line-length= 120 -max-complexity = 15 -min_python_version = 3.7.0 -per-file-ignores = - # imported but unused - __init__.py: F401 diff --git a/.github/workflows/ci_action.yml b/.github/workflows/ci_action.yml index 6e122919f..37bdbc9e0 100644 --- a/.github/workflows/ci_action.yml +++ b/.github/workflows/ci_action.yml @@ -6,9 +6,12 @@ on: branches: - "master" - "develop" - schedule: - # Schedule events are triggered by whoever last changed the cron schedule - - cron: "5 0 * * *" + workflow_call: + +concurrency: + # This will cancel outdated runs on the same pull-request, but not runs for other triggers + group: ${{ github.head_ref || github.run_id }} + cancel-in-progress: true env: # The only way to simulate if-else statement @@ -19,46 +22,45 @@ jobs: runs-on: ubuntu-latest steps: - name: Checkout branch - uses: actions/checkout@v2 + uses: actions/checkout@v3 with: ref: ${{ env.CHECKOUT_BRANCH }} - name: Setup Python - uses: actions/setup-python@v2 + uses: actions/setup-python@v4 with: python-version: "3.8" - architecture: x64 - - - name: Prepare pre-commit validators - run: | - pip install pre-commit - - name: Check code compliance with pre-commit validators - run: pre-commit run --all-files + - uses: pre-commit/action@v3.0.0 + with: + extra_args: --all-files --verbose check-code-pylint-and-mypy: runs-on: ubuntu-latest steps: - name: Checkout branch - uses: actions/checkout@v2 + uses: actions/checkout@v3 with: ref: ${{ env.CHECKOUT_BRANCH }} - name: Setup Python - uses: actions/setup-python@v2 + uses: actions/setup-python@v4 with: python-version: "3.8" - architecture: x64 + # cache: pip # uncomment when all requirements are in `pyproject.toml` + # caching the entire environment is faster when cache exists but slower for cache creation - name: Install packages run: | - pip install -r requirements-dev.txt --upgrade + pip install -r requirements-dev.txt --upgrade --upgrade-strategy eager python install_all.py + pip install -r ml_tools/requirements-tdigest.txt - name: Run pylint run: make pylint - name: Run mypy + if: success() || failure() run: | mypy \ core/eolearn/core \ @@ -77,6 +79,7 @@ jobs: python-version: - "3.9" - "3.10" + - "3.11" include: # A flag marks whether full or partial tests should be run # We don't run integration tests on pull requests from outside repos, because they don't have secrets @@ -84,21 +87,19 @@ jobs: full_test_suite: ${{ github.event_name != 'pull_request' || github.event.pull_request.head.repo.full_name == github.repository }} steps: - name: Checkout branch - uses: actions/checkout@v2 + uses: actions/checkout@v3 with: ref: ${{ env.CHECKOUT_BRANCH }} - name: Setup Python - uses: actions/setup-python@v2 + uses: actions/setup-python@v4 with: python-version: ${{ matrix.python-version }} - architecture: x64 + # cache: pip # uncomment when all requirements are in `pyproject.toml` - name: Install packages run: | - sudo apt-get update - sudo apt-get install -y build-essential libgdal-dev graphviz proj-bin gcc libproj-dev libspatialindex-dev - pip install -r requirements-dev.txt --upgrade + pip install -r requirements-dev.txt --upgrade --upgrade-strategy eager python install_all.py -e - name: Run full tests and code coverage @@ -111,8 +112,7 @@ jobs: - name: Run reduced tests if: ${{ !matrix.full_test_suite }} - run: | - pytest -m "not sh_integration" + run: pytest -m "not sh_integration" - name: Upload code coverage if: ${{ matrix.full_test_suite && github.event_name == 'push' }} diff --git a/.github/workflows/scheduled_caller.yml b/.github/workflows/scheduled_caller.yml new file mode 100644 index 000000000..8a72a0c18 --- /dev/null +++ b/.github/workflows/scheduled_caller.yml @@ -0,0 +1,11 @@ +name: scheduled build caller + +on: + schedule: + # Schedule events are triggered by whoever last changed the cron schedule + - cron: "0 0 * * *" + +jobs: + call-workflow: + uses: sentinel-hub/eo-learn/.github/workflows/ci_action.yml@develop + secrets: inherit diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 041ba4ea6..0693d2f1d 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -13,41 +13,18 @@ repos: - id: debug-statements - repo: https://github.com/psf/black - rev: 23.1.0 + rev: 23.3.0 hooks: - id: black language_version: python3 - - repo: https://github.com/pycqa/isort - rev: 5.12.0 + - repo: https://github.com/charliermarsh/ruff-pre-commit + rev: "v0.0.269" hooks: - - id: isort - name: isort (python) - - - repo: https://github.com/PyCQA/autoflake - rev: v2.0.1 - hooks: - - id: autoflake - args: - [ - --remove-all-unused-imports, - --in-place, - --ignore-init-module-imports, - ] - - - repo: https://github.com/pycqa/flake8 - rev: 6.0.0 - hooks: - - id: flake8 - additional_dependencies: - - flake8-bugbear==23.2.13 - - flake8-comprehensions==3.10.1 - - flake8-simplify==0.19.3 - - flake8-typing-imports==1.14.0 + - id: ruff - repo: https://github.com/nbQA-dev/nbQA - rev: 1.6.3 + rev: 1.7.0 hooks: - id: nbqa-black - - id: nbqa-isort - - id: nbqa-flake8 + - id: nbqa-ruff diff --git a/.zenodo.json b/.zenodo.json index 4f7820ddb..1857c50b3 100644 --- a/.zenodo.json +++ b/.zenodo.json @@ -141,6 +141,7 @@ }, { "name": "Colin Moldenhauer", + "affiliation": "Technical University of Munich", "type": "Other" }, { diff --git a/CHANGELOG.md b/CHANGELOG.md index 234efd937..e27a80770 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,18 @@ +## [Version 1.4.2] - 2023-3-14 + +- Introduced support for Python 3.11. +- Removed support for Python 3.7. +- Added T-Digest `EOTask` in the scope of the Global Earth Monitor Project, contributed by @meengel. +- Used evalscript generation utility from `sentinelhub-py` in SH related `EOTasks`. +- Deprecated the `EOPatch.merge` method and extracted it as a function. +- Deprecated the `OVERWRITE_PATCH` permission and enforcing the usage of explicit string permissions. +- Encapsulated `FeatureDict` class as `Mapping`, removed inheritance from `dict`. +- Switched to new-style typed annotations. +- Introduced the `ruff` python linter, removed `flake8` and `isort` (covered by `ruff`). +- Fixed issue with occasionally failing scheduled builds on the `master` branch. +- Various refactoring efforts and dependency improvements. +- Various improvements to tests and code. + ## [Version 1.4.1] - 2023-3-14 - The codebase is now fully annotated and type annotations are mandatory for all new code. diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 97c0d1ec7..c3152f648 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -131,7 +131,7 @@ This section assumes you have installed all packages in `requirements-dev.txt`. Most of the automated code-checking is packaged into [pre-commit hooks](https://pre-commit.com/). You can activate them by running `pre-commit install`. If you wish to check all code you can do so by running `pre-commit run --all-files`. This takes care of: - auto-formatting the code using `black`, `isort`, and `autoflake` -- checking the code with `flake8` +- checking the code with `ruff` - checking and formatting any Jupyter notebooks with `nbqa` - various other helpful things (correcting line-endings etc.) diff --git a/CREDITS.md b/CREDITS.md index ba09c9662..4f85327f2 100644 --- a/CREDITS.md +++ b/CREDITS.md @@ -37,11 +37,12 @@ page or mine the [commit history](https://github.com/sentinel-hub/eo-learn/commi ## Other contributors * Drew Bollinger (DevelopmentSeed) +* Michael Engel (Technical University of Munich) * Peter Fogh * Hugo Fournier (Magellium) * Ben Huff * Filip Koprivec (Jožef Stefan Institute) -* Colin Moldenhauer +* Colin Moldenhauer (Technical University of Munich) * William Ouellette (TomTom) * Radoslav Pitoňák * Johannes Schmid (GeoVille) diff --git a/MANIFEST.in b/MANIFEST.in index 0705b0f50..07c61fc6f 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,4 +1,5 @@ include requirements*.txt include README.md +include CREDITS.md include LICENSE diff --git a/README.md b/README.md index 8780819c8..dab165af6 100644 --- a/README.md +++ b/README.md @@ -52,7 +52,7 @@ At the moment there are the following subpackages: ### PyPi distribution -The package requires Python version **>=3.7** . It can be installed with: +The package requires Python version **>=3.8** . It can be installed with: ```bash pip install eo-learn @@ -92,7 +92,7 @@ Some subpackages contain extension modules under `extra` subfolder. Those module ### Conda Forge distribution -The package requires a Python environment **>=3.7**. +The package requires a Python environment **>=3.8**. Thanks to the maintainers of the conda forge feedstock (@benhuff, @dcunn, @mwilson8, @oblute, @rluria14), `eo-learn` can be installed using `conda-forge` as follows: @@ -149,7 +149,7 @@ Examples and introductions to the package can be found [here](https://github.com ## Contributions -If you would like to contribute to `eo-learn`, check out our [contribution guidelines](./CONTRIBUTING.md). +The list of all `eo-learn` contributors can be found [here](./CONTRIBUTING.md). If you would like to contribute to `eo-learn`, please check our [contribution guidelines](./CONTRIBUTING.md). ## Blog posts and papers diff --git a/core/eolearn/core/__init__.py b/core/eolearn/core/__init__.py index 6665563cf..dd050c14e 100644 --- a/core/eolearn/core/__init__.py +++ b/core/eolearn/core/__init__.py @@ -22,6 +22,7 @@ ZipFeatureTask, ) from .eodata import EOPatch +from .eodata_merge import merge_eopatches from .eoexecution import EOExecutor from .eonode import EONode, linearly_connect_tasks from .eotask import EOTask @@ -32,4 +33,4 @@ from .utils.parallelize import execute_with_mp_lock, join_futures, join_futures_iter, parallelize from .utils.parsing import FeatureParser -__version__ = "1.4.1" +__version__ = "1.4.2" diff --git a/core/eolearn/core/constants.py b/core/eolearn/core/constants.py index fbdf37587..05f87f5d7 100644 --- a/core/eolearn/core/constants.py +++ b/core/eolearn/core/constants.py @@ -8,7 +8,7 @@ """ import warnings from enum import Enum, EnumMeta -from typing import Any, Optional +from typing import Any, Optional, TypeVar from sentinelhub import BBox, MimeType from sentinelhub.exceptions import deprecated_function @@ -16,29 +16,30 @@ from .exceptions import EODeprecationWarning TIMESTAMP_COLUMN = "TIMESTAMP" +T = TypeVar("T") -def _warn_and_adjust(name: str) -> str: +def _warn_and_adjust(name: T) -> T: # since we stick with `UPPER` for attributes and `lower` for values, we include both to reuse function - deprecation_msg = None + deprecation_msg = None # placeholder if name in ("TIMESTAMP", "timestamp"): - name = "TIMESTAMPS" if name == "TIMESTAMP" else "timestamps" + name = "TIMESTAMPS" if name == "TIMESTAMP" else "timestamps" # type: ignore[assignment] if deprecation_msg: - warnings.warn(deprecation_msg, category=EODeprecationWarning, stacklevel=3) # type: ignore + warnings.warn(deprecation_msg, category=EODeprecationWarning, stacklevel=3) # type: ignore[unreachable] return name class EnumWithDeprecations(EnumMeta): """A custom EnumMeta class for catching the deprecated Enum members of the FeatureType Enum class.""" - def __getattribute__(cls, name: str) -> Any: + def __getattribute__(cls, name: str) -> Any: # noqa[N805] return super().__getattribute__(_warn_and_adjust(name)) - def __getitem__(cls, name: str) -> Any: + def __getitem__(cls, name: str) -> Any: # noqa[N805] return super().__getitem__(_warn_and_adjust(name)) - def __call__(cls, value: str, *args: Any, **kwargs: Any) -> Any: + def __call__(cls, value: str, *args: Any, **kwargs: Any) -> Any: # noqa[N805] return super().__call__(_warn_and_adjust(value), *args, **kwargs) @@ -292,17 +293,54 @@ class FeatureTypeSet(metaclass=DeprecatedCollectionClass): RASTER_TYPES_1D = frozenset([FeatureType.SCALAR_TIMELESS, FeatureType.LABEL_TIMELESS]) -class OverwritePermission(Enum): - """Enum class which specifies which content of saved EOPatch can be overwritten when saving new content. +def _warn_and_adjust_permissions(name: T) -> T: + if isinstance(name, str) and name.upper() == "OVERWRITE_PATCH": + warnings.warn( + '"OVERWRITE_PATCH" permission is deprecated and will be removed in a future version', + category=EODeprecationWarning, + stacklevel=3, + ) + return name + + +class PermissionsWithDeprecations(EnumMeta): + """A custom EnumMeta class for catching the deprecated Enum members of the OverwritePermission Enum class.""" + + def __getattribute__(cls, name: str) -> Any: # noqa[N805] + return super().__getattribute__(_warn_and_adjust_permissions(name)) + + def __getitem__(cls, name: str) -> Any: # noqa[N805] + return super().__getitem__(_warn_and_adjust_permissions(name)) + + def __call__(cls, value: str, *args: Any, **kwargs: Any) -> Any: # noqa[N805] + return super().__call__(_warn_and_adjust_permissions(value), *args, **kwargs) + + +class OverwritePermission(Enum, metaclass=PermissionsWithDeprecations): + """Enum class which specifies which content of the saved EOPatch can be overwritten when saving new content. Permissions are in the following hierarchy: - `ADD_ONLY` - Only new features can be added, anything that is already saved cannot be changed. - `OVERWRITE_FEATURES` - Overwrite only data for features which have to be saved. The remaining content of saved EOPatch will stay unchanged. - - `OVERWRITE_PATCH` - Overwrite entire content of saved EOPatch and replace it with the new content. """ - ADD_ONLY = 0 - OVERWRITE_FEATURES = 1 - OVERWRITE_PATCH = 2 + ADD_ONLY = "ADD_ONLY" + OVERWRITE_FEATURES = "OVERWRITE_FEATURES" + OVERWRITE_PATCH = "OVERWRITE_PATCH" + + @classmethod + def _missing_(cls, value: object) -> "OverwritePermission": + permissions_mapping = {0: "ADD_ONLY", 1: "OVERWRITE_FEATURES", 2: "OVERWRITE_PATCH"} + if isinstance(value, int) and value in permissions_mapping: + deprecation_msg = ( + f"Please use strings to instantiate overwrite permissions, e.g., instead of {value} use" + f" {permissions_mapping[value]!r}" + ) + warnings.warn(deprecation_msg, category=EODeprecationWarning, stacklevel=3) + + return cls(permissions_mapping[value]) + if isinstance(value, str) and value.upper() in cls._value2member_map_: + return cls(value.upper()) + return super()._missing_(value) diff --git a/core/eolearn/core/core_tasks.py b/core/eolearn/core/core_tasks.py index 47904dcc9..33a0d1eb8 100644 --- a/core/eolearn/core/core_tasks.py +++ b/core/eolearn/core/core_tasks.py @@ -10,7 +10,7 @@ import copy from abc import ABCMeta -from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union, cast +from typing import Any, Callable, Iterable, Tuple, Union, cast import fs import numpy as np @@ -20,6 +20,7 @@ from .constants import FeatureType from .eodata import EOPatch +from .eodata_merge import merge_eopatches from .eotask import EOTask from .types import FeatureSpec, FeaturesSpecification, SingleFeatureSpec from .utils.fs import get_filesystem, pickle_fs, unpickle_fs @@ -48,12 +49,10 @@ def execute(self, eopatch: EOPatch) -> EOPatch: return eopatch.copy(features=self.features, deep=True) -class IOTask(EOTask, metaclass=ABCMeta): # noqa B024 +class IOTask(EOTask, metaclass=ABCMeta): """An abstract Input/Output task that can handle a path and a filesystem object.""" - def __init__( - self, path: str, filesystem: Optional[FS] = None, create: bool = False, config: Optional[SHConfig] = None - ): + def __init__(self, path: str, filesystem: FS | None = None, create: bool = False, config: SHConfig | None = None): """ :param path: root path where all EOPatches are saved :param filesystem: An existing filesystem object. If not given it will be initialized according to the EOPatch @@ -83,7 +82,7 @@ def filesystem(self) -> FS: class SaveTask(IOTask): """Saves the given EOPatch to a filesystem.""" - def __init__(self, path: str, filesystem: Optional[FS] = None, config: Optional[SHConfig] = None, **kwargs: Any): + def __init__(self, path: str, filesystem: FS | None = None, config: SHConfig | None = None, **kwargs: Any): """ :param path: root path where all EOPatches are saved :param filesystem: An existing filesystem object. If not given it will be initialized according to the EOPatch @@ -114,7 +113,7 @@ def execute(self, eopatch: EOPatch, *, eopatch_folder: str = "") -> EOPatch: class LoadTask(IOTask): """Loads an EOPatch from a filesystem.""" - def __init__(self, path: str, filesystem: Optional[FS] = None, config: Optional[SHConfig] = None, **kwargs: Any): + def __init__(self, path: str, filesystem: FS | None = None, config: SHConfig | None = None, **kwargs: Any): """ :param path: root directory where all EOPatches are saved :param filesystem: An existing filesystem object. If not given it will be initialized according to the EOPatch @@ -127,7 +126,7 @@ def __init__(self, path: str, filesystem: Optional[FS] = None, config: Optional[ self.kwargs = kwargs super().__init__(path, filesystem=filesystem, create=False, config=config) - def execute(self, eopatch: Optional[EOPatch] = None, *, eopatch_folder: str = "") -> EOPatch: + def execute(self, eopatch: EOPatch | None = None, *, eopatch_folder: str = "") -> EOPatch: """Loads the EOPatch from disk: `folder/eopatch_folder`. :param eopatch: Optional input EOPatch. If given the loaded features are merged onto it, otherwise a new EOPatch @@ -137,9 +136,7 @@ def execute(self, eopatch: Optional[EOPatch] = None, *, eopatch_folder: str = "" """ path = fs.path.combine(self.filesystem_path, eopatch_folder) loaded_patch = EOPatch.load(path, filesystem=self.filesystem, **self.kwargs) - if eopatch is None: - return loaded_patch - return eopatch.merge(loaded_patch) + return loaded_patch if eopatch is None else merge_eopatches(eopatch, loaded_patch) class AddFeatureTask(EOTask): @@ -257,9 +254,9 @@ class InitializeFeatureTask(EOTask): def __init__( self, features: FeaturesSpecification, - shape: Union[Tuple[int, ...], FeatureSpec], + shape: tuple[int, ...] | FeatureSpec, init_value: int = 0, - dtype: Union[np.dtype, type] = np.uint8, + dtype: np.dtype | type = np.uint8, ): """ :param features: A collection of features to initialize. @@ -270,8 +267,8 @@ def __init__( """ self.features = self.parse_features(features) - self.shape_feature: Optional[Tuple[FeatureType, Optional[str]]] - self.shape: Union[None, Tuple[int, int, int], Tuple[int, int, int, int]] + self.shape_feature: tuple[FeatureType, str | None] | None + self.shape: None | tuple[int, int, int] | tuple[int, int, int, int] try: self.shape_feature = self.parse_feature(shape) # type: ignore[arg-type] @@ -373,7 +370,7 @@ def __init__( self, input_features: FeaturesSpecification, output_features: FeaturesSpecification, - map_function: Optional[Callable] = None, + map_function: Callable | None = None, **kwargs: Any, ): """ @@ -453,7 +450,7 @@ def __init__( self, input_features: FeaturesSpecification, output_feature: SingleFeatureSpec, - zip_function: Optional[Callable] = None, + zip_function: Callable | None = None, **kwargs: Any, ): """ @@ -490,7 +487,7 @@ def zip_method(self, *features: Any) -> Any: class MergeFeatureTask(ZipFeatureTask): """Merges multiple features together by concatenating their data along the specified axis.""" - def zip_method(self, *f: np.ndarray, dtype: Union[None, np.dtype, type] = None, axis: int = -1) -> np.ndarray: + def zip_method(self, *f: np.ndarray, dtype: None | np.dtype | type = None, axis: int = -1) -> np.ndarray: """Concatenates the data of features along the specified axis.""" return np.concatenate(f, axis=axis, dtype=dtype) # pylint: disable=unexpected-keyword-arg @@ -498,7 +495,7 @@ def zip_method(self, *f: np.ndarray, dtype: Union[None, np.dtype, type] = None, class ExtractBandsTask(MapFeatureTask): """Moves a subset of bands from one feature to a new one.""" - def __init__(self, input_feature: FeaturesSpecification, output_feature: FeaturesSpecification, bands: List[int]): + def __init__(self, input_feature: FeaturesSpecification, output_feature: FeaturesSpecification, bands: list[int]): """ :param input_feature: A source feature from which to take the subset of bands. :param output_feature: An output feature to which to write the bands. @@ -519,8 +516,8 @@ class ExplodeBandsTask(EOTask): def __init__( self, - input_feature: Tuple[FeatureType, str], - output_mapping: Dict[Tuple[FeatureType, str], Union[int, Iterable[int]]], + input_feature: tuple[FeatureType, str], + output_mapping: dict[tuple[FeatureType, str], int | Iterable[int]], ): """ :param input_feature: A source feature from which to take the subset of bands. @@ -553,12 +550,12 @@ def execute(self, **kwargs: Any) -> EOPatch: class MergeEOPatchesTask(EOTask): """Merge content from multiple EOPatches into a single EOPatch. - Check :func:`EOPatch.merge` for more information about the merging process. + Check :func:`merge_eopatches` for more information. """ def __init__(self, **merge_kwargs: Any): """ - :param merge_kwargs: Keyword arguments defined for `EOPatch.merge` method. + :param merge_kwargs: Keyword arguments defined for `merge_eopatches` function. """ self.merge_kwargs = merge_kwargs @@ -570,4 +567,4 @@ def execute(self, *eopatches: EOPatch) -> EOPatch: if not eopatches: raise ValueError("At least one EOPatch should be given") - return eopatches[0].merge(*eopatches[1:], **self.merge_kwargs) + return merge_eopatches(*eopatches, **self.merge_kwargs) diff --git a/core/eolearn/core/eodata.py b/core/eolearn/core/eodata.py index 0960c75d0..990858c3d 100644 --- a/core/eolearn/core/eodata.py +++ b/core/eolearn/core/eodata.py @@ -9,26 +9,35 @@ from __future__ import annotations import concurrent.futures +import contextlib import copy import datetime as dt import logging from abc import ABCMeta, abstractmethod -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Set, Tuple, Type, TypeVar, Union, cast, overload +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Iterable, + Iterator, + Literal, + Mapping, + MutableMapping, + TypeVar, + cast, + overload, +) from warnings import warn -import attr -import dateutil.parser import geopandas as gpd import numpy as np from fs.base import FS -from typing_extensions import Literal -from sentinelhub import CRS, BBox +from sentinelhub import CRS, BBox, parse_time from sentinelhub.exceptions import deprecated_function from .constants import TIMESTAMP_COLUMN, FeatureType, OverwritePermission -from .eodata_io import FeatureIO, load_eopatch_content, save_eopatch -from .eodata_merge import merge_eopatches +from .eodata_io import FeatureIO, FeatureIOJson, load_eopatch_content, save_eopatch from .exceptions import EODeprecationWarning from .types import EllipsisType, FeatureSpec, FeaturesSpecification from .utils.common import deep_eq, is_discrete_type @@ -36,7 +45,6 @@ from .utils.parsing import parse_features T = TypeVar("T") -Self = TypeVar("Self") LOGGER = logging.getLogger(__name__) MISSING_BBOX_WARNING = ( @@ -44,18 +52,16 @@ " EOPatches represent geolocated data and so any EOPatch without a BBox is ill-formed. Consider" " using a different data structure for non-geolocated data." ) +TIMESTAMP_RENAME_WARNING = "The attribute `timestamp` is deprecated, use `timestamps` instead." MAX_DATA_REPR_LEN = 100 if TYPE_CHECKING: - try: - from eolearn.visualization import PlotBackend - from eolearn.visualization.eopatch_base import BasePlotConfig - except ImportError: - pass + with contextlib.suppress(ImportError): + from eolearn.visualization import PlotBackend, PlotConfig -class _FeatureDict(Dict[str, Union[T, FeatureIO[T]]], metaclass=ABCMeta): +class _FeatureDict(MutableMapping[str, T], metaclass=ABCMeta): """A dictionary structure that holds features of certain feature type. It checks that features have a correct and dimension. It also supports lazy loading by accepting a function as a @@ -64,7 +70,7 @@ class _FeatureDict(Dict[str, Union[T, FeatureIO[T]]], metaclass=ABCMeta): FORBIDDEN_CHARS = {".", "/", "\\", "|", ";", ":", "\n", "\t"} - def __init__(self, feature_dict: Dict[str, Union[T, FeatureIO[T]]], feature_type: FeatureType): + def __init__(self, feature_dict: Mapping[str, T | FeatureIO[T]], feature_type: FeatureType): """ :param feature_dict: A dictionary of feature names and values :param feature_type: Type of features @@ -72,27 +78,16 @@ def __init__(self, feature_dict: Dict[str, Union[T, FeatureIO[T]]], feature_type super().__init__() self.feature_type = feature_type + self._content = dict(feature_dict.items()) - for feature_name, value in feature_dict.items(): - self[feature_name] = value - - @classmethod - def empty_factory(cls: Type[Self], feature_type: FeatureType) -> Callable[[], Self]: - """Returns a factory function for creating empty feature dictionaries with an appropriate feature type.""" - - def factory() -> Self: - return cls(feature_dict={}, feature_type=feature_type) # type: ignore[call-arg] - - return factory - - def __setitem__(self, feature_name: str, value: Union[T, FeatureIO[T]]) -> None: + def __setitem__(self, feature_name: str, value: T | FeatureIO[T]) -> None: """Before setting value to the dictionary it checks that value is of correct type and dimension and tries to transform value in correct form. """ if not isinstance(value, FeatureIO): value = self._parse_feature_value(value, feature_name) self._check_feature_name(feature_name) - super().__setitem__(feature_name, value) + self._content[feature_name] = value def _check_feature_name(self, feature_name: str) -> None: """Ensures that feature names are strings and do not contain forbidden characters.""" @@ -101,42 +96,41 @@ def _check_feature_name(self, feature_name: str) -> None: for char in feature_name: if char in self.FORBIDDEN_CHARS: - raise ValueError( - f"The name of feature ({self.feature_type}, {feature_name}) contains an illegal character '{char}'." - ) + raise ValueError(f"The feature name of {feature_name} contains an illegal character '{char}'.") if feature_name == "": raise ValueError("Feature name cannot be an empty string.") - @overload - def __getitem__(self, feature_name: str, load: Literal[True] = ...) -> T: - ... - - @overload - def __getitem__(self, feature_name: str, load: Literal[False] = ...) -> Union[T, FeatureIO[T]]: - ... - - def __getitem__(self, feature_name: str, load: bool = True) -> Union[T, FeatureIO[T]]: + def __getitem__(self, feature_name: str) -> T: """Implements lazy loading.""" - value = super().__getitem__(feature_name) + value = self._content[feature_name] - if isinstance(value, FeatureIO) and load: + if isinstance(value, FeatureIO): + value = cast(FeatureIO[T], value) # not sure why mypy borks this one value = value.load() - self[feature_name] = value + self._content[feature_name] = value return value + def _get_unloaded(self, feature_name: str) -> T | FeatureIO[T]: + """Returns the value, bypassing lazy-loading mechanisms.""" + return self._content[feature_name] + + def __delitem__(self, feature_name: str) -> None: + del self._content[feature_name] + def __eq__(self, other: object) -> bool: - """Compares its content against a content of another feature type dictionary.""" + # default doesn't know how to compare numpy arrays return deep_eq(self, other) - def __ne__(self, other: object) -> bool: - """Compares its content against a content of another feature type dictionary.""" - return not self.__eq__(other) + def __len__(self) -> int: + return len(self._content) + + def __contains__(self, key: object) -> bool: + return key in self._content - def get_dict(self) -> Dict[str, T]: - """Returns a Python dictionary of features and value.""" - return dict(self) + def __iter__(self) -> Iterator[str]: + return iter(self._content) @abstractmethod def _parse_feature_value(self, value: object, feature_name: str) -> T: @@ -149,22 +143,15 @@ def _parse_feature_value(self, value: object, feature_name: str) -> T: class _FeatureDictNumpy(_FeatureDict[np.ndarray]): """_FeatureDict object specialized for Numpy arrays.""" - def __init__(self, feature_dict: Dict[str, Union[np.ndarray, FeatureIO[np.ndarray]]], feature_type: FeatureType): - ndim = feature_type.ndim() - if ndim is None: - raise ValueError(f"Feature type {feature_type} does not represent a Numpy based feature.") - self.ndim = ndim - super().__init__(feature_dict, feature_type) - def _parse_feature_value(self, value: object, feature_name: str) -> np.ndarray: if not isinstance(value, np.ndarray): raise ValueError(f"{self.feature_type} feature has to be a numpy array.") - if not hasattr(self, "ndim"): # Because of serialization/deserialization during multiprocessing - return value - if value.ndim != self.ndim: + + expected_ndim = cast(int, self.feature_type.ndim()) # numpy features have ndim + if value.ndim != expected_ndim: raise ValueError( - f"Numpy array of {self.feature_type} feature has to have {self.ndim} " - f"dimension{'s' if self.ndim > 1 else ''} but feature {feature_name} has {value.ndim}." + f"Numpy array of {self.feature_type} feature has to have {expected_ndim} " + f"dimension{'s' if expected_ndim > 1 else ''} but feature {feature_name} has {value.ndim}." ) if self.feature_type.is_discrete() and not is_discrete_type(value.dtype): @@ -179,11 +166,6 @@ def _parse_feature_value(self, value: object, feature_name: str) -> np.ndarray: class _FeatureDictGeoDf(_FeatureDict[gpd.GeoDataFrame]): """_FeatureDict object specialized for GeoDataFrames.""" - def __init__(self, feature_dict: Dict[str, gpd.GeoDataFrame], feature_type: FeatureType): - if not feature_type.is_vector(): - raise ValueError(f"Feature type {feature_type} does not represent a vector feature.") - super().__init__(feature_dict, feature_type) - def _parse_feature_value(self, value: object, feature_name: str) -> gpd.GeoDataFrame: if isinstance(value, gpd.GeoSeries): value = gpd.GeoDataFrame(geometry=value, crs=value.crs) @@ -210,7 +192,7 @@ def _parse_feature_value(self, value: object, _: str) -> Any: return value -def _create_feature_dict(feature_type: FeatureType, value: Dict[str, Any]) -> _FeatureDict: +def _create_feature_dict(feature_type: FeatureType, value: Mapping[str, Any]) -> _FeatureDict: """Creates the correct FeatureDict, corresponding to the FeatureType.""" if feature_type.is_vector(): return _FeatureDictGeoDf(value, feature_type) @@ -219,7 +201,6 @@ def _create_feature_dict(feature_type: FeatureType, value: Dict[str, Any]) -> _F return _FeatureDictNumpy(value, feature_type) -@attr.s(repr=False, eq=False, kw_only=True) class EOPatch: """The basic data object for multi-temporal remotely sensed data, such as satellite imagery and its derivatives. @@ -240,149 +221,148 @@ class EOPatch: arrays in other attributes. """ - data: _FeatureDictNumpy = attr.ib(factory=_FeatureDictNumpy.empty_factory(FeatureType.DATA)) - mask: _FeatureDictNumpy = attr.ib(factory=_FeatureDictNumpy.empty_factory(FeatureType.MASK)) - scalar: _FeatureDictNumpy = attr.ib(factory=_FeatureDictNumpy.empty_factory(FeatureType.SCALAR)) - label: _FeatureDictNumpy = attr.ib(factory=_FeatureDictNumpy.empty_factory(FeatureType.LABEL)) - vector: _FeatureDictGeoDf = attr.ib(factory=_FeatureDictGeoDf.empty_factory(FeatureType.VECTOR)) - data_timeless: _FeatureDictNumpy = attr.ib(factory=_FeatureDictNumpy.empty_factory(FeatureType.DATA_TIMELESS)) - mask_timeless: _FeatureDictNumpy = attr.ib(factory=_FeatureDictNumpy.empty_factory(FeatureType.MASK_TIMELESS)) - scalar_timeless: _FeatureDictNumpy = attr.ib(factory=_FeatureDictNumpy.empty_factory(FeatureType.SCALAR_TIMELESS)) - label_timeless: _FeatureDictNumpy = attr.ib(factory=_FeatureDictNumpy.empty_factory(FeatureType.LABEL_TIMELESS)) - vector_timeless: _FeatureDictGeoDf = attr.ib(factory=_FeatureDictGeoDf.empty_factory(FeatureType.VECTOR_TIMELESS)) - meta_info: _FeatureDictJson = attr.ib(factory=_FeatureDictJson.empty_factory(FeatureType.META_INFO)) - bbox: Optional[BBox] = attr.ib(default=None) - timestamps: List[dt.datetime] = attr.ib(factory=list) - - def __attrs_post_init__(self) -> None: - if self.bbox is None: - warn(MISSING_BBOX_WARNING, category=EODeprecationWarning, stacklevel=2) + # establish types of property value holders + _timestamps: list[dt.datetime] + _bbox: BBox | None + _meta_info: FeatureIOJson | _FeatureDictJson - @property - def timestamp(self) -> List[dt.datetime]: - """A property for handling the deprecated timestamp attribute. - - :return: A list of EOPatch timestamps - """ - warn( - "The attribute `timestamp` is deprecated, use `timestamps` instead.", - category=EODeprecationWarning, - stacklevel=2, + def __init__( + self, + *, + data: Mapping[str, np.ndarray] | None = None, + mask: Mapping[str, np.ndarray] | None = None, + scalar: Mapping[str, np.ndarray] | None = None, + label: Mapping[str, np.ndarray] | None = None, + vector: Mapping[str, gpd.GeoDataFrame] | None = None, + data_timeless: Mapping[str, np.ndarray] | None = None, + mask_timeless: Mapping[str, np.ndarray] | None = None, + scalar_timeless: Mapping[str, np.ndarray] | None = None, + label_timeless: Mapping[str, np.ndarray] | None = None, + vector_timeless: Mapping[str, gpd.GeoDataFrame] | None = None, + meta_info: Mapping[str, Any] | None = None, + bbox: BBox | None = None, + timestamps: list[dt.datetime] | None = None, + ): + self.data: MutableMapping[str, np.ndarray] = _FeatureDictNumpy(data or {}, FeatureType.DATA) + self.mask: MutableMapping[str, np.ndarray] = _FeatureDictNumpy(mask or {}, FeatureType.MASK) + self.scalar: MutableMapping[str, np.ndarray] = _FeatureDictNumpy(scalar or {}, FeatureType.SCALAR) + self.label: MutableMapping[str, np.ndarray] = _FeatureDictNumpy(label or {}, FeatureType.LABEL) + self.vector: MutableMapping[str, gpd.GeoDataFrame] = _FeatureDictGeoDf(vector or {}, FeatureType.VECTOR) + self.data_timeless: MutableMapping[str, np.ndarray] = _FeatureDictNumpy( + data_timeless or {}, FeatureType.DATA_TIMELESS ) + self.mask_timeless: MutableMapping[str, np.ndarray] = _FeatureDictNumpy( + mask_timeless or {}, FeatureType.MASK_TIMELESS + ) + self.scalar_timeless: MutableMapping[str, np.ndarray] = _FeatureDictNumpy( + scalar_timeless or {}, FeatureType.SCALAR_TIMELESS + ) + self.label_timeless: MutableMapping[str, np.ndarray] = _FeatureDictNumpy( + label_timeless or {}, FeatureType.LABEL_TIMELESS + ) + self.vector_timeless: MutableMapping[str, gpd.GeoDataFrame] = _FeatureDictGeoDf( + vector_timeless or {}, FeatureType.VECTOR_TIMELESS + ) + self.meta_info: MutableMapping[str, Any] = _FeatureDictJson(meta_info or {}, FeatureType.META_INFO) + self.bbox = bbox + self.timestamps = timestamps or [] + + @property + def timestamp(self) -> list[dt.datetime]: + """A property for handling the deprecated timestamp attribute.""" + warn(TIMESTAMP_RENAME_WARNING, category=EODeprecationWarning, stacklevel=2) return self.timestamps @timestamp.setter - def timestamp(self, value: List[dt.datetime]) -> None: - warn( - "The attribute `timestamp` is deprecated, use `timestamps` instead.", - category=EODeprecationWarning, - stacklevel=2, - ) + def timestamp(self, value: list[dt.datetime]) -> None: + warn(TIMESTAMP_RENAME_WARNING, category=EODeprecationWarning, stacklevel=2) self.timestamps = value - def __setattr__(self, key: str, value: object, feature_name: Union[str, None, EllipsisType] = None) -> None: - """Raises TypeError if feature type attributes are not of correct type. - - In case they are a dictionary they are cast to _FeatureDict class. - """ - if feature_name not in (None, Ellipsis) and FeatureType.has_value(key): - self.__getattribute__(key)[feature_name] = value - return - - if FeatureType.has_value(key) and not isinstance(value, FeatureIO): - feature_type = FeatureType(key) - value = self._parse_feature_type_value(feature_type, value) - - super().__setattr__(key, value) - - @staticmethod - def _parse_feature_type_value( - feature_type: FeatureType, value: object - ) -> Union[_FeatureDict, BBox, List[dt.date], None]: - """Checks or parses value which will be assigned to a feature type attribute of `EOPatch`. If the value - cannot be parsed correctly it raises an error. - - :raises: TypeError, ValueError - """ - - if feature_type is FeatureType.BBOX and (value is None or isinstance(value, BBox)): - if value is None: - warn(MISSING_BBOX_WARNING, category=EODeprecationWarning, stacklevel=2) - return value + @property + def timestamps(self) -> list[dt.datetime]: + """A property for handling the `timestamps` attribute.""" + return self._timestamps + + @timestamps.setter + def timestamps(self, value: Iterable[dt.datetime]) -> None: + if isinstance(value, Iterable) and all(isinstance(time, (dt.date, str)) for time in value): + self._timestamps = [parse_time(time, force_datetime=True) for time in value] + else: + raise TypeError(f"Cannot assign {value} as timestamps. Should be a sequence of datetime.datetime objects.") - if feature_type is FeatureType.TIMESTAMPS and isinstance(value, (tuple, list)): - return [ - timestamp if isinstance(timestamp, dt.date) else dateutil.parser.parse(timestamp) for timestamp in value - ] + @property + def bbox(self) -> BBox | None: + """A property for handling the `bbox` attribute.""" + return self._bbox + + @bbox.setter + def bbox(self, value: BBox | None) -> None: + if not (isinstance(value, BBox) or value is None): + raise TypeError(f"Cannot assign {value} as bbox. Should be a `BBox` object.") + if value is None: + warn(MISSING_BBOX_WARNING, category=EODeprecationWarning, stacklevel=2) + self._bbox = value - if isinstance(value, dict): - return value if isinstance(value, _FeatureDict) else _create_feature_dict(feature_type, value) + @property + def meta_info(self) -> MutableMapping[str, Any]: + """A property for handling the `meta_info` attribute.""" + # once META_INFO becomes regular (in terms of IO) this can be removed + if isinstance(self._meta_info, FeatureIOJson): + self.meta_info = self._meta_info.load() # assigned to `meta_info` property to trigger validation + return self._meta_info # type: ignore[return-value] # mypy cannot verify due to mutations - raise TypeError(f"Cannot parse given value {value} for feature type {feature_type}. Possible type missmatch.") + @meta_info.setter + def meta_info(self, value: Mapping[str, Any] | FeatureIOJson) -> None: + self._meta_info = value if isinstance(value, FeatureIOJson) else _FeatureDictJson(value, FeatureType.META_INFO) - def __getattribute__(self, key: str, load: bool = True, feature_name: Union[str, None, EllipsisType] = None) -> Any: - """Handles lazy loading and can even provide a single feature from _FeatureDict.""" - value = super().__getattribute__(key) + def __setattr__(self, key: str, value: object) -> None: + """Casts dictionaries to _FeatureDict objects for non-meta features.""" - if isinstance(value, FeatureIO) and load: - value = value.load() - setattr(self, key, value) - value = getattr(self, key) + if FeatureType.has_value(key) and not FeatureType(key).is_meta(): + if not isinstance(value, (dict, _FeatureDict)): + raise TypeError(f"Cannot parse {value} for attribute {key}. Should be a dictionary.") + value = _create_feature_dict(FeatureType(key), value) - if feature_name not in (None, Ellipsis) and isinstance(value, _FeatureDict): - feature_name = cast(str, feature_name) # the above check deals with ... and None - return value[feature_name] - - return value + super().__setattr__(key, value) @overload - def __getitem__(self, key: Union[Literal[FeatureType.BBOX], Tuple[Literal[FeatureType.BBOX], Any]]) -> BBox: + def __getitem__(self, key: Literal[FeatureType.BBOX] | tuple[Literal[FeatureType.BBOX], Any]) -> BBox: ... @overload def __getitem__( - self, key: Union[Literal[FeatureType.TIMESTAMPS], Tuple[Literal[FeatureType.TIMESTAMPS], Any]] - ) -> List[dt.datetime]: + self, key: Literal[FeatureType.TIMESTAMPS] | tuple[Literal[FeatureType.TIMESTAMPS], Any] + ) -> list[dt.datetime]: ... @overload - def __getitem__(self, key: Union[FeatureType, Tuple[FeatureType, Union[str, None, EllipsisType]]]) -> Any: + def __getitem__(self, key: FeatureType | tuple[FeatureType, str | None | EllipsisType]) -> Any: ... - def __getitem__(self, key: Union[FeatureType, Tuple[FeatureType, Union[str, None, EllipsisType]]]) -> Any: + def __getitem__(self, key: FeatureType | tuple[FeatureType, str | None | EllipsisType]) -> Any: """Provides features of requested feature type. It can also accept a tuple of (feature_type, feature_name). :param key: Feature type or a (feature_type, feature_name) pair. """ - if isinstance(key, tuple): - feature_type, feature_name = key - else: - feature_type, feature_name = key, None + feature_type, feature_name = key if isinstance(key, tuple) else (key, None) + value = getattr(self, FeatureType(feature_type).value) + if feature_name not in (None, Ellipsis) and isinstance(value, _FeatureDict): + feature_name = cast(str, feature_name) # the above check deals with ... and None + return value[feature_name] + return value - ftype = FeatureType(feature_type).value - return self.__getattribute__(ftype, feature_name=feature_name) # type: ignore[call-arg] + def __setitem__(self, key: FeatureType | tuple[FeatureType, str | None | EllipsisType], value: Any) -> None: + """Sets a new value to the given FeatureType or tuple of (feature_type, feature_name).""" + feature_type, feature_name = key if isinstance(key, tuple) else (key, None) + ftype_attr = FeatureType(feature_type).value - def __setitem__( - self, key: Union[FeatureType, Tuple[FeatureType, Union[str, None, EllipsisType]]], value: Any - ) -> None: - """Sets a new dictionary / list to the given FeatureType. As a key it can also accept a tuple of - (feature_type, feature_name). - - :param key: Type of EOPatch feature - :param value: New dictionary or list - """ - if isinstance(key, tuple): - feature_type, feature_name = key + if feature_name not in (None, Ellipsis): + getattr(self, ftype_attr)[feature_name] = value else: - feature_type, feature_name = key, None - - return self.__setattr__(FeatureType(feature_type).value, value, feature_name=feature_name) - - def __delitem__(self, feature: Union[FeatureType, FeatureSpec]) -> None: - """Deletes the selected feature type or feature. + setattr(self, ftype_attr, value) - :param feature: EOPatch feature - """ + def __delitem__(self, feature: FeatureType | FeatureSpec) -> None: + """Deletes the selected feature type or feature.""" if isinstance(feature, tuple): feature_type, feature_name = feature if feature_type in [FeatureType.BBOX, FeatureType.TIMESTAMPS]: @@ -421,6 +401,7 @@ def __contains__(self, key: object) -> bool: "`(feature_type, feature_name)` pairs." ) + @deprecated_function(EODeprecationWarning, "Use the `merge` method instead.") def __add__(self, other: EOPatch) -> EOPatch: """Merges two EOPatches into a new EOPatch.""" return self.merge(other) @@ -432,7 +413,8 @@ def __repr__(self) -> str: if not content: continue - if isinstance(content, dict): + if isinstance(content, _FeatureDict): + content = {k: content._get_unloaded(k) for k in content} # noqa: SLF001 inner_content_repr = "\n ".join( [f"{label}: {self._repr_value(value)}" for label, value in sorted(content.items())] ) @@ -467,10 +449,10 @@ def _repr_value(value: object) -> str: l_bracket, r_bracket = ("[", "]") if isinstance(value, list) else ("(", ")") if isinstance(value, (list, tuple)) and len(value) > 2: - repr_str = f"{l_bracket}{repr(value[0])}, ..., {repr(value[-1])}{r_bracket}" + repr_str = f"{l_bracket}{value[0]!r}, ..., {value[-1]!r}{r_bracket}" if len(repr_str) > MAX_DATA_REPR_LEN and isinstance(value, (list, tuple)) and len(value) > 1: - repr_str = f"{l_bracket}{repr(value[0])}, ...{r_bracket}" + repr_str = f"{l_bracket}{value[0]!r}, ...{r_bracket}" if len(repr_str) > MAX_DATA_REPR_LEN: repr_str = str(type(value)) @@ -498,10 +480,10 @@ def __copy__(self, features: FeaturesSpecification = ...) -> EOPatch: if feature_type in (FeatureType.BBOX, FeatureType.TIMESTAMPS): new_eopatch[feature_type] = copy.copy(self[feature_type]) else: - new_eopatch[feature_type][feature_name] = self[feature_type].__getitem__(feature_name, load=False) + new_eopatch[feature_type][feature_name] = self[feature_type]._get_unloaded(feature_name) # noqa: SLF001 return new_eopatch - def __deepcopy__(self, memo: Optional[dict] = None, features: FeaturesSpecification = ...) -> EOPatch: + def __deepcopy__(self, memo: dict | None = None, features: FeaturesSpecification = ...) -> EOPatch: """Returns a new EOPatch with deep copies of given features. :param memo: built-in parameter for memoization @@ -515,7 +497,7 @@ def __deepcopy__(self, memo: Optional[dict] = None, features: FeaturesSpecificat if feature_type in (FeatureType.BBOX, FeatureType.TIMESTAMPS): new_eopatch[feature_type] = copy.deepcopy(self[feature_type], memo=memo) else: - value = self[feature_type].__getitem__(feature_name, load=False) + value = self[feature_type]._get_unloaded(feature_name) # noqa: SLF001 if isinstance(value, FeatureIO): # We cannot deepcopy the entire object because of the filesystem attribute @@ -555,7 +537,7 @@ def reset_feature_type(self, feature_type: FeatureType) -> None: else: self[feature_type] = {} - def get_spatial_dimension(self, feature_type: FeatureType, feature_name: str) -> Tuple[int, int]: + def get_spatial_dimension(self, feature_type: FeatureType, feature_name: str) -> tuple[int, int]: """ Returns a tuple of spatial dimensions (height, width) of a feature. @@ -568,12 +550,12 @@ def get_spatial_dimension(self, feature_type: FeatureType, feature_name: str) -> raise ValueError(f"Features of type {feature_type} do not have a spatial dimension or are not arrays.") - def get_features(self) -> List[FeatureSpec]: + def get_features(self) -> list[FeatureSpec]: """Returns a list of all non-empty features of EOPatch. :return: List of non-empty features """ - feature_list: List[FeatureSpec] = [] + feature_list: list[FeatureSpec] = [] for feature_type in FeatureType: if feature_type is FeatureType.BBOX or feature_type is FeatureType.TIMESTAMPS: if feature_type in self: @@ -589,7 +571,7 @@ def save( features: FeaturesSpecification = ..., overwrite_permission: OverwritePermission = OverwritePermission.ADD_ONLY, compress_level: int = 0, - filesystem: Optional[FS] = None, + filesystem: FS | None = None, ) -> None: """Method to save an EOPatch from memory to a storage. @@ -617,7 +599,7 @@ def save( @staticmethod def load( - path: str, features: FeaturesSpecification = ..., lazy_loading: bool = False, filesystem: Optional[FS] = None + path: str, features: FeaturesSpecification = ..., lazy_loading: bool = False, filesystem: FS | None = None ) -> EOPatch: """Method to load an EOPatch from a storage into memory. @@ -632,11 +614,11 @@ def load( filesystem = get_filesystem(path, create=False) path = "/" - bbox, timestamps, meta_info, features_dict = load_eopatch_content(filesystem, path, features=features) - eopatch = EOPatch(bbox=bbox) # type: ignore[arg-type] + bbox_io, timestamps_io, meta_info, features_dict = load_eopatch_content(filesystem, path, features=features) + eopatch = EOPatch(bbox=None if bbox_io is None else bbox_io.load()) - if timestamps is not None: - eopatch.timestamps = timestamps # type: ignore[assignment] + if timestamps_io is not None: + eopatch.timestamps = timestamps_io.load() if meta_info is not None: eopatch.meta_info = meta_info # type: ignore[assignment] for feature, feature_io in features_dict.items(): @@ -646,12 +628,13 @@ def load( _trigger_loading_for_eopatch_features(eopatch) return eopatch + @deprecated_function(EODeprecationWarning, "Use the function `eolearn.core.merge_eopatches` instead.") def merge( self, *eopatches: EOPatch, features: FeaturesSpecification = ..., - time_dependent_op: Union[Literal[None, "concatenate", "min", "max", "mean", "median"], Callable] = None, - timeless_op: Union[Literal[None, "concatenate", "min", "max", "mean", "median"], Callable] = None, + time_dependent_op: Literal[None, "concatenate", "min", "max", "mean", "median"] | Callable = None, + timeless_op: Literal[None, "concatenate", "min", "max", "mean", "median"] | Callable = None, ) -> EOPatch: """Merge features of given EOPatches into a new EOPatch. @@ -678,17 +661,13 @@ def merge( - 'median': Join arrays by taking median values. Ignore NaN values. :return: A merged EOPatch """ - eopatch_content = merge_eopatches( + from .eodata_merge import merge_eopatches # pylint: disable=import-outside-toplevel, cyclic-import + + return merge_eopatches( self, *eopatches, features=features, time_dependent_op=time_dependent_op, timeless_op=timeless_op ) - merged_eopatch = EOPatch(bbox=eopatch_content[(FeatureType.BBOX, None)]) - for feature, value in eopatch_content.items(): - merged_eopatch[feature] = value - - return merged_eopatch - - def consolidate_timestamps(self, timestamps: List[dt.datetime]) -> Set[dt.datetime]: + def consolidate_timestamps(self, timestamps: list[dt.datetime]) -> set[dt.datetime]: """Removes all frames from the EOPatch with a date not found in the provided timestamps list. :param timestamps: keep frames with date found in this list @@ -711,12 +690,12 @@ def plot( self, feature: FeatureSpec, *, - times: Union[List[int], slice, None] = None, - channels: Union[List[int], slice, None] = None, - channel_names: Optional[List[str]] = None, - rgb: Optional[Tuple[int, int, int]] = None, - backend: Union[str, PlotBackend] = "matplotlib", - config: Optional[BasePlotConfig] = None, + times: list[int] | slice | None = None, + channels: list[int] | slice | None = None, + channel_names: list[str] | None = None, + rgb: tuple[int, int, int] | None = None, + backend: str | PlotBackend = "matplotlib", + config: PlotConfig | None = None, **kwargs: Any, ) -> object: """Plots an `EOPatch` feature. diff --git a/core/eolearn/core/eodata_io.py b/core/eolearn/core/eodata_io.py index 83ec2278d..cf04af851 100644 --- a/core/eolearn/core/eodata_io.py +++ b/core/eolearn/core/eodata_io.py @@ -19,7 +19,19 @@ from collections import defaultdict from dataclasses import dataclass, field from functools import partial -from typing import TYPE_CHECKING, Any, BinaryIO, Dict, Generic, Iterator, List, Optional, Tuple, Type, TypeVar, Union +from typing import ( + TYPE_CHECKING, + Any, + BinaryIO, + Dict, + Generic, + Iterator, + List, + Mapping, + Optional, + Tuple, + TypeVar, +) import dateutil.parser import fs @@ -39,7 +51,6 @@ from .exceptions import EODeprecationWarning from .types import EllipsisType, FeatureSpec, FeaturesSpecification from .utils.parsing import FeatureParser -from .utils.vector_io import infer_schema if TYPE_CHECKING: from .eodata import EOPatch @@ -62,12 +73,12 @@ class FilesystemDataInfo: """Information about data that is present on the filesystem. Fields represent paths to relevant file.""" - timestamps: Optional[str] = None - bbox: Optional[str] = None - meta_info: Optional[str] = None - features: Dict[FeatureType, Dict[str, str]] = field(default_factory=lambda: defaultdict(dict)) + timestamps: str | None = None + bbox: str | None = None + meta_info: str | None = None + features: dict[FeatureType, dict[str, str]] = field(default_factory=lambda: defaultdict(dict)) - def iterate_features(self) -> Iterator[Tuple[Tuple[FeatureType, str], str]]: + def iterate_features(self) -> Iterator[tuple[tuple[FeatureType, str], str]]: """Yields `(ftype, fname), path` tuples from `features`.""" for ftype, ftype_dict in self.features.items(): for fname, path in ftype_dict.items(): @@ -93,8 +104,10 @@ def save_eopatch( # Data must be collected before any tinkering with files due to lazy-loading data_for_saving = list(_yield_features_to_save(eopatch, eopatch_features, patch_location)) - if overwrite_permission is OverwritePermission.OVERWRITE_PATCH and patch_exists: - _remove_old_eopatch(filesystem, patch_location) + with warnings.catch_warnings(): + warnings.simplefilter("ignore", category=EODeprecationWarning) + if overwrite_permission is OverwritePermission.OVERWRITE_PATCH and patch_exists: + _remove_old_eopatch(filesystem, patch_location) ftype_folders = {fs.path.dirname(path) for _, _, path in data_for_saving} for folder in ftype_folders: @@ -104,8 +117,10 @@ def save_eopatch( save_function = partial(_save_single_feature, filesystem=filesystem, compress_level=compress_level) list(executor.map(save_function, data_for_saving)) # Wrapped in a list to get better exceptions - if overwrite_permission is not OverwritePermission.OVERWRITE_PATCH: - remove_redundant_files(filesystem, eopatch_features, file_information, compress_level) + with warnings.catch_warnings(): + warnings.simplefilter("ignore", category=EODeprecationWarning) + if overwrite_permission is not OverwritePermission.OVERWRITE_PATCH: + remove_redundant_files(filesystem, eopatch_features, file_information, compress_level) def _remove_old_eopatch(filesystem: FS, patch_location: str) -> None: @@ -114,8 +129,8 @@ def _remove_old_eopatch(filesystem: FS, patch_location: str) -> None: def _yield_features_to_save( - eopatch: EOPatch, eopatch_features: List[FeatureSpec], patch_location: str -) -> Iterator[Tuple[Type[FeatureIO], Any, str]]: + eopatch: EOPatch, eopatch_features: list[FeatureSpec], patch_location: str +) -> Iterator[tuple[type[FeatureIO], Any, str]]: """Prepares a triple `(featureIO, data, path)` so that the `featureIO` can save `data` to `path`.""" get_file_path = partial(fs.path.join, patch_location) meta_features = {ftype for ftype, _ in eopatch_features if ftype.is_meta()} @@ -134,20 +149,20 @@ def _yield_features_to_save( yield (_get_feature_io_constructor(ftype), eopatch[(ftype, fname)], get_file_path(ftype.value, fname)) -def _save_single_feature(save_spec: Tuple[Type[FeatureIO[T]], T, str], *, filesystem: FS, compress_level: int) -> None: +def _save_single_feature(save_spec: tuple[type[FeatureIO[T]], T, str], *, filesystem: FS, compress_level: int) -> None: feature_io, data, feature_path = save_spec feature_io.save(data, filesystem, feature_path, compress_level) def remove_redundant_files( filesystem: FS, - eopatch_features: List[FeatureSpec], + eopatch_features: list[FeatureSpec], preexisting_files: FilesystemDataInfo, current_compress_level: int, ) -> None: """Removes files that should have been overwritten but were not due to different compression levels.""" - def has_different_compression(path: Optional[str]) -> bool: + def has_different_compression(path: str | None) -> bool: return path is not None and MimeType.GZIP.matches_extension(path) != (current_compress_level > 0) files_to_remove = [] @@ -180,7 +195,7 @@ def load_eopatch_content( file_information = get_filesystem_data_info(filesystem, patch_location, features) bbox, timestamps, meta_info = _load_meta_features(filesystem, file_information, features) - features_dict: Dict[Tuple[FeatureType, str], FeatureIO] = {} + features_dict: dict[tuple[FeatureType, str], FeatureIO] = {} for ftype, fname in FeatureParser(features).get_feature_specifications(): if ftype.is_meta(): continue @@ -199,7 +214,7 @@ def load_eopatch_content( def _load_meta_features( filesystem: FS, file_information: FilesystemDataInfo, features: FeaturesSpecification -) -> Tuple[Optional[FeatureIOBBox], Optional[FeatureIOTimestamps], Optional[FeatureIOJson]]: +) -> tuple[FeatureIOBBox | None, FeatureIOTimestamps | None, FeatureIOJson | None]: requested = {ftype for ftype, _ in FeatureParser(features).get_feature_specifications() if ftype.is_meta()} err_msg = "Feature {} is specified to be loaded but does not exist in EOPatch." @@ -278,7 +293,7 @@ def get_filesystem_data_info( @deprecated_function(category=EODeprecationWarning) def walk_filesystem( filesystem: FS, patch_location: str, features: FeaturesSpecification = ... -) -> Iterator[Tuple[FeatureType, Union[str, EllipsisType], str]]: +) -> Iterator[tuple[FeatureType, str | EllipsisType, str]]: """Interface to the old walk_filesystem function which yields tuples of (feature_type, feature_name, file_path).""" file_information = get_filesystem_data_info(filesystem, patch_location, features) @@ -295,7 +310,7 @@ def walk_filesystem( yield (*feature, path) -def walk_feature_type_folder(filesystem: FS, folder_path: str) -> Iterator[Tuple[str, str]]: +def walk_feature_type_folder(filesystem: FS, folder_path: str) -> Iterator[tuple[str, str]]: """Walks a feature type subfolder of EOPatch and yields tuples (feature name, path in filesystem). Skips folders and files in subfolders. """ @@ -305,7 +320,7 @@ def walk_feature_type_folder(filesystem: FS, folder_path: str) -> Iterator[Tuple def _check_collisions( - overwrite_permission: OverwritePermission, eopatch_features: List[FeatureSpec], existing_files: FilesystemDataInfo + overwrite_permission: OverwritePermission, eopatch_features: list[FeatureSpec], existing_files: FilesystemDataInfo ) -> None: """Checks for possible name collisions to avoid unintentional overwriting.""" if overwrite_permission is OverwritePermission.ADD_ONLY: @@ -319,7 +334,7 @@ def _check_collisions( _check_letter_case_collisions(eopatch_features, FilesystemDataInfo()) -def _check_add_only_permission(eopatch_features: List[FeatureSpec], filesystem_features: FilesystemDataInfo) -> None: +def _check_add_only_permission(eopatch_features: list[FeatureSpec], filesystem_features: FilesystemDataInfo) -> None: """Checks that no existing feature will be overwritten.""" unique_filesystem_features = {_to_lowercase(*feature) for feature, _ in filesystem_features.iterate_features()} unique_eopatch_features = {_to_lowercase(*feature) for feature in eopatch_features} @@ -329,7 +344,7 @@ def _check_add_only_permission(eopatch_features: List[FeatureSpec], filesystem_f raise ValueError(f"Cannot save features {intersection} with overwrite_permission=OverwritePermission.ADD_ONLY") -def _check_letter_case_collisions(eopatch_features: List[FeatureSpec], filesystem_features: FilesystemDataInfo) -> None: +def _check_letter_case_collisions(eopatch_features: list[FeatureSpec], filesystem_features: FilesystemDataInfo) -> None: """Check that features have no name clashes (ignoring case) with other EOPatch features and saved features.""" lowercase_features = {_to_lowercase(*feature) for feature in eopatch_features} @@ -344,7 +359,7 @@ def _check_letter_case_collisions(eopatch_features: List[FeatureSpec], filesyste ) -def _to_lowercase(ftype: FeatureType, fname: Optional[str], *_: Any) -> Tuple[FeatureType, Optional[str]]: +def _to_lowercase(ftype: FeatureType, fname: str | None, *_: Any) -> tuple[FeatureType, str | None]: """Transforms a feature to it's lowercase representation.""" return ftype, fname if fname is None else fname.lower() @@ -372,7 +387,7 @@ def __init__(self, path: str, filesystem: FS): self.path = path self.filesystem = filesystem - self.loaded_value: Optional[T] = None + self.loaded_value: T | None = None @classmethod @abstractmethod @@ -399,7 +414,7 @@ def load(self) -> T: return self.loaded_value @abstractmethod - def _read_from_file(self, file: Union[BinaryIO, gzip.GzipFile]) -> T: + def _read_from_file(self, file: BinaryIO | gzip.GzipFile) -> T: """Loads from a file and decodes content.""" @classmethod @@ -436,7 +451,7 @@ def _save(cls, data: T, filesystem: FS, path: str, compress_level: int) -> None: @classmethod @abstractmethod - def _write_to_file(cls, data: T, file: Union[BinaryIO, gzip.GzipFile], path: str) -> None: + def _write_to_file(cls, data: T, file: BinaryIO | gzip.GzipFile, path: str) -> None: """Writes data to a file in the appropriate way.""" @@ -447,11 +462,11 @@ class FeatureIONumpy(FeatureIO[np.ndarray]): def get_file_format(cls) -> MimeType: return MimeType.NPY - def _read_from_file(self, file: Union[BinaryIO, gzip.GzipFile]) -> np.ndarray: + def _read_from_file(self, file: BinaryIO | gzip.GzipFile) -> np.ndarray: return np.load(file, allow_pickle=True) @classmethod - def _write_to_file(cls, data: np.ndarray, file: Union[BinaryIO, gzip.GzipFile], _: str) -> None: + def _write_to_file(cls, data: np.ndarray, file: BinaryIO | gzip.GzipFile, _: str) -> None: return np.save(file, data) @@ -462,7 +477,7 @@ class FeatureIOGeoDf(FeatureIO[gpd.GeoDataFrame]): def get_file_format(cls) -> MimeType: return MimeType.GPKG - def _read_from_file(self, file: Union[BinaryIO, gzip.GzipFile]) -> gpd.GeoDataFrame: + def _read_from_file(self, file: BinaryIO | gzip.GzipFile) -> gpd.GeoDataFrame: dataframe = gpd.read_file(file) if dataframe.crs is not None: @@ -477,22 +492,15 @@ def _read_from_file(self, file: Union[BinaryIO, gzip.GzipFile]) -> gpd.GeoDataFr return dataframe @classmethod - def _write_to_file(cls, data: gpd.GeoDataFrame, file: Union[BinaryIO, gzip.GzipFile], path: str) -> None: + def _write_to_file(cls, data: gpd.GeoDataFrame, file: BinaryIO | gzip.GzipFile, path: str) -> None: layer = fs.path.basename(path) - try: - with warnings.catch_warnings(): - warnings.filterwarnings( - "ignore", - message="You are attempting to write an empty DataFrame to file*", - category=UserWarning, - ) - return data.to_file(file, driver="GPKG", encoding="utf-8", layer=layer, index=False) - except ValueError as err: - # This workaround is only required for geopandas<0.11.0 and will be removed in the future. - if data.empty: - schema = infer_schema(data) - return data.to_file(file, driver="GPKG", encoding="utf-8", layer=layer, schema=schema) - raise err + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", + message="You are attempting to write an empty DataFrame to file*", + category=UserWarning, + ) + return data.to_file(file, driver="GPKG", encoding="utf-8", layer=layer, index=False) class FeatureIOJson(FeatureIO[T]): @@ -502,13 +510,13 @@ class FeatureIOJson(FeatureIO[T]): def get_file_format(cls) -> MimeType: return MimeType.JSON - def _read_from_file(self, file: Union[BinaryIO, gzip.GzipFile]) -> T: + def _read_from_file(self, file: BinaryIO | gzip.GzipFile) -> T: return json.load(file) @classmethod - def _write_to_file(cls, data: T, file: Union[BinaryIO, gzip.GzipFile], path: str) -> None: + def _write_to_file(cls, data: T, file: BinaryIO | gzip.GzipFile, path: str) -> None: try: - json_data = json.dumps(data, indent=2, default=_jsonify_timestamp) + json_data = json.dumps(data, indent=2, default=_better_jsonify) except TypeError as exception: raise TypeError( f"Failed to serialize when saving JSON file to {path}. Make sure that this feature type " @@ -521,7 +529,7 @@ def _write_to_file(cls, data: T, file: Union[BinaryIO, gzip.GzipFile], path: str class FeatureIOTimestamps(FeatureIOJson[List[datetime.datetime]]): """FeatureIOJson object specialized for List[dt.datetime].""" - def _read_from_file(self, file: Union[BinaryIO, gzip.GzipFile]) -> List[datetime.datetime]: + def _read_from_file(self, file: BinaryIO | gzip.GzipFile) -> list[datetime.datetime]: data = json.load(file) return [dateutil.parser.parse(timestamp) for timestamp in data] @@ -533,24 +541,26 @@ class FeatureIOBBox(FeatureIO[BBox]): def get_file_format(cls) -> MimeType: return MimeType.GEOJSON - def _read_from_file(self, file: Union[BinaryIO, gzip.GzipFile]) -> BBox: + def _read_from_file(self, file: BinaryIO | gzip.GzipFile) -> BBox: json_data = json.load(file) return Geometry.from_geojson(json_data).bbox @classmethod - def _write_to_file(cls, data: BBox, file: Union[BinaryIO, gzip.GzipFile], _: str) -> None: + def _write_to_file(cls, data: BBox, file: BinaryIO | gzip.GzipFile, _: str) -> None: json_data = json.dumps(data.geojson, indent=2) file.write(json_data.encode()) -def _jsonify_timestamp(param: object) -> str: - """Adds the option to serialize datetime.date objects via isoformat.""" +def _better_jsonify(param: object) -> Any: + """Adds the option to serialize datetime.date and FeatureDict objects via isoformat.""" if isinstance(param, datetime.date): return param.isoformat() + if isinstance(param, Mapping): + return dict(param.items()) raise TypeError(f"Object of type {type(param)} is not yet supported in jsonify utility function") -def _get_feature_io_constructor(ftype: FeatureType) -> Type[FeatureIO]: +def _get_feature_io_constructor(ftype: FeatureType) -> type[FeatureIO]: """Creates the correct FeatureIO, corresponding to the FeatureType.""" if ftype is FeatureType.BBOX: return FeatureIOBBox diff --git a/core/eolearn/core/eodata_merge.py b/core/eolearn/core/eodata_merge.py index f744162cc..e536e6485 100644 --- a/core/eolearn/core/eodata_merge.py +++ b/core/eolearn/core/eodata_merge.py @@ -12,23 +12,20 @@ import functools import itertools as it import warnings -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Sequence, Tuple, Union, cast +from typing import Any, Callable, Literal, Sequence, Union, cast import numpy as np import pandas as pd from geopandas import GeoDataFrame -from typing_extensions import Literal from sentinelhub import BBox from .constants import FeatureType +from .eodata import EOPatch from .exceptions import EORuntimeWarning from .types import FeatureSpec, FeaturesSpecification from .utils.parsing import FeatureParser -if TYPE_CHECKING: - from .eodata import EOPatch - OperationInputType = Union[Literal[None, "concatenate", "min", "max", "mean", "median"], Callable] @@ -37,76 +34,69 @@ def merge_eopatches( features: FeaturesSpecification = ..., time_dependent_op: OperationInputType = None, timeless_op: OperationInputType = None, -) -> Dict[FeatureSpec, Any]: +) -> EOPatch: """Merge features of given EOPatches into a new EOPatch. - :param eopatches: Any number of EOPatches to be merged together + :param eopatches: Any number of EOPatches to be merged together. :param features: A collection of features to be merged together. By default, all features will be merged. - :param time_dependent_op: An operation to be used to join data for any time-dependent raster feature. Before - joining time slices of all arrays will be sorted. Supported options are: + :param time_dependent_op: An operation for joining data for time-dependent raster features. Before joining time + slices of all arrays will be sorted. Supported options are: - - None (default): If time slices with matching timestamps have the same values, take one. Raise an error - otherwise. + - None: If time slices with matching timestamps have the same values, take one. Raise an error otherwise. - 'concatenate': Keep all time slices, even the ones with matching timestamps - 'min': Join time slices with matching timestamps by taking minimum values. Ignore NaN values. - 'max': Join time slices with matching timestamps by taking maximum values. Ignore NaN values. - 'mean': Join time slices with matching timestamps by taking mean values. Ignore NaN values. - 'median': Join time slices with matching timestamps by taking median values. Ignore NaN values. + :param timeless_op: An operation for joining data for timeless raster features. Supported options are: - :param timeless_op: An operation to be used to join data for any timeless raster feature. Supported options - are: - - - None (default): If arrays are the same, take one. Raise an error otherwise. + - None: If arrays are the same, take one. Raise an error otherwise. - 'concatenate': Join arrays over the last (i.e. bands) dimension - 'min': Join arrays by taking minimum values. Ignore NaN values. - 'max': Join arrays by taking maximum values. Ignore NaN values. - 'mean': Join arrays by taking mean values. Ignore NaN values. - 'median': Join arrays by taking median values. Ignore NaN values. - - :return: Contents of a merged EOPatch + :return: A merged EOPatch """ + reduce_timestamps = time_dependent_op != "concatenate" time_dependent_operation = _parse_operation(time_dependent_op, is_timeless=False) timeless_operation = _parse_operation(timeless_op, is_timeless=True) feature_parser = FeatureParser(features) all_features = {feature for eopatch in eopatches for feature in feature_parser.get_features(eopatch)} - eopatch_content: Dict[FeatureSpec, object] = {} timestamps, order_mask_per_eopatch = _merge_timestamps(eopatches, reduce_timestamps) optimize_raster_temporal = _check_if_optimize(eopatches, time_dependent_op) + merged_eopatch = EOPatch(bbox=_get_common_bbox(eopatches), timestamps=timestamps) + for feature in all_features: feature_type, feature_name = feature if feature_type.is_array(): if feature_type.is_temporal(): - eopatch_content[feature] = _merge_time_dependent_raster_feature( + merged_eopatch[feature] = _merge_time_dependent_raster_feature( eopatches, feature, time_dependent_operation, order_mask_per_eopatch, optimize_raster_temporal ) else: - eopatch_content[feature] = _merge_timeless_raster_feature(eopatches, feature, timeless_operation) + merged_eopatch[feature] = _merge_timeless_raster_feature(eopatches, feature, timeless_operation) if feature_type.is_vector(): - eopatch_content[feature] = _merge_vector_feature(eopatches, feature) - - if feature_type is FeatureType.TIMESTAMPS: - eopatch_content[feature] = timestamps + merged_eopatch[feature] = _merge_vector_feature(eopatches, feature) if feature_type is FeatureType.META_INFO: feature_name = cast(str, feature_name) # parser makes sure of it - eopatch_content[feature] = _select_meta_info_feature(eopatches, feature_name) - - eopatch_content[(FeatureType.BBOX, None)] = _get_common_bbox(eopatches) + merged_eopatch[feature] = _select_meta_info_feature(eopatches, feature_name) - return eopatch_content + return merged_eopatch def _parse_operation(operation_input: OperationInputType, is_timeless: bool) -> Callable: """Transforms operation's instruction (i.e. an input string) into a function that can be applied to a list of arrays. If the input already is a function it returns it. """ - defaults: Dict[Optional[str], Callable] = { + defaults: dict[str | None, Callable] = { None: _return_if_equal_operation, "concatenate": functools.partial(np.concatenate, axis=-1 if is_timeless else 0), "mean": functools.partial(np.nanmean, axis=0), @@ -131,7 +121,7 @@ def _return_if_equal_operation(arrays: np.ndarray) -> bool: def _merge_timestamps( eopatches: Sequence[EOPatch], reduce_timestamps: bool -) -> Tuple[List[dt.datetime], List[np.ndarray]]: +) -> tuple[list[dt.datetime], list[np.ndarray]]: """Merges together timestamps from EOPatches. It also prepares a list of masks, one for each EOPatch, how timestamps should be ordered and joined together. """ @@ -215,7 +205,7 @@ def _extract_and_join_time_dependent_feature_values( feature: FeatureSpec, order_mask_per_eopatch: Sequence[np.ndarray], optimize: bool, -) -> Tuple[np.ndarray, np.ndarray]: +) -> tuple[np.ndarray, np.ndarray]: """Collects feature arrays from EOPatches that have them and joins them together. It also joins together corresponding order masks. """ @@ -297,7 +287,7 @@ def _select_meta_info_feature(eopatches: Sequence[EOPatch], feature_name: str) - return values[0] -def _get_common_bbox(eopatches: Sequence[EOPatch]) -> Optional[BBox]: +def _get_common_bbox(eopatches: Sequence[EOPatch]) -> BBox | None: """Makes sure that all EOPatches, which define a bounding box and CRS, define the same ones.""" bboxes = [eopatch.bbox for eopatch in eopatches if eopatch.bbox is not None] @@ -309,13 +299,13 @@ def _get_common_bbox(eopatches: Sequence[EOPatch]) -> Optional[BBox]: raise ValueError("Cannot merge EOPatches because they are defined for different bounding boxes.") -def _extract_feature_values(eopatches: Sequence[EOPatch], feature: FeatureSpec) -> List[Any]: +def _extract_feature_values(eopatches: Sequence[EOPatch], feature: FeatureSpec) -> list[Any]: """A helper function that extracts a feature values from those EOPatches where a feature exists.""" feature_type, feature_name = feature return [eopatch[feature] for eopatch in eopatches if feature_name in eopatch[feature_type]] -def _all_equal(values: Union[Sequence[Any], np.ndarray]) -> bool: +def _all_equal(values: Sequence[Any] | np.ndarray) -> bool: """A helper function that checks if all values in a given list are equal to each other.""" first_value = values[0] diff --git a/core/eolearn/core/eoexecution.py b/core/eolearn/core/eoexecution.py index 9e8b1ee39..2636c75b0 100644 --- a/core/eolearn/core/eoexecution.py +++ b/core/eolearn/core/eoexecution.py @@ -10,6 +10,8 @@ This source code is licensed under the MIT license, see the LICENSE file in the root directory of this source tree. """ +from __future__ import annotations + import concurrent.futures import datetime as dt import inspect @@ -18,11 +20,10 @@ import warnings from dataclasses import dataclass from logging import FileHandler, Filter, Handler, Logger -from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union +from typing import Any, Callable, Protocol, Sequence, Union import fs from fs.base import FS -from typing_extensions import Protocol from .eonode import EONode from .eoworkflow import EOWorkflow, WorkflowResults @@ -49,11 +50,11 @@ class _ProcessingData: serializable with pickle.""" workflow: EOWorkflow - workflow_kwargs: Dict[EONode, Dict[str, object]] + workflow_kwargs: dict[EONode, dict[str, object]] pickled_filesystem: bytes - log_path: Optional[str] + log_path: str | None filter_logs_by_thread: bool - logs_filter: Optional[Filter] + logs_filter: Filter | None logs_handler_factory: _HandlerFactoryType @@ -61,9 +62,9 @@ class _ProcessingData: class _ExecutionRunParams: """Parameters that are used during execution run.""" - workers: Optional[int] + workers: int | None multiprocess: bool - tqdm_kwargs: Dict[str, Any] + tqdm_kwargs: dict[str, Any] class EOExecutor: @@ -78,13 +79,13 @@ class EOExecutor: def __init__( self, workflow: EOWorkflow, - execution_kwargs: Sequence[Dict[EONode, Dict[str, object]]], + execution_kwargs: Sequence[dict[EONode, dict[str, object]]], *, - execution_names: Optional[List[str]] = None, + execution_names: list[str] | None = None, save_logs: bool = False, logs_folder: str = ".", - filesystem: Optional[FS] = None, - logs_filter: Optional[Filter] = None, + filesystem: FS | None = None, + logs_filter: Filter | None = None, logs_handler_factory: _HandlerFactoryType = FileHandler, ): """ @@ -117,15 +118,15 @@ def __init__( self.logs_filter = logs_filter self.logs_handler_factory = logs_handler_factory - self.start_time: Optional[dt.datetime] = None - self.report_folder: Optional[str] = None - self.general_stats: Dict[str, object] = {} - self.execution_results: List[WorkflowResults] = [] + self.start_time: dt.datetime | None = None + self.report_folder: str | None = None + self.general_stats: dict[str, object] = {} + self.execution_results: list[WorkflowResults] = [] @staticmethod def _parse_and_validate_execution_kwargs( - execution_kwargs: Sequence[Dict[EONode, Dict[str, object]]] - ) -> List[Dict[EONode, Dict[str, object]]]: + execution_kwargs: Sequence[dict[EONode, dict[str, object]]] + ) -> list[dict[EONode, dict[str, object]]]: """Parses and validates execution arguments provided by user and raises an error if something is wrong.""" if not isinstance(execution_kwargs, (list, tuple)): raise ValueError("Parameter 'execution_kwargs' should be a list.") @@ -136,7 +137,7 @@ def _parse_and_validate_execution_kwargs( return [input_kwargs or {} for input_kwargs in execution_kwargs] @staticmethod - def _parse_execution_names(execution_names: Optional[List[str]], execution_kwargs: Sequence) -> List[str]: + def _parse_execution_names(execution_names: list[str] | None, execution_kwargs: Sequence) -> list[str]: """Parses a list of execution names.""" if execution_names is None: return [str(num) for num in range(1, len(execution_kwargs) + 1)] @@ -148,13 +149,13 @@ def _parse_execution_names(execution_names: Optional[List[str]], execution_kwarg return execution_names @staticmethod - def _parse_logs_filesystem(filesystem: Optional[FS], logs_folder: str) -> Tuple[FS, str]: + def _parse_logs_filesystem(filesystem: FS | None, logs_folder: str) -> tuple[FS, str]: """Ensures a filesystem and a file path relative to it.""" if filesystem is None: return get_base_filesystem_and_path(logs_folder) return filesystem, logs_folder - def run(self, workers: Optional[int] = 1, multiprocess: bool = True, **tqdm_kwargs: Any) -> List[WorkflowResults]: + def run(self, workers: int | None = 1, multiprocess: bool = True, **tqdm_kwargs: Any) -> list[WorkflowResults]: """Runs the executor with n workers. :param workers: Maximum number of workflows which will be executed in parallel. Default value is `1` which will @@ -177,7 +178,7 @@ def run(self, workers: Optional[int] = 1, multiprocess: bool = True, **tqdm_kwar if self.save_logs: self.filesystem.makedirs(self.report_folder, recreate=True) - log_paths: Sequence[Optional[str]] + log_paths: Sequence[str | None] if self.save_logs: log_paths = self.get_log_paths(full_path=False) else: @@ -207,8 +208,8 @@ def run(self, workers: Optional[int] = 1, multiprocess: bool = True, **tqdm_kwar @classmethod def _run_execution( - cls, processing_args: List[_ProcessingData], run_params: _ExecutionRunParams - ) -> List[WorkflowResults]: + cls, processing_args: list[_ProcessingData], run_params: _ExecutionRunParams + ) -> list[WorkflowResults]: """Parallelizes the execution for each item of processing_args list.""" return parallelize( cls._execute_workflow, @@ -221,12 +222,12 @@ def _run_execution( @classmethod def _try_add_logging( cls, - log_path: Optional[str], + log_path: str | None, pickled_filesystem: bytes, filter_logs_by_thread: bool, - logs_filter: Optional[Filter], + logs_filter: Filter | None, logs_handler_factory: _HandlerFactoryType, - ) -> Tuple[Optional[Logger], Optional[Handler]]: + ) -> tuple[Logger | None, Handler | None]: """Adds a handler to a logger and returns them both. In case this fails it shows a warning.""" if log_path: try: @@ -238,19 +239,19 @@ def _try_add_logging( logger.addHandler(handler) return logger, handler except BaseException as exception: - warnings.warn(f"Failed to start logging with exception: {repr(exception)}", category=EORuntimeWarning) + warnings.warn(f"Failed to start logging with exception: {exception!r}", category=EORuntimeWarning) return None, None @classmethod - def _try_remove_logging(cls, log_path: Optional[str], logger: Optional[Logger], handler: Optional[Handler]) -> None: + def _try_remove_logging(cls, log_path: str | None, logger: Logger | None, handler: Handler | None) -> None: """Removes a handler from a logger in case that handler exists.""" if log_path and logger and handler: try: handler.close() logger.removeHandler(handler) except BaseException as exception: - warnings.warn(f"Failed to end logging with exception: {repr(exception)}", category=EORuntimeWarning) + warnings.warn(f"Failed to end logging with exception: {exception!r}", category=EORuntimeWarning) @classmethod def _execute_workflow(cls, data: _ProcessingData) -> WorkflowResults: @@ -273,7 +274,7 @@ def _build_log_handler( log_path: str, pickled_filesystem: bytes, filter_logs_by_thread: bool, - logs_filter: Optional[Filter], + logs_filter: Filter | None, logs_handler_factory: _HandlerFactoryType, ) -> Handler: """Provides object which handles logs.""" @@ -299,11 +300,11 @@ def _build_log_handler( return handler @staticmethod - def _get_processing_type(workers: Optional[int], multiprocess: bool) -> _ProcessingType: + def _get_processing_type(workers: int | None, multiprocess: bool) -> _ProcessingType: """Provides a type of processing according to given parameters.""" return _decide_processing_type(workers=workers, multiprocess=multiprocess) - def _prepare_general_stats(self, workers: Optional[int], processing_type: _ProcessingType) -> Dict[str, object]: + def _prepare_general_stats(self, workers: int | None, processing_type: _ProcessingType) -> dict[str, object]: """Prepares a dictionary with a general statistics about executions.""" failed_count = sum(results.workflow_failed() for results in self.execution_results) return { @@ -315,7 +316,7 @@ def _prepare_general_stats(self, workers: Optional[int], processing_type: _Proce "workers": workers, } - def get_successful_executions(self) -> List[int]: + def get_successful_executions(self) -> list[int]: """Returns a list of IDs of successful executions. The IDs are integers from interval `[0, len(execution_kwargs) - 1]`, sorted in increasing order. @@ -323,7 +324,7 @@ def get_successful_executions(self) -> List[int]: """ return [idx for idx, results in enumerate(self.execution_results) if not results.workflow_failed()] - def get_failed_executions(self) -> List[int]: + def get_failed_executions(self) -> list[int]: """Returns a list of IDs of failed executions. The IDs are integers from interval `[0, len(execution_kwargs) - 1]`, sorted in increasing order. @@ -362,7 +363,7 @@ def make_report(self, include_logs: bool = True) -> None: return EOExecutorVisualization(self).make_report(include_logs=include_logs) - def get_log_paths(self, full_path: bool = True) -> List[str]: + def get_log_paths(self, full_path: bool = True) -> list[str]: """Returns a list of file paths containing logs. :param full_path: A flag to specify if it should return full absolute paths or paths relative to the @@ -376,7 +377,7 @@ def get_log_paths(self, full_path: bool = True) -> List[str]: return [get_full_path(self.filesystem, path) for path in log_paths] return log_paths - def read_logs(self) -> List[Optional[str]]: + def read_logs(self) -> list[str | None]: """Loads the content of log files if logs have been saved.""" if not self.save_logs: return [None] * len(self.execution_kwargs) @@ -391,5 +392,5 @@ def _read_log_file(self, log_path: str) -> str: with self.filesystem.open(log_path, "r") as file_handle: return file_handle.read() except BaseException as exception: - warnings.warn(f"Failed to load logs with exception: {repr(exception)}", category=EORuntimeWarning) + warnings.warn(f"Failed to load logs with exception: {exception!r}", category=EORuntimeWarning) return "Failed to load logs" diff --git a/core/eolearn/core/eotask.py b/core/eolearn/core/eotask.py index 839b80cf5..24023b2e2 100644 --- a/core/eolearn/core/eotask.py +++ b/core/eolearn/core/eotask.py @@ -17,7 +17,7 @@ import logging from abc import ABCMeta, abstractmethod from dataclasses import dataclass -from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Type, TypeVar, Union +from typing import Any, Callable, Iterable, TypeVar from sentinelhub.exceptions import deprecated_function @@ -46,11 +46,11 @@ class EOTask(metaclass=ABCMeta): deprecated_function(EODeprecationWarning, PARSE_RENAMED_DEPRECATE_MSG)(parse_renamed_features) ) - def __new__(cls: Type[Self], *args: Any, **kwargs: Any) -> Self: + def __new__(cls: type[Self], *args: Any, **kwargs: Any) -> Self: """Stores initialization parameters and the order to the instance attribute `init_args`.""" self = super().__new__(cls) # type: ignore[misc] - init_args: Dict[str, object] = {} + init_args: dict[str, object] = {} for arg, value in zip(inspect.getfullargspec(self.__init__).args[1 : len(args) + 1], args): init_args[arg] = repr(value) for arg in inspect.getfullargspec(self.__init__).args[len(args) + 1 :]: @@ -82,25 +82,25 @@ def execute(self, *eopatches, **kwargs): # type: ignore[no-untyped-def] # must @staticmethod def parse_feature( feature: SingleFeatureSpec, - eopatch: Optional[EOPatch] = None, - allowed_feature_types: Union[EllipsisType, Iterable[FeatureType], Callable[[FeatureType], bool]] = ..., - ) -> Tuple[FeatureType, Optional[str]]: + eopatch: EOPatch | None = None, + allowed_feature_types: EllipsisType | Iterable[FeatureType] | Callable[[FeatureType], bool] = ..., + ) -> tuple[FeatureType, str | None]: """See `eolearn.core.utils.parse_feature`.""" return parse_feature(feature, eopatch, allowed_feature_types) @staticmethod def parse_features( features: FeaturesSpecification, - eopatch: Optional[EOPatch] = None, - allowed_feature_types: Union[EllipsisType, Iterable[FeatureType], Callable[[FeatureType], bool]] = ..., - ) -> List[FeatureSpec]: + eopatch: EOPatch | None = None, + allowed_feature_types: EllipsisType | Iterable[FeatureType] | Callable[[FeatureType], bool] = ..., + ) -> list[FeatureSpec]: """See `eolearn.core.utils.parse_features`.""" return parse_features(features, eopatch, allowed_feature_types) @staticmethod def get_feature_parser( features: FeaturesSpecification, - allowed_feature_types: Union[EllipsisType, Iterable[FeatureType], Callable[[FeatureType], bool]] = ..., + allowed_feature_types: EllipsisType | Iterable[FeatureType] | Callable[[FeatureType], bool] = ..., ) -> FeatureParser: """See :class:`FeatureParser`.""" return FeatureParser(features, allowed_feature_types=allowed_feature_types) @@ -113,4 +113,4 @@ class _PrivateTaskConfig: :param init_args: A dictionary of parameters and values used for EOTask initialization """ - init_args: Dict[str, object] + init_args: dict[str, object] diff --git a/core/eolearn/core/eoworkflow.py b/core/eolearn/core/eoworkflow.py index 4c41a8dca..a432108e7 100644 --- a/core/eolearn/core/eoworkflow.py +++ b/core/eolearn/core/eoworkflow.py @@ -17,11 +17,13 @@ This source code is licensed under the MIT license, see the LICENSE file in the root directory of this source tree. """ +from __future__ import annotations + import datetime as dt import logging import traceback from dataclasses import dataclass, field, fields -from typing import Dict, List, Optional, Sequence, Set, Tuple, cast +from typing import Sequence, Tuple, cast from .eodata import EOPatch from .eonode import EONode, NodeStats @@ -81,7 +83,7 @@ def _parse_and_validate_nodes(nodes: Sequence[EONode]) -> Sequence[EONode]: return nodes @staticmethod - def _make_uid_dict(nodes: Sequence[EONode]) -> Dict[str, EONode]: + def _make_uid_dict(nodes: Sequence[EONode]) -> dict[str, EONode]: """Creates a dictionary mapping node IDs to nodes while checking uniqueness of tasks. :param nodes: The sequence of workflow nodes defining the computational graph @@ -116,17 +118,17 @@ def _create_dag(self, nodes: Sequence[EONode]) -> DirectedGraph[str]: return dag @classmethod - def from_endnodes(cls, *endnodes: EONode) -> "EOWorkflow": + def from_endnodes(cls, *endnodes: EONode) -> EOWorkflow: """Constructs the EOWorkflow from the end-nodes by recursively extracting nodes in the workflow structure.""" - all_nodes: Set[EONode] = set() - memo: Dict[EONode, Set[EONode]] = {} + all_nodes: set[EONode] = set() + memo: dict[EONode, set[EONode]] = {} for endnode in endnodes: all_nodes = all_nodes.union(endnode.get_dependencies(_memo=memo)) return cls(list(all_nodes)) def execute( - self, input_kwargs: Optional[Dict[EONode, Dict[str, object]]] = None, raise_errors: bool = True - ) -> "WorkflowResults": + self, input_kwargs: dict[EONode, dict[str, object]] | None = None, raise_errors: bool = True + ) -> WorkflowResults: """Executes the workflow. :param input_kwargs: External input arguments to the workflow. They have to be in a form of a dictionary where @@ -140,7 +142,7 @@ def execute( """ start_time = dt.datetime.now() - out_degrees: Dict[str, int] = self.uid_dag.get_outdegrees() + out_degrees: dict[str, int] = self.uid_dag.get_outdegrees() input_kwargs = input_kwargs or {} self.validate_input_kwargs(input_kwargs) @@ -160,7 +162,7 @@ def execute( return results @staticmethod - def validate_input_kwargs(input_kwargs: Dict[EONode, Dict[str, object]]) -> None: + def validate_input_kwargs(input_kwargs: dict[EONode, dict[str, object]]) -> None: """Validates EOWorkflow input arguments provided by user and raises an error if something is wrong. :param input_kwargs: A dictionary mapping tasks to task execution arguments @@ -185,8 +187,8 @@ def validate_input_kwargs(input_kwargs: Dict[EONode, Dict[str, object]]) -> None ) def _execute_nodes( - self, *, uid_input_kwargs: Dict[str, Dict[str, object]], out_degrees: Dict[str, int], raise_errors: bool - ) -> Tuple[dict, dict]: + self, *, uid_input_kwargs: dict[str, dict[str, object]], out_degrees: dict[str, int], raise_errors: bool + ) -> tuple[dict, dict]: """Executes workflow nodes in the predetermined order. :param uid_input_kwargs: External input arguments to the workflow. @@ -194,7 +196,7 @@ def _execute_nodes( of tasks that depend on this task.) :return: Results of a workflow """ - intermediate_results: Dict[str, object] = {} + intermediate_results: dict[str, object] = {} output_results = {} stats_dict = {} @@ -219,8 +221,8 @@ def _execute_nodes( return output_results, stats_dict def _execute_node( - self, *, node: EONode, node_input_values: List[object], node_input_kwargs: Dict[str, object], raise_errors: bool - ) -> Tuple[object, NodeStats]: + self, *, node: EONode, node_input_values: list[object], node_input_kwargs: dict[str, object], raise_errors: bool + ) -> tuple[object, NodeStats]: """Executes a node in the workflow by running its task and returning the results. :param node: A node of the workflow. @@ -257,8 +259,8 @@ def _execute_node( @staticmethod def _execute_task( - task: EOTask, task_args: List[object], task_kwargs: Dict[str, object], raise_errors: bool - ) -> Tuple[object, bool]: + task: EOTask, task_args: list[object], task_kwargs: dict[str, object], raise_errors: bool + ) -> tuple[object, bool]: """Executes an EOTask and handles any potential exceptions.""" if raise_errors: return task.execute(*task_args, **task_kwargs), True @@ -273,7 +275,7 @@ def _execute_task( @staticmethod def _relax_dependencies( - *, node: EONode, out_degrees: Dict[str, int], intermediate_results: Dict[str, object] + *, node: EONode, out_degrees: dict[str, int], intermediate_results: dict[str, object] ) -> None: """Relaxes dependencies incurred by `node` after it has been successfully executed. All the nodes it depended on are updated. If `node` was the last remaining node depending on a node `n` then `n`'s result @@ -295,14 +297,14 @@ def _relax_dependencies( ) del intermediate_results[relevant_node.uid] - def get_nodes(self) -> List[EONode]: + def get_nodes(self) -> list[EONode]: """Returns an ordered list of all nodes within this workflow, ordered in the execution order. :return: List of all nodes withing workflow. The order of nodes is the same as the order of execution. """ return self._nodes[:] - def get_node_with_uid(self, uid: Optional[str], fail_if_missing: bool = False) -> Optional[EONode]: + def get_node_with_uid(self, uid: str | None, fail_if_missing: bool = False) -> EONode | None: """Returns node with give uid, if it exists in the workflow.""" if uid in self._uid_dict: return self._uid_dict[uid] @@ -318,7 +320,7 @@ def get_dot(self): # type: ignore[no-untyped-def] # cannot type without extra d visualization = self._get_visualization() return visualization.get_dot() - def dependency_graph(self, filename: Optional[str] = None): # type: ignore[no-untyped-def] # same as get_dot + def dependency_graph(self, filename: str | None = None): # type: ignore[no-untyped-def] # same as get_dot """Visualize the computational graph. :param filename: Filename of the output image together with file extension. Supported formats: `png`, `jpg`, @@ -344,11 +346,11 @@ def _get_visualization(self): # type: ignore[no-untyped-def] # cannot type with class WorkflowResults: """An object containing results of an EOWorkflow execution.""" - outputs: Dict[str, object] + outputs: dict[str, object] start_time: dt.datetime end_time: dt.datetime - stats: Dict[str, NodeStats] - error_node_uid: Optional[str] = field(init=False, default=None) + stats: dict[str, NodeStats] + error_node_uid: str | None = field(init=False, default=None) def __post_init__(self) -> None: """Checks if there is any node that failed during the workflow execution.""" @@ -361,7 +363,7 @@ def workflow_failed(self) -> bool: """Informs if the EOWorkflow execution failed.""" return self.error_node_uid is not None - def drop_outputs(self) -> "WorkflowResults": + def drop_outputs(self) -> WorkflowResults: """Creates a new WorkflowResults object without outputs which can take a lot of memory.""" new_params = { param.name: {} if param.name == "outputs" else getattr(self, param.name) diff --git a/core/eolearn/core/eoworkflow_tasks.py b/core/eolearn/core/eoworkflow_tasks.py index 5e7436bb1..6472e9799 100644 --- a/core/eolearn/core/eoworkflow_tasks.py +++ b/core/eolearn/core/eoworkflow_tasks.py @@ -8,8 +8,6 @@ """ from __future__ import annotations -from typing import Optional - from .eodata import EOPatch from .eotask import EOTask from .types import FeaturesSpecification @@ -19,13 +17,13 @@ class InputTask(EOTask): """Introduces data into an EOWorkflow, where the data can be specified at initialization or at execution.""" - def __init__(self, value: Optional[object] = None): + def __init__(self, value: object | None = None): """ :param value: Default value that the task should provide as a result. Can be overridden in execution arguments """ self.value = value - def execute(self, *, value: Optional[object] = None) -> object: + def execute(self, *, value: object | None = None) -> object: """ :param value: A value that the task should provide as its result. If not set uses the value from initialization :return: Directly returns `value` @@ -36,7 +34,7 @@ def execute(self, *, value: Optional[object] = None) -> object: class OutputTask(EOTask): """Stores data as an output of `EOWorkflow` results.""" - def __init__(self, name: Optional[str] = None, features: FeaturesSpecification = ...): + def __init__(self, name: str | None = None, features: FeaturesSpecification = ...): """ :param name: A name under which the data will be saved in `WorkflowResults`, auto-generated if `None` :param features: A collection of features to be kept if the data is an `EOPatch` diff --git a/core/eolearn/core/extra/ray.py b/core/eolearn/core/extra/ray.py index af9016c17..2b641c3d0 100644 --- a/core/eolearn/core/extra/ray.py +++ b/core/eolearn/core/extra/ray.py @@ -8,7 +8,9 @@ This source code is licensed under the MIT license, see the LICENSE file in the root directory of this source tree. """ -from typing import Any, Callable, Collection, Generator, Iterable, List, Optional, Tuple, TypeVar, cast +from __future__ import annotations + +from typing import Any, Callable, Collection, Generator, Iterable, List, TypeVar, cast try: import ray @@ -27,7 +29,7 @@ class RayExecutor(EOExecutor): """A special type of `EOExecutor` that works with Ray framework""" - def run(self, **tqdm_kwargs: Any) -> List[WorkflowResults]: # type: ignore + def run(self, **tqdm_kwargs: Any) -> list[WorkflowResults]: # type: ignore[override] """Runs the executor using a Ray cluster Before calling this method make sure to initialize a Ray cluster using `ray.init`. @@ -43,8 +45,8 @@ def run(self, **tqdm_kwargs: Any) -> List[WorkflowResults]: # type: ignore @classmethod def _run_execution( - cls, processing_args: List[_ProcessingData], run_params: _ExecutionRunParams - ) -> List[WorkflowResults]: + cls, processing_args: list[_ProcessingData], run_params: _ExecutionRunParams + ) -> list[WorkflowResults]: """Runs ray execution""" futures = [_ray_workflow_executor.remote(workflow_args) for workflow_args in processing_args] return join_ray_futures(futures, **run_params.tqdm_kwargs) @@ -64,7 +66,7 @@ def _ray_workflow_executor(workflow_args: _ProcessingData) -> WorkflowResults: def parallelize_with_ray( function: Callable[[_InputType], _OutputType], *params: Iterable[_InputType], **tqdm_kwargs: Any -) -> List[_OutputType]: +) -> list[_OutputType]: """Parallelizes function execution with Ray. Note that this function will automatically connect to a Ray cluster, if a connection wouldn't exist yet. But it @@ -83,7 +85,7 @@ def parallelize_with_ray( return join_ray_futures(futures, **tqdm_kwargs) -def join_ray_futures(futures: List[ray.ObjectRef], **tqdm_kwargs: Any) -> List[Any]: +def join_ray_futures(futures: list[ray.ObjectRef], **tqdm_kwargs: Any) -> list[Any]: """Resolves futures, monitors progress, and returns a list of results. :param futures: A list of futures to be joined. Note that this list will be reduced into an empty list as a side @@ -93,7 +95,7 @@ def join_ray_futures(futures: List[ray.ObjectRef], **tqdm_kwargs: Any) -> List[A :param tqdm_kwargs: Keyword arguments that will be propagated to `tqdm` progress bar. :return: A list of results in the order that corresponds with the order of the given input `futures`. """ - results: List[Optional[Any]] = [None] * len(futures) + results: list[Any | None] = [None] * len(futures) for position, result in join_ray_futures_iter(futures, **tqdm_kwargs): results[position] = result @@ -101,8 +103,8 @@ def join_ray_futures(futures: List[ray.ObjectRef], **tqdm_kwargs: Any) -> List[A def join_ray_futures_iter( - futures: List[ray.ObjectRef], update_interval: float = 0.5, **tqdm_kwargs: Any -) -> Generator[Tuple[int, Any], None, None]: + futures: list[ray.ObjectRef], update_interval: float = 0.5, **tqdm_kwargs: Any +) -> Generator[tuple[int, Any], None, None]: """Resolves futures, monitors progress, and serves as an iterator over results. :param futures: A list of futures to be joined. Note that this list will be reduced into an empty list as a side @@ -117,7 +119,7 @@ def join_ray_futures_iter( def _ray_wait_function( remaining_futures: Collection[ray.ObjectRef], - ) -> Tuple[Collection[ray.ObjectRef], Collection[ray.ObjectRef]]: + ) -> tuple[Collection[ray.ObjectRef], Collection[ray.ObjectRef]]: return ray.wait(remaining_futures, num_returns=len(remaining_futures), timeout=float(update_interval)) return _base_join_futures_iter(_ray_wait_function, ray.get, futures, **tqdm_kwargs) diff --git a/core/eolearn/core/graph.py b/core/eolearn/core/graph.py index c37b59f61..7dff93c8a 100644 --- a/core/eolearn/core/graph.py +++ b/core/eolearn/core/graph.py @@ -7,10 +7,11 @@ This source code is licensed under the MIT license, see the LICENSE file in the root directory of this source tree. """ +from __future__ import annotations import collections import copy -from typing import DefaultDict, Dict, Generic, Iterator, List, Optional, Sequence, Set, Tuple, TypeVar +from typing import Generic, Iterator, Sequence, TypeVar _T = TypeVar("_T") @@ -27,7 +28,7 @@ class DirectedGraph(Generic[_T]): :param adjacency_dict: A dictionary mapping vertices to lists of neighbors """ - def __init__(self, adjacency_dict: Optional[Dict[_T, List[_T]]] = None): + def __init__(self, adjacency_dict: dict[_T, list[_T]] | None = None): self._adj_dict = ( collections.defaultdict(list, adjacency_dict) if adjacency_dict else collections.defaultdict(list) ) @@ -49,8 +50,8 @@ def __iter__(self) -> Iterator[_T]: """Returns iterator over the vertices of the graph.""" return iter(self._vertices) - def _make_indegrees_dict(self) -> DefaultDict[_T, int]: - indegrees: DefaultDict[_T, int] = collections.defaultdict(int) + def _make_indegrees_dict(self) -> collections.defaultdict[_T, int]: + indegrees: collections.defaultdict[_T, int] = collections.defaultdict(int) for u_vertex in self._adj_dict: for v_vertex in self._adj_dict[u_vertex]: @@ -58,7 +59,7 @@ def _make_indegrees_dict(self) -> DefaultDict[_T, int]: return indegrees - def get_indegrees(self) -> Dict[_T, int]: + def get_indegrees(self) -> dict[_T, int]: """Returns a dictionary containing in-degrees of vertices of the graph.""" return dict(self._indegrees) @@ -71,7 +72,7 @@ def get_indegree(self, vertex: _T) -> int: """ return self._indegrees[vertex] - def get_outdegrees(self) -> Dict[_T, int]: + def get_outdegrees(self) -> dict[_T, int]: """ :return: dictionary of out-degrees, see get_outdegree """ @@ -86,13 +87,13 @@ def get_outdegree(self, vertex: _T) -> int: """ return len(self._adj_dict[vertex]) - def get_adj_dict(self) -> Dict[_T, list]: + def get_adj_dict(self) -> dict[_T, list]: """ :return: adj_dict """ return {vertex: copy.copy(neighbours) for vertex, neighbours in self._adj_dict.items()} - def get_vertices(self) -> Set[_T]: + def get_vertices(self) -> set[_T]: """Returns the set of vertices of the graph.""" return set(self._vertices) @@ -160,12 +161,12 @@ def is_edge(self, u_vertex: _T, v_vertex: _T) -> bool: """True if `u_vertex -> v_vertex` is an edge of the graph. False otherwise.""" return v_vertex in self._adj_dict[u_vertex] - def get_neighbors(self, vertex: _T) -> List[_T]: + def get_neighbors(self, vertex: _T) -> list[_T]: """Returns the set of successor vertices of `vertex`.""" return copy.copy(self._adj_dict[vertex]) @staticmethod - def from_edges(edges: Sequence[Tuple[_T, _T]]) -> "DirectedGraph[_T]": + def from_edges(edges: Sequence[tuple[_T, _T]]) -> DirectedGraph[_T]: """Return DirectedGraph created from edges. :param edges: Pairs of objects that describe all the edges of the graph """ @@ -175,7 +176,7 @@ def from_edges(edges: Sequence[Tuple[_T, _T]]) -> "DirectedGraph[_T]": return dag @staticmethod - def _is_cyclic(graph: "DirectedGraph") -> bool: + def _is_cyclic(graph: DirectedGraph) -> bool: """True if the directed graph contains a cycle. False otherwise. The algorithm is naive, running in O(V^2) time, and not intended for serious use! For production purposes on @@ -196,7 +197,7 @@ def _is_cyclic(graph: "DirectedGraph") -> bool: stack.append(v) return False - def topologically_ordered_vertices(self) -> List[_T]: + def topologically_ordered_vertices(self) -> list[_T]: """Computes an ordering `<` of vertices so that for any two vertices `v` and `v'` we have that if `v˙ depends on `v'` then `v' < v`. In words, all dependencies of a vertex precede the vertex in this ordering. diff --git a/core/eolearn/core/types.py b/core/eolearn/core/types.py index f2caa980a..7aabdef55 100644 --- a/core/eolearn/core/types.py +++ b/core/eolearn/core/types.py @@ -9,9 +9,7 @@ import sys # pylint: disable=unused-import -from typing import Dict, Iterable, Optional, Sequence, Tuple, Union - -from typing_extensions import Literal +from typing import Dict, Iterable, Literal, Optional, Sequence, Tuple, Union from .constants import FeatureType @@ -19,7 +17,7 @@ from types import EllipsisType # pylint: disable=ungrouped-imports from typing import TypeAlias else: - import builtins # noqa: F401 + import builtins # noqa: F401, RUF100 from typing_extensions import TypeAlias diff --git a/core/eolearn/core/utils/common.py b/core/eolearn/core/utils/common.py index b853a007f..b2c425f9c 100644 --- a/core/eolearn/core/utils/common.py +++ b/core/eolearn/core/utils/common.py @@ -8,7 +8,7 @@ This source code is licensed under the MIT license, see the LICENSE file in the root directory of this source tree. """ import uuid -from typing import Callable, Sequence, Tuple, Union, cast +from typing import Callable, Mapping, Sequence, Tuple, Union, cast import geopandas as gpd import numpy as np @@ -58,7 +58,7 @@ def deep_eq(fst_obj: object, snd_obj: object) -> bool: return len(fst_obj) == len(snd_obj) and all(map(deep_eq, fst_obj, snd_obj)) - if isinstance(fst_obj, dict): + if isinstance(fst_obj, (dict, Mapping)): snd_obj = cast(dict, snd_obj) if fst_obj.keys() != snd_obj.keys(): diff --git a/core/eolearn/core/utils/fs.py b/core/eolearn/core/utils/fs.py index 055ffd384..4e40ca0b5 100644 --- a/core/eolearn/core/utils/fs.py +++ b/core/eolearn/core/utils/fs.py @@ -22,6 +22,9 @@ from sentinelhub import SHConfig +# because we access internals when pickling FS +# ruff: noqa: SLF001 + def get_filesystem( path: Union[str, Path], create: bool = False, config: Optional[SHConfig] = None, **kwargs: Any diff --git a/core/eolearn/core/utils/parallelize.py b/core/eolearn/core/utils/parallelize.py index 9a1952920..783a249b8 100644 --- a/core/eolearn/core/utils/parallelize.py +++ b/core/eolearn/core/utils/parallelize.py @@ -7,18 +7,20 @@ This source code is licensed under the MIT license, see the LICENSE file in the root directory of this source tree. """ +from __future__ import annotations + import concurrent.futures import multiprocessing from concurrent.futures import FIRST_COMPLETED, Executor, Future, ProcessPoolExecutor, ThreadPoolExecutor from enum import Enum -from typing import TYPE_CHECKING, Any, Callable, Collection, Generator, Iterable, List, Optional, Tuple, TypeVar, cast +from typing import TYPE_CHECKING, Any, Callable, Collection, Generator, Iterable, List, TypeVar, cast from tqdm.auto import tqdm if TYPE_CHECKING: from threading import Lock - MULTIPROCESSING_LOCK: Optional[Lock] = None + MULTIPROCESSING_LOCK: Lock | None = None else: MULTIPROCESSING_LOCK = None @@ -37,7 +39,7 @@ class _ProcessingType(Enum): RAY = "ray" -def _decide_processing_type(workers: Optional[int], multiprocess: bool) -> _ProcessingType: +def _decide_processing_type(workers: int | None, multiprocess: bool) -> _ProcessingType: """Decides processing type according to given parameters. :param workers: A number of workers to be used (either threads or processes). If a single worker is given it will @@ -55,10 +57,10 @@ def _decide_processing_type(workers: Optional[int], multiprocess: bool) -> _Proc def parallelize( function: Callable[..., _OutputType], *params: Iterable[Any], - workers: Optional[int], + workers: int | None, multiprocess: bool = True, **tqdm_kwargs: Any, -) -> List[_OutputType]: +) -> list[_OutputType]: """Parallelizes the function on given parameters using the specified number of workers. :param function: A function to be parallelized. @@ -118,7 +120,7 @@ def submit_and_monitor_execution( function: Callable[..., _OutputType], *params: Iterable[Any], **tqdm_kwargs: Any, -) -> List[_OutputType]: +) -> list[_OutputType]: """Performs the execution parallelization and monitors the process using a progress bar. :param executor: An object that performs parallelization. @@ -130,7 +132,7 @@ def submit_and_monitor_execution( return join_futures(futures, **tqdm_kwargs) -def join_futures(futures: List[Future], **tqdm_kwargs: Any) -> List[Any]: +def join_futures(futures: list[Future], **tqdm_kwargs: Any) -> list[Any]: """Resolves futures, monitors progress, and returns a list of results. :param futures: A list of futures to be joined. Note that this list will be reduced into an empty list as a side @@ -140,7 +142,7 @@ def join_futures(futures: List[Future], **tqdm_kwargs: Any) -> List[Any]: :param tqdm_kwargs: Keyword arguments that will be propagated to `tqdm` progress bar. :return: A list of results in the order that corresponds with the order of the given input `futures`. """ - results: List[Optional[Any]] = [None] * len(futures) + results: list[Any | None] = [None] * len(futures) for position, result in join_futures_iter(futures, **tqdm_kwargs): results[position] = result @@ -148,8 +150,8 @@ def join_futures(futures: List[Future], **tqdm_kwargs: Any) -> List[Any]: def join_futures_iter( - futures: List[Future], update_interval: float = 0.5, **tqdm_kwargs: Any -) -> Generator[Tuple[int, Any], None, None]: + futures: list[Future], update_interval: float = 0.5, **tqdm_kwargs: Any +) -> Generator[tuple[int, Any], None, None]: """Resolves futures, monitors progress, and serves as an iterator over results. :param futures: A list of futures to be joined. Note that this list will be reduced into an empty list as a side @@ -162,7 +164,7 @@ def join_futures_iter( in the original list to which `result` belongs to. """ - def _wait_function(remaining_futures: Collection[Future]) -> Tuple[Collection[Future], Collection[Future]]: + def _wait_function(remaining_futures: Collection[Future]) -> tuple[Collection[Future], Collection[Future]]: done, not_done = concurrent.futures.wait( remaining_futures, timeout=float(update_interval), return_when=FIRST_COMPLETED ) @@ -175,11 +177,11 @@ def _get_result(future: Future) -> Any: def _base_join_futures_iter( - wait_function: Callable[[Collection[_FutureType]], Tuple[Collection[_FutureType], Collection[_FutureType]]], + wait_function: Callable[[Collection[_FutureType]], tuple[Collection[_FutureType], Collection[_FutureType]]], get_result_function: Callable[[_FutureType], _OutputType], - futures: List[_FutureType], + futures: list[_FutureType], **tqdm_kwargs: Any, -) -> Generator[Tuple[int, _OutputType], None, None]: +) -> Generator[tuple[int, _OutputType], None, None]: """A generalized utility function that resolves futures, monitors progress, and serves as an iterator over results.""" if not isinstance(futures, list): @@ -198,7 +200,7 @@ def _base_join_futures_iter( yield result_position, result -def _make_copy_and_empty_given(items: List[_T]) -> List[_T]: +def _make_copy_and_empty_given(items: list[_T]) -> list[_T]: """Removes items from the given list and returns its copy. The side effect of removing items is intentional.""" items_copy = items[:] while items: diff --git a/core/eolearn/core/utils/parsing.py b/core/eolearn/core/utils/parsing.py index 08023e90d..06e93bb4f 100644 --- a/core/eolearn/core/utils/parsing.py +++ b/core/eolearn/core/utils/parsing.py @@ -9,7 +9,7 @@ from __future__ import annotations import contextlib -from typing import TYPE_CHECKING, Callable, Iterable, List, Optional, Sequence, Tuple, Union, cast +from typing import TYPE_CHECKING, Callable, Iterable, Sequence, Tuple, Union, cast from ..constants import FeatureType from ..types import ( @@ -90,7 +90,7 @@ class FeatureParser: def __init__( self, features: FeaturesSpecification, - allowed_feature_types: Union[Iterable[FeatureType], Callable[[FeatureType], bool], EllipsisType] = ..., + allowed_feature_types: Iterable[FeatureType] | Callable[[FeatureType], bool] | EllipsisType = ..., ): """ :param features: A collection of features in one of the supported formats @@ -106,7 +106,7 @@ def __init__( ) self._feature_specs = self._parse_features(features) - def _parse_features(self, features: FeaturesSpecification) -> List[_ParserFeaturesSpec]: + def _parse_features(self, features: FeaturesSpecification) -> list[_ParserFeaturesSpec]: """This method parses and validates input, returning a list of `(ftype, old_name, new_name)` triples. Due to typing issues the all-features requests are transformed from `(ftype, ...)` to `(ftype, None, None)`. @@ -135,10 +135,10 @@ def _parse_features(self, features: FeaturesSpecification) -> List[_ParserFeatur def _parse_dict( self, features: DictFeatureSpec, - ) -> List[_ParserFeaturesSpec]: + ) -> list[_ParserFeaturesSpec]: """Implements parsing and validation in case the input is a dictionary.""" - feature_specs: List[_ParserFeaturesSpec] = [] + feature_specs: list[_ParserFeaturesSpec] = [] for feature_type, feature_names in features.items(): feature_type = self._parse_feature_type(feature_type, message_about_position="keys of the dictionary") @@ -159,11 +159,11 @@ def _parse_dict( def _parse_sequence( self, - features: Union[SingleFeatureSpec, SequenceFeatureSpec], - ) -> List[_ParserFeaturesSpec]: + features: SingleFeatureSpec | SequenceFeatureSpec, + ) -> list[_ParserFeaturesSpec]: """Implements parsing and validation in case the input is a tuple describing a single feature or a sequence.""" - feature_specs: List[_ParserFeaturesSpec] = [] + feature_specs: list[_ParserFeaturesSpec] = [] # Check for possible singleton if 2 <= len(features) <= 3: @@ -200,7 +200,7 @@ def _parse_singleton(self, feature: Sequence) -> FeatureRenameSpec: parsed_name = self._parse_feature_name(feature_type, feature_name) return (feature_type, *parsed_name) - def _parse_feature_type(self, feature_type: Union[str, FeatureType], *, message_about_position: str) -> FeatureType: + def _parse_feature_type(self, feature_type: str | FeatureType, *, message_about_position: str) -> FeatureType: """Tries to extract a feature type if possible, fails otherwise. The parameter `message_about_position` is used for more informative error messages. @@ -219,7 +219,7 @@ def _parse_feature_type(self, feature_type: Union[str, FeatureType], *, message_ return feature_type @staticmethod - def _parse_feature_name(feature_type: FeatureType, name: object) -> Tuple[str, str]: + def _parse_feature_name(feature_type: FeatureType, name: object) -> tuple[str, str]: """Parses input in places where a feature name is expected, handling the cases of a name and renaming pair.""" if isinstance(name, str): return name, name @@ -247,14 +247,14 @@ def _fail_for_noname_features(feature_type: FeatureType, specification: object) f" {specification} instead." ) - def get_feature_specifications(self) -> List[Tuple[FeatureType, Union[str, EllipsisType]]]: + def get_feature_specifications(self) -> list[tuple[FeatureType, str | EllipsisType]]: """Returns the feature specifications in a more streamlined fashion. Requests for all features, e.g. `(FeatureType.DATA, ...)`, are returned directly. """ return [(ftype, ... if fname is None else fname) for ftype, fname, _ in self._feature_specs] - def get_features(self, eopatch: Optional[EOPatch] = None) -> List[FeatureSpec]: + def get_features(self, eopatch: EOPatch | None = None) -> list[FeatureSpec]: """Returns a list of `(feature_type, feature_name)` pairs. For features that specify renaming, the new name of the feature is ignored. @@ -267,7 +267,7 @@ def get_features(self, eopatch: Optional[EOPatch] = None) -> List[FeatureSpec]: renamed_features = self.get_renamed_features(eopatch) return [feature[:2] for feature in renamed_features] # pattern unpacking messes with typechecking - def get_renamed_features(self, eopatch: Optional[EOPatch] = None) -> List[FeatureRenameSpec]: + def get_renamed_features(self, eopatch: EOPatch | None = None) -> list[FeatureRenameSpec]: """Returns a list of `(feature_type, old_name, new_name)` triples. For features without a specified renaming the new name is equal to the old one. @@ -278,7 +278,7 @@ def get_renamed_features(self, eopatch: Optional[EOPatch] = None) -> List[Featur If `eopatch` is not provided the method fails if an all-feature request is in the specification. """ - parsed_features: List[FeatureRenameSpec] = [] + parsed_features: list[FeatureRenameSpec] = [] for feature_spec in self._feature_specs: ftype, old_name, new_name = feature_spec @@ -303,9 +303,9 @@ def get_renamed_features(self, eopatch: Optional[EOPatch] = None) -> List[Featur def parse_feature( feature: SingleFeatureSpec, - eopatch: Optional[EOPatch] = None, - allowed_feature_types: Union[EllipsisType, Iterable[FeatureType], Callable[[FeatureType], bool]] = ..., -) -> Tuple[FeatureType, Optional[str]]: + eopatch: EOPatch | None = None, + allowed_feature_types: EllipsisType | Iterable[FeatureType] | Callable[[FeatureType], bool] = ..., +) -> tuple[FeatureType, str | None]: """Parses input describing a single feature into a `(feature_type, feature_name)` pair. See :class:`FeatureParser` for viable inputs. @@ -319,8 +319,8 @@ def parse_feature( def parse_renamed_feature( feature: SingleFeatureSpec, - eopatch: Optional[EOPatch] = None, - allowed_feature_types: Union[EllipsisType, Iterable[FeatureType], Callable[[FeatureType], bool]] = ..., + eopatch: EOPatch | None = None, + allowed_feature_types: EllipsisType | Iterable[FeatureType] | Callable[[FeatureType], bool] = ..., ) -> FeatureRenameSpec: """Parses input describing a single feature into a `(feature_type, old_name, new_name)` triple. @@ -335,9 +335,9 @@ def parse_renamed_feature( def parse_features( features: FeaturesSpecification, - eopatch: Optional[EOPatch] = None, - allowed_feature_types: Union[EllipsisType, Iterable[FeatureType], Callable[[FeatureType], bool]] = ..., -) -> List[FeatureSpec]: + eopatch: EOPatch | None = None, + allowed_feature_types: EllipsisType | Iterable[FeatureType] | Callable[[FeatureType], bool] = ..., +) -> list[FeatureSpec]: """Parses input describing features into a list of `(feature_type, feature_name)` pairs. See :class:`FeatureParser` for viable inputs. @@ -347,9 +347,9 @@ def parse_features( def parse_renamed_features( features: FeaturesSpecification, - eopatch: Optional[EOPatch] = None, - allowed_feature_types: Union[EllipsisType, Iterable[FeatureType], Callable[[FeatureType], bool]] = ..., -) -> List[FeatureRenameSpec]: + eopatch: EOPatch | None = None, + allowed_feature_types: EllipsisType | Iterable[FeatureType] | Callable[[FeatureType], bool] = ..., +) -> list[FeatureRenameSpec]: """Parses input describing features into a list of `(feature_type, old_name, new_name)` triples. See :class:`FeatureParser` for viable inputs. diff --git a/core/eolearn/core/utils/raster.py b/core/eolearn/core/utils/raster.py index 945b57423..0d5d3b996 100644 --- a/core/eolearn/core/utils/raster.py +++ b/core/eolearn/core/utils/raster.py @@ -6,10 +6,9 @@ This source code is licensed under the MIT license, see the LICENSE file in the root directory of this source tree. """ -from typing import Tuple +from typing import Literal, Tuple import numpy as np -from typing_extensions import Literal def fast_nanpercentile(data: np.ndarray, percentile: float, *, method: str = "linear") -> np.ndarray: @@ -57,7 +56,7 @@ def fast_nanpercentile(data: np.ndarray, percentile: float, *, method: str = "li return combined_data -def constant_pad( +def constant_pad( # noqa: C901 array: np.ndarray, multiple_of: Tuple[int, int], up_down_rule: Literal["even", "up", "down"] = "even", diff --git a/core/eolearn/core/utils/testing.py b/core/eolearn/core/utils/testing.py index 5b7cd243a..e0b95c3a7 100644 --- a/core/eolearn/core/utils/testing.py +++ b/core/eolearn/core/utils/testing.py @@ -7,6 +7,7 @@ This source code is licensed under the MIT license, see the LICENSE file in the root directory of this source tree. """ import datetime as dt +import string from dataclasses import dataclass, field from typing import Any, List, Optional, Tuple @@ -50,18 +51,24 @@ def generate_eopatch( ) -> EOPatch: """A class for generating EOPatches with dummy data.""" config = config if config is not None else PatchGeneratorConfig() - supported_feature_types = [ftype for ftype in FeatureType if ftype.is_array()] - parsed_features = FeatureParser(features or [], supported_feature_types).get_features() - rng = np.random.default_rng(seed) + parsed_features = FeatureParser( + features or [], lambda feature_type: feature_type.is_array() or feature_type == FeatureType.META_INFO + ).get_features() + + rng = np.random.default_rng(seed) timestamps = timestamps if timestamps is not None else config.timestamps patch = EOPatch(bbox=bbox, timestamps=timestamps) # fill eopatch with random data # note: the patch generation functionality could be extended by generating extra random features for ftype, fname in parsed_features: - shape = _get_feature_shape(rng, ftype, timestamps, config) - patch[(ftype, fname)] = _generate_feature_data(rng, ftype, shape, config) + if ftype == FeatureType.META_INFO: + patch[(ftype, fname)] = "".join(rng.choice(list(string.ascii_letters), 20)) + else: + shape = _get_feature_shape(rng, ftype, timestamps, config) + patch[(ftype, fname)] = _generate_feature_data(rng, ftype, shape, config) + return patch diff --git a/core/eolearn/core/utils/types.py b/core/eolearn/core/utils/types.py index 81650552e..9efc8bdae 100644 --- a/core/eolearn/core/utils/types.py +++ b/core/eolearn/core/utils/types.py @@ -2,7 +2,7 @@ from warnings import warn from ..exceptions import EODeprecationWarning -from ..types import * # noqa # pylint: disable=wildcard-import,unused-wildcard-import +from ..types import * # noqa: 403 # pylint: disable=wildcard-import,unused-wildcard-import warn( "The module `eolearn.core.utils.types` is deprecated, use `eolearn.core.types` instead.", diff --git a/core/eolearn/core/utils/vector_io.py b/core/eolearn/core/utils/vector_io.py deleted file mode 100644 index a7d46b6c9..000000000 --- a/core/eolearn/core/utils/vector_io.py +++ /dev/null @@ -1,57 +0,0 @@ -""" -A module implementing utilities for working with geopackage files - -Copyright (c) 2017- Sinergise and contributors -For the full list of contributors, see the CREDITS file in the root directory of this source tree. - -This source code is licensed under the MIT license, see the LICENSE file in the root directory of this source tree. -""" -# pylint: skip-file -import numpy as np -from geopandas.io.file import _geometry_types - - -def infer_schema(df): # type: ignore[no-untyped-def] - """This function is copied over from GeoPandas, part of geopandas.io.file module, with the sole purpose of - disabling the `if df.empty` check that prevents saving an empty dataframe. Will be removed after GeoPandas - version 0.11 fixes the problem. - """ - - from collections import OrderedDict - - # TODO: test pandas string type and boolean type once released - types = {"Int64": "int", "string": "str", "boolean": "bool"} - - def convert_type(column, in_type): # type: ignore[no-untyped-def] - if in_type == object: - return "str" - if in_type.name.startswith("datetime64"): - # numpy datetime type regardless of frequency - return "datetime" - if str(in_type) in types: # noqa - out_type = types[str(in_type)] - else: - out_type = type(np.zeros(1, in_type).item()).__name__ - if out_type == "long": - out_type = "int" - return out_type - - properties = OrderedDict( - [ - (col, convert_type(col, _type)) - for col, _type in zip(df.columns, df.dtypes) - if col != df._geometry_column_name - ] - ) - - # NOTE: commented out by eo-learn team - # if df.empty: - # raise ValueError("Cannot write empty DataFrame to file.") - - # Since https://github.com/Toblerity/Fiona/issues/446 resolution, - # Fiona allows a list of geometry types - geom_types = _geometry_types(df) - - schema = {"geometry": geom_types, "properties": properties} - - return schema diff --git a/core/eolearn/tests/test_constants.py b/core/eolearn/tests/test_constants.py index 5ab092831..8f4bfee68 100644 --- a/core/eolearn/tests/test_constants.py +++ b/core/eolearn/tests/test_constants.py @@ -11,7 +11,7 @@ @pytest.mark.parametrize( - "old_ftype, new_ftype", + ("old_ftype", "new_ftype"), [ (FeatureType.TIMESTAMP, FeatureType.TIMESTAMPS), (FeatureType["TIMESTAMP"], FeatureType["TIMESTAMPS"]), diff --git a/core/eolearn/tests/test_core_tasks.py b/core/eolearn/tests/test_core_tasks.py index 87cc9073b..96ebc3a66 100644 --- a/core/eolearn/tests/test_core_tasks.py +++ b/core/eolearn/tests/test_core_tasks.py @@ -47,6 +47,7 @@ from eolearn.core.utils.testing import PatchGeneratorConfig, assert_feature_data_equal, generate_eopatch DUMMY_BBOX = BBox((0, 0, 1, 1), CRS(3857)) +# ruff: noqa: NPY002 @pytest.fixture(name="patch") @@ -57,11 +58,10 @@ def patch_fixture() -> EOPatch: FeatureType.MASK: ["CLM"], FeatureType.MASK_TIMELESS: ["mask", "LULC", "RANDOM_UINT8"], FeatureType.SCALAR: ["values", "CLOUD_COVERAGE"], + FeatureType.META_INFO: ["something"], } ) patch.data["CLP_S2C"] = np.zeros_like(patch.data["CLP"]) - - patch.meta_info["something"] = "beep boop" return patch @@ -104,11 +104,13 @@ def test_load_task(test_eopatch_path: str) -> None: partial_load = LoadTask(test_eopatch_path, features=[FeatureType.BBOX, FeatureType.MASK_TIMELESS]) partial_patch = partial_load.execute(eopatch_folder=".") - assert FeatureType.BBOX in partial_patch and FeatureType.TIMESTAMPS not in partial_patch + assert FeatureType.BBOX in partial_patch + assert FeatureType.TIMESTAMPS not in partial_patch load_more = LoadTask(test_eopatch_path, features=[FeatureType.TIMESTAMPS]) upgraded_partial_patch = load_more.execute(partial_patch, eopatch_folder=".") - assert FeatureType.BBOX in upgraded_partial_patch and FeatureType.TIMESTAMPS in upgraded_partial_patch + assert FeatureType.BBOX in upgraded_partial_patch + assert FeatureType.TIMESTAMPS in upgraded_partial_patch assert FeatureType.DATA not in upgraded_partial_patch @@ -123,7 +125,7 @@ def test_io_task_pickling(filesystem: FS, task_class: Type[EOTask]) -> None: @pytest.mark.parametrize( - "feature, feature_data", + ("feature", "feature_data"), [ ((FeatureType.MASK, "CLOUD MASK"), np.arange(10).reshape(5, 2, 1, 1)), ((FeatureType.META_INFO, "something_else"), np.random.rand(10, 1)), @@ -198,7 +200,7 @@ def test_duplicate_feature_fails(patch: EOPatch) -> None: @pytest.mark.parametrize( - "init_val, shape, feature_spec", + ("init_val", "shape", "feature_spec"), [ (8, (5, 2, 6, 3), (FeatureType.MASK, "test")), (9, (1, 4, 3), (FeatureType.MASK_TIMELESS, "test")), @@ -211,11 +213,11 @@ def test_initialize_feature( expected_data = init_val * np.ones(shape) patch = InitializeFeatureTask(feature_spec, shape=shape, init_value=init_val)(patch) - assert all([np.array_equal(patch[features], expected_data) for features in parse_features(feature_spec)]) + assert all(np.array_equal(patch[features], expected_data) for features in parse_features(feature_spec)) @pytest.mark.parametrize( - "init_val, shape, feature_spec", + ("init_val", "shape", "feature_spec"), [ (3, (FeatureType.DATA, "bands"), {FeatureType.MASK: ["F1", "F2", "F3"]}), ], @@ -226,7 +228,7 @@ def test_initialize_feature_with_spec( expected_data = init_val * np.ones(patch[shape].shape) patch = InitializeFeatureTask(feature_spec, shape=shape, init_value=init_val)(patch) - assert all([np.array_equal(patch[features], expected_data) for features in parse_features(feature_spec)]) + assert all(np.array_equal(patch[features], expected_data) for features in parse_features(feature_spec)) def test_initialize_feature_fails(patch: EOPatch) -> None: @@ -258,7 +260,7 @@ def test_move_feature(features: FeatureSpec, deep: bool, patch: EOPatch) -> None @pytest.mark.parametrize( - "features_to_merge, feature, axis", + ("features_to_merge", "feature", "axis"), [ ([(FeatureType.DATA, "bands")], (FeatureType.DATA, "merged"), 0), ([(FeatureType.DATA, "bands"), (FeatureType.DATA, "CLP")], (FeatureType.DATA, "merged"), -1), @@ -282,7 +284,7 @@ def test_merge_features(axis: int, features_to_merge: List[FeatureSpec], feature @pytest.mark.parametrize( - "input_features, output_feature, zip_function, kwargs", + ("input_features", "output_feature", "zip_function", "kwargs"), [ ({FeatureType.DATA: ["CLP", "bands"]}, (FeatureType.DATA, "ziped"), np.maximum, {}), ({FeatureType.DATA: ["CLP", "bands"]}, (FeatureType.DATA, "ziped"), lambda a, b: a + b, {}), @@ -320,7 +322,7 @@ def test_zip_features_fails(patch: EOPatch) -> None: @pytest.mark.parametrize( - "input_features, output_features, map_function, kwargs", + ("input_features", "output_features", "map_function", "kwargs"), [ ({FeatureType.DATA: ["CLP", "bands"]}, {FeatureType.DATA: ["CLP_+3", "bands_+3"]}, lambda x: x + 3, {}), ( @@ -362,7 +364,7 @@ def test_map_features( assert_array_equal(mapped_patch[out_feature], expected_output) -@pytest.mark.parametrize("input_features, map_function", [({FeatureType.DATA: ["CLP", "bands"]}, lambda x: x + 3)]) +@pytest.mark.parametrize(("input_features", "map_function"), [({FeatureType.DATA: ["CLP", "bands"]}, lambda x: x + 3)]) def test_map_features_overwrite(input_features: FeaturesSpecification, map_function: Callable, patch: EOPatch) -> None: original_patch = patch.copy(deep=True, features=input_features) patch = MapFeatureTask(input_features, input_features, map_function)(patch) @@ -381,13 +383,13 @@ def test_map_features_fails(patch: EOPatch) -> None: @pytest.mark.parametrize( - "input_feature, kwargs", + ("input_feature", "kwargs"), [ ((FeatureType.DATA, "bands"), {"axis": -1, "name": "fun_name", "bands": [4, 3, 2]}), ], ) def test_map_kwargs_passing(input_feature: FeatureSpec, kwargs: Dict[str, Any], patch: EOPatch) -> None: - def kwargs_map(data, *, some=3, **kwargs) -> tuple: + def kwargs_map(_, *, some=3, **kwargs) -> tuple: return some, kwargs mapped_patch = MapFeatureTask(input_feature, (FeatureType.META_INFO, "kwargs"), kwargs_map, **kwargs)(patch) @@ -397,7 +399,7 @@ def kwargs_map(data, *, some=3, **kwargs) -> tuple: @pytest.mark.parametrize( - "feature, task_input", + ("feature", "task_input"), [ ((FeatureType.DATA, "bands"), {(FeatureType.DATA, "EXPLODED_BANDS"): [2, 4, 6]}), ((FeatureType.DATA, "bands"), {(FeatureType.DATA, "EXPLODED_BANDS"): [2]}), diff --git a/core/eolearn/tests/test_eodata.py b/core/eolearn/tests/test_eodata.py index ebef336f0..a160f33b0 100644 --- a/core/eolearn/tests/test_eodata.py +++ b/core/eolearn/tests/test_eodata.py @@ -21,20 +21,19 @@ from eolearn.core.utils.testing import assert_feature_data_equal, generate_eopatch DUMMY_BBOX = BBox((0, 0, 1, 1), CRS(3857)) +# ruff: noqa: NPY002, SLF001 @pytest.fixture(name="mini_eopatch") def mini_eopatch_fixture() -> EOPatch: - eop = generate_eopatch( + return generate_eopatch( { FeatureType.DATA: ["A", "B"], FeatureType.MASK: ["C", "D"], FeatureType.MASK_TIMELESS: ["E"], + FeatureType.META_INFO: ["beep"], } ) - eop.meta_info["beep"] = "boop" - - return eop def test_numpy_feature_types() -> None: @@ -127,7 +126,8 @@ def test_invalid_characters(): def test_repr(test_eopatch_path: str) -> None: test_eopatch = EOPatch.load(test_eopatch_path) repr_str = repr(test_eopatch) - assert repr_str.startswith("EOPatch(") and repr_str.endswith(")") + assert repr_str.startswith("EOPatch(") + assert repr_str.endswith(")") assert len(repr_str) > 100 assert repr(EOPatch(bbox=DUMMY_BBOX)) == "EOPatch(\n bbox=BBox(((0.0, 0.0), (1.0, 1.0)), crs=CRS('3857'))\n)" @@ -136,9 +136,8 @@ def test_repr(test_eopatch_path: str) -> None: def test_repr_no_crs(test_eopatch: EOPatch) -> None: test_eopatch.vector_timeless["LULC"].crs = None repr_str = test_eopatch.__repr__() - assert ( - isinstance(repr_str, str) and len(repr_str) > 100 - ), "EOPatch __repr__ must return non-empty string even in case of missing crs" + assert isinstance(repr_str, str) + assert len(repr_str) > 100, "EOPatch __repr__ must return non-empty string even in case of missing crs" def test_add_feature() -> None: @@ -225,19 +224,19 @@ def test_deep_copy(test_eopatch: EOPatch) -> None: assert test_eopatch != eopatch_copy -@pytest.mark.parametrize("features", (..., [(FeatureType.MASK, "CLM")])) +@pytest.mark.parametrize("features", [..., [(FeatureType.MASK, "CLM")]]) def test_copy_lazy_loaded_patch(test_eopatch_path: str, features: FeaturesSpecification) -> None: # shallow copy original_eopatch = EOPatch.load(test_eopatch_path, lazy_loading=True) copied_eopatch = original_eopatch.copy(features=features) - original_data = original_eopatch.mask.__getitem__("CLM", load=False) + original_data = original_eopatch.mask._get_unloaded("CLM") assert isinstance(original_data, FeatureIO), "Shallow copying loads the data." - copied_data = copied_eopatch.mask.__getitem__("CLM", load=False) + copied_data = copied_eopatch.mask._get_unloaded("CLM") assert original_data is copied_data original_mask = original_eopatch.mask["CLM"] - assert copied_eopatch.mask.__getitem__("CLM", load=False).loaded_value is not None + assert copied_eopatch.mask._get_unloaded("CLM").loaded_value is not None copied_mask = copied_eopatch.mask["CLM"] assert original_mask is copied_mask @@ -245,16 +244,17 @@ def test_copy_lazy_loaded_patch(test_eopatch_path: str, features: FeaturesSpecif original_eopatch = EOPatch.load(test_eopatch_path, lazy_loading=True) copied_eopatch = original_eopatch.copy(features=features, deep=True) - original_data = original_eopatch.mask.__getitem__("CLM", load=False) + original_data = original_eopatch.mask._get_unloaded("CLM") assert isinstance(original_data, FeatureIO), "Deep copying loads the data of source." - copied_data = copied_eopatch.mask.__getitem__("CLM", load=False) + copied_data = copied_eopatch.mask._get_unloaded("CLM") assert isinstance(copied_data, FeatureIO), "Deep copying loads the data of target." assert original_data is not copied_data, "Deep copying only does a shallow copy of FeatureIO objects." mask1 = original_eopatch.mask["CLM"] - assert copied_eopatch.mask.__getitem__("CLM", load=False).loaded_value is None + assert copied_eopatch.mask._get_unloaded("CLM").loaded_value is None mask2 = copied_eopatch.mask["CLM"] - assert np.array_equal(mask1, mask2) and mask1 is not mask2, "Data no longer matches after deep copying." + assert np.array_equal(mask1, mask2), "Data no longer matches after deep copying." + assert mask1 is not mask2, "Data was not deep copied." def test_copy_features(test_eopatch: EOPatch) -> None: @@ -266,12 +266,12 @@ def test_copy_features(test_eopatch: EOPatch) -> None: @pytest.mark.parametrize( - "ftype, fname", + ("ftype", "fname"), [ - [FeatureType.DATA, "BANDS-S2-L1C"], - [FeatureType.MASK, "CLM"], - [FeatureType.BBOX, ...], - [FeatureType.TIMESTAMPS, None], + (FeatureType.DATA, "BANDS-S2-L1C"), + (FeatureType.MASK, "CLM"), + (FeatureType.BBOX, ...), + (FeatureType.TIMESTAMPS, None), ], ) def test_contains(ftype: FeatureType, fname: str, test_eopatch: EOPatch) -> None: @@ -318,7 +318,7 @@ def test_equals() -> None: assert eop1 != eop2 -@pytest.fixture(scope="function", name="eopatch_spatial_dim") +@pytest.fixture(name="eopatch_spatial_dim") def eopatch_spatial_dim_fixture() -> EOPatch: patch = EOPatch(bbox=DUMMY_BBOX) patch.data["A"] = np.zeros((1, 2, 3, 4)) @@ -328,11 +328,11 @@ def eopatch_spatial_dim_fixture() -> EOPatch: @pytest.mark.parametrize( - "feature, expected_dim", + ("feature", "expected_dim"), [ - [(FeatureType.DATA, "A"), (2, 3)], - [(FeatureType.MASK, "B"), (3, 2)], - [(FeatureType.MASK_TIMELESS, "C"), (4, 5)], + ((FeatureType.DATA, "A"), (2, 3)), + ((FeatureType.MASK, "B"), (3, 2)), + ((FeatureType.MASK_TIMELESS, "C"), (4, 5)), ], ) def test_get_spatial_dimension( @@ -342,7 +342,7 @@ def test_get_spatial_dimension( @pytest.mark.parametrize( - "patch, expected_features", + ("patch", "expected_features"), [ ( pytest.lazy_fixture("mini_eopatch"), diff --git a/core/eolearn/tests/test_eodata_io.py b/core/eolearn/tests/test_eodata_io.py index e2168f2e0..3ad5d6c33 100644 --- a/core/eolearn/tests/test_eodata_io.py +++ b/core/eolearn/tests/test_eodata_io.py @@ -7,6 +7,7 @@ import datetime import os import tempfile +import warnings from typing import Any, Type import fs @@ -21,7 +22,7 @@ from sentinelhub import CRS, BBox -from eolearn.core import EOPatch, FeatureType, LoadTask, OverwritePermission, SaveTask +from eolearn.core import EOPatch, FeatureType, LoadTask, OverwritePermission, SaveTask, merge_eopatches from eolearn.core.constants import TIMESTAMP_COLUMN from eolearn.core.eodata_io import ( FeatureIO, @@ -42,6 +43,13 @@ DUMMY_BBOX = BBox((0, 0, 1, 1), CRS.WGS84) +@pytest.fixture(name="_silence_warnings") +def _silence_warnings_fixture(): + with warnings.catch_warnings(): + warnings.simplefilter("ignore", category=EODeprecationWarning) + yield + + @pytest.fixture(name="eopatch") def eopatch_fixture(): eopatch = generate_eopatch( @@ -50,10 +58,9 @@ def eopatch_fixture(): FeatureType.MASK_TIMELESS: ["mask"], FeatureType.SCALAR: ["my scalar with spaces"], FeatureType.SCALAR_TIMELESS: ["my timeless scalar with spaces"], + FeatureType.META_INFO: ["something", "something-else"], } ) - eopatch.meta_info["something"] = "nothing" - eopatch.meta_info["something-else"] = "nothing" eopatch.vector["my-df"] = GeoDataFrame( { "values": [1, 2], @@ -91,6 +98,7 @@ def test_saving_in_empty_folder(eopatch, fs_loader): @mock_s3 @pytest.mark.parametrize("fs_loader", FS_LOADERS) +@pytest.mark.usefixtures("_silence_warnings") def test_saving_in_non_empty_folder(eopatch, fs_loader): with fs_loader() as temp_fs: empty_file = "foo.txt" @@ -107,6 +115,7 @@ def test_saving_in_non_empty_folder(eopatch, fs_loader): @mock_s3 @pytest.mark.parametrize("fs_loader", FS_LOADERS) +@pytest.mark.usefixtures("_silence_warnings") def test_overwriting_non_empty_folder(eopatch, fs_loader): with fs_loader() as temp_fs: eopatch.save("/", filesystem=temp_fs) @@ -120,13 +129,13 @@ def test_overwriting_non_empty_folder(eopatch, fs_loader): add_eopatch.save("/", filesystem=temp_fs, overwrite_permission=OverwritePermission.ADD_ONLY) new_eopatch = EOPatch.load("/", filesystem=temp_fs, lazy_loading=False) - assert new_eopatch == eopatch + add_eopatch + assert new_eopatch == merge_eopatches(eopatch, add_eopatch) @mock_s3 @pytest.mark.parametrize("fs_loader", FS_LOADERS) @pytest.mark.parametrize( - "save_features, load_features", + ("save_features", "load_features"), [ (..., ...), ([(FeatureType.DATA, ...), FeatureType.TIMESTAMPS], [(FeatureType.DATA, ...), FeatureType.TIMESTAMPS]), @@ -161,7 +170,7 @@ def test_save_add_only_features(eopatch, fs_loader): ] with fs_loader() as temp_fs: - eopatch.save("/", filesystem=temp_fs, features=features, overwrite_permission=0) + eopatch.save("/", filesystem=temp_fs, features=features, overwrite_permission=OverwritePermission.ADD_ONLY) @mock_s3 @@ -174,6 +183,7 @@ def test_bbox_always_saved(eopatch, fs_loader): @mock_s3 @pytest.mark.parametrize("fs_loader", FS_LOADERS) +@pytest.mark.usefixtures("_silence_warnings") def test_overwrite_failure(fs_loader): eopatch = EOPatch(bbox=DUMMY_BBOX) mask = np.arange(3 * 3 * 2).reshape(3, 3, 2) @@ -184,11 +194,19 @@ def test_overwrite_failure(fs_loader): eopatch.save("/", filesystem=temp_fs) with fs_loader() as temp_fs: - eopatch.save("/", filesystem=temp_fs, features=[(FeatureType.MASK_TIMELESS, "mask")], overwrite_permission=2) + eopatch.save( + "/", + filesystem=temp_fs, + features=[(FeatureType.MASK_TIMELESS, "mask")], + overwrite_permission=OverwritePermission.OVERWRITE_PATCH, + ) with pytest.raises(IOError): eopatch.save( - "/", filesystem=temp_fs, features=[(FeatureType.MASK_TIMELESS, "Mask")], overwrite_permission=0 + "/", + filesystem=temp_fs, + features=[(FeatureType.MASK_TIMELESS, "Mask")], + overwrite_permission=OverwritePermission.ADD_ONLY, ) @@ -270,8 +288,12 @@ def test_cleanup_different_compression(fs_loader, eopatch): with fs_loader() as temp_fs: temp_fs.makedir(folder) - save_compressed_task = SaveTask(folder, filesystem=temp_fs, compress_level=9, overwrite_permission=1) - save_noncompressed_task = SaveTask(folder, filesystem=temp_fs, compress_level=0, overwrite_permission=1) + save_compressed_task = SaveTask( + folder, filesystem=temp_fs, compress_level=9, overwrite_permission="OVERWRITE_FEATURES" + ) + save_noncompressed_task = SaveTask( + folder, filesystem=temp_fs, compress_level=0, overwrite_permission="OVERWRITE_FEATURES" + ) bbox_path = fs.path.join(folder, patch_folder, "bbox.geojson") compressed_bbox_path = bbox_path + ".gz" mask_timeless_path = fs.path.join(folder, patch_folder, "mask_timeless", "mask.npy") @@ -294,6 +316,7 @@ def test_cleanup_different_compression(fs_loader, eopatch): @mock_s3 @pytest.mark.parametrize("fs_loader", FS_LOADERS) @pytest.mark.parametrize("folder_name", ["/", "foo", "foo/bar"]) +@pytest.mark.usefixtures("_silence_warnings") def test_lazy_loading_plus_overwrite_patch(fs_loader, folder_name, eopatch): with fs_loader() as temp_fs: eopatch.save(folder_name, filesystem=temp_fs) @@ -308,7 +331,7 @@ def test_lazy_loading_plus_overwrite_patch(fs_loader, folder_name, eopatch): @pytest.mark.parametrize( - "constructor, data", + ("constructor", "data"), [ (FeatureIONumpy, np.zeros(20)), (FeatureIONumpy, np.zeros((2, 3, 3, 2), dtype=np.int16)), diff --git a/core/eolearn/tests/test_eodata_merge.py b/core/eolearn/tests/test_eodata_merge.py index 6b48262d5..c6b8fa18a 100644 --- a/core/eolearn/tests/test_eodata_merge.py +++ b/core/eolearn/tests/test_eodata_merge.py @@ -12,7 +12,7 @@ from sentinelhub import CRS, BBox -from eolearn.core import EOPatch, FeatureType +from eolearn.core import EOPatch, FeatureType, merge_eopatches from eolearn.core.constants import TIMESTAMP_COLUMN from eolearn.core.eodata_io import FeatureIO from eolearn.core.exceptions import EORuntimeWarning @@ -33,11 +33,11 @@ def test_time_dependent_merge(): timestamps=[all_timestamps[3], all_timestamps[1], all_timestamps[2], all_timestamps[4], all_timestamps[3]], ) - eop = eop1.merge(eop2) + eop = merge_eopatches(eop1, eop2) expected_eop = EOPatch(bbox=DUMMY_BBOX, data={"bands": np.ones((6, 4, 5, 2))}, timestamps=all_timestamps) assert eop == expected_eop - eop = eop1.merge(eop2, time_dependent_op="concatenate") + eop = merge_eopatches(eop1, eop2, time_dependent_op="concatenate") expected_eop = EOPatch( bbox=DUMMY_BBOX, data={"bands": np.ones((8, 4, 5, 2))}, @@ -51,9 +51,9 @@ def test_time_dependent_merge(): eop2.data["bands"][1, ...] = 4 with pytest.raises(ValueError): - eop1.merge(eop2) + merge_eopatches(eop1, eop2) - eop = eop1.merge(eop2, time_dependent_op="mean") + eop = merge_eopatches(eop1, eop2, time_dependent_op="mean") expected_eop = EOPatch(bbox=DUMMY_BBOX, data={"bands": np.ones((6, 4, 5, 2))}, timestamps=all_timestamps) expected_eop.data["bands"][1, ...] = 4 expected_eop.data["bands"][3, ...] = 3 @@ -72,23 +72,25 @@ def test_time_dependent_merge_with_missing_features(): ) eop2 = EOPatch(bbox=DUMMY_BBOX, timestamps=timestamps[:4]) - eop = eop1.merge(eop2) + eop = merge_eopatches(eop1, eop2) assert eop == eop1 - eop = eop2.merge(eop1, eop1, eop2, time_dependent_op="min") + eop = merge_eopatches(eop2, eop1, eop1, eop2, time_dependent_op="min") assert eop == eop1 - eop = eop1.merge() + eop = merge_eopatches(eop1) assert eop == eop1 def test_failed_time_dependent_merge(): eop1 = EOPatch(bbox=DUMMY_BBOX, data={"bands": np.ones((6, 4, 5, 2))}) with pytest.raises(ValueError): - eop1.merge() + merge_eopatches( + eop1, + ) eop2 = EOPatch(bbox=DUMMY_BBOX, data={"bands": np.ones((1, 4, 5, 2))}, timestamps=[dt.datetime(2020, 1, 1)]) with pytest.raises(ValueError): - eop2.merge(eop1) + merge_eopatches(eop2, eop1) def test_timeless_merge(): @@ -102,9 +104,9 @@ def test_timeless_merge(): ) with pytest.raises(ValueError): - eop1.merge(eop2) + merge_eopatches(eop1, eop2) - eop = eop1.merge(eop2, timeless_op="concatenate") + eop = merge_eopatches(eop1, eop2, timeless_op="concatenate") expected_eop = EOPatch( bbox=DUMMY_BBOX, mask_timeless={ @@ -116,7 +118,7 @@ def test_timeless_merge(): expected_eop.mask_timeless["mask"][..., 5:] = 4 assert eop == expected_eop - eop = eop1.merge(eop2, eop2, timeless_op="min") + eop = merge_eopatches(eop1, eop2, eop2, timeless_op="min") expected_eop = EOPatch( bbox=DUMMY_BBOX, mask_timeless={ @@ -130,7 +132,7 @@ def test_timeless_merge(): def test_vector_merge(): bbox = BBox((1, 2, 3, 4), CRS.WGS84) - df = GeoDataFrame( + dummy_gdf = GeoDataFrame( { "values": [1, 2], TIMESTAMP_COLUMN: [dt.datetime(2017, 1, 1, 10, 4, 7), dt.datetime(2017, 1, 4, 10, 14, 5)], @@ -139,28 +141,27 @@ def test_vector_merge(): crs=bbox.crs.pyproj_crs(), ) - eop1 = EOPatch(bbox=bbox, vector_timeless={"vectors": df}) + eop1 = EOPatch(bbox=bbox, vector_timeless={"vectors": dummy_gdf}) - for eop in [eop1.merge(eop1), eop1 + eop1]: - assert eop == eop1 + assert eop1 == merge_eopatches(eop1, eop1) eop2 = eop1.__deepcopy__() eop2.vector_timeless["vectors"].crs = CRS.POP_WEB.pyproj_crs() with pytest.raises(ValueError): - eop1.merge(eop2) + merge_eopatches(eop1, eop2) def test_meta_info_merge(): eop1 = EOPatch(bbox=DUMMY_BBOX, meta_info={"a": 1, "b": 2}) eop2 = EOPatch(bbox=DUMMY_BBOX, meta_info={"a": 1, "c": 5}) - eop = eop1.merge(eop2) + eop = merge_eopatches(eop1, eop2) expected_eop = EOPatch(bbox=DUMMY_BBOX, meta_info={"a": 1, "b": 2, "c": 5}) assert eop == expected_eop eop2.meta_info["a"] = 3 with pytest.warns(EORuntimeWarning): - eop = eop1.merge(eop2) + eop = merge_eopatches(eop1, eop2) assert eop == expected_eop @@ -168,18 +169,18 @@ def test_bbox_merge(): eop1 = EOPatch(bbox=BBox((1, 2, 3, 4), CRS.WGS84)) eop2 = EOPatch(bbox=BBox((1, 2, 3, 4), CRS.POP_WEB)) - eop = eop1.merge(eop1) + eop = merge_eopatches(eop1, eop1) assert eop == eop1 with pytest.raises(ValueError): - eop1.merge(eop2) + merge_eopatches(eop1, eop2) def test_lazy_loading(test_eopatch_path): eop1 = EOPatch.load(test_eopatch_path, lazy_loading=True) eop2 = EOPatch.load(test_eopatch_path, lazy_loading=True) - eop = eop1.merge(eop2, features=[(FeatureType.MASK, ...)]) + eop = merge_eopatches(eop1, eop2, features=[(FeatureType.MASK, ...)]) assert isinstance(eop.mask.get("CLM"), np.ndarray) assert isinstance(eop1.mask.get("CLM"), np.ndarray) - assert isinstance(eop1.mask_timeless.get("LULC"), FeatureIO) + assert isinstance(eop1.mask_timeless._get_unloaded("LULC"), FeatureIO) # noqa: SLF001 diff --git a/core/eolearn/tests/test_eoexecutor.py b/core/eolearn/tests/test_eoexecutor.py index 69949896d..789026591 100644 --- a/core/eolearn/tests/test_eoexecutor.py +++ b/core/eolearn/tests/test_eoexecutor.py @@ -69,27 +69,19 @@ def test_nodes_fixture(): example = EONode(ExampleTask()) foo = EONode(FooTask(), inputs=[example, example]) output = EONode(OutputTask("output"), inputs=[foo]) - nodes = {"example": example, "foo": foo, "output": output} - return nodes + return {"example": example, "foo": foo, "output": output} @pytest.fixture(name="workflow") def workflow_fixture(test_nodes): - workflow = EOWorkflow(list(test_nodes.values())) - return workflow + return EOWorkflow(list(test_nodes.values())) @pytest.fixture(name="execution_kwargs") def execution_kwargs_fixture(test_nodes): example_node = test_nodes["example"] - execution_kwargs = [ - {example_node: {"arg1": 1}}, - {}, - {example_node: {"arg1": 3, "arg3": 10}}, - {example_node: {"arg1": None}}, - ] - return execution_kwargs + return [{example_node: {"arg1": 1}}, {}, {example_node: {"arg1": 3, "arg3": 10}}, {example_node: {"arg1": None}}] class DummyFilesystemFileHandler(FileHandler): @@ -220,7 +212,7 @@ def test_with_lock(num_workers): handler.close() logger.removeHandler(handler) - with open(fp.name, "r") as log_file: + with open(fp.name) as log_file: lines = log_file.read().strip("\n ").split("\n") assert len(lines) == 2 * num_workers @@ -243,7 +235,7 @@ def test_without_lock(num_workers): handler.close() logger.removeHandler(handler) - with open(fp.name, "r") as log_file: + with open(fp.name) as log_file: lines = log_file.read().strip("\n ").split("\n") assert len(lines) == 2 * num_workers diff --git a/core/eolearn/tests/test_eonode.py b/core/eolearn/tests/test_eonode.py index 069f92d2c..182be2e82 100644 --- a/core/eolearn/tests/test_eonode.py +++ b/core/eolearn/tests/test_eonode.py @@ -4,7 +4,6 @@ This source code is licensed under the MIT license, see the LICENSE file in the root directory of this source tree. """ -import time from eolearn.core import EONode, EOTask, OutputTask, linearly_connect_tasks @@ -26,31 +25,35 @@ def execute(self, x, *, d=1): def test_nodes_different_uids(): uids = set() + task = Inc() for _ in range(5000): - node = EONode(Inc()) + node = EONode(task) uids.add(node.uid) assert len(uids) == 5000, "Different nodes should have different uids." def test_hashing(): - _ = {EONode(Inc()): "Can be hashed!"} + """This tests that nodes are hashable. If this test is slow then hashing of large workflows is slow. + Probably due to structural hashing (should be avoided). + """ + task1 = Inc() + task2 = DivideTask() - linear = EONode(Inc()) + _ = {EONode(task1): "Can be hashed!"} + + many_nodes = {} for _ in range(5000): - linear = EONode(Inc(), inputs=[linear]) + many_nodes[EONode(task1)] = "We should all be different!" + assert len(many_nodes) == 5000, "Hash clashes happen." - branch_1, branch_2 = EONode(Inc()), EONode(Inc()) + branch_1, branch_2 = EONode(task1), EONode(task1) for _ in range(500): - branch_1 = EONode(DivideTask(), inputs=(branch_1, branch_2)) - branch_2 = EONode(DivideTask(), inputs=(branch_2, EONode(Inc()))) + branch_1 = EONode(task2, inputs=(branch_1, branch_2)) + branch_2 = EONode(task2, inputs=(branch_2, EONode(task1))) - t_start = time.time() - linear.__hash__() branch_1.__hash__() branch_2.__hash__() - t_end = time.time() - assert t_end - t_start < 5, "Assert hashing slows down for large workflows!" def test_get_dependencies(): diff --git a/core/eolearn/tests/test_eoworkflow.py b/core/eolearn/tests/test_eoworkflow.py index c4202ae47..9de5a801e 100644 --- a/core/eolearn/tests/test_eoworkflow.py +++ b/core/eolearn/tests/test_eoworkflow.py @@ -27,7 +27,7 @@ from eolearn.core.eoworkflow import NodeStats -class CustomException(ValueError): +class CustomExceptionError(ValueError): pass @@ -48,7 +48,7 @@ def execute(self, x, *, d=1): class ExceptionTask(EOTask): def execute(self, *_, **__): - raise CustomException + raise CustomExceptionError def test_workflow_arguments(): @@ -272,7 +272,7 @@ def test_exception_handling(): increase_node = EONode(IncTask(), inputs=[exception_node]) workflow = EOWorkflow([input_node, exception_node, increase_node]) - with pytest.raises(CustomException): + with pytest.raises(CustomExceptionError): workflow.execute() results = workflow.execute(raise_errors=False) @@ -288,7 +288,7 @@ def test_exception_handling(): assert node_stats.node_name == node.name if node is exception_node: - assert isinstance(node_stats.exception, CustomException) + assert isinstance(node_stats.exception, CustomExceptionError) assert node_stats.exception_traceback.startswith("Traceback") else: assert node_stats.exception is None diff --git a/core/eolearn/tests/test_extra/test_ray.py b/core/eolearn/tests/test_extra/test_ray.py index b96504cf4..bf1e8c3ea 100644 --- a/core/eolearn/tests/test_extra/test_ray.py +++ b/core/eolearn/tests/test_extra/test_ray.py @@ -19,6 +19,8 @@ from eolearn.core.eoworkflow_tasks import OutputTask from eolearn.core.extra.ray import RayExecutor, join_ray_futures, join_ray_futures_iter, parallelize_with_ray +# ruff: noqa: ARG001 + class ExampleTask(EOTask): def execute(self, *_, **kwargs): @@ -51,8 +53,8 @@ def filter(self, record): return record.levelno >= logging.WARNING -@pytest.fixture(name="simple_cluster", scope="module") -def simple_cluster_fixture(): +@pytest.fixture(name="_simple_cluster", scope="module") +def _simple_cluster_fixture(): ray.init(log_to_driver=False) yield ray.shutdown() @@ -63,27 +65,19 @@ def test_nodes_fixture(): example = EONode(ExampleTask()) foo = EONode(FooTask(), inputs=[example, example]) output = EONode(OutputTask("output"), inputs=[foo]) - nodes = {"example": example, "foo": foo, "output": output} - return nodes + return {"example": example, "foo": foo, "output": output} @pytest.fixture(name="workflow") def workflow_fixture(test_nodes): - workflow = EOWorkflow(list(test_nodes.values())) - return workflow + return EOWorkflow(list(test_nodes.values())) @pytest.fixture(name="execution_kwargs") def execution_kwargs_fixture(test_nodes): example_node = test_nodes["example"] - execution_kwargs = [ - {example_node: {"arg1": 1}}, - {}, - {example_node: {"arg1": 3, "arg3": 10}}, - {example_node: {"arg1": None}}, - ] - return execution_kwargs + return [{example_node: {"arg1": 1}}, {}, {example_node: {"arg1": 3, "arg3": 10}}, {example_node: {"arg1": None}}] def test_fail_without_ray(workflow, execution_kwargs): @@ -98,7 +92,8 @@ def test_fail_without_ray(workflow, execution_kwargs): @pytest.mark.parametrize("filter_logs", [True, False]) @pytest.mark.parametrize("execution_names", [None, [4, "x", "y", "z"]]) -def test_read_logs(filter_logs, execution_names, workflow, execution_kwargs, simple_cluster): +@pytest.mark.usefixtures("_simple_cluster") +def test_read_logs(filter_logs, execution_names, workflow, execution_kwargs): with tempfile.TemporaryDirectory() as tmp_dir_name: executor = RayExecutor( workflow, @@ -129,7 +124,8 @@ def test_read_logs(filter_logs, execution_names, workflow, execution_kwargs, sim assert line_count == expected_line_count -def test_execution_results(workflow, execution_kwargs, simple_cluster): +@pytest.mark.usefixtures("_simple_cluster") +def test_execution_results(workflow, execution_kwargs): with tempfile.TemporaryDirectory() as tmp_dir_name: executor = RayExecutor(workflow, execution_kwargs, logs_folder=tmp_dir_name) executor.run(desc="Test Ray") @@ -140,7 +136,8 @@ def test_execution_results(workflow, execution_kwargs, simple_cluster): assert isinstance(time_stat, datetime.datetime) -def test_execution_errors(workflow, execution_kwargs, simple_cluster): +@pytest.mark.usefixtures("_simple_cluster") +def test_execution_errors(workflow, execution_kwargs): with tempfile.TemporaryDirectory() as tmp_dir_name: executor = RayExecutor(workflow, execution_kwargs, logs_folder=tmp_dir_name) executor.run() @@ -155,7 +152,8 @@ def test_execution_errors(workflow, execution_kwargs, simple_cluster): assert executor.get_failed_executions() == [3] -def test_execution_results2(workflow, execution_kwargs, simple_cluster): +@pytest.mark.usefixtures("_simple_cluster") +def test_execution_results2(workflow, execution_kwargs): executor = RayExecutor(workflow, execution_kwargs) results = executor.run() @@ -166,7 +164,8 @@ def test_execution_results2(workflow, execution_kwargs, simple_cluster): assert workflow_results.outputs["output"] == 42 -def test_keyboard_interrupt(simple_cluster): +@pytest.mark.usefixtures("_simple_cluster") +def test_keyboard_interrupt(): exception_node = EONode(KeyboardExceptionTask()) workflow = EOWorkflow([exception_node]) execution_kwargs = [] @@ -177,7 +176,8 @@ def test_keyboard_interrupt(simple_cluster): RayExecutor(workflow, execution_kwargs).run() -def test_reruns(workflow, execution_kwargs, simple_cluster): +@pytest.mark.usefixtures("_simple_cluster") +def test_reruns(workflow, execution_kwargs): executor = RayExecutor(workflow, execution_kwargs) for _ in range(100): executor.run() @@ -190,7 +190,8 @@ def test_reruns(workflow, execution_kwargs, simple_cluster): executor.run() -def test_run_after_interrupt(workflow, execution_kwargs, simple_cluster): +@pytest.mark.usefixtures("_simple_cluster") +def test_run_after_interrupt(workflow, execution_kwargs): foo_node = EONode(FooTask()) exception_node = EONode(KeyboardExceptionTask(), inputs=[foo_node]) exception_workflow = EOWorkflow([foo_node, exception_node]) @@ -205,7 +206,8 @@ def test_run_after_interrupt(workflow, execution_kwargs, simple_cluster): assert [res.outputs for res in result_before_exception] == [res.outputs for res in result_after_exception] -def test_mix_with_eoexecutor(workflow, execution_kwargs, simple_cluster): +@pytest.mark.usefixtures("_simple_cluster") +def test_mix_with_eoexecutor(workflow, execution_kwargs): rayexecutor = RayExecutor(workflow, execution_kwargs) eoexecutor = EOExecutor(workflow, execution_kwargs) for _ in range(10): @@ -234,7 +236,8 @@ def plus_one(value): return value + 1 -def test_join_ray_futures(simple_cluster): +@pytest.mark.usefixtures("_simple_cluster") +def test_join_ray_futures(): futures = [plus_one.remote(value) for value in range(5)] results = join_ray_futures(futures) @@ -242,7 +245,8 @@ def test_join_ray_futures(simple_cluster): assert futures == [] -def test_join_ray_futures_iter(simple_cluster): +@pytest.mark.usefixtures("_simple_cluster") +def test_join_ray_futures_iter(): futures = [plus_one.remote(value) for value in range(5)] results = [] diff --git a/core/eolearn/tests/test_graph.py b/core/eolearn/tests/test_graph.py index 9cf3a89f0..3988f3fc5 100644 --- a/core/eolearn/tests/test_graph.py +++ b/core/eolearn/tests/test_graph.py @@ -124,9 +124,9 @@ def test_del_vertex(): def test_resolve_dependencies(edges): graph = DirectedGraph.from_edges(edges) - if DirectedGraph._is_cyclic(graph): + if DirectedGraph._is_cyclic(graph): # noqa[SLF001] with pytest.raises(CyclicDependencyError): graph.topologically_ordered_vertices() else: vertex_position = {vertex: i for i, vertex in enumerate(graph.topologically_ordered_vertices())} - assert functools.reduce(lambda P, Q: P and Q, [vertex_position[u] < vertex_position[v] for u, v in edges]) + assert functools.reduce(lambda p, q: p and q, [vertex_position[u] < vertex_position[v] for u, v in edges]) diff --git a/core/eolearn/tests/test_utils/test_common.py b/core/eolearn/tests/test_utils/test_common.py index 76dc0a55c..db743b56f 100644 --- a/core/eolearn/tests/test_utils/test_common.py +++ b/core/eolearn/tests/test_utils/test_common.py @@ -25,11 +25,8 @@ (bytes, False), (complex, False), (np.number, False), - (np.int, True), (np.byte, True), (np.bool_, True), - (np.bool, True), - (np.bool8, True), (np.integer, True), (np.dtype("uint16"), True), (np.int8, True), @@ -43,7 +40,7 @@ ] -@pytest.mark.parametrize("number_type, is_discrete", DTYPE_TEST_CASES) +@pytest.mark.parametrize(("number_type", "is_discrete"), DTYPE_TEST_CASES) def test_is_discrete_type(number_type, is_discrete): """Checks the given type and its numpy dtype against the expected answer.""" assert is_discrete_type(number_type) is is_discrete diff --git a/core/eolearn/tests/test_utils/test_fs.py b/core/eolearn/tests/test_utils/test_fs.py index f0f02cd2a..df7eedc4c 100644 --- a/core/eolearn/tests/test_utils/test_fs.py +++ b/core/eolearn/tests/test_utils/test_fs.py @@ -102,7 +102,7 @@ def test_get_aws_credentials(mocked_copy): @pytest.mark.parametrize( - "filesystem, compare_params", + ("filesystem", "compare_params"), [ (OSFS("."), ["root_path"]), (TempFS(identifier="test"), ["identifier", "_temp_dir"]), @@ -119,7 +119,7 @@ def test_filesystem_serialization(filesystem: FS, compare_params: List[str]): unpickled_filesystem = unpickle_fs(pickled_filesystem) assert filesystem is not unpickled_filesystem - assert isinstance(unpickled_filesystem._lock, RLock) + assert isinstance(unpickled_filesystem._lock, RLock) # noqa[SLF001] for param in compare_params: assert getattr(filesystem, param) == getattr(unpickled_filesystem, param) @@ -172,7 +172,7 @@ def test_join_path(path_parts, expected_path): @pytest.mark.parametrize( - "filesystem, path, expected_full_path", + ("filesystem", "path", "expected_full_path"), [ (OSFS("/tmp"), "my/folder", "/tmp/my/folder"), (S3FS(bucket_name="data", dir_path="/folder"), "/sub/folder", "s3://data/folder/sub/folder"), diff --git a/core/eolearn/tests/test_utils/test_parallelize.py b/core/eolearn/tests/test_utils/test_parallelize.py index d5262de19..e3030c9d1 100644 --- a/core/eolearn/tests/test_utils/test_parallelize.py +++ b/core/eolearn/tests/test_utils/test_parallelize.py @@ -21,7 +21,7 @@ @pytest.mark.parametrize( - "workers, multiprocess, expected_type", + ("workers", "multiprocess", "expected_type"), [ (1, False, _ProcessingType.SINGLE_PROCESS), (1, True, _ProcessingType.SINGLE_PROCESS), @@ -41,7 +41,7 @@ def test_execute_with_mp_lock(): @pytest.mark.parametrize( - "workers, multiprocess", + ("workers", "multiprocess"), [ (1, True), (3, False), diff --git a/core/eolearn/tests/test_utils/test_parsing.py b/core/eolearn/tests/test_utils/test_parsing.py index 6c9f8588b..f1a75f888 100644 --- a/core/eolearn/tests/test_utils/test_parsing.py +++ b/core/eolearn/tests/test_utils/test_parsing.py @@ -10,7 +10,7 @@ @dataclass class ParserTestCase: - input: FeaturesSpecification + parser_input: FeaturesSpecification features: List[FeatureSpec] renaming: List[FeatureRenameSpec] specifications: Optional[List[Tuple[FeatureType, Union[str, EllipsisType]]]] = None @@ -24,30 +24,30 @@ def get_test_case_description(test_case: ParserTestCase) -> str: @pytest.mark.parametrize( "test_case", [ - ParserTestCase(input=[], features=[], renaming=[], specifications=[], description="Empty input"), + ParserTestCase(parser_input=[], features=[], renaming=[], specifications=[], description="Empty input"), ParserTestCase( - input=(FeatureType.DATA, "bands"), + parser_input=(FeatureType.DATA, "bands"), features=[(FeatureType.DATA, "bands")], renaming=[(FeatureType.DATA, "bands", "bands")], specifications=[(FeatureType.DATA, "bands")], description="Singleton feature", ), ParserTestCase( - input=FeatureType.BBOX, + parser_input=FeatureType.BBOX, features=[(FeatureType.BBOX, None)], renaming=[(FeatureType.BBOX, None, None)], specifications=[(FeatureType.BBOX, ...)], description="BBox feature", ), ParserTestCase( - input=(FeatureType.MASK, "CLM", "new_CLM"), + parser_input=(FeatureType.MASK, "CLM", "new_CLM"), features=[(FeatureType.MASK, "CLM")], renaming=[(FeatureType.MASK, "CLM", "new_CLM")], specifications=[(FeatureType.MASK, "CLM")], description="Renamed feature", ), ParserTestCase( - input=[FeatureType.BBOX, (FeatureType.DATA, "bands"), (FeatureType.VECTOR_TIMELESS, "geoms")], + parser_input=[FeatureType.BBOX, (FeatureType.DATA, "bands"), (FeatureType.VECTOR_TIMELESS, "geoms")], features=[(FeatureType.BBOX, None), (FeatureType.DATA, "bands"), (FeatureType.VECTOR_TIMELESS, "geoms")], renaming=[ (FeatureType.BBOX, None, None), @@ -62,7 +62,7 @@ def get_test_case_description(test_case: ParserTestCase) -> str: description="List of inputs", ), ParserTestCase( - input=((FeatureType.TIMESTAMPS, ...), (FeatureType.MASK, "CLM"), (FeatureType.SCALAR, "a", "b")), + parser_input=((FeatureType.TIMESTAMPS, ...), (FeatureType.MASK, "CLM"), (FeatureType.SCALAR, "a", "b")), features=[(FeatureType.TIMESTAMPS, None), (FeatureType.MASK, "CLM"), (FeatureType.SCALAR, "a")], renaming=[ (FeatureType.TIMESTAMPS, None, None), @@ -73,7 +73,7 @@ def get_test_case_description(test_case: ParserTestCase) -> str: description="Tuple of inputs with rename", ), ParserTestCase( - input={ + parser_input={ FeatureType.DATA: ["bands_S2", ("bands_l8", "BANDS_L8")], FeatureType.MASK_TIMELESS: [], FeatureType.BBOX: ..., @@ -104,24 +104,24 @@ def get_test_case_description(test_case: ParserTestCase) -> str: ) def test_feature_parser_no_eopatch(test_case: ParserTestCase): """Test that input is parsed according to our expectations. No EOPatch provided.""" - parser = FeatureParser(test_case.input) + parser = FeatureParser(test_case.parser_input) assert parser.get_features() == test_case.features assert parser.get_renamed_features() == test_case.renaming assert parser.get_feature_specifications() == test_case.specifications @pytest.mark.parametrize( - "test_input, specifications", + ("test_input", "specifications"), [ - [(FeatureType.DATA, ...), [(FeatureType.DATA, ...)]], - [ + ((FeatureType.DATA, ...), [(FeatureType.DATA, ...)]), + ( [FeatureType.BBOX, (FeatureType.MASK, "CLM"), FeatureType.DATA], [(FeatureType.BBOX, ...), (FeatureType.MASK, "CLM"), (FeatureType.DATA, ...)], - ], - [ + ), + ( {FeatureType.BBOX: None, FeatureType.MASK: ["CLM"], FeatureType.DATA: ...}, [(FeatureType.BBOX, ...), (FeatureType.MASK, "CLM"), (FeatureType.DATA, ...)], - ], + ), ], ) def test_feature_parser_no_eopatch_failure( @@ -137,17 +137,17 @@ def test_feature_parser_no_eopatch_failure( @pytest.mark.parametrize( - "test_input, allowed_types", + ("test_input", "allowed_types"), [ - [ + ( ( (FeatureType.DATA, "bands", "new_bands"), (FeatureType.MASK, "IS_VALID", "new_IS_VALID"), (FeatureType.MASK, "CLM", "new_CLM"), ), (FeatureType.MASK,), - ], - [ + ), + ( { FeatureType.MASK: ["CLM", "IS_VALID"], FeatureType.DATA: [("bands", "new_bands")], @@ -157,7 +157,7 @@ def test_feature_parser_no_eopatch_failure( FeatureType.MASK, FeatureType.DATA, ), - ], + ), ], ) def test_allowed_feature_types_iterable(test_input: FeaturesSpecification, allowed_types: Iterable[FeatureType]): @@ -168,31 +168,34 @@ def test_allowed_feature_types_iterable(test_input: FeaturesSpecification, allow @pytest.fixture(name="eopatch", scope="module") def eopatch_fixture(): - patch = generate_eopatch( - {FeatureType.DATA: ["data", "CLP"], FeatureType.MASK: ["data", "IS_VALID"], FeatureType.MASK_TIMELESS: ["LULC"]} + return generate_eopatch( + { + FeatureType.DATA: ["data", "CLP"], + FeatureType.MASK: ["data", "IS_VALID"], + FeatureType.MASK_TIMELESS: ["LULC"], + FeatureType.META_INFO: ["something"], + } ) - patch.meta_info = {"something": "else"} - return patch @pytest.mark.parametrize( - "test_input, allowed_types", + ("test_input", "allowed_types"), [ - [ + ( ( (FeatureType.DATA, "bands", "new_bands"), (FeatureType.MASK, "IS_VALID", "new_IS_VALID"), (FeatureType.MASK, "CLM", "new_CLM"), ), lambda x: x == FeatureType.MASK, - ], - [ + ), + ( { FeatureType.META_INFO: ["something"], FeatureType.DATA: [("bands", "new_bands")], }, lambda ftype: not ftype.is_meta(), - ], + ), ], ) def test_allowed_feature_types_callable( @@ -224,7 +227,7 @@ def test_all_features_allowed_feature_types( "test_case", [ ParserTestCase( - input=..., + parser_input=..., features=[ (FeatureType.BBOX, None), (FeatureType.DATA, "data"), @@ -248,13 +251,13 @@ def test_all_features_allowed_feature_types( description="Get-all", ), ParserTestCase( - input=(FeatureType.DATA, ...), + parser_input=(FeatureType.DATA, ...), features=[(FeatureType.DATA, "data"), (FeatureType.DATA, "CLP")], renaming=[(FeatureType.DATA, "data", "data"), (FeatureType.DATA, "CLP", "CLP")], description="Get-all for a feature type", ), ParserTestCase( - input=[ + parser_input=[ FeatureType.BBOX, FeatureType.MASK, (FeatureType.META_INFO, ...), @@ -277,7 +280,7 @@ def test_all_features_allowed_feature_types( description="Sequence with ellipsis", ), ParserTestCase( - input={ + parser_input={ FeatureType.DATA: ["data", ("CLP", "new_CLP")], FeatureType.MASK_TIMELESS: ..., }, @@ -290,14 +293,17 @@ def test_all_features_allowed_feature_types( description="Dictionary with ellipsis", ), ParserTestCase( - input={FeatureType.VECTOR: ...}, features=[], renaming=[], description="Request all of an empty feature" + parser_input={FeatureType.VECTOR: ...}, + features=[], + renaming=[], + description="Request all of an empty feature", ), ], ids=get_test_case_description, ) def test_feature_parser_with_eopatch(test_case: ParserTestCase, eopatch: EOPatch): """Test that input is parsed according to our expectations. EOPatch provided.""" - parser = FeatureParser(test_case.input) + parser = FeatureParser(test_case.parser_input) assert parser.get_features(eopatch) == test_case.features, f"{parser.get_features(eopatch)}" assert parser.get_renamed_features(eopatch) == test_case.renaming diff --git a/core/eolearn/tests/test_utils/test_raster.py b/core/eolearn/tests/test_utils/test_raster.py index 01c4d34bb..00d532504 100644 --- a/core/eolearn/tests/test_utils/test_raster.py +++ b/core/eolearn/tests/test_utils/test_raster.py @@ -5,15 +5,16 @@ This source code is licensed under the MIT license, see the LICENSE file in the root directory of this source tree. """ import warnings -from typing import Optional, Tuple +from typing import Literal, Optional, Tuple import numpy as np import pytest from numpy.testing import assert_array_equal -from typing_extensions import Literal from eolearn.core.utils.raster import constant_pad, fast_nanpercentile +# ruff: noqa: NPY002 + @pytest.mark.parametrize("size", [0, 5]) @pytest.mark.parametrize("percentile", [0, 1.5, 50, 80.99, 100]) diff --git a/core/eolearn/tests/test_utils/test_testing.py b/core/eolearn/tests/test_utils/test_testing.py index 0d62c19f7..18e1166f5 100644 --- a/core/eolearn/tests/test_utils/test_testing.py +++ b/core/eolearn/tests/test_utils/test_testing.py @@ -61,6 +61,7 @@ def test_generate_eopatch_config(config: Dict[str, Any]) -> None: FeatureType.MASK_TIMELESS: "mask_timeless", FeatureType.SCALAR_TIMELESS: "scalar_timeless", FeatureType.LABEL_TIMELESS: "label_timeless", + FeatureType.META_INFO: "meta_info", }, ], ) @@ -145,11 +146,16 @@ def test_generate_eopatch_data(test_case: GenerateTestCase) -> None: @pytest.mark.parametrize( "feature", [ - (FeatureType.META_INFO, "meta_info"), (FeatureType.VECTOR, "vector"), (FeatureType.VECTOR_TIMELESS, "vector_timeless"), + {FeatureType.VECTOR_TIMELESS: ["vector_timeless"], FeatureType.META_INFO: ["test_meta"]}, ], ) -def test_generate_eopatch_fails(feature: FeatureSpec) -> None: +def test_generate_eopatch_fails(feature: FeaturesSpecification) -> None: with pytest.raises(ValueError): generate_eopatch(feature) + + +def test_generate_meta_data() -> None: + patch = generate_eopatch((FeatureType.META_INFO, "test_meta")) + assert isinstance(patch.meta_info["test_meta"], str) diff --git a/core/requirements.txt b/core/requirements.txt index a30c2e690..43f6bd519 100644 --- a/core/requirements.txt +++ b/core/requirements.txt @@ -1,10 +1,9 @@ -attrs>=19.2.0 boto3 fs fs-s3fs -geopandas>=0.8.1 +geopandas>=0.11.0 numpy>=1.20.0 python-dateutil -sentinelhub>=3.8.1 +sentinelhub>=3.9.0 tqdm>=4.27 typing-extensions diff --git a/core/setup.py b/core/setup.py index f9908c8ff..0fc14abb0 100644 --- a/core/setup.py +++ b/core/setup.py @@ -7,9 +7,7 @@ def get_long_description(): this_directory = os.path.abspath(os.path.dirname(__file__)) with open(os.path.join(this_directory, "README.md"), encoding="utf-8") as f: - long_description = f.read() - - return long_description + return f.read() def parse_requirements(file): @@ -29,7 +27,7 @@ def get_version(): setup( name="eo-learn-core", - python_requires=">=3.7", + python_requires=">=3.8", version=get_version(), description="Core Machine Learning Framework at Sinergise", long_description=get_long_description(), @@ -61,10 +59,10 @@ def get_version(): "Operating System :: Unix", "Programming Language :: Python", "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.7", "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", "Topic :: Scientific/Engineering", "Topic :: Scientific/Engineering :: GIS", "Topic :: Scientific/Engineering :: Image Processing", diff --git a/coregistration/MANIFEST.in b/coregistration/MANIFEST.in index 1b163e951..7d9a0dcc1 100644 --- a/coregistration/MANIFEST.in +++ b/coregistration/MANIFEST.in @@ -1,4 +1,5 @@ include requirements*.txt include LICENSE include README.md +include eolearn/coregistration/py.typed exclude eolearn/tests/* diff --git a/coregistration/eolearn/coregistration/__init__.py b/coregistration/eolearn/coregistration/__init__.py index 42f12c071..3c80f69e9 100644 --- a/coregistration/eolearn/coregistration/__init__.py +++ b/coregistration/eolearn/coregistration/__init__.py @@ -4,4 +4,4 @@ from .coregistration import ECCRegistrationTask, get_gradient -__version__ = "1.4.1" +__version__ = "1.4.2" diff --git a/coregistration/eolearn/coregistration/coregistration.py b/coregistration/eolearn/coregistration/coregistration.py index d3d1e8abd..bb81d02dc 100644 --- a/coregistration/eolearn/coregistration/coregistration.py +++ b/coregistration/eolearn/coregistration/coregistration.py @@ -10,7 +10,6 @@ import logging import warnings -from typing import Optional, Tuple import cv2 import numpy as np @@ -37,10 +36,10 @@ class ECCRegistrationTask(EOTask): def __init__( self, - registration_feature: Tuple[FeatureType, str], - reference_feature: Tuple[FeatureType, str], + registration_feature: tuple[FeatureType, str], + reference_feature: tuple[FeatureType, str], channel: int, - valid_mask_feature: Optional[Tuple[FeatureType, str]] = None, + valid_mask_feature: tuple[FeatureType, str] | None = None, apply_to_features: FeaturesSpecification = ..., interpolation_mode: int = cv2.INTER_LINEAR, warp_mode: int = cv2.MOTION_TRANSLATION, @@ -98,7 +97,7 @@ def register( self, src: np.ndarray, trg: np.ndarray, - valid_mask: Optional[np.ndarray] = None, + valid_mask: np.ndarray | None = None, warp_mode: int = cv2.MOTION_TRANSLATION, ) -> np.ndarray: """Method that estimates the transformation between source and target image""" @@ -160,7 +159,7 @@ def execute(self, eopatch: EOPatch) -> EOPatch: new_eopatch[FeatureType.META_INFO, "warp_matrices"] = warp_matrices return new_eopatch - def warp(self, img: np.ndarray, warp_matrix: np.ndarray, shape: Tuple[int, int], flags: int) -> np.ndarray: + def warp(self, img: np.ndarray, warp_matrix: np.ndarray, shape: tuple[int, int], flags: int) -> np.ndarray: """Transform the target image with the estimated transformation matrix""" if warp_matrix.shape == (3, 3): return cv2.warpPerspective( @@ -216,6 +215,5 @@ def get_gradient(src: np.ndarray) -> np.ndarray: grad_x = cv2.Sobel(src, cv2.CV_32F, 1, 0, ksize=3) grad_y = cv2.Sobel(src, cv2.CV_32F, 0, 1, ksize=3) - # Combine the two gradients - grad = cv2.addWeighted(np.absolute(grad_x), 0.5, np.absolute(grad_y), 0.5, 0) - return grad + # Combine and return the two gradients + return cv2.addWeighted(np.absolute(grad_x), 0.5, np.absolute(grad_y), 0.5, 0) diff --git a/coregistration/eolearn/coregistration/py.typed b/coregistration/eolearn/coregistration/py.typed new file mode 100644 index 000000000..e69de29bb diff --git a/coregistration/setup.py b/coregistration/setup.py index 110ed8d52..474639df9 100644 --- a/coregistration/setup.py +++ b/coregistration/setup.py @@ -7,9 +7,7 @@ def get_long_description(): this_directory = os.path.abspath(os.path.dirname(__file__)) with open(os.path.join(this_directory, "README.md"), encoding="utf-8") as f: - long_description = f.read() - - return long_description + return f.read() def parse_requirements(file): @@ -29,7 +27,7 @@ def get_version(): setup( name="eo-learn-coregistration", - python_requires=">=3.7", + python_requires=">=3.8", version=get_version(), description="A collection of image co-registration EOTasks", long_description=get_long_description(), @@ -59,10 +57,10 @@ def get_version(): "Operating System :: Unix", "Programming Language :: Python", "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.7", "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", "Topic :: Scientific/Engineering", "Topic :: Scientific/Engineering :: GIS", "Topic :: Scientific/Engineering :: Image Processing", diff --git a/docs/environment.yml b/docs/environment.yml index 792036a4f..8808b0b25 100644 --- a/docs/environment.yml +++ b/docs/environment.yml @@ -18,5 +18,5 @@ dependencies: - ./../geometry - ./../io[METEOBLUE] - ./../mask - - ./../ml_tools + - ./../ml_tools[TDIGEST] - ./../visualization diff --git a/docs/source/conf.py b/docs/source/conf.py index 3f530cf2c..00247a5d0 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -1,4 +1,4 @@ -# -*- coding: utf-8 -*- +# -*- coding: utf-8 -*- # noqa: UP009 # # Configuration file for the Sphinx documentation builder. # @@ -27,14 +27,14 @@ import eolearn.io import eolearn.mask import eolearn.ml_tools -import eolearn.visualization # noqa +import eolearn.visualization # noqa: F401 from eolearn.core import EOTask # -- Project information ----------------------------------------------------- # General information about the project. project = "eo-learn" -copyright = "2017-, Sinergise" +copyright = "2017-, Sinergise" # noqa: A001 author = "Sinergise EO research team" doc_title = "eo-learn Documentation" @@ -77,7 +77,7 @@ "SingleFeatureSpec": "eolearn.core.types.SingleFeatureSpec", } -# Both the class’ and the __init__ method’s docstring are concatenated and inserted. +# Both the class' and the __init__ method's docstring are concatenated and inserted. autoclass_content = "both" # Content is in the same order as in module @@ -230,7 +230,7 @@ # When Sphinx documents class signature it prioritizes __new__ method over __init__ method. The following hack puts # EOTask.__new__ method to the blacklist so that __init__ method signature will be taken instead. This seems the # cleanest way even though a private object is accessed. -sphinx.ext.autodoc._CLASS_NEW_BLACKLIST.append("{0.__module__}.{0.__qualname__}".format(EOTask.__new__)) +sphinx.ext.autodoc._CLASS_NEW_BLACKLIST.append("{0.__module__}.{0.__qualname__}".format(EOTask.__new__)) # noqa[SLF001] EXAMPLES_FOLDER = "./examples" @@ -242,7 +242,7 @@ def copy_documentation_examples(source_folder, target_folder): files_to_include = ["core/images/eopatch.png"] for rst_file in ["examples.rst", "index.rst"]: - with open(rst_file, "r") as fp: + with open(rst_file) as fp: content = fp.read() for line in content.split("\n"): @@ -271,7 +271,7 @@ def process_readme(): """Function which will process README.md file and divide it into INTRO.md and INSTALL.md, which will be used in documentation """ - with open("../../README.md", "r") as file: + with open("../../README.md") as file: readme = file.read() readme = readme.replace("# eo-learn", "# Introduction").replace("docs/source/", "") @@ -289,13 +289,7 @@ def process_readme(): chapters = ["\n".join(chapter) for chapter in chapters] - intro = "\n".join( - [ - chapter - for chapter in chapters - if not (chapter.startswith("## Install") or chapter.startswith("## Documentation")) - ] - ) + intro = "\n".join([chapter for chapter in chapters if not (chapter.startswith(("## Install", "## Documentation")))]) install = "\n".join([chapter for chapter in chapters if chapter.startswith("## Install")]) intro = intro.replace("./CONTRIBUTING.md", "contribute.html") diff --git a/examples/io/SentinelHubIO.ipynb b/examples/io/SentinelHubIO.ipynb index 53c6bfb1c..1a1466d47 100644 --- a/examples/io/SentinelHubIO.ipynb +++ b/examples/io/SentinelHubIO.ipynb @@ -711,7 +711,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.4" + "version": "3.8.10" }, "toc": { "base_numbering": 1, diff --git a/examples/land-cover-map/SI_LULC_pipeline.ipynb b/examples/land-cover-map/SI_LULC_pipeline.ipynb index 4adc0664d..c34ed3f5e 100644 --- a/examples/land-cover-map/SI_LULC_pipeline.ipynb +++ b/examples/land-cover-map/SI_LULC_pipeline.ipynb @@ -1,6 +1,7 @@ { "cells": [ { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -103,6 +104,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -121,6 +123,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -181,6 +184,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -213,17 +217,17 @@ "ID = 616\n", "\n", "# Obtain surrounding 5x5 patches\n", - "patchIDs = []\n", - "for idx, (bbox, info) in enumerate(zip(bbox_list, info_list)):\n", + "patch_ids = []\n", + "for idx, info in enumerate(info_list):\n", " if abs(info[\"index_x\"] - info_list[ID][\"index_x\"]) <= 2 and abs(info[\"index_y\"] - info_list[ID][\"index_y\"]) <= 2:\n", - " patchIDs.append(idx)\n", + " patch_ids.append(idx)\n", "\n", "# Check if final size is 5x5\n", - "if len(patchIDs) != 5 * 5:\n", + "if len(patch_ids) != 5 * 5:\n", " print(\"Warning! Use a different central patch ID, this one is on the border.\")\n", "\n", "# Change the order of the patches (useful for plotting)\n", - "patchIDs = np.transpose(np.fliplr(np.array(patchIDs).reshape(5, 5))).ravel()\n", + "patch_ids = np.transpose(np.fliplr(np.array(patch_ids).reshape(5, 5))).ravel()\n", "\n", "# Save to shapefile\n", "shapefile_name = \"grid_slovenia_500x500.gpkg\"\n", @@ -231,6 +235,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -267,12 +272,13 @@ " ax.text(geo.centroid.x, geo.centroid.y, info[\"index\"], ha=\"center\", va=\"center\")\n", "\n", "# Mark bboxes of selected area\n", - "bbox_gdf[bbox_gdf.index.isin(patchIDs)].plot(ax=ax, facecolor=\"g\", edgecolor=\"r\", alpha=0.5)\n", + "bbox_gdf[bbox_gdf.index.isin(patch_ids)].plot(ax=ax, facecolor=\"g\", edgecolor=\"r\", alpha=0.5)\n", "\n", "plt.axis(\"off\");" ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -297,6 +303,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -338,6 +345,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -395,6 +403,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -426,6 +435,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -484,6 +494,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -535,6 +546,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -554,6 +566,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -567,7 +580,123 @@ "outputs": [ { "data": { - "image/svg+xml": "\n\n\n\n\n\n%3\n\n\n\nSentinelHubInputTask\n\nSentinelHubInputTask\n\n\n\nNormalizedDifferenceIndexTask\n\nNormalizedDifferenceIndexTask\n\n\n\nSentinelHubInputTask->NormalizedDifferenceIndexTask\n\n\n\n\n\nNormalizedDifferenceIndexTask_1\n\nNormalizedDifferenceIndexTask_1\n\n\n\nNormalizedDifferenceIndexTask->NormalizedDifferenceIndexTask_1\n\n\n\n\n\nNormalizedDifferenceIndexTask_2\n\nNormalizedDifferenceIndexTask_2\n\n\n\nNormalizedDifferenceIndexTask_1->NormalizedDifferenceIndexTask_2\n\n\n\n\n\nSentinelHubValidDataTask\n\nSentinelHubValidDataTask\n\n\n\nNormalizedDifferenceIndexTask_2->SentinelHubValidDataTask\n\n\n\n\n\nAddValidCountTask\n\nAddValidCountTask\n\n\n\nSentinelHubValidDataTask->AddValidCountTask\n\n\n\n\n\nVectorImportTask\n\nVectorImportTask\n\n\n\nAddValidCountTask->VectorImportTask\n\n\n\n\n\nVectorToRasterTask\n\nVectorToRasterTask\n\n\n\nVectorImportTask->VectorToRasterTask\n\n\n\n\n\nSaveTask\n\nSaveTask\n\n\n\nVectorToRasterTask->SaveTask\n\n\n\n\n\n", + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "%3\n", + "\n", + "\n", + "\n", + "SentinelHubInputTask\n", + "\n", + "SentinelHubInputTask\n", + "\n", + "\n", + "\n", + "NormalizedDifferenceIndexTask\n", + "\n", + "NormalizedDifferenceIndexTask\n", + "\n", + "\n", + "\n", + "SentinelHubInputTask->NormalizedDifferenceIndexTask\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "NormalizedDifferenceIndexTask_1\n", + "\n", + "NormalizedDifferenceIndexTask_1\n", + "\n", + "\n", + "\n", + "NormalizedDifferenceIndexTask->NormalizedDifferenceIndexTask_1\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "NormalizedDifferenceIndexTask_2\n", + "\n", + "NormalizedDifferenceIndexTask_2\n", + "\n", + "\n", + "\n", + "NormalizedDifferenceIndexTask_1->NormalizedDifferenceIndexTask_2\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "SentinelHubValidDataTask\n", + "\n", + "SentinelHubValidDataTask\n", + "\n", + "\n", + "\n", + "NormalizedDifferenceIndexTask_2->SentinelHubValidDataTask\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "AddValidCountTask\n", + "\n", + "AddValidCountTask\n", + "\n", + "\n", + "\n", + "SentinelHubValidDataTask->AddValidCountTask\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "VectorImportTask\n", + "\n", + "VectorImportTask\n", + "\n", + "\n", + "\n", + "AddValidCountTask->VectorImportTask\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "VectorToRasterTask\n", + "\n", + "VectorToRasterTask\n", + "\n", + "\n", + "\n", + "VectorImportTask->VectorToRasterTask\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "SaveTask\n", + "\n", + "SaveTask\n", + "\n", + "\n", + "\n", + "VectorToRasterTask->SaveTask\n", + "\n", + "\n", + "\n", + "\n", + "\n" + ], "text/plain": [ "" ] @@ -589,6 +718,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -612,7 +742,7 @@ "input_node = workflow_nodes[0]\n", "save_node = workflow_nodes[-1]\n", "execution_args = []\n", - "for idx, bbox in enumerate(bbox_list[patchIDs]):\n", + "for idx, bbox in enumerate(bbox_list[patch_ids]):\n", " execution_args.append(\n", " {\n", " input_node: {\"bbox\": bbox, \"time_interval\": time_interval},\n", @@ -635,6 +765,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -697,6 +828,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -734,7 +866,7 @@ "\n", "date = datetime.datetime(2019, 7, 1)\n", "\n", - "for i in tqdm(range(len(patchIDs))):\n", + "for i in tqdm(range(len(patch_ids))):\n", " eopatch_path = os.path.join(EOPATCH_FOLDER, f\"eopatch_{i}\")\n", " eopatch = EOPatch.load(eopatch_path, lazy_loading=True)\n", "\n", @@ -752,6 +884,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -786,7 +919,7 @@ "source": [ "fig, axs = plt.subplots(nrows=5, ncols=5, figsize=(20, 25))\n", "\n", - "for i in tqdm(range(len(patchIDs))):\n", + "for i in tqdm(range(len(patch_ids))):\n", " eopatch_path = os.path.join(EOPATCH_FOLDER, f\"eopatch_{i}\")\n", " eopatch = EOPatch.load(eopatch_path, lazy_loading=True)\n", "\n", @@ -807,6 +940,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -841,7 +975,7 @@ "source": [ "# Calculate min and max counts of valid data per pixel\n", "vmin, vmax = None, None\n", - "for i in range(len(patchIDs)):\n", + "for i in range(len(patch_ids)):\n", " eopatch_path = os.path.join(EOPATCH_FOLDER, f\"eopatch_{i}\")\n", " eopatch = EOPatch.load(eopatch_path, lazy_loading=True)\n", " data = eopatch.mask_timeless[\"VALID_COUNT\"].squeeze()\n", @@ -850,7 +984,7 @@ "\n", "fig, axs = plt.subplots(nrows=5, ncols=5, figsize=(20, 25))\n", "\n", - "for i in tqdm(range(len(patchIDs))):\n", + "for i in tqdm(range(len(patch_ids))):\n", " eopatch_path = os.path.join(EOPATCH_FOLDER, f\"eopatch_{i}\")\n", " eopatch = EOPatch.load(eopatch_path, lazy_loading=True)\n", " ax = axs[i // 5][i % 5]\n", @@ -868,6 +1002,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -903,7 +1038,6 @@ } ], "source": [ - "eID = 16\n", "eopatch = EOPatch.load(os.path.join(EOPATCH_FOLDER, f\"eopatch_{i}\"), lazy_loading=True)\n", "\n", "ndvi = eopatch.data[\"NDVI\"]\n", @@ -932,6 +1066,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -968,7 +1103,7 @@ "source": [ "fig, axs = plt.subplots(nrows=5, ncols=5, figsize=(20, 25))\n", "\n", - "for i in tqdm(range(len(patchIDs))):\n", + "for i in tqdm(range(len(patch_ids))):\n", " eopatch_path = os.path.join(EOPATCH_FOLDER, f\"eopatch_{i}\")\n", " eopatch = EOPatch.load(eopatch_path, lazy_loading=True)\n", " ndvi = eopatch.data[\"NDVI\"]\n", @@ -991,6 +1126,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -1029,7 +1165,7 @@ "source": [ "fig, axs = plt.subplots(nrows=5, ncols=5, figsize=(20, 25))\n", "\n", - "for i in tqdm(range(len(patchIDs))):\n", + "for i in tqdm(range(len(patch_ids))):\n", " eopatch_path = os.path.join(EOPATCH_FOLDER, f\"eopatch_{i}\")\n", " eopatch = EOPatch.load(eopatch_path, lazy_loading=True)\n", " clp = eopatch.data[\"CLP\"].astype(float) / 255\n", @@ -1052,6 +1188,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -1077,6 +1214,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -1159,6 +1297,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -1190,7 +1329,7 @@ "%%time\n", "\n", "execution_args = []\n", - "for idx in range(len(patchIDs)):\n", + "for idx in range(len(patch_ids)):\n", " execution_args.append(\n", " {\n", " workflow_nodes[0]: {\"eopatch_folder\": f\"eopatch_{idx}\"}, # load\n", @@ -1213,6 +1352,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -1238,7 +1378,7 @@ "# Load sampled eopatches\n", "sampled_eopatches = []\n", "\n", - "for i in range(len(patchIDs)):\n", + "for i in range(len(patch_ids)):\n", " sample_path = os.path.join(EOPATCH_SAMPLES_FOLDER, f\"eopatch_{i}\")\n", " sampled_eopatches.append(EOPatch.load(sample_path, lazy_loading=True))" ] @@ -1250,10 +1390,10 @@ "outputs": [], "source": [ "# Definition of the train and test patch IDs, take 80 % for train\n", - "test_ID = [0, 8, 16, 19, 20]\n", - "test_eopatches = [sampled_eopatches[i] for i in test_ID]\n", - "train_ID = [i for i in range(len(patchIDs)) if i not in test_ID]\n", - "train_eopatches = [sampled_eopatches[i] for i in train_ID]\n", + "test_ids = [0, 8, 16, 19, 20]\n", + "test_eopatches = [sampled_eopatches[i] for i in test_ids]\n", + "train_ids = [i for i in range(len(patch_ids)) if i not in test_ids]\n", + "train_eopatches = [sampled_eopatches[i] for i in train_ids]\n", "\n", "# Set the features and the labels for train and test sets\n", "features_train = np.concatenate([eopatch.data[\"FEATURES_SAMPLED\"] for eopatch in train_eopatches], axis=1)\n", @@ -1294,6 +1434,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -1343,6 +1484,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -1350,6 +1492,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -1382,6 +1525,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -1436,6 +1580,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -1448,7 +1593,6 @@ "metadata": {}, "outputs": [], "source": [ - "# Define the plotting function\n", "def plot_confusion_matrix(\n", " confusion_matrix,\n", " classes,\n", @@ -1457,7 +1601,6 @@ " cmap=plt.cm.Blues,\n", " ylabel=\"True label\",\n", " xlabel=\"Predicted label\",\n", - " filename=None,\n", "):\n", " \"\"\"\n", " This function prints and plots the confusion matrix.\n", @@ -1539,6 +1682,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -1577,6 +1721,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -1598,7 +1743,7 @@ "\n", "fpr, tpr, roc_auc = {}, {}, {}\n", "\n", - "for idx, lbl in enumerate(class_labels):\n", + "for idx, _ in enumerate(class_labels):\n", " fpr[idx], tpr[idx], _ = metrics.roc_curve(labels_binarized[:, idx], scores_test[:, idx])\n", " roc_auc[idx] = metrics.auc(fpr[idx], tpr[idx])" ] @@ -1649,6 +1794,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -1729,7 +1875,7 @@ "\n", "time_id = np.where(feature_importances == np.max(feature_importances))[0][0]\n", "\n", - "for i in tqdm(range(len(patchIDs))):\n", + "for i in tqdm(range(len(patch_ids))):\n", " sample_path = os.path.join(EOPATCH_SAMPLES_FOLDER, f\"eopatch_{i}\")\n", " eopatch = EOPatch.load(sample_path, lazy_loading=True)\n", " ax = axs[i // 5][i % 5]\n", @@ -1743,6 +1889,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -1754,6 +1901,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -1799,6 +1947,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -1830,6 +1979,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -1854,7 +2004,7 @@ "source": [ "# Create a list of execution arguments for each patch\n", "execution_args = []\n", - "for i in range(len(patchIDs)):\n", + "for i in range(len(patch_ids)):\n", " execution_args.append(\n", " {\n", " workflow_nodes[0]: {\"eopatch_folder\": f\"eopatch_{i}\"},\n", @@ -1899,6 +2049,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -1933,7 +2084,7 @@ "source": [ "fig, axs = plt.subplots(nrows=5, ncols=5, figsize=(20, 25))\n", "\n", - "for i in tqdm(range(len(patchIDs))):\n", + "for i in tqdm(range(len(patch_ids))):\n", " eopatch_path = os.path.join(EOPATCH_SAMPLES_FOLDER, f\"eopatch_{i}\")\n", " eopatch = EOPatch.load(eopatch_path, lazy_loading=True)\n", " ax = axs[i // 5][i % 5]\n", @@ -1953,6 +2104,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -1985,7 +2137,7 @@ "# Draw the Reference map\n", "fig = plt.figure(figsize=(20, 20))\n", "\n", - "idx = np.random.choice(range(len(patchIDs)))\n", + "idx = np.random.choice(range(len(patch_ids)))\n", "inspect_size = 100\n", "\n", "eopatch = EOPatch.load(os.path.join(EOPATCH_SAMPLES_FOLDER, f\"eopatch_{idx}\"), lazy_loading=True)\n", diff --git a/examples/mask/ValidDataMask.ipynb b/examples/mask/ValidDataMask.ipynb index 4a7877833..6f800c2f8 100644 --- a/examples/mask/ValidDataMask.ipynb +++ b/examples/mask/ValidDataMask.ipynb @@ -1,6 +1,7 @@ { "cells": [ { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -8,6 +9,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -34,6 +36,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -51,6 +54,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -58,6 +62,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -99,6 +104,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -117,6 +123,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -168,7 +175,6 @@ "metadata": {}, "outputs": [], "source": [ - "# helper function\n", "def plot_timestamp(eop, timestamp_idx):\n", " fig, ax = plt.subplots(nrows=2, ncols=2, figsize=(20, 12))\n", " ax[0][0].imshow(np.clip(2.5 * eop.data[\"trueColorBands\"][timestamp_idx][..., [2, 1, 0]], 0, 1))\n", @@ -206,6 +212,7 @@ "# In this case part of the image is outside 'orbit', and part is covered with (small) clouds.\n", "# `VALID_DATA` is constructed from both.\n", "\n", + "\n", "plot_timestamp(eopatch, 5)" ] }, diff --git a/examples/water-monitor/WaterMonitorWorkflow.ipynb b/examples/water-monitor/WaterMonitorWorkflow.ipynb index 35dd40b26..34a22e78d 100644 --- a/examples/water-monitor/WaterMonitorWorkflow.ipynb +++ b/examples/water-monitor/WaterMonitorWorkflow.ipynb @@ -1,6 +1,7 @@ { "cells": [ { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -24,35 +25,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## Imports\n", - "\n", - "### eo-learn imports" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [], - "source": [ - "from eolearn.core import EOTask, EOWorkflow, FeatureType, OutputTask, linearly_connect_tasks\n", - "\n", - "# filtering of scenes\n", - "from eolearn.features import NormalizedDifferenceIndexTask, SimpleFilterTask\n", - "\n", - "# burning the vectorised polygon to raster\n", - "from eolearn.geometry import VectorToRasterTask\n", - "from eolearn.io import SentinelHubInputTask\n", - "\n", - "# We'll use Sentinel-2 imagery (Level-1C) provided through Sentinel Hub\n", - "# If you don't know what `Level 1C` means, don't worry. It doesn't matter." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Other imports " + "## Imports" ] }, { @@ -79,10 +52,21 @@ "from skimage.filters import threshold_otsu\n", "\n", "# sentinelhub-py package\n", - "from sentinelhub import CRS, BBox, DataCollection" + "from sentinelhub import CRS, BBox, DataCollection\n", + "\n", + "# eo-learn core building blocks\n", + "from eolearn.core import EOTask, EOWorkflow, FeatureType, OutputTask, linearly_connect_tasks\n", + "\n", + "# filtering of scenes\n", + "from eolearn.features import NormalizedDifferenceIndexTask, SimpleFilterTask\n", + "\n", + "# burning the vectorised polygon to raster\n", + "from eolearn.geometry import VectorToRasterTask\n", + "from eolearn.io import SentinelHubInputTask" ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -103,6 +87,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -118,7 +103,9 @@ "outputs": [ { "data": { - "image/svg+xml": "", + "image/svg+xml": [ + "" + ], "text/plain": [ "" ] @@ -152,6 +139,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -159,6 +147,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -185,6 +174,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -240,6 +230,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -247,6 +238,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -271,6 +263,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -302,6 +295,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -331,6 +325,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -375,6 +370,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -405,6 +401,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -441,6 +438,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -496,6 +494,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ diff --git a/features/MANIFEST.in b/features/MANIFEST.in index 1b163e951..9f43ca148 100644 --- a/features/MANIFEST.in +++ b/features/MANIFEST.in @@ -1,4 +1,5 @@ include requirements*.txt include LICENSE include README.md +include eolearn/features/py.typed exclude eolearn/tests/* diff --git a/features/eolearn/features/__init__.py b/features/eolearn/features/__init__.py index 235dc3a01..d9ff3335b 100644 --- a/features/eolearn/features/__init__.py +++ b/features/eolearn/features/__init__.py @@ -38,4 +38,4 @@ AddSpatioTemporalFeaturesTask, ) -__version__ = "1.4.1" +__version__ = "1.4.2" diff --git a/features/eolearn/features/bands_extraction.py b/features/eolearn/features/bands_extraction.py index fb9ea8b79..d4071797a 100644 --- a/features/eolearn/features/bands_extraction.py +++ b/features/eolearn/features/bands_extraction.py @@ -8,8 +8,6 @@ """ from __future__ import annotations -from typing import List, Optional, Tuple - import numpy as np from eolearn.core import MapFeatureTask @@ -25,7 +23,7 @@ class EuclideanNormTask(MapFeatureTask): """ def __init__( - self, input_feature: SingleFeatureSpec, output_feature: SingleFeatureSpec, bands: Optional[List[int]] = None + self, input_feature: SingleFeatureSpec, output_feature: SingleFeatureSpec, bands: list[int] | None = None ): """ :param input_feature: A source feature from which to take the subset of bands. @@ -57,7 +55,7 @@ def __init__( self, input_feature: SingleFeatureSpec, output_feature: SingleFeatureSpec, - bands: Tuple[int, int], + bands: tuple[int, int], acorvi_constant: float = 0, undefined_value: float = np.nan, ): diff --git a/features/eolearn/features/blob.py b/features/eolearn/features/blob.py index 01fa0c2b8..225744231 100644 --- a/features/eolearn/features/blob.py +++ b/features/eolearn/features/blob.py @@ -8,13 +8,17 @@ """ from __future__ import annotations +import itertools as it from math import sqrt from typing import Any, Callable import numpy as np import skimage.feature -from eolearn.core import EOPatch, EOTask +from sentinelhub.exceptions import deprecated_class + +from eolearn.core import EOPatch, EOTask, FeatureType +from eolearn.core.exceptions import EODeprecationWarning from eolearn.core.types import SingleFeatureSpec @@ -31,47 +35,33 @@ class BlobTask(EOTask): The output is a `FeatureType.DATA` where the radius of each blob is stored in his center. ie : If blob[date, i, j, 0] = 5 then a blob of radius 5 is present at the coordinate (i, j) - The task uses skimage.feature.blob_log or skimage.feature.blob_dog or skimage.feature.blob_doh to extract the blobs. + The task uses `skimage.feature.blob_log`, `skimage.feature.blob_dog` or `skimage.feature.blob_doh` for extraction. The input image must be in [-1,1] range. - :param feature: A feature that will be used and a new feature name where data will be saved. If new name is not - specified it will be saved with name '_BLOB' - - Example: (FeatureType.DATA, 'bands') or (FeatureType.DATA, 'bands', 'blob') - - :param blob_object: Name of the blob method to use - :param blob_parameters: List of parameters to be passed to the blob function. Below is a list of such parameters. - :param min_sigma: The minimum standard deviation for Gaussian Kernel. Keep this low to detect smaller blobs - :param max_sigma: The maximum standard deviation for Gaussian Kernel. Keep this high to detect larger blobs - :param threshold: The absolute lower bound for scale space maxima. Local maxima smaller than thresh are ignored. - Reduce this to detect blobs with less intensity - :param overlap: A value between 0 and 1. If the area of two blobs overlaps by a fraction greater than threshold, - the smaller blob is eliminated - :param num_sigma: For ‘Log’ and ‘DoH’: The number of intermediate values of standard deviations to consider between - min_sigma and max_sigma - :param log_scale: For ‘Log’ and ‘DoH’: If set intermediate values of standard deviations are interpolated using a - logarithmic scale to the base 10. If not, linear interpolation is used - :param sigma_ratio: For ‘DoG’: The ratio between the standard deviation of Gaussian Kernels used for computing the - Difference of Gaussians + :param feature: A feature that will be used and a new feature name where data will be saved, e.g. + `(FeatureType.DATA, 'bands', 'blob')`. + :param blob_object: Callable that calculates the blob + :param blob_parameters: Parameters to be passed to the blob function. Consult documentation of `blob_object` + for available parameters. """ def __init__(self, feature: SingleFeatureSpec, blob_object: Callable, **blob_parameters: Any): - self.feature_parser = self.get_feature_parser(feature) + self.feature_parser = self.get_feature_parser(feature, allowed_feature_types=[FeatureType.DATA]) self.blob_object = blob_object self.blob_parameters = blob_parameters def _compute_blob(self, data: np.ndarray) -> np.ndarray: - result = np.zeros(data.shape, dtype=float) - for time in range(data.shape[0]): - for band in range(data.shape[-1]): - image = data[time, :, :, band] - res = np.asarray(self.blob_object(image, **self.blob_parameters)) - x_coord = res[:, 0].astype(int) - y_coord = res[:, 1].astype(int) - radius = res[:, 2] * sqrt(2) - result[time, x_coord, y_coord, band] = radius + result = np.zeros(data.shape, dtype=np.float32) + num_time, _, _, num_bands = data.shape + for time_idx, band_idx in it.product(range(num_time), range(num_bands)): + image = data[time_idx, :, :, band_idx] + blob = self.blob_object(image, **self.blob_parameters) + x_coord = blob[:, 0].astype(int) + y_coord = blob[:, 1].astype(int) + radius = blob[:, 2] * sqrt(2) + result[time_idx, x_coord, y_coord, band_idx] = radius return result def execute(self, eopatch: EOPatch) -> EOPatch: @@ -80,14 +70,13 @@ def execute(self, eopatch: EOPatch) -> EOPatch: :param eopatch: Input eopatch :return: EOPatch instance with new key holding the blob image. """ - for feature_type, feature_name, new_feature_name in self.feature_parser.get_renamed_features(eopatch): - eopatch[feature_type][new_feature_name] = self._compute_blob( - eopatch[feature_type][feature_name].astype(np.float64) - ).astype(np.float32) + for ftype, fname, new_fname in self.feature_parser.get_renamed_features(eopatch): + eopatch[ftype, new_fname] = self._compute_blob(eopatch[ftype, fname].astype(np.float64)) return eopatch +@deprecated_class(EODeprecationWarning, "Use `BlobTask` with `blob_object=skimage.feature.blob_dog`.") class DoGBlobTask(BlobTask): """Task to compute blobs with Difference of Gaussian (DoG) method""" @@ -114,6 +103,7 @@ def __init__( ) +@deprecated_class(EODeprecationWarning, "Use `BlobTask` with `blob_object=skimage.feature.blob_doh`.") class DoHBlobTask(BlobTask): """Task to compute blobs with Determinant of the Hessian (DoH) method""" @@ -142,6 +132,7 @@ def __init__( ) +@deprecated_class(EODeprecationWarning, "Use `BlobTask` with `blob_object=skimage.feature.blob_log`.") class LoGBlobTask(BlobTask): """Task to compute blobs with Laplacian of Gaussian (LoG) method""" diff --git a/features/eolearn/features/clustering.py b/features/eolearn/features/clustering.py index ee11b6ce4..d66f6ef51 100644 --- a/features/eolearn/features/clustering.py +++ b/features/eolearn/features/clustering.py @@ -8,12 +8,11 @@ """ from __future__ import annotations -from typing import Callable, List, Optional, Union, cast +from typing import Callable, Literal import numpy as np from sklearn.cluster import AgglomerativeClustering from sklearn.feature_extraction.image import grid_to_graph -from typing_extensions import Literal from eolearn.core import EOPatch, EOTask, FeatureType from eolearn.core.types import FeatureSpec @@ -32,13 +31,13 @@ def __init__( self, features: FeatureSpec, new_feature_name: str, - distance_threshold: Optional[float] = None, - n_clusters: Optional[int] = None, + distance_threshold: float | None = None, + n_clusters: int | None = None, affinity: Literal["euclidean", "l1", "l2", "manhattan", "cosine"] = "cosine", linkage: Literal["ward", "complete", "average", "single"] = "single", remove_small: int = 0, - connectivity: Union[None, np.ndarray, Callable] = None, - mask_name: Optional[str] = None, + connectivity: None | np.ndarray | Callable = None, + mask_name: str | None = None, ): """Class constructor @@ -62,17 +61,13 @@ def __init__( adjacent pixels connected. :param mask_name: An optional mask feature used for exclusion of the area from clustering """ - self.features_parser = self.get_feature_parser(features) + self.features_parser = self.get_feature_parser(features, allowed_feature_types=[FeatureType.DATA_TIMELESS]) self.distance_threshold = distance_threshold self.affinity = affinity self.linkage = linkage self.new_feature_name = new_feature_name self.n_clusters = n_clusters - self.compute_full_tree: Union[Literal["auto"], bool] = "auto" - if distance_threshold is not None: - self.compute_full_tree = True - if remove_small < 0: - raise ValueError("remove_small argument should be non-negative") + self.compute_full_tree: Literal["auto"] | bool = "auto" if distance_threshold is None else True self.remove_small = remove_small self.connectivity = connectivity self.mask_name = mask_name @@ -86,19 +81,16 @@ def execute(self, eopatch: EOPatch) -> EOPatch: data = np.concatenate([eopatch[feature] for feature in relevant_features], axis=2) # Reshapes the data, because AgglomerativeClustering method only takes one dimensional arrays of vectors - org_shape = data.shape - data = np.reshape(data, (-1, org_shape[-1])) - org_length = len(data) + height, width, num_channels = data.shape + data = np.reshape(data, (-1, num_channels)) - graph_args = {"n_x": org_shape[0], "n_y": org_shape[1]} - locations = None + graph_args = {"n_x": height, "n_y": width} # All connections to masked pixels are removed if self.mask_name is not None: mask = eopatch.mask_timeless[self.mask_name].squeeze() graph_args["mask"] = mask - locations = [i for i, elem in enumerate(np.ravel(mask)) if elem == 0] - data = np.delete(data, locations, axis=0) + data = data[np.ravel(mask) != 0] # If connectivity is not set, it uses pixel-to-pixel connections if not self.connectivity: @@ -114,28 +106,18 @@ def execute(self, eopatch: EOPatch) -> EOPatch: ) model.fit(data) - trimmed_labels = model.labels_ + result = model.labels_ if self.remove_small > 0: - # Counts how many pixels covers each cluster - labels = np.zeros(model.n_clusters_) - for i in trimmed_labels: - labels[i] += 1 - - # Sets to -1 all pixels corresponding to too small clusters - for i, no_lab in enumerate(labels): - if no_lab < self.remove_small: - trimmed_labels[trimmed_labels == i] = -1 + for label, count in zip(*np.unique(result, return_counts=True)): + if count < self.remove_small: + result[result == label] = -1 # Transforms data back to original shape and setting all masked regions to -1 if self.mask_name is not None: - locations = cast(List[int], locations) # set because mask_name is not None - new_data = [-1] * org_length - for i, val in zip(np.delete(np.arange(org_length), locations), trimmed_labels): - new_data[i] = val - trimmed_labels = new_data - - trimmed_labels = np.reshape(trimmed_labels, org_shape[:-1]) + unmasked_result = np.full(height * width, -1) + unmasked_result[np.ravel(mask) != 0] = result + result = unmasked_result - eopatch[FeatureType.DATA_TIMELESS, self.new_feature_name] = trimmed_labels[..., np.newaxis] + eopatch[FeatureType.DATA_TIMELESS, self.new_feature_name] = np.reshape(result, (height, width, 1)) return eopatch diff --git a/features/eolearn/features/doubly_logistic_approximation.py b/features/eolearn/features/doubly_logistic_approximation.py index 5c2a20fbb..b7b989b00 100644 --- a/features/eolearn/features/doubly_logistic_approximation.py +++ b/features/eolearn/features/doubly_logistic_approximation.py @@ -9,7 +9,6 @@ from __future__ import annotations import itertools as it -from typing import List, Optional import numpy as np from scipy.optimize import curve_fit @@ -39,8 +38,8 @@ def __init__( self, feature: SingleFeatureSpec, new_feature: SingleFeatureSpec = (FeatureType.DATA_TIMELESS, "DOUBLY_LOGISTIC_PARAM"), - initial_parameters: Optional[List[float]] = None, - valid_mask: Optional[SingleFeatureSpec] = None, + initial_parameters: list[float] | None = None, + valid_mask: SingleFeatureSpec | None = None, ): self.initial_parameters = initial_parameters self.feature = self.parse_feature(feature) diff --git a/features/eolearn/features/feature_manipulation.py b/features/eolearn/features/feature_manipulation.py index 517f882e5..7e895a88c 100644 --- a/features/eolearn/features/feature_manipulation.py +++ b/features/eolearn/features/feature_manipulation.py @@ -12,11 +12,10 @@ import datetime as dt import logging from functools import partial -from typing import Any, Callable, Dict, Iterable, List, Optional, Union, cast +from typing import Any, Callable, Iterable, Literal, cast import numpy as np from geopandas import GeoDataFrame -from typing_extensions import Literal from sentinelhub import bbox_to_dimensions @@ -41,7 +40,7 @@ class SimpleFilterTask(EOTask): def __init__( self, feature: SingleFeatureSpec, - filter_func: Union[Callable[[np.ndarray], bool], Callable[[dt.datetime], bool]], + filter_func: Callable[[np.ndarray], bool] | Callable[[dt.datetime], bool], filter_features: FeaturesSpecification = ..., ): """ @@ -56,12 +55,12 @@ def __init__( self.filter_func = filter_func self.filter_features_parser = self.get_feature_parser(filter_features) - def _get_filtered_indices(self, feature_data: Iterable) -> List[int]: + def _get_filtered_indices(self, feature_data: Iterable) -> list[int]: """Get valid time indices from either a numpy array or a list of timestamps.""" return [idx for idx, img in enumerate(feature_data) if self.filter_func(img)] @staticmethod - def _filter_vector_feature(gdf: GeoDataFrame, good_idxs: List[int], timestamps: List[dt.datetime]) -> GeoDataFrame: + def _filter_vector_feature(gdf: GeoDataFrame, good_idxs: list[int], timestamps: list[dt.datetime]) -> GeoDataFrame: """Filters rows that don't match with the timestamps that will be kept.""" timestamps_to_keep = {timestamps[idx] for idx in good_idxs} return gdf[gdf[TIMESTAMP_COLUMN].isin(timestamps_to_keep)] @@ -220,10 +219,10 @@ class LinearFunctionTask(MapFeatureTask): def __init__( self, input_features: FeaturesSpecification, - output_features: Optional[FeaturesSpecification] = None, + output_features: FeaturesSpecification | None = None, slope: float = 1, intercept: float = 0, - dtype: Union[str, type, np.dtype, None] = None, + dtype: str | type | np.dtype | None = None, ): """ :param input_features: Feature or features on which the function is used. @@ -284,7 +283,7 @@ def __init__( ) def execute(self, eopatch: EOPatch) -> EOPatch: - resize_fun_kwargs: Dict[str, Any] + resize_fun_kwargs: dict[str, Any] if self.resize_type == ResizeParam.RESOLUTION: if not eopatch.bbox: raise ValueError("Resolution-specified resizing can only be done on EOPatches with a defined BBox.") diff --git a/features/eolearn/features/haralick.py b/features/eolearn/features/haralick.py index 7c419dcbc..a5c60c792 100644 --- a/features/eolearn/features/haralick.py +++ b/features/eolearn/features/haralick.py @@ -37,7 +37,7 @@ class HaralickTask(EOTask): "sum_entropy", "difference_variance", "difference_entropy", - } + }.union(AVAILABLE_TEXTURES_SKIMAGE) def __init__( self, @@ -50,10 +50,8 @@ def __init__( stride: int = 1, ): """ - :param feature: A feature that will be used and a new feature name where data will be saved. If new name is not - specified it will be saved with name '_HARALICK'. - - Example: `(FeatureType.DATA, 'bands')` or `(FeatureType.DATA, 'bands', 'haralick_values')` + :param feature: A feature that will be used and a new feature name where data will be saved, e.g. + `(FeatureType.DATA, 'bands', 'haralick_values')` :param texture_feature: Type of Haralick textural feature to be calculated :param distance: Distance between pairs of pixels used for GLCM :param angle: Angle between pairs of pixels used for GLCM in radians, e.g. angle=np.pi/4 @@ -64,11 +62,8 @@ def __init__( self.feature_parser = self.get_feature_parser(feature) self.texture_feature = texture_feature - if self.texture_feature not in self.AVAILABLE_TEXTURES.union(self.AVAILABLE_TEXTURES_SKIMAGE): - raise ValueError( - "Haralick texture feature must be one of these: " - f"{self.AVAILABLE_TEXTURES.union(self.AVAILABLE_TEXTURES_SKIMAGE)}" - ) + if self.texture_feature not in self.AVAILABLE_TEXTURES: + raise ValueError(f"Haralick texture feature must be one of these: {self.AVAILABLE_TEXTURES}") self.distance = distance self.angle = angle @@ -79,99 +74,81 @@ def __init__( raise ValueError("Window size must be an odd number") self.stride = stride - if self.stride >= self.window_size + 1: + if self.stride > self.window_size: warnings.warn( - "Haralick stride is superior to the window size; some pixel values will be ignored", EOUserWarning + "Haralick stride is larger than window size; some pixel values will be ignored", EOUserWarning ) - def _custom_texture(self, glcm: np.ndarray) -> np.ndarray: - # Sum of square: Variance + def _custom_texture(self, glcm: np.ndarray) -> np.ndarray: # pylint: disable=too-many-return-statements if self.texture_feature == "sum_of_square_variance": i_raw = np.empty_like(glcm) i_raw[...] = np.arange(glcm.shape[0]) i_raw = np.transpose(i_raw) i_minus_mean = (i_raw - glcm.mean()) ** 2 - res = np.apply_over_axes(np.sum, i_minus_mean * glcm, axes=(0, 1))[0][0] - elif self.texture_feature == "inverse_difference_moment": - # np.meshgrid + return np.apply_over_axes(np.sum, i_minus_mean * glcm, axes=(0, 1))[0][0] + if self.texture_feature == "inverse_difference_moment": j_cols = np.empty_like(glcm) j_cols[...] = np.arange(glcm.shape[1]) i_minus_j = ((j_cols - np.transpose(j_cols)) ** 2) + 1 - res = np.apply_over_axes(np.sum, glcm / i_minus_j, axes=(0, 1))[0][0] - elif self.texture_feature == "sum_average": - # Slow - tuple_array = np.array(list(it.product(list(range(self.levels)), list(range(self.levels)))), dtype=(int, 2)) - index = np.array([list(map(tuple, tuple_array[tuple_array.sum(axis=1) == x])) for x in range(self.levels)]) - p_x_y = np.array([glcm[tuple(np.moveaxis(index[y], -1, 0))].sum() for y in range(len(index))]) - res = np.array(p_x_y * np.array(range(len(index)))).sum() - elif self.texture_feature == "sum_variance": - # Slow - tuple_array = np.array(list(it.product(list(range(self.levels)), list(range(self.levels)))), dtype=(int, 2)) - index = np.array([list(map(tuple, tuple_array[tuple_array.sum(axis=1) == x])) for x in range(self.levels)]) - p_x_y = np.array([glcm[tuple(np.moveaxis(index[y], -1, 0))].sum() for y in range(len(index))]) - sum_average = np.array(p_x_y * np.array(range(len(index)))).sum() - res = ((np.array(range(len(index))) - sum_average) ** 2).sum() - elif self.texture_feature == "sum_entropy": - # Slow - tuple_array = np.array(list(it.product(list(range(self.levels)), list(range(self.levels)))), dtype=(int, 2)) - index = np.array([list(map(tuple, tuple_array[tuple_array.sum(axis=1) == x])) for x in range(self.levels)]) - p_x_y = np.array([glcm[tuple(np.moveaxis(index[y], -1, 0))].sum() for y in range(len(index))]) - res = (p_x_y * np.log(p_x_y + np.finfo(float).eps)).sum() * -1.0 - elif self.texture_feature == "difference_variance": - # Slow - tuple_array = np.array( - list(it.product(list(range(self.levels)), list(np.asarray(range(self.levels)) * -1))), dtype=(int, 2) - ) - index = np.array( - [list(map(tuple, tuple_array[np.abs(tuple_array.sum(axis=1)) == x])) for x in range(self.levels)] - ) - p_x_y = np.array([glcm[tuple(np.moveaxis(index[y], -1, 0))].sum() for y in range(len(index))]) - sum_average = np.array(p_x_y * np.array(range(len(index)))).sum() - res = ((np.array(range(len(index))) - sum_average) ** 2).sum() - else: - # self.texture_feature == 'difference_entropy': - # Slow - tuple_array = np.array( - list(it.product(list(range(self.levels)), list(np.asarray(range(self.levels)) * -1))), dtype=(int, 2) - ) - index = np.array( - [list(map(tuple, tuple_array[np.abs(tuple_array.sum(axis=1)) == x])) for x in range(self.levels)] - ) - p_x_y = np.array([glcm[tuple(np.moveaxis(index[y], -1, 0))].sum() for y in range(len(index))]) - res = (p_x_y * np.log(p_x_y + np.finfo(float).eps)).sum() * -1.0 - return res + return np.apply_over_axes(np.sum, glcm / i_minus_j, axes=(0, 1))[0][0] + if self.texture_feature == "sum_average": + p_x_y = self._get_pxy(glcm) + return np.array(p_x_y * np.arange(len(p_x_y))).sum() + if self.texture_feature == "sum_variance": + p_x_y = self._get_pxy(glcm) + sum_average = np.array(p_x_y * np.arange(len(p_x_y))).sum() + return ((np.arange(len(p_x_y)) - sum_average) ** 2).sum() + if self.texture_feature == "sum_entropy": + p_x_y = self._get_pxy(glcm) + return (p_x_y * np.log(p_x_y + np.finfo(float).eps)).sum() * -1.0 + if self.texture_feature == "difference_variance": + p_x_y = self._get_pxy_for_diff(glcm) + sum_average = np.array(p_x_y * np.arange(len(p_x_y))).sum() + return ((np.arange(len(p_x_y)) - sum_average) ** 2).sum() + + # self.texture_feature == 'difference_entropy': + p_x_y = self._get_pxy_for_diff(glcm) + return (p_x_y * np.log(p_x_y + np.finfo(float).eps)).sum() * -1.0 + + def _get_pxy(self, glcm: np.ndarray) -> np.ndarray: + tuple_array = np.array(list(it.product(range(self.levels), range(self.levels)))) + index = [tuple_array[tuple_array.sum(axis=1) == x] for x in range(self.levels)] + return np.array([glcm[tuple(np.moveaxis(idx, -1, 0))].sum() for idx in index]) + + def _get_pxy_for_diff(self, glcm: np.ndarray) -> np.ndarray: + tuple_array = np.array(list(it.product(range(self.levels), np.asarray(range(self.levels)) * -1))) + index = [tuple_array[np.abs(tuple_array.sum(axis=1)) == x] for x in range(self.levels)] + return np.array([glcm[tuple(np.moveaxis(idx, -1, 0))].sum() for idx in index]) def _calculate_haralick(self, data: np.ndarray) -> np.ndarray: result = np.empty(data.shape, dtype=float) + num_times, _, _, num_bands = data.shape # For each date and each band - for time in range(data.shape[0]): - for band in range(data.shape[3]): - image = data[time, :, :, band] - image_min, image_max = np.min(image), np.max(image) - coef = (image_max - image_min) / self.levels - digitized_image = np.digitize(image, np.array([image_min + k * coef for k in range(self.levels - 1)])) - - # Padding the image to handle borders - pad = self.window_size // 2 - digitized_image = np.pad(digitized_image, ((pad, pad), (pad, pad)), "edge") - # Sliding window - for i in range(0, image.shape[0], self.stride): - for j in range(0, image.shape[1], self.stride): - window = digitized_image[i : i + self.window_size, j : j + self.window_size] - glcm = skimage.feature.graycomatrix( - window, [self.distance], [self.angle], levels=self.levels, normed=True, symmetric=True - ) - - if self.texture_feature in self.AVAILABLE_TEXTURES_SKIMAGE: - res = skimage.feature.graycoprops(glcm, self.texture_feature)[0][0] - else: - res = self._custom_texture(glcm[:, :, 0, 0]) - - result[time, i, j, band] = res + for time, band in it.product(range(num_times), range(num_bands)): + image = data[time, :, :, band] + image_min, image_max = np.min(image), np.max(image) + coef = (image_max - image_min) / self.levels + digitized_image = np.digitize(image, np.array([image_min + k * coef for k in range(self.levels - 1)])) + + # Padding the image to handle borders + pad = self.window_size // 2 + digitized_image = np.pad(digitized_image, ((pad, pad), (pad, pad)), "edge") + # Sliding window + for i, j in it.product(range(0, image.shape[0], self.stride), range(0, image.shape[1], self.stride)): + window = digitized_image[i : i + self.window_size, j : j + self.window_size] + glcm = skimage.feature.graycomatrix( + window, [self.distance], [self.angle], levels=self.levels, normed=True, symmetric=True + ) + + if self.texture_feature in self.AVAILABLE_TEXTURES_SKIMAGE: + result[time, i, j, band] = skimage.feature.graycoprops(glcm, self.texture_feature)[0][0] + else: + result[time, i, j, band] = self._custom_texture(glcm[:, :, 0, 0]) + return result def execute(self, eopatch: EOPatch) -> EOPatch: - for feature_type, feature_name, new_feature_name in self.feature_parser.get_renamed_features(eopatch): - eopatch[feature_type, new_feature_name] = self._calculate_haralick(eopatch[feature_type, feature_name]) + for ftype, fname, new_fname in self.feature_parser.get_renamed_features(eopatch): + eopatch[ftype, new_fname] = self._calculate_haralick(eopatch[ftype, fname]) return eopatch diff --git a/features/eolearn/features/hog.py b/features/eolearn/features/hog.py index ff32264c8..88c14dd9d 100644 --- a/features/eolearn/features/hog.py +++ b/features/eolearn/features/hog.py @@ -8,7 +8,7 @@ """ from __future__ import annotations -from typing import Optional, Tuple +import itertools as it import numpy as np import skimage.feature @@ -31,24 +31,22 @@ def __init__( self, feature: SingleFeatureSpec, orientations: int = 9, - pixels_per_cell: Tuple[int, int] = (8, 8), - cells_per_block: Tuple[int, int] = (3, 3), + pixels_per_cell: tuple[int, int] = (8, 8), + cells_per_block: tuple[int, int] = (3, 3), visualize: bool = True, hog_feature_vector: bool = False, block_norm: str = "L2-Hys", - visualize_feature_name: Optional[str] = None, + visualize_feature_name: str | None = None, ): """ - :param feature: A feature that will be used and a new feature name where data will be saved. If new name is not - specified it will be saved with name '_HOG' - - Example: `(FeatureType.DATA, 'bands')` or `(FeatureType.DATA, 'bands', 'hog')` + :param feature: A feature that will be used and a new feature name where data will be saved, e.g. + `(FeatureType.DATA, 'bands', 'hog')`. :param orientations: Number of direction to use for the oriented gradient :param pixels_per_cell: Number of pixels in a cell, provided as a pair of integers. :param cells_per_block: Number of cells in a block, provided as a pair of integers. :param visualize: Produce a visualization for the HOG in an image :param visualize_feature_name: Name of the visualization feature to be added to the eopatch (if empty and - visualize is True, it becomes “new_name”_VIZU) + visualize is True, it becomes “new_name”_VISU) """ self.feature_parser = self.get_feature_parser(feature, allowed_feature_types=[FeatureType.DATA]) @@ -60,23 +58,23 @@ def __init__( self.hog_feature_vector = hog_feature_vector self.visualize_name = visualize_feature_name - def _compute_hog(self, data: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: - results_im = np.empty( + def _compute_hog(self, data: np.ndarray) -> tuple[np.ndarray, np.ndarray]: # pylint: disable=too-many-locals + num_times, height, width, num_bands = data.shape + is_multichannel = num_bands != 1 + hog_result = np.empty( ( - data.shape[0], - (int(data.shape[1] // self.pixels_per_cell[0]) - self.cells_per_block[0] + 1) * self.cells_per_block[0], - (int(data.shape[2] // self.pixels_per_cell[1]) - self.cells_per_block[1] + 1) * self.cells_per_block[1], + num_times, + ((height // self.pixels_per_cell[0]) - self.cells_per_block[0] + 1) * self.cells_per_block[0], + ((width // self.pixels_per_cell[1]) - self.cells_per_block[1] + 1) * self.cells_per_block[1], self.n_orientations, ), - dtype=float, + dtype=np.float32, ) if self.visualize: - im_visu = np.empty(data.shape[0:3] + (1,)) - for time in range(data.shape[0]): - is_multichannel = data.shape[-1] != 1 - image = data[time] if is_multichannel else data[time, :, :, 0] - res, image = skimage.feature.hog( - image, + hog_visualization = np.empty((num_times, height, width, 1)) + for time in range(num_times): + output, image = skimage.feature.hog( + data[time] if is_multichannel else data[time, :, :, 0], orientations=self.n_orientations, pixels_per_cell=self.pixels_per_cell, visualize=self.visualize, @@ -85,29 +83,31 @@ def _compute_hog(self, data: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: feature_vector=self.hog_feature_vector, channel_axis=-1 if is_multichannel else None, ) + + block_rows, block_cols, cell_rows, cell_cols, angles = output.shape + for block_row, block_col in it.product(range(block_rows), range(block_cols)): + for cell_row, cell_col in it.product(range(cell_rows), range(cell_cols)): + row = block_row * self.cells_per_block[0] + cell_row + col = block_col * self.cells_per_block[1] + cell_col + for angle in range(angles): + hog_result[time, row, col, angle] = output[block_row, block_col, cell_row, cell_col, angle] + if self.visualize: - im_visu[time, :, :, 0] = image - for block_row in range(res.shape[0]): - for block_col in range(res.shape[1]): - for cell_row in range(res.shape[2]): - for cell_col in range(res.shape[3]): - row = block_row * self.cells_per_block[0] + cell_row - col = block_col * self.cells_per_block[1] + cell_col - for angle in range(res.shape[4]): - results_im[time, row, col, angle] = res[block_row, block_col, cell_row, cell_col, angle] - return results_im, im_visu + hog_visualization[time, :, :, 0] = image + + return hog_result, hog_visualization def execute(self, eopatch: EOPatch) -> EOPatch: """Execute computation of HoG features on input eopatch :param eopatch: Input eopatch - :return: EOPatch instance with new keys holding the HoG features and HoG image for visualisation. + :return: EOPatch instance with new keys holding the HoG features and HoG image for visualization. """ for feature_type, feature_name, new_feature_name in self.feature_parser.get_renamed_features(eopatch): - result_im, im_visu = self._compute_hog(eopatch[feature_type, feature_name]) - eopatch[feature_type, new_feature_name] = result_im + hog_result, hog_visualization = self._compute_hog(eopatch[feature_type, feature_name]) + eopatch[feature_type, new_feature_name] = hog_result if self.visualize: visualize_name = self.visualize_name or f"{new_feature_name}_VISU" - eopatch[feature_type, visualize_name] = im_visu + eopatch[feature_type, visualize_name] = hog_visualization return eopatch diff --git a/features/eolearn/features/interpolation.py b/features/eolearn/features/interpolation.py index 5eb6e84f6..b8eb8bee6 100644 --- a/features/eolearn/features/interpolation.py +++ b/features/eolearn/features/interpolation.py @@ -11,8 +11,9 @@ import datetime as dt import inspect import warnings +from collections import defaultdict from functools import partial -from typing import Any, Callable, List, Optional, Set, Tuple, Union, cast +from typing import Any, Callable, Iterable, List, Tuple, Union, cast import dateutil import numpy as np @@ -120,9 +121,9 @@ def __init__( interpolation_object: Callable, *, resample_range: ResampleRangeType = None, - result_interval: Optional[Tuple[float, float]] = None, - mask_feature: Optional[SingleFeatureSpec] = None, - copy_features: Optional[FeaturesSpecification] = None, + result_interval: tuple[float, float] | None = None, + mask_feature: SingleFeatureSpec | None = None, + copy_features: FeaturesSpecification | None = None, unknown_value: float = np.nan, filling_factor: int = 10, scale_time: int = 3600, @@ -137,18 +138,16 @@ def __init__( self.resample_range = resample_range self.result_interval = result_interval - self.mask_feature_parser = ( - None - if mask_feature is None - else self.get_feature_parser( + self.mask_feature_parser = None + if mask_feature is not None: + self.mask_feature_parser = self.get_feature_parser( mask_feature, allowed_feature_types={FeatureType.MASK, FeatureType.MASK_TIMELESS, FeatureType.LABEL} ) - ) if resample_range is None and copy_features is not None: self.copy_features = None warnings.warn( - 'Argument "copy_features" will be ignored if "resample_range" is None. Nothing to copy.', EOUserWarning + "If `resample_range` is None the task is done in-place. Ignoring `copy_features`.", EOUserWarning ) else: self.copy_features_parser = None if copy_features is None else self.get_feature_parser(copy_features) @@ -159,26 +158,15 @@ def __init__( self.filling_factor = filling_factor self.interpolate_pixel_wise = interpolate_pixel_wise - self._resampled_times = None - @staticmethod def _mask_feature_data(feature_data: np.ndarray, mask: np.ndarray, mask_type: FeatureType) -> np.ndarray: - """Masks values of data feature with a given mask of given mask type. The masking is done by assigning - `numpy.nan` value. - - :param feature_data: Data array which will be masked - :param mask: Mask array - :return: Masked data array - """ - - if mask_type.is_spatial() and feature_data.shape[1:3] != mask.shape[-3:-1]: - raise ValueError( - f"Spatial dimensions of interpolation and mask feature do not match: {feature_data.shape} {mask.shape}" - ) + """Masks values of data feature (in-place) with a given mask by assigning `numpy.nan` value to masked fields.""" - if mask_type.is_temporal() and feature_data.shape[0] != mask.shape[0]: + spatial_dim_wrong = mask_type.is_spatial() and feature_data.shape[1:3] != mask.shape[-3:-1] + temporal_dim_wrong = mask_type.is_temporal() and feature_data.shape[0] != mask.shape[0] + if spatial_dim_wrong or temporal_dim_wrong: raise ValueError( - f"Time dimension of interpolation and mask feature do not match: {feature_data.shape} {mask.shape}" + f"Dimensions of interpolation data {feature_data.shape} and mask {mask.shape} do not match." ) # This allows masking each channel differently but causes some complications while masking with label @@ -202,11 +190,11 @@ def _mask_feature_data(feature_data: np.ndarray, mask: np.ndarray, mask_type: Fe def _get_start_end_nans(data: np.ndarray) -> np.ndarray: """Find NaN values in data that either start or end the time-series - Function to return a binary array of same size as data where `True` values correspond to NaN values present at + Function returns a array of same size as data where `True` corresponds to NaN values present at beginning or end of time-series. NaNs internal to the time-series are not included in the binary mask. - :param data: Array of observations of size TxNOBS - :return: Binary array of shape TxNOBS. `True` values indicate NaNs present at beginning or end of time-series + :param data: Array of observations of size t x num_obs + :return: Array of shape t x num_obs. `True` values indicate NaNs present at beginning or end of time-series """ # find NaNs that start a time-series start_nan = np.isnan(data) @@ -220,109 +208,67 @@ def _get_start_end_nans(data: np.ndarray) -> np.ndarray: return np.logical_or(start_nan, end_nan) @staticmethod - def _get_unique_times(data: np.ndarray, times: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: + def _get_unique_times(data: np.ndarray, times: np.ndarray) -> tuple[np.ndarray, np.ndarray]: """Replace duplicate acquisitions which have same values on the chosen timescale with their average. The average is calculated with numpy.nanmean, meaning that NaN values are ignored when calculating the average. - :param data: Array in a shape of t x nobs, where nobs = h x w x n + :param data: Array in a shape of t x num_obs, where num_obs = h x w x n :param times: Array of reference times relative to the first timestamp :return: cleaned versions of data input """ - seen = set() - duplication_list = [] - for idx, item in enumerate(times): - if item in seen: - duplication_list.append(idx) - else: - seen.add(item) - duplicated_indices = np.array(duplication_list, dtype=int) + time_groups = defaultdict(list) + for idx, time in enumerate(times): + time_groups[time].append(data[idx]) - duplicated_times = np.unique(times[duplicated_indices]) + clean_times = np.array(sorted(time_groups)) + clean_data = np.full((len(clean_times), *data.shape[1:]), np.nan) + for idx, time in enumerate(clean_times): + # np.nanmean complains about rows of full nans, so we have to use masking, makes more complicated + data_for_time = np.array(time_groups[time]) + nan_mask = np.all(np.isnan(data_for_time), axis=0) + clean_data[idx, ~nan_mask] = np.nanmean(data_for_time[:, ~nan_mask], axis=0) - for time in duplicated_times: - indices = np.where(times == time)[0] - nan_mask = np.all(np.isnan(data[indices]), axis=0) - data[indices[0], ~nan_mask] = np.nanmean(data[indices][:, ~nan_mask], axis=0) - - times = np.delete(times, duplicated_indices, axis=0) - data = np.delete(data, duplicated_indices, axis=0) - - return data, times + return clean_data, clean_times def _copy_old_features(self, new_eopatch: EOPatch, old_eopatch: EOPatch) -> EOPatch: - """Copy features from old EOPatch - - :param new_eopatch: New EOPatch container where the old features will be copied to - :param old_eopatch: Old EOPatch container where the old features are located - """ + """Copy features from old EOPatch into new_eopatch""" if self.copy_features_parser is not None: - existing_features: Set[Tuple[FeatureType, Optional[str]]] = set(self.parse_features(..., new_eopatch)) - - renamed_features = self.copy_features_parser.get_renamed_features(old_eopatch) - for copy_feature_type, copy_feature_name, copy_new_feature_name in renamed_features: - new_feature = copy_feature_type, copy_new_feature_name - - if new_feature in existing_features: - raise ValueError( - f"Feature {copy_new_feature_name} of {copy_feature_type} already exists in the " - "new EOPatch! Use a different name!" - ) - existing_features.add(new_feature) + for ftype, fname, new_fname in self.copy_features_parser.get_renamed_features(old_eopatch): + if (ftype, new_fname) in new_eopatch: + raise ValueError(f"Feature {new_fname} of {ftype} already exists in the new EOPatch!") - new_eopatch[copy_feature_type][copy_new_feature_name] = old_eopatch[copy_feature_type][ - copy_feature_name - ] + new_eopatch[ftype, new_fname] = old_eopatch[ftype, fname] return new_eopatch def interpolate_data(self, data: np.ndarray, times: np.ndarray, resampled_times: np.ndarray) -> np.ndarray: """Interpolates data feature - :param data: Array in a shape of t x nobs, where nobs = h x w x n + :param data: Array in a shape of t x num_obs, where num_obs = h x w x n :param times: Array of reference times relative to the first timestamp :param resampled_times: Array of reference times relative to the first timestamp in initial timestamp array. :return: Array of interpolated values """ # pylint: disable=too-many-locals - # get size of 2d array t x nobs - nobs = data.shape[-1] - if self.interpolate_pixel_wise: - # initialise array of interpolated values - new_data = ( - data if self.resample_range is None else np.full((len(resampled_times), nobs), np.nan, dtype=data.dtype) - ) - - # Interpolate for each pixel, could be easily parallelized - for obs in range(nobs): - valid = ~np.isnan(data[:, obs]) + num_obs = data.shape[-1] + new_data = np.full((len(resampled_times), num_obs), np.nan, dtype=data.dtype) - obs_interpolating_func = self.get_interpolation_function(times[valid], data[valid, obs]) + if self.interpolate_pixel_wise: + for idx in range(num_obs): + tseries = data[:, idx] + valid = ~np.isnan(tseries) + obs_interpolating_func = self.get_interpolation_function(times[valid], tseries[valid]) - new_data[:, obs] = obs_interpolating_func(resampled_times[:, np.newaxis]) + new_data[:, idx] = obs_interpolating_func(resampled_times[:, np.newaxis]) - # return interpolated values return new_data - # mask representing overlap between reference and resampled times - time_mask = (resampled_times >= np.min(times)) & (resampled_times <= np.max(times)) - - # define time values as linear monotonically increasing over the observations - const = int(self.filling_factor * (np.max(times) - np.min(times))) - temp_values = times[:, np.newaxis] + const * np.arange(nobs)[np.newaxis, :].astype(np.float64) - res_temp_values = resampled_times[:, np.newaxis] + const * np.arange(nobs)[np.newaxis, :].astype(np.float64) - - # initialise array of interpolated values - new_data = np.full((len(resampled_times), nobs), np.nan, dtype=data.dtype) - # array defining index correspondence between reference times and resampled times + min_time, max_time = np.min(resampled_times), np.max(resampled_times) ori2res = np.array( [ - ( - np.abs(resampled_times - o).argmin() - if np.min(resampled_times) <= o <= np.max(resampled_times) - else None - ) - for o in times + np.abs(resampled_times - orig_time).argmin() if min_time <= orig_time <= max_time else None + for orig_time in times ] ) @@ -331,25 +277,28 @@ def interpolate_data(self, data: np.ndarray, times: np.ndarray, resampled_times: nan_row_res_indices = np.array([index for index in ori2res[row_nans] if index is not None], dtype=np.int32) nan_col_res_indices = np.array([index is not None for index in ori2res[row_nans]], dtype=bool) + # define time values as linear monotonically increasing over the observations + const = int(self.filling_factor * (np.max(times) - np.min(times))) + temp_values = times[:, np.newaxis] + const * np.arange(num_obs)[np.newaxis, :].astype(np.float64) + res_temp_values = resampled_times[:, np.newaxis] + const * np.arange(num_obs)[np.newaxis, :].astype(np.float64) + if nan_row_res_indices.size: # mask out from output values the starting/ending NaNs res_temp_values[nan_row_res_indices, col_nans[nan_col_res_indices]] = np.nan # if temporal values outside the reference dates are required (extrapolation) masked them to NaN + time_mask = (resampled_times >= np.min(times)) & (resampled_times <= np.max(times)) res_temp_values[~time_mask, :] = np.nan - # build 1d array for interpolation. Spline functions require monotonically increasing values of x, - # so .T is used + # build 1d array for interpolation. Spline functions require monotonically increasing values of x, so .T is used input_x = temp_values.T[~np.isnan(data).T] input_y = data.T[~np.isnan(data).T] # build interpolation function if len(input_x) > 1: interp_func = self.get_interpolation_function(input_x, input_y) + valid = ~np.isnan(res_temp_values) + new_data[valid] = interp_func(res_temp_values[valid]) - # interpolate non-NaN values in resampled time values - new_data[~np.isnan(res_temp_values)] = interp_func(res_temp_values[~np.isnan(res_temp_values)]) - - # return interpolated values return new_data def get_interpolation_function(self, times: np.ndarray, series: np.ndarray) -> Callable: @@ -363,13 +312,8 @@ def get_interpolation_function(self, times: np.ndarray, series: np.ndarray) -> C return partial(self.interpolation_object, xp=times, fp=series, left=np.nan, right=np.nan) return self.interpolation_object(times, series, **self.interpolation_parameters) - def get_resampled_timestamp(self, timestamps: List[dt.datetime]) -> List[dt.datetime]: - """Takes a list of timestamps and generates new list of timestamps according to ``resample_range`` - - :param timestamp: list of timestamps - :return: new list of timestamps - """ - days: List[dt.datetime] + def get_resampled_timestamp(self, timestamps: list[dt.datetime]) -> list[dt.datetime]: + """Takes a list of timestamps and generates new list of timestamps according to `resample_range`""" if self.resample_range is None: return timestamps @@ -377,26 +321,24 @@ def get_resampled_timestamp(self, timestamps: List[dt.datetime]) -> List[dt.date raise ValueError(f"Invalid resample_range {self.resample_range}, expected tuple") if tuple(map(type, self.resample_range)) == (str, str, int): - resample_range = cast(Tuple[str, str, int], self.resample_range) - start_date = dateutil.parser.parse(resample_range[0]) - end_date = dateutil.parser.parse(resample_range[1]) - step = dt.timedelta(days=resample_range[2]) + start_str, end_str, step_size = cast(Tuple[str, str, int], self.resample_range) + start_date, end_date = dateutil.parser.parse(start_str), dateutil.parser.parse(end_str) + step = dt.timedelta(days=step_size) days = [start_date] while days[-1] + step < end_date: days.append(days[-1] + step) - elif np.all([isinstance(date, str) for date in self.resample_range]): - days = [dateutil.parser.parse(date) for date in cast(List[str], self.resample_range)] - elif np.all([isinstance(date, dt.datetime) for date in self.resample_range]): - days = list(cast(List[dt.datetime], self.resample_range)) - else: - raise ValueError("Invalid format in {self.resample_range}, expected strings or datetimes") + return days + + if isinstance(self.resample_range, (list, tuple)): + dates = cast(Iterable[Union[str, dt.datetime]], self.resample_range) + return [dateutil.parser.parse(date) if isinstance(date, str) else date for date in dates] - return days + raise ValueError(f"Invalid format in {self.resample_range}, expected {ResampleRangeType}") @staticmethod def _get_eopatch_time_series( - eopatch: EOPatch, ref_date: Optional[dt.datetime] = None, scale_time: int = 1 + eopatch: EOPatch, ref_date: dt.datetime | None = None, scale_time: int = 1 ) -> np.ndarray: """Returns a numpy array with seconds passed between the reference date and the timestamp of each image. @@ -411,8 +353,7 @@ def _get_eopatch_time_series( if not eopatch.timestamps: return np.zeros(0, dtype=np.int64) - if ref_date is None: - ref_date = eopatch.timestamps[0] + ref_date = ref_date or eopatch.timestamps[0] return np.asarray( [round((timestamp - ref_date).total_seconds() / scale_time) for timestamp in eopatch.timestamps], @@ -425,23 +366,19 @@ def execute(self, eopatch: EOPatch) -> EOPatch: feature_type, feature_name, new_feature_name = self.renamed_feature # Make a copy not to change original numpy array - feature_data = eopatch[feature_type][feature_name].copy() + feature_data = eopatch[feature_type, feature_name].copy() time_num, height, width, band_num = feature_data.shape if time_num <= 1: raise ValueError( - f"Feature {(feature_type, feature_name)} has time dimension of size {time_num}, " - "required at least size 2" + f"Feature {(feature_type, feature_name)} has temporal dimension {time_num}, required at least size 2" ) # Apply a mask on data if self.mask_feature_parser is not None: for mask_type, mask_name in self.mask_feature_parser.get_features(eopatch): - negated_mask = ~eopatch[mask_type][mask_name].astype(bool) + negated_mask = ~eopatch[mask_type, mask_name].astype(bool) feature_data = self._mask_feature_data(feature_data, negated_mask, mask_type) - # Flatten array - feature_data = np.reshape(feature_data, (time_num, height * width * band_num)) - # If resampling create new EOPatch new_eopatch = EOPatch(bbox=eopatch.bbox) if self.resample_range else eopatch @@ -453,28 +390,21 @@ def execute(self, eopatch: EOPatch) -> EOPatch: self._get_eopatch_time_series(new_eopatch, scale_time=self.scale_time) + total_diff // self.scale_time ) + # Flatten array + feature_data = np.reshape(feature_data, (time_num, height * width * band_num)) + # Replace duplicate acquisitions which have same values on the chosen timescale with their average feature_data, times = self._get_unique_times(feature_data, times) # Interpolate feature_data = self.interpolate_data(feature_data, times, resampled_times) - # Normalize + # Normalize and insert correct unknown value if self.result_interval: - min_val, max_val = self.result_interval - valid_mask = ~np.isnan(feature_data) - feature_data[valid_mask] = np.maximum(np.minimum(feature_data[valid_mask], max_val), min_val) + feature_data = np.clip(feature_data, *self.result_interval) + feature_data[np.isnan(feature_data)] = self.unknown_value - # Replace unknown value - if not np.isnan(self.unknown_value): - feature_data[np.isnan(feature_data)] = self.unknown_value - - # Reshape back - new_eopatch[feature_type][new_feature_name] = np.reshape( - feature_data, (feature_data.shape[0], height, width, band_num) - ) - - # append features from old patch + new_eopatch[feature_type, new_feature_name] = np.reshape(feature_data, (-1, height, width, band_num)) new_eopatch = self._copy_old_features(new_eopatch, eopatch) return new_eopatch @@ -483,8 +413,7 @@ def execute(self, eopatch: EOPatch) -> EOPatch: class LinearInterpolationTask(InterpolationTask): """Implements `eolearn.features.InterpolationTask` by using `numpy.interp` and `@numba.jit(nopython=True)` - :param parallel: interpolation is calculated in parallel using as many CPUs as detected - by the multiprocessing module. + :param parallel: interpolation is calculated in parallel by numba. :param kwargs: parameters of InterpolationTask(EOTask) """ @@ -495,10 +424,9 @@ def __init__(self, feature: SingleFeatureSpec, parallel: bool = False, **kwargs: def interpolate_data(self, data: np.ndarray, times: np.ndarray, resampled_times: np.ndarray) -> np.ndarray: """Interpolates data feature - :param data: Array in a shape of t x nobs, where nobs = h x w x n + :param data: Array in a shape of t x num_obs, where num_obs = h x w x n :param times: Array of reference times in second relative to the first timestamp - :param resampled_times: Array of reference times in second relative to the first timestamp in initial timestamp - array. + :param resampled_times: Array of reference times relative to the first timestamp in initial timestamp array. :return: Array of interpolated values """ if self.parallel: @@ -581,7 +509,7 @@ def __init__( interpolation_object: Callable, resample_range: ResampleRangeType, *, - result_interval: Optional[Tuple[float, float]] = None, + result_interval: tuple[float, float] | None = None, unknown_value: float = np.nan, **interpolation_parameters: Any, ): @@ -599,19 +527,18 @@ def __init__( def interpolate_data(self, data: np.ndarray, times: np.ndarray, resampled_times: np.ndarray) -> np.ndarray: """Interpolates data feature - :param data: Array in a shape of t x nobs, where nobs = h x w x n + :param data: Array in a shape of t x num_obs, where num_obs = h x w x n :param times: Array of reference times in second relative to the first timestamp - :param resampled_times: Array of reference times in second relative to the first timestamp in initial timestamp - array. + :param resampled_times: Array of reference times relative to the first timestamp in initial timestamp array. :return: Array of interpolated values """ - if True in np.unique(np.isnan(data)): + if np.isnan(data).any(): raise ValueError("Data must not contain any masked/invalid pixels or NaN values") interp_func = self.get_interpolation_function(times, data) - time_mask = (resampled_times >= np.min(times)) & (resampled_times <= np.max(times)) - new_data = np.full((resampled_times.size,) + data.shape[1:], np.nan, dtype=data.dtype) + time_mask = (np.min(times) <= resampled_times) & (resampled_times <= np.max(times)) + new_data = np.full((resampled_times.size, *data.shape[1:]), np.nan, dtype=data.dtype) new_data[time_mask] = interp_func(resampled_times[time_mask]) return new_data @@ -626,27 +553,21 @@ def get_interpolation_function(self, times: np.ndarray, series: np.ndarray) -> C class NearestResamplingTask(ResamplingTask): - """ - Implements `eolearn.features.ResamplingTask` by using `scipy.interpolate.interp1d(kind='nearest')` - """ + """Implements `eolearn.features.ResamplingTask` by using `scipy.interpolate.interp1d(kind='nearest')`""" def __init__(self, feature: SingleFeatureSpec, resample_range: ResampleRangeType, **kwargs: Any): super().__init__(feature, scipy.interpolate.interp1d, resample_range, kind="nearest", **kwargs) class LinearResamplingTask(ResamplingTask): - """ - Implements `eolearn.features.ResamplingTask` by using `scipy.interpolate.interp1d(kind='linear')` - """ + """Implements `eolearn.features.ResamplingTask` by using `scipy.interpolate.interp1d(kind='linear')`""" def __init__(self, feature: SingleFeatureSpec, resample_range: ResampleRangeType, **kwargs: Any): super().__init__(feature, scipy.interpolate.interp1d, resample_range, kind="linear", **kwargs) class CubicResamplingTask(ResamplingTask): - """ - Implements `eolearn.features.ResamplingTask` by using `scipy.interpolate.interp1d(kind='cubic')` - """ + """Implements `eolearn.features.ResamplingTask` by using `scipy.interpolate.interp1d(kind='cubic')`""" def __init__(self, feature: SingleFeatureSpec, resample_range: ResampleRangeType, **kwargs: Any): super().__init__(feature, scipy.interpolate.interp1d, resample_range, kind="cubic", **kwargs) diff --git a/features/eolearn/features/py.typed b/features/eolearn/features/py.typed new file mode 100644 index 000000000..e69de29bb diff --git a/features/eolearn/features/radiometric_normalization.py b/features/eolearn/features/radiometric_normalization.py index d98ddcdcb..54d0ed24f 100644 --- a/features/eolearn/features/radiometric_normalization.py +++ b/features/eolearn/features/radiometric_normalization.py @@ -9,7 +9,6 @@ from __future__ import annotations from abc import ABCMeta, abstractmethod -from typing import Optional import numpy as np @@ -37,7 +36,7 @@ def __init__( self, feature: SingleFeatureSpec, valid_fraction_feature: SingleFeatureSpec, - max_scene_number: Optional[int] = None, + max_scene_number: int | None = None, ): self.renamed_feature = parse_renamed_feature(feature) self.valid_fraction_feature = self.parse_feature(valid_fraction_feature) @@ -133,8 +132,7 @@ def _geoville_index_by_percentile(self, data: np.ndarray, percentile: int) -> np ind = f_arr.astype("int16") y_val, x_val = ind_tmp.shape[1], ind_tmp.shape[2] y_val, x_val = np.ogrid[0:y_val, 0:x_val] # type: ignore[assignment] - idx = np.where(valid_obs == 0, self.max_index, ind_tmp[ind, y_val, x_val]) - return idx + return np.where(valid_obs == 0, self.max_index, ind_tmp[ind, y_val, x_val]) @abstractmethod def _get_reference_band(self, data: np.ndarray) -> np.ndarray: @@ -150,8 +148,7 @@ def _get_indices(self, data: np.ndarray) -> np.ndarray: :param data: Input 3D array holding the reference band :return: 2D array holding the temporal index corresponding to percentile """ - indices = self._index_by_percentile(data, self.percentile) - return indices + return self._index_by_percentile(data, self.percentile) def execute(self, eopatch: EOPatch) -> EOPatch: """Compute composite array merging temporal frames according to the compositing method @@ -283,8 +280,7 @@ def _get_indices(self, data: np.ndarray) -> np.ndarray: median = np.nanmedian(data, axis=0) indices_min = self._index_by_percentile(data, self.percentiles[0]) indices_max = self._index_by_percentile(data, self.percentiles[1]) - indices = np.where(median < -0.05, indices_min, indices_max) - return indices + return np.where(median < -0.05, indices_min, indices_max) class MaxNDWICompositingTask(BaseCompositingTask): diff --git a/features/eolearn/features/temporal_features.py b/features/eolearn/features/temporal_features.py index e7d5b5c0e..1ebdc8fc1 100644 --- a/features/eolearn/features/temporal_features.py +++ b/features/eolearn/features/temporal_features.py @@ -154,10 +154,10 @@ def execute(self, eopatch: EOPatch) -> EOPatch: argmin_data = np.ma.MaskedArray.argmin(madata, axis=0) if argmax_data.ndim == 2: - argmax_data = argmax_data.reshape(argmax_data.shape + (1,)) + argmax_data = argmax_data.reshape((*argmax_data.shape, 1)) if argmin_data.ndim == 2: - argmin_data = argmin_data.reshape(argmin_data.shape + (1,)) + argmin_data = argmin_data.reshape((*argmin_data.shape, 1)) eopatch.data_timeless[self.amax_feature] = argmax_data eopatch.data_timeless[self.amin_feature] = argmin_data diff --git a/features/eolearn/tests/test_bands_extraction.py b/features/eolearn/tests/test_bands_extraction.py index 3565b0dbb..e8b0adfa0 100644 --- a/features/eolearn/tests/test_bands_extraction.py +++ b/features/eolearn/tests/test_bands_extraction.py @@ -31,7 +31,7 @@ def test_euclidean_norm(): assert (eopatch.data["NORM"] == np.sqrt(len(bands))).all() -@pytest.mark.parametrize("bad_input", ([1, 2, 3], "test", 0.5)) +@pytest.mark.parametrize("bad_input", [(1, 2, 3), "test", 0.5]) def test_bad_input(bad_input): with pytest.raises(ValueError): NormalizedDifferenceIndexTask(INPUT_FEATURE, (FeatureType.DATA, "NDI"), bands=bad_input) diff --git a/features/eolearn/tests/test_blob.py b/features/eolearn/tests/test_blob.py index 71405c959..062e18c37 100644 --- a/features/eolearn/tests/test_blob.py +++ b/features/eolearn/tests/test_blob.py @@ -7,47 +7,36 @@ This source code is licensed under the MIT license, see the LICENSE file in the root directory of this source tree. """ import copy -import sys import pytest -from pytest import approx -from skimage.feature import blob_dog +from skimage.feature import blob_dog, blob_doh, blob_log from sentinelhub.testing_utils import assert_statistics_match from eolearn.core import FeatureType -from eolearn.features import BlobTask, DoGBlobTask, DoHBlobTask, LoGBlobTask +from eolearn.features import BlobTask FEATURE = (FeatureType.DATA, "NDVI", "blob") BLOB_FEATURE = (FeatureType.DATA, "blob") -def test_dog_blob_task(small_ndvi_eopatch): - eopatch = small_ndvi_eopatch - BlobTask(FEATURE, blob_dog, sigma_ratio=1.6, min_sigma=1, max_sigma=30, overlap=0.5, threshold=0)(eopatch) - DoGBlobTask((FeatureType.DATA, "NDVI", "blob_dog"), threshold=0)(eopatch) - assert eopatch[BLOB_FEATURE] == approx(eopatch.data["blob_dog"]) - - BLOB_TESTS = [ - (DoGBlobTask(FEATURE, threshold=0), {"exp_min": 0.0, "exp_max": 37.9625, "exp_mean": 0.08545, "exp_median": 0.0}), + ( + BlobTask(FEATURE, blob_dog, threshold=0, max_sigma=30), + {"exp_min": 0.0, "exp_max": 37.9625, "exp_mean": 0.08545, "exp_median": 0.0}, + ), + ( + BlobTask(FEATURE, blob_doh, num_sigma=5, threshold=0), + {"exp_min": 0.0, "exp_max": 21.9203, "exp_mean": 0.05807, "exp_median": 0.0}, + ), + ( + BlobTask(FEATURE, blob_log, log_scale=True, threshold=0, max_sigma=30), + {"exp_min": 0, "exp_max": 42.4264, "exp_mean": 0.09767, "exp_median": 0.0}, + ), ] -if sys.version_info >= (3, 8): # For Python 3.7 scikit-image returns less accurate result for this test - BLOB_TESTS.extend( - [ - ( - DoHBlobTask(FEATURE, num_sigma=5, threshold=0), - {"exp_min": 0.0, "exp_max": 21.9203, "exp_mean": 0.05807, "exp_median": 0.0}, - ), - ( - LoGBlobTask(FEATURE, log_scale=True, threshold=0), - {"exp_min": 0, "exp_max": 42.4264, "exp_mean": 0.09767, "exp_median": 0.0}, - ), - ] - ) - - -@pytest.mark.parametrize("task, expected_statistics", BLOB_TESTS) + + +@pytest.mark.parametrize(("task", "expected_statistics"), BLOB_TESTS) def test_blob_task(small_ndvi_eopatch, task, expected_statistics): eopatch = copy.deepcopy(small_ndvi_eopatch) task.execute(eopatch) diff --git a/features/eolearn/tests/test_clustering.py b/features/eolearn/tests/test_clustering.py index a48b8a780..c0bb804d3 100644 --- a/features/eolearn/tests/test_clustering.py +++ b/features/eolearn/tests/test_clustering.py @@ -8,7 +8,7 @@ import logging import numpy as np -from pytest import approx +import pytest from eolearn.core import FeatureType from eolearn.features import ClusteringTask @@ -19,7 +19,7 @@ def test_clustering(example_eopatch): test_features = {FeatureType.DATA_TIMELESS: ["DEM", "MAX_NDVI"]} mask = np.zeros_like(example_eopatch.mask_timeless["LULC"], dtype=np.uint8) - mask[:90, :90] = 1 + mask[:90, :95] = 1 example_eopatch.mask_timeless["mask"] = mask ClusteringTask( @@ -38,17 +38,18 @@ def test_clustering(example_eopatch): affinity="cosine", linkage="average", mask_name="mask", + remove_small=10, ).execute(example_eopatch) clusters = example_eopatch.data_timeless["clusters_small"].squeeze() assert len(np.unique(clusters)) == 22, "Wrong number of clusters." assert np.median(clusters) == 2 - assert np.mean(clusters) == approx(2.19109) + assert np.mean(clusters) == pytest.approx(2.19109) clusters = example_eopatch.data_timeless["clusters_mask"].squeeze() - assert len(np.unique(clusters)) == 20, "Wrong number of clusters." + assert len(np.unique(clusters)) == 8, "Wrong number of clusters." assert np.median(clusters) == 0 - assert np.mean(clusters) == approx(-0.0948515) - assert np.all(clusters[90:, 90:] == -1), "Wrong area" + assert np.mean(clusters) == pytest.approx(-0.0550495) + assert np.all(clusters[90:, 95:] == -1), "Wrong area" diff --git a/features/eolearn/tests/test_doubly_logistic_approximation.py b/features/eolearn/tests/test_doubly_logistic_approximation.py index 34f1b81cf..9f3f3bc4d 100644 --- a/features/eolearn/tests/test_doubly_logistic_approximation.py +++ b/features/eolearn/tests/test_doubly_logistic_approximation.py @@ -4,9 +4,8 @@ This source code is licensed under the MIT license, see the LICENSE file in the root directory of this source tree. """ - import numpy as np -from pytest import approx +import pytest from sentinelhub import CRS, BBox @@ -37,4 +36,4 @@ def test_double_logistic_approximation(example_eopatch): delta = 0.1 for name, value, expected_value in zip(names, values, expected_values): - assert value == approx(expected_value, abs=delta), f"Missmatch in value of {name}" + assert value == pytest.approx(expected_value, abs=delta), f"Missmatch in value of {name}" diff --git a/features/eolearn/tests/test_feature_manipulation.py b/features/eolearn/tests/test_feature_manipulation.py index cdca39206..674684ce2 100644 --- a/features/eolearn/tests/test_feature_manipulation.py +++ b/features/eolearn/tests/test_feature_manipulation.py @@ -20,6 +20,7 @@ from eolearn.features.utils import ResizeParam DUMMY_BBOX = BBox((0, 0, 1, 1), CRS(3857)) +# ruff: noqa: NPY002 @pytest.mark.parametrize( @@ -133,7 +134,7 @@ def test_fill(): @pytest.mark.parametrize( - "input_feature, operation", + ("input_feature", "operation"), [((FeatureType.DATA, "TEST"), "x"), ((FeatureType.DATA, "TEST"), 4), (None, "f"), (np.zeros((4, 5)), "x")], ) def test_bad_input(input_feature, operation): @@ -241,7 +242,7 @@ def test_linear_function_task(): @pytest.mark.parametrize( - ["resize_type", "height_param", "width_param", "features_call", "features_check", "outputs"], + ("resize_type", "height_param", "width_param", "features_call", "features_check", "outputs"), [ (ResizeParam.NEW_SIZE, 50, 70, ("data", "CLP"), ("data", "CLP"), (68, 50, 70, 1)), (ResizeParam.NEW_SIZE, 50, 70, ("data", "CLP"), ("mask", "CLM"), (68, 101, 100, 1)), @@ -269,9 +270,6 @@ def test_spatial_resize_task( assert resize(example_eopatch)[features_check].shape == outputs -def test_spatial_resize_task_exception(example_eopatch): +def test_spatial_resize_task_exception(): with pytest.raises(ValueError): - resize_wrong_param = SpatialResizeTask( - features=("mask", "CLM"), resize_type="blabla", height_param=20, width_param=20 - ) - resize_wrong_param(example_eopatch) + SpatialResizeTask(features=("mask", "CLM"), resize_type="blabla", height_param=20, width_param=20) diff --git a/features/eolearn/tests/test_features_utils.py b/features/eolearn/tests/test_features_utils.py index 39e01f7a6..851d7823c 100644 --- a/features/eolearn/tests/test_features_utils.py +++ b/features/eolearn/tests/test_features_utils.py @@ -8,7 +8,7 @@ @pytest.mark.parametrize("method", ResizeMethod) @pytest.mark.parametrize("library", ResizeLib) -@pytest.mark.parametrize("dtype", (np.float32, np.int32, np.uint8, bool)) +@pytest.mark.parametrize("dtype", [np.float32, np.int32, np.uint8, bool]) @pytest.mark.parametrize("new_size", [(50, 50), (35, 39), (271, 271)]) def test_spatially_resize_image_new_size( method: ResizeMethod, library: ResizeLib, dtype: Union[np.dtype, type], new_size: Tuple[int, int] @@ -21,24 +21,28 @@ def test_spatially_resize_image_new_size( old_shape = (111, 111) data_2d = np.arange(np.prod(old_shape)).astype(dtype).reshape(old_shape) result = spatially_resize_image(data_2d, new_size, resize_method=method, resize_library=library) - assert result.shape == new_size and result.dtype == dtype + assert result.shape == new_size + assert result.dtype == dtype old_shape = (111, 111, 3) data_3d = np.arange(np.prod(old_shape)).astype(dtype).reshape(old_shape) result = spatially_resize_image(data_3d, new_size, resize_method=method, resize_library=library) - assert result.shape == (*new_size, 3) and result.dtype == dtype + assert result.shape == (*new_size, 3) + assert result.dtype == dtype old_shape = (5, 111, 111, 3) data_4d = np.arange(np.prod(old_shape)).astype(dtype).reshape(old_shape) result = spatially_resize_image(data_4d, new_size, resize_method=method, resize_library=library) - assert result.shape == (5, *new_size, 3) and result.dtype == dtype + assert result.shape == (5, *new_size, 3) + assert result.dtype == dtype old_shape = (2, 1, 111, 111, 3) data_5d = np.arange(np.prod(old_shape)).astype(dtype).reshape(old_shape) result = spatially_resize_image( data_5d, new_size, resize_method=method, spatial_axes=(2, 3), resize_library=library ) - assert result.shape == (2, 1, *new_size, 3) and result.dtype == dtype + assert result.shape == (2, 1, *new_size, 3) + assert result.dtype == dtype @pytest.mark.parametrize("method", ResizeMethod) @@ -56,10 +60,10 @@ def test_spatially_resize_image_scale_factors( assert result.shape == (height * scale_factors[0], width * scale_factors[1], 3) -@pytest.mark.parametrize("library", (ResizeLib.PIL,)) +@pytest.mark.parametrize("library", [ResizeLib.PIL]) @pytest.mark.parametrize( "dtype", - ( + [ bool, np.int8, np.uint8, @@ -74,7 +78,7 @@ def test_spatially_resize_image_scale_factors( np.float32, np.float64, float, - ), + ], ) @pytest.mark.filterwarnings("ignore::RuntimeWarning") def test_spatially_resize_image_dtype(library: ResizeLib, dtype: Union[np.dtype, type]): diff --git a/features/eolearn/tests/test_haralick.py b/features/eolearn/tests/test_haralick.py index 3358d88ee..15e0bcec3 100644 --- a/features/eolearn/tests/test_haralick.py +++ b/features/eolearn/tests/test_haralick.py @@ -19,21 +19,25 @@ @pytest.mark.parametrize( - "task, expected_statistics", - ( - [ + ("task", "expected_statistics"), + [ + ( HaralickTask(FEATURE, texture_feature="contrast", angle=0, levels=255, window_size=3), {"exp_min": 3.5, "exp_max": 9079.0, "exp_mean": 965.8295, "exp_median": 628.5833}, - ], - [ + ), + ( HaralickTask(FEATURE, texture_feature="sum_of_square_variance", angle=np.pi / 2, levels=8, window_size=5), {"exp_min": 0.96899, "exp_max": 48.7815, "exp_mean": 23.0229, "exp_median": 23.8987}, - ], - [ + ), + ( HaralickTask(FEATURE, texture_feature="sum_entropy", angle=-np.pi / 2, levels=8, window_size=7), {"exp_min": 0, "exp_max": 1.7463, "exp_mean": 0.5657, "exp_median": 0.50558}, - ], - ), + ), + ( + HaralickTask(FEATURE, texture_feature="difference_variance", angle=-np.pi / 2, levels=8, window_size=7), + {"exp_min": 42, "exp_max": 110.6122, "exp_mean": 53.857082, "exp_median": 50}, + ), + ], ) def test_haralick(small_ndvi_eopatch, task, expected_statistics): eopatch = copy.deepcopy(small_ndvi_eopatch) diff --git a/features/eolearn/tests/test_interpolation.py b/features/eolearn/tests/test_interpolation.py index 3d930afd8..14fcb6482 100644 --- a/features/eolearn/tests/test_interpolation.py +++ b/features/eolearn/tests/test_interpolation.py @@ -65,6 +65,8 @@ def execute(self, eopatch): return result +# Some of these might be very randomly slow, but that is due to the JIT of numba +# It is hard to trigger it before the tests reliably. INTERPOLATION_TEST_CASES = [ InterpolationTestCase( "linear", @@ -230,46 +232,52 @@ def execute(self, eopatch): COPY_FEATURE_CASES = [ - InterpolationTestCase( - "cubic_copy_success", - CubicInterpolationTask( - (FeatureType.DATA, "NDVI"), - result_interval=(0.0, 1.0), - mask_feature=(FeatureType.MASK, "IS_VALID"), - resample_range=("2015-01-01", "2018-01-01", 16), - unknown_value=5, - bounds_error=False, - copy_features=[ - (FeatureType.MASK, "IS_VALID"), - (FeatureType.DATA, "NDVI", "NDVI_OLD"), - (FeatureType.MASK_TIMELESS, "LULC"), - ], + ( + InterpolationTestCase( + "cubic_copy_success", + CubicInterpolationTask( + (FeatureType.DATA, "NDVI"), + result_interval=(0.0, 1.0), + mask_feature=(FeatureType.MASK, "IS_VALID"), + resample_range=("2015-01-01", "2018-01-01", 16), + unknown_value=5, + bounds_error=False, + copy_features=[ + (FeatureType.MASK, "IS_VALID"), + (FeatureType.DATA, "NDVI", "NDVI_OLD"), + (FeatureType.MASK_TIMELESS, "LULC"), + ], + ), + result_len=69, + expected_statistics={}, ), - result_len=69, - expected_statistics=dict(exp_min=0.0, exp_max=5.0, exp_mean=1.3592644, exp_median=0.6174331), + True, ), - InterpolationTestCase( - "cubic_copy_fail", - CubicInterpolationTask( - (FeatureType.DATA, "NDVI"), - result_interval=(0.0, 1.0), - mask_feature=(FeatureType.MASK, "IS_VALID"), - resample_range=("2015-01-01", "2018-01-01", 16), - unknown_value=5, - bounds_error=False, - copy_features=[ - (FeatureType.MASK, "IS_VALID"), + ( + InterpolationTestCase( + "cubic_copy_fail", + CubicInterpolationTask( (FeatureType.DATA, "NDVI"), - (FeatureType.MASK_TIMELESS, "LULC"), - ], + result_interval=(0.0, 1.0), + mask_feature=(FeatureType.MASK, "IS_VALID"), + resample_range=("2015-01-01", "2018-01-01", 16), + unknown_value=5, + bounds_error=False, + copy_features=[ + (FeatureType.MASK, "IS_VALID"), + (FeatureType.DATA, "NDVI"), + (FeatureType.MASK_TIMELESS, "LULC"), + ], + ), + result_len=69, + expected_statistics={}, ), - result_len=69, - expected_statistics=dict(exp_min=0.0, exp_max=5.0, exp_mean=1.3592644, exp_median=0.6174331), + False, ), ] -@pytest.mark.parametrize("test_case", INTERPOLATION_TEST_CASES) +@pytest.mark.parametrize("test_case", INTERPOLATION_TEST_CASES, ids=lambda x: x.name) def test_interpolation(test_case: InterpolationTestCase, test_patch): eopatch = test_case.execute(test_patch) delta = 1e-4 if isinstance(test_case.task, KrigingInterpolationTask) else 1e-5 @@ -286,14 +294,11 @@ def test_interpolation(test_case: InterpolationTestCase, test_patch): assert_statistics_match(data, **test_case.expected_statistics, abs_delta=delta) -@pytest.mark.parametrize("test_case", COPY_FEATURE_CASES) -def test_copied_fields(test_case, test_patch): - try: +@pytest.mark.parametrize(("test_case", "passes"), COPY_FEATURE_CASES) +def test_copied_fields(test_case, passes, test_patch): + if passes: eopatch = test_case.execute(test_patch) - except ValueError: - eopatch = None - if eopatch is not None: copied_features = [ (FeatureType.MASK, "IS_VALID"), (FeatureType.DATA, "NDVI_OLD"), @@ -301,3 +306,7 @@ def test_copied_fields(test_case, test_patch): ] for feature in copied_features: assert feature in eopatch, f"Expected feature `{feature}` is not present in EOPatch" + else: + # Fails due to name duplication + with pytest.raises(ValueError): + test_case.execute(test_patch) diff --git a/features/eolearn/tests/test_local_binary_pattern.py b/features/eolearn/tests/test_local_binary_pattern.py index f98bd6831..876e48349 100644 --- a/features/eolearn/tests/test_local_binary_pattern.py +++ b/features/eolearn/tests/test_local_binary_pattern.py @@ -19,13 +19,13 @@ @pytest.mark.parametrize( - "task, expected_statistics", - ( - [ + ("task", "expected_statistics"), + [ + ( LocalBinaryPatternTask(LBP_FEATURE, nb_points=24, radius=3), {"exp_min": 0.0, "exp_max": 25.0, "exp_mean": 15.8313, "exp_median": 21.0}, - ], - ), + ), + ], ) def test_local_binary_pattern(small_ndvi_eopatch, task, expected_statistics): eopatch = copy.deepcopy(small_ndvi_eopatch) diff --git a/features/eolearn/tests/test_radiometric_normalization.py b/features/eolearn/tests/test_radiometric_normalization.py index 68ac42b96..1b7e03cf1 100644 --- a/features/eolearn/tests/test_radiometric_normalization.py +++ b/features/eolearn/tests/test_radiometric_normalization.py @@ -24,6 +24,8 @@ ) from eolearn.mask import MaskFeatureTask +# ruff: noqa: NPY002 + @pytest.fixture(name="eopatch") def eopatch_fixture(example_eopatch): @@ -44,9 +46,9 @@ def eopatch_fixture(example_eopatch): @pytest.mark.parametrize( - "task, test_feature, expected_statistics", - ( - [ + ("task", "test_feature", "expected_statistics"), + [ + ( MaskFeatureTask( (FeatureType.DATA, "BANDS-S2-L1C", "TEST"), (FeatureType.MASK, "SCL"), @@ -54,15 +56,15 @@ def eopatch_fixture(example_eopatch): ), DATA_TEST_FEATURE, {"exp_min": 0.0002, "exp_max": 1.4244, "exp_mean": 0.21167801, "exp_median": 0.1422}, - ], - [ + ), + ( ReferenceScenesTask( (FeatureType.DATA, "BANDS-S2-L1C", "TEST"), (FeatureType.SCALAR, "CLOUD_COVERAGE"), max_scene_number=5 ), DATA_TEST_FEATURE, {"exp_min": 0.0005, "exp_max": 0.5318, "exp_mean": 0.16823094, "exp_median": 0.1404}, - ], - [ + ), + ( BlueCompositingTask( (FeatureType.DATA, "REFERENCE_SCENES"), (FeatureType.DATA_TIMELESS, "TEST"), @@ -71,8 +73,8 @@ def eopatch_fixture(example_eopatch): ), DATA_TIMELESS_TEST_FEATURE, {"exp_min": 0.0005, "exp_max": 0.5075, "exp_mean": 0.11658352, "exp_median": 0.0833}, - ], - [ + ), + ( HOTCompositingTask( (FeatureType.DATA, "REFERENCE_SCENES"), (FeatureType.DATA_TIMELESS, "TEST"), @@ -82,8 +84,8 @@ def eopatch_fixture(example_eopatch): ), DATA_TIMELESS_TEST_FEATURE, {"exp_min": 0.0005, "exp_max": 0.5075, "exp_mean": 0.117758796, "exp_median": 0.0846}, - ], - [ + ), + ( MaxNDVICompositingTask( (FeatureType.DATA, "REFERENCE_SCENES"), (FeatureType.DATA_TIMELESS, "TEST"), @@ -93,8 +95,8 @@ def eopatch_fixture(example_eopatch): ), DATA_TIMELESS_TEST_FEATURE, {"exp_min": 0.0005, "exp_max": 0.5075, "exp_mean": 0.13430128, "exp_median": 0.0941}, - ], - [ + ), + ( MaxNDWICompositingTask( (FeatureType.DATA, "REFERENCE_SCENES"), (FeatureType.DATA_TIMELESS, "TEST"), @@ -104,8 +106,8 @@ def eopatch_fixture(example_eopatch): ), DATA_TIMELESS_TEST_FEATURE, {"exp_min": 0.0005, "exp_max": 0.5318, "exp_mean": 0.2580135, "exp_median": 0.2888}, - ], - [ + ), + ( MaxRatioCompositingTask( (FeatureType.DATA, "REFERENCE_SCENES"), (FeatureType.DATA_TIMELESS, "TEST"), @@ -116,15 +118,15 @@ def eopatch_fixture(example_eopatch): ), DATA_TIMELESS_TEST_FEATURE, {"exp_min": 0.0006, "exp_max": 0.5075, "exp_mean": 0.13513365, "exp_median": 0.0958}, - ], - [ + ), + ( HistogramMatchingTask( (FeatureType.DATA, "BANDS-S2-L1C", "TEST"), (FeatureType.DATA_TIMELESS, "REFERENCE_COMPOSITE") ), DATA_TEST_FEATURE, {"exp_min": -0.049050678, "exp_max": 0.68174845, "exp_mean": 0.1165936, "exp_median": 0.08370649}, - ], - ), + ), + ], ) def test_radiometric_normalization(eopatch, task, test_feature, expected_statistics): initial_patch = copy.deepcopy(eopatch) diff --git a/features/setup.py b/features/setup.py index 8e7b442d2..9c0117dac 100644 --- a/features/setup.py +++ b/features/setup.py @@ -7,9 +7,7 @@ def get_long_description(): this_directory = os.path.abspath(os.path.dirname(__file__)) with open(os.path.join(this_directory, "README.md"), encoding="utf-8") as f: - long_description = f.read() - - return long_description + return f.read() def parse_requirements(file): @@ -29,7 +27,7 @@ def get_version(): setup( name="eo-learn-features", - python_requires=">=3.7", + python_requires=">=3.8", version=get_version(), description="A collection of feature manipulation EOTasks and utilities", long_description=get_long_description(), @@ -59,7 +57,6 @@ def get_version(): "Operating System :: Unix", "Programming Language :: Python", "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.7", "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", diff --git a/geometry/MANIFEST.in b/geometry/MANIFEST.in index 1b163e951..b7faef195 100644 --- a/geometry/MANIFEST.in +++ b/geometry/MANIFEST.in @@ -1,4 +1,5 @@ include requirements*.txt include LICENSE include README.md +include eolearn/geometry/py.typed exclude eolearn/tests/* diff --git a/geometry/eolearn/geometry/__init__.py b/geometry/eolearn/geometry/__init__.py index 0edcaedd8..5dfdad3c7 100644 --- a/geometry/eolearn/geometry/__init__.py +++ b/geometry/eolearn/geometry/__init__.py @@ -11,4 +11,4 @@ ) from .transformations import RasterToVectorTask, VectorToRasterTask -__version__ = "1.4.1" +__version__ = "1.4.2" diff --git a/geometry/eolearn/geometry/morphology.py b/geometry/eolearn/geometry/morphology.py index aceafca75..bcdd6558a 100644 --- a/geometry/eolearn/geometry/morphology.py +++ b/geometry/eolearn/geometry/morphology.py @@ -10,7 +10,7 @@ import itertools as it from enum import Enum -from typing import Callable, List, Optional, Tuple, Union, cast +from typing import Callable, Tuple, cast import numpy as np import skimage.filters.rank @@ -35,7 +35,7 @@ def __init__( self, mask_feature: SingleFeatureSpec, disk_radius: int = 1, - erode_labels: Optional[List[int]] = None, + erode_labels: list[int] | None = None, no_data_label: int = 0, ): if not isinstance(disk_radius, int) or disk_radius is None or disk_radius < 1: @@ -139,10 +139,10 @@ class MorphologicalFilterTask(MapFeatureTask): def __init__( self, input_features: FeaturesSpecification, - output_features: Optional[FeaturesSpecification] = None, + output_features: FeaturesSpecification | None = None, *, - morph_operation: Union[MorphologicalOperations, Callable], - struct_elem: Optional[np.ndarray] = None, + morph_operation: MorphologicalOperations | Callable, + struct_elem: np.ndarray | None = None, ): """ :param input_features: Input features to be processed. diff --git a/geometry/eolearn/geometry/transformations.py b/geometry/eolearn/geometry/transformations.py index b2750f65a..4e2bd5135 100644 --- a/geometry/eolearn/geometry/transformations.py +++ b/geometry/eolearn/geometry/transformations.py @@ -9,14 +9,14 @@ from __future__ import annotations import datetime as dt -import functools +import itertools as it import logging import warnings -from typing import Any, Callable, Iterator, List, Optional, Tuple, Union, cast +from functools import partial +from typing import Any, Callable, Iterator, Tuple, cast import numpy as np import pandas as pd -import pyproj import rasterio.features import rasterio.transform import shapely.geometry @@ -49,30 +49,28 @@ class VectorToRasterTask(EOTask): # A mapping between types that are not supported by rasterio into types that are. After rasterization the task # will cast results back into the original dtype. - _RASTERIO_BASIC_DTYPES_MAP = { + _RASTERIO_DTYPES_MAP = { bool: np.uint8, + np.dtype(bool): np.uint8, np.int8: np.int16, + np.dtype(np.int8): np.int16, float: np.float64, - } - _RASTERIO_DTYPES_MAP = { - dtype: rasterio_type - for basic_type, rasterio_type in _RASTERIO_BASIC_DTYPES_MAP.items() - for dtype in [basic_type, np.dtype(basic_type)] + np.dtype(float): np.float64, } def __init__( self, - vector_input: Union[GeoDataFrame, SingleFeatureSpec], + vector_input: GeoDataFrame | SingleFeatureSpec, raster_feature: SingleFeatureSpec, *, - values: Union[None, float, List[float]] = None, - values_column: Optional[str] = None, - raster_shape: Union[None, Tuple[int, int], SingleFeatureSpec] = None, - raster_resolution: Union[None, float, Tuple[float, float]] = None, - raster_dtype: Union[np.dtype, type] = np.uint8, + values: None | float | list[float] = None, + values_column: str | None = None, + raster_shape: None | tuple[int, int] | SingleFeatureSpec = None, + raster_resolution: None | float | tuple[float, float] = None, + raster_dtype: np.dtype | type = np.uint8, no_data_value: float = 0, write_to_existing: bool = False, - overlap_value: Optional[float] = None, + overlap_value: float | None = None, buffer: float = 0, **rasterio_params: Any, ): @@ -103,9 +101,10 @@ def __init__( available parameters are `all_touched` and `merge_alg` """ self.vector_input, self.raster_feature = self._parse_main_params(vector_input, raster_feature) + self._rasterize_per_timestamp = self.raster_feature[0].is_temporal() if _vector_is_timeless(self.vector_input) and not self.raster_feature[0].is_timeless(): - raise ValueError("Vector input has no time-dependence but a time-dependent raster feature was selected") + raise ValueError("Vector input has no time-dependence but a time-dependent output feature was selected") self.values = values self.values_column = values_column @@ -124,11 +123,9 @@ def __init__( self.overlap_value = overlap_value self.buffer = buffer - self._rasterize_per_timestamp = self.raster_feature[0].is_temporal() - def _parse_main_params( - self, vector_input: Union[GeoDataFrame, SingleFeatureSpec], raster_feature: SingleFeatureSpec - ) -> Tuple[Union[GeoDataFrame, Tuple[FeatureType, str]], Tuple[FeatureType, str]]: + self, vector_input: GeoDataFrame | SingleFeatureSpec, raster_feature: SingleFeatureSpec + ) -> tuple[GeoDataFrame | tuple[FeatureType, str], tuple[FeatureType, str]]: """Parsing first 2 parameters - what vector data will be used and in which raster feature it will be saved""" if not _is_geopandas_object(vector_input): vector_input = self.parse_feature(vector_input, allowed_feature_types=lambda fty: fty.is_vector()) @@ -138,63 +135,33 @@ def _parse_main_params( def _get_vector_data_iterator( self, eopatch: EOPatch, join_per_value: bool - ) -> Iterator[Tuple[Optional[dt.datetime], Optional[ShapeIterator]]]: + ) -> Iterator[tuple[dt.datetime | None, ShapeIterator]]: """Collects and prepares vector shapes for rasterization. It works as an iterator that returns pairs of `(timestamp or None, )` :param eopatch: An EOPatch from where geometries will be obtained :param join_per_value: If `True` it will join geometries with the same value using a cascaded union """ - vector_data = self._get_vector_data_from_eopatch(eopatch) + vector_data = self.vector_input if _is_geopandas_object(self.vector_input) else eopatch[self.vector_input] # EOPatch has a bbox, verified in execute vector_data = self._preprocess_vector_data(vector_data, cast(BBox, eopatch.bbox), eopatch.timestamps) if self._rasterize_per_timestamp: - for timestamp, vector_data_per_timestamp in vector_data.groupby(TIMESTAMP_COLUMN): - yield timestamp.to_pydatetime(), self._vector_data_to_shape_iterator( - vector_data_per_timestamp, join_per_value - ) - else: + for timestamp, data_for_time in vector_data.groupby(TIMESTAMP_COLUMN): + if not data_for_time.empty: + yield timestamp.to_pydatetime(), self._vector_data_to_shape_iterator(data_for_time, join_per_value) + elif not vector_data.empty: yield None, self._vector_data_to_shape_iterator(vector_data, join_per_value) - def _get_vector_data_from_eopatch(self, eopatch: EOPatch) -> GeoDataFrame: - """Provides a vector dataframe either from the attribute or from given EOPatch feature""" - if _is_geopandas_object(self.vector_input): - return self.vector_input - - return eopatch[self.vector_input] - def _preprocess_vector_data( - self, vector_data: GeoDataFrame, bbox: BBox, timestamps: List[dt.datetime] + self, vector_data: GeoDataFrame, bbox: BBox, timestamps: list[dt.datetime] ) -> GeoDataFrame: """Applies preprocessing steps on a dataframe with geometries and potential values and timestamps""" - columns_to_keep = ["geometry"] - if self._rasterize_per_timestamp: - columns_to_keep.append(TIMESTAMP_COLUMN) - if self.values_column is not None: - columns_to_keep.append(self.values_column) - vector_data = vector_data[columns_to_keep] - - if self._rasterize_per_timestamp: - vector_data[TIMESTAMP_COLUMN] = vector_data[TIMESTAMP_COLUMN].apply(parse_time) - vector_data = vector_data[vector_data[TIMESTAMP_COLUMN].isin(timestamps)] - - if self.values_column is not None and self.values is not None: - values = [self.values] if isinstance(self.values, (int, float)) else self.values - vector_data = vector_data[vector_data[self.values_column].isin(values)] + vector_data = self._reduce_vector_data(vector_data, timestamps) - gpd_crs = vector_data.crs - # This special case has to be handled because of WGS84 and lat-lon order: - if isinstance(gpd_crs, pyproj.CRS): - gpd_crs = gpd_crs.to_epsg() - vector_data_crs = CRS(gpd_crs) - - if bbox.crs is not vector_data_crs: + if bbox.crs is not CRS(vector_data.crs.to_epsg()): warnings.warn( - ( - "Vector data is not in the same CRS as EOPatch, this task will re-project vector data for " - "each execution" - ), + "Vector data is not in the same CRS as EOPatch, the task will re-project vectors for each execution", EORuntimeWarning, ) vector_data = vector_data.to_crs(bbox.crs.pyproj_crs()) @@ -207,53 +174,58 @@ def _preprocess_vector_data( vector_data = vector_data[~vector_data.is_empty] if not vector_data.geometry.is_valid.all(): - warnings.warn("Given vector polygons contain some invalid geometries, they will be fixed", EORuntimeWarning) + warnings.warn("Given vector polygons contain some invalid geometries, attempting to fix", EORuntimeWarning) vector_data.geometry = vector_data.geometry.buffer(0) if vector_data.geometry.has_z.any(): - warnings.warn( - "Given vector polygons contain some 3D geometries, they will be projected to 2D", EORuntimeWarning - ) - vector_data.geometry = vector_data.geometry.map( - functools.partial(shapely.ops.transform, lambda *args: args[:2]) - ) + warnings.warn("Polygons contain 3D geometries, they will be projected to 2D", EORuntimeWarning) + vector_data.geometry = vector_data.geometry.map(partial(shapely.ops.transform, lambda *args: args[:2])) return vector_data - def _vector_data_to_shape_iterator( - self, vector_data: GeoDataFrame, join_per_value: bool - ) -> Optional[ShapeIterator]: - """Returns an iterator of pairs `(shape, value)` or `None` if given dataframe is empty""" - if vector_data.empty: - return None + def _reduce_vector_data(self, vector_data: GeoDataFrame, timestamps: list[dt.datetime]) -> GeoDataFrame: + """Removes all redundant columns and rows.""" + columns_to_keep = ["geometry"] + if self._rasterize_per_timestamp: + columns_to_keep.append(TIMESTAMP_COLUMN) + if self.values_column is not None: + columns_to_keep.append(self.values_column) + vector_data = vector_data[columns_to_keep] + if self._rasterize_per_timestamp: + vector_data[TIMESTAMP_COLUMN] = vector_data[TIMESTAMP_COLUMN].apply(parse_time) + vector_data = vector_data[vector_data[TIMESTAMP_COLUMN].isin(timestamps)] + + if self.values_column is not None and self.values is not None: + values = [self.values] if isinstance(self.values, (int, float)) else self.values + vector_data = vector_data[vector_data[self.values_column].isin(values)] + return vector_data + + def _vector_data_to_shape_iterator(self, vector_data: GeoDataFrame, join_per_value: bool) -> ShapeIterator: if self.values_column is None: value = cast(float, self.values) # cast is checked at init - return zip(vector_data.geometry, [value] * len(vector_data.index)) + return zip(vector_data.geometry, it.repeat(value)) + values = vector_data[self.values_column] if join_per_value: - classes = np.unique(vector_data[self.values_column]) - grouped = (vector_data.geometry[vector_data[self.values_column] == cl] for cl in classes) + groups = {val: vector_data.geometry[values == val] for val in np.unique(values)} join_function = shapely.ops.unary_union if shapely.__version__ >= "1.8.0" else shapely.ops.cascaded_union - grouped = (join_function(group) for group in grouped) - return zip(grouped, classes) + return ((join_function(group), val) for val, group in groups.items()) - return zip(vector_data.geometry, vector_data[self.values_column]) + return zip(vector_data.geometry, values) - def _get_raster_shape(self, eopatch: EOPatch) -> Tuple[int, int]: + def _get_raster_shape(self, eopatch: EOPatch) -> tuple[int, int]: """Determines the shape of new raster feature, returns a pair (height, width)""" if isinstance(self.raster_shape, (tuple, list)) and len(self.raster_shape) == 2: if isinstance(self.raster_shape[0], int) and isinstance(self.raster_shape[1], int): return self.raster_shape - feature_type, feature_name = self.parse_feature( - self.raster_shape, allowed_feature_types=lambda fty: fty.is_array() - ) - return eopatch.get_spatial_dimension(feature_type, cast(str, feature_name)) # cast verified in parser + ftype, fname = self.parse_feature(self.raster_shape, allowed_feature_types=lambda fty: fty.is_array()) + return eopatch.get_spatial_dimension(ftype, cast(str, fname)) # cast verified in parser if self.raster_resolution: # parsing from strings is not denoted in types, so an explicit upcast is required - raw_resolution: Union[str, float, Tuple[float, float]] = self.raster_resolution + raw_resolution: str | float | tuple[float, float] = self.raster_resolution resolution = float(raw_resolution.strip("m")) if isinstance(raw_resolution, str) else raw_resolution width, height = bbox_to_dimensions(cast(BBox, eopatch.bbox), resolution) # cast verified in execute @@ -263,22 +235,18 @@ def _get_raster_shape(self, eopatch: EOPatch) -> Tuple[int, int]: def _get_raster(self, eopatch: EOPatch, height: int, width: int) -> np.ndarray: """Provides raster into which data will be written""" - feature_type, feature_name = self.raster_feature raster_shape = (len(eopatch.timestamps), height, width) if self._rasterize_per_timestamp else (height, width) - if self.write_to_existing and feature_name in eopatch[feature_type]: + if self.write_to_existing and self.raster_feature in eopatch: raster = eopatch[self.raster_feature] - expected_full_shape = raster_shape + (1,) + expected_full_shape = (*raster_shape, 1) if raster.shape != expected_full_shape: - warnings.warn( - ( - f"The existing raster feature {self.raster_feature} has a shape {raster.shape} but " - f"the expected shape is {expected_full_shape}. This might cause errors or unexpected " - "results." - ), - EORuntimeWarning, + msg = ( + f"The existing raster feature {self.raster_feature} has a shape {raster.shape} but the expected" + f" shape is {expected_full_shape}. This might cause errors or unexpected results." ) + warnings.warn(msg, EORuntimeWarning) return raster.squeeze(axis=-1) @@ -291,7 +259,7 @@ def _get_rasterization_function(self, bbox: BBox, height: int, width: int) -> Ca base_rasterize_func = rasterio.features.rasterize if self.overlap_value is None else self.rasterize_overlapped - return functools.partial(base_rasterize_func, **rasterize_params) + return partial(base_rasterize_func, **rasterize_params) def rasterize_overlapped(self, shapes: ShapeIterator, out: np.ndarray, **rasterize_args: Any) -> None: """Rasterize overlapped classes. @@ -335,9 +303,6 @@ def execute(self, eopatch: EOPatch) -> EOPatch: timestamp_to_index = {timestamp: index for index, timestamp in enumerate(eopatch.timestamps)} for timestamp, shape_iterator in vector_data_iterator: - if shape_iterator is None: - continue - if timestamp is None: rasterize_func(shape_iterator, out=raster) else: @@ -368,9 +333,9 @@ def __init__( self, features: FeaturesSpecification, *, - values: Optional[List[int]] = None, + values: list[int] | None = None, values_column: str = "VALUE", - raster_dtype: Union[None, np.dtype, type] = None, + raster_dtype: None | np.dtype | type = None, **rasterio_params: Any, ): """ @@ -398,7 +363,7 @@ def __init__( self.rasterio_params = rasterio_params def _vectorize_single_raster( - self, raster: np.ndarray, affine_transform: Affine, crs: CRS, timestamps: Optional[dt.datetime] = None + self, raster: np.ndarray, affine_transform: Affine, crs: CRS, timestamps: dt.datetime | None = None ) -> GeoDataFrame: """Vectorizes a data slice of a single time component @@ -408,20 +373,14 @@ def _vectorize_single_raster( :param timestamp: Time of the data slice :return: Vectorized data """ - mask = None - if self.values: - mask = np.zeros(raster.shape, dtype=bool) - for value in self.values: - mask[raster == value] = True + mask = np.isin(raster, self.values) if self.values is not None else None geo_list = [] value_list = [] for idx in range(raster.shape[-1]): + idx_mask = None if mask is None else mask[..., idx] for geojson, value in rasterio.features.shapes( - raster[..., idx], - mask=None if mask is None else mask[..., idx], - transform=affine_transform, - **self.rasterio_params, + raster[..., idx], mask=idx_mask, transform=affine_transform, **self.rasterio_params ): geo_list.append(shapely.geometry.shape(geojson)) value_list.append(value) @@ -445,31 +404,27 @@ def execute(self, eopatch: EOPatch) -> EOPatch: """ if eopatch.bbox is None: raise ValueError("EOPatch has to have a bounding box") + crs = eopatch.bbox.crs - for raster_ft, raster_fn, vector_fn in self.feature_parser.get_renamed_features(eopatch): - vector_ft = FeatureType.VECTOR_TIMELESS if raster_ft.is_timeless() else FeatureType.VECTOR - - raster = eopatch[raster_ft][raster_fn] - height, width = raster.shape[:2] if raster_ft.is_timeless() else raster.shape[1:3] + for raster_type, raster_name, vector_name in self.feature_parser.get_renamed_features(eopatch): + vector_type = FeatureType.VECTOR_TIMELESS if raster_type.is_timeless() else FeatureType.VECTOR + raster = eopatch[raster_type, raster_name] + height, width = raster.shape[:2] if raster_type.is_timeless() else raster.shape[1:3] if self.raster_dtype: raster = raster.astype(self.raster_dtype) affine_transform = rasterio.transform.from_bounds(*eopatch.bbox, width=width, height=height) - crs = eopatch.bbox.crs - - if raster_ft.is_timeless(): - eopatch[vector_ft][vector_fn] = self._vectorize_single_raster(raster, affine_transform, crs) + if raster_type.is_timeless(): + eopatch[vector_type, vector_name] = self._vectorize_single_raster(raster, affine_transform, crs) else: gpd_list = [ - self._vectorize_single_raster( - raster[time_idx, ...], affine_transform, crs, timestamps=eopatch.timestamps[time_idx] - ) - for time_idx in range(raster.shape[0]) + self._vectorize_single_raster(raster[idx, ...], affine_transform, crs, eopatch.timestamps[idx]) + for idx in range(raster.shape[0]) ] - eopatch[vector_ft][vector_fn] = GeoDataFrame( + eopatch[vector_type, vector_name] = GeoDataFrame( pd.concat(gpd_list, ignore_index=True), crs=gpd_list[0].crs ) @@ -481,7 +436,7 @@ def _is_geopandas_object(data: object) -> bool: return isinstance(data, (GeoDataFrame, GeoSeries)) -def _vector_is_timeless(vector_input: Union[GeoDataFrame, Tuple[FeatureType, Any]]) -> bool: +def _vector_is_timeless(vector_input: GeoDataFrame | tuple[FeatureType, Any]) -> bool: """Used to check if the vector input (either geopandas object EOPatch Feature) is time independent""" if _is_geopandas_object(vector_input): return TIMESTAMP_COLUMN not in vector_input diff --git a/geometry/eolearn/tests/test_morphology.py b/geometry/eolearn/tests/test_morphology.py index bbd0352fa..df0dc25f1 100644 --- a/geometry/eolearn/tests/test_morphology.py +++ b/geometry/eolearn/tests/test_morphology.py @@ -16,6 +16,7 @@ CLASSES = (0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10) MASK_FEATURE = FeatureType.MASK, "mask" MASK_TIMELESS_FEATURE = FeatureType.MASK_TIMELESS, "timeless_mask" +# ruff: noqa: NPY002 @pytest.mark.parametrize("invalid_input", [None, 0, "a"]) diff --git a/geometry/eolearn/tests/test_superpixel.py b/geometry/eolearn/tests/test_superpixel.py index 641d38407..5a699ff62 100644 --- a/geometry/eolearn/tests/test_superpixel.py +++ b/geometry/eolearn/tests/test_superpixel.py @@ -16,25 +16,25 @@ @pytest.mark.parametrize( - "task, expected_statistics", - ( - [ + ("task", "expected_statistics"), + [ + ( SuperpixelSegmentationTask( (FeatureType.DATA, "BANDS-S2-L1C"), SUPERPIXEL_FEATURE, scale=100, sigma=0.5, min_size=100 ), {"exp_dtype": np.int64, "exp_min": 0, "exp_max": 25, "exp_mean": 10.6809, "exp_median": 11}, - ], - [ + ), + ( FelzenszwalbSegmentationTask( (FeatureType.DATA_TIMELESS, "MAX_NDVI"), SUPERPIXEL_FEATURE, scale=21, sigma=1.0, min_size=52 ), {"exp_dtype": np.int64, "exp_min": 0, "exp_max": 22, "exp_mean": 8.5302, "exp_median": 7}, - ], - [ + ), + ( FelzenszwalbSegmentationTask((FeatureType.MASK, "CLM"), SUPERPIXEL_FEATURE, scale=1, sigma=0, min_size=15), {"exp_dtype": np.int64, "exp_min": 0, "exp_max": 171, "exp_mean": 86.46267, "exp_median": 90}, - ], - [ + ), + ( SlicSegmentationTask( (FeatureType.DATA, "CLP"), SUPERPIXEL_FEATURE, @@ -44,8 +44,8 @@ sigma=0.8, ), {"exp_dtype": np.int64, "exp_min": 0, "exp_max": 48, "exp_mean": 24.6072, "exp_median": 25}, - ], - [ + ), + ( SlicSegmentationTask( (FeatureType.MASK_TIMELESS, "RANDOM_UINT8"), SUPERPIXEL_FEATURE, @@ -55,8 +55,8 @@ sigma=0.2, ), {"exp_dtype": np.int64, "exp_min": 0, "exp_max": 195, "exp_mean": 100.1844, "exp_median": 101}, - ], - ), + ), + ], ) def test_superpixel(test_eopatch, task, expected_statistics): task.execute(test_eopatch) diff --git a/geometry/eolearn/tests/test_transformations.py b/geometry/eolearn/tests/test_transformations.py index a248c0cd9..6d7d2fa8a 100644 --- a/geometry/eolearn/tests/test_transformations.py +++ b/geometry/eolearn/tests/test_transformations.py @@ -35,6 +35,9 @@ CUSTOM_DATAFRAME_3D.geometry = CUSTOM_DATAFRAME_3D.geometry.map(partial(shapely.ops.transform, lambda x, y: (x, y, 0))) +# ruff: noqa: PD008 + + @dataclasses.dataclass(frozen=True) class VectorToRasterTestCase: name: str diff --git a/geometry/setup.py b/geometry/setup.py index 6aa324636..d136e3551 100644 --- a/geometry/setup.py +++ b/geometry/setup.py @@ -7,9 +7,7 @@ def get_long_description(): this_directory = os.path.abspath(os.path.dirname(__file__)) with open(os.path.join(this_directory, "README.md"), encoding="utf-8") as f: - long_description = f.read() - - return long_description + return f.read() def parse_requirements(file): @@ -29,7 +27,7 @@ def get_version(): setup( name="eo-learn-geometry", - python_requires=">=3.7", + python_requires=">=3.8", version=get_version(), description="A collection of geometry EOTasks and utilities", long_description=get_long_description(), @@ -59,10 +57,10 @@ def get_version(): "Operating System :: Unix", "Programming Language :: Python", "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.7", "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", "Topic :: Scientific/Engineering", "Topic :: Scientific/Engineering :: GIS", "Topic :: Scientific/Engineering :: Image Processing", diff --git a/install_all.py b/install_all.py index 8bba27aa0..b4d761f10 100644 --- a/install_all.py +++ b/install_all.py @@ -24,7 +24,7 @@ def pip_command(name, args): args = [arg for arg in args if not arg.startswith(".")] - subprocess.check_call([sys.executable, "-m", "pip", "install"] + args + [f"./{name}"]) + subprocess.check_call([sys.executable, "-m", "pip", "install", *args, f"./{name}"]) if __name__ == "__main__": diff --git a/io/MANIFEST.in b/io/MANIFEST.in index 53aa1aeeb..7802b675f 100644 --- a/io/MANIFEST.in +++ b/io/MANIFEST.in @@ -1,6 +1,7 @@ include requirements*.txt include LICENSE include README.md +include eolearn/io/py.typed exclude eolearn/tests/* exclude eolearn/tests/test_extra/* exclude eolearn/tests/TestInputs/* diff --git a/io/eolearn/io/__init__.py b/io/eolearn/io/__init__.py index be9df7f56..6f18b58c3 100644 --- a/io/eolearn/io/__init__.py +++ b/io/eolearn/io/__init__.py @@ -13,4 +13,4 @@ get_available_timestamps, ) -__version__ = "1.4.1" +__version__ = "1.4.2" diff --git a/io/eolearn/io/extra/geodb.py b/io/eolearn/io/extra/geodb.py index c73447007..b3cac3a07 100644 --- a/io/eolearn/io/extra/geodb.py +++ b/io/eolearn/io/extra/geodb.py @@ -8,8 +8,9 @@ This source code is licensed under the MIT license, see the LICENSE file in the root directory of this source tree. """ +from __future__ import annotations -from typing import Any, Optional +from typing import Any from sentinelhub import CRS, BBox @@ -47,7 +48,7 @@ def __init__( self.geodb_db = geodb_db self.geodb_collection = geodb_collection self.geodb_kwargs = kwargs - self._dataset_crs: Optional[CRS] = None + self._dataset_crs: CRS | None = None super().__init__(feature=feature, reproject=reproject, clip=clip) @@ -63,7 +64,7 @@ def dataset_crs(self) -> CRS: return self._dataset_crs - def _load_vector_data(self, bbox: Optional[BBox]) -> Any: + def _load_vector_data(self, bbox: BBox | None) -> Any: """Loads vector data from geoDB table""" prepared_bbox = bbox.transform_bounds(self.dataset_crs).geometry.bounds if bbox else None diff --git a/io/eolearn/io/extra/meteoblue.py b/io/eolearn/io/extra/meteoblue.py index 66a8e9041..48c857d40 100644 --- a/io/eolearn/io/extra/meteoblue.py +++ b/io/eolearn/io/extra/meteoblue.py @@ -12,9 +12,11 @@ This source code is licensed under the MIT license, see the LICENSE file in the root directory of this source tree. """ +from __future__ import annotations + import datetime as dt from abc import ABCMeta, abstractmethod -from typing import Any, List, Optional, Tuple +from typing import Any import dateutil.parser import geopandas as gpd @@ -41,12 +43,12 @@ class BaseMeteoblueTask(EOTask, metaclass=ABCMeta): def __init__( self, - feature: Tuple[FeatureType, str], + feature: tuple[FeatureType, str], apikey: str, - query: Optional[dict] = None, - units: Optional[dict] = None, - time_difference: dt.timedelta = dt.timedelta(minutes=30), # noqa: B008 - cache_folder: Optional[str] = None, + query: dict | None = None, + units: dict | None = None, + time_difference: dt.timedelta = dt.timedelta(minutes=30), # noqa: B008, RUF100 + cache_folder: str | None = None, cache_max_age: int = 604800, ): """ @@ -72,7 +74,7 @@ def __init__( self.time_difference = time_difference @staticmethod - def _get_modified_eopatch(eopatch: Optional[EOPatch], bbox: Optional[BBox]) -> Tuple[EOPatch, BBox]: + def _get_modified_eopatch(eopatch: EOPatch | None, bbox: BBox | None) -> tuple[EOPatch, BBox]: if bbox is not None: if eopatch is None: eopatch = EOPatch(bbox=bbox) @@ -86,7 +88,7 @@ def _get_modified_eopatch(eopatch: Optional[EOPatch], bbox: Optional[BBox]) -> T raise ValueError("Bounding box is not provided") return eopatch, eopatch.bbox - def _prepare_time_intervals(self, eopatch: EOPatch, time_interval: Optional[RawTimeIntervalType]) -> List[str]: + def _prepare_time_intervals(self, eopatch: EOPatch, time_interval: RawTimeIntervalType | None) -> list[str]: """Prepare a list of time intervals for which data will be collected from meteoblue services""" if not eopatch.timestamps and not time_interval: raise ValueError( @@ -98,7 +100,7 @@ def _prepare_time_intervals(self, eopatch: EOPatch, time_interval: Optional[RawT return [f"{serialized_start_time}/{serialized_end_time}"] timestamps = eopatch.timestamps - time_intervals: List[str] = [] + time_intervals: list[str] = [] for timestamp in timestamps: start_time = timestamp - self.time_difference end_time = timestamp + self.time_difference @@ -110,16 +112,16 @@ def _prepare_time_intervals(self, eopatch: EOPatch, time_interval: Optional[RawT return time_intervals @abstractmethod - def _get_data(self, query: dict) -> Tuple[Any, List[dt.datetime]]: + def _get_data(self, query: dict) -> tuple[Any, list[dt.datetime]]: """It should return an output feature object and a list of timestamps""" def execute( self, - eopatch: Optional[EOPatch] = None, + eopatch: EOPatch | None = None, *, - query: Optional[dict] = None, - bbox: Optional[BBox] = None, - time_interval: Optional[RawTimeIntervalType] = None, + query: dict | None = None, + bbox: BBox | None = None, + time_interval: RawTimeIntervalType | None = None, ) -> EOPatch: """Execute method that adds new meteoblue data into an EOPatch @@ -167,7 +169,7 @@ class MeteoblueVectorTask(BaseMeteoblueTask): A meteoblue API key is required to retrieve data. """ - def _get_data(self, query: dict) -> Tuple[gpd.GeoDataFrame, List[dt.datetime]]: + def _get_data(self, query: dict) -> tuple[gpd.GeoDataFrame, list[dt.datetime]]: """Provides a GeoDataFrame with information about weather control points and an empty list of timestamps""" result = self.client.querySync(query) dataframe = meteoblue_to_dataframe(result) @@ -187,7 +189,7 @@ class MeteoblueRasterTask(BaseMeteoblueTask): A meteoblue API key is required to retrieve data. """ - def _get_data(self, query: dict) -> Tuple[np.ndarray, List[dt.datetime]]: + def _get_data(self, query: dict) -> tuple[np.ndarray, list[dt.datetime]]: """Return a 4-dimensional numpy array of shape (time, height, width, weather variables) and a list of timestamps """ @@ -208,7 +210,7 @@ def meteoblue_to_dataframe(result: Any) -> pd.DataFrame: code_names = [f"{code.code}_{code.level}_{code.aggregation}" for code in geometry.codes] if not geometry.timeIntervals: - return pd.DataFrame(columns=[TIMESTAMP_COLUMN, "Longitude", "Latitude"] + code_names) + return pd.DataFrame(columns=[TIMESTAMP_COLUMN, "Longitude", "Latitude", *code_names]) dataframes = [] for index, time_interval in enumerate(geometry.timeIntervals): @@ -271,12 +273,12 @@ def map_code(code: Any) -> np.ndarray: return data.transpose((1, 2, 3, 0)) -def _meteoblue_timestamps_from_geometry(geometry_pb: Any) -> List[dt.datetime]: +def _meteoblue_timestamps_from_geometry(geometry_pb: Any) -> list[dt.datetime]: """Transforms a protobuf geometry object into a list of datetime objects""" return list(pd.core.common.flatten(map(_meteoblue_timestamps_from_time_interval, geometry_pb.timeIntervals))) -def _meteoblue_timestamps_from_time_interval(timestamp_pb: Any) -> List[dt.datetime]: +def _meteoblue_timestamps_from_time_interval(timestamp_pb: Any) -> list[dt.datetime]: """Transforms a protobuf timestamp object into a list of datetime objects""" if timestamp_pb.timestrings: # Time intervals like weekly data, return an `array of strings` as timestamps diff --git a/io/eolearn/io/geometry_io.py b/io/eolearn/io/geometry_io.py index 3dfc8f2d7..19368aec3 100644 --- a/io/eolearn/io/geometry_io.py +++ b/io/eolearn/io/geometry_io.py @@ -7,9 +7,12 @@ This source code is licensed under the MIT license, see the LICENSE file in the root directory of this source tree. """ +from __future__ import annotations + import abc import logging -from typing import Any, Optional, Union +from contextlib import nullcontext +from typing import Any import boto3 import fiona @@ -31,7 +34,7 @@ class _BaseVectorImportTask(EOTask, metaclass=abc.ABCMeta): """Base Vector Import Task, implementing common methods""" def __init__( - self, feature: FeatureSpec, reproject: bool = True, clip: bool = False, config: Optional[SHConfig] = None + self, feature: FeatureSpec, reproject: bool = True, clip: bool = False, config: SHConfig | None = None ): """ :param feature: A vector feature into which to import data @@ -45,10 +48,10 @@ def __init__( self.clip = clip @abc.abstractmethod - def _load_vector_data(self, bbox: Optional[BBox]) -> gpd.GeoDataFrame: + def _load_vector_data(self, bbox: BBox | None) -> gpd.GeoDataFrame: """Loads vector data given a bounding box""" - def _reproject_and_clip(self, vectors: gpd.GeoDataFrame, bbox: Optional[BBox]) -> gpd.GeoDataFrame: + def _reproject_and_clip(self, vectors: gpd.GeoDataFrame, bbox: BBox | None) -> gpd.GeoDataFrame: """Method to reproject and clip vectors to the EOPatch crs and bbox""" if self.reproject: @@ -70,7 +73,7 @@ def _reproject_and_clip(self, vectors: gpd.GeoDataFrame, bbox: Optional[BBox]) - return vectors - def execute(self, eopatch: Optional[EOPatch] = None, *, bbox: Optional[BBox] = None) -> EOPatch: + def execute(self, eopatch: EOPatch | None = None, *, bbox: BBox | None = None) -> EOPatch: """ :param eopatch: An existing EOPatch. If none is provided it will create a new one. :param bbox: A bounding box for which to load data. By default, if none is provided, it will take a bounding box @@ -102,8 +105,8 @@ def __init__( path: str, reproject: bool = True, clip: bool = False, - filesystem: Optional[FS] = None, - config: Optional[SHConfig] = None, + filesystem: FS | None = None, + config: SHConfig | None = None, **kwargs: Any, ): """ @@ -125,7 +128,7 @@ def __init__( self.fiona_kwargs = kwargs self._aws_session = None - self._dataset_crs: Optional[CRS] = None + self._dataset_crs: CRS | None = None super().__init__(feature=feature, reproject=reproject, clip=clip, config=config) @@ -152,28 +155,19 @@ def aws_session(self) -> AWSSession: return self._aws_session @property - def dataset_crs(self) -> Optional[CRS]: - """Provides a CRS of dataset, it loads it lazily (i.e. the first time it is needed) - - :return: Dataset's CRS - """ + def dataset_crs(self) -> CRS: + """Provides a CRS of dataset, it loads it lazily (i.e. the first time it is needed)""" if self._dataset_crs is None: - if self.full_path.startswith("s3://"): - with fiona.Env(session=self.aws_session): - self._read_crs() - else: - self._read_crs() + is_on_s3 = self.full_path.startswith("s3://") + with fiona.Env(session=self.aws_session) if is_on_s3 else nullcontext(): + with fiona.open(self.full_path, **self.fiona_kwargs) as features: + self._dataset_crs = CRS(features.crs) return self._dataset_crs - def _read_crs(self) -> None: - """Reads information about CRS from a dataset""" - with fiona.open(self.full_path, **self.fiona_kwargs) as features: - self._dataset_crs = CRS(features.crs) - - def _load_vector_data(self, bbox: Optional[BBox]) -> gpd.GeoDataFrame: + def _load_vector_data(self, bbox: BBox | None) -> gpd.GeoDataFrame: """Loads vector data either from S3 or local path""" - bbox_bounds = bbox.transform_bounds(self.dataset_crs).geometry.bounds if bbox and self.dataset_crs else None + bbox_bounds = bbox.transform_bounds(self.dataset_crs).geometry.bounds if bbox else None if self.full_path.startswith("s3://"): with fiona.Env(session=self.aws_session), fiona.open(self.full_path, **self.fiona_kwargs) as features: @@ -181,8 +175,8 @@ def _load_vector_data(self, bbox: Optional[BBox]) -> gpd.GeoDataFrame: return gpd.GeoDataFrame.from_features( feature_iter, - columns=list(features.schema["properties"]) + ["geometry"], - crs=self.dataset_crs.pyproj_crs() if self.dataset_crs else None, + columns=list(features.schema["properties"]) + ["geometry"], # noqa: RUF005 + crs=self.dataset_crs.pyproj_crs(), ) return gpd.read_file(self.full_path, bbox=bbox_bounds, **self.fiona_kwargs) @@ -194,7 +188,7 @@ class GeopediaVectorImportTask(_BaseVectorImportTask): def __init__( self, feature: FeatureSpec, - geopedia_table: Union[str, int], + geopedia_table: str | int, reproject: bool = True, clip: bool = False, **kwargs: Any, @@ -208,10 +202,10 @@ def __init__( """ self.geopedia_table = geopedia_table self.geopedia_kwargs = kwargs - self.dataset_crs: Optional[CRS] = None + self.dataset_crs: CRS | None = None super().__init__(feature=feature, reproject=reproject, clip=clip) - def _load_vector_data(self, bbox: Optional[BBox]) -> gpd.GeoDataFrame: + def _load_vector_data(self, bbox: BBox | None) -> gpd.GeoDataFrame: """Loads vector data from geopedia table""" prepared_bbox = bbox.transform_bounds(CRS.POP_WEB) if bbox else None diff --git a/io/eolearn/io/py.typed b/io/eolearn/io/py.typed new file mode 100644 index 000000000..e69de29bb diff --git a/io/eolearn/io/raster_io.py b/io/eolearn/io/raster_io.py index 7c47bd93f..d2a446408 100644 --- a/io/eolearn/io/raster_io.py +++ b/io/eolearn/io/raster_io.py @@ -6,12 +6,14 @@ This source code is licensed under the MIT license, see the LICENSE file in the root directory of this source tree. """ +from __future__ import annotations + import datetime as dt import functools import logging import warnings from abc import ABCMeta -from typing import Any, BinaryIO, List, Optional, Tuple, Union +from typing import Any, BinaryIO import fs import numpy as np @@ -37,7 +39,7 @@ LOGGER = logging.getLogger(__name__) -class BaseRasterIoTask(IOTask, metaclass=ABCMeta): # noqa: B024 +class BaseRasterIoTask(IOTask, metaclass=ABCMeta): """Base abstract class for raster IO tasks""" def __init__( @@ -45,19 +47,18 @@ def __init__( feature: SingleFeatureSpec, folder: str, *, - filesystem: Optional[FS] = None, - image_dtype: Optional[Union[np.dtype, type]] = None, - no_data_value: Optional[float] = None, + filesystem: FS | None = None, + image_dtype: np.dtype | type | None = None, + no_data_value: float | None = None, create: bool = False, - config: Optional[SHConfig] = None, + config: SHConfig | None = None, ): """ :param feature: Feature which will be exported or imported :param folder: A path to a main folder containing all image, potentially in its subfolders. If `filesystem` parameter is defined, then `folder` should be a path relative to filesystem object. Otherwise, it should be an absolute path. - :param filesystem: An existing filesystem object. If not given it will be initialized according to `folder` - parameter. + :param filesystem: A filesystem object. If not given it will be initialized according to `folder` parameter. :param image_dtype: A data type of data in exported images or data imported from images. :param no_data_value: When exporting this is the NoData value of pixels in exported images. When importing this value is assigned to the pixels with NoData. @@ -75,63 +76,59 @@ def __init__( if filesystem is None: filesystem, folder = get_base_filesystem_and_path(folder, create=create, config=config) + # the super-class takes care of filesystem pickling super().__init__(folder, filesystem=filesystem, create=create, config=config) - def _get_filename_paths(self, filename_template: Union[str, List[str]], timestamps: List[dt.datetime]) -> List[str]: + def _get_filename_paths(self, filename_template: str | list[str], timestamps: list[dt.datetime]) -> list[str]: """From a filename "template" and base path on the filesystem it generates full paths to tiff files. The paths are still relative to the filesystem object. + + If the file extension is not provided, it will default to `.tif`. If a "*" wildcard or a datetime format + substring (e.g. "%Y%m%dT%H%M%S") is provided in the template, it returns multiple GeoTIFF paths where each one + will correspond to a single timestamp. Alternatively, a list of paths can be provided, one for each timestamp. """ if isinstance(filename_template, str): filename_path = fs.path.join(self.filesystem_path, filename_template) filename_paths = self._generate_paths(filename_path, timestamps) elif isinstance(filename_template, list): - filename_paths = [] - for timestamp_index, path in enumerate(filename_template): - filename_path = fs.path.join(self.filesystem_path, path) - if len(filename_template) == len(timestamps): - filename_paths.extend(self._generate_paths(filename_path, [timestamps[timestamp_index]])) - elif not timestamps: - filename_paths.extend(self._generate_paths(filename_path, timestamps)) - else: - raise ValueError( - "The number of provided timestamps does not match the number of provided filenames." - ) + if timestamps and len(filename_template) != len(timestamps): + raise ValueError("The number of provided timestamps does not match the number of provided filenames.") + + filenames = [] + for idx, path in enumerate(filename_template): + timestamps = [] if not timestamps else [timestamps[idx]] + filenames.extend(self._generate_paths(path, timestamps)) + + filename_paths = [fs.path.join(self.filesystem_path, path) for path in filenames] + else: - raise TypeError( - f"The 'filename' parameter must either be a list or a string, but {filename_template} found" - ) + raise TypeError(f"The `filename` parameter must be a list or a string, but {filename_template} found") if self._create_path: - paths_to_create = {fs.path.dirname(filename_path) for filename_path in filename_paths} - for filename_path in paths_to_create: - self.filesystem.makedirs(filename_path, recreate=True) + unique_folder_paths = {fs.path.dirname(filename_path) for filename_path in filename_paths} + for folder_path in unique_folder_paths: + self.filesystem.makedirs(folder_path, recreate=True) return filename_paths @classmethod - def _generate_paths(cls, path_template: str, timestamps: List[dt.datetime]) -> List[str]: + def _generate_paths(cls, path_template: str, timestamps: list[dt.datetime]) -> list[str]: """Uses a filename path template to create a list of actual filename paths.""" - if not cls._has_tiff_file_extension(path_template): + has_tiff_file_extensions = path_template.lower().endswith(".tif") or path_template.lower().endswith(".tiff") + if not has_tiff_file_extensions: path_template = f"{path_template}.tif" if not timestamps: return [path_template] - if "*" in path_template: - path_template = path_template.replace("*", "%Y%m%dT%H%M%S") + path_template = path_template.replace("*", "%Y%m%dT%H%M%S") - if timestamps[0].strftime(path_template) == path_template: + if timestamps[0].strftime(path_template) == path_template: # unaffected by timestamps return [path_template] return [timestamp.strftime(path_template) for timestamp in timestamps] - @staticmethod - def _has_tiff_file_extension(path: str) -> bool: - """Checks if path ends with a tiff file extension.""" - path = path.lower() - return path.endswith(".tif") or path.endswith(".tiff") - class ExportToTiffTask(BaseRasterIoTask): """Task exports specified feature to GeoTIFF. @@ -151,11 +148,11 @@ def __init__( feature: SingleFeatureSpec, folder: str, *, - date_indices: Union[List[int], Tuple[int, int], Tuple[dt.datetime, dt.datetime], Tuple[str, str], None] = None, - band_indices: Union[List[int], Tuple[int, int], None] = None, - crs: Union[CRS, int, str, None] = None, + date_indices: list[int] | tuple[int, int] | tuple[dt.datetime, dt.datetime] | tuple[str, str] | None = None, + band_indices: list[int] | tuple[int, int] | None = None, + crs: CRS | int | str | None = None, fail_on_missing: bool = True, - compress: Optional[str] = None, + compress: str | None = None, **kwargs: Any, ): """ @@ -185,7 +182,7 @@ def __init__( self.compress = compress def _prepare_image_array( - self, data_array: np.ndarray, timestamps: List[dt.datetime], feature: Tuple[FeatureType, str] + self, data_array: np.ndarray, timestamps: list[dt.datetime], feature: tuple[FeatureType, str] ) -> np.ndarray: """Collects a feature from EOPatch and prepares the array of an image which will be rasterized. The resulting array has shape (channels, height, width) and is of correct dtype. @@ -196,7 +193,6 @@ def _prepare_image_array( if feature_type.is_temporal(): data_array = self._reduce_by_time(data_array, timestamps) else: - # add temporal dimension data_array = np.expand_dims(data_array, axis=0) if not feature_type.is_spatial(): @@ -228,15 +224,15 @@ def _reduce_by_bands(self, array: np.ndarray) -> np.ndarray: raise ValueError(f"Invalid format in {self.band_indices}, expected tuple or list") - def _reduce_by_time(self, array: np.ndarray, timestamps: List[dt.datetime]) -> np.ndarray: + def _reduce_by_time(self, data_array: np.ndarray, timestamps: list[dt.datetime]) -> np.ndarray: """Reduce array by selecting a subset of times.""" if self.date_indices is None: - return array + return data_array if isinstance(self.date_indices, list): if [date for date in self.date_indices if not isinstance(date, int)]: raise ValueError(f"Invalid format in {self.date_indices} list, expected integers") - return array[np.array(self.date_indices), ...] + return data_array[np.array(self.date_indices), ...] if isinstance(self.date_indices, tuple): dates = np.array(timestamps) @@ -249,49 +245,38 @@ def _reduce_by_time(self, array: np.ndarray, timestamps: List[dt.datetime]) -> n start_date, end_date = start_idx, end_idx else: raise ValueError(f"Invalid format in {self.date_indices} tuple, expected ints, strings, or datetimes") - return array[np.nonzero(np.where((dates >= start_date) & (dates <= end_date), dates, 0))[0]] + return data_array[np.nonzero(np.where((dates >= start_date) & (dates <= end_date), dates, 0))[0]] raise ValueError(f"Invalid format in {self.date_indices}, expected tuple or list") - def _set_export_dtype(self, data_array: np.ndarray, feature: Tuple[FeatureType, str]) -> np.ndarray: + def _set_export_dtype(self, data_array: np.ndarray, feature: tuple[FeatureType, str]) -> np.ndarray: """To a given array it sets a dtype in which data will be exported""" image_dtype = data_array.dtype if self.image_dtype is None else self.image_dtype if image_dtype == np.int64: image_dtype = np.int32 warnings.warn( - ( - f"Data from feature {feature} cannot be exported to tiff with dtype numpy.int64. Will export " - "as numpy.int32 instead" - ), - EORuntimeWarning, + f"Cannot export {feature} with dtype numpy.int64. Will export as numpy.int32 instead", EORuntimeWarning ) - if image_dtype == data_array.dtype: - return data_array - return data_array.astype(image_dtype) + return data_array.astype(image_dtype) if image_dtype != data_array.dtype else data_array def _get_source_and_destination_params( self, data_array: np.ndarray, bbox: BBox - ) -> Tuple[Tuple[str, Affine], Tuple[str, Affine], Tuple[int, int]]: - """ - Calculates source and destination CRS and transforms. Additionally, it returns destination height and width - """ + ) -> tuple[tuple[str, Affine], tuple[str, Affine], tuple[int, int]]: + """Calculates source and destination CRS and transforms. Also returns destination height and width.""" _, height, width = data_array.shape src_crs = bbox.crs.ogc_string() src_transform = rasterio.transform.from_bounds(*bbox, width=width, height=height) - if self.crs: - dst_crs = self.crs.ogc_string() - dst_transform, dst_width, dst_height = rasterio.warp.calculate_default_transform( - src_crs, dst_crs, width, height, *bbox - ) - else: - dst_crs = src_crs - dst_transform = src_transform - dst_width, dst_height = width, height + if self.crs is None: + return (src_crs, src_transform), (src_crs, src_transform), (height, width) + dst_crs = self.crs.ogc_string() + dst_transform, dst_width, dst_height = rasterio.warp.calculate_default_transform( + src_crs, dst_crs, width, height, *bbox + ) return (src_crs, src_transform), (dst_crs, dst_transform), (dst_height, dst_width) def _export_tiff( @@ -300,6 +285,7 @@ def _export_tiff( filesystem: FS, path: str, channel_count: int, + *, dst_crs: str, dst_transform: Affine, dst_height: int, @@ -308,7 +294,7 @@ def _export_tiff( src_transform: Affine, ) -> None: """Export an EOPatch feature to tiff based on input channel range.""" - with rasterio.Env(), filesystem.openbin(path, "w") as file_handle: # noqa: SIM117 + with rasterio.Env(), filesystem.openbin(path, "w") as file_handle: with rasterio.open( file_handle, "w", @@ -336,7 +322,7 @@ def _export_tiff( resampling=rasterio.warp.Resampling.nearest, ) - def execute(self, eopatch: EOPatch, *, filename: Union[str, List[str], None] = "") -> EOPatch: + def execute(self, eopatch: EOPatch, *, filename: str | list[str] | None = "") -> EOPatch: """Execute method :param eopatch: An input EOPatch @@ -352,27 +338,23 @@ def execute(self, eopatch: EOPatch, *, filename: Union[str, List[str], None] = " return eopatch if self.feature not in eopatch: - error_msg = f"Feature {self.feature[1]} of type {self.feature[0]} was not found in EOPatch" + error_msg = f"Feature {self.feature} was not found in EOPatch" LOGGER.warning(error_msg) if self.fail_on_missing: raise ValueError(error_msg) return eopatch if eopatch.bbox is None: - raise ValueError( - "Given EOPatch is missing a bounding box and therefore no feature can be exported to GeoTIFF" - ) + raise ValueError("EOPatch without a bounding box encountered, cannot export to GeoTIFF") image_array = self._prepare_image_array(eopatch[self.feature], eopatch.timestamps, self.feature) - ( - (src_crs, src_transform), - (dst_crs, dst_transform), - (dst_height, dst_width), - ) = self._get_source_and_destination_params(image_array, eopatch.bbox) + src_info, dst_info, (dst_height, dst_width) = self._get_source_and_destination_params(image_array, eopatch.bbox) + src_crs, src_transform = src_info + dst_crs, dst_transform = dst_info filename_paths = self._get_filename_paths(filename, eopatch.timestamps) - with self.filesystem as filesystem: + with self.filesystem as filesystem: # no worries about `close`, filesystem is freshly unpickled by the property export_function = functools.partial( self._export_tiff, filesystem=filesystem, @@ -415,7 +397,7 @@ def __init__( folder: str, *, use_vsi: bool = False, - timestamp_size: Optional[int] = None, + timestamp_size: int | None = None, **kwargs: Any, ): """ @@ -452,7 +434,7 @@ def _get_session(self, filesystem: FS) -> AWSSession: endpoint_url=filesystem.endpoint_url, ) - def _load_from_image(self, path: str, filesystem: FS, bbox: Optional[BBox]) -> Tuple[np.ndarray, Optional[BBox]]: + def _load_from_image(self, path: str, filesystem: FS, bbox: BBox | None) -> tuple[np.ndarray, BBox | None]: """The method decides in what way data will be loaded from the image. The method always uses `rasterio.Env` to suppress any low-level warnings. In case of a local filesystem @@ -475,7 +457,7 @@ def _load_from_image(self, path: str, filesystem: FS, bbox: Optional[BBox]) -> T with rasterio.Env(), filesystem.openbin(path, "r") as file_handle: return self._read_image(file_handle, bbox) - def _read_image(self, file_object: Union[str, BinaryIO], bbox: Optional[BBox]) -> Tuple[np.ndarray, Optional[BBox]]: + def _read_image(self, file_object: str | BinaryIO, bbox: BBox | None) -> tuple[np.ndarray, BBox | None]: """Reads data from the image.""" src: DatasetReader with rasterio.open(file_object) as src: @@ -484,9 +466,7 @@ def _read_image(self, file_object: Union[str, BinaryIO], bbox: Optional[BBox]) - return src.read(window=read_window, boundless=boundless_reading, fill_value=self.no_data_value), read_bbox @staticmethod - def _get_reading_window_and_bbox( - reader: DatasetReader, bbox: Optional[BBox] - ) -> Tuple[Optional[Window], Optional[BBox]]: + def _get_reading_window_and_bbox(reader: DatasetReader, bbox: BBox | None) -> tuple[Window | None, BBox | None]: """Provides a reading window for which data will be read from image. If it returns `None` this means that the whole image should be read. Those cases are when bbox is not defined, image is not geo-referenced, or bbox coordinates exactly match image coordinates. Additionally, it provides a bounding box of reading window @@ -511,10 +491,10 @@ def _get_reading_window_and_bbox( return from_bounds(*iter(bbox), transform=image_transform), original_bbox - def _load_data(self, filename_paths: List[str], initial_bbox: Optional[BBox]) -> Tuple[np.ndarray, Optional[BBox]]: + def _load_data(self, filename_paths: list[str], initial_bbox: BBox | None) -> tuple[np.ndarray, BBox | None]: """Load data from images, join them, and provide their bounding box.""" - data_per_path: List[np.ndarray] = [] - final_bbox: Optional[BBox] = None + data_per_path = [] + final_bbox: BBox | None = None with self.filesystem as filesystem: for path in filename_paths: @@ -532,7 +512,7 @@ def _load_data(self, filename_paths: List[str], initial_bbox: Optional[BBox]) -> return np.concatenate(data_per_path, axis=0), final_bbox - def execute(self, eopatch: Optional[EOPatch] = None, *, filename: Optional[str] = "") -> EOPatch: + def execute(self, eopatch: EOPatch | None = None, *, filename: str | None = "") -> EOPatch: """Execute method which adds a new feature to the EOPatch. :param eopatch: input EOPatch or None if a new EOPatch should be created diff --git a/io/eolearn/io/sentinelhub_process.py b/io/eolearn/io/sentinelhub_process.py index 5dcda2bfc..c02352403 100644 --- a/io/eolearn/io/sentinelhub_process.py +++ b/io/eolearn/io/sentinelhub_process.py @@ -11,10 +11,9 @@ import datetime as dt import logging from abc import ABCMeta, abstractmethod -from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union, cast +from typing import Any, Callable, Iterable, List, Literal, Tuple, cast import numpy as np -from typing_extensions import Literal from sentinelhub import ( BBox, @@ -23,7 +22,6 @@ MimeType, MosaickingOrder, ResamplingType, - SentinelHubCatalog, SentinelHubDownloadClient, SentinelHubRequest, SentinelHubSession, @@ -32,12 +30,13 @@ filter_times, parse_time_interval, ) -from sentinelhub.data_collections_bands import Band +from sentinelhub.api.catalog import get_available_timestamps +from sentinelhub.evalscript import generate_evalscript, parse_data_collection_bands from sentinelhub.types import JsonDict, RawTimeIntervalType from eolearn.core import EOPatch, EOTask, FeatureType from eolearn.core.types import FeatureRenameSpec, FeatureSpec, FeaturesSpecification -from eolearn.core.utils.parsing import parse_renamed_features +from eolearn.core.utils.parsing import parse_renamed_feature, parse_renamed_features LOGGER = logging.getLogger(__name__) @@ -48,14 +47,14 @@ class SentinelHubInputBaseTask(EOTask, metaclass=ABCMeta): def __init__( self, data_collection: DataCollection, - size: Optional[Tuple[int, int]] = None, - resolution: Optional[Union[float, Tuple[float, float]]] = None, - cache_folder: Optional[str] = None, - config: Optional[SHConfig] = None, - max_threads: Optional[int] = None, - upsampling: Optional[ResamplingType] = None, - downsampling: Optional[ResamplingType] = None, - session_loader: Optional[Callable[[], SentinelHubSession]] = None, + size: tuple[int, int] | None = None, + resolution: float | tuple[float, float] | None = None, + cache_folder: str | None = None, + config: SHConfig | None = None, + max_threads: int | None = None, + upsampling: ResamplingType | None = None, + downsampling: ResamplingType | None = None, + session_loader: Callable[[], SentinelHubSession] | None = None, ): """ :param data_collection: A collection of requested satellite data. @@ -85,10 +84,10 @@ def __init__( def execute( self, - eopatch: Optional[EOPatch] = None, - bbox: Optional[BBox] = None, - time_interval: Optional[RawTimeIntervalType] = None, # should be kept at this to prevent code-breaks - geometry: Optional[Geometry] = None, + eopatch: EOPatch | None = None, + bbox: BBox | None = None, + time_interval: RawTimeIntervalType | None = None, # should be kept at this to prevent code-breaks + geometry: Geometry | None = None, ) -> EOPatch: """Main execute method for the Process API tasks. The `geometry` is used only in conjunction with the `bbox` and does not act as a replacement.""" @@ -130,7 +129,7 @@ def execute( return eopatch - def _get_size(self, bbox: BBox) -> Tuple[int, int]: + def _get_size(self, bbox: BBox) -> tuple[int, int]: """Get the size (width, height) for the request either from inputs, or from the (existing) eopatch""" if self.size is not None: return self.size @@ -141,7 +140,7 @@ def _get_size(self, bbox: BBox) -> Tuple[int, int]: raise ValueError("Size or resolution for the requests should be provided!") @staticmethod - def _consolidate_bbox(bbox: Optional[BBox], eopatch_bbox: Optional[BBox]) -> BBox: + def _consolidate_bbox(bbox: BBox | None, eopatch_bbox: BBox | None) -> BBox: if eopatch_bbox is None: if bbox is None: raise ValueError("Either the eopatch or the task must provide valid bbox.") @@ -152,23 +151,23 @@ def _consolidate_bbox(bbox: Optional[BBox], eopatch_bbox: Optional[BBox]) -> BBo raise ValueError("Either the eopatch or the task must provide bbox, or they must be the same.") @abstractmethod - def _extract_data(self, eopatch: EOPatch, responses: List[Any], shape: Tuple[int, ...]) -> EOPatch: + def _extract_data(self, eopatch: EOPatch, responses: list[Any], shape: tuple[int, ...]) -> EOPatch: """Extract data from the received images and assign them to eopatch features""" @abstractmethod def _build_requests( self, - bbox: Optional[BBox], + bbox: BBox | None, size_x: int, size_y: int, - timestamps: Optional[List[dt.datetime]], - time_interval: Optional[RawTimeIntervalType], - geometry: Optional[Geometry], - ) -> List[SentinelHubRequest]: + timestamps: list[dt.datetime] | None, + time_interval: RawTimeIntervalType | None, + geometry: Geometry | None, + ) -> list[SentinelHubRequest]: """Build requests""" @abstractmethod - def _get_timestamps(self, time_interval: Optional[RawTimeIntervalType], bbox: BBox) -> List[dt.datetime]: + def _get_timestamps(self, time_interval: RawTimeIntervalType | None, bbox: BBox) -> list[dt.datetime]: """Get the timestamp array needed as a parameter for downloading the images""" @@ -181,19 +180,19 @@ def __init__( features: FeaturesSpecification, evalscript: str, data_collection: DataCollection, - size: Optional[Tuple[int, int]] = None, - resolution: Optional[Union[float, Tuple[float, float]]] = None, - maxcc: Optional[float] = None, - time_difference: Optional[dt.timedelta] = None, - mosaicking_order: Optional[Union[str, MosaickingOrder]] = None, - cache_folder: Optional[str] = None, - config: Optional[SHConfig] = None, - max_threads: Optional[int] = None, - upsampling: Optional[ResamplingType] = None, - downsampling: Optional[ResamplingType] = None, - aux_request_args: Optional[dict] = None, - session_loader: Optional[Callable[[], SentinelHubSession]] = None, - timestamp_filter: Callable[[List[dt.datetime], dt.timedelta], List[dt.datetime]] = filter_times, + size: tuple[int, int] | None = None, + resolution: float | tuple[float, float] | None = None, + maxcc: float | None = None, + time_difference: dt.timedelta | None = None, + mosaicking_order: str | MosaickingOrder | None = None, + cache_folder: str | None = None, + config: SHConfig | None = None, + max_threads: int | None = None, + upsampling: ResamplingType | None = None, + downsampling: ResamplingType | None = None, + aux_request_args: dict | None = None, + session_loader: Callable[[], SentinelHubSession] | None = None, + timestamp_filter: Callable[[list[dt.datetime], dt.timedelta], list[dt.datetime]] = filter_times, ): """ :param features: Features to construct from the evalscript. @@ -243,7 +242,7 @@ def __init__( self.mosaicking_order = None if mosaicking_order is None else MosaickingOrder(mosaicking_order) self.aux_request_args = aux_request_args - def _parse_and_validate_features(self, features: FeaturesSpecification) -> List[FeatureRenameSpec]: + def _parse_and_validate_features(self, features: FeaturesSpecification) -> list[FeatureRenameSpec]: _features = parse_renamed_features( features, allowed_feature_types=lambda fty: fty.is_array() or fty == FeatureType.META_INFO ) @@ -254,7 +253,7 @@ def _parse_and_validate_features(self, features: FeaturesSpecification) -> List[ raise ValueError("Cannot mix time dependent and timeless requests!") - def _create_response_objects(self) -> List[JsonDict]: + def _create_response_objects(self) -> list[JsonDict]: """Construct SentinelHubRequest output_responses from features""" responses = [] for feat_type, feat_name, _ in self.features: @@ -269,33 +268,33 @@ def _create_response_objects(self) -> List[JsonDict]: return responses - def _get_timestamps(self, time_interval: Optional[RawTimeIntervalType], bbox: BBox) -> List[dt.datetime]: + def _get_timestamps(self, time_interval: RawTimeIntervalType | None, bbox: BBox) -> list[dt.datetime]: """Get the timestamp array needed as a parameter for downloading the images""" if any(feat_type.is_timeless() for feat_type, _, _ in self.features if feat_type.is_array()): return [] - return get_available_timestamps( + timestamps = get_available_timestamps( bbox=bbox, time_interval=time_interval, - timestamp_filter=self.timestamp_filter, data_collection=self.data_collection, maxcc=self.maxcc, - time_difference=self.time_difference, config=self.config, ) + return self.timestamp_filter(timestamps, self.time_difference) + def _build_requests( self, - bbox: Optional[BBox], + bbox: BBox | None, size_x: int, size_y: int, - timestamps: Optional[List[dt.datetime]], - time_interval: Optional[RawTimeIntervalType], - geometry: Optional[Geometry], - ) -> List[SentinelHubRequest]: + timestamps: list[dt.datetime] | None, + time_interval: RawTimeIntervalType | None, + geometry: Geometry | None, + ) -> list[SentinelHubRequest]: """Defines request timestamps and builds requests. In case `timestamps` is either `None` or an empty list it still has to create at least one request in order to obtain back number of bands of responses.""" - dates: List[Optional[Tuple[Optional[dt.datetime], Optional[dt.datetime]]]] + dates: list[tuple[dt.datetime | None, dt.datetime | None] | None] if timestamps: dates = [(date - self.time_difference, date + self.time_difference) for date in timestamps] elif timestamps is None: @@ -307,11 +306,11 @@ def _build_requests( def _create_sh_request( self, - time_interval: Optional[RawTimeIntervalType], - bbox: Optional[BBox], + time_interval: RawTimeIntervalType | None, + bbox: BBox | None, size_x: int, size_y: int, - geometry: Optional[Geometry], + geometry: Geometry | None, ) -> SentinelHubRequest: """Create an instance of SentinelHubRequest""" return SentinelHubRequest( @@ -335,7 +334,7 @@ def _create_sh_request( config=self.config, ) - def _extract_data(self, eopatch: EOPatch, responses: List[Any], shape: Tuple[int, ...]) -> EOPatch: + def _extract_data(self, eopatch: EOPatch, responses: list[Any], shape: tuple[int, ...]) -> EOPatch: """Extract data from the received images and assign them to eopatch features""" # pylint: disable=arguments-renamed if len(self.features) == 1: @@ -363,36 +362,29 @@ class SentinelHubInputTask(SentinelHubInputBaseTask): """Process API input task that loads 16bit integer data and converts it to a 32bit float feature.""" # pylint: disable=too-many-arguments - DTYPE_TO_SAMPLE_TYPE: Dict[type, str] = { - bool: "SampleType.UINT8", - np.uint8: "SampleType.UINT8", - np.uint16: "SampleType.UINT16", - np.float32: "SampleType.FLOAT32", - } - # pylint: disable=too-many-locals def __init__( self, data_collection: DataCollection, - size: Optional[Tuple[int, int]] = None, - resolution: Optional[Union[float, Tuple[float, float]]] = None, - bands_feature: Optional[Tuple[FeatureType, str]] = None, - bands: Optional[List[str]] = None, - additional_data: Optional[List[Tuple[FeatureType, str]]] = None, - evalscript: Optional[str] = None, - maxcc: Optional[float] = None, - time_difference: Optional[dt.timedelta] = None, - cache_folder: Optional[str] = None, - config: Optional[SHConfig] = None, - max_threads: Optional[int] = None, - bands_dtype: Union[None, np.dtype, type] = None, + size: tuple[int, int] | None = None, + resolution: float | tuple[float, float] | None = None, + bands_feature: tuple[FeatureType, str] | None = None, + bands: list[str] | None = None, + additional_data: list[tuple[FeatureType, str]] | None = None, + evalscript: str | None = None, + maxcc: float | None = None, + time_difference: dt.timedelta | None = None, + cache_folder: str | None = None, + config: SHConfig | None = None, + max_threads: int | None = None, + bands_dtype: None | np.dtype | type = None, single_scene: bool = False, - mosaicking_order: Optional[Union[str, MosaickingOrder]] = None, - upsampling: Optional[ResamplingType] = None, - downsampling: Optional[ResamplingType] = None, - aux_request_args: Optional[dict] = None, - session_loader: Optional[Callable[[], SentinelHubSession]] = None, - timestamp_filter: Callable[[List[dt.datetime], dt.timedelta], List[dt.datetime]] = filter_times, + mosaicking_order: str | MosaickingOrder | None = None, + upsampling: ResamplingType | None = None, + downsampling: ResamplingType | None = None, + aux_request_args: dict | None = None, + session_loader: Callable[[], SentinelHubSession] | None = None, + timestamp_filter: Callable[[list[dt.datetime], dt.timedelta], list[dt.datetime]] = filter_times, ): """ :param data_collection: Source of requested satellite data. @@ -446,108 +438,43 @@ def __init__( self.requested_bands = [] if bands_feature: self.bands_feature = self.parse_feature(bands_feature, allowed_feature_types=[FeatureType.DATA]) - if bands is not None: - self.requested_bands = self._parse_requested_bands(bands, self.data_collection.bands) - else: - self.requested_bands = list(self.data_collection.bands) + bands = bands if bands is not None else [band.name for band in data_collection.bands] + self.requested_bands = parse_data_collection_bands(data_collection, bands) self.requested_additional_bands = [] - self.additional_data: Optional[List[FeatureRenameSpec]] = None + self.additional_data: list[FeatureRenameSpec] | None = None if additional_data is not None: - parsed_additional_data = parse_renamed_features(additional_data) # parser gives too general type - additional_bands = cast(List[str], [band for _, band, _ in parsed_additional_data]) - parsed_bands = self._parse_requested_bands(additional_bands, self.data_collection.metabands) - self.requested_additional_bands = parsed_bands - self.additional_data = parsed_additional_data - - def _parse_requested_bands(self, bands: List[str], available_bands: Tuple[Band, ...]) -> List[Band]: - """Checks that all requested bands are available and returns the band information for further processing""" - requested_bands = [] - band_info_dict = {band_info.name: band_info for band_info in available_bands} - for band_name in bands: - if band_name in band_info_dict: - requested_bands.append(band_info_dict[band_name]) - else: - raise ValueError( - f"Data collection {self.data_collection} does not have specifications for {band_name}." - f"Available bands are {[band.name for band in self.data_collection.bands]} and meta-bands" - f"{[band.name for band in self.data_collection.metabands]}" - ) - return requested_bands - - def generate_evalscript(self) -> str: - """Generate the evalscript to be passed with the request, based on chosen bands""" - evalscript = """ - //VERSION=3 - - function setup() {{ - return {{ - input: [{{ - bands: [{bands}], - units: [{units}] - }}], - output: [ - {outputs} - ] - }} - }} - - function evaluatePixel(sample) {{ - return {{ {samples} }} - }} - """ - - bands, units, outputs, samples = [], [], [], [] - for band in self.requested_bands + self.requested_additional_bands: - unit_choice = 0 # use default units - if band in self.requested_bands and self.bands_dtype is not None: - if self.bands_dtype not in band.output_types: - raise ValueError( - f"Band {band.name} only supports output types {band.output_types} but `bands_dtype` is set to " - f"{self.bands_dtype}. To use default types set `bands_dtype` to None." - ) - unit_choice = band.output_types.index(self.bands_dtype) - - sample_type = SentinelHubInputTask.DTYPE_TO_SAMPLE_TYPE[band.output_types[unit_choice]] - - bands.append(f'"{band.name}"') - units.append(f'"{band.units[unit_choice].value}"') - samples.append(f"{band.name}: [sample.{band.name}]") - outputs.append(f'{{ id: "{band.name}", bands: 1, sampleType: {sample_type} }}') - - evalscript = evalscript.format( - bands=", ".join(bands), units=", ".join(units), outputs=", ".join(outputs), samples=", ".join(samples) - ) + self.additional_data = parse_renamed_features(additional_data) # parser gives too general type + additional_bands = cast(List[str], [band for _, band, _ in self.additional_data]) + self.requested_additional_bands = parse_data_collection_bands(data_collection, additional_bands) - return evalscript - - def _get_timestamps(self, time_interval: Optional[RawTimeIntervalType], bbox: BBox) -> List[dt.datetime]: + def _get_timestamps(self, time_interval: RawTimeIntervalType | None, bbox: BBox) -> list[dt.datetime]: """Get the timestamp array needed as a parameter for downloading the images""" if self.single_scene: return [time_interval[0]] # type: ignore[index, list-item] - return get_available_timestamps( + timestamps = get_available_timestamps( bbox=bbox, time_interval=time_interval, - timestamp_filter=self.timestamp_filter, data_collection=self.data_collection, maxcc=self.maxcc, - time_difference=self.time_difference, config=self.config, ) + return self.timestamp_filter(timestamps, self.time_difference) + def _build_requests( self, - bbox: Optional[BBox], + bbox: BBox | None, size_x: int, size_y: int, - timestamps: Optional[List[dt.datetime]], - time_interval: Optional[RawTimeIntervalType], - geometry: Optional[Geometry], - ) -> List[SentinelHubRequest]: + timestamps: list[dt.datetime] | None, + time_interval: RawTimeIntervalType | None, + geometry: Geometry | None, + ) -> list[SentinelHubRequest]: """Build requests""" if timestamps is None: - intervals: List[Optional[RawTimeIntervalType]] = [None] + intervals: list[RawTimeIntervalType | None] = [None] elif self.single_scene: intervals = [parse_time_interval(time_interval)] else: @@ -557,20 +484,26 @@ def _build_requests( def _create_sh_request( self, - time_interval: Optional[RawTimeIntervalType], - bbox: Optional[BBox], + time_interval: RawTimeIntervalType | None, + bbox: BBox | None, size_x: int, size_y: int, - geometry: Optional[Geometry], + geometry: Geometry | None, ) -> SentinelHubRequest: """Create an instance of SentinelHubRequest""" responses = [ SentinelHubRequest.output_response(band.name, MimeType.TIFF) for band in self.requested_bands + self.requested_additional_bands ] + evalscript = generate_evalscript( + data_collection=self.data_collection, + bands=[band.name for band in self.requested_bands], + meta_bands=[band.name for band in self.requested_additional_bands], + prioritize_dn=not np.issubdtype(self.bands_dtype, np.floating), + ) return SentinelHubRequest( - evalscript=self.evalscript or self.generate_evalscript(), + evalscript=self.evalscript or evalscript, input_data=[ SentinelHubRequest.input_data( data_collection=self.data_collection, @@ -590,7 +523,7 @@ def _create_sh_request( config=self.config, ) - def _extract_data(self, eopatch: EOPatch, responses: List[Any], shape: Tuple[int, ...]) -> EOPatch: + def _extract_data(self, eopatch: EOPatch, responses: list[Any], shape: tuple[int, ...]) -> EOPatch: """Extract data from the received images and assign them to eopatch features""" if len(self.requested_bands) + len(self.requested_additional_bands) == 1: # if only one band is requested the response is not a tar so we reshape it @@ -606,7 +539,7 @@ def _extract_data(self, eopatch: EOPatch, responses: List[Any], shape: Tuple[int return eopatch def _extract_additional_features( - self, eopatch: EOPatch, images: Iterable[np.ndarray], shape: Tuple[int, ...] + self, eopatch: EOPatch, images: Iterable[np.ndarray], shape: tuple[int, ...] ) -> None: """Extracts additional features from response into an EOPatch""" additional_data = cast(List[FeatureRenameSpec], self.additional_data) # verified by `if` in _extract_data @@ -614,7 +547,7 @@ def _extract_additional_features( tiffs = [tar[band_info.name + ".tif"] for tar in images] eopatch[ftype, new_name] = self._extract_array(tiffs, 0, shape, band_info.output_types[0]) - def _extract_bands_feature(self, eopatch: EOPatch, images: Iterable[np.ndarray], shape: Tuple[int, ...]) -> None: + def _extract_bands_feature(self, eopatch: EOPatch, images: Iterable[np.ndarray], shape: tuple[int, ...]) -> None: """Extract the bands feature arrays and concatenate them along the last axis""" processed_bands = [] for band_info in self.requested_bands: @@ -626,9 +559,7 @@ def _extract_bands_feature(self, eopatch: EOPatch, images: Iterable[np.ndarray], eopatch[bands_feature] = np.concatenate(processed_bands, axis=-1) @staticmethod - def _extract_array( - tiffs: List[np.ndarray], idx: int, shape: Tuple[int, ...], dtype: Union[type, np.dtype] - ) -> np.ndarray: + def _extract_array(tiffs: list[np.ndarray], idx: int, shape: tuple[int, ...], dtype: type | np.dtype) -> np.ndarray: """Extract a numpy array from the received tiffs""" feature_arrays = (np.atleast_3d(img)[..., idx] for img in tiffs) return np.asarray(list(feature_arrays), dtype=dtype).reshape(*shape, 1) @@ -642,44 +573,23 @@ class SentinelHubDemTask(SentinelHubEvalscriptTask): def __init__( self, - feature: Union[None, str, FeatureSpec] = None, + feature: None | str | FeatureSpec = None, data_collection: DataCollection = DataCollection.DEM, **kwargs: Any, ): + dem_band = data_collection.bands[0].name + renamed_feature: tuple[FeatureType, str, str] + if feature is None: - feature = (FeatureType.DATA_TIMELESS, "dem") + renamed_feature = (FeatureType.DATA_TIMELESS, dem_band, dem_band) elif isinstance(feature, str): - feature = (FeatureType.DATA_TIMELESS, feature) - - feature_type, feature_name = feature - if feature_type.is_temporal(): - raise ValueError("DEM feature should be timeless!") - - band = data_collection.bands[0] - - evalscript = f""" - //VERSION=3 - - function setup() {{ - return {{ - input: [{{ - bands: ["{band.name}"], - units: ["{band.units[0].value}"] - }}], - output: {{ - id: "{feature_name}", - bands: 1, - sampleType: SampleType.UINT16 - }} - }} - }} - - function evaluatePixel(sample) {{ - return {{ {feature_name}: [sample.{band.name}] }} - }} - """ + renamed_feature = (FeatureType.DATA_TIMELESS, dem_band, feature) + else: + ftype, _, fname = parse_renamed_feature(feature, allowed_feature_types=lambda ftype: ftype.is_timeless()) + renamed_feature = (ftype, dem_band, fname or dem_band) - super().__init__(evalscript=evalscript, features=[feature], data_collection=data_collection, **kwargs) + evalscript = generate_evalscript(data_collection=data_collection, bands=[dem_band]) + super().__init__(evalscript=evalscript, features=[renamed_feature], data_collection=data_collection, **kwargs) class SentinelHubSen2corTask(SentinelHubInputTask): @@ -704,7 +614,7 @@ class SentinelHubSen2corTask(SentinelHubInputTask): def __init__( self, - sen2cor_classification: Union[Literal["SCL", "CLD", "SNW"], List[Literal["SCL", "CLD", "SNW"]]], + sen2cor_classification: Literal["SCL", "CLD", "SNW"] | list[Literal["SCL", "CLD", "SNW"]], data_collection: DataCollection = DataCollection.SENTINEL2_L2A, **kwargs: Any, ): @@ -728,50 +638,5 @@ def __init__( if data_collection != DataCollection.SENTINEL2_L2A: raise ValueError("Sen2Cor classification layers are only available on Sentinel-2 L2A data.") - features: List[Tuple[FeatureType, str]] = [(classification_types[s2c], s2c) for s2c in sen2cor_classification] + features: list[tuple[FeatureType, str]] = [(classification_types[s2c], s2c) for s2c in sen2cor_classification] super().__init__(additional_data=features, data_collection=data_collection, **kwargs) - - -def get_available_timestamps( - bbox: BBox, - data_collection: DataCollection, - *, - time_interval: Optional[RawTimeIntervalType] = None, - time_difference: dt.timedelta = dt.timedelta(seconds=-1), # noqa: B008 - timestamp_filter: Callable[[List[dt.datetime], dt.timedelta], List[dt.datetime]] = filter_times, - maxcc: Optional[float] = None, - config: Optional[SHConfig] = None, -) -> List[dt.datetime]: - """Helper function to search for all available timestamps, based on query parameters. - - :param bbox: A bounding box of the search area. - :param data_collection: A data collection for which to find available timestamps. - :param time_interval: A time interval from which to provide the timestamps. - :param time_difference: Minimum allowed time difference, used when filtering dates. - :param timestamp_filter: A function that performs the final filtering of timestamps, usually to remove multiple - occurrences within the time_difference window. The filtration is performed after all suitable timestamps for - the given region are obtained (with maxcc filtering already done by SH). By default only keeps the oldest - timestamp when multiple occur within `time_difference`. - :param maxcc: Maximum cloud coverage filter from interval [0, 1], default is None. - :param config: A configuration object. - :return: A list of timestamps of available observations. - """ - query_filter = None - if maxcc is not None and data_collection.has_cloud_coverage: - if isinstance(maxcc, (int, float)) and (maxcc < 0 or maxcc > 1): - raise ValueError('Maximum cloud coverage "maxcc" parameter should be a float on an interval [0, 1]') - query_filter = f"eo:cloud_cover < {int(maxcc * 100)}" - - fields = {"include": ["properties.datetime"], "exclude": []} - - if data_collection.service_url: - config = config.copy() if config else SHConfig() - config.sh_base_url = data_collection.service_url - - catalog = SentinelHubCatalog(config=config) - search_iterator = catalog.search( - collection=data_collection, bbox=bbox, time=time_interval, filter=query_filter, fields=fields - ) - - all_timestamps = search_iterator.get_timestamps() - return timestamp_filter(all_timestamps, time_difference) diff --git a/io/eolearn/tests/conftest.py b/io/eolearn/tests/conftest.py index b9c456984..2fed8401c 100644 --- a/io/eolearn/tests/conftest.py +++ b/io/eolearn/tests/conftest.py @@ -12,9 +12,11 @@ import pytest from botocore.exceptions import ClientError, NoCredentialsError -from sentinelhub import SHConfig +pytest.register_assert_rewrite("sentinelhub.testing_utils") # makes asserts in helper functions work with pytest -from eolearn.core import EOPatch +from sentinelhub import SHConfig # noqa[E402] + +from eolearn.core import EOPatch # noqa[E402] EXAMPLE_DATA_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "..", "..", "example_data") TEST_EOPATCH_PATH = os.path.join(EXAMPLE_DATA_PATH, "TestEOPatch") @@ -30,26 +32,16 @@ def example_data_path_fixture(): return EXAMPLE_DATA_PATH -@pytest.fixture(name="config") -def config_fixture(): - config = SHConfig() - # for param in config.get_params(): - # env_variable = param.upper() - # if os.environ.get(env_variable): - # setattr(config, param, os.environ.get(env_variable)) - return config - - @pytest.fixture(name="gpkg_file") def local_gpkg_example_file_fixture(): """A pytest fixture to retrieve a gpkg example file""" - path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "../../../example_data/import-gpkg-test.gpkg") - return path + return os.path.join(os.path.dirname(os.path.realpath(__file__)), "../../../example_data/import-gpkg-test.gpkg") @pytest.fixture(name="s3_gpkg_file") -def s3_gpkg_example_file_fixture(config): +def s3_gpkg_example_file_fixture(): """A pytest fixture to retrieve a gpkg example file""" + config = SHConfig() aws_config = { "region_name": "eu-central-1", } diff --git a/io/eolearn/tests/test_geometry_io.py b/io/eolearn/tests/test_geometry_io.py index dd439c609..295de5aaf 100644 --- a/io/eolearn/tests/test_geometry_io.py +++ b/io/eolearn/tests/test_geometry_io.py @@ -51,7 +51,7 @@ def _test_import(bbox, clip, crs, gpkg_example, n_features, reproject): def test_clipping_wrong_crs(gpkg_file): """Test for trying to clip using different CRS than the data is in""" + feature = FeatureType.VECTOR_TIMELESS, "lpis_iacs" + import_task = VectorImportTask(feature=feature, path=gpkg_file, reproject=False, clip=True) with pytest.raises(ValueError): - feature = FeatureType.VECTOR_TIMELESS, "lpis_iacs" - import_task = VectorImportTask(feature=feature, path=gpkg_file, reproject=False, clip=True) import_task.execute(bbox=BBox([657690, 5071637, 660493, 5074440], CRS.UTM_31N)) diff --git a/io/eolearn/tests/test_raster_io.py b/io/eolearn/tests/test_raster_io.py index 403eb4062..0971b840a 100644 --- a/io/eolearn/tests/test_raster_io.py +++ b/io/eolearn/tests/test_raster_io.py @@ -37,7 +37,7 @@ @pytest.fixture(autouse=True) -def create_s3_bucket_fixture(): +def _create_s3_bucket_fixture(): with mock_s3(): s3resource = boto3.resource("s3", region_name="eu-central-1") s3resource.create_bucket(Bucket=BUCKET_NAME, CreateBucketConfiguration={"LocationConstraint": "eu-central-1"}) @@ -127,7 +127,7 @@ def get_expected_timestamp_size(self): ] -@pytest.mark.parametrize("test_case", TIFF_TEST_CASES, ids=[test_case.name for test_case in TIFF_TEST_CASES]) +@pytest.mark.parametrize("test_case", TIFF_TEST_CASES, ids=lambda x: x.name) def test_export_import(test_case, test_eopatch): test_eopatch[test_case.feature_type][test_case.name] = test_case.data @@ -192,17 +192,18 @@ def _execute_with_warning_control( @pytest.mark.parametrize( - "bands, times", [([2, "string", 1, 0], [1, 7, 0, 2, 3]), ([2, 3, 1, 0], [1, 7, "string", 2, 3])] + ("bands", "times"), [([2, "string", 1, 0], [1, 7, 0, 2, 3]), ([2, 3, 1, 0], [1, 7, "string", 2, 3])] ) def test_export2tiff_wrong_format(bands, times, test_eopatch): test_eopatch.data["data"] = np.arange(10 * 3 * 2 * 6, dtype=float).reshape(10, 3, 2, 6) - with tempfile.TemporaryDirectory() as tmp_dir_name, pytest.raises(ValueError): + with tempfile.TemporaryDirectory() as tmp_dir_name: tmp_file_name = "temp_file.tiff" task = ExportToTiffTask( (FeatureType.DATA, "data"), folder=tmp_dir_name, band_indices=bands, date_indices=times, image_dtype=float ) - task.execute(test_eopatch, filename=tmp_file_name) + with pytest.raises(ValueError): + task.execute(test_eopatch, filename=tmp_file_name) def test_export2tiff_wrong_feature(mocker, test_eopatch): @@ -218,10 +219,13 @@ def test_export2tiff_wrong_feature(mocker, test_eopatch): assert logging.Logger.warning.call_count == 1 (val_err,), _ = logging.Logger.warning.call_args - assert str(val_err) == "Feature feature-not-present of type FeatureType.MASK_TIMELESS was not found in EOPatch" + assert ( + str(val_err) + == "Feature (, 'feature-not-present') was not found in EOPatch" + ) + failing_export_task = ExportToTiffTask(feature, folder=tmp_dir_name, fail_on_missing=True) with pytest.raises(ValueError): - failing_export_task = ExportToTiffTask(feature, folder=tmp_dir_name, fail_on_missing=True) failing_export_task(test_eopatch, filename=tmp_file_name) @@ -328,37 +332,39 @@ def test_time_dependent_feature(test_eopatch): f'relative-path/{timestamp.strftime("%Y%m%dT%H%M%S")}.tiff' for timestamp in test_eopatch.timestamps ] - export_task = ExportToTiffTask(feature, folder=PATH_ON_BUCKET) - import_task = ImportFromTiffTask(feature, folder=PATH_ON_BUCKET, timestamp_size=68) + with tempfile.TemporaryDirectory() as tmp_dir_name: + export_task = ExportToTiffTask(feature, folder=tmp_dir_name) + import_task = ImportFromTiffTask(feature, folder=tmp_dir_name, timestamp_size=68) - export_task(test_eopatch, filename=filename_export) - new_eopatch = import_task(filename=filename_import) + export_task(test_eopatch, filename=filename_export) + new_eopatch = import_task(filename=filename_import) - assert_array_equal(new_eopatch[feature], test_eopatch[feature]) + assert_array_equal(new_eopatch[feature], test_eopatch[feature]) - test_eopatch.timestamps[-1] = datetime.datetime(2020, 10, 10) - filename_import = [ - f'relative-path/{timestamp.strftime("%Y%m%dT%H%M%S")}.tiff' for timestamp in test_eopatch.timestamps - ] + test_eopatch.timestamps[-1] = datetime.datetime(2020, 10, 10) + filename_import = [ + f'relative-path/{timestamp.strftime("%Y%m%dT%H%M%S")}.tiff' for timestamp in test_eopatch.timestamps + ] - with pytest.raises(ResourceNotFound): - import_task(filename=filename_import) + with pytest.raises((ResourceNotFound, rasterio.errors.RasterioIOError)): + import_task(filename=filename_import) def test_time_dependent_feature_with_timestamps(test_eopatch): feature = FeatureType.DATA, "NDVI" filename = "relative-path/%Y%m%dT%H%M%S.tiff" - export_task = ExportToTiffTask(feature, folder=PATH_ON_BUCKET) - import_task = ImportFromTiffTask(feature, folder=PATH_ON_BUCKET) + with tempfile.TemporaryDirectory() as tmp_dir_name: + export_task = ExportToTiffTask(feature, folder=tmp_dir_name) + import_task = ImportFromTiffTask(feature, folder=tmp_dir_name) - export_task.execute(test_eopatch, filename=filename) - new_eopatch = import_task(test_eopatch, filename=filename) + export_task.execute(test_eopatch, filename=filename) + new_eopatch = import_task(test_eopatch, filename=filename) - assert_array_equal(new_eopatch[feature], test_eopatch[feature]) + assert_array_equal(new_eopatch[feature], test_eopatch[feature]) -@pytest.mark.parametrize("no_data_value, data_type", [(np.nan, float), (0, int), (None, float), (1, np.byte)]) +@pytest.mark.parametrize(("no_data_value", "data_type"), [(np.nan, float), (0, int), (None, float), (1, np.byte)]) def test_export_import_sequence(no_data_value, data_type): """Tests import and export tiff tasks on generated array with different values of no_data_value.""" eopatch = EOPatch(bbox=BBox((0, 0, 1, 1), crs=CRS.WGS84)) diff --git a/io/eolearn/tests/test_sentinelhub_process.py b/io/eolearn/tests/test_sentinelhub_process.py index fac0cd170..d073fce42 100644 --- a/io/eolearn/tests/test_sentinelhub_process.py +++ b/io/eolearn/tests/test_sentinelhub_process.py @@ -13,18 +13,12 @@ import numpy as np import pytest -from pytest import approx -from sentinelhub import CRS, Band, BBox, DataCollection, Geometry, MosaickingOrder, ResamplingType, SHConfig, Unit +from sentinelhub import CRS, Band, BBox, DataCollection, Geometry, MosaickingOrder, ResamplingType, Unit +from sentinelhub.testing_utils import assert_statistics_match from eolearn.core import EOPatch, EOTask, FeatureType -from eolearn.io import ( - SentinelHubDemTask, - SentinelHubEvalscriptTask, - SentinelHubInputTask, - SentinelHubSen2corTask, - get_available_timestamps, -) +from eolearn.io import SentinelHubDemTask, SentinelHubEvalscriptTask, SentinelHubInputTask, SentinelHubSen2corTask @pytest.fixture(name="cache_folder") @@ -61,11 +55,11 @@ def calculate_stats(array): array[: max(int(time / 2), 1), -1, -1, :], array[:, int(height / 2), int(width / 2), :], ] - values = [(np.nanmean(slice) if not np.isnan(slice).all() else np.nan) for slice in slices] + values = [(np.nanmean(_slice) if not np.isnan(_slice).all() else np.nan) for _slice in slices] return np.round(np.array(values), 4) -@pytest.mark.sh_integration +@pytest.mark.sh_integration() class TestProcessingIO: """Test cases for SentinelHubInputTask""" @@ -98,7 +92,7 @@ def test_s2l1c_float32_uint16(self, cache_folder): bands = eopatch[(FeatureType.DATA, "BANDS")] is_data = eopatch[(FeatureType.MASK, "dataMask")] - assert calculate_stats(bands) == approx([x / 10000 for x in expected_int_stats], abs=1e-4) + assert calculate_stats(bands) == pytest.approx([x / 10000 for x in expected_int_stats], abs=1e-4) width, height = self.size assert bands.shape == (4, height, width, 3) @@ -114,7 +108,7 @@ def test_s2l1c_float32_uint16(self, cache_folder): eopatch = task.execute(bbox=self.bbox, time_interval=self.time_interval) bands = eopatch[(FeatureType.DATA, "BANDS")] - assert calculate_stats(bands) == approx(expected_int_stats) + assert calculate_stats(bands) == pytest.approx(expected_int_stats) assert bands.dtype == np.uint16 @@ -133,13 +127,13 @@ def test_specific_bands(self): eopatch = task.execute(bbox=self.bbox, time_interval=self.time_interval) bands = eopatch[(FeatureType.DATA, "BANDS")] - assert calculate_stats(bands) == approx([0.0648, 0.1193, 0.063]) + assert calculate_stats(bands) == pytest.approx([0.0648, 0.1193, 0.063]) width, height = self.size assert bands.shape == (4, height, width, 3) @pytest.mark.parametrize( - ["resampling_type", "stats"], + ("resampling_type", "stats"), [ (ResamplingType.NEAREST, [0.0836, 0.1547, 0.0794]), (ResamplingType.BICUBIC, [0.0836, 0.1548, 0.0792]), @@ -162,13 +156,13 @@ def test_upsampling_downsampling(self, resampling_type: ResamplingType, stats: L eopatch = task.execute(bbox=self.bbox, time_interval=self.time_interval) bands = eopatch[(FeatureType.DATA, "BANDS")] - assert calculate_stats(bands) == approx(stats) + assert calculate_stats(bands) == pytest.approx(stats) width, height = self.size assert bands.shape == (4, height, width, 1) @pytest.mark.parametrize( - ["geometry", "stats"], + ("geometry", "stats"), [ ( Geometry( @@ -203,13 +197,13 @@ def test_geometry_argument(self, geometry: Geometry, stats: List[float]): eopatch = task.execute(bbox=self.bbox, time_interval=self.time_interval, geometry=geometry) bands = eopatch[(FeatureType.DATA, "BANDS")] - assert calculate_stats(bands) == approx(stats) + assert calculate_stats(bands) == pytest.approx(stats) width, height = self.size assert bands.shape == (4, height, width, 1) @pytest.mark.parametrize( - ["geometry", "stats"], + ("geometry", "stats"), [ ( Geometry( @@ -273,7 +267,7 @@ def test_geometry_argument_evalscript(self, geometry: Geometry, stats: List[floa assert eop.data["bands"].shape == (4, height, width, 1) bands = eop[(FeatureType.DATA, "bands")] - assert calculate_stats(bands) == approx(stats) + assert calculate_stats(bands) == pytest.approx(stats) def test_scl_only(self): """Download just SCL, without any other bands""" @@ -351,7 +345,7 @@ def test_additional_data(self): sun_azimuth_angles = eopatch[(FeatureType.DATA, "sunAzimuthAngles")] sun_zenith_angles = eopatch[(FeatureType.DATA, "sunZenithAngles")] - assert calculate_stats(bands) == approx([0.027, 0.0243, 0.0162]) + assert calculate_stats(bands) == pytest.approx([0.027, 0.0243, 0.0162]) width, height = self.size assert bands.shape == (4, height, width, 3) @@ -384,16 +378,22 @@ def test_aux_request_args(self): bands = eopatch[(FeatureType.DATA, "BANDS")] assert bands.shape == (4, 4, 4, 13) - assert calculate_stats(bands) == approx([0.0, 0.0493, 0.0277]) + assert calculate_stats(bands) == pytest.approx([0.0, 0.0493, 0.0277]) def test_dem(self): task = SentinelHubDemTask(resolution=10, feature=(FeatureType.DATA_TIMELESS, "DEM"), max_threads=3) - eopatch = task.execute(bbox=self.bbox) - dem = eopatch.data_timeless["DEM"] - width, height = self.size - assert dem.shape == (height, width, 1) + + assert_statistics_match( + eopatch.data_timeless["DEM"], + exp_shape=(height, width, 1), + exp_dtype=np.float32, + exp_max=3.4277425, + exp_min=-0.96642065, + exp_mean=0.2557371, + exp_median=0, + ) def test_dem_cop(self): task = SentinelHubDemTask( @@ -507,42 +507,6 @@ def test_multi_processing(self): width, height = self.size assert array.shape == (13, height, width, 2) - def test_get_available_timestamps_with_missing_data_collection_service_url(self): - collection = DataCollection.SENTINEL2_L1C.define_from("COLLECTION_WITHOUT_URL", service_url=None) - timestamps = get_available_timestamps( - bbox=self.bbox, - config=SHConfig(), - data_collection=collection, - time_difference=self.time_difference, - time_interval=self.time_interval, - maxcc=self.maxcc, - ) - - assert len(timestamps) == 4 - assert all(timestamp.tzinfo is not None for timestamp in timestamps) - - def test_get_available_timestamps_custom_filtration(self): - """Checks that the custom filtration works as intended.""" - timestamps1 = get_available_timestamps( - bbox=self.bbox, - config=SHConfig(), - data_collection=DataCollection.SENTINEL2_L1C, - time_interval=self.time_interval, - timestamp_filter=lambda stamps, diff: stamps[:3], - ) - - assert len(timestamps1) == 3 - - timestamps2 = get_available_timestamps( - bbox=self.bbox, - config=SHConfig(), - data_collection=DataCollection.SENTINEL2_L1C, - time_interval=self.time_interval, - timestamp_filter=lambda stamps, diff: stamps[:5], - ) - - assert len(timestamps2) == 5 - def test_no_data_input_task_request(self): task = SentinelHubInputTask( bands_feature=(FeatureType.DATA, "BANDS"), @@ -606,7 +570,7 @@ def test_no_data_evalscript_task_request(self): assert masks.shape == (0, 101, 99, 1) -@pytest.mark.sh_integration +@pytest.mark.sh_integration() class TestSentinelHubInputTaskDataCollections: """Integration tests for all supported data collections""" @@ -823,4 +787,4 @@ def test_data_collections(self, test_case): assert len(timestamps) == test_case.timestamp_length stats = calculate_stats(data) - assert stats == approx(test_case.stats, nan_ok=True), f"Expected stats {test_case.stats}, got {stats}" + assert stats == pytest.approx(test_case.stats, nan_ok=True), f"Expected stats {test_case.stats}, got {stats}" diff --git a/io/requirements.txt b/io/requirements.txt index 64d9625be..1a2f77db7 100644 --- a/io/requirements.txt +++ b/io/requirements.txt @@ -4,6 +4,5 @@ eo-learn-core fiona>=1.8.18 geopandas>=0.8.1 rasterio>=1.2.7 -rtree # might be redundant after 3.7 is dropped due to new geopandas version sentinelhub>=3.8.0 typing-extensions diff --git a/io/setup.py b/io/setup.py index ecaed890b..7c2af05ca 100644 --- a/io/setup.py +++ b/io/setup.py @@ -7,9 +7,7 @@ def get_long_description(): this_directory = os.path.abspath(os.path.dirname(__file__)) with open(os.path.join(this_directory, "README.md"), encoding="utf-8") as f: - long_description = f.read() - - return long_description + return f.read() def parse_requirements(file): @@ -29,7 +27,7 @@ def get_version(): setup( name="eo-learn-io", - python_requires=">=3.7", + python_requires=">=3.8", version=get_version(), description="A collection of input/output EOTasks and utilities", long_description=get_long_description(), @@ -60,10 +58,10 @@ def get_version(): "Operating System :: Unix", "Programming Language :: Python", "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.7", "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", "Topic :: Scientific/Engineering", "Topic :: Scientific/Engineering :: GIS", "Topic :: Scientific/Engineering :: Image Processing", diff --git a/mask/MANIFEST.in b/mask/MANIFEST.in index 5478e06b1..dad260ce5 100644 --- a/mask/MANIFEST.in +++ b/mask/MANIFEST.in @@ -2,4 +2,5 @@ include requirements*.txt include eolearn/mask/models/* include LICENSE include README.md +include eolearn/mask/py.typed exclude eolearn/tests/* diff --git a/mask/eolearn/mask/__init__.py b/mask/eolearn/mask/__init__.py index db067f249..77a3a15a0 100644 --- a/mask/eolearn/mask/__init__.py +++ b/mask/eolearn/mask/__init__.py @@ -8,4 +8,4 @@ from .snow_mask import SnowMaskTask, TheiaSnowMaskTask from .utils import resize_images -__version__ = "1.4.1" +__version__ = "1.4.2" diff --git a/mask/eolearn/mask/cloud_mask.py b/mask/eolearn/mask/cloud_mask.py index 2d2022e55..1d91f3636 100644 --- a/mask/eolearn/mask/cloud_mask.py +++ b/mask/eolearn/mask/cloud_mask.py @@ -6,22 +6,24 @@ This source code is licensed under the MIT license, see the LICENSE file in the root directory of this source tree. """ +from __future__ import annotations + import logging import os from functools import partial -from typing import Callable, Optional, Tuple, Union, cast +from typing import Protocol, cast import cv2 import numpy as np from lightgbm import Booster from skimage.morphology import disk -from typing_extensions import Protocol from sentinelhub import BBox, bbox_to_resolution from eolearn.core import EOPatch, EOTask, FeatureType, execute_with_mp_lock +from eolearn.core.utils.common import _apply_to_spatial_axes -from .utils import map_over_axis, resize_images +from .utils import resize_images LOGGER = logging.getLogger(__name__) @@ -29,10 +31,12 @@ class ClassifierType(Protocol): """Defines the necessary classifier interface.""" - def predict(self, X: np.ndarray) -> np.ndarray: # pylint: disable=missing-function-docstring,invalid-name + # pylint: disable-next=missing-function-docstring,invalid-name + def predict(self, X: np.ndarray) -> np.ndarray: # noqa[N803] ... - def predict_proba(self, X: np.ndarray) -> np.ndarray: # pylint: disable=missing-function-docstring,invalid-name + # pylint: disable-next=missing-function-docstring,invalid-name + def predict_proba(self, X: np.ndarray) -> np.ndarray: # noqa[N803] ... @@ -79,20 +83,20 @@ class CloudMaskTask(EOTask): def __init__( self, - data_feature: Tuple[FeatureType, str] = (FeatureType.DATA, "BANDS-S2-L1C"), - is_data_feature: Tuple[FeatureType, str] = (FeatureType.MASK, "IS_DATA"), + data_feature: tuple[FeatureType, str] = (FeatureType.DATA, "BANDS-S2-L1C"), + is_data_feature: tuple[FeatureType, str] = (FeatureType.MASK, "IS_DATA"), all_bands: bool = True, - processing_resolution: Union[None, float, Tuple[float, float]] = None, + processing_resolution: None | float | tuple[float, float] = None, max_proc_frames: int = 11, - mono_features: Optional[Tuple[Optional[str], Optional[str]]] = None, - multi_features: Optional[Tuple[Optional[str], Optional[str]]] = None, - mask_feature: Optional[Tuple[FeatureType, str]] = (FeatureType.MASK, "CLM_INTERSSIM"), + mono_features: tuple[str | None, str | None] | None = None, + multi_features: tuple[str | None, str | None] | None = None, + mask_feature: tuple[FeatureType, str] | None = (FeatureType.MASK, "CLM_INTERSSIM"), mono_threshold: float = 0.4, multi_threshold: float = 0.5, - average_over: Optional[int] = 4, - dilation_size: Optional[int] = 2, - mono_classifier: Optional[ClassifierType] = None, - multi_classifier: Optional[ClassifierType] = None, + average_over: int | None = 4, + dilation_size: int | None = 2, + mono_classifier: ClassifierType | None = None, + multi_classifier: ClassifierType | None = None, ): """ :param data_feature: A data feature which stores raw Sentinel-2 reflectance bands. @@ -174,7 +178,7 @@ def __init__( self.dil_kernel = None @staticmethod - def _parse_resolution_arg(resolution: Union[None, float, Tuple[float, float]]) -> Optional[Tuple[float, float]]: + def _parse_resolution_arg(resolution: None | float | tuple[float, float]) -> tuple[float, float] | None: """Parses initialization resolution argument""" if isinstance(resolution, (int, float)): resolution = resolution, resolution @@ -209,7 +213,7 @@ def _run_prediction(classifier: ClassifierType, features: np.ndarray) -> np.ndar return prediction if is_booster else prediction[..., 1] - def _scale_factors(self, reference_shape: Tuple[int, int], bbox: BBox) -> Tuple[Tuple[float, float], float]: + def _scale_factors(self, reference_shape: tuple[int, int], bbox: BBox) -> tuple[tuple[float, float], float]: """Compute the resampling factors for height and width of the input array and sigma :param reference_shape: Tuple specifying height and width in pixels of high-resolution array @@ -251,19 +255,14 @@ def _red_ssim( mu2_2 = mu2 * mu2 mu1_mu2 = mu1 * mu2 - sigma12 = cv2.GaussianBlur((data_x * data_y).astype(np.float64), (0, 0), sigma, borderType=cv2.BORDER_REFLECT) + sigma12 = cv2.GaussianBlur(data_x * data_y, (0, 0), sigma, borderType=cv2.BORDER_REFLECT) sigma12 -= mu1_mu2 # Formula - tmp1 = 2.0 * mu1_mu2 + const1 - tmp2 = 2.0 * sigma12 + const2 - num = tmp1 * tmp2 - - tmp1 = mu1_2 + mu2_2 + const1 - tmp2 = sigma1_2 + sigma2_2 + const2 - den = tmp1 * tmp2 + numerator = (2.0 * mu1_mu2 + const1) * (2.0 * sigma12 + const2) + denominator = (mu1_2 + mu2_2 + const1) * (sigma1_2 + sigma2_2 + const2) - return np.divide(num, den) + return np.divide(numerator, denominator) def _win_avg(self, data: np.ndarray, sigma: float) -> np.ndarray: """Spatial window average""" @@ -279,37 +278,17 @@ def _average(self, data: np.ndarray) -> np.ndarray: def _dilate(self, data: np.ndarray) -> np.ndarray: return (cv2.dilate(data.astype(np.uint8), self.dil_kernel) > 0).astype(np.uint8) - @staticmethod - def _map_sequence(data: np.ndarray, func2d: Callable[[np.ndarray], np.ndarray]) -> np.ndarray: - """Iterate over time and band dimensions and apply a function to each slice. - Returns a new array with the combined results. - - :param data: input array - :param func2d: Mapping function that is applied on each 2d image slice. All outputs must have the same shape. - """ - - # Map over channel dimension on 3d tensor - def func3d(data_slice: np.ndarray) -> np.ndarray: - return map_over_axis(data_slice, func2d, axis=2) - - # Map over time dimension on 4d tensor - def func4d(data_slice: np.ndarray) -> np.ndarray: - return map_over_axis(data_slice, func3d, axis=0) - - output = func4d(data) - return output - def _average_all(self, data: np.ndarray) -> np.ndarray: """Average over each spatial slice of data""" if self.avg_kernel is not None: - return self._map_sequence(data, self._average) + return _apply_to_spatial_axes(self._average, data, (1, 2)) return data def _dilate_all(self, data: np.ndarray) -> np.ndarray: """Dilate over each spatial slice of data""" if self.dil_kernel is not None: - return self._map_sequence(data, self._dilate) + return _apply_to_spatial_axes(self._dilate, data, (1, 2)) return data @@ -321,7 +300,7 @@ def _ssim_stats( local_var: np.ndarray, rel_tdx: int, sigma: float, - ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: """Calculate SSIM stats""" ssim_max = np.empty((1, *bands.shape[1:]), dtype=np.float32) ssim_mean = np.empty_like(ssim_max) @@ -331,8 +310,7 @@ def _ssim_stats( win_avg_r = np.delete(local_avg, rel_tdx, axis=0) var_r = np.delete(local_var, rel_tdx, axis=0) - n_frames = bands_r.shape[0] - n_bands = bands_r.shape[-1] + n_frames, _, _, n_bands = bands_r.shape valid_mask = np.delete(is_data, rel_tdx, axis=0) & is_data[rel_tdx, ..., 0].reshape(1, *is_data.shape[1:-1], 1) @@ -363,19 +341,16 @@ def _ssim_stats( def _do_single_temporal_cloud_detection(self, bands: np.ndarray) -> np.ndarray: """Performs a cloud detection process on each scene separately""" - mono_proba = np.empty(np.prod(bands.shape[:-1])) - img_size = np.prod(bands.shape[1:-1]).astype(int) - - n_times = bands.shape[0] + n_times, height, width, n_bands = bands.shape + img_size = height * width + mono_proba = np.empty(n_times * img_size) for t_i in range(0, n_times, self.max_proc_frames): # Extract mono features nt_min = t_i nt_max = min(t_i + self.max_proc_frames, n_times) - bands_t = bands[nt_min:nt_max] - - mono_features = bands_t.reshape(np.prod(bands_t.shape[:-1]), bands_t.shape[-1]) + mono_features = bands[nt_min:nt_max].reshape(-1, n_bands) mono_proba[nt_min * img_size : nt_max * img_size] = self._run_prediction( self.mono_classifier, mono_features @@ -385,15 +360,14 @@ def _do_single_temporal_cloud_detection(self, bands: np.ndarray) -> np.ndarray: def _do_multi_temporal_cloud_detection(self, bands: np.ndarray, is_data: np.ndarray, sigma: float) -> np.ndarray: """Performs a cloud detection process on multiple scenes at once""" - multi_proba = np.empty(np.prod(bands.shape[:-1])) - img_size = int(np.prod(bands.shape[1:-1])) - - n_times = bands.shape[0] + n_times, height, width, n_bands = bands.shape + img_size = height * width + multi_proba = np.empty(n_times * img_size) - local_avg: Optional[np.ndarray] = None - local_var: Optional[np.ndarray] = None - prev_left: Optional[int] = None - prev_right: Optional[int] = None + local_avg: np.ndarray | None = None + local_var: np.ndarray | None = None + prev_left: int | None = None + prev_right: int | None = None for t_idx in range(n_times): # Extract temporal window indices @@ -401,7 +375,7 @@ def _do_multi_temporal_cloud_detection(self, bands: np.ndarray, is_data: np.ndar bands_slice = bands[left:right] is_data_slice = is_data[left:right] - masked_bands = np.ma.array(bands_slice, mask=~is_data_slice.repeat(bands_slice.shape[-1], axis=-1)) + masked_bands = np.ma.array(bands_slice, mask=~is_data_slice.repeat(n_bands, axis=-1)) # Calculate the averages/variances for the local (windowed) streaming data if local_avg is None or (left, right) != (prev_left, prev_right): @@ -418,41 +392,36 @@ def _do_multi_temporal_cloud_detection(self, bands: np.ndarray, is_data: np.ndar self.multi_classifier, multi_features ) - prev_left = left - prev_right = right + prev_left, prev_right = left, right return multi_proba[..., None] def _update_batches( self, - local_avg: Optional[np.ndarray], - local_var: Optional[np.ndarray], + local_avg: np.ndarray | None, + local_var: np.ndarray | None, bands: np.ndarray, is_data: np.ndarray, sigma: float, - ) -> Tuple[np.ndarray, np.ndarray]: - """ - Calculates or updates the window average and variance. - The calculation is done per 2D image along the temporal and band axes. - """ + ) -> tuple[np.ndarray, np.ndarray]: + """Calculates or updates the window average and variance. The calculation is done per 2D image along the + temporal and band axes.""" local_avg_func = partial(self._win_avg, sigma=sigma) local_var_func = partial(self._win_prevar, sigma=sigma) # take full batch if avg/var don't exist, otherwise take only last index slice - data = bands if local_avg is None else bands[-1][None, ...] - data_mask = is_data if local_avg is None else is_data[-1][None, ...] + data = bands if local_avg is None else bands[-1][np.newaxis, ...] + data_mask = is_data if local_avg is None else is_data[-1][np.newaxis, ...] - avg_data = self._map_sequence(data, local_avg_func) - avg_data_mask = self._map_sequence(data_mask, local_avg_func) + avg_data = _apply_to_spatial_axes(local_avg_func, data, (1, 2)) + avg_data_mask = _apply_to_spatial_axes(local_avg_func, data_mask, (1, 2)) avg_data_mask[avg_data_mask == 0.0] = 1.0 - var_data = self._map_sequence(data, local_var_func) + var_data = _apply_to_spatial_axes(local_var_func, data, (1, 2)) if local_avg is None or local_var is None: - local_avg = avg_data / avg_data_mask - local_avg = cast(np.ndarray, local_avg) - local_var = var_data - local_avg**2 - local_var = cast(np.ndarray, local_var) + local_avg = cast(np.ndarray, avg_data / avg_data_mask) + local_var = cast(np.ndarray, var_data - local_avg**2) return local_avg, local_var # shift back, drop first element @@ -465,7 +434,7 @@ def _update_batches( return local_avg, local_var - def _extract_multi_features( + def _extract_multi_features( # pylint: disable=too-many-locals self, bands: np.ndarray, is_data: np.ndarray, @@ -480,16 +449,16 @@ def _extract_multi_features( ssim_max, ssim_mean, ssim_std = self._ssim_stats(bands, is_data, local_avg, local_var, local_t_idx, sigma) # Compute temporal stats - temp_min = np.ma.min(masked_bands, axis=0).data[None, ...] - temp_mean = np.ma.mean(masked_bands, axis=0).data[None, ...] + temp_min = np.ma.min(masked_bands, axis=0).data[np.newaxis, ...] + temp_mean = np.ma.mean(masked_bands, axis=0).data[np.newaxis, ...] # Compute difference stats t_all = len(bands) - diff_max = (masked_bands[local_t_idx][None, ...] - temp_min).data - diff_mean = ( - masked_bands[local_t_idx][None, ...] * (1.0 + 1.0 / (t_all - 1)) - t_all * temp_mean / (t_all - 1) - ).data + diff_max = (masked_bands[local_t_idx][np.newaxis, ...] - temp_min).data + coef1 = 1.0 + 1.0 / (t_all - 1) + coef2 = t_all * temp_mean / (t_all - 1) + diff_mean = (masked_bands[local_t_idx][np.newaxis, ...] * coef1 - coef2).data # Interweave ssim_interweaved = np.empty((*ssim_max.shape[:-1], 3 * ssim_max.shape[-1])) @@ -508,8 +477,8 @@ def _extract_multi_features( # Put it all together multi_features = np.concatenate( ( - bands[local_t_idx][None, ...], - local_avg[local_t_idx][None, ...], + bands[local_t_idx][np.newaxis, ...], + local_avg[local_t_idx][np.newaxis, ...], ssim_interweaved, temp_interweaved, diff_interweaved, @@ -517,11 +486,9 @@ def _extract_multi_features( axis=3, ) - multi_features = multi_features.reshape(np.prod(multi_features.shape[:-1]), multi_features.shape[-1]) - - return multi_features + return multi_features.reshape(-1, multi_features.shape[-1]) - def execute(self, eopatch: EOPatch) -> EOPatch: + def execute(self, eopatch: EOPatch) -> EOPatch: # noqa: C901 """Add selected features (cloud probabilities and masks) to an EOPatch instance. :param eopatch: Input `EOPatch` instance @@ -531,11 +498,11 @@ def execute(self, eopatch: EOPatch) -> EOPatch: is_data = eopatch[self.is_data_feature].astype(bool) - original_shape = bands.shape[1:-1] + image_size = bands.shape[1:-1] patch_bbox = eopatch.bbox if patch_bbox is None: raise ValueError("Cannot run cloud masking on an EOPatch without a BBox.") - scale_factors, sigma = self._scale_factors(original_shape, patch_bbox) + scale_factors, sigma = self._scale_factors(image_size, patch_bbox) is_data_sm = is_data # Downscale if specified @@ -553,7 +520,7 @@ def execute(self, eopatch: EOPatch) -> EOPatch: # Upscale if necessary if scale_factors is not None: - mono_proba = resize_images(mono_proba, new_size=original_shape) + mono_proba = resize_images(mono_proba, new_size=image_size) # Average over and threshold mono_mask = self._average_all(mono_proba) >= self.mono_threshold @@ -565,7 +532,7 @@ def execute(self, eopatch: EOPatch) -> EOPatch: # Upscale if necessary if scale_factors is not None: - multi_proba = resize_images(multi_proba, new_size=original_shape) + multi_proba = resize_images(multi_proba, new_size=image_size) # Average over and threshold multi_mask = self._average_all(multi_proba) >= self.multi_threshold @@ -593,7 +560,7 @@ def execute(self, eopatch: EOPatch) -> EOPatch: return eopatch -def _get_window_indices(num_of_elements: int, middle_idx: int, window_size: int) -> Tuple[int, int]: +def _get_window_indices(num_of_elements: int, middle_idx: int, window_size: int) -> tuple[int, int]: """ Returns the minimum and maximum indices to be used for indexing, lower inclusive and upper exclusive. The window has the following properties: diff --git a/mask/eolearn/mask/mask_counting.py b/mask/eolearn/mask/mask_counting.py index 52f48dc37..0f0440490 100644 --- a/mask/eolearn/mask/mask_counting.py +++ b/mask/eolearn/mask/mask_counting.py @@ -8,7 +8,7 @@ """ from __future__ import annotations -from typing import Iterator, List, Union +from typing import Iterator import numpy as np @@ -23,7 +23,7 @@ def __init__( self, input_feature: FeaturesSpecification, output_feature: FeaturesSpecification, - classes: List[int], + classes: list[int], no_data_value: int = 0, ): """ @@ -44,10 +44,8 @@ def __init__( def map_method(self, feature: np.ndarray) -> np.ndarray: """Map method being applied to the feature that calculates the frequencies.""" - count_valid: Union[int, np.ndarray] = np.count_nonzero(feature != self.no_data_value, axis=0) - class_counts: Iterator[Union[int, np.ndarray]] = ( - np.count_nonzero(feature == scl, axis=0) for scl in self.classes - ) + count_valid: int | np.ndarray = np.count_nonzero(feature != self.no_data_value, axis=0) + class_counts: Iterator[int | np.ndarray] = (np.count_nonzero(feature == scl, axis=0) for scl in self.classes) with np.errstate(invalid="ignore"): class_frequencies = [np.divide(count, count_valid, dtype=np.float32) for count in class_counts] diff --git a/mask/eolearn/mask/masking.py b/mask/eolearn/mask/masking.py index ae11a38db..954cfb5b0 100644 --- a/mask/eolearn/mask/masking.py +++ b/mask/eolearn/mask/masking.py @@ -8,10 +8,9 @@ """ from __future__ import annotations -from typing import Callable, Dict, Iterable, Union +from typing import Callable, Iterable, Literal import numpy as np -from typing_extensions import Literal from eolearn.core import EOPatch, EOTask, FeatureType, ZipFeatureTask from eolearn.core.types import FeaturesSpecification, SingleFeatureSpec @@ -25,7 +24,7 @@ def __init__( self, input_features: FeaturesSpecification, output_feature: SingleFeatureSpec, - join_operation: Union[Literal["and", "or", "xor"], Callable] = "and", + join_operation: Literal["and", "or", "xor"] | Callable = "and", ): """ :param input_features: Mask features to be joined together. @@ -34,7 +33,7 @@ def __init__( """ self.join_method: Callable[[np.ndarray, np.ndarray], np.ndarray] if isinstance(join_operation, str): - methods: Dict[str, Callable[[np.ndarray, np.ndarray], np.ndarray]] = { + methods: dict[str, Callable[[np.ndarray, np.ndarray], np.ndarray]] = { "and": np.logical_and, "or": np.logical_or, "xor": np.logical_xor, diff --git a/mask/eolearn/mask/py.typed b/mask/eolearn/mask/py.typed new file mode 100644 index 000000000..e69de29bb diff --git a/mask/eolearn/mask/snow_mask.py b/mask/eolearn/mask/snow_mask.py index f9b9c1a65..57aed8520 100644 --- a/mask/eolearn/mask/snow_mask.py +++ b/mask/eolearn/mask/snow_mask.py @@ -10,7 +10,8 @@ import itertools import logging -from typing import Any, List, Optional, Tuple +from abc import ABCMeta +from typing import Any import numpy as np from skimage.morphology import binary_dilation, disk @@ -23,13 +24,13 @@ LOGGER = logging.getLogger(__name__) -class BaseSnowMaskTask(EOTask): +class BaseSnowMaskTask(EOTask, metaclass=ABCMeta): """Base class for snow detection and masking""" def __init__( self, data_feature: FeatureSpec, - band_indices: List[int], + band_indices: list[int], dilation_size: int = 0, undefined_value: int = 0, mask_name: str = "SNOW_MASK", @@ -40,7 +41,7 @@ def __init__( :param dilation_size: Size of the disk in pixels for performing dilation. Value 0 means do not perform this post-processing step. """ - self.bands_feature = self.parse_feature(data_feature) + self.bands_feature = self.parse_feature(data_feature, allowed_feature_types={FeatureType.DATA}) self.band_indices = band_indices self.dilation_size = dilation_size self.undefined_value = undefined_value @@ -52,9 +53,6 @@ def _apply_dilation(self, snow_masks: np.ndarray) -> np.ndarray: snow_masks = np.array([binary_dilation(mask, disk(self.dilation_size)) for mask in snow_masks]) return snow_masks - def execute(self, eopatch: EOPatch) -> Any: - raise NotImplementedError - class SnowMaskTask(BaseSnowMaskTask): """The task calculates the snow mask using the given thresholds. @@ -68,7 +66,7 @@ class SnowMaskTask(BaseSnowMaskTask): def __init__( self, data_feature: FeatureSpec, - band_indices: List[int], + band_indices: list[int], ndsi_threshold: float = 0.4, brightness_threshold: float = 0.3, **kwargs: Any, @@ -99,20 +97,13 @@ def execute(self, eopatch: EOPatch) -> EOPatch: ndsi[ndsi_invalid] = self.undefined_value ndvi[ndvi_invalid] = self.undefined_value - snow_mask = np.where( - np.logical_and( - np.logical_or( - ndsi >= self.ndsi_threshold, np.abs(ndvi - self.NDVI_THRESHOLD) < self.NDVI_THRESHOLD / 2 - ), - bands[..., 0] >= self.brightness_threshold, - ), - 1, - 0, - ) + ndi_criterion = (ndsi >= self.ndsi_threshold) | (np.abs(ndvi - self.NDVI_THRESHOLD) < self.NDVI_THRESHOLD / 2) + brightnes_criterion = bands[..., 0] >= self.brightness_threshold + snow_mask = np.where(ndi_criterion & brightnes_criterion, 1, 0) snow_mask = self._apply_dilation(snow_mask) - snow_mask[np.logical_or(ndsi_invalid, ndvi_invalid)] = self.undefined_value + snow_mask[ndsi_invalid | ndvi_invalid] = self.undefined_value eopatch[self.mask_feature] = snow_mask[..., np.newaxis].astype(bool) return eopatch @@ -137,13 +128,13 @@ class TheiaSnowMaskTask(BaseSnowMaskTask): def __init__( self, data_feature: FeatureSpec, - band_indices: List[int], + band_indices: list[int], cloud_mask_feature: FeatureSpec, dem_feature: FeatureSpec, - dem_params: Tuple[float, float] = (100, 0.1), - red_params: Tuple[float, float, float, float, float] = (12, 0.3, 0.1, 0.2, 0.040), - ndsi_params: Tuple[float, float, float] = (0.4, 0.15, 0.001), - b10_index: Optional[int] = None, + dem_params: tuple[float, float] = (100, 0.1), + red_params: tuple[float, float, float, float, float] = (12, 0.3, 0.1, 0.2, 0.040), + ndsi_params: tuple[float, float, float] = (0.4, 0.15, 0.001), + b10_index: int | None = None, **kwargs: Any, ): """ @@ -181,80 +172,54 @@ def __init__( self.red_params = red_params self.ndsi_params = ndsi_params self.b10_index = b10_index - self._validate_params() - - def _validate_params(self) -> None: - """Check length of parameters defining threshold values""" - for params, n_params in [(self.dem_params, 2), (self.red_params, 5), (self.ndsi_params, 3)]: - if not isinstance(params, (tuple, list)) or len(params) != n_params: - raise ValueError( - f"Incorrect format or number of parameters for {params}. Has to be a tuple of length {n_params}" - ) def _resample_red(self, input_array: np.ndarray) -> np.ndarray: """Method to resample the values of the red band The input array is first down-scaled using bicubic interpolation and up-scaled back using nearest neighbour interpolation - - :param input_array: input values - :return: resampled values """ - height, width = input_array.shape[1:] + _, height, width = input_array.shape size = (height // self.red_params[0], width // self.red_params[0]) - return resize_images( - resize_images(input_array[..., np.newaxis], new_size=size), - new_size=(height, width), - ).squeeze() + downscaled = resize_images(input_array[..., np.newaxis], new_size=size) + return resize_images(downscaled, new_size=(height, width)).squeeze() def _adjust_cloud_mask( - self, bands: np.ndarray, cloud_mask: np.ndarray, dem: np.ndarray, b10: np.ndarray + self, bands: np.ndarray, cloud_mask: np.ndarray, dem: np.ndarray, b10: np.ndarray | None ) -> np.ndarray: """Adjust existing cloud mask using cirrus band if L1C data and resampled red band Add to the existing cloud mask pixels found thresholding down-sampled red band and cirrus band/DEM """ - clm_b10 = ( - np.where(b10 > self.B10_THR + self.DEM_FACTOR * dem, 1, 0) - if b10 is not None - else np.ones(shape=cloud_mask.shape, dtype=np.uint8) - ) - return np.logical_or( - np.where(np.logical_and(cloud_mask == 1, self._resample_red(bands[..., 1]) > self.red_params[1]), 1, 0), - clm_b10, - ).astype(np.uint8) + if b10 is not None: + clm_b10 = b10 > self.B10_THR + self.DEM_FACTOR * dem + else: + clm_b10 = np.full_like(cloud_mask, True) + + criterion = (cloud_mask == 1) & (self._resample_red(bands[..., 1]) > self.red_params[1]) + return criterion | clm_b10 def _apply_first_pass( self, bands: np.ndarray, ndsi: np.ndarray, clm: np.ndarray, dem: np.ndarray, clm_temp: np.ndarray - ) -> Tuple[np.ndarray, Optional[np.ndarray], np.ndarray]: + ) -> tuple[np.ndarray, np.ndarray | None, np.ndarray]: """Apply first pass of snow detection""" - snow_mask_pass1 = np.where( - np.logical_and( - np.logical_not(clm_temp), np.logical_and(ndsi > self.ndsi_params[0], bands[..., 1] > self.red_params[3]) - ), - 1, - 0, - ) - - clm_pass1 = np.where( - np.logical_or(clm_temp, (bands[..., 1] > self.red_params[2]) & np.logical_not(snow_mask_pass1) & clm), 1, 0 - ) - - dem_edges = np.linspace( - np.min(dem), np.max(dem), int(np.ceil((np.max(dem) - np.min(dem)) / self.dem_params[0])) - ) + snow_mask_pass1 = ~clm_temp & (ndsi > self.ndsi_params[0]) & (bands[..., 1] > self.red_params[3]) + + clm_pass1 = clm_temp | ((bands[..., 1] > self.red_params[2]) & ~snow_mask_pass1 & clm.astype(bool)) + + min_dem, max_dem = np.min(dem), np.max(dem) + dem_edges = np.linspace(min_dem, max_dem, int(np.ceil((max_dem - min_dem) / self.dem_params[0]))) nbins = len(dem_edges) - 1 + dem_hist_clear_pixels, snow_frac = None, None if nbins > 0: snow_frac = np.zeros(shape=(bands.shape[0], nbins)) - dem_hist_clear_pixels = np.array( - [np.histogram(dem[np.logical_not(mask)], bins=dem_edges)[0] for mask in clm_pass1] - ) + dem_hist_clear_pixels = np.array([np.histogram(dem[~mask], bins=dem_edges)[0] for mask in clm_pass1]) for date, nbin in itertools.product(range(bands.shape[0]), range(nbins)): if dem_hist_clear_pixels[date, nbin] > 0: - dem_mask = np.logical_and(dem_edges[nbin] <= dem, dem < dem_edges[nbin + 1]) - in_dem_range_clear = np.where(np.logical_and(dem_mask, np.logical_not(clm_pass1[date]))) + dem_mask = (dem_edges[nbin] <= dem) & (dem < dem_edges[nbin + 1]) + in_dem_range_clear = np.where(dem_mask & ~clm_pass1[date]) snow_frac[date, nbin] = ( np.sum(snow_mask_pass1[date][in_dem_range_clear]) / dem_hist_clear_pixels[date, nbin] ) @@ -267,30 +232,28 @@ def _apply_second_pass( dem: np.ndarray, clm_temp: np.ndarray, snow_mask_pass1: np.ndarray, - snow_frac: Optional[np.ndarray], + snow_frac: np.ndarray | None, dem_edges: np.ndarray, ) -> np.ndarray: """Second pass of snow detection""" _, height, width, _ = bands.shape total_snow_frac = np.sum(snow_mask_pass1, axis=(1, 2)) / (height * width) - snow_mask_pass2 = np.zeros(snow_mask_pass1.shape) + snow_mask_pass2 = np.full_like(snow_mask_pass1, False) for date in range(bands.shape[0]): - if (total_snow_frac[date] > self.ndsi_params[2]) and ( - snow_frac is not None and np.any(snow_frac[date] > self.dem_params[1]) + if ( + (total_snow_frac[date] > self.ndsi_params[2]) + and snow_frac is not None + and np.any(snow_frac[date] > self.dem_params[1]) ): z_s = dem_edges[np.max(np.argmax(snow_frac[date] > self.dem_params[1]) - 2, 0)] - snow_mask_pass2[date, :, :] = np.where( - np.logical_and( - dem > z_s, - np.logical_and( - np.logical_not(clm_temp[date]), - np.logical_and(ndsi[date] > self.ndsi_params[1], bands[date, ..., 1] > self.red_params[-1]), - ), - ), - 1, - 0, + snow_mask_pass2[date, :, :] = ( + (dem > z_s) + & ~clm_temp[date] + & (ndsi[date] > self.ndsi_params[1]) + & (bands[date, ..., 1] > self.red_params[-1]) ) + return snow_mask_pass2 def execute(self, eopatch: EOPatch) -> EOPatch: @@ -313,7 +276,7 @@ def execute(self, eopatch: EOPatch) -> EOPatch: snow_mask_pass2 = self._apply_second_pass(bands, ndsi, dem, clm_temp, snow_mask_pass1, snow_frac, dem_edges) - snow_mask = self._apply_dilation(np.logical_or(snow_mask_pass1, snow_mask_pass2)) + snow_mask = self._apply_dilation(snow_mask_pass1 | snow_mask_pass2) eopatch[self.mask_feature] = snow_mask[..., np.newaxis].astype(bool) diff --git a/mask/eolearn/mask/utils.py b/mask/eolearn/mask/utils.py index ee6dc2d84..fc200225d 100644 --- a/mask/eolearn/mask/utils.py +++ b/mask/eolearn/mask/utils.py @@ -95,12 +95,12 @@ def _resize2d(image: np.ndarray) -> np.ndarray: image = cv2.GaussianBlur(image, (0, 0), sigmaX=sigma_x, sigmaY=sigma_y, borderType=cv2.BORDER_REFLECT) height, width = new_size - resized = cv2.resize(image, (width, height), interpolation=interpolation_method) + return cv2.resize(image, (width, height), interpolation=interpolation_method) - return resized - - _resize3d = lambda x: map_over_axis(x, _resize2d, axis=2) # pylint: disable=unnecessary-lambda-assignment # noqa - _resize4d = lambda x: map_over_axis(x, _resize3d, axis=0) # pylint: disable=unnecessary-lambda-assignment # noqa + # pylint: disable-next=unnecessary-lambda-assignment + _resize3d = lambda x: map_over_axis(x, _resize2d, axis=2) # noqa: E731 + # pylint: disable-next=unnecessary-lambda-assignment + _resize4d = lambda x: map_over_axis(x, _resize3d, axis=0) # noqa: E731 # Choose a resize method based on number of dimensions resize_methods = {2: _resize2d, 3: _resize3d, 4: _resize4d} diff --git a/mask/eolearn/tests/test_cloud_mask.py b/mask/eolearn/tests/test_cloud_mask.py index db5a46a32..51e8d290f 100644 --- a/mask/eolearn/tests/test_cloud_mask.py +++ b/mask/eolearn/tests/test_cloud_mask.py @@ -8,7 +8,6 @@ import numpy as np import pytest from numpy.testing import assert_array_equal -from pytest import approx from eolearn.core import FeatureType from eolearn.mask import CloudMaskTask @@ -16,21 +15,21 @@ @pytest.mark.parametrize( - "num_of_elements, middle_idx, window_size, expected_indices", - ( - [100, 0, 10, (0, 10)], - [100, 1, 10, (0, 10)], - [100, 50, 10, (45, 55)], - [271, 270, 10, (261, 271)], - [314, 314, 10, (304, 314)], - [100, 0, 11, (0, 11)], - [100, 1, 11, (0, 11)], - [100, 50, 11, (45, 56)], - [271, 270, 11, (260, 271)], - [314, 314, 11, (303, 314)], - [11, 2, 11, (0, 11)], - [11, 2, 33, (0, 11)], - ), + ("num_of_elements", "middle_idx", "window_size", "expected_indices"), + [ + (100, 0, 10, (0, 10)), + (100, 1, 10, (0, 10)), + (100, 50, 10, (45, 55)), + (271, 270, 10, (261, 271)), + (314, 314, 10, (304, 314)), + (100, 0, 11, (0, 11)), + (100, 1, 11, (0, 11)), + (100, 50, 11, (45, 56)), + (271, 270, 11, (260, 271)), + (314, 314, 11, (303, 314)), + (11, 2, 11, (0, 11)), + (11, 2, 33, (0, 11)), + ], ids=str, ) def test_window_indices_function(num_of_elements, middle_idx, window_size, expected_indices): @@ -79,8 +78,8 @@ def test_multi_temporal_cloud_detection_downscaled(test_eopatch): assert eop_clm.data["CLP_TEST"].dtype == np.float32 # Compare mean cloud coverage with provided reference - assert np.mean(eop_clm.mask["CLM_TEST"]) == approx(np.mean(eop_clm.mask["CLM_S2C"]), abs=0.01) - assert np.mean(eop_clm.data["CLP_TEST"]) == approx(np.mean(eop_clm.data["CLP_S2C"]), abs=0.01) + assert np.mean(eop_clm.mask["CLM_TEST"]) == pytest.approx(np.mean(eop_clm.mask["CLM_S2C"]), abs=0.01) + assert np.mean(eop_clm.data["CLP_TEST"]) == pytest.approx(np.mean(eop_clm.data["CLP_S2C"]), abs=0.01) # Check if most of the same times are flagged as cloudless cloudless = np.mean(eop_clm.mask["CLM_TEST"], axis=(1, 2, 3)) == 0 diff --git a/mask/eolearn/tests/test_mask_counting.py b/mask/eolearn/tests/test_mask_counting.py index 9ebff2f71..310debf86 100644 --- a/mask/eolearn/tests/test_mask_counting.py +++ b/mask/eolearn/tests/test_mask_counting.py @@ -15,9 +15,10 @@ IN_FEATURE = (FeatureType.MASK, "TEST") OUT_FEATURE = (FeatureType.DATA_TIMELESS, "FREQ") +# ruff: noqa: NPY002 -@pytest.mark.parametrize("classes, no_data_value", ((["a", "b"], 0), (4, 0), (None, 0), ([1, 2, 3], 2))) +@pytest.mark.parametrize(("classes", "no_data_value"), [(["a", "b"], 0), (4, 0), (None, 0), ([1, 2, 3], 2)]) def test_value_error(classes, no_data_value): with pytest.raises(ValueError): ClassFrequencyTask(IN_FEATURE, OUT_FEATURE, classes=classes, no_data_value=no_data_value) diff --git a/mask/eolearn/tests/test_mask_utils.py b/mask/eolearn/tests/test_mask_utils.py index 9e6c5e9c6..6f74decbe 100644 --- a/mask/eolearn/tests/test_mask_utils.py +++ b/mask/eolearn/tests/test_mask_utils.py @@ -5,9 +5,9 @@ def test_map_over_axis(): data = np.ones((5, 10, 10)) - result = map_over_axis(data, lambda x: np.zeros((7, 20)), axis=0) + result = map_over_axis(data, lambda _: np.zeros((7, 20)), axis=0) assert result.shape == (5, 7, 20) - result = map_over_axis(data, lambda x: np.zeros((7, 20)), axis=1) + result = map_over_axis(data, lambda _: np.zeros((7, 20)), axis=1) assert result.shape == (7, 10, 20) - result = map_over_axis(data, lambda x: np.zeros((5, 10)), axis=1) + result = map_over_axis(data, lambda _: np.zeros((5, 10)), axis=1) assert result.shape == (5, 10, 10) diff --git a/mask/eolearn/tests/test_snow_mask.py b/mask/eolearn/tests/test_snow_mask.py index 5c14b7117..30c6a5149 100644 --- a/mask/eolearn/tests/test_snow_mask.py +++ b/mask/eolearn/tests/test_snow_mask.py @@ -11,21 +11,8 @@ from eolearn.mask import SnowMaskTask, TheiaSnowMaskTask -@pytest.mark.parametrize("params", [{"dem_params": (100, 100, 100)}, {"red_params": 45}, {"ndsi_params": (0.2, 3)}]) -def test_raises_errors(params, test_eopatch): - with pytest.raises(ValueError): - theia_mask = TheiaSnowMaskTask( - (FeatureType.DATA, "BANDS-S2-L1C"), - [2, 3, 11], - (FeatureType.MASK, "CLM"), - (FeatureType.DATA_TIMELESS, "DEM"), - **params - ) - theia_mask(test_eopatch) - - @pytest.mark.parametrize( - "task, result", + ("task", "result"), [ (SnowMaskTask((FeatureType.DATA, "BANDS-S2-L1C"), [2, 3, 7, 11], mask_name="TEST_SNOW_MASK"), (50468, 1405)), ( diff --git a/mask/setup.py b/mask/setup.py index 115f06380..d26a7a6c2 100644 --- a/mask/setup.py +++ b/mask/setup.py @@ -7,9 +7,7 @@ def get_long_description(): this_directory = os.path.abspath(os.path.dirname(__file__)) with open(os.path.join(this_directory, "README.md"), encoding="utf-8") as f: - long_description = f.read() - - return long_description + return f.read() def parse_requirements(file): @@ -29,7 +27,7 @@ def get_version(): setup( name="eo-learn-mask", - python_requires=">=3.7", + python_requires=">=3.8", version=get_version(), description="A collection of masking EOTasks and utilities", long_description=get_long_description(), @@ -65,10 +63,10 @@ def get_version(): "Operating System :: Unix", "Programming Language :: Python", "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.7", "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", "Topic :: Scientific/Engineering", "Topic :: Scientific/Engineering :: GIS", "Topic :: Scientific/Engineering :: Image Processing", diff --git a/ml_tools/MANIFEST.in b/ml_tools/MANIFEST.in index 1b163e951..970f86032 100644 --- a/ml_tools/MANIFEST.in +++ b/ml_tools/MANIFEST.in @@ -1,4 +1,5 @@ include requirements*.txt include LICENSE include README.md +include eolearn/ml_tools/py.typed exclude eolearn/tests/* diff --git a/ml_tools/eolearn/ml_tools/__init__.py b/ml_tools/eolearn/ml_tools/__init__.py index 0d6889d2a..da151dd17 100644 --- a/ml_tools/eolearn/ml_tools/__init__.py +++ b/ml_tools/eolearn/ml_tools/__init__.py @@ -5,4 +5,4 @@ from .sampling import BlockSamplingTask, FractionSamplingTask, GridSamplingTask, sample_by_values from .train_test_split import TrainTestSplitTask -__version__ = "1.4.1" +__version__ = "1.4.2" diff --git a/ml_tools/eolearn/ml_tools/sampling.py b/ml_tools/eolearn/ml_tools/sampling.py index b81277153..d195369bc 100644 --- a/ml_tools/eolearn/ml_tools/sampling.py +++ b/ml_tools/eolearn/ml_tools/sampling.py @@ -10,7 +10,7 @@ from abc import ABCMeta from math import sqrt -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, Union, cast import numpy as np from shapely.geometry import Point, Polygon @@ -21,7 +21,7 @@ _FractionType = Union[float, Dict[int, float]] -def random_point_in_triangle(triangle: Polygon, rng: Optional[np.random.Generator] = None) -> Point: +def random_point_in_triangle(triangle: Polygon, rng: np.random.Generator | None = None) -> Point: """Selects a random point from an interior of a triangle. :param triangle: A triangle polygon. @@ -43,10 +43,10 @@ def random_point_in_triangle(triangle: Polygon, rng: Optional[np.random.Generato def sample_by_values( image: np.ndarray, - n_samples_per_value: Dict[int, int], - rng: Optional[np.random.Generator] = None, + n_samples_per_value: dict[int, int], + rng: np.random.Generator | None = None, replace: bool = False, -) -> Tuple[np.ndarray, np.ndarray]: +) -> tuple[np.ndarray, np.ndarray]: """Sample points from image with the amount of samples specified for each value. :param image: A 2-dimensional numpy array @@ -61,8 +61,7 @@ def sample_by_values( raise ValueError(f"Given image has shape {image.shape} but sampling operates only on 2D images") rng = rng or np.random.default_rng() - rows = np.empty((0,), dtype=np.int16) - columns = np.empty((0,), dtype=np.int16) + rows, columns = np.empty((0,), dtype=int), np.empty((0,), dtype=int) for value, n_samples in n_samples_per_value.items(): sample_rows, sample_cols = rng.choice(np.nonzero(image == value), size=n_samples, replace=replace, axis=1) @@ -73,8 +72,8 @@ def sample_by_values( def expand_to_grids( - rows: np.ndarray, columns: np.ndarray, sample_size: Tuple[int, int] = (1, 1) -) -> Tuple[np.ndarray, np.ndarray]: + rows: np.ndarray, columns: np.ndarray, sample_size: tuple[int, int] = (1, 1) +) -> tuple[np.ndarray, np.ndarray]: """Expands sampled points into blocks and returns a pair of arrays. Each array represents a grid of indices of pixel locations, the first one row indices and the second one column indices. Each array is of shape `(N * sample_height, sample_width)`, where each element represent a row or column index in an original array @@ -106,7 +105,7 @@ def expand_to_grids( return row_grids, column_grids -def get_mask_of_samples(image_shape: Tuple[int, int], row_grid: np.ndarray, column_grid: np.ndarray) -> np.ndarray: +def get_mask_of_samples(image_shape: tuple[int, int], row_grid: np.ndarray, column_grid: np.ndarray) -> np.ndarray: """Creates a mask of counts how many times each pixel has been sampled. :param image_shape: Height and width of a sampled image. @@ -127,14 +126,14 @@ def get_mask_of_samples(image_shape: Tuple[int, int], row_grid: np.ndarray, colu return mask -class BaseSamplingTask(EOTask, metaclass=ABCMeta): # noqa: B024 +class BaseSamplingTask(EOTask, metaclass=ABCMeta): """A base class for sampling tasks""" def __init__( self, features_to_sample: FeaturesSpecification, *, - mask_of_samples: Optional[Tuple[FeatureType, str]] = None, + mask_of_samples: tuple[FeatureType, str] | None = None, ): """ :param features_to_sample: Features that will be spatially sampled according to given sampling parameters. @@ -155,12 +154,12 @@ def _apply_sampling(self, eopatch: EOPatch, row_grid: np.ndarray, column_grid: n image_shape = None for feature_type, feature_name, new_feature_name in self.features_parser.get_renamed_features(eopatch): if feature_name is not None: - data_to_sample = eopatch[feature_type][feature_name] + data_to_sample = eopatch[feature_type, feature_name] feature_shape = eopatch.get_spatial_dimension(feature_type, feature_name) image_shape = feature_shape - eopatch[feature_type][new_feature_name] = data_to_sample[..., row_grid, column_grid, :] + eopatch[feature_type, new_feature_name] = data_to_sample[..., row_grid, column_grid, :] if self.mask_of_samples is not None and image_shape is not None: mask = get_mask_of_samples(image_shape, row_grid, column_grid) @@ -181,7 +180,7 @@ def __init__( features_to_sample: FeaturesSpecification, sampling_feature: SingleFeatureSpec, fraction: _FractionType, - exclude_values: Optional[List[int]] = None, + exclude_values: list[int] | None = None, replace: bool = False, **kwargs: Any, ): @@ -222,7 +221,7 @@ def _validate_fraction_input(self, fraction: _FractionType) -> None: f"The fraction input is {fraction} but needs to be a number or a dictionary mapping labels to numbers." ) - def _calculate_amount_per_value(self, image: np.ndarray, fraction: _FractionType) -> Dict[int, int]: + def _calculate_amount_per_value(self, image: np.ndarray, fraction: _FractionType) -> dict[int, int]: """Calculates the number of samples needed for each value present in mask according to the fraction parameter""" uniques, counts = np.unique(image, return_counts=True) available = {val: n for val, n in zip(uniques, counts) if val not in self.exclude_values} @@ -231,9 +230,7 @@ def _calculate_amount_per_value(self, image: np.ndarray, fraction: _FractionType return {val: round(n * fraction[val]) for val, n in available.items() if val in fraction} return {val: round(n * self.fraction) for val, n in available.items()} - def execute( - self, eopatch: EOPatch, *, seed: Optional[int] = None, fraction: Optional[_FractionType] = None - ) -> EOPatch: + def execute(self, eopatch: EOPatch, *, seed: int | None = None, fraction: _FractionType | None = None) -> EOPatch: """Execute random spatial sampling of specified features of eopatch :param eopatch: Input eopatch to be sampled @@ -269,7 +266,7 @@ def __init__( self, features_to_sample: FeaturesSpecification, amount: float, - sample_size: Tuple[int, int] = (1, 1), + sample_size: tuple[int, int] = (1, 1), replace: bool = False, **kwargs: Any, ): @@ -284,13 +281,6 @@ def __init__( super().__init__(features_to_sample, **kwargs) self.amount = amount - if not ( - isinstance(sample_size, tuple) - and len(sample_size) == 2 - and all(isinstance(value, int) for value in sample_size) - ): - raise ValueError(f"Parameter sample_size should be a tuple of 2 integers but {sample_size} found") - self.sample_size = tuple(sample_size) self.replace = replace @@ -298,19 +288,13 @@ def _generate_dummy_mask(self, eopatch: EOPatch) -> np.ndarray: """Generate a mask consisting entirely of `values` entries, used for sampling on whole raster""" feature_type, feature_name = self.features_parser.get_features(eopatch)[0] - if feature_name is None: - raise ValueError( - f"Encountered {feature_type} when calculating spatial dimension, please report bug to eo-learn" - " developers." - ) - - height, width = eopatch.get_spatial_dimension(feature_type, feature_name) + height, width = eopatch.get_spatial_dimension(feature_type, cast(str, feature_name)) height -= self.sample_size[0] - 1 width -= self.sample_size[1] - 1 return np.ones((height, width), dtype=np.uint8) - def execute(self, eopatch: EOPatch, *, seed: Optional[int] = None, amount: Optional[float] = None) -> EOPatch: + def execute(self, eopatch: EOPatch, *, seed: int | None = None, amount: float | None = None) -> EOPatch: """Execute a spatial sampling on features from a given EOPatch :param eopatch: Input eopatch to be sampled @@ -346,8 +330,8 @@ class GridSamplingTask(BaseSamplingTask): def __init__( self, features_to_sample: FeaturesSpecification, - sample_size: Tuple[int, int] = (1, 1), - stride: Tuple[int, int] = (1, 1), + sample_size: tuple[int, int] = (1, 1), + stride: tuple[int, int] = (1, 1), **kwargs: Any, ): """ @@ -367,7 +351,7 @@ def __init__( if not all(value > 0 for value in self.sample_size + self.stride): raise ValueError("Both sample_size and stride should have only positive values") - def _sample_regular_grid(self, image_shape: Tuple[int, int]) -> Tuple[np.ndarray, np.ndarray]: + def _sample_regular_grid(self, image_shape: tuple[int, int]) -> tuple[np.ndarray, np.ndarray]: """Samples points from a regular grid and returns indices of rows and columns""" rows = np.arange(0, image_shape[0] - self.sample_size[0] + 1, self.stride[0]) columns = np.arange(0, image_shape[1] - self.sample_size[1] + 1, self.stride[1]) @@ -382,13 +366,8 @@ def execute(self, eopatch: EOPatch) -> EOPatch: :return: An EOPatch with additional spatially sampled features """ feature_type, feature_name = self.features_parser.get_features(eopatch)[0] - if feature_name is None: - raise ValueError( - f"Encountered {feature_type} when calculating spatial dimension, please report bug to eo-learn" - " developers." - ) - image_shape = eopatch.get_spatial_dimension(feature_type, feature_name) + image_shape = eopatch.get_spatial_dimension(feature_type, cast(str, feature_name)) rows, columns = self._sample_regular_grid(image_shape) size_x, size_y = self.sample_size # this way it also works for lists row_grid, column_grid = expand_to_grids(rows, columns, sample_size=(size_x, size_y)) diff --git a/ml_tools/eolearn/ml_tools/tdigest.py b/ml_tools/eolearn/ml_tools/tdigest.py new file mode 100644 index 000000000..7b3ac95b7 --- /dev/null +++ b/ml_tools/eolearn/ml_tools/tdigest.py @@ -0,0 +1,199 @@ +""" +The module provides an EOTask for the computation of a T-Digest representation of an EOPatch. +Requires installation of `eolearn.ml_tools[TDIGEST]`. + +Copyright (c) 2017- Sinergise and contributors +For the full list of contributors, see the CREDITS file in the root directory of this source tree. + +This source code is licensed under the MIT license, see the LICENSE file in the root directory of this source tree. +""" +from __future__ import annotations + +from functools import partial +from itertools import product +from typing import Any, Callable, Generator, Iterable, Literal, Union + +import numpy as np +import tdigest as td + +from eolearn.core import EOPatch, EOTask, FeatureType +from eolearn.core.types import FeatureSpec, FeaturesSpecification + +ModeTypes = Union[Literal["standard", "timewise", "monthly", "total"], Callable] + + +class TDigestTask(EOTask): + """ + An EOTask to compute the T-Digest representation of a chosen feature of an EOPatch. + It integrates the [T-Digest algorithm by Ted Dunning](https://arxiv.org/abs/1902.04023) to efficiently compute + quantiles of the underlying dataset into eo-learn. + + The output features of the tasks may be merged to compute a representation of the complete dataset. + That enables quantile based normalisation or statistical analysis of datasets larger than RAM in EO. + """ + + def __init__( + self, + in_feature: FeaturesSpecification, + out_feature: FeaturesSpecification, + mode: Literal["standard", "timewise", "monthly", "total"] | Callable = "standard", + pixelwise: bool = False, + filternan: bool = False, + ): + """ + :param in_feature: The input feature to compute the T-Digest representation for. + :param out_feature: The output feature where to save the T-Digest representation of the chosen feature. + :param mode: The mode to apply to the timestamps and bands. + * `'standard'` computes the T-Digest representation for each band accumulating timestamps. + * `'timewise'` computes the T-Digest representation for each band and timestamp of the chosen feature. + * `'monthly'` computes the T-Digest representation for each band accumulating the timestamps per month. + * | `'total'` computes the total T-Digest representation of the whole feature accumulating all timestamps, + | bands and pixels. Cannot be used with `pixelwise=True`. + * | Callable computes the T-Digest representation defined by the processing function given as mode. Receives + | the input_array of the feature, timestamps, shape and pixelwise and filternan keywords as an input. + :param pixelwise: Decider whether to compute the T-Digest representation accumulating pixels or per pixel. + Cannot be used with `mode='total'`. + :param filternan: Decider whether to filter out nan-values before computing the T-Digest. + """ + + self.mode = mode + + self.pixelwise = pixelwise + + self.filternan = filternan + + if self.pixelwise and self.mode == "total": + raise ValueError("Total mode does not support pixelwise=True.") + + self.in_feature = self.parse_features(in_feature, allowed_feature_types=partial(_is_input_ftype, mode=mode)) + self.out_feature = self.parse_features( + out_feature, allowed_feature_types=partial(_is_output_ftype, mode=mode, pixelwise=pixelwise) + ) + + if len(self.in_feature) != len(self.out_feature): + raise ValueError( + f"The number of input ({len(self.in_feature)}) and output features ({len(self.out_feature)}) must" + " match." + ) + + def execute(self, eopatch: EOPatch) -> EOPatch: + """ + Execute method that computes the TDigest of the chosen features. + + :param eopatch: EOPatch which the chosen input feature already exists + """ + + for in_feature_, out_feature_, shape in _looper( + in_feature=self.in_feature, out_feature=self.out_feature, eopatch=eopatch + ): + processing_func = self.mode if callable(self.mode) else _processing_function[self.mode] + eopatch[out_feature_] = processing_func( + input_array=eopatch[in_feature_], + timestamps=eopatch.timestamps, + shape=shape, + pixelwise=self.pixelwise, + filternan=self.filternan, + ) + + return eopatch + + +# auxiliary +def _is_input_ftype(feature_type: FeatureType, mode: ModeTypes) -> bool: + if mode == "standard": + return feature_type.is_image() + if mode in ("timewise", "monthly"): + return feature_type in [FeatureType.DATA, FeatureType.MASK] + return True + + +def _is_output_ftype(feature_type: FeatureType, mode: ModeTypes, pixelwise: bool) -> bool: + if callable(mode): + return True + + if mode == "standard": + return feature_type == (FeatureType.DATA_TIMELESS if pixelwise else FeatureType.SCALAR_TIMELESS) + + if mode in ("timewise", "monthly"): + return feature_type == (FeatureType.DATA if pixelwise else FeatureType.SCALAR) + + return feature_type == FeatureType.SCALAR_TIMELESS + + +def _looper( + in_feature: list[FeatureSpec], out_feature: list[FeatureSpec], eopatch: EOPatch +) -> Generator[tuple[FeatureSpec, FeatureSpec, np.ndarray], None, None]: + for in_feature_, out_feature_ in zip(in_feature, out_feature): + shape = np.array(eopatch[in_feature_].shape) + yield in_feature_, out_feature_, shape + + +def _process_standard( + input_array: np.ndarray, shape: np.ndarray, pixelwise: bool, filternan: bool, **_: Any +) -> np.ndarray: + if pixelwise: + array = np.empty(shape[-3:], dtype=object) + for i, j, k in product(range(shape[-3]), range(shape[-2]), range(shape[-1])): + array[i, j, k] = _get_tdigest(input_array[..., i, j, k], filternan) + + else: + array = np.empty(shape[-1], dtype=object) + for k in range(shape[-1]): + array[k] = _get_tdigest(input_array[..., k], filternan) + + return array + + +def _process_timewise( + input_array: np.ndarray, shape: np.ndarray, pixelwise: bool, filternan: bool, **_: Any +) -> np.ndarray: + if pixelwise: + array = np.empty(shape, dtype=object) + for time_, i, j, k in product(range(shape[0]), range(shape[1]), range(shape[2]), range(shape[3])): + array[time_, i, j, k] = _get_tdigest(input_array[time_, i, j, k], filternan) + + else: + array = np.empty(shape[[0, -1]], dtype=object) + for time_, k in product(range(shape[0]), range(shape[-1])): + array[time_, k] = _get_tdigest(input_array[time_, ..., k], filternan) + + return array + + +def _process_monthly( + input_array: np.ndarray, timestamps: Iterable, shape: np.ndarray, pixelwise: bool, filternan: bool, **_: Any +) -> np.ndarray: + midx = [] + for month_ in range(12): + midx.append(np.array([timestamp.month == month_ + 1 for timestamp in timestamps])) + + if pixelwise: + array = np.empty([12, *shape[1:]], dtype=object) + for month_, i, j, k in product(range(12), range(shape[1]), range(shape[2]), range(shape[3])): + array[month_, i, j, k] = _get_tdigest(input_array[midx[month_], i, j, k], filternan) + + else: + array = np.empty([12, shape[-1]], dtype=object) + for month_, k in product(range(12), range(shape[-1])): + array[month_, k] = _get_tdigest(input_array[midx[month_], ..., k], filternan) + + return array + + +def _process_total(input_array: np.ndarray, filternan: bool, **_: Any) -> np.ndarray: + return _get_tdigest(input_array, filternan) + + +_processing_function: dict[str, Callable] = { + "standard": _process_standard, + "timewise": _process_timewise, + "monthly": _process_monthly, + "total": _process_total, +} + + +def _get_tdigest(values: np.ndarray, filternan: bool) -> td.TDigest: + result = td.TDigest() + values_ = values.flatten() + result.batch_update(values_[~np.isnan(values_)] if filternan else values_) + return result diff --git a/ml_tools/eolearn/ml_tools/train_test_split.py b/ml_tools/eolearn/ml_tools/train_test_split.py index b74339cdb..834ee7637 100644 --- a/ml_tools/eolearn/ml_tools/train_test_split.py +++ b/ml_tools/eolearn/ml_tools/train_test_split.py @@ -9,13 +9,16 @@ from __future__ import annotations from enum import Enum -from typing import Any, List, Optional, Union +from typing import Any import numpy as np from eolearn.core import EOPatch, EOTask, FeatureType from eolearn.core.types import SingleFeatureSpec +# switching to np.random.Generator would change results +# ruff: noqa: NPY002 + class TrainTestSplitType(Enum): """An enum defining TrainTestSplitTask's methods of splitting the data into subsets""" @@ -65,13 +68,13 @@ def __init__( self, input_feature: SingleFeatureSpec, output_feature: SingleFeatureSpec, - bins: Union[float, List[Any]], + bins: float | list[Any], split_type: TrainTestSplitType = TrainTestSplitType.PER_PIXEL, - ignore_values: Optional[List[int]] = None, + ignore_values: list[int] | None = None, ): """ :param input_feature: The input feature to guide the split. - :param input_feature: The output feature where to save the mask. + :param output_feature: The output feature where to save the mask. :param bins: Cumulative probabilities of all value classes or a single float, representing a fraction. :param split_type: Value split type, either 'PER_PIXEL', 'PER_CLASS' or 'PER_VALUE'. :param ignore_values: A list of values in input_feature to ignore and not assign them to any subsets. @@ -79,23 +82,17 @@ def __init__( self.input_feature = self.parse_feature(input_feature, allowed_feature_types=[FeatureType.MASK_TIMELESS]) self.output_feature = self.parse_feature(output_feature, allowed_feature_types=[FeatureType.MASK_TIMELESS]) - if np.isscalar(bins): - bins = [bins] - - if ( - not isinstance(bins, list) - or not all(isinstance(bi, float) for bi in bins) - or np.any(np.diff(bins) <= 0) - or bins[0] <= 0 - or bins[-1] >= 1 - ): + if isinstance(bins, float): + self.bins = [bins] + else: + self.bins = list(bins) + if np.any(np.diff(self.bins) <= 0) or self.bins[0] <= 0 or self.bins[-1] >= 1: raise ValueError("bins argument should be a list of ascending floats inside an open interval (0, 1)") self.ignore_values = set() if ignore_values is None else set(ignore_values) - self.bins = bins self.split_type = TrainTestSplitType(split_type) - def execute(self, eopatch: EOPatch, *, seed: Optional[int] = None) -> EOPatch: + def execute(self, eopatch: EOPatch, *, seed: int | None = None) -> EOPatch: """ :param eopatch: input EOPatch :param seed: An argument to be passed to numpy.random.seed function. diff --git a/ml_tools/eolearn/ml_tools/utils.py b/ml_tools/eolearn/ml_tools/utils.py index b15b2fea0..57b9074d4 100644 --- a/ml_tools/eolearn/ml_tools/utils.py +++ b/ml_tools/eolearn/ml_tools/utils.py @@ -21,7 +21,7 @@ @deprecated_function( category=EODeprecationWarning, message_suffix="Please use `numpy.lib.stride_tricks.sliding_window_view` instead." ) -def rolling_window( +def rolling_window( # noqa: C901 array: np.ndarray, window: Any = (0,), asteps: Optional[Any] = None, diff --git a/ml_tools/eolearn/tests/test_sampling.py b/ml_tools/eolearn/tests/test_sampling.py index 41795c842..cb81d66b8 100644 --- a/ml_tools/eolearn/tests/test_sampling.py +++ b/ml_tools/eolearn/tests/test_sampling.py @@ -11,7 +11,6 @@ import numpy as np import pytest from numpy.testing import assert_array_equal -from pytest import approx from shapely.geometry import Point, Polygon from eolearn.core import EOPatch, EOTask, FeatureType @@ -21,7 +20,7 @@ @pytest.mark.parametrize( - "triangle, expected_points", + ("triangle", "expected_points"), [ ( Polygon([[-10, -12], [5, 10], [15, 4]]), @@ -66,7 +65,7 @@ def small_image_fixture() -> np.ndarray: @pytest.mark.parametrize( - "image, n_samples", + ("image", "n_samples"), [ (np.ones((100,)), {1: 100}), (np.ones((100, 100, 3)), {1: 100}), @@ -82,7 +81,7 @@ def test_sample_by_values_errors(image: np.ndarray, n_samples: Dict[int, int]) - @pytest.mark.parametrize("seed", range(5)) @pytest.mark.parametrize( - "n_samples, replace", + ("n_samples", "replace"), [ ({0: 100, 1: 200, 2: 30}, False), ({1: 200}, False), @@ -100,7 +99,7 @@ def test_sample_by_values(small_image: np.ndarray, seed: int, n_samples: Dict[in @pytest.mark.parametrize( - "rows, columns", + ("rows", "columns"), [ (np.array([1, 1, 2, 3, 4]), np.array([2, 3, 1, 1, 4])), ], @@ -135,7 +134,7 @@ def test_get_mask_of_samples(small_image: np.ndarray, n_samples: Dict[int, int]) def eopatch_fixture(small_image: np.ndarray) -> EOPatch: config = PatchGeneratorConfig(raster_shape=small_image.shape, depth_range=(5, 6), num_timestamps=10) patch = generate_eopatch([(FeatureType.DATA, "bands")], config=config) - patch.mask_timeless["raster"] = small_image.reshape(small_image.shape + (1,)) + patch.mask_timeless["raster"] = small_image.reshape((*small_image.shape, 1)) return patch @@ -189,8 +188,8 @@ def test_object_sampling_reproducibility(eopatch: EOPatch, seed: int, block_task @pytest.mark.parametrize( - "fraction, replace", - [[2, False], [-0.5, True], [{1: 0.5, 3: 0.4, 5: 1.2}, False], [{1: 0.5, 3: -0.4, 5: 1.2}, True], [(1, 0.4), True]], + ("fraction", "replace"), + [(2, False), (-0.5, True), ({1: 0.5, 3: 0.4, 5: 1.2}, False), ({1: 0.5, 3: -0.4, 5: 1.2}, True), ((1, 0.4), True)], ) def test_fraction_sampling_errors(fraction: Union[float, Dict[int, float]], replace: bool) -> None: with pytest.raises(ValueError): @@ -259,7 +258,7 @@ def test_fraction_sampling_input_fraction( for val, count in full.items(): if val not in exclude: - assert samples[val] == approx(count * fraction_task.fraction, abs=1) + assert samples[val] == pytest.approx(count * fraction_task.fraction, abs=1) @pytest.mark.parametrize("seed", range(3)) @@ -283,7 +282,7 @@ def test_fraction_sampling_input_dict(fraction_task: FractionSamplingTask, seed: exclude = fraction_task.exclude_values or [] # get rid of pesky None assert set(exclude).isdisjoint(set(sample_values)) assert set(sample_values).issubset(set(fraction_task.fraction)) - assert all(count == approx(full[val] * fraction_task.fraction[val], abs=1) for val, count in samples.items()) + assert all(count == pytest.approx(full[val] * fraction_task.fraction[val], abs=1) for val, count in samples.items()) @pytest.mark.parametrize("seed", range(3)) @@ -317,7 +316,7 @@ def grid_task_fixture(request) -> EOTask: @pytest.mark.parametrize("grid_task", [[(1, 1), (1, 1)], [(2, 3), (5, 3)], [(6, 5), (3, 3)]], indirect=True) -def test_grid_sampling_task(test_eopatch: EOPatch, grid_task: EOTask) -> None: +def test_grid_sampling_task(test_eopatch: EOPatch, grid_task: GridSamplingTask) -> None: # expected_shape calculated sample_size = grid_task.sample_size expected_shape = list(test_eopatch.data["BANDS-S2-L1C"].shape) @@ -336,11 +335,10 @@ def test_grid_sampling_task(test_eopatch: EOPatch, grid_task: EOTask) -> None: assert np.sum(eopatch[SAMPLE_MASK]) == height * width -@pytest.mark.parametrize("grid_task", [[(1, 1), (1, 1)], [(2, 3), (5, 3)], [(6, 5), (3, 3)]], indirect=True) -def test_grid_sampling_task_reproducibility(test_eopatch: EOPatch, grid_task: EOTask) -> None: - test_eopatch2 = test_eopatch.copy(deep=True) - eopatch = grid_task.execute(test_eopatch) - eopatch2 = grid_task.execute(test_eopatch2) +@pytest.mark.parametrize("grid_task", [[(1, 1), (1, 1)], [(2, 3), (5, 3)]], indirect=True) +def test_grid_sampling_task_reproducibility(test_eopatch: EOPatch, grid_task: GridSamplingTask) -> None: + eopatch1 = grid_task.execute(copy.copy(test_eopatch)) + eopatch2 = grid_task.execute(copy.copy(test_eopatch)) - assert eopatch is not eopatch2 - assert eopatch == eopatch2 + assert eopatch1 == eopatch2 + assert eopatch1 is not eopatch2 diff --git a/ml_tools/eolearn/tests/test_train_split.py b/ml_tools/eolearn/tests/test_train_split.py index a89fb2dad..4a8849f8f 100644 --- a/ml_tools/eolearn/tests/test_train_split.py +++ b/ml_tools/eolearn/tests/test_train_split.py @@ -17,13 +17,12 @@ INPUT_FEATURE = (FeatureType.MASK_TIMELESS, "TEST") OUTPUT_FEATURE = (FeatureType.MASK_TIMELESS, "TEST_TRAIN_MASK") -INPUT_FEATURE_CONFIG = PatchGeneratorConfig(raster_shape=(1000, 1000), depth_range=(3, 4)) +INPUT_FEATURE_CONFIG = PatchGeneratorConfig(raster_shape=(300, 300), depth_range=(3, 4)) @pytest.mark.parametrize( - "bad_arg, bad_kwargs", + ("bad_arg", "bad_kwargs"), [ - (None, {}), (1.5, {}), ([0.5, 0.3], {}), ([0.5], {"split_type": None}), @@ -35,7 +34,7 @@ def test_bad_args(bad_arg: Any, bad_kwargs: Any) -> None: TrainTestSplitTask(INPUT_FEATURE, OUTPUT_FEATURE, bad_arg, **bad_kwargs) -@pytest.fixture(name="eopatch1", scope="function") +@pytest.fixture(name="eopatch1") def eopatch1_fixture() -> EOPatch: return generate_eopatch(INPUT_FEATURE, config=INPUT_FEATURE_CONFIG) diff --git a/ml_tools/requirements-tdigest.txt b/ml_tools/requirements-tdigest.txt new file mode 100644 index 000000000..3dd522fdf --- /dev/null +++ b/ml_tools/requirements-tdigest.txt @@ -0,0 +1 @@ +tdigest==0.5.2.2 diff --git a/ml_tools/setup.py b/ml_tools/setup.py index 7c0cf3f8c..03047a3ab 100644 --- a/ml_tools/setup.py +++ b/ml_tools/setup.py @@ -7,9 +7,7 @@ def get_long_description(): this_directory = os.path.abspath(os.path.dirname(__file__)) with open(os.path.join(this_directory, "README.md"), encoding="utf-8") as f: - long_description = f.read() - - return long_description + return f.read() def parse_requirements(file): @@ -29,7 +27,7 @@ def get_version(): setup( name="eo-learn-ml-tools", - python_requires=">=3.7", + python_requires=">=3.8", version=get_version(), description="A collection of machine learning EOTasks and utilities", long_description=get_long_description(), @@ -47,7 +45,10 @@ def get_version(): packages=find_packages(), include_package_data=True, install_requires=parse_requirements("requirements.txt"), - extras_require={"PLOTTING": parse_requirements("requirements-plotting.txt")}, + extras_require={ + "PLOTTING": parse_requirements("requirements-plotting.txt"), + "TDIGEST": parse_requirements("requirements-tdigest.txt"), + }, zip_safe=False, classifiers=[ "Development Status :: 5 - Production/Stable", @@ -60,10 +61,10 @@ def get_version(): "Operating System :: Unix", "Programming Language :: Python", "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.7", "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", "Topic :: Scientific/Engineering", "Topic :: Scientific/Engineering :: GIS", "Topic :: Scientific/Engineering :: Image Processing", diff --git a/pyproject.toml b/pyproject.toml index 23c4d81ca..68e17c231 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -2,12 +2,78 @@ line-length = 120 preview = true -[tool.isort] -profile = "black" -known_first_party = "sentinelhub" -known_absolute = "eolearn" -sections = ["FUTURE","STDLIB","THIRDPARTY","FIRSTPARTY","ABSOLUTE","LOCALFOLDER"] -line_length = 120 +[tool.ruff] +line-length = 120 +target-version = "py38" +select = [ + "F", # pyflakes + "E", # pycodestyle + "W", # pycodestyle + "C90", # mccabe + "N", # naming + "YTT", # flake-2020 + "B", # bugbear + "A", # built-ins + "COM", # commas + "C4", # comprehensions + "T10", # debugger statements + "ISC", # implicit string concatenation + "ICN", # import conventions + "G", # logging format + "PIE", # flake8-pie + "T20", # print statements + "PT", # pytest style + "RET", # returns + "SLF", # private member access + "SIM", # simplifications + "ARG", # unused arguments + "PD", # pandas + "PGH", # pygrep hooks (useless noqa comments, eval statements etc.) + "FLY", # flynt + "RUF", # ruff rules + "NPY", # numpy + "I", # isort + "UP", # pyupgrade + "FA", # checks where future import of annotations would make types nicer +] +fix = true +fixable = [ + "I", # sort imports + "F401", # remove redundant imports + "UP007", # use new-style union type annotations + "UP006", # use new-style built-in type annotations + "UP037", # remove quotes around types when not necessary + "FA100", # import future annotations where necessary (not autofixable ATM) +] +ignore = [ + "C408", # complains about `dict()` calls, we use them to avoid too many " in the code + "SIM117", # wants to always combine `with` statements, gets ugly for us + "SIM108", # tries to aggresively inline `if`, not always readable + "A003", # complains when ATTRIBUTES shadow builtins, we have objects that implement `filter` and such + "COM812", # trailing comma missing, fights with black + "PD011", # suggests `.to_numpy` instead of `.values`, also does this for non-pandas objects... + # potentially fixable + "B904", # want `raise ... from None` instead of just `raise ...` + "B028", # always demands a stacklevel argument when warning + "PT011", # complains for `pytest.raises(ValueError)` but we use it a lot + "UP024", # wants to switch IOError with OSError +] +per-file-ignores = { "__init__.py" = ["F401"] } +exclude = [".git", "__pycache__", "build", "dist"] + + +[tool.ruff.isort] +section-order = [ + "future", + "standard-library", + "third-party", + "our-packages", + "first-party", + "local-folder", +] +known-first-party = ["eolearn"] +sections = { our-packages = ["sentinelhub"] } + [tool.pylint.format] max-line-length = 120 @@ -21,14 +87,14 @@ disable = [ "invalid-unary-operand-type", "unspecified-encoding", "unnecessary-ellipsis", - "use-dict-literal" + "use-dict-literal", ] [tool.pylint.design] max-args = 15 max-branches = 15 max-attributes = 20 -max-locals = 20 +max-locals = 21 min-public-methods = 0 [tool.pylint.similarities] @@ -44,9 +110,7 @@ overgeneral-exceptions = "builtins.Exception" max-nested-blocks = 7 [tool.pytest.ini_options] -markers = [ - "sh_integration: marks integration tests with Sentinel Hub service" -] +markers = ["sh_integration: marks integration tests with Sentinel Hub service"] [tool.coverage.run] source = [ @@ -57,23 +121,20 @@ source = [ "io", "mask", "ml_tools", - "visualization" + "visualization", ] [tool.coverage.report] -omit = [ - "*/setup.py", - "*/tests/*", - "*/__init__.py" -] +omit = ["*/setup.py", "*/tests/*", "*/__init__.py"] [tool.nbqa.addopts] -flake8 = [ - "--extend-ignore=E402" -] - -[tool.nbqa.exclude] -flake8 = "examples/core/CoreOverview.ipynb" +ruff = ["--extend-ignore=E402,T201,B015,B018,NPY002,UP,FA"] +# E402 -> imports on top +# T201 -> print found +# B015 & B018 -> useless expression (used to show values in ipynb) +# NPY002 -> use RNG instead of old numpy.random +# UP -> suggestions for new-style classes (future import might confuse readers) +# FA -> necessary future annotations import [tool.mypy] follow_imports = "normal" diff --git a/requirements-dev.txt b/requirements-dev.txt index 51f2c8498..928639c88 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,4 +1,3 @@ -codecov hypothesis moto mypy>=0.990 diff --git a/setup.py b/setup.py index 52830b638..9a3b897f0 100644 --- a/setup.py +++ b/setup.py @@ -20,8 +20,8 @@ def parse_requirements(file): setup( name="eo-learn", - python_requires=">=3.7", - version="1.4.1", + python_requires=">=3.8", + version="1.4.2", description="Earth observation processing framework for machine learning in Python", long_description=get_long_description(), long_description_content_type="text/markdown", @@ -38,14 +38,14 @@ def parse_requirements(file): packages=[], include_package_data=True, install_requires=[ - "eo-learn-core==1.4.1", - "eo-learn-coregistration==1.4.1", - "eo-learn-features==1.4.1", - "eo-learn-geometry==1.4.1", - "eo-learn-io==1.4.1", - "eo-learn-mask==1.4.1", - "eo-learn-ml-tools==1.4.1", - "eo-learn-visualization==1.4.1", + "eo-learn-core==1.4.2", + "eo-learn-coregistration==1.4.2", + "eo-learn-features==1.4.2", + "eo-learn-geometry==1.4.2", + "eo-learn-io==1.4.2", + "eo-learn-mask==1.4.2", + "eo-learn-ml-tools==1.4.2", + "eo-learn-visualization==1.4.2", ], extras_require={"DEV": parse_requirements("requirements-dev.txt")}, zip_safe=False, @@ -60,10 +60,10 @@ def parse_requirements(file): "Operating System :: Unix", "Programming Language :: Python", "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.7", "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", "Topic :: Scientific/Engineering", "Topic :: Scientific/Engineering :: GIS", "Topic :: Scientific/Engineering :: Image Processing", diff --git a/visualization/MANIFEST.in b/visualization/MANIFEST.in index 24451c2bc..84f45995e 100644 --- a/visualization/MANIFEST.in +++ b/visualization/MANIFEST.in @@ -2,4 +2,5 @@ include requirements*.txt include eolearn/visualization/report_templates/report.html include LICENSE include README.md +include eolearn/visualization/py.typed exclude eolearn/tests/* diff --git a/visualization/eolearn/tests/test_eopatch.py b/visualization/eolearn/tests/test_eopatch.py index de2c77603..71aae4c6f 100644 --- a/visualization/eolearn/tests/test_eopatch.py +++ b/visualization/eolearn/tests/test_eopatch.py @@ -23,7 +23,7 @@ def eopatch_fixture(): @pytest.mark.parametrize( - "feature, params", + ("feature", "params"), [ ((FeatureType.DATA, "BANDS-S2-L1C"), {"rgb": [3, 2, 1]}), ((FeatureType.DATA, "BANDS-S2-L1C"), {"times": [7, 14, 67], "channels": slice(4, 8)}), @@ -45,7 +45,7 @@ def eopatch_fixture(): (FeatureType.BBOX, {}), ], ) -@pytest.mark.sh_integration # python 3.7 dose not support matpotlib 3.6 +@pytest.mark.sh_integration() def test_eopatch_plot(eopatch, feature, params): """A simple test of EOPatch plotting for different features.""" # We reduce width and height otherwise running matplotlib.pyplot.subplots in combination with pytest would diff --git a/visualization/eolearn/visualization/__init__.py b/visualization/eolearn/visualization/__init__.py index 82a29cae4..9a191c134 100644 --- a/visualization/eolearn/visualization/__init__.py +++ b/visualization/eolearn/visualization/__init__.py @@ -4,4 +4,4 @@ from .eopatch import PlotBackend, PlotConfig -__version__ = "1.4.1" +__version__ = "1.4.2" diff --git a/visualization/eolearn/visualization/eoexecutor.py b/visualization/eolearn/visualization/eoexecutor.py index 25704b332..6c9c8451d 100644 --- a/visualization/eolearn/visualization/eoexecutor.py +++ b/visualization/eolearn/visualization/eoexecutor.py @@ -6,6 +6,8 @@ This source code is licensed under the MIT license, see the LICENSE file in the root directory of this source tree. """ +from __future__ import annotations + import base64 import datetime as dt import importlib @@ -13,7 +15,7 @@ import os import warnings from collections import defaultdict -from typing import Any, DefaultDict, Dict, List, Tuple, cast +from typing import Any, cast import fs import graphviz @@ -97,12 +99,12 @@ def _create_dependency_graph(self) -> str: dot = self.eoexecutor.workflow.dependency_graph() return base64.b64encode(dot.pipe()).decode() - def _get_exception_stats(self) -> List[Tuple[str, str, List[Tuple[str, int]]]]: + def _get_exception_stats(self) -> list[tuple[str, str, list[tuple[str, int]]]]: """Creates aggregated stats about exceptions""" formatter = HtmlFormatter() lexer = pygments.lexers.get_lexer_by_name("python", stripall=True) - exception_stats: DefaultDict[str, DefaultDict[str, int]] = defaultdict(lambda: defaultdict(lambda: 0)) + exception_stats: defaultdict[str, defaultdict[str, int]] = defaultdict(lambda: defaultdict(lambda: 0)) for workflow_results in self.eoexecutor.execution_results: if not workflow_results.error_node_uid: @@ -117,8 +119,8 @@ def _get_exception_stats(self) -> List[Tuple[str, str, List[Tuple[str, int]]]]: return self._to_ordered_stats(exception_stats) def _to_ordered_stats( - self, exception_stats: DefaultDict[str, DefaultDict[str, int]] - ) -> List[Tuple[str, str, List[Tuple[str, int]]]]: + self, exception_stats: defaultdict[str, defaultdict[str, int]] + ) -> list[tuple[str, str, list[tuple[str, int]]]]: """Exception stats get ordered by nodes in their execution order in workflows. Exception stats that happen for the same node get ordered by number of occurrences in a decreasing order. """ @@ -134,10 +136,10 @@ def _to_ordered_stats( return ordered_exception_stats - def _get_node_descriptions(self) -> List[Dict[str, Any]]: + def _get_node_descriptions(self) -> list[dict[str, Any]]: """Prepares a list of node names and initialization parameters of their tasks""" descriptions = [] - name_counts: Dict[str, int] = defaultdict(lambda: 0) + name_counts: dict[str, int] = defaultdict(lambda: 0) for node in self.eoexecutor.workflow.get_nodes(): node_name = node.get_name(name_counts[node.get_name()]) @@ -148,7 +150,7 @@ def _get_node_descriptions(self) -> List[Dict[str, Any]]: "name": f"{node_name} ({node.uid})", "uid": node.uid, "args": { - key: value.replace("<", "<").replace(">", ">") # type: ignore + key: value.replace("<", "<").replace(">", ">") # type: ignore[attr-defined] for key, value in node.task.private_task_config.init_args.items() }, } @@ -156,7 +158,7 @@ def _get_node_descriptions(self) -> List[Dict[str, Any]]: return descriptions - def _render_task_sources(self, formatter: pygments.formatter.Formatter) -> Dict[str, Any]: + def _render_task_sources(self, formatter: pygments.formatter.Formatter) -> dict[str, Any]: """Renders source code of EOTasks""" lexer = pygments.lexers.get_lexer_by_name("python", stripall=True) sources = {} @@ -211,9 +213,8 @@ def _get_template(self) -> Template: env = Environment(loader=FileSystemLoader(templates_dir)) env.filters["datetime"] = self._format_datetime env.globals.update(timedelta=self._format_timedelta) - template = env.get_template(self.eoexecutor.REPORT_FILENAME) - return template + return env.get_template(self.eoexecutor.REPORT_FILENAME) @staticmethod def _format_datetime(value: dt.datetime) -> str: diff --git a/visualization/eolearn/visualization/eopatch.py b/visualization/eolearn/visualization/eopatch.py index a0da4c0be..f423e9008 100644 --- a/visualization/eolearn/visualization/eopatch.py +++ b/visualization/eolearn/visualization/eopatch.py @@ -12,7 +12,7 @@ import itertools as it from dataclasses import dataclass, field from enum import Enum -from typing import Any, Dict, List, Optional, Union +from typing import Any import matplotlib.pyplot as plt import numpy as np @@ -20,9 +20,10 @@ from pyproj import CRS from eolearn.core import EOPatch, FeatureType +from eolearn.core.constants import TIMESTAMP_COLUMN from eolearn.core.types import SingleFeatureSpec - -from .eopatch_base import BaseEOPatchVisualization, BasePlotConfig +from eolearn.core.utils.common import is_discrete_type +from eolearn.core.utils.parsing import parse_feature class PlotBackend(Enum): @@ -31,7 +32,7 @@ class PlotBackend(Enum): MATPLOTLIB = "matplotlib" -def plot_eopatch(*args: Any, backend: Union[PlotBackend, str] = PlotBackend.MATPLOTLIB, **kwargs: Any) -> object: +def plot_eopatch(*args: Any, backend: PlotBackend | str = PlotBackend.MATPLOTLIB, **kwargs: Any) -> object: """The main `EOPatch` plotting function. It pr :param args: Positional arguments to be propagated to a plotting backend. @@ -48,9 +49,13 @@ def plot_eopatch(*args: Any, backend: Union[PlotBackend, str] = PlotBackend.MATP @dataclass -class PlotConfig(BasePlotConfig): +class PlotConfig: """Advanced plotting configurations + :param rgb_factor: A factor by which to scale RGB images to make them look better. + :param timestamp_column: A name of a column containing timestamps in a `GeoDataFrame` feature. If set to `None` it + will plot temporal vector features as if they were timeless. + :param geometry_column: A name of a column containing geometries in a `GeoDataFrame` feature. :param subplot_width: A width of each subplot in a grid :param subplot_height: A height of each subplot in a grid :param subplot_kwargs: A dictionary of parameters that will be passed to `matplotlib.pyplot.subplots` function. @@ -61,39 +66,65 @@ class PlotConfig(BasePlotConfig): box. """ - subplot_width: Union[float, int] = 8 - subplot_height: Union[float, int] = 8 + rgb_factor: float | None = 3.5 + timestamp_column: str | None = TIMESTAMP_COLUMN + geometry_column: str = "geometry" + subplot_width: float | int = 8 + subplot_height: float | int = 8 interpolation: str = "none" - subplot_kwargs: Dict[str, object] = field(default_factory=dict) + subplot_kwargs: dict[str, object] = field(default_factory=dict) show_title: bool = True - title_kwargs: Dict[str, object] = field(default_factory=dict) - label_kwargs: Dict[str, object] = field(default_factory=dict) - bbox_kwargs: Dict[str, object] = field(default_factory=dict) + title_kwargs: dict[str, object] = field(default_factory=dict) + label_kwargs: dict[str, object] = field(default_factory=dict) + bbox_kwargs: dict[str, object] = field(default_factory=dict) -class MatplotlibVisualization(BaseEOPatchVisualization): +class MatplotlibVisualization: """EOPatch visualization using `matplotlib` framework.""" - config: PlotConfig - def __init__( self, eopatch: EOPatch, feature: SingleFeatureSpec, *, - axes: Optional[np.ndarray] = None, - config: Optional[PlotConfig] = None, - **kwargs: Any, + axes: np.ndarray | None = None, + config: PlotConfig | None = None, + times: list[int] | slice | None = None, + channels: list[int] | slice | None = None, + channel_names: list[str] | None = None, + rgb: tuple[int, int, int] | None = None, ): """ :param eopatch: An EOPatch with a feature to plot. :param feature: A feature from the given EOPatch to plot. :param axes: A grid of axes on which to write plots. If not provided it will create a new grid. :param config: A configuration object with advanced plotting parameters. - :param kwargs: Parameters to be passed to the base class. + :param times: A list or a slice of indices on temporal axis to be used for plotting. If not provided all + indices will be used. + :param channels: A list or a slice of indices on channels axis to be used for plotting. If not provided all + indices will be used. + :param channel_names: Names of channels of the last dimension in the given raster feature. + :param rgb: If provided, it should be a list of 3 indices of RGB channels to be plotted. It will plot only RGB + images with these channels. This only works for raster features with spatial dimension. """ - config = config or PlotConfig() - super().__init__(eopatch, feature, config=config, **kwargs) + self.eopatch = eopatch + self.feature = parse_feature(feature) + feature_type, _ = self.feature + self.config = config or PlotConfig() + + if times is not None and not feature_type.is_temporal(): + raise ValueError("Parameter times can only be provided for temporal features.") + self.times = times + + self.channels = channels + self.channel_names = None if channel_names is None else [str(name) for name in channel_names] + + if rgb and not (feature_type.is_spatial() and feature_type.is_array()): + raise ValueError("Parameter rgb can only be provided for plotting spatial raster features.") + self.rgb = rgb + + if self.channels and self.rgb: + raise ValueError("Only one of parameters channels and rgb can be provided.") if axes is not None and not isinstance(axes, np.ndarray): axes = np.array([np.array([axes])]) # type: ignore[unreachable] @@ -102,7 +133,7 @@ def __init__( def plot(self) -> np.ndarray: """Plots the given feature""" feature_type, feature_name = self.feature - data, timestamps = self.collect_and_prepare_feature() + data, timestamps = self.collect_and_prepare_feature(self.eopatch) if feature_type is FeatureType.BBOX: return self._plot_bbox() @@ -126,8 +157,60 @@ def plot(self) -> np.ndarray: return self._plot_bar(data, title=feature_name) return self._plot_time_series(data, timestamps=timestamps, title=feature_name) + def collect_and_prepare_feature(self, eopatch: EOPatch) -> tuple[Any, list[dt.datetime]]: + """Collects a feature from EOPatch and modifies it according to plotting parameters""" + feature_type, _ = self.feature + data = eopatch[self.feature] + timestamps = eopatch.timestamps + + if feature_type.is_array(): + if self.times is not None: + data = data[self.times, ...] + if timestamps: + timestamps = list(np.array(timestamps)[self.times]) + + if self.channels is not None: + data = data[..., self.channels] + + if feature_type.is_spatial() and self.rgb: + data = self._prepare_rgb_data(data) + + number_of_plot_columns = 1 if self.rgb else data.shape[-1] + if self.channel_names and len(self.channel_names) != number_of_plot_columns: + raise ValueError( + f"Provided {len(self.channel_names)} channel names but attempting to make plots with " + f"{number_of_plot_columns} columns for the given feature channels." + ) + + if feature_type.is_vector() and self.times is not None: + data = self._filter_temporal_dataframe(data) + + return data, timestamps + + def _prepare_rgb_data(self, data: np.ndarray) -> np.ndarray: + """Prepares data array for RGB plotting""" + data = data[..., self.rgb] + + if self.config.rgb_factor is not None: + data = data * self.config.rgb_factor + + if is_discrete_type(data.dtype): + data = np.clip(data, 0, 255) + else: + data = np.clip(data, 0.0, 1.0) + + return data + + def _filter_temporal_dataframe(self, dataframe: GeoDataFrame) -> GeoDataFrame: + """Prepares a list of unique timestamps from the dataframe, applies filter on them and returns a new + dataframe with rows that only contain filtered timestamps.""" + unique_timestamps = dataframe[self.config.timestamp_column].unique() + filtered_timestamps = np.sort(unique_timestamps)[self.times] + filtered_rows = dataframe[self.config.timestamp_column].isin(filtered_timestamps) + return dataframe[filtered_rows] + def _plot_raster_grid( - self, raster: np.ndarray, timestamps: Optional[List[dt.datetime]] = None, title: Optional[str] = None + self, raster: np.ndarray, timestamps: list[dt.datetime] | None = None, title: str | None = None ) -> np.ndarray: """Plots a grid of raster images""" rows, _, _, columns = raster.shape @@ -155,7 +238,7 @@ def _plot_raster_grid( return axes def _plot_time_series( - self, series: np.ndarray, timestamps: Optional[List[dt.datetime]] = None, title: Optional[str] = None + self, series: np.ndarray, timestamps: list[dt.datetime] | None = None, title: str | None = None ) -> np.ndarray: """Plots time series feature.""" axes = self._provide_axes(nrows=1, ncols=1, title=title) @@ -171,7 +254,7 @@ def _plot_time_series( axis.legend() return axes - def _plot_bar(self, values: np.ndarray, title: Optional[str] = None) -> np.ndarray: + def _plot_bar(self, values: np.ndarray, title: str | None = None) -> np.ndarray: """Make a bar plot from values.""" axes = self._provide_axes(nrows=1, ncols=1, title=title) axis = axes.flatten()[0] @@ -182,7 +265,7 @@ def _plot_bar(self, values: np.ndarray, title: Optional[str] = None) -> np.ndarr return axes def _plot_vector_feature( - self, dataframe: GeoDataFrame, timestamp_column: Optional[str] = None, title: Optional[str] = None + self, dataframe: GeoDataFrame, timestamp_column: str | None = None, title: str | None = None ) -> np.ndarray: """Plots a GeoDataFrame vector feature""" rows = len(dataframe[timestamp_column].unique()) if timestamp_column else 1 @@ -205,7 +288,7 @@ def _plot_vector_feature( return axes - def _plot_bbox(self, axes: Optional[np.ndarray] = None, target_crs: Optional[CRS] = None) -> np.ndarray: + def _plot_bbox(self, axes: np.ndarray | None = None, target_crs: CRS | None = None) -> np.ndarray: """Plot a bounding box""" bbox = self.eopatch.bbox if bbox is None: @@ -231,9 +314,7 @@ def _plot_bbox(self, axes: Optional[np.ndarray] = None, target_crs: Optional[CRS return axes - def _provide_axes( - self, *, nrows: int, ncols: int, title: Optional[str] = None, **subplot_kwargs: Any - ) -> np.ndarray: + def _provide_axes(self, *, nrows: int, ncols: int, title: str | None = None, **subplot_kwargs: Any) -> np.ndarray: """Either provides an existing grid of axes or creates new one""" if self.axes is not None: return self.axes @@ -258,6 +339,6 @@ def _provide_axes( return axes - def _get_label_kwargs(self) -> Dict[str, object]: + def _get_label_kwargs(self) -> dict[str, object]: """Provides `matplotlib` arguments for writing labels in plots.""" return {"fontsize": 12, **self.config.label_kwargs} diff --git a/visualization/eolearn/visualization/eopatch_base.py b/visualization/eolearn/visualization/eopatch_base.py deleted file mode 100644 index 04345f607..000000000 --- a/visualization/eolearn/visualization/eopatch_base.py +++ /dev/null @@ -1,142 +0,0 @@ -""" -This module implements base objects for `EOPatch` visualizations. - -Copyright (c) 2017- Sinergise and contributors -For the full list of contributors, see the CREDITS file in the root directory of this source tree. - -This source code is licensed under the MIT license, see the LICENSE file in the root directory of this source tree. -""" -from __future__ import annotations - -import abc -import datetime as dt -from dataclasses import dataclass -from typing import Any, List, Optional, Tuple, Union - -import numpy as np -from geopandas import GeoDataFrame - -from eolearn.core import EOPatch -from eolearn.core.constants import TIMESTAMP_COLUMN -from eolearn.core.types import SingleFeatureSpec -from eolearn.core.utils.common import is_discrete_type -from eolearn.core.utils.parsing import parse_feature - - -@dataclass -class BasePlotConfig: - """A base class for advanced plotting configuration parameters. - - :param rgb_factor: A factor by which to scale RGB images to make them look better. - :param timestamp_column: A name of a column containing timestamps in a `GeoDataFrame` feature. If set to `None` it - will plot temporal vector features as if they were timeless. - :param geometry_column: A name of a column containing geometries in a `GeoDataFrame` feature. - """ - - rgb_factor: Optional[float] = 3.5 - timestamp_column: Optional[str] = TIMESTAMP_COLUMN - geometry_column: str = "geometry" - - -class BaseEOPatchVisualization(metaclass=abc.ABCMeta): - """A base class for EOPatch visualization""" - - def __init__( - self, - eopatch: EOPatch, - feature: SingleFeatureSpec, - *, - config: BasePlotConfig, - times: Union[List[int], slice, None] = None, - channels: Union[List[int], slice, None] = None, - channel_names: Optional[List[str]] = None, - rgb: Optional[Tuple[int, int, int]] = None, - ): - """ - :param eopatch: An EOPatch with a feature to plot. - :param feature: A feature from the given EOPatch to plot. - :param config: A configuration object with advanced plotting parameters. - :param times: A list or a slice of indices on temporal axis to be used for plotting. If not provided all - indices will be used. - :param channels: A list or a slice of indices on channels axis to be used for plotting. If not provided all - indices will be used. - :param channel_names: Names of channels of the last dimension in the given raster feature. - :param rgb: If provided, it should be a list of 3 indices of RGB channels to be plotted. It will plot only RGB - images with these channels. This only works for raster features with spatial dimension. - """ - self.eopatch = eopatch - self.feature = parse_feature(feature) - feature_type, _ = self.feature - self.config = config - - if times is not None and not feature_type.is_temporal(): - raise ValueError("Parameter times can only be provided for temporal features.") - self.times = times - - self.channels = channels - self.channel_names = None if channel_names is None else [str(name) for name in channel_names] - - if rgb and len(rgb) != 3: - raise ValueError(f"Parameter rgb should be a list of 3 indices but got {rgb}") - if rgb and not (feature_type.is_spatial() and feature_type.is_array()): - raise ValueError("Parameter rgb can only be provided for plotting spatial raster features.") - self.rgb = rgb - - if self.channels and self.rgb: - raise ValueError("Only one of parameters channels and rgb can be provided.") - - @abc.abstractmethod - def plot(self) -> object: - """Plots the given feature""" - - def collect_and_prepare_feature(self) -> Tuple[Any, List[dt.datetime]]: - """Collects a feature from EOPatch and modifies it according to plotting parameters""" - feature_type, _ = self.feature - data = self.eopatch[self.feature] - timestamps = self.eopatch.timestamps - - if feature_type.is_array(): - if self.times is not None: - data = data[self.times, ...] - if timestamps: - timestamps = list(np.array(timestamps)[self.times]) - - if self.channels is not None: - data = data[..., self.channels] - - if feature_type.is_spatial() and self.rgb: - data = self._prepare_rgb_data(data) - - number_of_plot_columns = 1 if self.rgb else data.shape[-1] - if self.channel_names and len(self.channel_names) != number_of_plot_columns: - raise ValueError( - f"Provided {len(self.channel_names)} channel names but attempting to make plots with " - f"{number_of_plot_columns} columns for the given feature channels." - ) - - if feature_type.is_vector() and self.times is not None: - data = self._filter_temporal_dataframe(data) - - return data, timestamps - - def _prepare_rgb_data(self, data: np.ndarray) -> np.ndarray: - """Prepares data array for RGB plotting""" - data = data[..., self.rgb] - - if self.config.rgb_factor is not None: - data = data * self.config.rgb_factor - - if is_discrete_type(data.dtype): - data = np.clip(data, 0, 255) - else: - data = np.clip(data, 0.0, 1.0) - - return data - - def _filter_temporal_dataframe(self, dataframe: GeoDataFrame) -> GeoDataFrame: - """Prepares a list of unique timestamps from the dataframe, applies filter on them and returns a new - dataframe with rows that only contain filtered timestamps.""" - unique_timestamps = dataframe[self.config.timestamp_column].unique() - filtered_timestamps = np.sort(unique_timestamps)[self.times] - filtered_rows = dataframe[self.config.timestamp_column].isin(filtered_timestamps) - return dataframe[filtered_rows] diff --git a/visualization/eolearn/visualization/eoworkflow.py b/visualization/eolearn/visualization/eoworkflow.py index 5a2e8d935..fb40fd395 100644 --- a/visualization/eolearn/visualization/eoworkflow.py +++ b/visualization/eolearn/visualization/eoworkflow.py @@ -6,7 +6,9 @@ This source code is licensed under the MIT license, see the LICENSE file in the root directory of this source tree. """ -from typing import Dict, List, Optional, Sequence +from __future__ import annotations + +from typing import Sequence from graphviz import Digraph @@ -22,7 +24,7 @@ def __init__(self, nodes: Sequence[EONode]): """ self.nodes = nodes - def dependency_graph(self, filename: Optional[str] = None) -> Digraph: + def dependency_graph(self, filename: str | None = None) -> Digraph: """Visualize the computational graph. :param filename: Filename of the output image together with file extension. Supported formats: `png`, `jpg`, @@ -54,10 +56,10 @@ def get_dot(self) -> Digraph: return dot @staticmethod - def _get_node_uid_to_dot_name_mapping(nodes: Sequence[EONode]) -> Dict[str, str]: + def _get_node_uid_to_dot_name_mapping(nodes: Sequence[EONode]) -> dict[str, str]: """Creates mapping between EONode classes and names used in DOT graph. To do that, it has to collect nodes with the same name and assign them different indices.""" - dot_name_to_nodes: Dict[str, List[EONode]] = {} + dot_name_to_nodes: dict[str, list[EONode]] = {} for node in nodes: dot_name_to_nodes[node.get_name()] = dot_name_to_nodes.get(node.get_name(), []) dot_name_to_nodes[node.get_name()].append(node) diff --git a/visualization/eolearn/visualization/py.typed b/visualization/eolearn/visualization/py.typed new file mode 100644 index 000000000..e69de29bb diff --git a/visualization/setup.py b/visualization/setup.py index 7805257f1..e554008ed 100644 --- a/visualization/setup.py +++ b/visualization/setup.py @@ -7,9 +7,7 @@ def get_long_description(): this_directory = os.path.abspath(os.path.dirname(__file__)) with open(os.path.join(this_directory, "README.md"), encoding="utf-8") as f: - long_description = f.read() - - return long_description + return f.read() def parse_requirements(file): @@ -29,7 +27,7 @@ def get_version(): setup( name="eo-learn-visualization", - python_requires=">=3.7", + python_requires=">=3.8", version=get_version(), description="A collection of visualization utilities", long_description=get_long_description(), @@ -60,10 +58,10 @@ def get_version(): "Operating System :: Unix", "Programming Language :: Python", "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.7", "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", "Topic :: Scientific/Engineering", "Topic :: Scientific/Engineering :: GIS", "Topic :: Scientific/Engineering :: Image Processing",