Skip to content

Commit

Permalink
Task/geospatial tests (#3841)
Browse files Browse the repository at this point in the history
Co-authored-by: dgboss <[email protected]>
Test suite for geospatial functions introduced in #3800
  • Loading branch information
conbrad authored Aug 13, 2024
1 parent ea10ff8 commit db4a4e0
Show file tree
Hide file tree
Showing 3 changed files with 131 additions and 25 deletions.
Binary file added api/app/tests/utils/snow_masked_hfi20240810.tif
Binary file not shown.
108 changes: 108 additions & 0 deletions api/app/tests/utils/test_geospatial.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
import os
import pytest
from osgeo import gdal
import numpy as np

from app.utils.geospatial import raster_mul, warp_to_match_extent

fixture_path = os.path.join(os.path.dirname(__file__), "snow_masked_hfi20240810.tif")


def get_test_tpi_raster(hfi_ds: gdal.Dataset, fill_value: int):
# Get raster dimensions
x_size = hfi_ds.RasterXSize
y_size = hfi_ds.RasterYSize

# Get the geotransform and projection from the first raster
geotransform = hfi_ds.GetGeoTransform()
projection = hfi_ds.GetProjection()

# Create the output raster
driver = gdal.GetDriverByName("MEM")
out_ds: gdal.Dataset = driver.Create("memory", x_size, y_size, 1, gdal.GDT_Byte)

# Set the geotransform and projection
out_ds.SetGeoTransform(geotransform)
out_ds.SetProjection(projection)

filler_data = hfi_ds.GetRasterBand(1).ReadAsArray()
tpi_data = np.full_like(filler_data, fill_value)

# Write the modified data to the new raster
out_band = out_ds.GetRasterBand(1)
out_band.SetNoDataValue(0)
out_band.WriteArray(tpi_data)
return out_ds


def get_tpi_raster_wrong_shape():
driver = gdal.GetDriverByName("MEM")
out_ds: gdal.Dataset = driver.Create("memory", 1, 1, 1, gdal.GDT_Byte)
out_band = out_ds.GetRasterBand(1)
out_band.SetNoDataValue(0)
out_band.WriteArray(np.array([[1]]))
return out_ds


def test_zero_case():
hfi_ds: gdal.Dataset = gdal.Open(fixture_path, gdal.GA_ReadOnly)
tpi_ds: gdal.Dataset = get_test_tpi_raster(hfi_ds, 0)

masked_raster = raster_mul(tpi_ds, hfi_ds)
masked_data = masked_raster.GetRasterBand(1).ReadAsArray()

assert masked_data.shape == hfi_ds.GetRasterBand(1).ReadAsArray().shape
assert np.all(masked_data == 0) == True

hfi_ds = None
tpi_ds = None


def test_identity_case():
hfi_ds: gdal.Dataset = gdal.Open(fixture_path, gdal.GA_ReadOnly)
tpi_ds: gdal.Dataset = get_test_tpi_raster(hfi_ds, 1)

masked_raster = raster_mul(tpi_ds, hfi_ds)
masked_data = masked_raster.GetRasterBand(1).ReadAsArray()
hfi_data = hfi_ds.GetRasterBand(1).ReadAsArray()

# do the simple classification for hfi, pixels >4k are 1
hfi_data[hfi_data >= 1] = 1
hfi_data[hfi_data < 1] = 0

assert masked_data.shape == hfi_data.shape
assert np.all(masked_data == hfi_data) == True

hfi_ds = None
tpi_ds = None


def test_wrong_dimensions():
hfi_ds: gdal.Dataset = gdal.Open(fixture_path, gdal.GA_ReadOnly)
tpi_ds: gdal.Dataset = get_tpi_raster_wrong_shape()

with pytest.raises(ValueError):
raster_mul(tpi_ds, hfi_ds)

hfi_ds = None
tpi_ds = None


@pytest.mark.skip(reason="enable once gdal is updated past version 3.4")
def test_warp_to_match_dimension():
hfi_ds: gdal.Dataset = gdal.Open(fixture_path, gdal.GA_ReadOnly)
tpi_ds: gdal.Dataset = get_tpi_raster_wrong_shape()

driver = gdal.GetDriverByName("MEM")
out_dataset: gdal.Dataset = driver.Create("memory", hfi_ds.RasterXSize, hfi_ds.RasterYSize, 1, gdal.GDT_Byte)

warp_to_match_extent(tpi_ds, hfi_ds, out_dataset)
output_data = out_dataset.GetRasterBand(1).ReadAsArray()
hfi_data = hfi_ds.GetRasterBand(1).ReadAsArray()

assert hfi_data.shape == output_data.shape
assert hfi_ds.RasterXSize == out_dataset.RasterXSize
assert hfi_ds.RasterYSize == out_dataset.RasterYSize

hfi_ds = None
tpi_ds = None
48 changes: 23 additions & 25 deletions api/app/utils/geospatial.py
Original file line number Diff line number Diff line change
@@ -1,66 +1,64 @@
from dataclasses import dataclass
import logging
from typing import Any, Optional
from osgeo import gdal


logger = logging.getLogger(__name__)


def warp_to_match_extent(source_raster: gdal.Dataset, raster_to_match: gdal.Dataset, output_path: str) -> gdal.Dataset:
def warp_to_match_extent(source_ds: gdal.Dataset, ds_to_match: gdal.Dataset, output_path: str) -> gdal.Dataset:
"""
Warp the source_raster to match the extent and projection of the other raster.
Warp the source dataset to match the extent and projection of the other dataset.
:param source_raster: the raster to warp
:param raster_to_match: the reference raster to match the source against
:param source_ds: the dataset raster to warp
:param ds_to_match: the reference dataset raster to match the source against
:param output_path: output path of the resulting raster
:return: warped raster dataset
"""
source_geotransform = raster_to_match.GetGeoTransform()
source_geotransform = ds_to_match.GetGeoTransform()
x_res = source_geotransform[1]
y_res = -source_geotransform[5]
minx = source_geotransform[0]
maxy = source_geotransform[3]
maxx = minx + source_geotransform[1] * raster_to_match.RasterXSize
miny = maxy + source_geotransform[5] * raster_to_match.RasterYSize
maxx = minx + source_geotransform[1] * ds_to_match.RasterXSize
miny = maxy + source_geotransform[5] * ds_to_match.RasterYSize
extent = [minx, miny, maxx, maxy]

# Warp to match input option parameters
return gdal.Warp(output_path, source_raster, dstSRS=raster_to_match.GetProjection(), outputBounds=extent, xRes=x_res, yRes=y_res, resampleAlg=gdal.GRA_NearestNeighbour)
return gdal.Warp(output_path, source_ds, dstSRS=ds_to_match.GetProjection(), outputBounds=extent, xRes=x_res, yRes=y_res, resampleAlg=gdal.GRA_NearestNeighbour)


def raster_mul(tpi_raster: gdal.Dataset, hfi_raster: gdal.Dataset, chunk_size=256) -> gdal.Dataset:
def raster_mul(tpi_ds: gdal.Dataset, hfi_ds: gdal.Dataset, chunk_size=256) -> gdal.Dataset:
"""
Multiply rasters together by reading in chunks of pixels at a time to avoid loading
the rasters into memory all at once.
:param tpi_raster: Classified TPI raster to multiply against the classified HFI raster
:param hfi_raster: Classified HFI raster to multiply against the classified TPI raster
:param tpi_ds: Classified TPI dataset raster to multiply against the classified HFI dataset raster
:param hfi_ds: Classified HFI dataset raster to multiply against the classified TPI dataset raster
:raises ValueError: Raised if the dimensions of the rasters don't match
:return: Multiplied raster result as a raster dataset
"""
# Get raster dimensions
x_size = tpi_raster.RasterXSize
y_size = tpi_raster.RasterYSize
x_size = tpi_ds.RasterXSize
y_size = tpi_ds.RasterYSize

# Check if the dimensions of both rasters match
if x_size != hfi_raster.RasterXSize or y_size != hfi_raster.RasterYSize:
if x_size != hfi_ds.RasterXSize or y_size != hfi_ds.RasterYSize:
raise ValueError("The dimensions of the two rasters do not match.")

# Get the geotransform and projection from the first raster
geotransform = tpi_raster.GetGeoTransform()
projection = tpi_raster.GetProjection()
geotransform = tpi_ds.GetGeoTransform()
projection = tpi_ds.GetProjection()

# Create the output raster
driver = gdal.GetDriverByName("MEM")
out_raster: gdal.Dataset = driver.Create("memory", x_size, y_size, 1, gdal.GDT_Byte)
out_ds: gdal.Dataset = driver.Create("memory", x_size, y_size, 1, gdal.GDT_Byte)

# Set the geotransform and projection
out_raster.SetGeoTransform(geotransform)
out_raster.SetProjection(projection)
out_ds.SetGeoTransform(geotransform)
out_ds.SetProjection(projection)

tpi_raster_band = tpi_raster.GetRasterBand(1)
hfi_raster_band = hfi_raster.GetRasterBand(1)
tpi_raster_band = tpi_ds.GetRasterBand(1)
hfi_raster_band = hfi_ds.GetRasterBand(1)

# Process in chunks
for y in range(0, y_size, chunk_size):
Expand All @@ -80,8 +78,8 @@ def raster_mul(tpi_raster: gdal.Dataset, hfi_raster: gdal.Dataset, chunk_size=25
tpi_chunk *= hfi_chunk

# Write the result to the output raster
out_raster.GetRasterBand(1).WriteArray(tpi_chunk, x, y)
out_ds.GetRasterBand(1).WriteArray(tpi_chunk, x, y)
tpi_chunk = None
hfi_chunk = None

return out_raster
return out_ds

0 comments on commit db4a4e0

Please sign in to comment.