-
Notifications
You must be signed in to change notification settings - Fork 387
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add DL4GAMAlps dataset #2508
base: main
Are you sure you want to change the base?
Add DL4GAMAlps dataset #2508
Changes from all commits
6eb701d
6c986a1
c959d46
0e919c6
6fa50bc
871d080
5d7b27a
d5413cb
890af53
cb3a9bc
b1202fd
c7ec25c
4b9ba74
e9675b1
874409d
86bb374
58de111
98da858
640f72a
0ca7b34
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||
---|---|---|---|---|---|---|---|---|---|---|
|
@@ -87,6 +87,8 @@ datasets = [ | |||||||||
"h5py>=3.6", | ||||||||||
# laspy 2+ required for laspy.read | ||||||||||
"laspy>=2", | ||||||||||
# netcdf4 1.5.4+ required for xarray.open_dataset with engine="netcdf4" | ||||||||||
"netcdf4>=1.5.4", | ||||||||||
Comment on lines
+90
to
+91
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Older versions might work, but are annoying to test in CI due to lack of wheels. |
||||||||||
# opencv-python 4.5.4+ required for Python 3.10 wheels | ||||||||||
"opencv-python>=4.5.4", | ||||||||||
# pandas 2+ required for parquet extra | ||||||||||
|
@@ -97,6 +99,8 @@ datasets = [ | |||||||||
"scikit-image>=0.19", | ||||||||||
# scipy 1.7.2+ required for Python 3.10 wheels | ||||||||||
"scipy>=1.7.2", | ||||||||||
# xarray 2023.9+ required for xarray.open_dataset | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Pretty sure xarray.open_dataset existed before 2023, what error do you see for older versions? |
||||||||||
"xarray>=2023.9", | ||||||||||
] | ||||||||||
docs = [ | ||||||||||
# ipywidgets 7+ required by nbsphinx | ||||||||||
|
dcodrut marked this conversation as resolved.
Show resolved
Hide resolved
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also need to update |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,8 +1,10 @@ | ||
# datasets | ||
h5py==3.12.1 | ||
laspy==2.5.4 | ||
netcdf4==1.7.2 | ||
opencv-python==4.11.0.86 | ||
pandas[parquet]==2.2.3 | ||
pycocotools==2.0.8 | ||
scikit-image==0.25.0 | ||
scipy==1.15.1 | ||
xarray==2024.11.0 |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,143 @@ | ||
#!/usr/bin/env python3 | ||
|
||
# Copyright (c) Microsoft Corporation. All rights reserved. | ||
# Licensed under the MIT License. | ||
|
||
import hashlib | ||
import shutil | ||
from pathlib import Path | ||
|
||
import numpy as np | ||
import pandas as pd | ||
import xarray as xr | ||
|
||
# define the patch size | ||
PATCH_SIZE = 16 | ||
|
||
# create a random generator | ||
rg = np.random.RandomState(42) | ||
|
||
|
||
def create_dummy_sample(fp: str | Path) -> None: | ||
# create the random S2 bands data; make the last two bands as binary masks | ||
band_data = rg.randint( | ||
low=0, high=10000, dtype=np.int16, size=(15, PATCH_SIZE, PATCH_SIZE) | ||
) | ||
band_data[-2:] = (band_data[-2:] > 5000).astype(np.int16) | ||
|
||
data_dict = { | ||
'band_data': { | ||
'dims': ('band', 'y', 'x'), | ||
'data': band_data, | ||
'attrs': { | ||
'long_name': [ | ||
'B1', | ||
'B2', | ||
'B3', | ||
'B4', | ||
'B5', | ||
'B6', | ||
'B7', | ||
'B8', | ||
'B8A', | ||
'B9', | ||
'B10', | ||
'B11', | ||
'B12', | ||
'CLOUDLESS_MASK', | ||
'FILL_MASK', | ||
], | ||
'_FillValue': -9999, | ||
}, | ||
}, | ||
'mask_all_g_id': { # glaciers mask (with -1 for no-glacier and GLACIER_ID for glacier) | ||
'dims': ('y', 'x'), | ||
'data': rg.choice([-1, 8, 9, 30, 35], size=(PATCH_SIZE, PATCH_SIZE)).astype( | ||
np.int32 | ||
), | ||
'attrs': {'_FillValue': -1}, | ||
}, | ||
'mask_debris': { | ||
'dims': ('y', 'x'), | ||
'data': (rg.random((PATCH_SIZE, PATCH_SIZE)) > 0.5).astype(np.int8), | ||
'attrs': {'_FillValue': -1}, | ||
}, | ||
} | ||
|
||
# add the additional variables | ||
for v in [ | ||
'dem', | ||
'slope', | ||
'aspect', | ||
'planform_curvature', | ||
'profile_curvature', | ||
'terrain_ruggedness_index', | ||
'dhdt', | ||
'v', | ||
]: | ||
data_dict[v] = { | ||
'dims': ('y', 'x'), | ||
'data': (rg.random((PATCH_SIZE, PATCH_SIZE)) * 100).astype(np.float32), | ||
'attrs': {'_FillValue': -9999}, | ||
} | ||
|
||
# create the xarray dataset and save it | ||
nc = xr.Dataset.from_dict(data_dict) | ||
nc.to_netcdf(fp) | ||
|
||
|
||
def create_splits_df(fp: str | Path) -> pd.DataFrame: | ||
# create a dataframe with the splits for the 4 glaciers | ||
splits_df = pd.DataFrame( | ||
{ | ||
'entry_id': ['g_0008', 'g_0009', 'g_0030', 'g_0035'], | ||
'split_1': ['fold_train', 'fold_train', 'fold_valid', 'fold_test'], | ||
'split_2': ['fold_train', 'fold_valid', 'fold_train', 'fold_test'], | ||
'split_3': ['fold_train', 'fold_valid', 'fold_test', 'fold_train'], | ||
'split_4': ['fold_test', 'fold_valid', 'fold_train', 'fold_train'], | ||
'split_5': ['fold_test', 'fold_train', 'fold_train', 'fold_valid'], | ||
} | ||
) | ||
|
||
splits_df.to_csv(fp_splits, index=False) | ||
print(f'Splits dataframe saved to {fp_splits}') | ||
return splits_df | ||
|
||
|
||
if __name__ == '__main__': | ||
# prepare the paths | ||
fp_splits = Path('splits.csv') | ||
fp_dir_ds_small = Path('dataset_small') | ||
fp_dir_ds_large = Path('dataset_large') | ||
|
||
# cleanup | ||
fp_splits.unlink(missing_ok=True) | ||
fp_dir_ds_small.with_suffix('.tar.gz').unlink(missing_ok=True) | ||
fp_dir_ds_large.with_suffix('.tar.gz').unlink(missing_ok=True) | ||
shutil.rmtree(fp_dir_ds_small, ignore_errors=True) | ||
shutil.rmtree(fp_dir_ds_large, ignore_errors=True) | ||
|
||
# create the splits dataframe | ||
split_df = create_splits_df(fp_splits) | ||
|
||
# create the two datasets versions (small and large) with 1 and 2 patches per glacier, respectively | ||
for fp_dir, num_patches in zip([fp_dir_ds_small, fp_dir_ds_large], [1, 2]): | ||
for glacier_id in split_df.entry_id: | ||
for i in range(num_patches): | ||
fp = fp_dir / glacier_id / f'{glacier_id}_patch_{i}.nc' | ||
fp.parent.mkdir(parents=True, exist_ok=True) | ||
create_dummy_sample(fp=fp) | ||
|
||
# archive the datasets | ||
for fp_dir in [fp_dir_ds_small, fp_dir_ds_large]: | ||
shutil.make_archive(str(fp_dir), 'gztar', fp_dir) | ||
|
||
# compute checksums | ||
for fp in [ | ||
fp_dir_ds_small.with_suffix('.tar.gz'), | ||
fp_dir_ds_large.with_suffix('.tar.gz'), | ||
fp_splits, | ||
]: | ||
with open(fp, 'rb') as f: | ||
md5 = hashlib.md5(f.read()).hexdigest() | ||
print(f'md5 for {fp}: {md5}') |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
entry_id,split_1,split_2,split_3,split_4,split_5 | ||
g_0008,fold_train,fold_train,fold_train,fold_test,fold_test | ||
g_0009,fold_train,fold_valid,fold_valid,fold_valid,fold_train | ||
g_0030,fold_valid,fold_train,fold_test,fold_train,fold_train | ||
g_0035,fold_test,fold_test,fold_train,fold_train,fold_valid |
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -0,0 +1,112 @@ | ||||||
# Copyright (c) Microsoft Corporation. All rights reserved. | ||||||
# Licensed under the MIT License. | ||||||
|
||||||
import shutil | ||||||
from pathlib import Path | ||||||
|
||||||
import matplotlib.pyplot as plt | ||||||
import pytest | ||||||
import torch | ||||||
import torch.nn as nn | ||||||
from _pytest.fixtures import SubRequest | ||||||
from pytest import MonkeyPatch | ||||||
|
||||||
from torchgeo.datasets import DatasetNotFoundError, DL4GAMAlps | ||||||
|
||||||
pytest.importorskip('xarray', minversion='2023.9') | ||||||
pytest.importorskip('netCDF4', minversion='1.5.4') | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
|
||||||
|
||||||
class TestDL4GAMAlps: | ||||||
@pytest.fixture( | ||||||
params=zip( | ||||||
['train', 'val', 'test'], | ||||||
[1, 3, 5], | ||||||
['small', 'small', 'large'], | ||||||
[DL4GAMAlps.rgb_bands, DL4GAMAlps.rgb_nir_swir_bands, DL4GAMAlps.all_bands], | ||||||
[None, ['dem'], DL4GAMAlps.valid_extra_features], | ||||||
) | ||||||
) | ||||||
def dataset( | ||||||
self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest | ||||||
) -> DL4GAMAlps: | ||||||
r_url = Path('tests', 'data', 'dl4gam_alps') | ||||||
download_metadata = { | ||||||
'dataset_small': { | ||||||
'url': str(r_url / 'dataset_small.tar.gz'), | ||||||
'checksum': '35f85360b943caa8661d9fb573b0f0b5', | ||||||
}, | ||||||
'dataset_large': { | ||||||
'url': str(r_url / 'dataset_large.tar.gz'), | ||||||
'checksum': '636be5be35b8bd1e7771e9010503e4bc', | ||||||
}, | ||||||
'splits_csv': { | ||||||
'url': str(r_url / 'splits.csv'), | ||||||
'checksum': '973367465c8ab322d0cf544a345b02f5', | ||||||
}, | ||||||
} | ||||||
|
||||||
monkeypatch.setattr(DL4GAMAlps, 'download_metadata', download_metadata) | ||||||
root = tmp_path | ||||||
split, cv_iter, version, bands, extra_features = request.param | ||||||
transforms = nn.Identity() | ||||||
return DL4GAMAlps( | ||||||
root, | ||||||
split, | ||||||
cv_iter, | ||||||
version, | ||||||
bands, | ||||||
extra_features, | ||||||
transforms, | ||||||
download=True, | ||||||
checksum=True, | ||||||
) | ||||||
|
||||||
def test_getitem(self, dataset: DL4GAMAlps) -> None: | ||||||
x = dataset[0] | ||||||
assert isinstance(x, dict) | ||||||
|
||||||
var_names = ['image', 'mask_glacier', 'mask_debris', 'mask_clouds_and_shadows'] | ||||||
if dataset.extra_features: | ||||||
var_names += list(dataset.extra_features) | ||||||
for v in var_names: | ||||||
assert v in x | ||||||
assert isinstance(x[v], torch.Tensor) | ||||||
|
||||||
# check if all variables have the same spatial dimensions as the image | ||||||
assert x['image'].shape[-2:] == x[v].shape[-2:] | ||||||
|
||||||
# check the first dimension of the image tensor | ||||||
assert x['image'].shape[0] == len(dataset.bands) | ||||||
|
||||||
def test_len(self, dataset: DL4GAMAlps) -> None: | ||||||
num_glaciers_per_fold = 2 if dataset.split == 'train' else 1 | ||||||
num_patches_per_glacier = 1 if dataset.version == 'small' else 2 | ||||||
assert len(dataset) == num_glaciers_per_fold * num_patches_per_glacier | ||||||
|
||||||
def test_not_downloaded(self, tmp_path: Path) -> None: | ||||||
with pytest.raises(DatasetNotFoundError, match='Dataset not found'): | ||||||
DL4GAMAlps(tmp_path) | ||||||
|
||||||
def test_already_downloaded_and_extracted(self, dataset: DL4GAMAlps) -> None: | ||||||
DL4GAMAlps(root=dataset.root, download=False, version=dataset.version) | ||||||
|
||||||
def test_already_downloaded_but_not_yet_extracted(self, tmp_path: Path) -> None: | ||||||
fp_archive = Path('tests', 'data', 'dl4gam_alps', 'dataset_small.tar.gz') | ||||||
shutil.copyfile(fp_archive, Path(str(tmp_path), fp_archive.name)) | ||||||
fp_splits = Path('tests', 'data', 'dl4gam_alps', 'splits.csv') | ||||||
shutil.copyfile(fp_splits, Path(str(tmp_path), fp_splits.name)) | ||||||
DL4GAMAlps(root=str(tmp_path), download=False) | ||||||
|
||||||
def test_invalid_split(self) -> None: | ||||||
with pytest.raises(AssertionError): | ||||||
DL4GAMAlps(split='foo') | ||||||
|
||||||
def test_plot(self, dataset: DL4GAMAlps) -> None: | ||||||
dataset.plot(dataset[0], suptitle='Test') | ||||||
plt.close() | ||||||
|
||||||
sample = dataset[0] | ||||||
sample['prediction'] = torch.clone(sample['mask_glacier']) | ||||||
dataset.plot(sample, suptitle='Test with prediction') | ||||||
plt.close() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Confirmed the license