Skip to content

Commit

Permalink
minor improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
Marius Isken committed Jul 7, 2024
1 parent 65e94e7 commit b79b62b
Show file tree
Hide file tree
Showing 8 changed files with 36 additions and 39 deletions.
7 changes: 4 additions & 3 deletions src/qseek/images/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,15 +89,16 @@ def resample(self, sampling_rate: float, max_normalize: bool = False) -> None:
max_normalize (bool): Normalize by maximum value to keep the scale of the
maximum detection. Defaults to False.
"""
if self.sampling_rate == sampling_rate:
return

downsample = self.sampling_rate > sampling_rate

for tr in self.traces:
if max_normalize:
# We can use maximum here since the PhaseNet output is single-sided
_, max_value = tr.max()
resample(tr, sampling_rate)

if max_normalize and downsample:
_, max_value = tr.max()
tr.ydata /= tr.ydata.max()
tr.ydata *= max_value

Expand Down
7 changes: 4 additions & 3 deletions src/qseek/images/images.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from qseek.images.base import ImageFunction
from qseek.images.phase_net import PhaseNet
from qseek.stats import Stats
from qseek.utils import PhaseDescription, datetime_now, human_readable_bytes
from qseek.utils import QUEUE_SIZE, PhaseDescription, datetime_now, human_readable_bytes

if TYPE_CHECKING:
from pyrocko.trace import Trace
Expand Down Expand Up @@ -72,7 +72,9 @@ def _populate_table(self, table: Table) -> None:
class ImageFunctions(RootModel):
root: list[ImageFunctionType] = [PhaseNet()]

_queue: asyncio.Queue[Tuple[WaveformImages, WaveformBatch] | None] = PrivateAttr()
_queue: asyncio.Queue[Tuple[WaveformImages, WaveformBatch] | None] = PrivateAttr(
asyncio.Queue(maxsize=QUEUE_SIZE)
)
_processed_images: int = PrivateAttr(0)
_stats: ImageFunctionsStats = PrivateAttr(default_factory=ImageFunctionsStats)

Expand All @@ -81,7 +83,6 @@ def model_post_init(self, __context: Any) -> None:
phases = self.get_phases()
if len(set(phases)) != len(phases):
raise ValueError("A phase was provided twice")
self._queue = asyncio.Queue(maxsize=16)
self._stats.set_queue(self._queue)

async def process_traces(self, traces: list[Trace]) -> WaveformImages:
Expand Down
2 changes: 1 addition & 1 deletion src/qseek/models/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def new_detection(self, detection: EventDetection):
self.max_semblance = max(self.max_semblance, detection.semblance)

def _populate_table(self, table: Table) -> None:
table.add_row("No. Detections", f"[bold]{self.n_detections} :dim_button:")
table.add_row("No. Detections", f"[bold]{self.n_detections} :fire:")
table.add_row("Maximum semblance", f"{self.max_semblance:.4f}")


Expand Down
19 changes: 11 additions & 8 deletions src/qseek/models/station.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,6 @@ def weed_from_squirrel_waveforms(self, squirrel: Squirrel) -> None:
raise ValueError("no stations available, add waveforms to start detection")

def __iter__(self) -> Iterator[Station]:
# TODO: this is inefficient
return (sta for sta in self.stations if sta.nsl.pretty not in self.blacklist)

@property
Expand All @@ -186,14 +185,18 @@ def select_from_traces(self, traces: Iterable[Trace]) -> Stations:
Returns:
Stations: Containing only selected stations.
"""
available_stations = tuple(self)
available_nsls = tuple(sta.nsl for sta in available_stations)

selected_stations = []
for nsl in ((tr.network, tr.station, tr.location) for tr in traces):
for sta in self:
if sta.nsl == nsl:
selected_stations.append(sta)
break
else:
raise ValueError(f"could not find a station for {'.'.join(nsl)} ")
for nsl in {(tr.network, tr.station, tr.location) for tr in traces}:
try:
sta_idx = available_nsls.index(nsl)
selected_stations.append(available_stations[sta_idx])
except ValueError as exc:
raise ValueError(
f"could not find a station for {'.'.join(nsl)} "
) from exc
return Stations.model_construct(stations=selected_stations)

