diff --git a/ocf_datapipes/transform/xarray/pv/ensure_n_pv_systems_per_example.py b/ocf_datapipes/transform/xarray/pv/ensure_n_pv_systems_per_example.py index d509fe4ec..0d9a4d671 100644 --- a/ocf_datapipes/transform/xarray/pv/ensure_n_pv_systems_per_example.py +++ b/ocf_datapipes/transform/xarray/pv/ensure_n_pv_systems_per_example.py @@ -1,5 +1,6 @@ """Ensure there is N PV systems per example""" import logging +from typing import Optional import numpy as np import xarray as xr @@ -13,7 +14,14 @@ class EnsureNPVSystemsPerExampleIterDataPipe(IterDataPipe): """Ensure there is N PV systems per example""" - def __init__(self, source_datapipe: IterDataPipe, n_pv_systems_per_example: int, seed=None): + def __init__( + self, + source_datapipe: IterDataPipe, + n_pv_systems_per_example: int, + seed=None, + method: str = "random", + locations_datapipe: Optional[IterDataPipe] = None, + ): """ Ensure there is N PV systems per example @@ -21,21 +29,47 @@ def __init__(self, source_datapipe: IterDataPipe, n_pv_systems_per_example: int, source_datapipe: Datapipe of PV data n_pv_systems_per_example: Number of PV systems to have in example seed: Random seed for choosing + method: method for picking PV systems. Can be 'random' or 'closest' + locations_datapipe: location of this example. + Can be None as its only needed for 'closest' """ self.source_datapipe = source_datapipe self.n_pv_systems_per_example = n_pv_systems_per_example self.rng = np.random.default_rng(seed=seed) + self.method = method + self.locations_datapipe = locations_datapipe + + assert method in ["random", "closest"] + + if method == "closest": + assert ( + locations_datapipe is not None + ), "If you are slect closest PV systems, then a location data pipe is needed" def __iter__(self): for xr_data in self.source_datapipe: if len(xr_data.pv_system_id) > self.n_pv_systems_per_example: logger.debug(f"Reducing PV systems to {self.n_pv_systems_per_example}") # More PV systems are available than we need. Reduce by randomly sampling: - subset_of_pv_system_ids = self.rng.choice( - xr_data.pv_system_id, - size=self.n_pv_systems_per_example, - replace=False, - ) + if self.method == "random": + subset_of_pv_system_ids = self.rng.choice( + xr_data.pv_system_id, + size=self.n_pv_systems_per_example, + replace=False, + ) + elif self.method == "closest": + + location = next(self.locations_datapipe) + + # get distance + delta_x = xr_data.x_osgb - location.x + delta_y = xr_data.y_osgb - location.y + r2 = delta_x**2 + delta_y**2 + + # order and select closest + r2 = r2.sortby(r2) + subset_of_pv_system_ids = r2.pv_system_id[: self.n_pv_systems_per_example] + xr_data = xr_data.sel(pv_system_id=subset_of_pv_system_ids) elif len(xr_data.pv_system_id) < self.n_pv_systems_per_example: logger.debug("Padding out PV systems") diff --git a/tests/transform/xarray/test_create_pv_image.py b/tests/transform/xarray/pv/test_create_pv_image.py similarity index 100% rename from tests/transform/xarray/test_create_pv_image.py rename to tests/transform/xarray/pv/test_create_pv_image.py diff --git a/tests/transform/xarray/pv/test_ensure_n_pv_systems_per_example.py b/tests/transform/xarray/pv/test_ensure_n_pv_systems_per_example.py new file mode 100644 index 000000000..8c764585a --- /dev/null +++ b/tests/transform/xarray/pv/test_ensure_n_pv_systems_per_example.py @@ -0,0 +1,57 @@ +import numpy as np +import pytest + +from ocf_datapipes.transform.xarray import EnsureNPVSystemsPerExample +from ocf_datapipes.utils.consts import Location + + +def test_ensure_n_pv_systems_per_example_expand(passiv_datapipe): + + data_before = next(iter(passiv_datapipe)) + + passiv_datapipe = EnsureNPVSystemsPerExample(passiv_datapipe, n_pv_systems_per_example=12) + data_after = next(iter(passiv_datapipe)) + + assert len(data_before[0, :]) == 2 + assert len(data_after[0, :]) == 12 + + +def test_ensure_n_pv_systems_per_example_random(passiv_datapipe): + + data_before = next(iter(passiv_datapipe)) + + passiv_datapipe = EnsureNPVSystemsPerExample(passiv_datapipe, n_pv_systems_per_example=1) + data_after = next(iter(passiv_datapipe)) + + assert len(data_before[0, :]) == 2 + assert len(data_after[0, :]) == 1 + + +def test_ensure_n_pv_systems_per_example_closest_error(passiv_datapipe): + + with pytest.raises(Exception): + _ = EnsureNPVSystemsPerExample( + passiv_datapipe, n_pv_systems_per_example=1, method="closest" + ) + + +def test_ensure_n_pv_systems_per_example_closest(passiv_datapipe): + + # make fake location datapipe + location = Location(x=2.687e05, y=6.267e05) + location_datapipe = iter([location]) + + data_before = next(iter(passiv_datapipe)) + + passiv_datapipe = EnsureNPVSystemsPerExample( + passiv_datapipe, + n_pv_systems_per_example=1, + method="closest", + locations_datapipe=location_datapipe, + ) + + data_after = next(iter(passiv_datapipe)) + + assert len(data_before[0, :]) == 2 + assert len(data_after[0, :]) == 1 + assert data_after.pv_system_id[0] == 9960 diff --git a/tests/transform/xarray/test_fill_night_time_nans_with_zeros.py b/tests/transform/xarray/pv/test_fill_night_time_nans_with_zeros.py similarity index 100% rename from tests/transform/xarray/test_fill_night_time_nans_with_zeros.py rename to tests/transform/xarray/pv/test_fill_night_time_nans_with_zeros.py diff --git a/tests/transform/xarray/test_pv_power_rolling_window.py b/tests/transform/xarray/pv/test_pv_power_rolling_window.py similarity index 100% rename from tests/transform/xarray/test_pv_power_rolling_window.py rename to tests/transform/xarray/pv/test_pv_power_rolling_window.py diff --git a/tests/transform/xarray/test_pv_remove_zero_data.py b/tests/transform/xarray/pv/test_pv_remove_zero_data.py similarity index 100% rename from tests/transform/xarray/test_pv_remove_zero_data.py rename to tests/transform/xarray/pv/test_pv_remove_zero_data.py