Skip to content

Commit

Permalink
Merge branch 'main' into havok2063/add-file-upload-endpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
havok2063 committed Dec 9, 2024
2 parents fa30480 + dbf2ad5 commit 32f8537
Show file tree
Hide file tree
Showing 5 changed files with 347 additions and 20 deletions.
128 changes: 128 additions & 0 deletions locustfile.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
import random
from locust import HttpUser, task, between


# A set of files to run with [locust.io](https://locust.io/) for performance testing the app.
# pip install locust
# then run "locust -f locustfile.py" and open http://localhost:8089/ in your browser


class FastAPIUser(HttpUser):
release = 'IPL3'
sdssids = [23326, 54392544, 57651832, 57832526, 61731453, 85995134, 56055457]
wait_time = between(1, 5) # Simulate user think time between requests

@task
def query_main(self):
url = "/query/main"
headers = {'Content-Type': 'application/json'}
params = {'release': self.release}
payload1 = {
'ra': random.uniform(0, 360),
'dec': random.uniform(-90, 90),
'radius': random.uniform(0.01, 0.2),
'units': 'degree',
'observed': True
}
payload2 = {
"id": random.choice(self.sdssids),
}
payload3 = {
'ra': random.uniform(0, 360),
'dec': random.uniform(-90, 90),
'radius': random.uniform(0.01, 0.2),
'units': 'degree',
'observed': True,
'program': 'bhm_rm',
'carton': 'bhm_rm_core'
}
payload = random.choice([payload1, payload2, payload3])
with self.client.post(url, headers=headers, params=params, json=payload, catch_response=True) as response:
if response.status_code != 200:
response.failure(f"POST {url} failed: {response.text}")

@task
def query_cone(self):
url = "/query/cone"
params = {
'ra': random.uniform(0, 360),
'dec': random.uniform(-90, 90),
'radius': random.uniform(0.01, 0.5),
'units': 'degree',
'observed': random.choice([True, False]),
'release': self.release
}
with self.client.get(url, params=params, catch_response=True) as response:
if response.status_code != 200:
response.failure(f"GET {url} failed: {response.text}")

@task
def query_carton(self):
url = '/query/carton-program'
params = {
'name': 'manual_mwm_tess_ob',
'name_type': 'carton',
'observed': True,
'release': self.release
}
with self.client.get(url, params=params, catch_response=True) as response:
if response.status_code != 200:
response.failure(f"GET {url} failed: {response.text}")

@task
def get_spectrum(self):
sdss_id = 23326
url = f"/target/spectra/{sdss_id}"
params = {
'product': 'specLite',
'ext': 'BOSS/APO',
'release': self.release
}
with self.client.get(url, params=params, catch_response=True) as response:
if response.status_code != 200:
response.failure(f"GET {url} failed: {response.text}")

@task
def get_catalogs(self):
sdss_id = random.choice(self.sdssids)
url = f"/target/catalogs/{sdss_id}"
params = {'release': self.release}
with self.client.get(url, params=params, catch_response=True) as response:
if response.status_code != 200:
response.failure(f"GET {url} failed: {response.text}")

@task
def get_parents(self):
catalog = 'gaia_dr3_source'
sdss_id = 129047350
url = f"/target/parents/{catalog}/{sdss_id}"
params = {
'catalogid': 63050396587194280,
'release': self.release
}
with self.client.get(url, params=params, catch_response=True) as response:
if response.status_code != 200:
response.failure(f"GET {url} failed: {response.text}")

@task
def get_cartons(self):
sdss_id = random.choice(self.sdssids)
url = f"/target/cartons/{sdss_id}"
params = {'release': self.release}
with self.client.get(url, params=params, catch_response=True) as response:
if response.status_code != 200:
response.failure(f"GET {url} failed: {response.text}")

@task
def get_pipelines(self):
sdss_id = random.choice(self.sdssids)
url = f"/target/pipelines/{sdss_id}"
params = {
'release': self.release
}
with self.client.get(url, params=params, catch_response=True) as response:
if response.status_code != 200:
response.failure(f"GET {url} failed: {response.text}")

# if __name__ == "__main__":
# run_single_user(FastAPIUser)
186 changes: 173 additions & 13 deletions python/valis/db/queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,16 @@

# all resuable queries go here

from contextlib import contextmanager
import itertools
import packaging
import uuid
from typing import Sequence, Union, Generator