def get_centroid(self) -> Location:
Expand Down
6 changes: 4 additions & 2 deletions src/qseek/pre_processing/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
Lowpass,
)
from qseek.stats import Stats
from qseek.utils import datetime_now, human_readable_bytes
from qseek.utils import QUEUE_SIZE, datetime_now, human_readable_bytes

if TYPE_CHECKING:
from rich.table import Table
Expand Down Expand Up @@ -71,7 +71,9 @@ class PreProcessing(RootModel):
"The first module is the first to be applied.",
)

_queue: asyncio.Queue[WaveformBatch | None] = asyncio.Queue(maxsize=12)
_queue: asyncio.Queue[WaveformBatch | None] = PrivateAttr(
asyncio.Queue(maxsize=QUEUE_SIZE)
)
_stats: PreProcessingStats = PrivateAttr(default_factory=PreProcessingStats)

def model_post_init(self, __context: Any) -> None:
Expand Down
12 changes: 6 additions & 6 deletions src/qseek/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,19 +500,15 @@ async def start(self, force_rundir: bool = False) -> None:

await self.prepare()

logger.info("starting search")
stats = self._stats
stats.reset_start_time()

processing_start = datetime_now()

if self._progress.time_progress:
logger.info("continuing search from %s", self._progress.time_progress)
await self._catalog.check(repair=True)
await self._catalog.filter_events_by_time(
start_time=None,
end_time=self._progress.time_progress,
)
else:
logger.info("starting search")

batches = self.data_provider.iter_batches(
window_increment=self.window_length,
Expand All @@ -522,6 +518,10 @@ async def start(self, force_rundir: bool = False) -> None:
)
processed_batches = self.pre_processing.iter_batches(batches)

stats = self._stats
stats.reset_start_time()

processing_start = datetime_now()
console = asyncio.create_task(RuntimeStats.live_view())

async for images, batch in self.image_functions.iter_images(processed_batches):
Expand Down
1 change: 1 addition & 0 deletions src/qseek/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@

PhaseDescription = Annotated[str, constr(pattern=r"[a-zA-Z]*:[a-zA-Z]*")]

QUEUE_SIZE = 16
CACHE_DIR = Path.home() / ".cache" / "qseek"
if not CACHE_DIR.exists():
logger.info("creating cache dir %s", CACHE_DIR)
Expand Down
21 changes: 5 additions & 16 deletions src/qseek/waveforms/squirrel.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

from qseek.models.station import Stations
from qseek.stats import Stats
from qseek.utils import datetime_now, human_readable_bytes, to_datetime
from qseek.utils import QUEUE_SIZE, datetime_now, human_readable_bytes, to_datetime
from qseek.waveforms.base import WaveformBatch, WaveformProvider

if TYPE_CHECKING:
Expand All @@ -40,14 +40,10 @@ class SquirrelPrefetcher:
_fetched_batches: int
_task: asyncio.Task[None]

def __init__(
self,
iterator: Iterator[Batch],
queue_size: int = 8,
) -> None:
def __init__(self, iterator: Iterator[Batch]) -> None:
self.iterator = iterator
self.queue = asyncio.Queue(maxsize=queue_size)
self._load_queue = asyncio.Queue(maxsize=queue_size)
self.queue = asyncio.Queue(maxsize=QUEUE_SIZE)
self._load_queue = asyncio.Queue(maxsize=QUEUE_SIZE)
self._fetched_batches = 0

self._task = asyncio.create_task(self.prefetch_worker())
Expand Down Expand Up @@ -143,10 +139,6 @@ class PyrockoSquirrel(WaveformProvider):
description="Channel selector for waveforms, " "e.g. `['HH', 'EN']`.",
)
)
async_prefetch_batches: PositiveInt = Field(
default=10,
description="Queue size for asynchronous pre-fetcher.",
)
n_threads: PositiveInt = Field(
default=8,
description="Number of threads for loading waveforms.",
Expand Down Expand Up @@ -227,10 +219,7 @@ async def iter_batches(
codes=[(*nsl, "*") for nsl in self._stations.get_all_nsl()],
channel_priorities=self.channel_selector,
)
prefetcher = SquirrelPrefetcher(
iterator,
queue_size=self.async_prefetch_batches,
)
prefetcher = SquirrelPrefetcher(iterator)
stats.set_queue(prefetcher.queue)

while True:
Expand Down

0 comments on commit b79b62b

Please sign in to comment.