Skip to content

Commit

Permalink
Add ability to wrap building detection input images in WarpedVRT to t…
Browse files Browse the repository at this point in the history
…ransform them to a unified CRS, resolution, and origin.

PiperOrigin-RevId: 713539966
  • Loading branch information
jzxu authored and copybara-github committed Jan 9, 2025
1 parent 4eda411 commit 54d94b1
Show file tree
Hide file tree
Showing 4 changed files with 328 additions and 63 deletions.
11 changes: 2 additions & 9 deletions src/detect_buildings_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@
from skai import detect_buildings
from skai import extract_tiles
from skai import read_raster
from skai import utils

import tensorflow as tf

Expand Down Expand Up @@ -138,15 +137,9 @@ def main(args):
gdf = gpd.read_file(f)
aoi = gdf.geometry.values[0]
gdal_env = read_raster.parse_gdal_env(FLAGS.gdal_env)
image_paths = utils.expand_file_patterns(FLAGS.image_paths)
for image_path in image_paths:
if not read_raster.raster_is_tiled(image_path):
raise ValueError(f'Raster "{image_path}" is not tiled.')

vrt_paths = read_raster.build_vrts(
image_paths, os.path.join(temp_dir, 'image'), 0.5, FLAGS.mosaic_images
vrt_paths = read_raster.prepare_building_detection_input_images(
FLAGS.image_paths, os.path.join(temp_dir, 'vrts'), gdal_env
)

tiles = []
for path in vrt_paths:
tiles.extend(
Expand Down
19 changes: 14 additions & 5 deletions src/skai/extract_tiles.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,13 +203,19 @@ def get_tiles_for_aoi(image_path: str,
Yields:
A grid of tiles that covers the AOI.
Raises:
RuntimeError: If the image file does not exist.
"""
if not rasterio.shutil.exists(image_path):
raise RuntimeError(f'File {image_path} does not exist')

with rasterio.Env(**gdal_env):
image = rasterio.open(image_path)
x_min, y_min, x_max, y_max = _get_pixel_bounds_for_aoi(image, aoi)
yield from get_tiles(
image_path, x_min, y_min, x_max, y_max, tile_size, margin
)
x_min, y_min, x_max, y_max = _get_pixel_bounds_for_aoi(image, aoi)
yield from get_tiles(
image_path, x_min, y_min, x_max, y_max, tile_size, margin
)


class ExtractTilesAsExamplesFn(beam.DoFn):
Expand All @@ -227,7 +233,10 @@ def _get_raster(
raster, rgb_bands = self._rasters.get(image_path, (None, None))
if raster is None:
with rasterio.Env(**self._gdal_env):
raster = rasterio.open(image_path)
try:
raster = rasterio.open(image_path)
except rasterio.errors.RasterioIOError as error:
raise ValueError(f'Error opening raster {image_path}') from error
rgb_bands = read_raster.get_rgb_indices(raster)
self._rasters[image_path] = (raster, rgb_bands)
return raster, rgb_bands
Expand Down
180 changes: 148 additions & 32 deletions src/skai/read_raster.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import dataclasses
import functools
import logging
import math
import os
import re
import shutil
Expand All @@ -26,10 +27,13 @@

import affine
import apache_beam as beam
import geopandas as gpd
import numpy as np
import pandas as pd
import pyproj
import rasterio
import rasterio.plot
import rasterio.shutil
import rasterio.warp
import rtree
import shapely.geometry
Expand Down Expand Up @@ -762,6 +766,31 @@ def _run_gdalbuildvrt(
extents: If not None, sets the extents of the VRT. Should by x_min, x_max,
y_min, y_max.
"""
# First verify that all images have the same projections and number of bands.
# VRTs do not support images with different projections and different numbers
# of bands.
# Input images with different resolutions are supported.
raster = rasterio.open(image_paths[0])
expected_crs = raster.crs
expected_band_count = raster.count
if expected_crs.units_factor[0] not in ('meter', 'metre'):
# Requiring meters may be too strict but is simpler. If other linear units
# such as feet are absolutely required, we can support them as well.
raise ValueError(
'The only supported linear unit is "meter", but found'
f' {expected_crs.units_factor[0]}'
)
for path in image_paths[1:]:
raster = rasterio.open(path)
if raster.crs != expected_crs:
raise ValueError(
f'Expecting CRS {expected_crs}, got {raster.crs}'
)
if raster.count != expected_band_count:
raise ValueError(
f'Expecting {expected_band_count} bands, got {raster.count}'
)

# GDAL doesn't recognize gs:// prefixes. Instead it wants /vsigs/ prefixes.
gdal_image_paths = [
p.replace('gs://', '/vsigs/') if p.startswith('gs://') else p
Expand Down Expand Up @@ -796,11 +825,78 @@ def _run_gdalbuildvrt(
shutil.copyfileobj(source, dest)


def _get_unified_warped_vrt_options(
image_paths: list[str], resolution: float
) -> dict[str, Any]:
"""Gets options for a WarpedVRT that projects images into unified space.
Input images can have arbitrary boundaries, CRS, and resolution. The WarpedVRT
will project them to the same boundaries, CRS, and resolution.
Args:
image_paths: Input image paths.
resolution: Desired output resolution.
Returns:
Dictionary of WarpedVRT constructor options.
"""
image_bounds = []
for image_path in image_paths:
r = rasterio.open(image_path)
image_bounds.append(
gpd.GeoDataFrame(
geometry=[shapely.geometry.box(*r.bounds)], crs=r.crs
).to_crs('EPSG:4326')
)
combined = pd.concat(image_bounds)
utm_crs = combined.estimate_utm_crs()
left, bottom, right, top = combined.to_crs(
utm_crs
).geometry.unary_union.bounds
width = int(math.ceil((right - left) / resolution))
height = int(math.ceil((top - bottom) / resolution))
transform = affine.Affine(resolution, 0.0, left, 0.0, -resolution, top)
return {
'resampling': rasterio.enums.Resampling.cubic,
'crs': utm_crs,
'transform': transform,
'width': width,
'height': height,
}


def _build_warped_vrt(
image_path: str,
vrt_path: str,
vrt_options: dict[str, Any],
gdal_env: dict[str, str],
) -> None:
"""Creates a WarpedVRT file from an image.
Args:
image_path: Path to source image.
vrt_path: VRT file output path.
vrt_options: Options for VRT creation.
gdal_env: GDAL environment configuration.
"""
with rasterio.Env(**gdal_env):
raster = rasterio.open(image_path)
with rasterio.vrt.WarpedVRT(raster, **vrt_options) as vrt:
with tempfile.TemporaryDirectory() as temp_dir:
temp_vrt_path = os.path.join(temp_dir, 'temp.vrt')
rasterio.shutil.copy(vrt, temp_vrt_path, driver='VRT')
with open(temp_vrt_path, 'rb') as source, tf.io.gfile.GFile(
vrt_path, 'wb'
) as dest:
shutil.copyfileobj(source, dest)


def build_vrts(
image_paths: list[str],
vrt_prefix: str,
resolution: float,
mosaic_images: bool,
gdal_env: dict[str, str],
) -> list[str]:
"""Builds VRTs from a list of image paths.
Expand All @@ -810,48 +906,68 @@ def build_vrts(
resolution: VRT resolution in meters per pixel.
mosaic_images: If true, build a single VRT containing all images. If false,
build an individual VRT per input image.
gdal_env: GDAL environment configuration.
Returns:
A list of paths of the generated VRTs.
"""
# First verify that all images have the same projections and number of bands.
# VRTs do not support images with different projections and different numbers
# of bands.
# Input images with different resolutions are supported.
raster = rasterio.open(image_paths[0])
expected_crs = raster.crs
expected_band_count = raster.count
x_bounds = [raster.bounds.left, raster.bounds.right]
y_bounds = [raster.bounds.bottom, raster.bounds.top]
if expected_crs.units_factor[0] not in ('meter', 'metre'):
# Requiring meters may be too strict but is simpler. If other linear units
# such as feet are absolutely required, we can support them as well.
raise ValueError(
'The only supported linear unit is "meter", but found'
f' {expected_crs.units_factor[0]}'
)
for path in image_paths[1:]:
raster = rasterio.open(path)
if raster.crs != expected_crs:
raise ValueError(
f'Expecting CRS {expected_crs}, got {raster.crs}'
)
if raster.count != expected_band_count:
raise ValueError(
f'Expecting {expected_band_count} bands, got {raster.count}'
)
x_bounds.extend((raster.bounds.left, raster.bounds.right))
y_bounds.extend((raster.bounds.bottom, raster.bounds.top))

extents = [min(x_bounds), min(y_bounds), max(x_bounds), max(y_bounds)]
vrt_paths = []
if mosaic_images:
vrt_path = f'{vrt_prefix}-00000-of-00001.vrt'
vrt_path = f'{vrt_prefix}.vrt'
_run_gdalbuildvrt(image_paths, vrt_path, resolution, None)
vrt_paths.append(vrt_path)
else:
warped_vrt_options = _get_unified_warped_vrt_options(
image_paths, resolution
)
for i, image_path in enumerate(image_paths):
vrt_path = f'{vrt_prefix}-{i:05d}-of-{len(image_paths):05d}.vrt'
_run_gdalbuildvrt([image_path], vrt_path, resolution, extents)
_build_warped_vrt(image_path, vrt_path, warped_vrt_options, gdal_env)
vrt_paths.append(vrt_path)
return vrt_paths


def prepare_building_detection_input_images(
image_patterns: list[str], vrt_dir: str, gdal_env: dict[str, str]
) -> list[str]:
"""Prepares input images for the building detection pipeline.
This function performs two operations:
1. For each image pattern that matches multiple files, the files are mosaic'ed
together by wrapping them in a regular VRT.
2. For all input images, including mosaic'ed images, this function wraps a
WarpedVRT around it to transform the image into the correct CRS and
resolution (0.5 meter).
Args:
image_patterns: Input image patterns.
vrt_dir: Directory to store VRTs in.
gdal_env: GDAL environment variables.
Returns:
List of VRTs.
Raises:
FileNotFoundError: If any of the image patterns does not match any files.
"""
wrapped_paths = []
for i, pattern in enumerate(image_patterns):
image_paths = utils.expand_file_patterns([pattern])
if not image_paths:
raise FileNotFoundError(f'{pattern} did not match any files.')
for image_path in image_paths:
if not raster_is_tiled(image_path):
raise ValueError(f'Raster "{image_path}" is not tiled.')
if len(image_paths) == 1:
wrapped_paths.append(image_paths[0])
else:
mosaic_dir = os.path.join(vrt_dir, 'mosaics')
if not tf.io.gfile.exists(mosaic_dir):
tf.io.gfile.makedirs(mosaic_dir)
vrt_prefix = os.path.join(vrt_dir, 'mosaics', f'mosaic-{i:05d}')
wrapped_paths.extend(
build_vrts(image_paths, vrt_prefix, 0.5, True, gdal_env)
)
return build_vrts(
wrapped_paths, os.path.join(vrt_dir, 'input'), 0.5, False, gdal_env
)
Loading

0 comments on commit 54d94b1

Please sign in to comment.