import astropy.units as u
import deepmerge
import peewee
from peewee import Case
from astropy.coordinates import SkyCoord
from sdssdb.peewee.sdss5db import apogee_drpdb as apo
from sdssdb.peewee.sdss5db import boss_drp as boss
Expand All @@ -26,7 +29,7 @@


def append_pipes(query: peewee.ModelSelect, table: str = 'stacked',
observed: bool = True) -> peewee.ModelSelect:
observed: bool = True, release: str = None) -> peewee.ModelSelect:
""" Joins a query to the SDSSidToPipes table
Joines an existing query to the SDSSidToPipes table and returns
Expand Down Expand Up @@ -57,21 +60,45 @@ def append_pipes(query: peewee.ModelSelect, table: str = 'stacked',
if table not in {'stacked', 'flat'}:
raise ValueError('table must be either "stacked" or "flat"')

model = vizdb.SDSSidStacked if table == 'stacked' else vizdb.SDSSidFlat
qq = query.select_extend(vizdb.SDSSidToPipes.in_boss,
vizdb.SDSSidToPipes.in_apogee,
vizdb.SDSSidToPipes.in_bvs,
vizdb.SDSSidToPipes.in_astra,
vizdb.SDSSidToPipes.has_been_observed,
vizdb.SDSSidToPipes.release,
vizdb.SDSSidToPipes.obs,
vizdb.SDSSidToPipes.mjd).\
join(vizdb.SDSSidToPipes, on=(model.sdss_id == vizdb.SDSSidToPipes.sdss_id),
attr='pipes').distinct(vizdb.SDSSidToPipes.sdss_id)
# Run initial query as a temporary table.
temp = create_temporary_table(query, indices=['sdss_id'])

qq = temp.select(temp.__star__,
vizdb.SDSSidToPipes.in_boss,
vizdb.SDSSidToPipes.in_apogee,
vizdb.SDSSidToPipes.in_bvs,
vizdb.SDSSidToPipes.in_astra,
vizdb.SDSSidToPipes.has_been_observed,
vizdb.SDSSidToPipes.release,
vizdb.SDSSidToPipes.obs,
vizdb.SDSSidToPipes.mjd).\
join(vizdb.SDSSidToPipes, on=(temp.c.sdss_id == vizdb.SDSSidToPipes.sdss_id)).\
distinct(temp.c.sdss_id)

if observed:
qq = qq.where(vizdb.SDSSidToPipes.has_been_observed == observed)

if release:
# get the release
rel = vizdb.Releases.select().where(vizdb.Releases.release==release).first()

# if a release has no cutoff info, then force the cutoff to 0, query will return nothing
# to fix this we want mjd cutoffs by survey for all older releases
if not rel.mjd_cutoff_apo and not rel.mjd_cutoff_lco:
rel.mjd_cutoff_apo = 0
rel.mjd_cutoff_lco = 0

# create the mjd cutoff condition
qq = qq.where(vizdb.SDSSidToPipes.mjd <= Case(
vizdb.SDSSidToPipes.obs,
(
('apo', rel.mjd_cutoff_apo),
('lco', rel.mjd_cutoff_lco)
),
None
)
)

return qq


Expand Down Expand Up @@ -264,7 +291,8 @@ def carton_program_map(key: str = 'program') -> dict:

def carton_program_search(name: str,
name_type: str,
query: peewee.ModelSelect | None = None) -> peewee.ModelSelect:
query: peewee.ModelSelect | None = None,
limit: int | None = None) -> peewee.ModelSelect:
""" Perform a search on either carton or program
Parameters
Expand All @@ -276,6 +304,8 @@ def carton_program_search(name: str,
query : ModelSelect
An initial query to extend. If ``None``, a new query with all the unique
``sdss_id``s is created.
limit : int
Limit the number of results returned.
Returns
-------
Expand All @@ -286,6 +316,13 @@ def carton_program_search(name: str,
if query is None:
query = vizdb.SDSSidStacked.select(vizdb.SDSSidStacked).distinct()

# NOTE: These setting seem to help when querying some cartons or programs, mainly
# those with small number of targets, and in some cases with these the query
# actually applies the LIMIT more efficiently, but it's not a perfect solution.
vizdb.database.execute_sql('SET enable_gathermerge = off;')
vizdb.database.execute_sql('SET parallel_tuple_cost = 100;')
vizdb.database.execute_sql('SET enable_bitmapscan = off;')

query = (query.join(
vizdb.SDSSidFlat,
on=(vizdb.SDSSidFlat.sdss_id == vizdb.SDSSidStacked.sdss_id))
Expand All @@ -295,6 +332,9 @@ def carton_program_search(name: str,
.join(targetdb.Carton)
.where(getattr(targetdb.Carton, name_type) == name))

if limit:
query = query.limit(limit)

return query

def get_targets_obs(release: str, obs: str, spectrograph: str) -> peewee.ModelSelect:
Expand Down Expand Up @@ -831,3 +871,123 @@ def starfields(model: peewee.ModelSelect) -> peewee.NodeList:
pw_ver = peewee.__version__
oldver = packaging.version.parse(pw_ver) < packaging.version.parse('3.17.1')
return model.star if oldver else model.__star__


def get_sdssid_by_altid(id: str | int, idtype: str = None) -> peewee.ModelSelect:
""" Get an sdss_id by an alternative id
This query attempts to identify a target sdss_id from an
alternative id, which can be a string or integer. It tries
to distinguish between the following formats:
- a (e)BOSS plate-mjd-fiberid, e.g. "10235-58127-0020"
- a BOSS field-mjd-catalogid, e.g. "101077-59845-27021603187129892"
- an SDSS-IV APOGEE ID, e.g "2M23595980+1528407"
- an SDSS-V catalogid, e.g. 2702160318712989
- a GAIA DR3 ID, e.g. 4110508934728363520
It queries either the boss_drp.boss_spectrum or astra.source
tables for the sdss_id.
Parameters
----------
id : str | int
the input alternative id
idtype : str, optional
the type of integer id, by default None
Returns
-------
peewee.ModelSelect
the ORM query
"""

# cast to str
if isinstance(id, int):
id = str(id)

# temp for now; maybe we make a single "altid" db column somewhere
ndash = id.count('-')
final = id.rsplit('-', 1)[-1]
if ndash == 2 and len(final) <= 4 and final.isdigit() and int(final) <= 1000:
# boss/eboss plate-mjd-fiberid e.g '10235-58127-0020'
return
elif ndash == 2 and len(final) > 5:
# field-mjd-catalogid, e.g. '101077-59845-27021603187129892'
field, mjd, catalogid = id.split('-')
targ = boss.BossSpectrum.select(boss.BossSpectrum.sdss_id).\
where(boss.BossSpectrum.catalogid == catalogid,
boss.BossSpectrum.mjd == mjd, boss.BossSpectrum.field == field)
elif ndash == 1:
# apogee south, e.g. '2M17282323-2415476'
targ = astra.Source.select(astra.Source.sdss_id).\
where(astra.Source.sdss4_apogee_id.in_([id]))
elif ndash == 0 and not id.isdigit():
# apogee obj id
targ = astra.Source.select(astra.Source.sdss_id).\
where(astra.Source.sdss4_apogee_id.in_([id]))
elif ndash == 0 and id.isdigit():
# single integer id
if idtype == 'catalogid':
# catalogid , e.g. 27021603187129892
field = 'catalogid'
elif idtype == 'gaiaid':
# gaia dr3 id , e.g. 4110508934728363520
field = 'gaia_dr3_source_id'
else:
field = 'catalogid'

targ = astra.Source.select(astra.Source.sdss_id).\
where(getattr(astra.Source, field).in_([id]))

return targ


def get_target_by_altid(id: str | int, idtype: str = None) -> peewee.ModelSelect:
""" Get a target by an alternative id
This retrieves the target info from vizdb.sdss_id_stacked,
given an alternative id. It first tries to identify the proper
sdss_id for the given altid, then it retrieves the basic target
info. See ``get_sdssid_by_altid`` for details on the altid formats.
Parameters
----------
id : str | int
the input alternative id
idtype : str, optional
the type of integer id, by default None
Returns
-------
peewee.ModelSelect
the ORM query
"""
# get the sdss_id
targ = get_sdssid_by_altid(id, idtype=idtype)
res = targ.get_or_none() if targ else None
if not res:
return

# get the sdss_id metadata info
return get_targets_by_sdss_id(res.sdss_id)


def create_temporary_table(query: peewee.ModelSelect,
indices: list[str] | None = None) -> Generator[None, None, peewee.Table]:
"""Create a temporary table from a query."""

table_name = uuid.uuid4().hex[0:8]

table = peewee.Table(table_name)
table.bind(vizdb.database)

query.create_table(table_name, temporary=True)

if indices:
for index in indices:
vizdb.database.execute_sql(f'CREATE INDEX ON "{table_name}" ({index})')

vizdb.database.execute_sql(f'ANALYZE "{table_name}"')

return table
Loading

0 comments on commit 32f8537

Please sign in to comment.