diff --git a/xee/ext.py b/xee/ext.py index af472ca..63a6dd3 100644 --- a/xee/ext.py +++ b/xee/ext.py @@ -125,30 +125,17 @@ def __init__( if n_images != -1: self.image_collection = image_collection.limit(n_images) + self.projection = projection + self.geometry = geometry self.primary_dim_name = primary_dim_name or 'time' self.primary_dim_property = primary_dim_property or 'system:time_start' - n_images, props, img_info, imgs_list = ee.List([ - self.image_collection.size(), - self.image_collection.toDictionary(), - self.image_collection.first(), - ( - self.image_collection.reduceColumns( - ee.Reducer.toList(), ['system:id'] - ).get('list') - ), - ]).getInfo() - - self.n_images = n_images - self._props = props + self.n_images = self.get_info['size'] + self._props = self.get_info['props'] # Metadata should apply to all imgs. - self._img_info: types.ImageInfo = img_info - self.image_ids = imgs_list + self._img_info: types.ImageInfo = self.get_info['first'] - proj = {} - # TODO(alxr): See if we can factor this into the above getInfo() call. - if isinstance(projection, ee.Projection): - proj = projection.getInfo() + proj = self.get_info.get('projection', {}) self.crs_arg = crs or proj.get('crs', proj.get('wkt', 'EPSG:4326')) self.crs = CRS(self.crs_arg) @@ -180,14 +167,16 @@ def __init__( # used for all internal `computePixels()` calls. try: if isinstance(geometry, ee.Geometry): - x_min_0, y_min_0, x_max_0, y_max_0 = geometry_to_bounds(geometry) + x_min_0, y_min_0, x_max_0, y_max_0 = _ee_bounds_to_bounds( + self.get_info['bounds'] + ) else: x_min_0, y_min_0, x_max_0, y_max_0 = self.crs.area_of_use.bounds except AttributeError: # `area_of_use` is probable `None`. Parse the geometry from the first - # image instead. - x_min_0, y_min_0, x_max_0, y_max_0 = geometry_to_bounds( - self.image_collection.first().geometry() + # image instead (calculated in self.get_info()) + x_min_0, y_min_0, x_max_0, y_max_0 = _ee_bounds_to_bounds( + self.get_info['bounds'] ) # We add and subtract the scale to solve an off-by-one error. With this # adjustment, we achieve parity with a pure `computePixels()` call. @@ -208,6 +197,58 @@ def __init__( self.preferred_chunks = self._assign_preferred_chunks() + @functools.cached_property + def get_info(self) -> dict[str, Any]: + """Make all getInfo() calls to EE at once.""" + + rpcs = [ + ('size', self.image_collection.size()), + ('props', self.image_collection.toDictionary()), + ('first', self.image_collection.first()), + ] + + if isinstance(self.projection, ee.Projection): + rpcs.append(('projection', self.projection)) + + if isinstance(self.geometry, ee.Geometry): + rpcs.append(('bounds', self.geometry.bounds())) + else: + rpcs.append(('bounds', self.image_collection.first().geometry().bounds())) + + # TODO(#29, #30): This RPC call takes the longest time to compute. This + # requires a full scan of the images in the collection, which happens on the + # EE backend. This is essential because we want the primary dimension of the + # opened dataset to be something relevant to the data, like time (start + # time) as opposed to a random index number. + # + # One optimization that could prove really fruitful: read the first and last + # (few) values of the primary dim (read: time) and interpolate the rest + # client-side. Ideally, this would live behind a xarray-backend-specific + # feature flag, since it's not guaranteed that data is this consistent. + columns = ['system:id', self.primary_dim_property] + rpcs.append( + ('properties', + ( + self.image_collection + .reduceColumns(ee.Reducer.toList().repeat(len(columns)), columns) + .get('list')) + ) + ) + + info = ee.List([rpc for _, rpc in rpcs]).getInfo() + + return dict(zip((name for name, _ in rpcs), info)) + + @property + def image_collection_properties(self) -> tuple[list[str], list[str]]: + system_ids, primary_coord = self.get_info['properties'] + return (system_ids, primary_coord) + + @property + def image_ids(self) -> list[str]: + image_ids, _ = self.image_collection_properties + return image_ids + def _assign_index_chunks( self, input_chunk_store: dict[Any, Any] ) -> dict[Any, Any]: @@ -294,23 +335,17 @@ def get_attrs(self) -> utils.Frozen[Any, Any]: def _get_primary_coordinates(self) -> list[Any]: """Gets the primary dimension coordinate values from an ImageCollection.""" - coods_list = ( - self.image_collection - .reduceColumns(ee.Reducer.toList(), [self.primary_dim_property]) - .get('list') - ) + _, primary_coords = self.image_collection_properties - coords = coods_list.getInfo() - - if not coords: + if not primary_coords: raise ValueError( f'No {self.primary_dim_property!r} values found in the' " 'ImageCollection'" ) if self.primary_dim_property in ['system:time_start', 'system:time_end']: # Convert elements in primary_dim_list to np.datetime64 - coords = [np.datetime64(time, 'ms') for time in coords] - return coords + primary_coords = [np.datetime64(time, 'ms') for time in primary_coords] + return primary_coords def get_variables(self) -> utils.Frozen[str, xarray.Variable]: vars_ = [(name, self.open_store_variable(name)) for name in self._bands()] @@ -390,9 +425,7 @@ def _parse_dtype(data_type: types.DataType): return np.dtype(dt) -def geometry_to_bounds(geom: ee.Geometry) -> types.Bounds: - """Finds the CRS bounds from a ee.Geometry polygon.""" - bounds = geom.bounds().getInfo() +def _ee_bounds_to_bounds(bounds: ee.Bounds) -> types.Bounds: coords = np.array(bounds['coordinates'], dtype=np.float32)[0] x_min, y_min, x_max, y_max = ( min(coords[:, 0]), @@ -403,6 +436,12 @@ def geometry_to_bounds(geom: ee.Geometry) -> types.Bounds: return x_min, y_min, x_max, y_max +def geometry_to_bounds(geom: ee.Geometry) -> types.Bounds: + """Finds the CRS bounds from a ee.Geometry polygon.""" + bounds = geom.bounds().getInfo() + return _ee_bounds_to_bounds(bounds) + + class _GetComputedPixels: """Wrapper around `ee.data.computePixels()` to make retries simple.""" @@ -763,7 +802,7 @@ def open_dataset( upon opening. By default, data is opened with `EPSG:4326'. scale (optional): The scale in the `crs` or `projection`'s units of measure -- either meters or degrees. This defines the scale that all - data is represented in upon opening. By default, the scale is 1° when + data is represented in upon opening. By default, the scale is 1° when the CRS is in degrees or 10,000 when in meters. projection (optional): Specify an `ee.Projection` object to define the `scale` and `crs` (or other coordinate reference system) with which to diff --git a/xee/micro_benchmarks.py b/xee/micro_benchmarks.py index 92d5d56..464c2eb 100644 --- a/xee/micro_benchmarks.py +++ b/xee/micro_benchmarks.py @@ -18,6 +18,7 @@ to run. """ +import cProfile import os import tempfile import timeit @@ -32,6 +33,7 @@ REPEAT = 10 LOOPS = 1 +PROFILE = False def init_ee_for_tests(): @@ -74,6 +76,8 @@ def main(_: list[str]) -> None: init_ee_for_tests() print(f'[{REPEAT} time(s) with {LOOPS} loop(s) each.]') for fn in ['open_dataset()', 'open_and_chunk()', 'open_and_write()']: + if PROFILE: + cProfile.run(fn) timer = timeit.Timer(fn, globals=globals()) res = timer.repeat(REPEAT, number=LOOPS) avg, std, best, worst = np.mean(res), np.std(res), np.min(res), np.max(res)