From 5dd0f68af4c44aac3c65c1b2b15f0a9bbf65a37c Mon Sep 17 00:00:00 2001 From: David Bitner Date: Mon, 14 Nov 2022 19:18:10 -0600 Subject: [PATCH] test for using using iterators more --- tifeatures/dbmodel.py | 22 ++-- tifeatures/dependencies.py | 24 ++++- tifeatures/factory.py | 197 +++++++++++++++++----------------- tifeatures/layer.py | 212 +++++++++++++++++++------------------ 4 files changed, 238 insertions(+), 217 deletions(-) diff --git a/tifeatures/dbmodel.py b/tifeatures/dbmodel.py index b2082c2..910879c 100644 --- a/tifeatures/dbmodel.py +++ b/tifeatures/dbmodel.py @@ -117,21 +117,13 @@ def id_column_info(self) -> Column: # type: ignore def columns(self, properties: Optional[List[str]] = None) -> List[str]: """Return table columns optionally filtered to only include columns from properties.""" - cols = [c.name for c in self.properties] - if properties is not None: - if self.id_column and self.id_column not in properties: - properties.append(self.id_column) - - geom_col = self.get_geometry_column() - if geom_col: - properties.append(geom_col.name) - - cols = [col for col in cols if col in properties] - - if len(cols) < 1: - raise TypeError("No columns selected") - - return cols + nongeo = [ + c.name for c in self.properties if c.type not in ["geometry", "geography"] + ] + if properties is None: + return nongeo + else: + return [c for c in nongeo if c in properties] def get_column(self, property_name: str) -> Optional[Column]: """Return column info.""" diff --git a/tifeatures/dependencies.py b/tifeatures/dependencies.py index 1d60807..42a022e 100644 --- a/tifeatures/dependencies.py +++ b/tifeatures/dependencies.py @@ -7,12 +7,12 @@ from pygeofilter.parsers.cql2_json import parse as cql2_json_parser from pygeofilter.parsers.cql2_text import parse as cql2_text_parser -from tifeatures.errors import InvalidBBox -from tifeatures.layer import CollectionLayer +from tifeatures.dbmodel import GeometryColumn +from tifeatures.errors import InvalidBBox, InvalidGeometryColumnName from tifeatures.layer import Table as TableLayer from tifeatures.resources import enums -from fastapi import HTTPException, Path, Query +from fastapi import Depends, HTTPException, Path, Query from starlette.requests import Request @@ -20,7 +20,7 @@ def CollectionParams( request: Request, collectionId: str = Path(..., description="Collection identifier"), -) -> CollectionLayer: +) -> TableLayer: """Return Layer Object.""" # Check function_catalog function_catalog = getattr(request.app.state, "tifeatures_function_catalog", {}) @@ -251,3 +251,19 @@ def sortby_query( ): """Sortby dependency.""" return sortby + + +def geom_col( + collection: TableLayer = Depends(CollectionParams), + geom_column: Optional[str] = Query( + None, + description="Select geometry column.", + alias="geom-column", + ), +) -> GeometryColumn: + """Geometry Column Dependency.""" + geom_col = collection.get_geometry_column(geom_column) + + if geom_column and geom_column.lower() != "none" and not geom_col: + raise InvalidGeometryColumnName(f"Invalid Geometry Column: {geom_column}.") + return geom_col diff --git a/tifeatures/factory.py b/tifeatures/factory.py index e9b391c..f3fb5b4 100644 --- a/tifeatures/factory.py +++ b/tifeatures/factory.py @@ -3,12 +3,14 @@ import csv import json from dataclasses import dataclass, field -from typing import Any, Callable, Dict, Generator, Iterable, List, Optional +from typing import Any, Callable, Dict, List, Optional import jinja2 +import orjson from pygeofilter.ast import AstType from tifeatures import model +from tifeatures.dbmodel import GeometryColumn from tifeatures.dependencies import ( CollectionParams, ItemOutputType, @@ -18,6 +20,7 @@ bbox_query, datetime_query, filter_query, + geom_col, ids_query, properties_query, sortby_query, @@ -45,7 +48,7 @@ ) -def create_csv_rows(data: Iterable[Dict]) -> Generator[str, None, None]: +async def create_csv_rows(data): """Creates an iterator that returns lines of csv from an iterable of dicts.""" class DummyWriter: @@ -56,7 +59,7 @@ def write(self, line: str): return line # Get the first row and construct the column names - row = next(data) # type: ignore + row = await data.__anext__() # type: ignore fieldnames = row.keys() writer = csv.DictWriter(DummyWriter(), fieldnames=fieldnames) @@ -67,10 +70,42 @@ def write(self, line: str): yield writer.writerow(row) # Write all remaining rows - for row in data: + async for row in data: yield writer.writerow(row) +def create_geojson_feature(item: Dict, geometry_column: Optional[str]) -> Dict: + """Creates an iterator that returns geojson features from an iterable of dicts.""" + geom = item.pop("tifeatures_geom") + if geom: + geomout = geom + else: + geomout = None + id = item.pop("tifeatures_id") + return { + "type": "Feature", + "id": id, + "geometry": geomout, + "properties": item, + } + + +def create_wkt_feature(item: Dict, geometry_column: Optional[str]) -> Dict: + """Creates an iterator that returns geojson features from an iterable of dicts.""" + geom = item.pop("tifeatures_geom") + if geom: + geomout = geom + else: + geomout = None + id = item.pop("tifeatures_id") + return { + "type": "Feature", + "id": id, + "geometry": geomout, + "properties": item, + } + + @dataclass class Endpoints: """Endpoints Factory.""" @@ -543,7 +578,7 @@ def queryables( @self.router.get( "/collections/{collectionId}/items", - response_model=model.Items, + # response_model=model.Items, response_model_exclude_none=True, response_class=GeoJSONResponse, responses={ @@ -568,11 +603,7 @@ async def items( properties: Optional[List[str]] = Depends(properties_query), cql_filter: Optional[AstType] = Depends(filter_query), sortby: Optional[str] = Depends(sortby_query), - geom_column: Optional[str] = Query( - None, - description="Select geometry column.", - alias="geom-column", - ), + geom_column: Optional[GeometryColumn] = Depends(geom_col), datetime_column: Optional[str] = Query( None, description="Select datetime column.", @@ -625,53 +656,48 @@ async def items( if key.lower() not in exclude and key.lower() in table_property ] - items, matched_items = await collection.features( - request.app.state.pool, - ids_filter=ids_filter, - bbox_filter=bbox_filter, - datetime_filter=datetime_filter, - properties_filter=properties_filter, - cql_filter=cql_filter, + if geom_column: + geom_col_name = geom_column.name + else: + geom_col_name = None + + _from = collection._from() + _where = collection._where( + ids=ids_filter, + datetime=datetime_filter, + bbox=bbox_filter, + properties=properties_filter, + cql=cql_filter, + geom=geom_col_name, + dt=datetime_filter, + ) + + features = collection._features( + pool=request.app.state.pool, + _from=_from, + _where=_where, sortby=sortby, properties=properties, limit=limit, offset=offset, - geom=geom_column, - dt=datetime_column, + geometry_column=geom_column, bbox_only=bbox_only, simplify=simplify, ) - if output_type in ( + if output_type in [ MediaType.csv, MediaType.json, MediaType.ndjson, - ): - if items and items[0].geometry is not None: - rows = ( - { - "collectionId": collection.id, - "itemId": f.id, - **f.properties, - "geometry": f.geometry.wkt, - } - for f in items - ) - - else: - rows = ( - { - "collectionId": collection.id, - "itemId": f.id, - **f.properties, - } - for f in items - ) - + ]: + items = ( + create_wkt_feature(feature, geom_col_name) + async for feature in features + ) # CSV Response if output_type == MediaType.csv: return StreamingResponse( - create_csv_rows(rows), + create_csv_rows(items), media_type=MediaType.csv, headers={ "Content-Disposition": "attachment;filename=items.csv" @@ -680,18 +706,35 @@ async def items( # JSON Response if output_type == MediaType.json: - return JSONResponse([row for row in rows]) + return JSONResponse([item async for item in items]) # NDJSON Response if output_type == MediaType.ndjson: return StreamingResponse( - (json.dumps(row) + "\n" for row in rows), + (orjson.dumps(item) + b"\n" async for item in items), media_type=MediaType.ndjson, headers={ "Content-Disposition": "attachment;filename=items.ndjson" }, ) + matched_items = await collection._features_count( + pool=request.app.state.pool, _from=_from, _where=_where + ) + items = ( + create_geojson_feature(feature, geom_col_name) + async for feature in features + ) + + if output_type == MediaType.geojsonseq: + return StreamingResponse( + (orjson.dumps(item) async for item in items), + media_type=MediaType.geojsonseq, + headers={ + "Content-Disposition": "attachment;filename=items.geojson" + }, + ) + qs = "?" + str(request.query_params) if request.query_params else "" links = [ model.Link( @@ -711,7 +754,9 @@ async def items( ), ] + items = [item async for item in items] items_returned = len(items) + print(f"items_returned {items_returned}") if (matched_items - items_returned) > offset: next_offset = offset + items_returned @@ -753,64 +798,26 @@ async def items( ), ) - data = model.Items( - id=collection.id, - title=collection.title or collection.id, - description=collection.description or collection.title or collection.id, - numberMatched=matched_items, - numberReturned=items_returned, - links=links, - features=[ - model.Item( - **{ - **feature.dict(), - "links": [ - model.Link( - title="Collection", - href=self.url_for( - request, - "collection", - collectionId=collection.id, - ), - rel="collection", - type=MediaType.json, - ), - model.Link( - title="Item", - href=self.url_for( - request, - "item", - collectionId=collection.id, - itemId=feature.id, - ), - rel="item", - type=MediaType.json, - ), - ], - } - ) - for feature in items - ], - ) + data = { + "id": collection.id, + "title": collection.title or collection.id, + "description": collection.description + or collection.title + or collection.id, + "numberMatched": matched_items, + "numberReturned": items_returned, + "links": [link.dict() for link in links], + "features": items, + } # HTML Response if output_type == MediaType.html: return self._create_html_response( request, - data.json(exclude_none=True), + orjson.dumps(data), template_name="items", ) - # GeoJSONSeq Response - elif output_type == MediaType.geojsonseq: - return StreamingResponse( - data.json_seq(exclude_none=True), - media_type=MediaType.geojsonseq, - headers={ - "Content-Disposition": "attachment;filename=items.geojson" - }, - ) - # Default to GeoJSON Response return data diff --git a/tifeatures/layer.py b/tifeatures/layer.py index 5f47a4e..a53f87c 100644 --- a/tifeatures/layer.py +++ b/tifeatures/layer.py @@ -5,7 +5,7 @@ from dataclasses import dataclass from typing import Any, ClassVar, Dict, List, Optional, Tuple -from buildpg import asyncpg, clauses +from buildpg import RawDangerous, asyncpg, clauses from buildpg import funcs as pg_funcs from buildpg import logic, render from ciso8601 import parse_rfc3339 @@ -24,6 +24,7 @@ ) from tifeatures.filter.evaluate import to_filter from tifeatures.filter.filters import bbox_to_wkt +from tifeatures.resources.enums import MediaType # Links to geojson schema geojson_schema = { @@ -120,8 +121,45 @@ def bounds_default(cls, values): return values - def _select(self, properties: Optional[List[str]]): - return clauses.Select(self.columns(properties)) + def _select( + self, + properties: Optional[List[str]], + geometry_column: Optional[GeometryColumn], + bbox_only: Optional[bool], + simplify: Optional[float], + media_type: MediaType = MediaType.geojson, + ): + columns = self.columns(properties) + if columns: + sel = clauses.Select(columns) + RawDangerous(",") + else: + sel = RawDangerous("SELECT ") + + if self.id_column: + sel = sel + logic.V(self.id_column) + RawDangerous(" AS tifeatures_id, ") + else: + sel = sel + RawDangerous(" ROW_NUMBER () OVER () AS tifeatures_id, ") + + geom = self._geom(geometry_column, bbox_only, simplify) + if media_type in [MediaType.geojson, MediaType.geojsonseq]: + if geom: + sel = ( + sel + + pg_funcs.cast(logic.Func("st_asgeojson", geom), "json") + + RawDangerous(" AS tifeatures_geom ") + ) + else: + sel = sel + RawDangerous(" NULL::json AS tifeatures_geom ") + else: + if geom: + sel = ( + sel + + logic.Func("st_asgeojson", geom) + + RawDangerous(" AS tifeatures_geom ") + ) + else: + sel = sel + RawDangerous(" NULL::text AS tifeatures_geom ") + return sel def _select_count(self): return clauses.Select(pg_funcs.count("*")) @@ -136,7 +174,7 @@ def _geom( simplify: Optional[float], ): if geometry_column is None: - return pg_funcs.cast(None, "jsonb") + return pg_funcs.cast(None, "geometry") g = logic.V(geometry_column.name) g = pg_funcs.cast(g, "geometry") @@ -153,9 +191,7 @@ def _geom( simplify, ) - g = logic.Func("ST_AsGeoJson", g) - - return pg_funcs.cast(g, "jsonb") + return g def _where( self, @@ -328,31 +364,47 @@ def _features_query( + clauses.Offset(offset or 0) ) - def _features_count_query( + async def _features_count(self, pool: asyncpg.BuildPgPool, _from, _where): + """Build features COUNT query.""" + c = self._select_count() + _from + _where + q, p = render(":c", c=c) + async with pool.acquire() as conn: + count = await conn.fetchval(q, *p) + return count + + async def _features( self, - *, - ids_filter: Optional[List[str]] = None, - bbox_filter: Optional[List[float]] = None, - datetime_filter: Optional[List[str]] = None, - properties_filter: Optional[List[Tuple[str, str]]] = None, - cql_filter: Optional[AstType] = None, - geom: str = None, - dt: str = None, + pool: asyncpg.BuildPgPool, + _from, + _where, + sortby: Optional[str] = None, + properties: Optional[List[str]] = None, + limit: Optional[int] = None, + offset: Optional[int] = None, + geometry_column=None, + bbox_only=None, + simplify=None, ): """Build features COUNT query.""" - return ( - self._select_count() - + self._from() - + self._where( - ids=ids_filter, - datetime=datetime_filter, - bbox=bbox_filter, - properties=properties_filter, - cql=cql_filter, - geom=geom, - dt=dt, + c = ( + self._select( + properties=properties, + geometry_column=geometry_column, + bbox_only=bbox_only, + simplify=simplify, ) + + _from + + _where + + self._sortby(sortby) + + clauses.Limit(limit or 10) + + clauses.Offset(offset or 0) ) + q, p = render(":c", c=c) + print(q, p) + async with pool.acquire() as conn: + for r in await conn.fetch(q, *p): + properties = dict(r) + yield properties async def query( self, @@ -373,81 +425,35 @@ async def query( simplify: Optional[float] = None, ) -> Tuple[FeatureCollection, int]: """Build and run Pg query.""" - if geom and geom.lower() != "none" and not self.get_geometry_column(geom): + geom_col = self.get_geometry_column(geom) + + if geom and geom.lower() != "none" and not geom_col: raise InvalidGeometryColumnName(f"Invalid Geometry Column: {geom}.") - sql_query = """ - WITH - features AS ( - :features_q - ), - total_count AS ( - :count_q - ) - SELECT json_build_object( - 'type', 'FeatureCollection', - 'features', - ( - SELECT - json_agg( - json_build_object( - 'type', 'Feature', - 'id', :id_column, - 'geometry', :geometry_q, - 'properties', to_jsonb( features.* ) - :geom_columns::text[] - ) - ) - FROM features - ), - 'total_count', - ( - SELECT count FROM total_count - ) - ) - ; - """ - id_column = logic.V(self.id_column) or pg_funcs.cast(None, "text") - geom_columns = [g.name for g in self.geometry_columns] - q, p = render( - sql_query, - features_q=self._features_query( - ids_filter=ids_filter, - bbox_filter=bbox_filter, - datetime_filter=datetime_filter, - properties_filter=properties_filter, - cql_filter=cql_filter, - sortby=sortby, - properties=properties, - geom=geom, - dt=dt, - limit=limit, - offset=offset, - ), - count_q=self._features_count_query( - ids_filter=ids_filter, - bbox_filter=bbox_filter, - datetime_filter=datetime_filter, - properties_filter=properties_filter, - cql_filter=cql_filter, - geom=geom, - dt=dt, - ), - id_column=id_column, - geometry_q=self._geom( - geometry_column=self.get_geometry_column(geom), - bbox_only=bbox_only, - simplify=simplify, - ), - geom_columns=geom_columns, - ) - async with pool.acquire() as conn: - items = await conn.fetchval(q, *p) + if geom_col: + geom_col_name = geom_col.name + else: + geom_col_name = None - return ( - FeatureCollection(features=items.get("features") or []), - items["total_count"], + _from = self._from() + _where = self._where() + + # Get count + + pgfeatures = self._features( + pool=pool, + _from=_from, + _where=_where, + geom_col_name=geom_col_name, + id_column=self.id_column, + sortby=sortby, + properties=properties, + limit=limit, + offset=offset, ) + return pgfeatures + async def features( self, pool: asyncpg.BuildPgPool, @@ -460,7 +466,7 @@ async def features( limit: Optional[int] = None, offset: Optional[int] = None, **kwargs: Any, - ) -> Tuple[FeatureCollection, int]: + ): """Return a FeatureCollection and the number of matched items.""" return await self.query( pool=pool, @@ -481,7 +487,7 @@ async def feature( item_id: str, properties: Optional[List[str]] = None, **kwargs: Any, - ) -> Optional[Feature]: + ): """Return a Feature.""" feature_collection, _ = await self.query( pool=pool, @@ -489,10 +495,10 @@ async def feature( properties=properties, **kwargs, ) - if len(feature_collection): - return feature_collection.features[0] - - return None + try: + return await feature_collection.__anext__() + except: + return None @property def queryables(self) -> Dict: