Skip to content

Commit

Permalink
Xee performance: Improving benchmarks.
Browse files Browse the repository at this point in the history
Makes improvements towards #29. Here, we group together all ee.getInfo() calls into one RPC call. In addition, this PR helped identify the underlying reason why Xee is slow (see #30).

Status Quo:
```
open_dataset():avg=51.30,std=10.21,best=39.41,worst=68.27
open_and_chunk():avg=52.94,std=7.01,best=43.06,worst=63.03
open_and_write():avg=113.94,std=27.35,best=90.03,worst=173.90
```

After:

```
open_dataset():avg=39.82,std=8.67,best=25.24,worst=55.54
open_and_chunk():avg=36.46,std=11.96,best=25.71,worst=59.83
open_and_write():avg=91.48,std=4.74,best=86.33,worst=104.08
```

PiperOrigin-RevId: 570480601
  • Loading branch information
Xee authors authored and alxmrs committed Oct 3, 2023
1 parent 024263e commit 00ccfc8
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 37 deletions.
113 changes: 76 additions & 37 deletions xee/ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,30 +125,17 @@ def __init__(
if n_images != -1:
self.image_collection = image_collection.limit(n_images)

self.projection = projection
self.geometry = geometry
self.primary_dim_name = primary_dim_name or 'time'
self.primary_dim_property = primary_dim_property or 'system:time_start'

n_images, props, img_info, imgs_list = ee.List([
self.image_collection.size(),
self.image_collection.toDictionary(),
self.image_collection.first(),
(
self.image_collection.reduceColumns(
ee.Reducer.toList(), ['system:id']
).get('list')
),
]).getInfo()

self.n_images = n_images
self._props = props
self.n_images = self.get_info['size']
self._props = self.get_info['props']
# Metadata should apply to all imgs.
self._img_info: types.ImageInfo = img_info
self.image_ids = imgs_list
self._img_info: types.ImageInfo = self.get_info['first']

proj = {}
# TODO(alxr): See if we can factor this into the above getInfo() call.
if isinstance(projection, ee.Projection):
proj = projection.getInfo()
proj = self.get_info.get('projection', {})

self.crs_arg = crs or proj.get('crs', proj.get('wkt', 'EPSG:4326'))
self.crs = CRS(self.crs_arg)
Expand Down Expand Up @@ -180,14 +167,16 @@ def __init__(
# used for all internal `computePixels()` calls.
try:
if isinstance(geometry, ee.Geometry):
x_min_0, y_min_0, x_max_0, y_max_0 = geometry_to_bounds(geometry)
x_min_0, y_min_0, x_max_0, y_max_0 = _ee_bounds_to_bounds(
self.get_info['bounds']
)
else:
x_min_0, y_min_0, x_max_0, y_max_0 = self.crs.area_of_use.bounds
except AttributeError:
# `area_of_use` is probable `None`. Parse the geometry from the first
# image instead.
x_min_0, y_min_0, x_max_0, y_max_0 = geometry_to_bounds(
self.image_collection.first().geometry()
# image instead (calculated in self.get_info())
x_min_0, y_min_0, x_max_0, y_max_0 = _ee_bounds_to_bounds(
self.get_info['bounds']
)
# We add and subtract the scale to solve an off-by-one error. With this
# adjustment, we achieve parity with a pure `computePixels()` call.
Expand All @@ -208,6 +197,58 @@ def __init__(

self.preferred_chunks = self._assign_preferred_chunks()

@functools.cached_property
def get_info(self) -> dict[str, Any]:
"""Make all getInfo() calls to EE at once."""

rpcs = [
('size', self.image_collection.size()),
('props', self.image_collection.toDictionary()),
('first', self.image_collection.first()),
]

if isinstance(self.projection, ee.Projection):
rpcs.append(('projection', self.projection))

if isinstance(self.geometry, ee.Geometry):
rpcs.append(('bounds', self.geometry.bounds()))
else:
rpcs.append(('bounds', self.image_collection.first().geometry().bounds()))

# TODO(#29, #30): This RPC call takes the longest time to compute. This
# requires a full scan of the images in the collection, which happens on the
# EE backend. This is essential because we want the primary dimension of the
# opened dataset to be something relevant to the data, like time (start
# time) as opposed to a random index number.
#
# One optimization that could prove really fruitful: read the first and last
# (few) values of the primary dim (read: time) and interpolate the rest
# client-side. Ideally, this would live behind a xarray-backend-specific
# feature flag, since it's not guaranteed that data is this consistent.
columns = ['system:id', self.primary_dim_property]
rpcs.append(
('properties',
(
self.image_collection
.reduceColumns(ee.Reducer.toList().repeat(len(columns)), columns)
.get('list'))
)
)

info = ee.List([rpc for _, rpc in rpcs]).getInfo()

return dict(zip((name for name, _ in rpcs), info))

@property
def image_collection_properties(self) -> tuple[list[str], list[str]]:
system_ids, primary_coord = self.get_info['properties']
return (system_ids, primary_coord)

@property
def image_ids(self) -> list[str]:
image_ids, _ = self.image_collection_properties
return image_ids

def _assign_index_chunks(
self, input_chunk_store: dict[Any, Any]
) -> dict[Any, Any]:
Expand Down Expand Up @@ -294,23 +335,17 @@ def get_attrs(self) -> utils.Frozen[Any, Any]:

def _get_primary_coordinates(self) -> list[Any]:
"""Gets the primary dimension coordinate values from an ImageCollection."""
coods_list = (
self.image_collection
.reduceColumns(ee.Reducer.toList(), [self.primary_dim_property])
.get('list')
)
_, primary_coords = self.image_collection_properties

coords = coods_list.getInfo()

if not coords:
if not primary_coords:
raise ValueError(
f'No {self.primary_dim_property!r} values found in the'
" 'ImageCollection'"
)
if self.primary_dim_property in ['system:time_start', 'system:time_end']:
# Convert elements in primary_dim_list to np.datetime64
coords = [np.datetime64(time, 'ms') for time in coords]
return coords
primary_coords = [np.datetime64(time, 'ms') for time in primary_coords]
return primary_coords

def get_variables(self) -> utils.Frozen[str, xarray.Variable]:
vars_ = [(name, self.open_store_variable(name)) for name in self._bands()]
Expand Down Expand Up @@ -390,9 +425,7 @@ def _parse_dtype(data_type: types.DataType):
return np.dtype(dt)


def geometry_to_bounds(geom: ee.Geometry) -> types.Bounds:
"""Finds the CRS bounds from a ee.Geometry polygon."""
bounds = geom.bounds().getInfo()
def _ee_bounds_to_bounds(bounds: ee.Bounds) -> types.Bounds:
coords = np.array(bounds['coordinates'], dtype=np.float32)[0]
x_min, y_min, x_max, y_max = (
min(coords[:, 0]),
Expand All @@ -403,6 +436,12 @@ def geometry_to_bounds(geom: ee.Geometry) -> types.Bounds:
return x_min, y_min, x_max, y_max


def geometry_to_bounds(geom: ee.Geometry) -> types.Bounds:
"""Finds the CRS bounds from a ee.Geometry polygon."""
bounds = geom.bounds().getInfo()
return _ee_bounds_to_bounds(bounds)


class _GetComputedPixels:
"""Wrapper around `ee.data.computePixels()` to make retries simple."""

Expand Down Expand Up @@ -763,7 +802,7 @@ def open_dataset(
upon opening. By default, data is opened with `EPSG:4326'.
scale (optional): The scale in the `crs` or `projection`'s units of
measure -- either meters or degrees. This defines the scale that all
data is represented in upon opening. By default, the scale is 1° when
data is represented in upon opening. By default, the scale is 1° when
the CRS is in degrees or 10,000 when in meters.
projection (optional): Specify an `ee.Projection` object to define the
`scale` and `crs` (or other coordinate reference system) with which to
Expand Down
4 changes: 4 additions & 0 deletions xee/micro_benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
to run.
"""

import cProfile
import os
import tempfile
import timeit
Expand All @@ -32,6 +33,7 @@

REPEAT = 10
LOOPS = 1
PROFILE = False


def init_ee_for_tests():
Expand Down Expand Up @@ -74,6 +76,8 @@ def main(_: list[str]) -> None:
init_ee_for_tests()
print(f'[{REPEAT} time(s) with {LOOPS} loop(s) each.]')
for fn in ['open_dataset()', 'open_and_chunk()', 'open_and_write()']:
if PROFILE:
cProfile.run(fn)
timer = timeit.Timer(fn, globals=globals())
res = timer.repeat(REPEAT, number=LOOPS)
avg, std, best, worst = np.mean(res), np.std(res), np.min(res), np.max(res)
Expand Down

0 comments on commit 00ccfc8

Please sign in to comment.