Skip to content

Commit

Permalink
Merge branch 'main' into albireox-issue-69
Browse files Browse the repository at this point in the history
  • Loading branch information
albireox committed Dec 9, 2024
2 parents f372c69 + dbf2ad5 commit 5daa219
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 18 deletions.
86 changes: 73 additions & 13 deletions python/valis/db/queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,16 @@

# all resuable queries go here

from contextlib import contextmanager
import itertools
import packaging
import uuid
from typing import Sequence, Union, Generator

import astropy.units as u
import deepmerge
import peewee
from peewee import Case
from astropy.coordinates import SkyCoord
from sdssdb.peewee.sdss5db import apogee_drpdb as apo
from sdssdb.peewee.sdss5db import boss_drp as boss
Expand All @@ -26,7 +29,7 @@


def append_pipes(query: peewee.ModelSelect, table: str = 'stacked',
observed: bool = True) -> peewee.ModelSelect:
observed: bool = True, release: str = None) -> peewee.ModelSelect:
""" Joins a query to the SDSSidToPipes table
Joines an existing query to the SDSSidToPipes table and returns
Expand Down Expand Up @@ -57,21 +60,45 @@ def append_pipes(query: peewee.ModelSelect, table: str = 'stacked',
if table not in {'stacked', 'flat'}:
raise ValueError('table must be either "stacked" or "flat"')

model = vizdb.SDSSidStacked if table == 'stacked' else vizdb.SDSSidFlat
qq = query.select_extend(vizdb.SDSSidToPipes.in_boss,
vizdb.SDSSidToPipes.in_apogee,
vizdb.SDSSidToPipes.in_bvs,
vizdb.SDSSidToPipes.in_astra,
vizdb.SDSSidToPipes.has_been_observed,
vizdb.SDSSidToPipes.release,
vizdb.SDSSidToPipes.obs,
vizdb.SDSSidToPipes.mjd).\
join(vizdb.SDSSidToPipes, on=(model.sdss_id == vizdb.SDSSidToPipes.sdss_id),
attr='pipes').distinct(vizdb.SDSSidToPipes.sdss_id)
# Run initial query as a temporary table.
temp = create_temporary_table(query, indices=['sdss_id'])

qq = temp.select(temp.__star__,
vizdb.SDSSidToPipes.in_boss,
vizdb.SDSSidToPipes.in_apogee,
vizdb.SDSSidToPipes.in_bvs,
vizdb.SDSSidToPipes.in_astra,
vizdb.SDSSidToPipes.has_been_observed,
vizdb.SDSSidToPipes.release,
vizdb.SDSSidToPipes.obs,
vizdb.SDSSidToPipes.mjd).\
join(vizdb.SDSSidToPipes, on=(temp.c.sdss_id == vizdb.SDSSidToPipes.sdss_id)).\
distinct(temp.c.sdss_id)

if observed:
qq = qq.where(vizdb.SDSSidToPipes.has_been_observed == observed)

if release:
# get the release
rel = vizdb.Releases.select().where(vizdb.Releases.release==release).first()

# if a release has no cutoff info, then force the cutoff to 0, query will return nothing
# to fix this we want mjd cutoffs by survey for all older releases
if not rel.mjd_cutoff_apo and not rel.mjd_cutoff_lco:
rel.mjd_cutoff_apo = 0
rel.mjd_cutoff_lco = 0

# create the mjd cutoff condition
qq = qq.where(vizdb.SDSSidToPipes.mjd <= Case(
vizdb.SDSSidToPipes.obs,
(
('apo', rel.mjd_cutoff_apo),
('lco', rel.mjd_cutoff_lco)
),
None
)
)

return qq


Expand Down Expand Up @@ -264,7 +291,8 @@ def carton_program_map(key: str = 'program') -> dict:

def carton_program_search(name: str,
name_type: str,
query: peewee.ModelSelect | None = None) -> peewee.ModelSelect:
query: peewee.ModelSelect | None = None,
limit: int | None = None) -> peewee.ModelSelect:
""" Perform a search on either carton or program
Parameters
Expand All @@ -276,6 +304,8 @@ def carton_program_search(name: str,
query : ModelSelect
An initial query to extend. If ``None``, a new query with all the unique
``sdss_id``s is created.
limit : int
Limit the number of results returned.
Returns
-------
Expand All @@ -286,6 +316,13 @@ def carton_program_search(name: str,
if query is None:
query = vizdb.SDSSidStacked.select(vizdb.SDSSidStacked).distinct()

# NOTE: These setting seem to help when querying some cartons or programs, mainly
# those with small number of targets, and in some cases with these the query
# actually applies the LIMIT more efficiently, but it's not a perfect solution.
vizdb.database.execute_sql('SET enable_gathermerge = off;')
vizdb.database.execute_sql('SET parallel_tuple_cost = 100;')
vizdb.database.execute_sql('SET enable_bitmapscan = off;')

query = (query.join(
vizdb.SDSSidFlat,
on=(vizdb.SDSSidFlat.sdss_id == vizdb.SDSSidStacked.sdss_id))
Expand All @@ -295,6 +332,9 @@ def carton_program_search(name: str,
.join(targetdb.Carton)
.where(getattr(targetdb.Carton, name_type) == name))

if limit:
query = query.limit(limit)

return query

def get_targets_obs(release: str, obs: str, spectrograph: str) -> peewee.ModelSelect:
Expand Down Expand Up @@ -931,3 +971,23 @@ def get_target_by_altid(id: str | int, idtype: str = None) -> peewee.ModelSelect

# get the sdss_id metadata info
return get_targets_by_sdss_id(res.sdss_id)


def create_temporary_table(query: peewee.ModelSelect,
indices: list[str] | None = None) -> Generator[None, None, peewee.Table]:
"""Create a temporary table from a query."""

table_name = uuid.uuid4().hex[0:8]

table = peewee.Table(table_name)
table.bind(vizdb.database)

query.create_table(table_name, temporary=True)

if indices:
for index in indices:
vizdb.database.execute_sql(f'CREATE INDEX ON "{table_name}" ({index})')

vizdb.database.execute_sql(f'ANALYZE "{table_name}"')

return table
23 changes: 18 additions & 5 deletions python/valis/routes/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ class SearchModel(BaseModel):
program: Optional[str] = Field(None, description='The program name', example='bhm_rm')
carton: Optional[str] = Field(None, description='The carton name', example='bhm_rm_core')
observed: Optional[bool] = Field(True, description='Flag to only include targets that have been observed', example=True)
limit: Optional[int] = Field(None, description='Limit the number of returned targets', example=100)

class MainResponse(SDSSModel):
""" Combined model from all individual query models """
Expand Down Expand Up @@ -106,9 +107,16 @@ async def main_search(self, body: SearchModel):
query = carton_program_search(body.program or body.carton,
'program' if body.program else 'carton',
query=query)

# DANGER!!! This limit applies *before* the append_pipes call. If the
# append_pipes call includes observed=True we may have limited things in
# such a way that only unobserved or very few targets are returned.
if body.limit:
query = query.limit(body.limit)

# append query to pipes
if query:
query = append_pipes(query, observed=body.observed)
query = append_pipes(query, observed=body.observed, release=self.release)

# query iterator
res = query.dicts().iterator() if query else []
Expand All @@ -127,7 +135,7 @@ async def cone_search(self,
""" Perform a cone search """

res = cone_search(ra, dec, radius, units=units)
r = append_pipes(res, observed=observed)
r = append_pipes(res, observed=observed, release=self.release)
# return sorted by distance
# doing this here due to the append_pipes distinct
return sorted(r.dicts().iterator(), key=lambda x: x['distance'])
Expand Down Expand Up @@ -211,12 +219,17 @@ async def carton_program(self,
Query(enum=['carton', 'program'],
description='Specify search on carton or program',
example='carton')] = 'carton',
observed: Annotated[bool, Query(description='Flag to only include targets that have been observed', example=True)] = True):
observed: Annotated[bool, Query(description='Flag to only include targets that have been observed', example=True)] = True,
limit: Annotated[int | None, Query(description='Limit the number of returned targets', example=100)] = None):
""" Perform a search on carton or program """
with database.atomic():
database.execute_sql('SET LOCAL enable_seqscan=false;')
query = carton_program_search(name, name_type)
if limit is False:
# This tweak seems to do more harm than good when limit is passed.
database.execute_sql('SET LOCAL enable_seqscan=false;')

query = carton_program_search(name, name_type, limit=limit)
query = append_pipes(query, observed=observed)

return query.dicts().iterator()

@router.get('/obs', summary='Return targets with spectrum at observatory',
Expand Down

0 comments on commit 5daa219

Please sign in to comment.