Skip to content

Commit

Permalink
bugfixes
Browse files Browse the repository at this point in the history
  • Loading branch information
Marius Isken committed May 31, 2024
1 parent f06c178 commit c9639f1
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 149 deletions.
15 changes: 8 additions & 7 deletions src/qseek/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
)

from qseek.corrections.corrections import StationCorrectionType
from qseek.distance_weights import DistanceWeights
from qseek.features import FeatureExtractorType
from qseek.images.images import ImageFunctions, WaveformImages
from qseek.magnitudes import EventMagnitudeCalculatorType
Expand All @@ -34,7 +35,6 @@
from qseek.pre_processing.frequency_filters import Bandpass
from qseek.pre_processing.module import Downsample, PreProcessing
from qseek.signals import Signal
from qseek.spatial_weights import SpatialWeights
from qseek.stats import RuntimeStats, Stats
from qseek.tracers.tracers import RayTracer, RayTracers
from qseek.utils import (
Expand Down Expand Up @@ -238,8 +238,9 @@ class Search(BaseModel):
default=RayTracers(root=[tracer() for tracer in RayTracer.get_subclasses()]),
description="List of ray tracers for travel time calculation.",
)
spatial_weights: SpatialWeights | None = Field(
default=SpatialWeights(),
distance_weights: DistanceWeights | None = Field(
default=DistanceWeights(),
alias="spatial_weights",
description="Spatial weights for distance weighting.",
)
station_corrections: StationCorrectionType | None = Field(
Expand Down Expand Up @@ -463,8 +464,8 @@ async def prepare(self) -> None:
self.data_provider.prepare(self.stations)
await self.pre_processing.prepare()

if self.spatial_weights:
self.spatial_weights.prepare(self.stations, self.octree)
if self.distance_weights:
self.distance_weights.prepare(self.stations, self.octree)

if self.station_corrections:
await self.station_corrections.prepare(
Expand Down Expand Up @@ -722,8 +723,8 @@ async def calculate_semblance(
weights = np.full_like(shifts, fill_value=image.weight, dtype=np.float32)
weights[traveltimes_bad] = 0.0

if parent.spatial_weights:
weights *= await parent.spatial_weights.get_weights(octree, image.stations)
if parent.distance_weights:
weights *= await parent.distance_weights.get_weights(octree, image.stations)

with np.errstate(divide="ignore", invalid="ignore"):
weights /= weights.sum(axis=1, keepdims=True)
Expand Down
131 changes: 0 additions & 131 deletions src/qseek/spatial_weights.py

This file was deleted.

23 changes: 12 additions & 11 deletions src/qseek/waveforms/squirrel.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
PositiveInt,
PrivateAttr,
computed_field,
constr,
model_validator,
)
from pyrocko.squirrel import Squirrel
Expand Down Expand Up @@ -53,7 +54,7 @@ def __init__(

async def prefetch_worker(self) -> None:
logger.info(
"start prefetching data, queue size %d",
"start prefetching waveforms, queue size %d",
self.queue.maxsize,
)

Expand All @@ -70,7 +71,7 @@ async def load_data() -> None | Batch:
await self.queue.put(batch)

await asyncio.create_task(load_data())
logger.debug("loading waveform batches to finish")
logger.debug("done loading waveforms")


class SquirrelStats(Stats):
Expand Down Expand Up @@ -135,11 +136,12 @@ class PyrockoSquirrel(WaveformProvider):
"[ISO8601](https://en.wikipedia.org/wiki/ISO_8601).",
)

channel_selector: str = Field(
default="*",
max_length=3,
description="Channel selector for waveforms, "
"use e.g. `EN?` for selection of all accelerometer data.",
channel_selector: list[constr(to_upper=True, max_length=2, min_length=2)] | None = (
Field(
default=None,
description="Channel selector for waveforms, "
"use e.g. `EN` for selection of all accelerometer data.",
)
)
async_prefetch_batches: PositiveInt = Field(
default=10,
Expand Down Expand Up @@ -220,12 +222,11 @@ async def iter_batches(
tinc=window_increment.total_seconds(),
tpad=window_padding.total_seconds(),
want_incomplete=False,
codes=[
(*nsl, self.channel_selector) for nsl in self._stations.get_all_nsl()
],
codes=[(*nsl, "*") for nsl in self._stations.get_all_nsl()],
)
prefetcher = SquirrelPrefetcher(
iterator, queue_size=self.async_prefetch_batches
iterator,
queue_size=self.async_prefetch_batches,
)
stats.set_queue(prefetcher.queue)

Expand Down

0 comments on commit c9639f1

Please sign in to comment.