diff --git a/python/valis/db/queries.py b/python/valis/db/queries.py index 1209a33..fe74cc0 100644 --- a/python/valis/db/queries.py +++ b/python/valis/db/queries.py @@ -6,6 +6,7 @@ import itertools import packaging +import uuid from typing import Sequence, Union, Generator import astropy.units as u @@ -57,17 +58,21 @@ def append_pipes(query: peewee.ModelSelect, table: str = 'stacked', if table not in {'stacked', 'flat'}: raise ValueError('table must be either "stacked" or "flat"') + # Run initial query as a temporary table. + temp = create_temporary_table(query, indices=['sdss_id']) + model = vizdb.SDSSidStacked if table == 'stacked' else vizdb.SDSSidFlat - qq = query.select_extend(vizdb.SDSSidToPipes.in_boss, - vizdb.SDSSidToPipes.in_apogee, - vizdb.SDSSidToPipes.in_bvs, - vizdb.SDSSidToPipes.in_astra, - vizdb.SDSSidToPipes.has_been_observed, - vizdb.SDSSidToPipes.release, - vizdb.SDSSidToPipes.obs, - vizdb.SDSSidToPipes.mjd).\ - join(vizdb.SDSSidToPipes, on=(model.sdss_id == vizdb.SDSSidToPipes.sdss_id), - attr='pipes').distinct(vizdb.SDSSidToPipes.sdss_id) + qq = temp.select(temp.__star__, + vizdb.SDSSidToPipes.in_boss, + vizdb.SDSSidToPipes.in_apogee, + vizdb.SDSSidToPipes.in_bvs, + vizdb.SDSSidToPipes.in_astra, + vizdb.SDSSidToPipes.has_been_observed, + vizdb.SDSSidToPipes.release, + vizdb.SDSSidToPipes.obs, + vizdb.SDSSidToPipes.mjd).\ + join(vizdb.SDSSidToPipes, on=(temp.c.sdss_id == vizdb.SDSSidToPipes.sdss_id)).\ + distinct(temp.c.sdss_id) if observed: qq = qq.where(vizdb.SDSSidToPipes.has_been_observed == observed) @@ -264,7 +269,8 @@ def carton_program_map(key: str = 'program') -> dict: def carton_program_search(name: str, name_type: str, - query: peewee.ModelSelect | None = None) -> peewee.ModelSelect: + query: peewee.ModelSelect | None = None, + limit: int | None = None) -> peewee.ModelSelect: """ Perform a search on either carton or program Parameters @@ -276,6 +282,8 @@ def carton_program_search(name: str, query : ModelSelect An initial query to extend. If ``None``, a new query with all the unique ``sdss_id``s is created. + limit : int + Limit the number of results returned. Returns ------- @@ -284,7 +292,7 @@ def carton_program_search(name: str, """ if query is None: - query = vizdb.SDSSidStacked.select(vizdb.SDSSidStacked).distinct() + query = vizdb.SDSSidStacked.select(vizdb.SDSSidStacked) query = (query.join( vizdb.SDSSidFlat, @@ -295,6 +303,9 @@ def carton_program_search(name: str, .join(targetdb.Carton) .where(getattr(targetdb.Carton, name_type) == name)) + if limit: + query = query.limit(limit) + return query def get_targets_obs(release: str, obs: str, spectrograph: str) -> peewee.ModelSelect: @@ -931,3 +942,20 @@ def get_target_by_altid(id: str | int, idtype: str = None) -> peewee.ModelSelect # get the sdss_id metadata info return get_targets_by_sdss_id(res.sdss_id) + + +def create_temporary_table(query: peewee.ModelSelect, indices: list[str] | None = None) -> peewee.Table: + """Create a temporary table from a query.""" + + table_name = uuid.uuid4().hex[0:8] + + table = peewee.Table(table_name) + table.bind(vizdb.database) + + query.create_table(table_name, temporary=True) + + if indices: + for index in indices: + vizdb.database.execute_sql(f'CREATE INDEX ON "{table_name}" ({index})') + + return table diff --git a/python/valis/routes/query.py b/python/valis/routes/query.py index 5a7d25c..9c2e24a 100644 --- a/python/valis/routes/query.py +++ b/python/valis/routes/query.py @@ -203,12 +203,17 @@ async def carton_program(self, Query(enum=['carton', 'program'], description='Specify search on carton or program', example='carton')] = 'carton', - observed: Annotated[bool, Query(description='Flag to only include targets that have been observed', example=True)] = True): + observed: Annotated[bool, Query(description='Flag to only include targets that have been observed', example=True)] = True, + limit: Annotated[int | None, Query(description='Limit the number of returned targets', example=100)] = None): """ Perform a search on carton or program """ with database.atomic(): - database.execute_sql('SET LOCAL enable_seqscan=false;') - query = carton_program_search(name, name_type) + if limit is False: + # This tweak seems to do more harm than good when limit is passed. + database.execute_sql('SET LOCAL enable_seqscan=false;') + + query = carton_program_search(name, name_type, limit=limit) query = append_pipes(query, observed=observed) + return query.dicts().iterator() @router.get('/obs', summary='Return targets with spectrum at observatory',