diff --git a/locustfile.py b/locustfile.py new file mode 100644 index 0000000..a2cac13 --- /dev/null +++ b/locustfile.py @@ -0,0 +1,128 @@ +import random +from locust import HttpUser, task, between + + +# A set of files to run with [locust.io](https://locust.io/) for performance testing the app. +# pip install locust +# then run "locust -f locustfile.py" and open http://localhost:8089/ in your browser + + +class FastAPIUser(HttpUser): + release = 'IPL3' + sdssids = [23326, 54392544, 57651832, 57832526, 61731453, 85995134, 56055457] + wait_time = between(1, 5) # Simulate user think time between requests + + @task + def query_main(self): + url = "/query/main" + headers = {'Content-Type': 'application/json'} + params = {'release': self.release} + payload1 = { + 'ra': random.uniform(0, 360), + 'dec': random.uniform(-90, 90), + 'radius': random.uniform(0.01, 0.2), + 'units': 'degree', + 'observed': True + } + payload2 = { + "id": random.choice(self.sdssids), + } + payload3 = { + 'ra': random.uniform(0, 360), + 'dec': random.uniform(-90, 90), + 'radius': random.uniform(0.01, 0.2), + 'units': 'degree', + 'observed': True, + 'program': 'bhm_rm', + 'carton': 'bhm_rm_core' + } + payload = random.choice([payload1, payload2, payload3]) + with self.client.post(url, headers=headers, params=params, json=payload, catch_response=True) as response: + if response.status_code != 200: + response.failure(f"POST {url} failed: {response.text}") + + @task + def query_cone(self): + url = "/query/cone" + params = { + 'ra': random.uniform(0, 360), + 'dec': random.uniform(-90, 90), + 'radius': random.uniform(0.01, 0.5), + 'units': 'degree', + 'observed': random.choice([True, False]), + 'release': self.release + } + with self.client.get(url, params=params, catch_response=True) as response: + if response.status_code != 200: + response.failure(f"GET {url} failed: {response.text}") + + @task + def query_carton(self): + url = '/query/carton-program' + params = { + 'name': 'manual_mwm_tess_ob', + 'name_type': 'carton', + 'observed': True, + 'release': self.release + } + with self.client.get(url, params=params, catch_response=True) as response: + if response.status_code != 200: + response.failure(f"GET {url} failed: {response.text}") + + @task + def get_spectrum(self): + sdss_id = 23326 + url = f"/target/spectra/{sdss_id}" + params = { + 'product': 'specLite', + 'ext': 'BOSS/APO', + 'release': self.release + } + with self.client.get(url, params=params, catch_response=True) as response: + if response.status_code != 200: + response.failure(f"GET {url} failed: {response.text}") + + @task + def get_catalogs(self): + sdss_id = random.choice(self.sdssids) + url = f"/target/catalogs/{sdss_id}" + params = {'release': self.release} + with self.client.get(url, params=params, catch_response=True) as response: + if response.status_code != 200: + response.failure(f"GET {url} failed: {response.text}") + + @task + def get_parents(self): + catalog = 'gaia_dr3_source' + sdss_id = 129047350 + url = f"/target/parents/{catalog}/{sdss_id}" + params = { + 'catalogid': 63050396587194280, + 'release': self.release + } + with self.client.get(url, params=params, catch_response=True) as response: + if response.status_code != 200: + response.failure(f"GET {url} failed: {response.text}") + + @task + def get_cartons(self): + sdss_id = random.choice(self.sdssids) + url = f"/target/cartons/{sdss_id}" + params = {'release': self.release} + with self.client.get(url, params=params, catch_response=True) as response: + if response.status_code != 200: + response.failure(f"GET {url} failed: {response.text}") + + @task + def get_pipelines(self): + sdss_id = random.choice(self.sdssids) + url = f"/target/pipelines/{sdss_id}" + params = { + 'release': self.release + } + with self.client.get(url, params=params, catch_response=True) as response: + if response.status_code != 200: + response.failure(f"GET {url} failed: {response.text}") + +# if __name__ == "__main__": +# run_single_user(FastAPIUser) \ No newline at end of file diff --git a/python/valis/db/queries.py b/python/valis/db/queries.py index 7f19c9b..2275d06 100644 --- a/python/valis/db/queries.py +++ b/python/valis/db/queries.py @@ -4,13 +4,16 @@ # all resuable queries go here +from contextlib import contextmanager import itertools import packaging +import uuid from typing import Sequence, Union, Generator import astropy.units as u import deepmerge import peewee +from peewee import Case from astropy.coordinates import SkyCoord from sdssdb.peewee.sdss5db import apogee_drpdb as apo from sdssdb.peewee.sdss5db import boss_drp as boss @@ -26,7 +29,7 @@ def append_pipes(query: peewee.ModelSelect, table: str = 'stacked', - observed: bool = True) -> peewee.ModelSelect: + observed: bool = True, release: str = None) -> peewee.ModelSelect: """ Joins a query to the SDSSidToPipes table Joines an existing query to the SDSSidToPipes table and returns @@ -57,21 +60,45 @@ def append_pipes(query: peewee.ModelSelect, table: str = 'stacked', if table not in {'stacked', 'flat'}: raise ValueError('table must be either "stacked" or "flat"') - model = vizdb.SDSSidStacked if table == 'stacked' else vizdb.SDSSidFlat - qq = query.select_extend(vizdb.SDSSidToPipes.in_boss, - vizdb.SDSSidToPipes.in_apogee, - vizdb.SDSSidToPipes.in_bvs, - vizdb.SDSSidToPipes.in_astra, - vizdb.SDSSidToPipes.has_been_observed, - vizdb.SDSSidToPipes.release, - vizdb.SDSSidToPipes.obs, - vizdb.SDSSidToPipes.mjd).\ - join(vizdb.SDSSidToPipes, on=(model.sdss_id == vizdb.SDSSidToPipes.sdss_id), - attr='pipes').distinct(vizdb.SDSSidToPipes.sdss_id) + # Run initial query as a temporary table. + temp = create_temporary_table(query, indices=['sdss_id']) + + qq = temp.select(temp.__star__, + vizdb.SDSSidToPipes.in_boss, + vizdb.SDSSidToPipes.in_apogee, + vizdb.SDSSidToPipes.in_bvs, + vizdb.SDSSidToPipes.in_astra, + vizdb.SDSSidToPipes.has_been_observed, + vizdb.SDSSidToPipes.release, + vizdb.SDSSidToPipes.obs, + vizdb.SDSSidToPipes.mjd).\ + join(vizdb.SDSSidToPipes, on=(temp.c.sdss_id == vizdb.SDSSidToPipes.sdss_id)).\ + distinct(temp.c.sdss_id) if observed: qq = qq.where(vizdb.SDSSidToPipes.has_been_observed == observed) + if release: + # get the release + rel = vizdb.Releases.select().where(vizdb.Releases.release==release).first() + + # if a release has no cutoff info, then force the cutoff to 0, query will return nothing + # to fix this we want mjd cutoffs by survey for all older releases + if not rel.mjd_cutoff_apo and not rel.mjd_cutoff_lco: + rel.mjd_cutoff_apo = 0 + rel.mjd_cutoff_lco = 0 + + # create the mjd cutoff condition + qq = qq.where(vizdb.SDSSidToPipes.mjd <= Case( + vizdb.SDSSidToPipes.obs, + ( + ('apo', rel.mjd_cutoff_apo), + ('lco', rel.mjd_cutoff_lco) + ), + None + ) + ) + return qq @@ -264,7 +291,8 @@ def carton_program_map(key: str = 'program') -> dict: def carton_program_search(name: str, name_type: str, - query: peewee.ModelSelect | None = None) -> peewee.ModelSelect: + query: peewee.ModelSelect | None = None, + limit: int | None = None) -> peewee.ModelSelect: """ Perform a search on either carton or program Parameters @@ -276,6 +304,8 @@ def carton_program_search(name: str, query : ModelSelect An initial query to extend. If ``None``, a new query with all the unique ``sdss_id``s is created. + limit : int + Limit the number of results returned. Returns ------- @@ -286,6 +316,13 @@ def carton_program_search(name: str, if query is None: query = vizdb.SDSSidStacked.select(vizdb.SDSSidStacked).distinct() + # NOTE: These setting seem to help when querying some cartons or programs, mainly + # those with small number of targets, and in some cases with these the query + # actually applies the LIMIT more efficiently, but it's not a perfect solution. + vizdb.database.execute_sql('SET enable_gathermerge = off;') + vizdb.database.execute_sql('SET parallel_tuple_cost = 100;') + vizdb.database.execute_sql('SET enable_bitmapscan = off;') + query = (query.join( vizdb.SDSSidFlat, on=(vizdb.SDSSidFlat.sdss_id == vizdb.SDSSidStacked.sdss_id)) @@ -295,6 +332,9 @@ def carton_program_search(name: str, .join(targetdb.Carton) .where(getattr(targetdb.Carton, name_type) == name)) + if limit: + query = query.limit(limit) + return query def get_targets_obs(release: str, obs: str, spectrograph: str) -> peewee.ModelSelect: @@ -831,3 +871,123 @@ def starfields(model: peewee.ModelSelect) -> peewee.NodeList: pw_ver = peewee.__version__ oldver = packaging.version.parse(pw_ver) < packaging.version.parse('3.17.1') return model.star if oldver else model.__star__ + + +def get_sdssid_by_altid(id: str | int, idtype: str = None) -> peewee.ModelSelect: + """ Get an sdss_id by an alternative id + + This query attempts to identify a target sdss_id from an + alternative id, which can be a string or integer. It tries + to distinguish between the following formats: + + - a (e)BOSS plate-mjd-fiberid, e.g. "10235-58127-0020" + - a BOSS field-mjd-catalogid, e.g. "101077-59845-27021603187129892" + - an SDSS-IV APOGEE ID, e.g "2M23595980+1528407" + - an SDSS-V catalogid, e.g. 2702160318712989 + - a GAIA DR3 ID, e.g. 4110508934728363520 + + It queries either the boss_drp.boss_spectrum or astra.source + tables for the sdss_id. + + Parameters + ---------- + id : str | int + the input alternative id + idtype : str, optional + the type of integer id, by default None + + Returns + ------- + peewee.ModelSelect + the ORM query + """ + + # cast to str + if isinstance(id, int): + id = str(id) + + # temp for now; maybe we make a single "altid" db column somewhere + ndash = id.count('-') + final = id.rsplit('-', 1)[-1] + if ndash == 2 and len(final) <= 4 and final.isdigit() and int(final) <= 1000: + # boss/eboss plate-mjd-fiberid e.g '10235-58127-0020' + return + elif ndash == 2 and len(final) > 5: + # field-mjd-catalogid, e.g. '101077-59845-27021603187129892' + field, mjd, catalogid = id.split('-') + targ = boss.BossSpectrum.select(boss.BossSpectrum.sdss_id).\ + where(boss.BossSpectrum.catalogid == catalogid, + boss.BossSpectrum.mjd == mjd, boss.BossSpectrum.field == field) + elif ndash == 1: + # apogee south, e.g. '2M17282323-2415476' + targ = astra.Source.select(astra.Source.sdss_id).\ + where(astra.Source.sdss4_apogee_id.in_([id])) + elif ndash == 0 and not id.isdigit(): + # apogee obj id + targ = astra.Source.select(astra.Source.sdss_id).\ + where(astra.Source.sdss4_apogee_id.in_([id])) + elif ndash == 0 and id.isdigit(): + # single integer id + if idtype == 'catalogid': + # catalogid , e.g. 27021603187129892 + field = 'catalogid' + elif idtype == 'gaiaid': + # gaia dr3 id , e.g. 4110508934728363520 + field = 'gaia_dr3_source_id' + else: + field = 'catalogid' + + targ = astra.Source.select(astra.Source.sdss_id).\ + where(getattr(astra.Source, field).in_([id])) + + return targ + + +def get_target_by_altid(id: str | int, idtype: str = None) -> peewee.ModelSelect: + """ Get a target by an alternative id + + This retrieves the target info from vizdb.sdss_id_stacked, + given an alternative id. It first tries to identify the proper + sdss_id for the given altid, then it retrieves the basic target + info. See ``get_sdssid_by_altid`` for details on the altid formats. + + Parameters + ---------- + id : str | int + the input alternative id + idtype : str, optional + the type of integer id, by default None + + Returns + ------- + peewee.ModelSelect + the ORM query + """ + # get the sdss_id + targ = get_sdssid_by_altid(id, idtype=idtype) + res = targ.get_or_none() if targ else None + if not res: + return + + # get the sdss_id metadata info + return get_targets_by_sdss_id(res.sdss_id) + + +def create_temporary_table(query: peewee.ModelSelect, + indices: list[str] | None = None) -> Generator[None, None, peewee.Table]: + """Create a temporary table from a query.""" + + table_name = uuid.uuid4().hex[0:8] + + table = peewee.Table(table_name) + table.bind(vizdb.database) + + query.create_table(table_name, temporary=True) + + if indices: + for index in indices: + vizdb.database.execute_sql(f'CREATE INDEX ON "{table_name}" ({index})') + + vizdb.database.execute_sql(f'ANALYZE "{table_name}"') + + return table diff --git a/python/valis/io/spectra.py b/python/valis/io/spectra.py index f131c84..62123a3 100644 --- a/python/valis/io/spectra.py +++ b/python/valis/io/spectra.py @@ -11,6 +11,7 @@ from astropy.io import fits from astropy.nddata import InverseVariance from astropy.wcs import WCS +import numpy as np try: from specutils import Spectrum1D @@ -108,6 +109,12 @@ def extract_data(product: str, filepath: str, multispec: Union[int, str] = None) else: data[param] = hdulist[extension].data + # set dtype byteorder to the native + for key, val in data.items(): + if key == 'header': + continue + data[key] = val.byteswap().newbyteorder('=') + return data diff --git a/python/valis/routes/query.py b/python/valis/routes/query.py index aa63887..1e3b7bd 100644 --- a/python/valis/routes/query.py +++ b/python/valis/routes/query.py @@ -16,7 +16,8 @@ from valis.db.queries import (cone_search, append_pipes, carton_program_search, carton_program_list, carton_program_map, get_targets_by_sdss_id, get_targets_by_catalog_id, - get_targets_obs, get_paged_target_list_by_mapper) + get_targets_obs, get_paged_target_list_by_mapper, + get_target_by_altid) from sdssdb.peewee.sdss5db import database, catalogdb # convert string floats to proper floats @@ -37,9 +38,12 @@ class SearchModel(BaseModel): radius: Optional[Float] = Field(None, description='Search radius in specified units', example=0.02) units: Optional[SearchCoordUnits] = Field('degree', description='Units of search radius', example='degree') id: Optional[Union[int, str]] = Field(None, description='The SDSS identifier', example=23326) + altid: Optional[Union[int, str]] = Field(None, description='An alternative identifier', example=27021603187129892) + idtype: Optional[str] = Field(None, description='The type of integer id, for ambiguous ids', example="catalogid") program: Optional[str] = Field(None, description='The program name', example='bhm_rm') carton: Optional[str] = Field(None, description='The carton name', example='bhm_rm_core') observed: Optional[bool] = Field(True, description='Flag to only include targets that have been observed', example=True) + limit: Optional[int] = Field(None, description='Limit the number of returned targets', example=100) class MainResponse(SDSSModel): """ Combined model from all individual query models """ @@ -95,14 +99,25 @@ async def main_search(self, body: SearchModel): elif body.id: query = get_targets_by_sdss_id(body.id) + # build the altid query + elif body.altid: + query = get_target_by_altid(body.altid, body.idtype) + # build the program/carton query if body.program or body.carton: query = carton_program_search(body.program or body.carton, 'program' if body.program else 'carton', query=query) + + # DANGER!!! This limit applies *before* the append_pipes call. If the + # append_pipes call includes observed=True we may have limited things in + # such a way that only unobserved or very few targets are returned. + if body.limit: + query = query.limit(body.limit) + # append query to pipes if query: - query = append_pipes(query, observed=body.observed) + query = append_pipes(query, observed=body.observed, release=self.release) # query iterator res = query.dicts().iterator() if query else [] @@ -120,7 +135,7 @@ async def cone_search(self, """ Perform a cone search """ res = cone_search(ra, dec, radius, units=units) - r = append_pipes(res, observed=observed) + r = append_pipes(res, observed=observed, release=self.release) # return sorted by distance # doing this here due to the append_pipes distinct return sorted(r.dicts().iterator(), key=lambda x: x['distance']) @@ -198,12 +213,17 @@ async def carton_program(self, Query(enum=['carton', 'program'], description='Specify search on carton or program', example='carton')] = 'carton', - observed: Annotated[bool, Query(description='Flag to only include targets that have been observed', example=True)] = True): + observed: Annotated[bool, Query(description='Flag to only include targets that have been observed', example=True)] = True, + limit: Annotated[int | None, Query(description='Limit the number of returned targets', example=100)] = None): """ Perform a search on carton or program """ with database.atomic(): - database.execute_sql('SET LOCAL enable_seqscan=false;') - query = carton_program_search(name, name_type) + if limit is False: + # This tweak seems to do more harm than good when limit is passed. + database.execute_sql('SET LOCAL enable_seqscan=false;') + + query = carton_program_search(name, name_type, limit=limit) query = append_pipes(query, observed=observed) + return query.dicts().iterator() @router.get('/obs', summary='Return targets with spectrum at observatory', diff --git a/python/valis/routes/target.py b/python/valis/routes/target.py index c049f00..91745a4 100644 --- a/python/valis/routes/target.py +++ b/python/valis/routes/target.py @@ -15,7 +15,7 @@ from valis.routes.base import Base from valis.db.queries import (get_target_meta, get_a_spectrum, get_catalog_sources, get_parent_catalog_data, get_target_cartons, - get_target_pipeline) + get_target_pipeline, get_target_by_altid, append_pipes) from valis.db.db import get_pw_db from valis.db.models import CatalogResponse, CartonModel, ParentCatalogModel, PipesModel, SDSSModel @@ -167,6 +167,18 @@ async def get_target(self, sdss_id: int = Path(title="The sdss_id of the target """ Return target metadata for a given sdss_id """ return get_target_meta(sdss_id, self.release) or {} + @router.get('/sdssid/{id}', summary='Retrieve a target sdss_id from an alternative id', + dependencies=[Depends(get_pw_db)], + response_model=Union[SDSSModel, dict], + response_model_exclude_unset=True, response_model_exclude_none=True) + async def get_target_altid(self, + id: Annotated[int | str, Path(title="The alternative id of the target to get", example="2M23595980+1528407")], + idtype: Annotated[str, Query(enum=['catalogid', 'gaiaid'], description='For ambiguous integer ids, the type of id, e.g. "catalogid"', example=None)] = None + ): + """ Return target metadata for a given sdss_id """ + query = append_pipes(get_target_by_altid(id, idtype=idtype), observed=False) + return query.dicts().first() or {} + @router.get('/spectra/{sdss_id}', summary='Retrieve a spectrum for a target sdss_id', dependencies=[Depends(get_pw_db)], response_model=List[SpectrumModel])