From 54d94b1870c6c7d724c6628ded0ebf8cb861478e Mon Sep 17 00:00:00 2001 From: Joseph Xu Date: Wed, 8 Jan 2025 22:33:55 -0800 Subject: [PATCH] Add ability to wrap building detection input images in WarpedVRT to transform them to a unified CRS, resolution, and origin. PiperOrigin-RevId: 713539966 --- src/detect_buildings_main.py | 11 +-- src/skai/extract_tiles.py | 19 +++- src/skai/read_raster.py | 180 +++++++++++++++++++++++++++------- src/skai/read_raster_test.py | 181 +++++++++++++++++++++++++++++++---- 4 files changed, 328 insertions(+), 63 deletions(-) diff --git a/src/detect_buildings_main.py b/src/detect_buildings_main.py index ff24d2d..747a198 100644 --- a/src/detect_buildings_main.py +++ b/src/detect_buildings_main.py @@ -54,7 +54,6 @@ from skai import detect_buildings from skai import extract_tiles from skai import read_raster -from skai import utils import tensorflow as tf @@ -138,15 +137,9 @@ def main(args): gdf = gpd.read_file(f) aoi = gdf.geometry.values[0] gdal_env = read_raster.parse_gdal_env(FLAGS.gdal_env) - image_paths = utils.expand_file_patterns(FLAGS.image_paths) - for image_path in image_paths: - if not read_raster.raster_is_tiled(image_path): - raise ValueError(f'Raster "{image_path}" is not tiled.') - - vrt_paths = read_raster.build_vrts( - image_paths, os.path.join(temp_dir, 'image'), 0.5, FLAGS.mosaic_images + vrt_paths = read_raster.prepare_building_detection_input_images( + FLAGS.image_paths, os.path.join(temp_dir, 'vrts'), gdal_env ) - tiles = [] for path in vrt_paths: tiles.extend( diff --git a/src/skai/extract_tiles.py b/src/skai/extract_tiles.py index 7c3ee59..8489807 100644 --- a/src/skai/extract_tiles.py +++ b/src/skai/extract_tiles.py @@ -203,13 +203,19 @@ def get_tiles_for_aoi(image_path: str, Yields: A grid of tiles that covers the AOI. + + Raises: + RuntimeError: If the image file does not exist. """ + if not rasterio.shutil.exists(image_path): + raise RuntimeError(f'File {image_path} does not exist') + with rasterio.Env(**gdal_env): image = rasterio.open(image_path) - x_min, y_min, x_max, y_max = _get_pixel_bounds_for_aoi(image, aoi) - yield from get_tiles( - image_path, x_min, y_min, x_max, y_max, tile_size, margin - ) + x_min, y_min, x_max, y_max = _get_pixel_bounds_for_aoi(image, aoi) + yield from get_tiles( + image_path, x_min, y_min, x_max, y_max, tile_size, margin + ) class ExtractTilesAsExamplesFn(beam.DoFn): @@ -227,7 +233,10 @@ def _get_raster( raster, rgb_bands = self._rasters.get(image_path, (None, None)) if raster is None: with rasterio.Env(**self._gdal_env): - raster = rasterio.open(image_path) + try: + raster = rasterio.open(image_path) + except rasterio.errors.RasterioIOError as error: + raise ValueError(f'Error opening raster {image_path}') from error rgb_bands = read_raster.get_rgb_indices(raster) self._rasters[image_path] = (raster, rgb_bands) return raster, rgb_bands diff --git a/src/skai/read_raster.py b/src/skai/read_raster.py index ee3b81f..4acd701 100644 --- a/src/skai/read_raster.py +++ b/src/skai/read_raster.py @@ -16,6 +16,7 @@ import dataclasses import functools import logging +import math import os import re import shutil @@ -26,10 +27,13 @@ import affine import apache_beam as beam +import geopandas as gpd import numpy as np +import pandas as pd import pyproj import rasterio import rasterio.plot +import rasterio.shutil import rasterio.warp import rtree import shapely.geometry @@ -762,6 +766,31 @@ def _run_gdalbuildvrt( extents: If not None, sets the extents of the VRT. Should by x_min, x_max, y_min, y_max. """ + # First verify that all images have the same projections and number of bands. + # VRTs do not support images with different projections and different numbers + # of bands. + # Input images with different resolutions are supported. + raster = rasterio.open(image_paths[0]) + expected_crs = raster.crs + expected_band_count = raster.count + if expected_crs.units_factor[0] not in ('meter', 'metre'): + # Requiring meters may be too strict but is simpler. If other linear units + # such as feet are absolutely required, we can support them as well. + raise ValueError( + 'The only supported linear unit is "meter", but found' + f' {expected_crs.units_factor[0]}' + ) + for path in image_paths[1:]: + raster = rasterio.open(path) + if raster.crs != expected_crs: + raise ValueError( + f'Expecting CRS {expected_crs}, got {raster.crs}' + ) + if raster.count != expected_band_count: + raise ValueError( + f'Expecting {expected_band_count} bands, got {raster.count}' + ) + # GDAL doesn't recognize gs:// prefixes. Instead it wants /vsigs/ prefixes. gdal_image_paths = [ p.replace('gs://', '/vsigs/') if p.startswith('gs://') else p @@ -796,11 +825,78 @@ def _run_gdalbuildvrt( shutil.copyfileobj(source, dest) +def _get_unified_warped_vrt_options( + image_paths: list[str], resolution: float +) -> dict[str, Any]: + """Gets options for a WarpedVRT that projects images into unified space. + + Input images can have arbitrary boundaries, CRS, and resolution. The WarpedVRT + will project them to the same boundaries, CRS, and resolution. + + Args: + image_paths: Input image paths. + resolution: Desired output resolution. + + Returns: + Dictionary of WarpedVRT constructor options. + """ + image_bounds = [] + for image_path in image_paths: + r = rasterio.open(image_path) + image_bounds.append( + gpd.GeoDataFrame( + geometry=[shapely.geometry.box(*r.bounds)], crs=r.crs + ).to_crs('EPSG:4326') + ) + combined = pd.concat(image_bounds) + utm_crs = combined.estimate_utm_crs() + left, bottom, right, top = combined.to_crs( + utm_crs + ).geometry.unary_union.bounds + width = int(math.ceil((right - left) / resolution)) + height = int(math.ceil((top - bottom) / resolution)) + transform = affine.Affine(resolution, 0.0, left, 0.0, -resolution, top) + return { + 'resampling': rasterio.enums.Resampling.cubic, + 'crs': utm_crs, + 'transform': transform, + 'width': width, + 'height': height, + } + + +def _build_warped_vrt( + image_path: str, + vrt_path: str, + vrt_options: dict[str, Any], + gdal_env: dict[str, str], +) -> None: + """Creates a WarpedVRT file from an image. + + Args: + image_path: Path to source image. + vrt_path: VRT file output path. + vrt_options: Options for VRT creation. + gdal_env: GDAL environment configuration. + """ + with rasterio.Env(**gdal_env): + raster = rasterio.open(image_path) + with rasterio.vrt.WarpedVRT(raster, **vrt_options) as vrt: + with tempfile.TemporaryDirectory() as temp_dir: + temp_vrt_path = os.path.join(temp_dir, 'temp.vrt') + rasterio.shutil.copy(vrt, temp_vrt_path, driver='VRT') + with open(temp_vrt_path, 'rb') as source, tf.io.gfile.GFile( + vrt_path, 'wb' + ) as dest: + shutil.copyfileobj(source, dest) + + def build_vrts( image_paths: list[str], vrt_prefix: str, resolution: float, mosaic_images: bool, + gdal_env: dict[str, str], ) -> list[str]: """Builds VRTs from a list of image paths. @@ -810,48 +906,68 @@ def build_vrts( resolution: VRT resolution in meters per pixel. mosaic_images: If true, build a single VRT containing all images. If false, build an individual VRT per input image. + gdal_env: GDAL environment configuration. Returns: A list of paths of the generated VRTs. """ - # First verify that all images have the same projections and number of bands. - # VRTs do not support images with different projections and different numbers - # of bands. - # Input images with different resolutions are supported. - raster = rasterio.open(image_paths[0]) - expected_crs = raster.crs - expected_band_count = raster.count - x_bounds = [raster.bounds.left, raster.bounds.right] - y_bounds = [raster.bounds.bottom, raster.bounds.top] - if expected_crs.units_factor[0] not in ('meter', 'metre'): - # Requiring meters may be too strict but is simpler. If other linear units - # such as feet are absolutely required, we can support them as well. - raise ValueError( - 'The only supported linear unit is "meter", but found' - f' {expected_crs.units_factor[0]}' - ) - for path in image_paths[1:]: - raster = rasterio.open(path) - if raster.crs != expected_crs: - raise ValueError( - f'Expecting CRS {expected_crs}, got {raster.crs}' - ) - if raster.count != expected_band_count: - raise ValueError( - f'Expecting {expected_band_count} bands, got {raster.count}' - ) - x_bounds.extend((raster.bounds.left, raster.bounds.right)) - y_bounds.extend((raster.bounds.bottom, raster.bounds.top)) - - extents = [min(x_bounds), min(y_bounds), max(x_bounds), max(y_bounds)] vrt_paths = [] if mosaic_images: - vrt_path = f'{vrt_prefix}-00000-of-00001.vrt' + vrt_path = f'{vrt_prefix}.vrt' _run_gdalbuildvrt(image_paths, vrt_path, resolution, None) vrt_paths.append(vrt_path) else: + warped_vrt_options = _get_unified_warped_vrt_options( + image_paths, resolution + ) for i, image_path in enumerate(image_paths): vrt_path = f'{vrt_prefix}-{i:05d}-of-{len(image_paths):05d}.vrt' - _run_gdalbuildvrt([image_path], vrt_path, resolution, extents) + _build_warped_vrt(image_path, vrt_path, warped_vrt_options, gdal_env) vrt_paths.append(vrt_path) return vrt_paths + + +def prepare_building_detection_input_images( + image_patterns: list[str], vrt_dir: str, gdal_env: dict[str, str] +) -> list[str]: + """Prepares input images for the building detection pipeline. + + This function performs two operations: + 1. For each image pattern that matches multiple files, the files are mosaic'ed + together by wrapping them in a regular VRT. + 2. For all input images, including mosaic'ed images, this function wraps a + WarpedVRT around it to transform the image into the correct CRS and + resolution (0.5 meter). + + Args: + image_patterns: Input image patterns. + vrt_dir: Directory to store VRTs in. + gdal_env: GDAL environment variables. + + Returns: + List of VRTs. + + Raises: + FileNotFoundError: If any of the image patterns does not match any files. + """ + wrapped_paths = [] + for i, pattern in enumerate(image_patterns): + image_paths = utils.expand_file_patterns([pattern]) + if not image_paths: + raise FileNotFoundError(f'{pattern} did not match any files.') + for image_path in image_paths: + if not raster_is_tiled(image_path): + raise ValueError(f'Raster "{image_path}" is not tiled.') + if len(image_paths) == 1: + wrapped_paths.append(image_paths[0]) + else: + mosaic_dir = os.path.join(vrt_dir, 'mosaics') + if not tf.io.gfile.exists(mosaic_dir): + tf.io.gfile.makedirs(mosaic_dir) + vrt_prefix = os.path.join(vrt_dir, 'mosaics', f'mosaic-{i:05d}') + wrapped_paths.extend( + build_vrts(image_paths, vrt_prefix, 0.5, True, gdal_env) + ) + return build_vrts( + wrapped_paths, os.path.join(vrt_dir, 'input'), 0.5, False, gdal_env + ) diff --git a/src/skai/read_raster_test.py b/src/skai/read_raster_test.py index cc14051..1cffd98 100644 --- a/src/skai/read_raster_test.py +++ b/src/skai/read_raster_test.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os import pathlib import tempfile from absl.testing import absltest @@ -35,14 +36,15 @@ def _create_test_image_tiff_file_with_position_size( - west: float, - north: float, + x: float, + y: float, width: int, height: int, num_channels: int, crs: str, colorinterps: list[ColorInterp], tags: list[dict[str, str]], + resolution: float = 0.1, ): image = np.random.randint( 0, 256, (height, width, num_channels), dtype=np.uint8 @@ -55,7 +57,9 @@ def _create_test_image_tiff_file_with_position_size( 'count': num_channels, 'dtype': 'uint8', 'crs': crs, - 'transform': rasterio.transform.from_origin(west, north, 0.1, 0.1), + 'transform': rasterio.transform.from_origin( + x, y, resolution, resolution + ), } _, image_path = tempfile.mkstemp( dir=absltest.TEST_TMPDIR.value, suffix='.tiff' @@ -77,6 +81,36 @@ def _create_test_image_tiff_file(colorinterps: list[ColorInterp]): ) +def _create_test_image_at_path( + path: str, + x: float, + y: float, + width: int, + height: int, + num_channels: int, + crs: str, + resolution: float, +): + image = np.random.randint( + 0, 256, (height, width, num_channels), dtype=np.uint8 + ) + + profile = { + 'driver': 'GTiff', + 'height': height, + 'width': width, + 'count': num_channels, + 'dtype': 'uint8', + 'crs': crs, + 'transform': rasterio.transform.from_origin( + x, y, resolution, resolution + ), + } + with rasterio.open(path, 'w', **profile) as dst: + for i in range(num_channels): + dst.write(image[..., i], i + 1) + + def _create_buildings_file( coordinates: list[tuple[float, float]], output_path: str ) -> gpd.GeoDataFrame: @@ -309,8 +343,8 @@ def test_get_rgb_indices_missing_blue(self): def test_get_rgb_indices_band_name_tags(self): image_path = _create_test_image_tiff_file_with_position_size( - west=10, - north=20, + x=10, + y=20, width=100, height=100, num_channels=4, @@ -345,8 +379,8 @@ def test_convert_image_to_uint8(self): def test_build_mosaic_vrt(self): image1_path = _create_test_image_tiff_file_with_position_size( - west=10, - north=20, + x=10, + y=20, width=100, height=100, num_channels=3, @@ -355,8 +389,8 @@ def test_build_mosaic_vrt(self): tags={}, ) image2_path = _create_test_image_tiff_file_with_position_size( - west=20, - north=20, + x=20, + y=20, width=100, height=100, num_channels=3, @@ -369,6 +403,7 @@ def test_build_mosaic_vrt(self): pathlib.Path(absltest.TEST_TMPDIR.value) / 'image', 0.5, True, + {}, ) vrt_raster = rasterio.open(vrt_paths[0]) vrt_image = vrt_raster.read() @@ -379,22 +414,22 @@ def test_build_mosaic_vrt(self): def test_build_individual_vrts(self): image1_path = _create_test_image_tiff_file_with_position_size( - west=-125, - north=40, + x=581370, + y=4141960, width=200, height=100, num_channels=3, - crs='EPSG:26910', + crs='EPSG:32610', colorinterps=[ColorInterp.red, ColorInterp.green, ColorInterp.blue], tags={}, ) image2_path = _create_test_image_tiff_file_with_position_size( - west=-120, - north=50, + x=581375, + y=4141970, width=200, height=100, num_channels=3, - crs='EPSG:26910', + crs='EPSG:32610', colorinterps=[ColorInterp.red, ColorInterp.green, ColorInterp.blue], tags={}, ) @@ -403,6 +438,7 @@ def test_build_individual_vrts(self): pathlib.Path(absltest.TEST_TMPDIR.value) / 'image', 0.5, False, + {}, ) self.assertLen(vrt_paths, 2) for vrt_path in vrt_paths: @@ -411,10 +447,121 @@ def test_build_individual_vrts(self): self.assertEqual(3, vrt_raster.count) self.assertEqual((0.5, 0.5), vrt_raster.res) # Image size is divided by 5 because resolution lowered from 0.1 to 0.5. - self.assertEqual((3, 40, 50), vrt_image.shape) + self.assertEqual((3, 41, 51), vrt_image.shape) # All images should have the same bounds. The bounds is the union of the # bounds of the individual images. - self.assertEqual((-125, 30, -100, 50), tuple(vrt_raster.bounds)) + np.testing.assert_allclose( + (581370, 4141949.5, 581395.5, 4141970), tuple(vrt_raster.bounds) + ) + + def test_build_individual_vrts_diff_crs(self): + image1_path = _create_test_image_tiff_file_with_position_size( + x=581370, + y=4141960, + width=200, + height=100, + num_channels=3, + crs='EPSG:32610', + colorinterps=[ColorInterp.red, ColorInterp.green, ColorInterp.blue], + tags={}, + ) + image2_path = _create_test_image_tiff_file_with_position_size( + x=-122.08034382, + y=37.42096571, + width=200, + height=100, + num_channels=3, + crs='EPSG:4326', + colorinterps=[ColorInterp.red, ColorInterp.green, ColorInterp.blue], + tags={}, + resolution=0.0000010198723355, # About 0.1m + ) + vrt_paths = read_raster.build_vrts( + [image1_path, image2_path], + pathlib.Path(absltest.TEST_TMPDIR.value) / 'image', + 0.5, + False, + {}, + ) + self.assertLen(vrt_paths, 2) + for vrt_path in vrt_paths: + vrt_raster = rasterio.open(vrt_path) + vrt_image = vrt_raster.read() + self.assertEqual(3, vrt_raster.count) + self.assertEqual((0.5, 0.5), vrt_raster.res) + # Image size is divided by 5 because resolution lowered from 0.1 to 0.5. + self.assertEqual((3, 41, 47), vrt_image.shape) + # All images should have the same bounds. The bounds is the union of the + # bounds of the individual images. + np.testing.assert_allclose( + (581370, 4141949.675, 581393.5, 4141970.175), tuple(vrt_raster.bounds) + ) + + def test_prepare_building_detection_input_images(self): + images_dir = self.create_tempdir() + _create_test_image_at_path( + path=os.path.join(images_dir, 'image1.tif'), + x=581370, + y=4141960, + width=200, + height=100, + num_channels=3, + crs='EPSG:26910', + resolution=1.0, + ) + _create_test_image_at_path( + path=os.path.join(images_dir, 'image2.tif'), + x=581380, + y=4141970, + width=200, + height=100, + num_channels=3, + crs='EPSG:26910', + resolution=0.5, + ) + _create_test_image_at_path( + path=os.path.join(images_dir, 'standalone_image.tif'), + x=-122.08034382, + y=37.42096571, + width=500, + height=500, + num_channels=3, + crs='EPSG:4326', + resolution=0.000006, + ) + image_paths = [ + os.path.join(images_dir, 'image*.tif'), + os.path.join(images_dir, 'standalone_image.tif'), + ] + vrt_dir = self.create_tempdir() + vrt_paths = read_raster.prepare_building_detection_input_images( + image_paths, vrt_dir, {} + ) + self.assertLen(vrt_paths, 2) + vrt1 = rasterio.open(vrt_paths[0]) + self.assertEqual(vrt1.crs.to_epsg(), 32610) + self.assertEqual(vrt1.res, (0.5, 0.5)) + vrt2 = rasterio.open(vrt_paths[1]) + self.assertEqual(vrt2.crs.to_epsg(), 32610) + self.assertEqual(vrt2.res, (0.5, 0.5)) + self.assertEqual(vrt1.width, vrt2.width) + self.assertEqual(vrt1.height, vrt2.height) + self.assertEqual(vrt1.transform, vrt2.transform) + + self.assertCountEqual( + os.listdir(vrt_dir), + [ + 'mosaics', + 'input-00000-of-00002.vrt', + 'input-00001-of-00002.vrt', + ], + ) + self.assertCountEqual( + os.listdir(os.path.join(vrt_dir, 'mosaics')), + [ + 'mosaic-00000.vrt', + ], + ) if __name__ == '__main__':