From 92ed89c5d8272f77638ad57b53fae309a37d04b8 Mon Sep 17 00:00:00 2001 From: miili Date: Fri, 15 Mar 2024 21:44:30 +0100 Subject: [PATCH 01/26] search: fixes --- src/qseek/images/images.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/qseek/images/images.py b/src/qseek/images/images.py index 74771ded..9b3eb44e 100644 --- a/src/qseek/images/images.py +++ b/src/qseek/images/images.py @@ -112,6 +112,10 @@ async def worker() -> None: "start pre-processing images, queue size %d", self._queue.maxsize ) async for batch in batch_iterator: + if batch.is_empty(): + logger.debug("empty batch, skipping") + continue + start_time = datetime_now() images = await self.process_traces(batch.traces) stats.time_per_batch = datetime_now() - start_time From ea1122340c7630cf8cdd0cbfd7f689ec5513d759 Mon Sep 17 00:00:00 2001 From: Marius Isken Date: Mon, 18 Mar 2024 10:32:54 +0000 Subject: [PATCH 02/26] bugfixes --- src/qseek/images/phase_net.py | 9 ++++----- src/qseek/models/semblance.py | 9 +++------ src/qseek/models/station.py | 17 ++++++++++++----- src/qseek/octree.py | 21 +++++++++++++++++++++ 4 files changed, 40 insertions(+), 16 deletions(-) diff --git a/src/qseek/images/phase_net.py b/src/qseek/images/phase_net.py index ef5c2e33..60b03c34 100644 --- a/src/qseek/images/phase_net.py +++ b/src/qseek/images/phase_net.py @@ -113,16 +113,15 @@ def search_phase_arrival( peak_delay = peak_times - event_time.timestamp() # Limit to post-event peaks - post_event_peaks = peak_delay > 0.0 - peak_idx = peak_idx[post_event_peaks] - peak_times = peak_times[post_event_peaks] - peak_delay = peak_delay[post_event_peaks] + after_event_peaks = peak_delay > 0.0 + peak_idx = peak_idx[after_event_peaks] + peak_times = peak_times[after_event_peaks] + peak_delay = peak_delay[after_event_peaks] if not peak_idx.size: return None peak_values = search_trace.get_ydata()[peak_idx] - closest_peak_idx = np.argmin(peak_delay) return ObservedArrival( diff --git a/src/qseek/models/semblance.py b/src/qseek/models/semblance.py index 136eeef1..aa26868f 100644 --- a/src/qseek/models/semblance.py +++ b/src/qseek/models/semblance.py @@ -74,13 +74,10 @@ def _populate_table(self, table: Table) -> None: ) table.add_row( "Semblance size", - f"{human_readable_bytes(self.semblance_size_bytes)}" + f"{human_readable_bytes(self.semblance_size_bytes)}/" + f"{human_readable_bytes(self.semblance_allocation_bytes)}" f" ({self.last_nodes_stacked} nodes)", ) - table.add_row( - "Memory allocated", - f"{human_readable_bytes(self.semblance_allocation_bytes)}", - ) class SemblanceCache(dict[bytes, np.ndarray]): @@ -240,7 +237,7 @@ async def apply_cache(self, cache: SemblanceCache) -> None: self.semblance_unpadded, data, mask, - nthreads=1, + nthreads=8, ) def maximum_node_semblance(self) -> np.ndarray: diff --git a/src/qseek/models/station.py b/src/qseek/models/station.py index 10742598..ded665d2 100644 --- a/src/qseek/models/station.py +++ b/src/qseek/models/station.py @@ -5,7 +5,7 @@ from typing import TYPE_CHECKING, Any, Iterable, Iterator import numpy as np -from pydantic import BaseModel, Field, FilePath, constr +from pydantic import BaseModel, DirectoryPath, Field, FilePath, constr from pyrocko.io.stationxml import load_xml from pyrocko.model import Station as PyrockoStation from pyrocko.model import dump_stations_yaml, load_stations @@ -76,7 +76,7 @@ class Stations(BaseModel): description="List of [Pyrocko station YAML]" "(https://pyrocko.org/docs/current/formats/yaml.html) files.", ) - station_xmls: list[FilePath] = Field( + station_xmls: list[FilePath | DirectoryPath] = Field( default=[], description="List of StationXML files.", ) @@ -93,9 +93,16 @@ def model_post_init(self, __context: Any) -> None: for file in self.pyrocko_station_yamls: loaded_stations += load_stations(filename=str(file.expanduser())) - for file in self.station_xmls: - station_xml = load_xml(filename=str(file.expanduser())) - loaded_stations += station_xml.get_pyrocko_stations() + for path in self.station_xmls: + if path.is_dir(): + station_xmls = path.glob("*.xml") + elif path.is_file(): + station_xmls = [path] + else: + continue + for file in station_xmls: + station_xml = load_xml(filename=str(file.expanduser())) + loaded_stations += station_xml.get_pyrocko_stations() for sta in loaded_stations: sta = Station.from_pyrocko_station(sta) diff --git a/src/qseek/octree.py b/src/qseek/octree.py index 1030165a..291c1c4a 100644 --- a/src/qseek/octree.py +++ b/src/qseek/octree.py @@ -709,6 +709,27 @@ def save_pickle(self, filename: Path) -> None: with filename.open("wb") as f: pickle.dump(self, f) + def get_corners(self) -> list[Location]: + """Get the corners of the octree. + + Returns: + list[Location]: List of locations. + """ + reference = self.location + return [ + Location( + lat=reference.lat, + lon=reference.lon, + elevation=reference.elevation, + east_shift=reference.east_shift + east, + north_shift=reference.north_shift + north, + depth=reference.depth + depth, + ) + for east in (self.east_bounds.min, self.east_bounds.max) + for north in (self.north_bounds.min, self.north_bounds.max) + for depth in (self.depth_bounds.min, self.depth_bounds.max) + ] + def __hash__(self) -> int: return hash( ( From 72f70da3e070833438a00ec036bd0cbf7eaca07a Mon Sep 17 00:00:00 2001 From: Marius Isken Date: Thu, 21 Mar 2024 11:48:18 +0000 Subject: [PATCH 03/26] bugfixes --- pyproject.toml | 16 +- src/qseek/apps/qseek.py | 67 ++- src/qseek/corrections/base.py | 9 +- src/qseek/images/base.py | 8 +- src/qseek/images/images.py | 3 +- src/qseek/images/phase_net.py | 2 +- src/qseek/magnitudes/base.py | 17 +- src/qseek/magnitudes/local_magnitude.py | 8 +- src/qseek/magnitudes/local_magnitude_model.py | 2 +- src/qseek/magnitudes/moment_magnitude.py | 61 ++- .../magnitudes/moment_magnitude_store.py | 440 ++++++++++++------ src/qseek/models/catalog.py | 102 +++- src/qseek/models/detection.py | 180 ++++--- src/qseek/models/detection_uncertainty.py | 9 +- src/qseek/models/location.py | 7 +- src/qseek/models/semblance.py | 16 +- src/qseek/models/station.py | 5 +- src/qseek/octree.py | 27 +- src/qseek/pre_processing/base.py | 18 +- src/qseek/pre_processing/module.py | 2 +- src/qseek/search.py | 64 ++- src/qseek/tracers/cake.py | 7 +- src/qseek/utils.py | 55 +-- test/test_moment_magnitude_store.py | 47 +- 24 files changed, 773 insertions(+), 399 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 403fb1e3..a43ccfb3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -106,13 +106,27 @@ extend-select = [ 'I', 'RUF', 'T20', + 'D', ] -ignore = ["RUF012", "RUF009"] +ignore = [ + "RUF012", + "RUF009", + "D100", + "D101", + "D102", + "D103", + "D104", + "D105", + "D107", +] [tool.ruff] target-version = 'py311' +[tool.ruff.lint.pydocstyle] +convention = "google" + [tool.pytest.ini_options] markers = ["plot: plot figures in tests"] diff --git a/src/qseek/apps/qseek.py b/src/qseek/apps/qseek.py index 3d6ddb70..72343179 100644 --- a/src/qseek/apps/qseek.py +++ b/src/qseek/apps/qseek.py @@ -11,6 +11,9 @@ import nest_asyncio from pkg_resources import get_distribution +from qseek.models.detection import EventDetection +from qseek.utils import get_cpu_count + nest_asyncio.apply() logger = logging.getLogger(__name__) @@ -128,6 +131,12 @@ type=Path, help="path of existing run", ) +features_extract.add_argument( + "--recalculate", + action="store_true", + default=False, + help="recalculate all magnitudes", +) modules = subparsers.add_parser( "modules", @@ -256,35 +265,55 @@ async def run() -> None: case "feature-extraction": search = Search.load_rundir(args.rundir) search.data_provider.prepare(search.stations) + recalculate_magnitudes = args.recalculate + + tasks = [] + + def console_status(task: asyncio.Task[EventDetection]): + detection = task.result() + if detection.magnitudes: + console.print( + f"Event {str(detection.time).split('.')[0]}:", + ", ".join( + f"[bold]{m.magnitude}[/bold] {m.average:.2f}±{m.error:.2f}" + for m in detection.magnitudes + ), + ) + else: + console.print(f"Event {detection.time}: No magnitudes") - async def extract() -> None: + async def worker() -> None: for magnitude in search.magnitudes: await magnitude.prepare(search.octree, search.stations) - - iterator = asyncio.as_completed( - tuple( - search.add_magnitude_and_features(detection) - for detection in search._catalog + await search.catalog.check(repair=True) + + sem = asyncio.Semaphore(get_cpu_count()) + for detection in track( + search.catalog, + description="Calculating magnitudes", + total=search.catalog.n_events, + console=console, + ): + await sem.acquire() + task = asyncio.create_task( + search.add_magnitude_and_features( + detection, + recalculate=recalculate_magnitudes, + ) ) - ) + tasks.append(task) + task.add_done_callback(lambda _: sem.release()) + task.add_done_callback(tasks.remove) + task.add_done_callback(console_status) - for result in track( - iterator, - description="Extracting features", - total=search._catalog.n_events, - ): - event = await result - if event.magnitudes: - for mag in event.magnitudes: - print(f"{mag.magnitude} {mag.average:.2f}±{mag.error:.2f}") # noqa: T201 - print("--") # noqa: T201 + await asyncio.gather(*tasks) await search._catalog.save() await search._catalog.export_detections( jitter_location=search.octree.smallest_node_size() ) - asyncio.run(extract(), debug=loop_debug) + asyncio.run(worker(), debug=loop_debug) case "corrections": import json @@ -391,7 +420,7 @@ def is_insight(module: type) -> bool: raise EnvironmentError(f"folder {args.folder} does not exist") file = args.folder / "search.schema.json" - print(f"writing JSON schemas to {args.folder}") # noqa: T201 + console.print(f"writing JSON schemas to {args.folder}") file.write_text(json.dumps(Search.model_json_schema(), indent=2)) file = args.folder / "detections.schema.json" diff --git a/src/qseek/corrections/base.py b/src/qseek/corrections/base.py index 10b614e5..7c5543e3 100644 --- a/src/qseek/corrections/base.py +++ b/src/qseek/corrections/base.py @@ -80,10 +80,11 @@ async def prepare( """Prepare the station for the corrections. Args: - station: The station to prepare. - octree: The octree to use for the preparation. - phases: The phases to prepare the station for. - rundir: The rundir to use for the delay. Defaults to None. + stations (Stations): The station to prepare. + octree (Octree): The octree to use for the preparation. + phases (Iterable[PhaseDescription]): The phases to prepare the station for. + rundir (Path | None, optional): The rundir to use for the delay. + Defaults to None. """ ... diff --git a/src/qseek/images/base.py b/src/qseek/images/base.py index c95ff365..a19621d5 100644 --- a/src/qseek/images/base.py +++ b/src/qseek/images/base.py @@ -33,8 +33,7 @@ def name(self) -> str: return self.__class__.__name__ def get_blinding(self, sampling_rate: float) -> timedelta: - """ - Blinding duration for the image function. Added to padded waveforms. + """Blinding duration for the image function. Added to padded waveforms. Args: sampling_rate (float): The sampling rate of the waveform. @@ -73,6 +72,7 @@ def set_stations(self, stations: Stations) -> None: def resample(self, sampling_rate: float, max_normalize: bool = False) -> None: """Resample traces in-place. + Args: sampling_rate (float): Desired sampling rate in Hz. max_normalize (bool): Normalize by maximum value to keep the scale of the @@ -137,7 +137,7 @@ def search_phase_arrival( trace_idx (int): Index of the trace. event_time (datetime): Time of the event. modelled_arrival (datetime): Time to search around. - search_length_seconds (float, optional): Total search length in seconds + search_window_seconds (float, optional): Total search length in seconds around modelled arrival time. Defaults to 5. threshold (float, optional): Threshold for detection. Defaults to 0.1. @@ -158,7 +158,7 @@ def search_phase_arrivals( Args: event_time (datetime): Time of the event. modelled_arrivals (list[datetime]): Time to search around. - search_length_seconds (float, optional): Total search length in seconds + search_window_seconds (float, optional): Total search length in seconds around modelled arrival time. Defaults to 5. threshold (float, optional): Threshold for detection. Defaults to 0.1. diff --git a/src/qseek/images/images.py b/src/qseek/images/images.py index 9b3eb44e..a9f3a77f 100644 --- a/src/qseek/images/images.py +++ b/src/qseek/images/images.py @@ -99,12 +99,11 @@ async def iter_images( """Iterate over images from batches. Args: - batches (AsyncIterator[Batch]): Async iterator over batches. + batch_iterator (AsyncIterator[Batch]): Async iterator over batches. Yields: AsyncIterator[WaveformImages]: Async iterator over images. """ - stats = self._stats async def worker() -> None: diff --git a/src/qseek/images/phase_net.py b/src/qseek/images/phase_net.py index 60b03c34..925552b8 100644 --- a/src/qseek/images/phase_net.py +++ b/src/qseek/images/phase_net.py @@ -59,7 +59,7 @@ def search_phase_arrival( trace_idx (int): Index of the trace. event_time (datetime): Time of the event. modelled_arrival (datetime): Time to search around. - search_length_seconds (float, optional): Total search length in seconds + search_window_seconds (float, optional): Total search length in seconds around modelled arrival time. Defaults to 5. threshold (float, optional): Threshold for detection. Defaults to 0.1. detection_blinding_seconds (float, optional): Blinding time in seconds for diff --git a/src/qseek/magnitudes/base.py b/src/qseek/magnitudes/base.py index 71567758..81ff900b 100644 --- a/src/qseek/magnitudes/base.py +++ b/src/qseek/magnitudes/base.py @@ -110,13 +110,23 @@ def get_subclasses(cls) -> tuple[type[EventMagnitudeCalculator], ...]: """ return tuple(cls.__subclasses__()) + def has_magnitude(self, event: EventDetection) -> bool: + """Check if the given event has a magnitude. + + Args: + event (EventDetection): The event to check. + + Returns: + bool: True if the event has a magnitude, False otherwise. + """ + raise NotImplementedError + async def add_magnitude( self, squirrel: Squirrel, event: EventDetection, ) -> None: - """ - Adds a magnitude to the squirrel for the given event. + """Adds a magnitude to the squirrel for the given event. Args: squirrel (Squirrel): The squirrel object to add the magnitude to. @@ -132,8 +142,7 @@ async def prepare( octree: Octree, stations: Stations, ) -> None: - """ - Prepare the magnitudes calculation by initializing necessary data structures. + """Prepare the magnitudes calculation by initializing necessary data structures. Args: octree (Octree): The octree containing seismic event data. diff --git a/src/qseek/magnitudes/local_magnitude.py b/src/qseek/magnitudes/local_magnitude.py index 2721cb72..3b3fff3f 100644 --- a/src/qseek/magnitudes/local_magnitude.py +++ b/src/qseek/magnitudes/local_magnitude.py @@ -164,6 +164,12 @@ def validate_model(self) -> Self: self._model = LocalMagnitudeModel.get_subclass_by_name(self.model)() return self + def has_magnitude(self, event: EventDetection) -> bool: + for mag in event.magnitudes: + if type(mag) is LocalMagnitude and mag.model == self.model: + return True + return False + async def add_magnitude(self, squirrel: Squirrel, event: EventDetection) -> None: model = self._model @@ -180,7 +186,7 @@ async def add_magnitude(self, squirrel: Squirrel, event: EventDetection) -> None cut_off_fade=cut_off_fade, quantity=model.restitution_quantity, phase=None, - remove_clipped=True, + filter_clipped=True, ) if not traces: logger.warning("No restituted traces found for event %s", event.time) diff --git a/src/qseek/magnitudes/local_magnitude_model.py b/src/qseek/magnitudes/local_magnitude_model.py index 9ae9b04e..db30b10b 100644 --- a/src/qseek/magnitudes/local_magnitude_model.py +++ b/src/qseek/magnitudes/local_magnitude_model.py @@ -159,7 +159,7 @@ def get_station_magnitude( try: traces = _COMPONENT_MAP[self.component](traces) except KeyError: - logger.warning("Could not get channels for %s", receiver.nsl.pretty) + logger.debug("Could not get channels for %s", receiver.nsl.pretty) return None if not traces: return None diff --git a/src/qseek/magnitudes/moment_magnitude.py b/src/qseek/magnitudes/moment_magnitude.py index ff15276e..edbd7af0 100644 --- a/src/qseek/magnitudes/moment_magnitude.py +++ b/src/qseek/magnitudes/moment_magnitude.py @@ -51,8 +51,7 @@ def norm_traces(traces: list[Trace]) -> np.ndarray: - """ - Normalizes the traces to their maximum absolute value. + """Normalizes the traces to their maximum absolute value. Args: traces (list[Trace]): The traces to normalize. @@ -79,13 +78,12 @@ class PeakAmplitudeDefinition(PeakAmplitudesBase): description="The epicentral distance range of the stations.", ) frequency_range: Range = Field( - default=Range(min=2.0, max=6.0), + default=Range(min=2.0, max=20.0), description="The frequency range in Hz to filter the traces.", ) def filter_receivers_by_nsl(self, receivers: Iterable[Receiver]) -> set[Receiver]: - """ - Filters the list of receivers based on the NSL ID. + """Filters the list of receivers based on the NSL ID. Args: receivers (list[Receiver]): The list of receivers to filter. @@ -108,8 +106,7 @@ def filter_receivers_by_range( receivers: Iterable[Receiver], event: EventDetection, ) -> set[Receiver]: - """ - Filters the list of receivers based on the distance range. + """Filters the list of receivers based on the distance range. Args: receivers (Iterable[Receiver]): The list of receivers to filter. @@ -127,11 +124,9 @@ def filter_receivers_by_range( class StationMomentMagnitude(NamedTuple): - # quantity: MeasurementUnit distance_epi: float magnitude: float error: float - peak: float @@ -153,11 +148,15 @@ def m0(self) -> float: @property def n_stations(self) -> int: - """ - Number of stations used for calculating the moment magnitude. - """ + """Number of stations used for calculating the moment magnitude.""" return len(self.stations_magnitudes) + def csv_row(self) -> dict[str, float]: + return { + "Mw": self.average, + "Mw-error": self.error, + } + async def add_traces( self, store: PeakAmplitudesStore, @@ -190,7 +189,7 @@ async def add_traces( continue try: - model = await store.get_amplitude( + model = await store.get_amplitude_model( source_depth=event.effective_depth, distance=station.distance_epi, n_amplitudes=25, @@ -201,9 +200,13 @@ async def add_traces( logger.warning("No modelled amplitude for receiver %s", receiver.nsl) continue - magnitude = model.get_magnitude(station.peak) - error_upper = model.get_magnitude(station.peak + station.noise) - magnitude - error_lower = model.get_magnitude(station.peak - station.noise) - magnitude + magnitude = model.estimate_magnitude(station.peak) + error_upper = ( + model.estimate_magnitude(station.peak + station.noise) - magnitude + ) + error_lower = ( + model.estimate_magnitude(station.peak - station.noise) - magnitude + ) if not np.isfinite(error_lower): error_lower = error_upper @@ -278,6 +281,11 @@ async def prepare(self, octree: Octree, stations: Stations) -> None: depth_delta=definition.source_depth_delta, ) + def has_magnitude(self, event: EventDetection) -> bool: + if not event.magnitudes: + return False + return any(type(mag) is MomentMagnitude for mag in event.magnitudes) + async def add_magnitude( self, squirrel: Squirrel, @@ -298,7 +306,11 @@ async def add_magnitude( logger.info("No receivers in range for peak amplitude") continue if not store.source_depth_range.inside(event.effective_depth): - logger.info("Event depth outside of store depth range.") + logger.info( + "Event depth %.1f outside of magnitude store range (%.1f - %.1f).", + event.effective_depth, + *store.source_depth_range, + ) continue traces = await event.receivers.get_waveforms_restituted( @@ -310,7 +322,7 @@ async def add_magnitude( demean=True, seconds_fade=self.padding_seconds, cut_off_fade=False, - remove_clipped=True, + filter_clipped=True, ) if not traces: continue @@ -318,14 +330,23 @@ async def add_magnitude( for tr in traces: if store.frequency_range.min != 0.0: await asyncio.to_thread( - tr.highpass, 4, store.frequency_range.min, demean=True + tr.highpass, + 4, + store.frequency_range.min, + demean=False, ) await asyncio.to_thread( - tr.lowpass, 4, store.frequency_range.max, demean=True + tr.lowpass, + 4, + store.frequency_range.max, + demean=False, ) tr.chop(tr.tmin + self.padding_seconds, tr.tmax - self.padding_seconds) if self.processed_mseed_export is not None: + logger.debug( + "saving processed mseed traces to %s", self.processed_mseed_export + ) io.save(traces, str(self.processed_mseed_export), append=True) grouped_traces = [] diff --git a/src/qseek/magnitudes/moment_magnitude_store.py b/src/qseek/magnitudes/moment_magnitude_store.py index 9adbd10f..527b1ac7 100644 --- a/src/qseek/magnitudes/moment_magnitude_store.py +++ b/src/qseek/magnitudes/moment_magnitude_store.py @@ -5,6 +5,7 @@ import itertools import logging import struct +from collections import defaultdict from functools import cached_property from pathlib import Path from typing import ( @@ -32,9 +33,9 @@ from pyrocko import gf from pyrocko.guts import Float from pyrocko.trace import FrequencyResponse -from rich.progress import track from typing_extensions import Self +from qseek.stats import PROGRESS from qseek.utils import ( ChannelSelector, ChannelSelectors, @@ -72,11 +73,11 @@ class MTSourceCircularCrack(gf.MTSource): duration = Float.T() stress_drop = Float.T() radius = Float.T() + magnitude = Float.T() def _get_target(targets: list[gf.Target], nsl: tuple[str, str, str]) -> gf.Target: - """ - Get the target from the list of targets based on the given NSL codes. + """Get the target from the list of targets based on the given NSL codes. Args: targets (list[gf.Target]): List of targets to search from. @@ -95,12 +96,11 @@ def _get_target(targets: list[gf.Target], nsl: tuple[str, str, str]) -> gf.Targe def trace_amplitude(traces: list[Trace], channel_selector: ChannelSelector) -> float: - """ - Normalize traces channels. + """Normalize traces channels. Args: traces (list[Trace]): A list of traces to normalize. - components (str): The components to normalize. + channel_selector (ChannelSelector): The channel selector to use. Returns: Trace: The normalized trace. @@ -141,7 +141,7 @@ class PeakAmplitudesBase(BaseModel): default=1.0, ge=-1.0, le=8.0, - description="Reference magnitude in Mw.", + description="Reference moment magnitude in Mw.", ) rupture_velocities: Range = Field( default=Range(0.8, 0.9), @@ -158,14 +158,21 @@ class PeakAmplitudesBase(BaseModel): class SiteAmplitude(NamedTuple): + magnitude: float distance_epi: float peak_horizontal: float peak_vertical: float peak_absolute: float @classmethod - def from_traces(cls, receiver: gf.Receiver, traces: list[Trace]) -> Self: + def from_traces( + cls, + receiver: gf.Receiver, + traces: list[Trace], + magnitude: float, + ) -> Self: return cls( + magnitude=magnitude, distance_epi=np.sqrt(receiver.north_shift**2 + receiver.east_shift**2), peak_horizontal=trace_amplitude(traces, ChannelSelectors.Horizontal), peak_vertical=trace_amplitude(traces, ChannelSelectors.Vertical), @@ -174,7 +181,7 @@ def from_traces(cls, receiver: gf.Receiver, traces: list[Trace]) -> Self: class ModelledAmplitude(NamedTuple): - reference_magnitude: float + magnitude: float quantity: MeasurementUnit peak_amplitude: PeakAmplitude distance_epi: float @@ -188,11 +195,10 @@ def combine( other: ModelledAmplitude, weight: float = 1.0, ) -> ModelledAmplitude: - """ - Combines with another ModelledAmplitude using a weighted average. + """Combines with another ModelledAmplitude using a weighted average. Args: - amplitude (ModelledAmplitude): The ModelledAmplitude to be combined with. + other (ModelledAmplitude): The ModelledAmplitude to be combined with. weight (float, optional): The weight of the amplitude being combined. Defaults to 1.0. @@ -210,13 +216,13 @@ def combine( raise ValueError("Cannot add amplitudes with different distances") if self.quantity != other.quantity: raise ValueError("Cannot add amplitudes with different quantities ") - if self.reference_magnitude != other.reference_magnitude: + if self.magnitude != other.magnitude: raise ValueError("Cannot add amplitudes with different reference magnitude") if self.peak_amplitude != other.peak_amplitude: raise ValueError("Cannot add amplitudes with different peak amplitudes ") rcp_weight = 1.0 - weight return ModelledAmplitude( - reference_magnitude=self.reference_magnitude, + magnitude=self.magnitude, peak_amplitude=self.peak_amplitude, quantity=self.quantity, distance_epi=self.distance_epi, @@ -226,9 +232,8 @@ def combine( mad=self.mad * rcp_weight + other.mad * weight, ) - def get_magnitude(self, observed_amplitude: float) -> float: - """ - Get the moment magnitude for the given observed amplitude. + def estimate_magnitude(self, observed_amplitude: float) -> float: + """Get the moment magnitude for the given observed amplitude. Args: observed_amplitude (float): The observed amplitude. @@ -236,13 +241,13 @@ def get_magnitude(self, observed_amplitude: float) -> float: Returns: float: The moment magnitude. """ - return self.reference_magnitude + np.log10(observed_amplitude / self.median) + with np.errstate(divide="ignore", invalid="ignore"): + return self.magnitude + np.log10(observed_amplitude / self.average) class SiteAmplitudesCollection(BaseModel): source_depth: float quantity: MeasurementUnit - reference_magnitude: float rupture_velocities: Range stress_drop: Range gf_store_id: str @@ -258,32 +263,34 @@ def wrapped(self) -> np.ndarray: return wrapped _distances = cached_property[np.ndarray](_get_numpy_array("distance_epi")) + _magnitudes = cached_property[np.ndarray](_get_numpy_array("magnitude")) _vertical = cached_property[np.ndarray](_get_numpy_array("peak_vertical")) _absolute = cached_property[np.ndarray](_get_numpy_array("peak_absolute")) _horizontal = cached_property[np.ndarray](_get_numpy_array("peak_horizontal")) def _clear_cache(self) -> None: - self.__dict__.pop("_distances", None) - self.__dict__.pop("_horizontal", None) - self.__dict__.pop("_vertical", None) - self.__dict__.pop("_absolute", None) + keys = {"_distances", "_horizontal", "_vertical", "_absolute", "_magnitudes"} + for key in keys: + self.__dict__.pop(key, None) - def get_amplitude( + def get_amplitude_model( self, distance: float, n_amplitudes: int, - max_distance: float = 0.0, + distance_cutoff: float = 0.0, + reference_magnitude: float = 1.0, peak_amplitude: PeakAmplitude = "absolute", ) -> ModelledAmplitude: - """ - Get the amplitudes for a given distance. + """Get the amplitudes for a given distance. Args: distance (float): The epicentral distance to retrieve the amplitudes for. n_amplitudes (int): The number of amplitudes to retrieve. - max_distance (float): The maximum distance allowed for + distance_cutoff (float): The maximum distance allowed for the retrieved amplitudes. If 0.0, no maximum distance is applied and the number of amplitudes will be exactly n_amplitudes. Defaults to 0.0. + reference_magnitude (float, optional): The reference magnitude to retrieve + the amplitudes for. Defaults to 1.0. peak_amplitude (PeakAmplitude, optional): The type of peak amplitude to retrieve. Defaults to "absolute". @@ -294,28 +301,35 @@ def get_amplitude( ValueError: If there are not enough amplitudes in the specified range. ValueError: If the peak amplitude type is unknown. """ - site_distances = np.abs(self._distances - distance) + magnitude_idx = np.where(self._magnitudes == reference_magnitude)[0] + if not magnitude_idx.size: + raise ValueError(f"No amplitudes for magnitude {reference_magnitude}.") + + site_distances = np.abs(self._distances[magnitude_idx] - distance) distance_idx = np.argsort(site_distances) + idx = distance_idx[:n_amplitudes] + distances = site_distances[idx] - if max_distance and distances.max() > max_distance: + if distance_cutoff and distances.max() > distance_cutoff: raise ValueError( - f"Not enough amplitudes at distance {distance} and range {max_distance}" + f"Not enough amplitudes at distance {distance}" + f" at cutoff {distance_cutoff}" ) match peak_amplitude: case "horizontal": - amplitudes = self._horizontal[idx] + amplitudes = self._horizontal[magnitude_idx][idx] case "vertical": - amplitudes = self._vertical[idx] + amplitudes = self._vertical[magnitude_idx][idx] case "absolute": - amplitudes = self._absolute[idx] + amplitudes = self._absolute[magnitude_idx][idx] case _: raise ValueError(f"Unknown peak amplitude type {peak_amplitude}.") median = float(np.median(amplitudes)) return ModelledAmplitude( - reference_magnitude=self.reference_magnitude, + magnitude=reference_magnitude, peak_amplitude=peak_amplitude, quantity=self.quantity, distance_epi=distance, @@ -325,14 +339,22 @@ def get_amplitude( mad=float(np.median(np.abs(amplitudes - median))), ) - def fill(self, receivers: list[gf.Receiver], traces: list[list[Trace]]) -> None: - for receiver, rcv_traces in zip(receivers, traces, strict=True): - self.site_amplitudes.append(SiteAmplitude.from_traces(receiver, rcv_traces)) + def fill( + self, + receivers: list[gf.Receiver], + traces: list[list[Trace]], + magnitudes: list[float], + ) -> None: + for receiver, rcv_traces, magnitude in zip( + receivers, traces, magnitudes, strict=True + ): + self.site_amplitudes.append( + SiteAmplitude.from_traces(receiver, rcv_traces, magnitude) + ) self._clear_cache() def distance_range(self) -> Range: - """ - Get the distance range of the site amplitudes. + """Get the distance range of the site amplitudes. Returns: Range: The distance range. @@ -341,8 +363,7 @@ def distance_range(self) -> Range: @property def n_amplitudes(self) -> int: - """ - Get the number of amplitudes in the collection. + """Get the number of amplitudes in the collection. Returns: int: The number of amplitudes. @@ -352,6 +373,7 @@ def n_amplitudes(self) -> int: def plot( self, axes: Axes | None = None, + reference_magnitude: float = 1.0, peak_amplitude: PeakAmplitude = "absolute", ) -> None: from matplotlib.ticker import FuncFormatter @@ -371,10 +393,11 @@ def plot( interp_amplitudes: list[ModelledAmplitude] = [] for distance in np.arange(*self.distance_range(), 250.0): interp_amplitudes.append( - self.get_amplitude( + self.get_amplitude_model( distance=distance, n_amplitudes=50, peak_amplitude=peak_amplitude, + reference_magnitude=reference_magnitude, ) ) @@ -417,7 +440,7 @@ def plot( 0.025, 0.025, f"""n={self.n_amplitudes} -$M_w^r$={self.reference_magnitude} +$M_w^{{ref}}$={reference_magnitude} $z$={self.source_depth / KM} km $v_r$=[{self.rupture_velocities.min}, {self.rupture_velocities.max}]$\\cdot v_s$ $\\Delta\\sigma$=[{self.stress_drop.min / 1e6}, {self.stress_drop.max / 1e6}] MPa @@ -448,7 +471,7 @@ class PeakAmplitudesStore(PeakAmplitudesBase): default_factory=uuid4, description="Unique ID of the amplitude store.", ) - site_amplitudes: list[SiteAmplitudesCollection] = Field( + amplitude_collections: list[SiteAmplitudesCollection] = Field( default_factory=list, description="Site amplitudes per source depth.", ) @@ -460,8 +483,15 @@ class PeakAmplitudesStore(PeakAmplitudesBase): default="", description="Hash of the GF store configuration.", ) + magnitude_range: Range = Field( + default=Range(0.0, 6.0), + description="Range of moment magnitudes for the seismic sources.", + ) _rng: np.random.Generator = PrivateAttr(default_factory=np.random.default_rng) + _access_locks: dict[int, asyncio.Lock] = PrivateAttr( + default_factory=lambda: defaultdict(asyncio.Lock) + ) _engine: ClassVar[gf.LocalEngine | None] = None _cache_dir: ClassVar[Path | None] = None @@ -469,8 +499,7 @@ class PeakAmplitudesStore(PeakAmplitudesBase): @classmethod def set_engine(cls, engine: gf.LocalEngine) -> None: - """ - Set the GF engine for the store. + """Set the GF engine for the store. Args: engine (gf.LocalEngine): The engine to use. @@ -479,8 +508,7 @@ def set_engine(cls, engine: gf.LocalEngine) -> None: @classmethod def set_cache_dir(cls, cache_dir: Path) -> None: - """ - Set the cache directory for the store. + """Set the cache directory for the store. Args: cache_dir (Path): The cache directory to use. @@ -489,8 +517,7 @@ def set_cache_dir(cls, cache_dir: Path) -> None: @classmethod def from_selector(cls, selector: PeakAmplitudesBase) -> Self: - """ - Create a new PeakAmplitudesStore from the given selector. + """Create a new PeakAmplitudesStore from the given selector. Args: selector (PeakAmplitudesSelector): The selector to use. @@ -498,7 +525,6 @@ def from_selector(cls, selector: PeakAmplitudesBase) -> Self: Returns: PeakAmplitudesStore: The newly created store. """ - if cls._engine is None: raise EnvironmentError( "No GF engine available to determine frequency range." @@ -525,12 +551,11 @@ def from_selector(cls, selector: PeakAmplitudesBase) -> Self: @property def source_depth_range(self) -> Range: - return Range.from_list([sa.source_depth for sa in self.site_amplitudes]) + return Range.from_list([sa.source_depth for sa in self.amplitude_collections]) @property def gf_store_depth_range(self) -> Range: - """ - Get the depth range of the GF store. + """Get the depth range of the GF store. Returns: Range: The depth range. @@ -540,8 +565,7 @@ def gf_store_depth_range(self) -> Range: @property def gf_store_distance_range(self) -> Range: - """ - Returns the distance range for the ground motion store. + """Returns the distance range for the ground motion store. The distance range is determined by the minimum and maximum distances specified in the store's configuration. If the maximum distance exceeds @@ -557,10 +581,20 @@ def gf_store_distance_range(self) -> Range: max=min(store.config.distance_max, self.max_distance), ) - def get_store(self) -> gf.Store: - """ - Load the GF store for the given store ID. + def get_lock(self, source_depth: float, reference_magnitude: float) -> asyncio.Lock: + """Get the lock for the given source depth and reference magnitude. + + Args: + source_depth (float): The source depth. + reference_magnitude (float): The reference magnitude. + + Returns: + asyncio.Lock: The lock for the given source depth and reference magnitude. """ + return self._access_locks[hash((source_depth, reference_magnitude))] + + def get_store(self) -> gf.Store: + """Load the GF store for the given store ID.""" if self._engine is None: raise EnvironmentError("No GF engine available.") @@ -581,13 +615,18 @@ def get_store(self) -> gf.Store: return store def _get_random_source( - self, depth: float, stf: Type[gf.STF] | None = None + self, + depth: float, + reference_magnitude: float, + stf: Type[gf.STF] | None = None, ) -> MTSourceCircularCrack: - """ - Generates a random seismic source with the given depth. + """Generates a random seismic source with the given depth. Args: depth (float): The depth of the seismic source. + reference_magnitude (float): The reference moment magnitude. + stf (Type[gf.STF], optional): The source time function to use. + Defaults to None. Returns: gf.MTSource: A random moment tensor source. @@ -601,17 +640,18 @@ def _get_random_source( rupture_velocity = rng.uniform(*self.rupture_velocities) * vs radius = ( - pmt.magnitude_to_moment(self.reference_magnitude) * (7 / 16) / stress_drop + pmt.magnitude_to_moment(reference_magnitude) * (7 / 16) / stress_drop ) ** (1 / 3) duration = 1.5 * radius / rupture_velocity - moment_tensor = pmt.MomentTensor.random_dc(magnitude=self.reference_magnitude) + moment_tensor = pmt.MomentTensor.random_dc(magnitude=reference_magnitude) return MTSourceCircularCrack( + magnitude=reference_magnitude, + stress_drop=stress_drop, + radius=radius, m6=moment_tensor.m6(), depth=depth, duration=duration, - stress_drop=stress_drop, - radius=radius, - stf=stf(duration=duration) if stf else None, + stf=stf(effective_duration=duration) if stf else None, ) def _get_random_targets( @@ -619,18 +659,22 @@ def _get_random_targets( distance_range: Range, n_receivers: int, ) -> list[gf.Target]: - """ - Generate a list of receivers with random angles and distances. + """Generate a list of receivers with random angles and distances. Args: + distance_range (Range): The range of distances to generate the + receivers for. n_receivers (int): The number of receivers to generate. Returns: list[gf.Receiver]: A list of receivers with random angles and distances. """ rng = self._rng + _distance_range = np.array(distance_range) + _distance_range[_distance_range <= 0.0] = 1.0 # Add an epsilon + angles = rng.uniform(0.0, 360.0, size=n_receivers) - distances = np.exp(rng.uniform(*np.log(distance_range), size=n_receivers)) + distances = np.exp(rng.uniform(*np.log(_distance_range), size=n_receivers)) targets: list[gf.Receiver] = [] for i_receiver, (angle, distance) in enumerate( @@ -651,20 +695,22 @@ def _get_random_targets( targets.append(target) return targets # type: ignore - async def fill_source_depth( + async def compute_site_amplitudes( self, source_depth: float, + reference_magnitude: float, n_sources: int = 200, n_targets_per_source: int = 20, ) -> SiteAmplitudesCollection: - """ - Fills the moment magnitude store with amplitudes calculated - for a specific source depth. + """Fills the moment magnitude store. + + Calculates the amplitudes for a given source depth and reference magnitude. Args: source_depth (float): The depth of the seismic source. - n_targets (int, optional): The number of target locations to calculate - amplitudes for. Defaults to 20. + reference_magnitude (float): The reference moment magnitude. + n_targets_per_source (int, optional): The number of target locations to + calculate amplitudes for. Defaults to 20. n_sources (int, optional): The number of source locations to generate random sources from. Defaults to 100. """ @@ -677,20 +723,23 @@ async def fill_source_depth( target_distances = self.gf_store_distance_range logger.info( - "calculating %d amplitudes for depth %f", + "calculating %d %s amplitudes for Mw %.1f at depth %.1f", n_sources * n_targets_per_source, + self.quantity, + reference_magnitude, source_depth, ) receivers = [] receiver_traces = [] - for _ in track( - range(n_sources), + magnitudes = [] + status = PROGRESS.add_task( + f"Calculating Mw {reference_magnitude} amplitudes for depth {source_depth}", total=n_sources, - description=f"calculating amplitudes for depth {source_depth}", - ): + ) + for _ in range(n_sources): targets = self._get_random_targets(target_distances, n_targets_per_source) - source = self._get_random_source(source_depth) + source = self._get_random_source(source_depth, reference_magnitude) response = await asyncio.to_thread(engine.process, source, targets) traces: list[Trace] = response.pyrocko_traces() @@ -707,12 +756,15 @@ async def fill_source_depth( ): receivers.append(_get_target(targets, nsl)) receiver_traces.append(list(grp_traces)) + magnitudes.append(reference_magnitude) + PROGRESS.update(status, advance=1) + PROGRESS.remove_task(status) try: collection = self.get_collection(source_depth) except KeyError: collection = self.new_collection(source_depth) - collection.fill(receivers, receiver_traces) + collection.fill(receivers, receiver_traces, magnitudes) self.save() return collection @@ -724,8 +776,7 @@ async def fill_source_depth_range( n_sources: int = 400, n_targets_per_source: int = 20, ) -> None: - """ - Fills the source depth range with seismic data. + """Fills the source depth range with seismic data. Args: depth_min (float): The minimum depth of the source in meters. @@ -756,7 +807,7 @@ async def fill_source_depth_range( depths = np.arange(gf_depth_min, gf_depth_max, depth_delta) calculate_depths = depths[(depths >= depth_min) & (depths <= depth_max)] - stored_depths = [sa.source_depth for sa in self.site_amplitudes] + stored_depths = [sa.source_depth for sa in self.amplitude_collections] logger.debug("filling source depths %s", calculate_depths) for depth in calculate_depths: if depth in stored_depths: @@ -764,95 +815,108 @@ async def fill_source_depth_range( self.remove_collection(depth) else: continue - await self.fill_source_depth( - source_depth=depth, - n_sources=n_sources, - n_targets_per_source=n_targets_per_source, - ) + async with self.get_lock(depth, self.reference_magnitude): + await self.compute_site_amplitudes( + reference_magnitude=self.reference_magnitude, + source_depth=depth, + n_sources=n_sources, + n_targets_per_source=n_targets_per_source, + ) def get_collection(self, source_depth: float) -> SiteAmplitudesCollection: - """ - Get the site amplitudes collection for the given source depth. + """Get the site amplitudes collection for the given source depth. Args: - depth (float): The source depth. + source_depth (float): The source depth. Returns: SiteAmplitudesCollection: The site amplitudes collection. """ - for site_amplitudes in self.site_amplitudes: + for site_amplitudes in self.amplitude_collections: if site_amplitudes.source_depth == source_depth: return site_amplitudes raise KeyError(f"No site amplitudes for depth {source_depth}.") - def new_collection(self, depth: float) -> SiteAmplitudesCollection: - """ - Creates a new SiteAmplitudesCollection object for the given depth and - adds it to the list of site amplitudes. + def new_collection(self, source_depth: float) -> SiteAmplitudesCollection: + """Creates a new SiteAmplitudesCollection object. + + For the given depth and add it to the list of site amplitudes. Args: - depth (float): The depth for which the site amplitudes collection is + source_depth (float): The depth for which the site amplitudes collection is created. Returns: SiteAmplitudesCollection: The newly created SiteAmplitudesCollection object. """ - logger.debug("creating new site amplitudes for depth %f", depth) - self.remove_collection(depth) + logger.debug("creating new site amplitudes for depth %f", source_depth) + self.remove_collection(source_depth) collection = SiteAmplitudesCollection( - source_depth=depth, + source_depth=source_depth, **self.model_dump(exclude={"site_amplitudes"}), ) - self.site_amplitudes.append(collection) + self.amplitude_collections.append(collection) return collection def remove_collection(self, depth: float) -> None: - """ - Removes the site amplitudes collection for the given depth. + """Removes the site amplitudes collection for the given depth. Args: depth (float): The depth for which the site amplitudes collection is removed. """ - logger.debug("removing site amplitudes for depth %f", depth) try: collection = self.get_collection(depth) - self.site_amplitudes.remove(collection) + self.amplitude_collections.remove(collection) + logger.debug("removed site amplitudes for depth %f", depth) except KeyError: pass - async def get_amplitude( + async def get_amplitude_model( self, source_depth: float, distance: float, n_amplitudes: int = 25, - max_distance: float = 0.0, + distance_cutoff: float = 0.0, + reference_magnitude: float | None = None, peak_amplitude: PeakAmplitude = "absolute", auto_fill: bool = True, interpolation: Literal["nearest", "linear"] = "linear", ) -> ModelledAmplitude: - """ - Retrieves the amplitude for a given depth and distance. + """Retrieves the amplitude for a given depth and distance. Args: - depth (float): The depth of the event. + source_depth (float): The depth of the event. distance (float): The epicentral distance from the event. n_amplitudes (int, optional): The number of amplitudes to retrieve. Defaults to 10. - max_distance (float, optional): The maximum distance to consider in [m]. - Defaults to 1000.0. + distance_cutoff (float, optional): The maximum distance allowed for + the retrieved amplitudes. If 0.0, no maximum distance is applied and the + number of amplitudes will be exactly n_amplitudes. Defaults to 0.0. + reference_magnitude (float, optional): The reference moment magnitude + for the amplitudes. Defaults to 1.0. peak_amplitude (PeakAmplitude, optional): The type of peak amplitude to retrieve. Defaults to "absolute". - auto_fill (bool, optional): If True, the site amplitudes are calculated + auto_fill (bool, optional): If True, the site amplitudes for + depth-reference magnitude combinations are calculated if they are not available. Defaults to True. + interpolation (Literal["nearest", "linear"], optional): The depth + interpolation method to use. Defaults to "linear". Returns: - ModelledAmplitude: The modelled amplitude for the given depth and distance. + ModelledAmplitude: The modelled amplitude for the given depth, distance and + reference magnitude. """ if not self.source_depth_range.inside(source_depth): raise ValueError(f"Source depth {source_depth} outside range.") - source_depths = np.array([sa.source_depth for sa in self.site_amplitudes]) + source_depths = np.array([sa.source_depth for sa in self.amplitude_collections]) + reference_magnitude = ( + self.reference_magnitude + if reference_magnitude is None + else reference_magnitude + ) + match interpolation: case "nearest": idx = [np.abs(source_depths - source_depth).argmin()] @@ -861,32 +925,43 @@ async def get_amplitude( case _: raise ValueError(f"Unknown interpolation method {interpolation}.") - collections = [self.site_amplitudes[i] for i in idx] + collections = [self.amplitude_collections[i] for i in idx] amplitudes: list[ModelledAmplitude] = [] + for collection in collections: + lock = self.get_lock(collection.source_depth, reference_magnitude) try: - amplitude = collection.get_amplitude( + await lock.acquire() + amplitude = collection.get_amplitude_model( distance=distance, n_amplitudes=n_amplitudes, - max_distance=max_distance, + distance_cutoff=distance_cutoff, peak_amplitude=peak_amplitude, + reference_magnitude=reference_magnitude, ) amplitudes.append(amplitude) - except ValueError: + except ValueError as e: + logger.exception(e) if auto_fill: - await self.fill_source_depth(source_depth) - logger.info("auto-filling amplitudes for depth %f", source_depth) - return await self.get_amplitude( + await self.compute_site_amplitudes( + source_depth=collection.source_depth, + reference_magnitude=reference_magnitude, + ) + lock.release() + return await self.get_amplitude_model( source_depth=source_depth, distance=distance, n_amplitudes=n_amplitudes, - max_distance=max_distance, + reference_magnitude=reference_magnitude, + distance_cutoff=distance_cutoff, peak_amplitude=peak_amplitude, interpolation=interpolation, auto_fill=True, ) + lock.release() raise + lock.release() if not amplitudes: raise ValueError(f"No site amplitudes for depth {source_depth}.") @@ -896,11 +971,8 @@ async def get_amplitude( amplitude = amplitudes[0] case "linear": - if len(amplitudes) != 2: - raise ValueError( - f"Cannot interpolate amplitudes with {len(amplitudes)} " - f" source depths." - ) + if len(amplitudes) == 1: + return amplitudes[0] depths = source_depths[idx] weight = abs((source_depth - depths[0]) / abs(depths[1] - depths[0])) amplitude = amplitudes[0].combine(amplitudes[1], weight=weight) @@ -910,9 +982,85 @@ async def get_amplitude( raise ValueError(f"Median amplitude is zero for depth {source_depth}.") return amplitude - def hash(self) -> str: + async def find_moment_magnitude( + self, + source_depth: float, + distance: float, + observed_amplitude: float, + n_amplitudes: int = 25, + distance_cutoff: float = 0.0, + initial_reference_magnitude: float = 1.0, + peak_amplitude: PeakAmplitude = "absolute", + interpolation: Literal["nearest", "linear"] = "linear", + ) -> tuple[float, ModelledAmplitude]: + """Get the moment magnitude for the given observed amplitude. + + Args: + source_depth (float): The depth of the event. + distance (float): The epicentral distance from the event. + observed_amplitude (float): The observed amplitude. + n_amplitudes (int, optional): The number of amplitudes to retrieve. + Defaults to 10. + initial_reference_magnitude (float, optional): The initial reference + moment magnitude to use. Defaults to 1.0. + distance_cutoff (float, optional): The maximum distance allowed for + the retrieved amplitudes. If 0.0, no maximum distance is applied and the + number of amplitudes will be exactly n_amplitudes. Defaults to 0.0. + peak_amplitude (PeakAmplitude, optional): The type of peak amplitude to + retrieve. Defaults to "absolute". + interpolation (Literal["nearest", "linear"], optional): The depth + interpolation method to use. Defaults to "linear". + + Returns: + float: The moment magnitude. """ - Calculate the hash of the store from store parameters. + cache: list[tuple[float, float, ModelledAmplitude]] = [] + + def get_cache(reference_magnitude: float) -> tuple[float, ModelledAmplitude]: + for mag, est, model in cache: + if mag == reference_magnitude: + return est, model + raise KeyError(f"No estimate for magnitude {reference_magnitude}.") + + async def estimate_magnitude( + reference_magnitude: float, + ) -> tuple[float, ModelledAmplitude]: + try: + return get_cache(reference_magnitude) + except KeyError: + model = await self.get_amplitude_model( + reference_magnitude=reference_magnitude, + source_depth=source_depth, + distance=distance, + n_amplitudes=n_amplitudes, + distance_cutoff=distance_cutoff, + peak_amplitude=peak_amplitude, + interpolation=interpolation, + ) + est_magnitude = model.estimate_magnitude(observed_amplitude) + cache.append((reference_magnitude, est_magnitude, model)) + return est_magnitude, model + + reference_mag = initial_reference_magnitude + for _ in range(3): + est_magnitude, _ = await estimate_magnitude(reference_mag) + rounded_mag = np.round(est_magnitude, 0) + explore_mags = np.array([rounded_mag - 1, rounded_mag, rounded_mag + 1]) + + predictions = [await estimate_magnitude(mag) for mag in explore_mags] + predicted_mags = np.array([mag for mag, _ in predictions]) + models = [model for _, model in predictions] + + magnitude_differences = np.abs(predicted_mags - explore_mags) + min_diff = np.argmin(magnitude_differences) + + if min_diff == 1: + return predicted_mags[1], models[1] + reference_mag = explore_mags[min_diff] + return predicted_mags[min_diff], models[min_diff] + + def hash(self) -> str: + """Calculate the hash of the store from store parameters. Returns: str: The hash of the store. @@ -931,8 +1079,7 @@ def hash(self) -> str: return hashlib.sha1(data).hexdigest() def is_suited(self, selector: PeakAmplitudesBase) -> bool: - """ - Check if the given selector is suited for this store. + """Check if the given selector is suited for this store. Args: selector (PeakAmpliutdesSelector): The selector to check. @@ -957,8 +1104,7 @@ def __hash__(self) -> int: return hash(self.hash()) def save(self, path: Path | None = None) -> None: - """ - Save the site amplitudes to a JSON file. + """Save the site amplitudes to a JSON file. The site amplitudes are saved in a directory called 'site_amplitudes' within the cache directory. The file name is generated based on the store ID and @@ -995,8 +1141,7 @@ def __init__(self, cache_dir: Path, engine: gf.LocalEngine | None = None) -> Non PeakAmplitudesStore.set_cache_dir(cache_dir) def clear_cache(self): - """ - Clear the cache directory. + """Clear the cache directory. This method deletes all files in the cache directory. """ @@ -1005,8 +1150,7 @@ def clear_cache(self): file.unlink() def clean_cache(self, keep_files: int = 100) -> None: - """ - Clean the cache directory. + """Clean the cache directory. Args: keep_files (int, optional): The number of most recent files to keep in the @@ -1020,8 +1164,7 @@ def clean_cache(self, keep_files: int = 100) -> None: file.unlink() def cache_stats(self) -> CacheStats: - """ - Get the cache statistics. + """Get the cache statistics. Returns: CacheStats: The cache statistics. @@ -1036,8 +1179,7 @@ def cache_stats(self) -> CacheStats: def get_cached_stores( self, store_id: str, quantity: MeasurementUnit ) -> list[PeakAmplitudesStore]: - """ - Get the cached peak amplitude stores for the given store ID and quantity. + """Get the cached peak amplitude stores for the given store ID and quantity. Args: store_id (str): The store ID. @@ -1065,9 +1207,9 @@ def get_cached_stores( return stores def get_store(self, selector: PeakAmplitudesBase) -> PeakAmplitudesStore: - """ - Get a peak amplitude store for the given selector, either from the cache - or by creating a new store. + """Get a peak amplitude store for the given selector. + + Either from the cache or by creating a new store. Args: selector (PeakAmplitudesSelector): The selector to use. diff --git a/src/qseek/models/catalog.py b/src/qseek/models/catalog.py index 7d27439a..e25c7f64 100644 --- a/src/qseek/models/catalog.py +++ b/src/qseek/models/catalog.py @@ -1,14 +1,16 @@ from __future__ import annotations +from datetime import datetime from pathlib import Path from typing import TYPE_CHECKING, Any, Iterator import aiofiles -from pydantic import BaseModel, PrivateAttr +from pydantic import BaseModel, PrivateAttr, computed_field from pyrocko import io from pyrocko.gui import marker from pyrocko.model import Event, dump_events from pyrocko.trace import Trace +from rich.progress import track from rich.table import Table from qseek.console import console @@ -31,6 +33,29 @@ class EventCatalogStats(Stats): max_semblance: float = 0.0 _position: int = 2 + _catalog: EventCatalog = PrivateAttr() + + def set_catalog(self, catalog: EventCatalog) -> None: + self._catalog = catalog + self.n_detections = catalog.n_events + + @property + def magnitudes(self) -> list[float]: + return [det.magnitude.average for det in self._catalog if det.magnitude] + + @computed_field + def mean_semblance(self) -> float: + return ( + sum(detection.semblance for detection in self._catalog) / self.n_detections + ) + + @computed_field + def magnitude_min(self) -> float: + return min(self.magnitudes) if self.magnitudes else 0.0 + + @computed_field + def magnitude_max(self) -> float: + return max(self.magnitudes) if self.magnitudes else 0.0 def new_detection(self, detection: EventDetection): self.n_detections += 1 @@ -51,7 +76,7 @@ def model_post_init(self, __context: Any) -> None: @property def n_events(self) -> int: - """Number of detections""" + """Number of detections.""" return len(self.events) @property @@ -66,6 +91,29 @@ def csv_dir(self) -> Path: dir.mkdir(exist_ok=True) return dir + async def filter_events_by_time( + self, + start_time: datetime | None, + end_time: datetime | None, + ) -> None: + """Filter the detections based on the given time range. + + Args: + start_time (datetime | None): Start time of the time range. + end_time (datetime | None): End time of the time range. + """ + events = [] + if start_time is not None and min(det.time for det in self.events) < start_time: + logger.info("filtering detections after start time %s", start_time) + events = [det for det in self.events if det.time >= start_time] + if end_time is not None and max(det.time for det in self.events) > end_time: + logger.info("filtering detections before end time %s", end_time) + events = [det for det in self.events if det.time <= end_time] + if events: + self.events = events + self._stats.n_detections = len(self.events) + await self.save() + async def add(self, detection: EventDetection) -> None: detection.set_index(self.n_events) @@ -126,10 +174,48 @@ def load_rundir(cls, rundir: Path) -> EventCatalog: stats = catalog._stats stats.n_detections = catalog.n_events - if catalog: + if catalog and catalog.n_events: stats.max_semblance = max(detection.semblance for detection in catalog) return catalog + async def check(self, repair: bool = True) -> None: + """Check the catalog for errors and inconsistencies. + + Args: + repair (bool, optional): If True, attempt to repair the catalog. + Defaults to True. + """ + logger.info("checking catalog...") + found_bad = 0 + found_duplicates = 0 + event_uids = set() + for detection in track( + self.events.copy(), + description=f"checking {self.n_events} events...", + ): + try: + _ = detection.receivers + except ValueError: + found_bad += 1 + if repair: + self.events.remove(detection) + + if detection.uid in event_uids: + found_duplicates += 1 + if repair: + self.events.remove(detection) + + event_uids.add(detection.uid) + + if found_bad or found_duplicates: + logger.info("found %d detections with invalid receivers", found_bad) + logger.info("found %d duplicate detections", found_duplicates) + if repair: + logger.info("repairing catalog") + await self.save() + else: + logger.info("all detections are ok") + async def save(self) -> None: """Save catalog to current rundir.""" logger.debug("saving %d detections", self.n_events) @@ -148,8 +234,7 @@ async def save(self) -> None: await f.writelines(lines_recv) async def export_detections(self, jitter_location: float = 0.0) -> None: - """ - Export detections to CSV and Pyrocko event lists in the current rundir. + """Export detections to CSV and Pyrocko event lists in the current rundir. Args: jitter_location (float): The amount of jitter in [m] to apply @@ -178,6 +263,7 @@ async def export_csv(self, file: Path, jitter_location: float = 0.0) -> None: jitter_location (float, optional): Randomize the location of each detection by this many meters. Defaults to 0.0. """ + logger.info("saving event CSV to %s", file) header = [] if jitter_location: @@ -225,10 +311,12 @@ def get_pyrocko_markers(self) -> list[EventMarker | PhaseMarker]: def export_pyrocko_events( self, filename: Path, jitter_location: float = 0.0 ) -> None: - """Export Pyrocko events for all detections to a file + """Export Pyrocko events for all detections to a file. Args: filename (Path): output filename + jitter_location (float, optional): Randomize the location of each detection + by this many meters. Defaults to 0.0. """ logger.info("saving Pyrocko events to %s", filename) detections = self.events @@ -241,7 +329,7 @@ def export_pyrocko_events( ) def export_pyrocko_markers(self, filename: Path) -> None: - """Export Pyrocko markers for all detections to a file + """Export Pyrocko markers for all detections to a file. Args: filename (Path): output filename diff --git a/src/qseek/models/detection.py b/src/qseek/models/detection.py index f0996235..bc24b70e 100644 --- a/src/qseek/models/detection.py +++ b/src/qseek/models/detection.py @@ -51,6 +51,7 @@ FILENAME_RECEIVERS = "detections_receivers.json" UPDATE_LOCK = asyncio.Lock() +SQUIRREL_SEM = asyncio.Semaphore(64) class ReceiverCache: @@ -70,11 +71,40 @@ def load(self) -> None: self.lines = self.file.read_text().splitlines() self.mtime = self.file.stat().st_mtime - def get_row(self, row_index: int) -> str: + def _check_mtime(self) -> None: if self.mtime is None or self.mtime != self.file.stat().st_mtime: self.load() + + def get_line(self, row_index: int) -> str: + """Retrieves the line at the specified row index. + + Args: + row_index (int): The index of the row to retrieve. + + Returns: + str: The line at the specified row index. + """ + self._check_mtime() return self.lines[row_index] + def find_uid(self, uid: UUID) -> tuple[int, str]: + """Find the given UID in the lines and return its index and value. + + get_line should be prefered over this method. + + Args: + uid (UUID): The UID to search for. + + Returns: + tuple[int, str]: A tuple containing the index and value of the found UID. + """ + self._check_mtime() + find_uid = str(uid) + for iline, line in enumerate(self.lines): + if find_uid in line: + return iline, line + raise KeyError + class PhaseDetection(BaseModel): phase: PhaseDescription @@ -114,8 +144,7 @@ def _get_csv_dict(self) -> dict[str, Any]: return csv_dict def as_pyrocko_markers(self) -> list[marker.PhaseMarker]: - """ - Convert the observed and modeled arrivals to a list of Pyrocko PhaseMarkers. + """Convert the observed and modeled arrivals to a list of Pyrocko PhaseMarkers. Returns: list[marker.PhaseMarker]: List of Pyrocko PhaseMarker objects representing @@ -151,8 +180,7 @@ def add_phase_detection(self, arrival: PhaseDetection) -> None: self.phase_arrivals[arrival.phase] = arrival def as_pyrocko_markers(self) -> list[marker.PhaseMarker]: - """ - Convert the phase arrivals to Pyrocko markers. + """Convert the phase arrivals to Pyrocko markers. Returns: A list of Pyrocko PhaseMarker objects. @@ -168,8 +196,7 @@ def as_pyrocko_markers(self) -> list[marker.PhaseMarker]: def get_arrivals_time_window( self, phase: PhaseDescription | None = None ) -> tuple[datetime, datetime]: - """ - Get the time window for phase arrivals. + """Get the time window for phase arrivals. Args: phase (PhaseDescription | None): Optional phase description. @@ -198,18 +225,18 @@ class EventReceivers(BaseModel): @property def n_receivers(self) -> int: - """Number of receivers in the receiver set""" + """Number of receivers in the receiver set.""" return len(self.receivers) def n_observations(self, phase: PhaseDescription) -> int: - """Number of observations for a given phase""" + """Number of observations for a given phase.""" n_observations = 0 for receiver in self: if (arrival := receiver.phase_arrivals.get(phase)) and arrival.observed: n_observations += 1 return n_observations - def get_waveforms( + async def get_waveforms( self, squirrel: Squirrel, seconds_before: float = 3.0, @@ -217,8 +244,7 @@ def get_waveforms( phase: PhaseDescription | None = None, receivers: Iterable[Receiver] | None = None, ) -> list[Trace]: - """ - Retrieves and restitutes waveforms for a given squirrel. + """Retrieves and restitutes waveforms for a given squirrel. Args: squirrel (Squirrel): The squirrel waveform organizer. @@ -247,13 +273,15 @@ def get_waveforms( tmin = min(times).timestamp() - seconds_before tmax = max(times).timestamp() + seconds_after nslc_ids = [(*receiver.nsl, "*") for receiver in receivers] - traces = squirrel.get_waveforms( - codes=nslc_ids, - tmin=tmin, - tmax=tmax, - accessor_id=accessor_id, - want_incomplete=False, - ) + async with SQUIRREL_SEM: + traces = await asyncio.to_thread( + squirrel.get_waveforms, + codes=nslc_ids, + tmin=tmin, + tmax=tmax, + accessor_id=accessor_id, + want_incomplete=False, + ) squirrel.advance_accessor(accessor_id, cache_id="waveform") for tr in traces: @@ -274,12 +302,11 @@ async def get_waveforms_restituted( phase: PhaseDescription | None = None, quantity: MeasurementUnit = "velocity", demean: bool = True, - remove_clipped: bool = False, + filter_clipped: bool = False, freqlimits: tuple[float, float, float, float] = (0.01, 0.1, 25.0, 35.0), receivers: Iterable[Receiver] | None = None, ) -> list[Trace]: - """ - Retrieves and restitutes waveforms for a given squirrel. + """Retrieves and restitutes waveforms for a given squirrel. Args: squirrel (Squirrel): The squirrel waveform organizer. @@ -302,20 +329,20 @@ async def get_waveforms_restituted( The frequency limits. Defaults to (0.01, 0.1, 25.0, 35.0). receivers (list[Receiver] | None, optional): The receivers to retrieve waveforms for. If None, all receivers are retrieved. Defaults to None. + filter_clipped (bool, optional): Whether to filter clipped traces. + Defaults to False. Returns: list[Trace]: The restituted waveforms. """ - traces = await asyncio.to_thread( - self.get_waveforms, + traces = await self.get_waveforms( squirrel, phase=phase, seconds_after=seconds_after + seconds_fade, seconds_before=seconds_before + seconds_fade, receivers=receivers, ) - traces = filter_clipped_traces(traces) if remove_clipped else traces - + traces = filter_clipped_traces(traces) if filter_clipped else traces if not traces: return [] @@ -356,8 +383,7 @@ def get_response(tr: Trace) -> Any: return restituted_traces def get_receiver(self, nsl: NSL) -> Receiver: - """ - Get the receiver object based on given NSL tuple. + """Get the receiver object based on given NSL tuple. Args: nsl (tuple[str, str, str]): The network, station, and location tuple. @@ -378,7 +404,7 @@ def add( stations: Stations, phase_arrivals: list[PhaseDetection | None], ) -> None: - """Add receivers to the receiver set + """Add receivers to the receiver set. Args: stations: List of stations @@ -398,8 +424,7 @@ def add( receiver.add_phase_detection(arrival) def get_by_nsl(self, nsl: NSL) -> Receiver: - """ - Retrieves a receiver object by its NSL (network, station, location) tuple. + """Retrieves a receiver object by its NSL (network, station, location) tuple. Args: nsl (NSL): The NSL tuple representing @@ -417,8 +442,7 @@ def get_by_nsl(self, nsl: NSL) -> Receiver: raise KeyError(f"cannot find station {nsl.pretty}") def get_pyrocko_markers(self) -> list[marker.PhaseMarker]: - """ - Get a list of Pyrocko phase markers from all receivers. + """Get a list of Pyrocko phase markers from all receivers. Returns: A list of Pyrocko phase markers. @@ -486,8 +510,7 @@ def migrate_features(cls, v: Any) -> list[EventFeaturesType]: @classmethod def set_rundir(cls, rundir: Path) -> None: - """ - Set the rundir for the detection model. + """Set the rundir for the detection model. Args: rundir (Path): The path to the rundir. @@ -497,22 +520,21 @@ def set_rundir(cls, rundir: Path) -> None: @property def magnitude(self) -> EventMagnitude | None: - """ - Returns the magnitude of the event. + """Returns the magnitude of the event. If there are no magnitudes available, returns None. """ return self.magnitudes[0] if self.magnitudes else None async def save(self, file: Path | None = None, update: bool = False) -> None: - """ - Dump the detection data to a file. + """Dump the detection data to a file. After the detection is dumped, the receivers are dumped to a separate file and the receivers cache is cleared. Args: - directory (Path): The directory where the file will be saved. + file (Path|None): The file to dump the detection to. + If None, the rundir is used. Defaults to None. update (bool): Whether to update an existing detection or append a new one. Raises: @@ -539,31 +561,35 @@ async def save(self, file: Path | None = None, update: bool = False) -> None: await asyncio.shield(f.writelines(lines)) else: logger.debug("appending detection %d", self._detection_idx) - async with aiofiles.open(file, "a") as f: - await f.write(f"{json_data}\n") + async with UPDATE_LOCK: + async with aiofiles.open(file, "a") as f: + await f.write(f"{json_data}\n") - receiver_file = self._rundir / FILENAME_RECEIVERS - async with aiofiles.open(receiver_file, "a") as f: - await asyncio.shield(f.write(f"{self.receivers.model_dump_json()}\n")) + receiver_file = self._rundir / FILENAME_RECEIVERS + async with aiofiles.open(receiver_file, "a") as f: + await asyncio.shield( + f.write(f"{self.receivers.model_dump_json()}\n") + ) self._receivers = None # Free the memory - def set_index(self, index: int) -> None: - """ - Set the index of the detection. + def set_index(self, index: int, force: bool = False) -> None: + """Set the index of the detection. Args: index (int): The index to set. + force (bool, optional): Whether to force the index to be set. + Defaults to False. Returns: None """ - if self._detection_idx is not None: + if not force and self._detection_idx is not None: raise ValueError("cannot set index twice") self._detection_idx = index def set_uncertainty(self, uncertainty: DetectionUncertainty) -> None: - """Set detection uncertainty + """Set detection uncertainty. Args: uncertainty (DetectionUncertainty): detection uncertainty @@ -571,11 +597,14 @@ def set_uncertainty(self, uncertainty: DetectionUncertainty) -> None: self.uncertainty = uncertainty def add_magnitude(self, magnitude: EventMagnitude) -> None: - """Add magnitude to detection + """Add magnitude to detection. Args: magnitude (EventMagnitudeType): magnitude """ + for mag in self.magnitudes.copy(): + if type(magnitude) is type(mag): + self.magnitudes.remove(mag) self.magnitudes.append(magnitude) def add_feature(self, feature: EventFeature) -> None: @@ -589,8 +618,7 @@ def add_feature(self, feature: EventFeature) -> None: @computed_field @property def receivers(self) -> EventReceivers: - """ - Retrieves the event receivers associated with the detection. + """Retrieves the event receivers associated with the detection. Returns: EventReceivers: The event receivers associated with the detection. @@ -607,20 +635,30 @@ def receivers(self) -> EventReceivers: elif self._rundir and self._detection_idx is not None: if self._receiver_cache is None: raise ValueError("cannot fetch receivers without set rundir") - logger.debug("fetching receiver information from cache") - row = self._receiver_cache.get_row(self._detection_idx) - receivers = EventReceivers.model_validate_json(row) - if receivers.event_uid != self.uid: - raise ValueError(f"uid mismatch: {receivers.event_uid} != {self.uid}") + try: + line = self._receiver_cache.get_line(self._detection_idx) + receivers = EventReceivers.model_validate_json(line) + except IndexError: + receivers = None + + if not receivers or receivers.event_uid != self.uid: + logger.warning("event %s uid mismatch, using brute search", self.time) + try: + idx, line = self._receiver_cache.find_uid(self.uid) + receivers = EventReceivers.model_validate_json(line) + self.set_index(idx, force=True) + except KeyError: + raise ValueError(f"uid mismatch for event {self.time}") from None + self._receivers = receivers else: raise ValueError("cannot fetch receivers without set rundir and index") return self._receivers def as_pyrocko_event(self) -> Event: - """Get detection as Pyrocko event + """Get detection as Pyrocko event. Returns: Event: Pyrocko event @@ -640,7 +678,7 @@ def as_pyrocko_event(self) -> Event: ) def get_csv_dict(self) -> dict[str, Any]: - """Get detection as CSV line + """Get detection as CSV line. Returns: dict[str, Any]: CSV line @@ -653,7 +691,6 @@ def get_csv_dict(self) -> dict[str, Any]: "east_shift": round(self.east_shift, 2), "north_shift": round(self.north_shift, 2), "distance_border": round(self.distance_border, 2), - "in_bounds": self.in_bounds, "semblance": self.semblance, } for magnitude in self.magnitudes: @@ -661,7 +698,7 @@ def get_csv_dict(self) -> dict[str, Any]: return csv_line def get_pyrocko_markers(self) -> list[marker.EventMarker | marker.PhaseMarker]: - """Get detections as Pyrocko markers + """Get detections as Pyrocko markers. Returns: list[marker.EventMarker | marker.PhaseMarker]: Pyrocko markers @@ -677,7 +714,7 @@ def get_pyrocko_markers(self) -> list[marker.EventMarker | marker.PhaseMarker]: return pyrocko_markers def export_pyrocko_markers(self, filename: Path) -> None: - """Save detection's Pyrocko markers to file + """Save detection's Pyrocko markers to file. Args: filename (Path): path to marker file @@ -686,7 +723,7 @@ def export_pyrocko_markers(self, filename: Path) -> None: marker.save_markers(self.get_pyrocko_markers(), str(filename)) def jitter_location(self, meters: float) -> Self: - """Randomize detection location + """Randomize detection location. Args: meters (float): maximum randomization in meters @@ -702,8 +739,12 @@ def jitter_location(self, meters: float) -> Self: detection._cached_lat_lon = None return detection - def snuffle(self, squirrel: Squirrel, restituted: bool = False) -> None: - """Open snuffler for detection + def snuffle( + self, + squirrel: Squirrel, + restituted: bool | MeasurementUnit = False, + ) -> None: + """Open snuffler for detection. Args: squirrel (Squirrel): The squirrel, holding the data @@ -711,10 +752,13 @@ def snuffle(self, squirrel: Squirrel, restituted: bool = False) -> None: """ from pyrocko.trace import snuffle + restitute_unit = "velocity" if restituted is True else restituted traces = ( self.receivers.get_waveforms(squirrel) - if not restituted - else self.receivers.get_waveforms_restituted(squirrel) + if not restitute_unit + else self.receivers.get_waveforms_restituted( + squirrel, quantity=restitute_unit + ) ) snuffle( traces, diff --git a/src/qseek/models/detection_uncertainty.py b/src/qseek/models/detection_uncertainty.py index fc173fab..5b2b386c 100644 --- a/src/qseek/models/detection_uncertainty.py +++ b/src/qseek/models/detection_uncertainty.py @@ -32,13 +32,12 @@ class DetectionUncertainty(BaseModel): def from_event( cls, source_node: Node, octree: Octree, percentile: float = PERCENTILE ) -> Self: - """ - Calculate the uncertainty of an event detection. + """Calculate the uncertainty of an event detection. Args: - event: The event detection to calculate the uncertainty for. - octree: The octree to use for the calculation. - percentile: The percentile to use for the calculation. + source_node (Node): The source node of the event. + octree (Octree): The octree to use for the calculation. + percentile (float): The percentile to use for the calculation. Defaults to 0.02 (2%). Returns: diff --git a/src/qseek/models/location.py b/src/qseek/models/location.py index e1f5e4a8..5f367df2 100644 --- a/src/qseek/models/location.py +++ b/src/qseek/models/location.py @@ -91,7 +91,6 @@ def surface_distance_to(self, other: Location) -> float: Returns: float: The surface distance in [m]. """ - if self._same_origin(other): return math.sqrt( (self.north_shift - other.north_shift) ** 2 @@ -129,7 +128,7 @@ def distance_to(self, other: Location) -> float: return math.sqrt((sx - ox) ** 2 + (sy - oy) ** 2 + (sz - oz) ** 2) def offset_from(self, other: Location) -> tuple[float, float, float]: - """Return offset vector (east, north, depth) from other location in [m] + """Return offset vector (east, north, depth) from other location in [m]. Args: other (Location): The other location. @@ -185,9 +184,7 @@ def shift(self, east: float, north: float, elevation: float) -> Self: return shifted def origin(self) -> Location: - """ - Returns the origin location based on the latitude, longitude, - and effective elevation. + """Get the origin location. Returns: Location: The origin location. diff --git a/src/qseek/models/semblance.py b/src/qseek/models/semblance.py index aa26868f..8ac988fc 100644 --- a/src/qseek/models/semblance.py +++ b/src/qseek/models/semblance.py @@ -200,8 +200,7 @@ def get_time_from_index(self, index: int) -> datetime: return self.start_time + timedelta(seconds=index / self.sampling_rate) def get_semblance(self, time_idx: int) -> np.ndarray: - """ - Get the semblance values at a specific time index. + """Get the semblance values at a specific time index. Parameters: time_idx (int): The index of the desired time. @@ -212,8 +211,7 @@ def get_semblance(self, time_idx: int) -> np.ndarray: return self.semblance[:, time_idx] async def apply_cache(self, cache: SemblanceCache) -> None: - """ - Applies the cached data to the `semblance_unpadded` array. + """Applies the cached data to the `semblance_unpadded` array. Args: cache (SemblanceCache): The cache containing the cached data. @@ -255,7 +253,7 @@ async def maxima_semblance( Args: trim_padding (bool, optional): Trim padded data in post-processing. - nparallel (int, optional): Number of threads for calculation. + nthreads (int, optional): Number of threads for calculation. Defaults to 12. Returns: @@ -282,7 +280,9 @@ async def maxima_node_idx( """Indices of maximum semblance at any time step. Args: - nparallel (int, optional): Number of threads for calculation. + trim_padding (bool, optional): Trim padded data in post-processing. + Defaults to True. + nthreads (int, optional): Number of threads for calculation. Defaults to 12. Returns: @@ -330,6 +330,8 @@ async def find_peaks( distance (float): Minium distance of a peak to other peaks. trim_padding (bool, optional): Trim padded data in post-processing. Defaults to True. + nthreads (int, optional): Number of threads for calculation. + Defaults to 12. Returns: tuple[np.ndarray, np.ndarray]: Indices of peaks and peak values. @@ -427,6 +429,8 @@ def normalize( Args: factor (int | float): Normalization factor. + semblance_cache (SemblanceCache | None, optional): Cache of the semblance. + Defaults to None. """ if factor == 1.0: return diff --git a/src/qseek/models/station.py b/src/qseek/models/station.py index ded665d2..1bfd7145 100644 --- a/src/qseek/models/station.py +++ b/src/qseek/models/station.py @@ -179,7 +179,7 @@ def select_from_traces(self, traces: Iterable[Trace]) -> Stations: """Select stations by NSL code. Args: - selection (Iterable[Trace]): Iterable of Pyrocko Traces + traces (Iterable[Trace]): Iterable of Pyrocko Traces Returns: Stations: Containing only selected stations. @@ -216,8 +216,7 @@ def get_coordinates(self, system: CoordSystem = "geographic") -> np.ndarray: ) def as_pyrocko_stations(self) -> list[PyrockoStation]: - """ - Convert the stations to PyrockoStation objects. + """Convert the stations to PyrockoStation objects. Returns: A list of PyrockoStation objects. diff --git a/src/qseek/octree.py b/src/qseek/octree.py index 291c1c4a..f2a89e6b 100644 --- a/src/qseek/octree.py +++ b/src/qseek/octree.py @@ -84,7 +84,7 @@ class Node: _location: Location | None = None def split(self) -> tuple[Node, ...]: - """Split the node into 8 children""" + """Split the node into 8 children.""" if not self.tree: raise EnvironmentError("Parent tree is not set.") @@ -149,8 +149,8 @@ def is_inside_border(self, with_surface: bool = False) -> bool: """Check if the node is within the root node border. Args: - trough (bool, optional): If True, the node is considered inside the - trough (open top). Defaults to False. + with_surface (bool, optional): If True, the surface is considered + as a border. Defaults to False. Returns: bool: True if the node is inside the root tree's border. @@ -200,8 +200,7 @@ def distance_to_location(self, location: Location) -> float: return location.distance_to(self.as_location()) def semblance_density(self) -> float: - """ - Calculate the semblance density of the octree. + """Calculate the semblance density of the octree. Returns: The semblance density of the octree. @@ -355,7 +354,7 @@ def check_limits(self) -> Octree: return self def model_post_init(self, __context: Any) -> None: - """Initialize octree. This method is called by the pydantic model""" + """Initialize octree. This method is called by the pydantic model.""" self._root_nodes = self.get_root_nodes(self.root_node_size) logger.info( @@ -394,12 +393,12 @@ def get_root_nodes(self, length: float) -> list[Node]: @cached_property def n_nodes(self) -> int: - """Number of nodes in the octree""" + """Number of nodes in the octree.""" return sum(1 for _ in self) @property def volume(self) -> float: - """Volume of the octree in cubic meters""" + """Volume of the octree in cubic meters.""" return reduce(mul, self.extent()) def iter_nodes(self, level: int | None = None) -> Iterator[Node]: @@ -433,7 +432,7 @@ def _clear_cache(self) -> None: del self.n_nodes def reset(self) -> Self: - """Reset the octree to its initial state""" + """Reset the octree to its initial state.""" logger.debug("resetting tree") self._clear_cache() self._root_nodes = self.get_root_nodes(self.root_node_size) @@ -459,11 +458,14 @@ def reduce_axis( self, surface: Literal["NE", "ED", "ND"] = "NE", max_level: int = -1, - accumulator: Callable = np.max, + accumulator: Callable[np.ndarray] = np.max, ) -> np.ndarray: - """Reduce the octree's nodes to the surface + """Reduce the octree's nodes to the surface. Args: + surface (Literal["NE", "ED", "ND"], optional): Surface to reduce to. + Defaults to "NE". + max_level (int, optional): Maximum level to reduce to. Defaults to -1. accumulator (Callable, optional): Accumulator function. Defaults to np.max. Returns: @@ -553,8 +555,7 @@ def distances_stations_surface(self, stations: Stations) -> np.ndarray: ).reshape(-1, stations.n_stations) def get_nodes(self, indices: Iterable[int]) -> list[Node]: - """ - Retrieves a list of nodes from the octree based on the given indices. + """Retrieves a list of nodes from the octree based on the given indices. Args: indices (Iterable[int]): The indices of the nodes to retrieve. diff --git a/src/qseek/pre_processing/base.py b/src/qseek/pre_processing/base.py index 8f4a3ac3..9b04fb68 100644 --- a/src/qseek/pre_processing/base.py +++ b/src/qseek/pre_processing/base.py @@ -31,17 +31,14 @@ def validate_stations(cls, v) -> set[NSL]: @classmethod def get_subclasses(cls) -> tuple[type[BatchPreProcessing], ...]: - """ - Returns a tuple of all the subclasses of BasePreProcessing. - """ + """Returns a tuple of all the subclasses of BasePreProcessing.""" return tuple(cls.__subclasses__()) def select_traces(self, batch: WaveformBatch) -> list[Trace]: - """ - Selects traces from the given list based on the stations specified. + """Selects traces from the given list based on the stations specified. Args: - traces (list[Trace]): The list of traces to select from. + batch (WaveformBatch): The batch of traces to select from. Returns: list[Trace]: The selected traces. @@ -57,17 +54,14 @@ def select_traces(self, batch: WaveformBatch) -> list[Trace]: return traces async def prepare(self) -> None: - """ - Prepare the pre-processing module. - """ + """Prepare the pre-processing module.""" pass async def process_batch(self, batch: WaveformBatch) -> WaveformBatch: - """ - Process a list of traces. + """Process a list of traces. Args: - traces (list[Trace]): The list of traces to be processed. + batch (WaveformBatch): The batch of traces to process. Returns: list[Trace]: The processed list of traces. diff --git a/src/qseek/pre_processing/module.py b/src/qseek/pre_processing/module.py index a6a2cf5b..689f2f9e 100644 --- a/src/qseek/pre_processing/module.py +++ b/src/qseek/pre_processing/module.py @@ -103,11 +103,11 @@ async def worker() -> None: start_time = datetime_now() for process in self: batch = await process.process_batch(batch) - await self._queue.put(batch) stats.time_per_batch = datetime_now() - start_time stats.bytes_per_second = ( batch.cumulative_bytes / stats.time_per_batch.total_seconds() ) + await self._queue.put(batch) await self._queue.put(None) diff --git a/src/qseek/search.py b/src/qseek/search.py index bf28db6e..3d0f852b 100644 --- a/src/qseek/search.py +++ b/src/qseek/search.py @@ -38,6 +38,7 @@ PhaseDescription, alog_call, datetime_now, + get_cpu_count, human_readable_bytes, time_to_path, ) @@ -90,8 +91,7 @@ def time_remaining(self) -> timedelta: @computed_field @property def processing_rate(self) -> float: - """ - Calculate the processing rate of the search. + """Calculate the processing rate of the search. Returns: float: The processing rate in bytes per second. @@ -110,8 +110,7 @@ def processing_speed(self) -> timedelta: @computed_field @property def processed_percent(self) -> float: - """ - Calculate the percentage of processed batches. + """Calculate the percentage of processed batches. Returns: float: The percentage of processed batches. @@ -129,7 +128,7 @@ def add_processed_batch( duration: timedelta, show_log: bool = False, ) -> None: - self.batch_count = batch.i_batch + self.batch_count = batch.i_batch + 1 self.batch_count_total = batch.n_batches self.batch_time = batch.end_time self.processed_bytes += batch.cumulative_bytes @@ -166,7 +165,7 @@ def tts(duration: timedelta) -> str: table.add_row( "Progress ", f"[bold]{self.processed_percent:.1f}%[/bold]" - f" ([bold]{self.batch_count+1}[/bold]/{self.batch_count_total or '?'}," + f" ([bold]{self.batch_count}[/bold]/{self.batch_count_total or '?'}," f' {self.batch_time.strftime("%Y-%m-%d %H:%M:%S")})', ) table.add_row( @@ -300,7 +299,9 @@ class Search(BaseModel): _config_stem: str = PrivateAttr("") _rundir: Path = PrivateAttr() - _feature_semaphore: asyncio.Semaphore = PrivateAttr(asyncio.Semaphore(16)) + _compute_semaphore: asyncio.Semaphore = PrivateAttr( + asyncio.Semaphore(max(1, get_cpu_count() - 4)) + ) # Signals _new_detection: Signal[EventDetection] = PrivateAttr(Signal()) @@ -415,8 +416,7 @@ async def init_boundaries(self) -> None: ) async def prepare(self) -> None: - """ - Prepares the search by initializing necessary components and data. + """Prepares the search by initializing necessary components and data. This method prepares the search by performing the following steps: 1. Prepares the data provider with the given stations. @@ -466,6 +466,10 @@ async def start(self, force_rundir: bool = False) -> None: if self._progress.time_progress: logger.info("continuing search from %s", self._progress.time_progress) + await self._catalog.filter_events_by_time( + start_time=None, + end_time=self._progress.time_progress, + ) batches = self.data_provider.iter_batches( window_increment=self.window_length, @@ -509,38 +513,45 @@ async def start(self, force_rundir: bool = False) -> None: ) console.cancel() logger.info("finished search in %s", datetime_now() - processing_start) - logger.info("found %d detections", self._catalog.n_events) + logger.info("detected %d events", self._catalog.n_events) async def new_detections(self, detections: list[EventDetection]) -> None: - """ - Process new detections. + """Process new detections. Args: detections (list[EventDetection]): List of new event detections. """ + catalog = self.catalog await asyncio.gather( *(self.add_magnitude_and_features(det) for det in detections) ) for detection in detections: - await self._catalog.add(detection) + await catalog.add(detection) await self._new_detection.emit(detection) - if ( - self._catalog.n_events - and self._catalog.n_events - self._last_detection_export > 100 - ): - await self._catalog.export_detections( + if not catalog.n_events: + return + + threshold = np.floor(np.log10(catalog.n_events)) - 1 + new_threshold = max(10, 10**threshold) + if catalog.n_events - self._last_detection_export > new_threshold: + await catalog.export_detections( jitter_location=self.octree.smallest_node_size() ) - self._last_detection_export = self._catalog.n_events + self._last_detection_export = catalog.n_events - async def add_magnitude_and_features(self, event: EventDetection) -> EventDetection: - """ - Adds magnitude and features to the given event. + async def add_magnitude_and_features( + self, + event: EventDetection, + recalculate: bool = True, + ) -> EventDetection: + """Adds magnitude and features to the given event. Args: event (EventDetection): The event to add magnitude and features to. + recalculate (bool, optional): Whether to overwrite existing magnitudes and + features. Defaults to True. """ if not event.in_bounds: return event @@ -550,8 +561,10 @@ async def add_magnitude_and_features(self, event: EventDetection) -> EventDetect except NotImplementedError: return event - async with self._feature_semaphore: + async with self._compute_semaphore: for mag_calculator in self.magnitudes: + if not recalculate and mag_calculator.has_magnitude(event): + continue logger.debug("adding magnitude from %s", mag_calculator.magnitude) await mag_calculator.add_magnitude(squirrel, event) @@ -688,8 +701,7 @@ async def calculate_semblance( ) async def get_images(self, sampling_rate: float | None = None) -> WaveformImages: - """ - Retrieves waveform images for the specified sampling rate. + """Retrieves waveform images for the specified sampling rate. Args: sampling_rate (float | None, optional): The desired sampling rate in Hz. @@ -719,6 +731,8 @@ async def search( Args: octree (Octree | None, optional): The octree to use for the search. Defaults to None. + semblance_cache (SemblanceCache | None, optional): The semblance cache to + use for the search. Defaults to None. Returns: tuple[list[EventDetection], Trace]: The event detections and the diff --git a/src/qseek/tracers/cake.py b/src/qseek/tracers/cake.py index b8d74ad8..f916ec55 100644 --- a/src/qseek/tracers/cake.py +++ b/src/qseek/tracers/cake.py @@ -151,8 +151,7 @@ def get_profile_vs(self) -> np.ndarray: return self.layered_model.profile("vs") def save_plot(self, filename: Path) -> None: - """ - Plot the layered model and save the figure to a file. + """Plot the layered model and save the figure to a file. Args: filename (Path): The path to save the figure. @@ -312,7 +311,7 @@ def save(self, path: Path) -> Path: """Save the model and traveltimes to an .sptree archive. Args: - folder (Path): Folder or file to save tree into. If path is a folder a + path (Path): Folder or file to save tree into. If path is a folder a native name from the model's hash is used Returns: @@ -398,7 +397,7 @@ async def init_lut(self, octree: Octree, stations: Stations) -> None: self._node_lut[node.hash()] = traveltimes.astype(np.float32) def lut_fill_level(self) -> float: - """Return the fill level of the LUT as a float between 0.0 and 1.0""" + """Return the fill level of the LUT as a float between 0.0 and 1.0.""" return len(self._node_lut) / self._node_lut.get_size() async def fill_lut(self, nodes: Sequence[Node]) -> None: diff --git a/src/qseek/utils.py b/src/qseek/utils.py index afc8ac61..ec7f9173 100644 --- a/src/qseek/utils.py +++ b/src/qseek/utils.py @@ -101,8 +101,7 @@ def pretty(self) -> str: return ".".join(self) def match(self, other: NSL) -> bool: - """ - Check if the current NSL object matches another NSL object. + """Check if the current NSL object matches another NSL object. Args: other (NSL): The NSL object to compare with. @@ -118,8 +117,7 @@ def match(self, other: NSL) -> bool: @classmethod def parse(cls, nsl: str) -> NSL: - """ - Parse the given NSL string and return an NSL object. + """Parse the given NSL string and return an NSL object. Args: nsl (str): The NSL string to parse. @@ -148,8 +146,7 @@ class _Range(NamedTuple): max: float def inside(self, value: float) -> bool: - """ - Check if a value is inside the range. + """Check if a value is inside the range. Args: value (float): The value to check. @@ -161,8 +158,7 @@ def inside(self, value: float) -> bool: @classmethod def from_list(cls, array: np.ndarray | list[float]) -> _Range: - """ - Create a Range object from a numpy array. + """Create a Range object from a numpy array. Parameters: - array: numpy.ndarray @@ -184,8 +180,7 @@ def _range_validator(v: _Range) -> _Range: def time_to_path(datetime: datetime) -> str: - """ - Converts a datetime object to a string representation of a file path. + """Converts a datetime object to a string representation of a file path. Args: datetime (datetime): The datetime object to convert. @@ -197,8 +192,7 @@ def time_to_path(datetime: datetime) -> str: def as_array(iterable: Iterable[float], dtype: np.dtype = float) -> np.ndarray: - """ - Convert an iterable of floats into a NumPy array. + """Convert an iterable of floats into a NumPy array. Parameters: iterable (Iterable[float]): An iterable containing float values. @@ -210,8 +204,7 @@ def as_array(iterable: Iterable[float], dtype: np.dtype = float) -> np.ndarray: def weighted_median(data: np.ndarray, weights: np.ndarray | None = None) -> float: - """ - Calculate the weighted median of an array/list using numpy. + """Calculate the weighted median of an array/list using numpy. Parameters: data (np.ndarray): The input array/list. @@ -254,8 +247,7 @@ def weighted_median(data: np.ndarray, weights: np.ndarray | None = None) -> floa async def async_weighted_median( data: np.ndarray, weights: np.ndarray | None = None ) -> float: - """ - Asynchronously calculate the weighted median of an array/list using numpy. + """Asynchronously calculate the weighted median of an array/list using numpy. Parameters: data (np.ndarray): The input array/list. @@ -296,8 +288,7 @@ async def async_weighted_median( def to_datetime(time: float) -> datetime: - """ - Convert a UNIX timestamp to a datetime object in UTC timezone. + """Convert a UNIX timestamp to a datetime object in UTC timezone. Args: time (float): The UNIX timestamp to convert. @@ -309,8 +300,7 @@ def to_datetime(time: float) -> datetime: def resample(trace: Trace, sampling_rate: float) -> None: - """ - Downsamples the given trace to the specified sampling rate in-place. + """Downsamples the given trace to the specified sampling rate in-place. Args: trace (Trace): The trace to be downsampled. @@ -361,8 +351,7 @@ async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: def human_readable_bytes(size: int | float) -> str: - """ - Convert a size in bytes to a human-readable string representation. + """Convert a size in bytes to a human-readable string representation. Args: size (int | float): The size in bytes. @@ -375,8 +364,7 @@ def human_readable_bytes(size: int | float) -> str: def datetime_now() -> datetime: - """ - Get the current datetime in UTC timezone. + """Get the current datetime in UTC timezone. Returns: datetime: The current datetime in UTC timezone. @@ -385,8 +373,7 @@ def datetime_now() -> datetime: def get_cpu_count() -> int: - """ - Get the number of CPUs available for the current job/task. + """Get the number of CPUs available for the current job/task. The function first checks if the environment variable SLURM_CPUS_PER_TASK is set. If it is set, the value is returned as the number of CPUs. @@ -417,8 +404,7 @@ def filter_clipped_traces( counts_threshold: int = 20, max_bits: tuple[int, ...] = (24, 32), ) -> list[Trace]: - """ - Filters out clipped traces from the given list of traces. + """Filters out clipped traces from the given list of traces. Args: traces (list[Trace]): The list of traces to filter. @@ -455,8 +441,7 @@ def filter_clipped_traces( def camel_case_to_snake_case(name: str) -> str: - """ - Converts a camel case string to snake case. + """Converts a camel case string to snake case. Args: name (str): The camel case string to be converted. @@ -472,8 +457,7 @@ def camel_case_to_snake_case(name: str) -> str: def load_insights() -> None: - """ - Imports the qseek.insights package if available. + """Imports the qseek.insights package if available. This function attempts to import the qseek.insights package and logs a debug message indicating whether the package was successfully imported or not. @@ -503,11 +487,10 @@ class ChannelSelector: normalize: bool = False def get_traces(self, traces_flt: list[Trace]) -> list[Trace]: - """ - Filter and normalize a list of traces based on the specified channels. + """Filter and normalize a list of traces based on the specified channels. Args: - traces (list[Trace]): The list of traces to filter. + traces_flt (list[Trace]): The list of traces to filter. Returns: list[Trace]: The filtered and normalized list of traces. @@ -562,7 +545,7 @@ class ChannelSelectors: def generate_docs(model: BaseModel, exclude: dict | set | None = None) -> str: - """Takes model and dumps markdown for documentation""" + """Takes model and dumps markdown for documentation.""" def generate_submodel(model: BaseModel) -> list[str]: lines = [] diff --git a/test/test_moment_magnitude_store.py b/test/test_moment_magnitude_store.py index 5076ee2e..f1216dbd 100644 --- a/test/test_moment_magnitude_store.py +++ b/test/test_moment_magnitude_store.py @@ -45,17 +45,35 @@ async def test_peak_amplitudes(engine: gf.LocalEngine) -> None: ) PeakAmplitudesStore.set_engine(engine) store = PeakAmplitudesStore.from_selector(peak_amplitudes) - await store.fill_source_depth(source_depth=2 * KM) - await store.get_amplitude( + await store.compute_site_amplitudes(source_depth=2 * KM, reference_magnitude=1.0) + await store.get_amplitude_model( source_depth=2 * KM, distance=10 * KM, n_amplitudes=10, - max_distance=1 * KM, + distance_cutoff=1 * KM, auto_fill=False, interpolation="nearest", ) +@pytest.mark.asyncio +async def test_peak_amplitude_estimation(engine: gf.LocalEngine) -> None: + store_id = "reykjanes_qseis" + peak_amplitudes = PeakAmplitudesBase( + gf_store_id=store_id, + quantity="displacement", + ) + PeakAmplitudesStore.set_engine(engine) + store = PeakAmplitudesStore.from_selector(peak_amplitudes) + await store.compute_site_amplitudes(source_depth=2 * KM, reference_magnitude=1.0) + + await store.find_moment_magnitude( + source_depth=2 * KM, + distance=10 * KM, + observed_amplitude=0.0001, + ) + + @pytest.mark.plot @pytest.mark.asyncio async def test_peak_amplitude_plot(engine: gf.LocalEngine) -> None: @@ -69,17 +87,27 @@ async def test_peak_amplitude_plot(engine: gf.LocalEngine) -> None: PeakAmplitudesStore.set_engine(engine) store = PeakAmplitudesStore.from_selector(peak_amplitudes) - collection = await store.fill_source_depth(source_depth=2 * KM) + collection = await store.compute_site_amplitudes( + source_depth=2 * KM, reference_magnitude=1.0 + ) collection.plot(peak_amplitude=plot_amplitude) + await store.find_moment_magnitude( + source_depth=2 * KM, + distance=10 * KM, + observed_amplitude=0.01, + ) + peak_amplitudes = PeakAmplitudesBase( gf_store_id=store_id, quantity="velocity", ) store = PeakAmplitudesStore.from_selector(peak_amplitudes) - collection = await store.fill_source_depth(source_depth=2 * KM) - collection.plot(peak_amplitude=plot_amplitude) + collection = await store.compute_site_amplitudes( + source_depth=2 * KM, reference_magnitude=2.0 + ) + collection.plot(peak_amplitude=plot_amplitude, reference_magnitude=2.0) peak_amplitudes = PeakAmplitudesBase( gf_store_id=store_id, @@ -87,7 +115,9 @@ async def test_peak_amplitude_plot(engine: gf.LocalEngine) -> None: ) store = PeakAmplitudesStore.from_selector(peak_amplitudes) - collection = await store.fill_source_depth(source_depth=2 * KM) + collection = await store.compute_site_amplitudes( + source_depth=2 * KM, reference_magnitude=1.0 + ) collection.plot(peak_amplitude=plot_amplitude) @@ -116,10 +146,11 @@ async def test_peak_amplitude_surface(engine: gf.LocalEngine) -> None: amplitudes: list[ModelledAmplitude] = [] for dist in distances: amplitudes.append( - await store.get_amplitude( + await store.get_amplitude_model( source_depth=depth, distance=dist, n_amplitudes=25, + reference_magnitude=1.0, peak_amplitude=plot_amplitude, auto_fill=False, ) From 7a6e17574960c1ffcbdcab3cf5aa98466e1211a2 Mon Sep 17 00:00:00 2001 From: Marius Isken Date: Fri, 22 Mar 2024 16:04:16 +0000 Subject: [PATCH 04/26] update --- src/qseek/apps/qseek.py | 28 +++++++++++++++++++--------- src/qseek/models/detection.py | 10 ++++++++-- src/qseek/pre_processing/module.py | 1 + 3 files changed, 28 insertions(+), 11 deletions(-) diff --git a/src/qseek/apps/qseek.py b/src/qseek/apps/qseek.py index 72343179..1cc66462 100644 --- a/src/qseek/apps/qseek.py +++ b/src/qseek/apps/qseek.py @@ -12,7 +12,6 @@ from pkg_resources import get_distribution from qseek.models.detection import EventDetection -from qseek.utils import get_cpu_count nest_asyncio.apply() @@ -137,6 +136,12 @@ default=False, help="recalculate all magnitudes", ) +features_extract.add_argument( + "--nparallel", + type=int, + default=32, + help="number of parallel tasks for feature extraction", +) modules = subparsers.add_parser( "modules", @@ -203,7 +208,7 @@ def main() -> None: load_insights() from rich import box - from rich.progress import track + from rich.progress import Progress from rich.prompt import IntPrompt from rich.table import Table @@ -282,18 +287,20 @@ def console_status(task: asyncio.Task[EventDetection]): else: console.print(f"Event {detection.time}: No magnitudes") + progress = Progress() + tracker = progress.add_task( + "Calculating magnitudes", + total=search.catalog.n_events, + console=console, + ) + async def worker() -> None: for magnitude in search.magnitudes: await magnitude.prepare(search.octree, search.stations) await search.catalog.check(repair=True) - sem = asyncio.Semaphore(get_cpu_count()) - for detection in track( - search.catalog, - description="Calculating magnitudes", - total=search.catalog.n_events, - console=console, - ): + sem = asyncio.Semaphore(args.nparallel) + for detection in search.catalog: await sem.acquire() task = asyncio.create_task( search.add_magnitude_and_features( @@ -305,6 +312,9 @@ async def worker() -> None: task.add_done_callback(lambda _: sem.release()) task.add_done_callback(tasks.remove) task.add_done_callback(console_status) + task.add_done_callback( + lambda _: progress.update(tracker, advance=1) + ) await asyncio.gather(*tasks) diff --git a/src/qseek/models/detection.py b/src/qseek/models/detection.py index bc24b70e..e1d4a5bd 100644 --- a/src/qseek/models/detection.py +++ b/src/qseek/models/detection.py @@ -274,8 +274,7 @@ async def get_waveforms( tmax = max(times).timestamp() + seconds_after nslc_ids = [(*receiver.nsl, "*") for receiver in receivers] async with SQUIRREL_SEM: - traces = await asyncio.to_thread( - squirrel.get_waveforms, + traces = await squirrel.get_waveforms_async( codes=nslc_ids, tmin=tmin, tmax=tmax, @@ -526,6 +525,13 @@ def magnitude(self) -> EventMagnitude | None: """ return self.magnitudes[0] if self.magnitudes else None + async def update(self) -> None: + """Update detection in database. + + Doing this often requires a lot of I/O. + """ + await self.save(update=True) + async def save(self, file: Path | None = None, update: bool = False) -> None: """Dump the detection data to a file. diff --git a/src/qseek/pre_processing/module.py b/src/qseek/pre_processing/module.py index 689f2f9e..ec1f504e 100644 --- a/src/qseek/pre_processing/module.py +++ b/src/qseek/pre_processing/module.py @@ -103,6 +103,7 @@ async def worker() -> None: start_time = datetime_now() for process in self: batch = await process.process_batch(batch) + await asyncio.sleep(0.0) stats.time_per_batch = datetime_now() - start_time stats.bytes_per_second = ( batch.cumulative_bytes / stats.time_per_batch.total_seconds() From 221396546ce3df2bbaf930ff8129cbd1dd995200 Mon Sep 17 00:00:00 2001 From: Marius Isken Date: Mon, 25 Mar 2024 13:35:59 +0000 Subject: [PATCH 05/26] bugfixes --- src/qseek/testing.py | 165 ++++++++++++++++++++++++++++ src/qseek/utils.py | 13 ++- test/conftest.py | 165 +--------------------------- test/test_moment_magnitude_store.py | 4 + test/test_utils.py | 25 +++++ 5 files changed, 205 insertions(+), 167 deletions(-) create mode 100644 src/qseek/testing.py create mode 100644 test/test_utils.py diff --git a/src/qseek/testing.py b/src/qseek/testing.py new file mode 100644 index 00000000..b623e254 --- /dev/null +++ b/src/qseek/testing.py @@ -0,0 +1,165 @@ +import asyncio +import random +from datetime import timedelta +from pathlib import Path +from tempfile import TemporaryDirectory +from typing import Generator + +import aiohttp +import numpy as np +import pytest +from rich.progress import Progress + +from qseek.models.catalog import EventCatalog +from qseek.models.detection import EventDetection +from qseek.models.location import Location +from qseek.models.station import Station, Stations +from qseek.octree import Octree +from qseek.tracers.cake import EarthModel, Timing, TravelTimeTree +from qseek.utils import Range, datetime_now + +DATA_DIR = Path(__file__).parent / "data" + +DATA_URL = "https://data.pyrocko.org/testing/lassie-v2/" +DATA_FILES = { + "FORGE_3D_5_large.P.mod.hdr", + "FORGE_3D_5_large.P.mod.buf", + "FORGE_3D_5_large.S.mod.hdr", + "FORGE_3D_5_large.S.mod.buf", +} + +KM = 1e3 + + +async def download_test_data() -> None: + request_files = [ + DATA_DIR / filename + for filename in DATA_FILES + if not (DATA_DIR / filename).exists() + ] + + if not request_files: + return + + async with aiohttp.ClientSession() as session: + for file in request_files: + url = DATA_URL + file.name + with Progress() as progress: + async with session.get(url) as response: + task = progress.add_task( + f"Downloading {url}", + total=response.content_length, + ) + with file.open("wb") as f: + while True: + chunk = await response.content.read(1024) + if not chunk: + break + f.write(chunk) + progress.advance(task, len(chunk)) + + +def pytest_addoption(parser) -> None: + parser.addoption("--plot", action="store_true", default=False) + + +@pytest.fixture(scope="session") +def plot(pytestconfig) -> bool: + return pytestconfig.getoption("plot") + + +@pytest.fixture(scope="session") +def travel_time_tree() -> TravelTimeTree: + return TravelTimeTree.new( + earthmodel=EarthModel(), + distance_bounds=(0 * KM, 15 * KM), + receiver_depth_bounds=(0 * KM, 0 * KM), + source_depth_bounds=(0 * KM, 10 * KM), + spatial_tolerance=100, + time_tolerance=0.05, + timing=Timing(definition="P,p"), + ) + + +@pytest.fixture(scope="session") +def data_dir() -> Path: + if not DATA_DIR.exists(): + DATA_DIR.mkdir() + + asyncio.run(download_test_data()) + return DATA_DIR + + +@pytest.fixture(scope="session") +def octree() -> Octree: + return Octree( + location=Location( + lat=10.0, + lon=10.0, + elevation=1.0 * KM, + ), + root_node_size=2 * KM, + n_levels=3, + east_bounds=Range(-10 * KM, 10 * KM), + north_bounds=Range(-10 * KM, 10 * KM), + depth_bounds=Range(0 * KM, 10 * KM), + absorbing_boundary=1 * KM, + ) + + +@pytest.fixture(scope="session") +def stations() -> Stations: + n_stations = 20 + stations: list[Station] = [] + for i_sta in range(n_stations): + station = Station( + network="XX", + station="STA%02d" % i_sta, + lat=10.0, + lon=10.0, + elevation=random.uniform(0, 0.8) * KM, + depth=random.uniform(0, 0.2) * KM, + north_shift=random.uniform(-10, 10) * KM, + east_shift=random.uniform(-10, 10) * KM, + ) + stations.append(station) + return Stations(stations=stations) + + +@pytest.fixture(scope="session") +def fixed_stations() -> Stations: + n_stations = 20 + rng = np.random.RandomState(0) + stations: list[Station] = [] + for i_sta in range(n_stations): + station = Station( + network="FX", + station="STA%02d" % i_sta, + lat=10.0, + lon=10.0, + elevation=rng.uniform(0, 1) * KM, + north_shift=rng.uniform(-10, 10) * KM, + east_shift=rng.uniform(-10, 10) * KM, + ) + stations.append(station) + return Stations(stations=stations) + + +@pytest.fixture(scope="session") +def detections() -> Generator[EventCatalog, None, None]: + n_detections = 2000 + detections: list[EventDetection] = [] + for _ in range(n_detections): + time = datetime_now() - timedelta(days=random.uniform(0, 365)) + detection = EventDetection( + lat=10.0, + lon=10.0, + east_shift=random.uniform(-10, 10) * KM, + north_shift=random.uniform(-10, 10) * KM, + distance_border=1000.0, + semblance=random.uniform(0, 1), + time=time, + ) + detections.append(detection) + with TemporaryDirectory() as tmpdir: + yield EventCatalog(rundir=Path(tmpdir), events=detections) diff --git a/src/qseek/utils.py b/src/qseek/utils.py index ec7f9173..8176f1e3 100644 --- a/src/qseek/utils.py +++ b/src/qseek/utils.py @@ -26,7 +26,7 @@ ) import numpy as np -from pydantic import AfterValidator, BaseModel, ByteSize, constr +from pydantic import AfterValidator, BaseModel, BeforeValidator, ByteSize, constr from pyrocko.util import UnavailableDecimation from rich.logging import RichHandler @@ -91,7 +91,7 @@ async def wait_all(cls) -> None: await asyncio.gather(*cls.tasks) -class NSL(NamedTuple): +class _NSL(NamedTuple): network: str station: str location: str @@ -116,7 +116,7 @@ def match(self, other: NSL) -> bool: return self.network == other.network @classmethod - def parse(cls, nsl: str) -> NSL: + def parse(cls, nsl: str | NSL) -> NSL: """Parse the given NSL string and return an NSL object. Args: @@ -130,6 +130,10 @@ def parse(cls, nsl: str) -> NSL: """ if not nsl: raise ValueError("invalid empty NSL") + if type(nsl) is _NSL: + return nsl + if not isinstance(nsl, str): + raise ValueError(f"invalid NSL {nsl}") parts = nsl.split(".") n_parts = len(parts) if n_parts >= 3: @@ -141,6 +145,9 @@ def parse(cls, nsl: str) -> NSL: raise ValueError(f"invalid NSL {nsl}") +NSL = Annotated[_NSL, BeforeValidator(_NSL.parse)] + + class _Range(NamedTuple): min: float max: float diff --git a/test/conftest.py b/test/conftest.py index d198370d..3fc91ea7 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -1,164 +1 @@ -import asyncio -import random -from datetime import timedelta -from pathlib import Path -from tempfile import TemporaryDirectory -from typing import Generator - -import aiohttp -import numpy as np -import pytest -from qseek.models.catalog import EventCatalog -from qseek.models.detection import EventDetection -from qseek.models.location import Location -from qseek.models.station import Station, Stations -from qseek.octree import Octree -from qseek.tracers.cake import EarthModel, Timing, TravelTimeTree -from qseek.utils import Range, datetime_now -from rich.progress import Progress - -DATA_DIR = Path(__file__).parent / "data" - -DATA_URL = "https://data.pyrocko.org/testing/lassie-v2/" -DATA_FILES = { - "FORGE_3D_5_large.P.mod.hdr", - "FORGE_3D_5_large.P.mod.buf", - "FORGE_3D_5_large.S.mod.hdr", - "FORGE_3D_5_large.S.mod.buf", -} - -KM = 1e3 - - -async def download_test_data() -> None: - request_files = [ - DATA_DIR / filename - for filename in DATA_FILES - if not (DATA_DIR / filename).exists() - ] - - if not request_files: - return - - async with aiohttp.ClientSession() as session: - for file in request_files: - url = DATA_URL + file.name - with Progress() as progress: - async with session.get(url) as response: - task = progress.add_task( - f"Downloading {url}", - total=response.content_length, - ) - with file.open("wb") as f: - while True: - chunk = await response.content.read(1024) - if not chunk: - break - f.write(chunk) - progress.advance(task, len(chunk)) - - -def pytest_addoption(parser) -> None: - parser.addoption("--plot", action="store_true", default=False) - - -@pytest.fixture(scope="session") -def plot(pytestconfig) -> bool: - return pytestconfig.getoption("plot") - - -@pytest.fixture(scope="session") -def travel_time_tree() -> TravelTimeTree: - return TravelTimeTree.new( - earthmodel=EarthModel(), - distance_bounds=(0 * KM, 15 * KM), - receiver_depth_bounds=(0 * KM, 0 * KM), - source_depth_bounds=(0 * KM, 10 * KM), - spatial_tolerance=100, - time_tolerance=0.05, - timing=Timing(definition="P,p"), - ) - - -@pytest.fixture(scope="session") -def data_dir() -> Path: - if not DATA_DIR.exists(): - DATA_DIR.mkdir() - - asyncio.run(download_test_data()) - return DATA_DIR - - -@pytest.fixture(scope="session") -def octree() -> Octree: - return Octree( - location=Location( - lat=10.0, - lon=10.0, - elevation=1.0 * KM, - ), - root_node_size=2 * KM, - n_levels=3, - east_bounds=Range(-10 * KM, 10 * KM), - north_bounds=Range(-10 * KM, 10 * KM), - depth_bounds=Range(0 * KM, 10 * KM), - absorbing_boundary=1 * KM, - ) - - -@pytest.fixture(scope="session") -def stations() -> Stations: - n_stations = 20 - stations: list[Station] = [] - for i_sta in range(n_stations): - station = Station( - network="XX", - station="STA%02d" % i_sta, - lat=10.0, - lon=10.0, - elevation=random.uniform(0, 0.8) * KM, - depth=random.uniform(0, 0.2) * KM, - north_shift=random.uniform(-10, 10) * KM, - east_shift=random.uniform(-10, 10) * KM, - ) - stations.append(station) - return Stations(stations=stations) - - -@pytest.fixture(scope="session") -def fixed_stations() -> Stations: - n_stations = 20 - rng = np.random.RandomState(0) - stations: list[Station] = [] - for i_sta in range(n_stations): - station = Station( - network="FX", - station="STA%02d" % i_sta, - lat=10.0, - lon=10.0, - elevation=rng.uniform(0, 1) * KM, - north_shift=rng.uniform(-10, 10) * KM, - east_shift=rng.uniform(-10, 10) * KM, - ) - stations.append(station) - return Stations(stations=stations) - - -@pytest.fixture(scope="session") -def detections() -> Generator[EventCatalog, None, None]: - n_detections = 2000 - detections: list[EventDetection] = [] - for _ in range(n_detections): - time = datetime_now() - timedelta(days=random.uniform(0, 365)) - detection = EventDetection( - lat=10.0, - lon=10.0, - east_shift=random.uniform(-10, 10) * KM, - north_shift=random.uniform(-10, 10) * KM, - distance_border=1000.0, - semblance=random.uniform(0, 1), - time=time, - ) - detections.append(detection) - with TemporaryDirectory() as tmpdir: - yield EventCatalog(rundir=Path(tmpdir), events=detections) +pytest_plugins = ["qseek.testing"] diff --git a/test/test_moment_magnitude_store.py b/test/test_moment_magnitude_store.py index f1216dbd..7903c580 100644 --- a/test/test_moment_magnitude_store.py +++ b/test/test_moment_magnitude_store.py @@ -56,6 +56,10 @@ async def test_peak_amplitudes(engine: gf.LocalEngine) -> None: ) +@pytest.mark.skipif( + not has_store("reykjanes_qseis"), + reason="reykjanes_qseis not available", +) @pytest.mark.asyncio async def test_peak_amplitude_estimation(engine: gf.LocalEngine) -> None: store_id = "reykjanes_qseis" diff --git a/test/test_utils.py b/test/test_utils.py new file mode 100644 index 00000000..30d55623 --- /dev/null +++ b/test/test_utils.py @@ -0,0 +1,25 @@ +from pydantic import BaseModel +from qseek.utils import NSL + + +def test_nsl(): + nsl_id = "6E.TE234." + nsl = NSL(*nsl_id.split(".")) + + assert nsl.network == "6E" + assert nsl.station == "TE234" + assert nsl.location == "" + + class Model(BaseModel): + nsl: NSL + nsl_list: list[NSL] + + Model(nsl=nsl, nsl_list=[nsl, nsl, nsl]) + + json = """ + { + "nsl": "6E.TE234.", + "nsl_list": ["6E.TE234.", "6E.TE234.", "6E.TE234."] + } + """ + Model.model_validate_json(json) From b49a00a537b1bb50b7c98d0f4cd0ca7e9e0ba8d8 Mon Sep 17 00:00:00 2001 From: Marius Isken Date: Mon, 1 Apr 2024 19:27:08 +0000 Subject: [PATCH 06/26] fixes --- src/qseek/search.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/qseek/search.py b/src/qseek/search.py index 3d0f852b..f834710d 100644 --- a/src/qseek/search.py +++ b/src/qseek/search.py @@ -466,6 +466,7 @@ async def start(self, force_rundir: bool = False) -> None: if self._progress.time_progress: logger.info("continuing search from %s", self._progress.time_progress) + await self._catalog.check(repair=True) await self._catalog.filter_events_by_time( start_time=None, end_time=self._progress.time_progress, From eb2402db4b44c4b8fd043c542179da52d6ad4de3 Mon Sep 17 00:00:00 2001 From: Marius Isken Date: Tue, 7 May 2024 20:26:12 +0000 Subject: [PATCH 07/26] adding stations weights --- pyproject.toml | 9 +-- src/qseek/images/phase_net.py | 12 ++-- src/qseek/models/detection.py | 3 +- src/qseek/models/station.py | 5 +- src/qseek/octree.py | 4 +- src/qseek/search.py | 56 +++++++++++++--- src/qseek/station_weights.py | 119 ++++++++++++++++++++++++++++++++++ src/qseek/tracers/cake.py | 10 +-- 8 files changed, 188 insertions(+), 30 deletions(-) create mode 100644 src/qseek/station_weights.py diff --git a/pyproject.toml b/pyproject.toml index a43ccfb3..2a2491b0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,6 +45,7 @@ dependencies = [ "rich>=13.4", "nest_asyncio>=1.5", "pyevtk>=1.6", + "psutil>=5.9", "aiofiles>=23.0", ] @@ -62,13 +63,7 @@ classifiers = [ ] [project.optional-dependencies] -dev = [ - "pre-commit>=3.4", - "black>=23.7", - "ruff>=0.1.14", - "pytest>=7.4", - "pytest-asyncio>=0.21", -] +dev = ["pre-commit>=3.4", "ruff>=0.3.0", "pytest>=7.4", "pytest-asyncio>=0.21"] docs = [ "mkdocs-material>=9.5.13", diff --git a/src/qseek/images/phase_net.py b/src/qseek/images/phase_net.py index 925552b8..46ccc4c9 100644 --- a/src/qseek/images/phase_net.py +++ b/src/qseek/images/phase_net.py @@ -167,7 +167,7 @@ class PhaseNet(ImageFunction): description="Method to stack the overlaping blocks internally. " "Choose from `avg` and `max`.", ) - upscale_input: PositiveFloat = Field( + rescale_input: PositiveFloat = Field( default=1.0, description="Upscale input by factor. " "This augments the input data from e.g. 100 Hz to 50 Hz (factor: `2`). Can be" @@ -219,15 +219,15 @@ def _prepare(self) -> None: def get_blinding(self, sampling_rate: float) -> timedelta: blinding_samples = ( - max(self.phase_net.default_args["blinding"]) / self.upscale_input + max(self.phase_net.default_args["blinding"]) / self.rescale_input ) return timedelta(seconds=blinding_samples / sampling_rate) @alog_call async def process_traces(self, traces: list[Trace]) -> list[PhaseNetImage]: stream = Stream(tr.to_obspy_trace() for tr in traces) - if self.upscale_input > 1: - scale = self.upscale_input + if self.rescale_input > 1: + scale = self.rescale_input for tr in stream: tr.stats.sampling_rate = tr.stats.sampling_rate / scale @@ -239,8 +239,8 @@ async def process_traces(self, traces: list[Trace]) -> list[PhaseNetImage]: copy=False, ) - if self.upscale_input > 1: - scale = self.upscale_input + if self.rescale_input > 1: + scale = self.rescale_input for tr in annotations: tr.stats.sampling_rate = tr.stats.sampling_rate * scale blinding_samples = self.phase_net.default_args["blinding"][0] diff --git a/src/qseek/models/detection.py b/src/qseek/models/detection.py index e1d4a5bd..89526cf4 100644 --- a/src/qseek/models/detection.py +++ b/src/qseek/models/detection.py @@ -274,7 +274,8 @@ async def get_waveforms( tmax = max(times).timestamp() + seconds_after nslc_ids = [(*receiver.nsl, "*") for receiver in receivers] async with SQUIRREL_SEM: - traces = await squirrel.get_waveforms_async( + traces = await asyncio.to_thread( + squirrel.get_waveforms, codes=nslc_ids, tmin=tmin, tmax=tmax, diff --git a/src/qseek/models/station.py b/src/qseek/models/station.py index 1bfd7145..f2d8e3e3 100644 --- a/src/qseek/models/station.py +++ b/src/qseek/models/station.py @@ -78,7 +78,8 @@ class Stations(BaseModel): ) station_xmls: list[FilePath | DirectoryPath] = Field( default=[], - description="List of StationXML files.", + description="List of StationXML files or " + "directories containing StationXML (.xml) files.", ) blacklist: set[constr(pattern=NSL_RE)] = Field( @@ -211,6 +212,8 @@ def get_centroid(self) -> Location: ) def get_coordinates(self, system: CoordSystem = "geographic") -> np.ndarray: + if system != "geographic": + raise NotImplementedError("only geographic coordinates are implemented.") return np.array( [(*sta.effective_lat_lon, sta.effective_elevation) for sta in self] ) diff --git a/src/qseek/octree.py b/src/qseek/octree.py index f2a89e6b..7a55c97b 100644 --- a/src/qseek/octree.py +++ b/src/qseek/octree.py @@ -10,7 +10,7 @@ from hashlib import sha1 from operator import mul from pathlib import Path -from typing import TYPE_CHECKING, Any, Callable, Iterable, Iterator, Literal, Sequence +from typing import TYPE_CHECKING, Any, Callable, Iterable, Iterator, Literal import numpy as np import scipy.interpolate @@ -39,7 +39,7 @@ def get_node_coordinates( - nodes: Sequence[Node], + nodes: Iterable[Node], system: CoordSystem = "geographic", ) -> np.ndarray: if system == "geographic": diff --git a/src/qseek/search.py b/src/qseek/search.py index f834710d..a6c98086 100644 --- a/src/qseek/search.py +++ b/src/qseek/search.py @@ -9,8 +9,10 @@ from typing import TYPE_CHECKING, Deque, Literal import numpy as np +import psutil from pydantic import ( BaseModel, + ByteSize, ConfigDict, Field, PositiveFloat, @@ -31,6 +33,7 @@ from qseek.octree import NodeSplitError, Octree from qseek.pre_processing.module import Downsample, PreProcessing from qseek.signals import Signal +from qseek.station_weights import StationWeights from qseek.stats import RuntimeStats, Stats from qseek.tracers.tracers import RayTracer, RayTracers from qseek.utils import ( @@ -69,11 +72,16 @@ class SearchStats(Stats): latest_processing_rate: float = 0.0 latest_processing_speed: timedelta = timedelta(seconds=0.0) + memory_total: ByteSize = Field( + default_factory=lambda: ByteSize(psutil.virtual_memory().total) + ) + _search_start: datetime = PrivateAttr(default_factory=datetime_now) _batch_processing_times: Deque[timedelta] = PrivateAttr( default_factory=lambda: deque(maxlen=25) ) _position: int = PrivateAttr(0) + _process: psutil.Process = PrivateAttr(default_factory=psutil.Process) @computed_field @property @@ -119,6 +127,16 @@ def processed_percent(self) -> float: return 0.0 return self.batch_count / self.batch_count_total * 100.0 + @computed_field + @property + def memory_used(self) -> int: + return self._process.memory_info().rss + + @computed_field + @property + def cpu_percent(self) -> float: + return self._process.cpu_percent(interval=None) + def reset_start_time(self) -> None: self._search_start = datetime_now() @@ -162,6 +180,12 @@ def tts(duration: timedelta) -> str: "Project", f"[bold]{self.project_name}[/bold]", ) + table.add_row( + "Resources", + f"CPU {self.cpu_percent:.1f}%, " + f"RAM {human_readable_bytes(self.memory_used)}" + f"/{self.memory_total.human_readable()}", + ) table.add_row( "Progress ", f"[bold]{self.processed_percent:.1f}%[/bold]" @@ -213,6 +237,10 @@ class Search(BaseModel): default=RayTracers(root=[tracer() for tracer in RayTracer.get_subclasses()]), description="List of ray tracers for travel time calculation.", ) + station_weights: StationWeights | None = Field( + default=StationWeights(), + description="Station weights for spatial weighting.", + ) station_corrections: StationCorrectionType | None = Field( default=None, description="Apply station corrections extracted from a previous run.", @@ -434,6 +462,9 @@ async def prepare(self) -> None: self.data_provider.prepare(self.stations) await self.pre_processing.prepare() + if self.station_weights: + self.station_weights.prepare(self.stations, self.octree) + if self.station_corrections: await self.station_corrections.prepare( self.stations, @@ -679,16 +710,26 @@ async def calculate_semblance( traveltimes_bad = np.isnan(traveltimes) traveltimes[traveltimes_bad] = 0.0 - station_contribution = (~traveltimes_bad).sum(axis=1, dtype=np.float32) + # station_contribution = (~traveltimes_bad).sum(axis=1, dtype=np.float32) - shifts = -np.round(traveltimes / image.delta_t).astype(np.int32) - weights = np.full_like(shifts, fill_value=image.weight, dtype=np.float32) + shifts = np.round(-traveltimes / image.delta_t).astype(np.int32) + # weights = np.full_like(shifts, fill_value=image.weight, dtype=np.float32) # Normalize by number of station contribution - with np.errstate(divide="ignore", invalid="ignore"): - weights /= station_contribution[:, np.newaxis] - weights /= self.images.cumulative_weight() + # with np.errstate(divide="ignore", invalid="ignore"): + # weights /= station_contribution[:, np.newaxis] + + weights = np.full_like(shifts, fill_value=image.weight, dtype=np.float32) weights[traveltimes_bad] = 0.0 + if parent.station_weights: + weights *= await parent.station_weights.get_weights(octree, image.stations) + + with np.errstate(divide="ignore", invalid="ignore"): + weights /= weights.sum(axis=1, keepdims=True) + + # applying waterlevel + weights[weights < 1e-3] = 0.0 + if semblance_cache: cache_mask = semblance_cache.get_mask(semblance.node_hashes) weights[cache_mask] = 0.0 @@ -768,8 +809,7 @@ async def search( ) # Applying the generalized mean to the semblance - # semblance.normalize( - # images.cumulative_weight(), semblance_cache=semblance_cache) + # semblance.normalize(images.cumulative_weight(), semblance_cache=semblance_cache) if semblance_cache: await semblance.apply_cache(semblance_cache) diff --git a/src/qseek/station_weights.py b/src/qseek/station_weights.py new file mode 100644 index 00000000..3500a252 --- /dev/null +++ b/src/qseek/station_weights.py @@ -0,0 +1,119 @@ +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING, Iterable, Sequence + +import numpy as np +import pyrocko.orthodrome as od +from lru import LRU +from pydantic import BaseModel, ByteSize, Field, PositiveFloat, PrivateAttr + +from qseek.octree import get_node_coordinates + +if TYPE_CHECKING: + from qseek.models.station import Station, Stations + from qseek.octree import Node, Octree + +MB = 1024**2 + +logger = logging.getLogger(__name__) + + +class StationWeights(BaseModel): + exponent: float = Field( + default=0.5, + description="Exponent of the exponential decay function. Default is 1.5.", + ge=0.0, + le=3.0, + ) + radius_meters: PositiveFloat = Field( + default=8000.0, + description="Radius in meters for the exponential decay function. " + "Default is 8000.", + ) + lut_cache_size: ByteSize = Field( + default=200 * MB, + description="Size of the LRU cache in bytes. Default is 1e9.", + ) + + _node_lut: dict[bytes, np.ndarray] = PrivateAttr() + _cached_stations_indices: dict[str, int] = PrivateAttr() + _station_coords_ecef: np.ndarray = PrivateAttr() + + def get_distances(self, nodes: Iterable[Node]) -> np.ndarray: + node_coords = get_node_coordinates(nodes, system="geographic") + node_coords = np.array(od.geodetic_to_ecef(*node_coords.T)).T + return np.linalg.norm( + self._station_coords_ecef - node_coords[:, np.newaxis], axis=2 + ) + + def calc_weights(self, distances: np.ndarray) -> np.ndarray: + exp = self.exponent + # radius = distances.min(axis=1)[:, np.newaxis] + radius = self.radius_meters + return np.exp(-(distances**exp) / (radius**exp)) + + def prepare(self, stations: Stations, octree: Octree) -> None: + logger.info("preparing station weights") + + bytes_per_node = stations.n_stations * np.float32().itemsize + lru_cache_size = int(self.lut_cache_size / bytes_per_node) + self._node_lut = LRU(size=lru_cache_size) + + sta_coords = stations.get_coordinates(system="geographic") + self._station_coords_ecef = np.array(od.geodetic_to_ecef(*sta_coords.T)).T + self._cached_stations_indices = { + sta.nsl.pretty: idx for idx, sta in enumerate(stations) + } + self.fill_lut(nodes=list(octree)) + + def fill_lut(self, nodes: Sequence[Node]) -> None: + logger.debug("filling weight lut for %d nodes", len(nodes)) + distances = self.get_distances(nodes) + for node, sta_distances in zip(nodes, distances, strict=True): + sta_distances = sta_distances.astype(np.float32) + sta_distances.setflags(write=False) + self._node_lut[node.hash()] = sta_distances + + def get_node_weights(self, node: Node, stations: list[Station]) -> np.ndarray: + try: + distances = self._node_lut[node.hash()] + except KeyError: + self.fill_lut([node]) + return self.get_node_weights(node, stations) + return self.calc_weights(distances) + + def lut_fill_level(self) -> float: + """Return the fill level of the LUT as a float between 0.0 and 1.0.""" + return len(self._node_lut) / self._node_lut.get_size() + + async def get_weights(self, octree: Octree, stations: Stations) -> np.ndarray: + station_indices = np.fromiter( + (self._cached_stations_indices[sta.nsl.pretty] for sta in stations), + dtype=int, + ) + distances = np.zeros( + shape=(octree.n_nodes, stations.n_stations), dtype=np.float32 + ) + + fill_nodes = [] + for idx, node in enumerate(octree): + try: + distances[idx] = self._node_lut[node.hash()][station_indices] + except KeyError: + cache_hits, cache_misses = self._node_lut.get_stats() + total_hits = cache_hits + cache_misses + cache_hit_rate = cache_hits / (total_hits or 1) + logger.debug( + "node LUT cache fill level %.1f%%, cache hit rate %.1f%%", + self.lut_fill_level() * 100, + cache_hit_rate * 100, + ) + fill_nodes.append(node) + continue + + if fill_nodes: + self.fill_lut(fill_nodes) + return await self.get_weights(octree, stations) + + return self.calc_weights(distances) diff --git a/src/qseek/tracers/cake.py b/src/qseek/tracers/cake.py index f916ec55..f6e0ee3f 100644 --- a/src/qseek/tracers/cake.py +++ b/src/qseek/tracers/cake.py @@ -11,7 +11,7 @@ from io import BytesIO from pathlib import Path from tempfile import NamedTemporaryFile, TemporaryDirectory -from typing import TYPE_CHECKING, Literal, Sequence +from typing import TYPE_CHECKING, Iterator, Literal, Sequence import matplotlib.pyplot as plt import numpy as np @@ -187,7 +187,7 @@ def id(self) -> str: return re.sub(r"[\,\s\;]", "", self.definition) -def surface_distances(nodes: Sequence[Node], stations: Stations) -> np.ndarray: +def surface_distances(nodes: Iterator[Node], stations: Stations) -> np.ndarray: """Returns the surface distance from all nodes to all stations. Args: @@ -224,7 +224,7 @@ class TravelTimeTree(BaseModel): _file: Path | None = PrivateAttr(None) _cached_stations: Stations = PrivateAttr() - _cached_station_indeces: dict[str, int] = PrivateAttr({}) + _cached_station_indices: dict[str, int] = PrivateAttr({}) _node_lut: dict[bytes, np.ndarray] = PrivateAttr( default_factory=lambda: LRU(LRU_CACHE_SIZE) ) @@ -388,7 +388,7 @@ async def init_lut(self, octree: Octree, stations: Stations) -> None: octree.n_nodes, ) self._cached_stations = stations - self._cached_station_indeces = { + self._cached_station_indices = { sta.nsl.pretty: idx for idx, sta in enumerate(stations) } station_traveltimes = await self.interpolate_travel_times(octree, stations) @@ -417,7 +417,7 @@ async def fill_lut(self, nodes: Sequence[Node]) -> None: async def get_travel_times(self, octree: Octree, stations: Stations) -> np.ndarray: try: station_indices = np.fromiter( - (self._cached_station_indeces[sta.nsl.pretty] for sta in stations), + (self._cached_station_indices[sta.nsl.pretty] for sta in stations), dtype=int, ) except KeyError as exc: From 71be4c37c8cd30b7114ac6ed2c8a34b579e581fa Mon Sep 17 00:00:00 2001 From: Marius Isken Date: Tue, 14 May 2024 10:03:31 +0000 Subject: [PATCH 08/26] bugfixes --- src/qseek/images/base.py | 4 ++++ src/qseek/search.py | 2 ++ src/qseek/utils.py | 4 +++- test/test_utils.py | 2 +- 4 files changed, 10 insertions(+), 2 deletions(-) diff --git a/src/qseek/images/base.py b/src/qseek/images/base.py index a19621d5..6db88af7 100644 --- a/src/qseek/images/base.py +++ b/src/qseek/images/base.py @@ -58,6 +58,10 @@ class WaveformImage: def sampling_rate(self) -> float: return 1.0 / self.delta_t + @property + def has_traces(self) -> bool: + return bool(self.traces) + @property def delta_t(self) -> float: return self.traces[0].deltat diff --git a/src/qseek/search.py b/src/qseek/search.py index a6c98086..4381d23d 100644 --- a/src/qseek/search.py +++ b/src/qseek/search.py @@ -800,6 +800,8 @@ async def search( ) for image in images: + if not image.has_traces: + continue await self.calculate_semblance( octree=octree, image=image, diff --git a/src/qseek/utils.py b/src/qseek/utils.py index 8176f1e3..59502da1 100644 --- a/src/qseek/utils.py +++ b/src/qseek/utils.py @@ -116,7 +116,7 @@ def match(self, other: NSL) -> bool: return self.network == other.network @classmethod - def parse(cls, nsl: str | NSL) -> NSL: + def parse(cls, nsl: str | NSL | list[str]) -> NSL: """Parse the given NSL string and return an NSL object. Args: @@ -132,6 +132,8 @@ def parse(cls, nsl: str | NSL) -> NSL: raise ValueError("invalid empty NSL") if type(nsl) is _NSL: return nsl + if isinstance(nsl, list): + return cls(*nsl) if not isinstance(nsl, str): raise ValueError(f"invalid NSL {nsl}") parts = nsl.split(".") diff --git a/test/test_utils.py b/test/test_utils.py index 30d55623..1adff0ec 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -19,7 +19,7 @@ class Model(BaseModel): json = """ { "nsl": "6E.TE234.", - "nsl_list": ["6E.TE234.", "6E.TE234.", "6E.TE234."] + "nsl_list": ["6E.TE234.", "6E.TE234.", "6E.TE234.", ["6E", "TY123", ""]] } """ Model.model_validate_json(json) From 5ff007aa106522108e1710ba29a413f207711b7d Mon Sep 17 00:00:00 2001 From: Marius Isken Date: Sun, 19 May 2024 09:44:55 +0000 Subject: [PATCH 09/26] utils: better nsl handling --- .pre-commit-config.yaml | 4 ++-- README.md | 9 ++++++--- src/qseek/utils.py | 39 +++++++++++++++++++++++++++++++++------ test/test_utils.py | 22 ++++++++++++++++++++++ 4 files changed, 63 insertions(+), 11 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 883406b7..9a64928e 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -2,7 +2,7 @@ # See https://pre-commit.com/hooks.html for more hooks repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.5.0 + rev: v4.6.0 hooks: - id: trailing-whitespace - id: end-of-file-fixer @@ -10,7 +10,7 @@ repos: - id: mixed-line-ending - repo: https://github.com/astral-sh/ruff-pre-commit # Ruff version. - rev: v0.3.2 + rev: v0.4.4 hooks: - id: ruff - id: ruff-format diff --git a/README.md b/README.md index 63abac6b..d9c7ae19 100644 --- a/README.md +++ b/README.md @@ -20,10 +20,13 @@ Key features are of the earthquake detection and localisation framework are: * 1D Layered velocity model * 3D fast-marching velocity model (NonLinLoc compatible) * Extraction of earthquake event features: - * Local magnitudes - * Ground motion attributes + * Moment Magnitudes (MW) based on modelled peak ground motions + * Local magnitudes (ML), different models + * Ground motion attributes (e.g. PGA, PGV, ...) * Automatic extraction of modelled and picked travel times -* Calculation and application of station corrections / station delay times +* Station Corrections + * station specific corrections (SST) + * source specific station corrections (SSST) Qseek is built on top of [Pyrocko](https://pyrocko.org). diff --git a/src/qseek/utils.py b/src/qseek/utils.py index 59502da1..b5f25642 100644 --- a/src/qseek/utils.py +++ b/src/qseek/utils.py @@ -116,7 +116,7 @@ def match(self, other: NSL) -> bool: return self.network == other.network @classmethod - def parse(cls, nsl: str | NSL | list[str]) -> NSL: + def parse(cls, nsl: str | NSL | list[str] | tuple[str, str, str]) -> NSL: """Parse the given NSL string and return an NSL object. Args: @@ -132,22 +132,49 @@ def parse(cls, nsl: str | NSL | list[str]) -> NSL: raise ValueError("invalid empty NSL") if type(nsl) is _NSL: return nsl - if isinstance(nsl, list): + if isinstance(nsl, (list, tuple)): return cls(*nsl) if not isinstance(nsl, str): raise ValueError(f"invalid NSL {nsl}") + parts = nsl.split(".") n_parts = len(parts) if n_parts >= 3: return cls(*parts[:3]) if n_parts == 2: return cls(parts[0], parts[1], "") - if n_parts == 1: - return cls(parts[0], "", "") - raise ValueError(f"invalid NSL {nsl}") + raise ValueError( + f"invalid NSL `{nsl}`, expecting `..`, " + "e.g. `6A.STA130.00`, `6A.STA130` or `.STA130`" + ) + + def _check(self) -> None: + """Check if the current NSL object matches another NSL object. + + Args: + nsl (NSL): The NSL object to compare with. + + Returns: + bool: True if the objects match, False otherwise. + """ + if len(self.network) > 2: + raise ValueError( + f"invalid network {self.network} for {self.pretty}," + " expected 0-2 characters for network code" + ) + if len(self.station) > 5 or len(self.station) < 1: + raise ValueError( + f"invalid station {self.station} for {self.pretty}," + " expected 1-5 characters for station code" + ) + if len(self.location) > 2: + raise ValueError( + f"invalid location {self.location} for {self.pretty}," + " expected 0-2 characters for location code" + ) -NSL = Annotated[_NSL, BeforeValidator(_NSL.parse)] +NSL = Annotated[_NSL, BeforeValidator(_NSL.parse), AfterValidator(_NSL._check)] class _Range(NamedTuple): diff --git a/test/test_utils.py b/test/test_utils.py index 1adff0ec..130bbbfd 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -1,3 +1,4 @@ +import pytest from pydantic import BaseModel from qseek.utils import NSL @@ -23,3 +24,24 @@ class Model(BaseModel): } """ Model.model_validate_json(json) + + json = """ + { + "nsl": "6E.TE234.", + "nsl_list": [".TE232"] + } + """ + Model.model_validate_json(json) + + json_tpl = """ + {{ + "nsl": "{code}", + "nsl_list": [] + }} + """ + + invalid_codes = ["6E", "6E5.", "6E.", "6E.TE123112"] + + for code in invalid_codes: + with pytest.raises(ValueError): + Model.model_validate_json(json_tpl.format(code=code)) From 3ba1cd6103ba3318004f773c85840d4d474a4195 Mon Sep 17 00:00:00 2001 From: Marius Isken Date: Fri, 24 May 2024 13:55:53 +0000 Subject: [PATCH 10/26] fixes to squirrel persistence // adding local magnitude argetina --- src/qseek/magnitudes/local_magnitude_model.py | 11 +++++++ src/qseek/octree.py | 2 +- src/qseek/pre_processing/downsample.py | 2 +- src/qseek/search.py | 11 +++---- src/qseek/waveforms/squirrel.py | 29 +++++++++++-------- 5 files changed, 36 insertions(+), 19 deletions(-) diff --git a/src/qseek/magnitudes/local_magnitude_model.py b/src/qseek/magnitudes/local_magnitude_model.py index db30b10b..ed368dda 100644 --- a/src/qseek/magnitudes/local_magnitude_model.py +++ b/src/qseek/magnitudes/local_magnitude_model.py @@ -378,4 +378,15 @@ def get_amp_attenuation(dist_hypo_km: float, dist_epi_km: float) -> float: return 0.89 * np.log10(dist_epi_km / 100) + 0.00256 * (dist_epi_km - 100) + 3 +class ArgentinaVolcanoes(WoodAnderson, LocalMagnitudeModel): + author = "Montenegro et al. (2021)" + + epicentral_range = Range(0.0 * KM, 100.0 * KM) # Bounds are not clear + component = "north-east-separate" + + @staticmethod + def get_amp_attenuation(dist_hypo_km: float, dist_epi_km: float) -> float: + return 2.76 * np.log10(dist_epi_km) - 2.48 + + ModelName = Literal[LocalMagnitudeModel.model_names()] diff --git a/src/qseek/octree.py b/src/qseek/octree.py index 7a55c97b..2347bcc0 100644 --- a/src/qseek/octree.py +++ b/src/qseek/octree.py @@ -306,7 +306,7 @@ class Octree(BaseModel): description="The reference location of the octree.", ) root_node_size: PositiveFloat = Field( - default=2 * KM, + default=1 * KM, description="Initial size of the root octree node at level 0 in meters.", ) n_levels: int = Field( diff --git a/src/qseek/pre_processing/downsample.py b/src/qseek/pre_processing/downsample.py index b9454dec..14ced59b 100644 --- a/src/qseek/pre_processing/downsample.py +++ b/src/qseek/pre_processing/downsample.py @@ -16,7 +16,7 @@ class Downsample(BatchPreProcessing): process: Literal["downsample"] = "downsample" sampling_frequency: PositiveFloat = Field( - 100.0, + default=100.0, description="The new sampling frequency in Hz.", ) diff --git a/src/qseek/search.py b/src/qseek/search.py index 4381d23d..1cef0d03 100644 --- a/src/qseek/search.py +++ b/src/qseek/search.py @@ -31,6 +31,7 @@ from qseek.models.detection_uncertainty import DetectionUncertainty from qseek.models.semblance import Semblance, SemblanceCache from qseek.octree import NodeSplitError, Octree +from qseek.pre_processing.frequency_filters import Bandpass from qseek.pre_processing.module import Downsample, PreProcessing from qseek.signals import Signal from qseek.station_weights import StationWeights @@ -215,11 +216,11 @@ class Search(BaseModel): description="Station inventory from StationXML or Pyrocko Station YAML.", ) data_provider: WaveformProviderType = Field( - default=PyrockoSquirrel(), + default_factory=PyrockoSquirrel.model_construct, description="Data provider for waveform data.", ) pre_processing: PreProcessing = Field( - default=PreProcessing(root=[Downsample(sampling_frequency=100.0)]), + default=PreProcessing(root=[Downsample(), Bandpass()]), description="Pre-processing steps for waveform data.", ) @@ -379,7 +380,7 @@ def write_config(self, path: Path | None = None) -> None: logger.debug("writing search config to %s", path) path.write_text(self.model_dump_json(indent=2, exclude_unset=True)) - logger.debug("dumping stations...") + logger.debug("dumping stations") self.stations.export_pyrocko_stations(rundir / "pyrocko_stations.yaml") csv_dir = rundir / "csv" @@ -458,7 +459,7 @@ async def prepare(self) -> None: Returns: None """ - logger.info("preparing search...") + logger.info("preparing search components") self.data_provider.prepare(self.stations) await self.pre_processing.prepare() @@ -489,7 +490,7 @@ async def start(self, force_rundir: bool = False) -> None: await self.prepare() - logger.info("starting search...") + logger.info("starting search") stats = self._stats stats.reset_start_time() diff --git a/src/qseek/waveforms/squirrel.py b/src/qseek/waveforms/squirrel.py index ebda8a7b..73ff248a 100644 --- a/src/qseek/waveforms/squirrel.py +++ b/src/qseek/waveforms/squirrel.py @@ -14,7 +14,6 @@ PositiveInt, PrivateAttr, computed_field, - field_validator, model_validator, ) from pyrocko.squirrel import Squirrel @@ -113,10 +112,14 @@ class PyrockoSquirrel(WaveformProvider): provider: Literal["PyrockoSquirrel"] = "PyrockoSquirrel" - environment: DirectoryPath = Field( - default=Path("."), + environment: DirectoryPath | None = Field( + default=None, description="Path to a Squirrel environment.", ) + persistent: str | None = Field( + default=None, + description="Name of the persistent collection for faster loading.", + ) waveform_dirs: list[Path] = Field( default=[], description="List of directories holding the waveform files.", @@ -151,18 +154,20 @@ class PyrockoSquirrel(WaveformProvider): def _validate_model(self) -> Self: if self.start_time and self.end_time and self.start_time > self.end_time: raise ValueError("start_time must be before end_time") + if not self.waveform_dirs and not self.persistent: + raise ValueError("no waveform directories or persistent selection provided") return self - @field_validator("waveform_dirs") - def check_dirs(cls, dirs: list[Path]) -> list[Path]: # noqa: N805 - if not dirs: - raise ValueError("no waveform directories provided!") - return dirs - def get_squirrel(self) -> Squirrel: if not self._squirrel: - logger.info("initializing squirrel waveform provider") - squirrel = Squirrel(str(self.environment.expanduser())) + logger.info( + "initializing squirrel waveform provider in environment %s", + self.environment, + ) + squirrel = Squirrel( + env=str(self.environment.expanduser()) if self.environment else None, + persistent=self.persistent, + ) paths = [] for path in self.waveform_dirs: if "**" in str(path): @@ -173,7 +178,7 @@ def get_squirrel(self) -> Squirrel: squirrel.add(paths, check=False) if self._stations: for path in self._stations.station_xmls: - logger.info("loading responses from %s", path) + logger.info("loading StationXML responses from %s", path) squirrel.add(str(path), check=False) self._squirrel = squirrel return self._squirrel From f06c178e41893435c6b8f8399b9c43f70345ea85 Mon Sep 17 00:00:00 2001 From: Marius Isken Date: Mon, 27 May 2024 15:10:17 +0000 Subject: [PATCH 11/26] upd --- src/qseek/apps/qseek.py | 4 ++- src/qseek/images/phase_net.py | 1 + src/qseek/models/station.py | 1 + src/qseek/search.py | 24 ++++++++------- ...{station_weights.py => spatial_weights.py} | 30 +++++++++++++------ src/qseek/tracers/cake.py | 17 +++++++---- src/qseek/utils.py | 6 ++-- 7 files changed, 54 insertions(+), 29 deletions(-) rename src/qseek/{station_weights.py => spatial_weights.py} (85%) diff --git a/src/qseek/apps/qseek.py b/src/qseek/apps/qseek.py index 1cc66462..55561abb 100644 --- a/src/qseek/apps/qseek.py +++ b/src/qseek/apps/qseek.py @@ -7,11 +7,13 @@ import logging import shutil from pathlib import Path +from typing import TYPE_CHECKING import nest_asyncio from pkg_resources import get_distribution -from qseek.models.detection import EventDetection +if TYPE_CHECKING: + from qseek.models.detection import EventDetection nest_asyncio.apply() diff --git a/src/qseek/images/phase_net.py b/src/qseek/images/phase_net.py index 46ccc4c9..c658c3b1 100644 --- a/src/qseek/images/phase_net.py +++ b/src/qseek/images/phase_net.py @@ -68,6 +68,7 @@ def search_phase_arrival( Returns: datetime | None: Time of arrival, None is none found. """ + # TODO adapt threshold to the seisbench model trace = self.traces[trace_idx] window_length = timedelta(seconds=search_window_seconds) try: diff --git a/src/qseek/models/station.py b/src/qseek/models/station.py index f2d8e3e3..2b9dc6e0 100644 --- a/src/qseek/models/station.py +++ b/src/qseek/models/station.py @@ -165,6 +165,7 @@ def weed_from_squirrel_waveforms(self, squirrel: Squirrel) -> None: raise ValueError("no stations available, add waveforms to start detection") def __iter__(self) -> Iterator[Station]: + # TODO: this is inefficient return (sta for sta in self.stations if sta.nsl.pretty not in self.blacklist) @property diff --git a/src/qseek/search.py b/src/qseek/search.py index 1cef0d03..2b5e29de 100644 --- a/src/qseek/search.py +++ b/src/qseek/search.py @@ -34,7 +34,7 @@ from qseek.pre_processing.frequency_filters import Bandpass from qseek.pre_processing.module import Downsample, PreProcessing from qseek.signals import Signal -from qseek.station_weights import StationWeights +from qseek.spatial_weights import SpatialWeights from qseek.stats import RuntimeStats, Stats from qseek.tracers.tracers import RayTracer, RayTracers from qseek.utils import ( @@ -184,8 +184,8 @@ def tts(duration: timedelta) -> str: table.add_row( "Resources", f"CPU {self.cpu_percent:.1f}%, " - f"RAM {human_readable_bytes(self.memory_used)}" - f"/{self.memory_total.human_readable()}", + f"RAM {human_readable_bytes(self.memory_used, decimal=True)}" + f"/{self.memory_total.human_readable(decimal=True)}", ) table.add_row( "Progress ", @@ -238,9 +238,9 @@ class Search(BaseModel): default=RayTracers(root=[tracer() for tracer in RayTracer.get_subclasses()]), description="List of ray tracers for travel time calculation.", ) - station_weights: StationWeights | None = Field( - default=StationWeights(), - description="Station weights for spatial weighting.", + spatial_weights: SpatialWeights | None = Field( + default=SpatialWeights(), + description="Spatial weights for distance weighting.", ) station_corrections: StationCorrectionType | None = Field( default=None, @@ -463,8 +463,8 @@ async def prepare(self) -> None: self.data_provider.prepare(self.stations) await self.pre_processing.prepare() - if self.station_weights: - self.station_weights.prepare(self.stations, self.octree) + if self.spatial_weights: + self.spatial_weights.prepare(self.stations, self.octree) if self.station_corrections: await self.station_corrections.prepare( @@ -722,8 +722,8 @@ async def calculate_semblance( weights = np.full_like(shifts, fill_value=image.weight, dtype=np.float32) weights[traveltimes_bad] = 0.0 - if parent.station_weights: - weights *= await parent.station_weights.get_weights(octree, image.stations) + if parent.spatial_weights: + weights *= await parent.spatial_weights.get_weights(octree, image.stations) with np.errstate(divide="ignore", invalid="ignore"): weights /= weights.sum(axis=1, keepdims=True) @@ -861,7 +861,9 @@ async def search( except NodeSplitError: continue logger.info( - "energy detected, refined %d nodes, level %d", + "detected %d energy burst%s - refined %d nodes, lowest level %d", + detection_idx.size, + "s" if detection_idx.size > 1 else "", len(refine_nodes), new_level, ) diff --git a/src/qseek/station_weights.py b/src/qseek/spatial_weights.py similarity index 85% rename from src/qseek/station_weights.py rename to src/qseek/spatial_weights.py index 3500a252..bb241230 100644 --- a/src/qseek/station_weights.py +++ b/src/qseek/spatial_weights.py @@ -19,21 +19,26 @@ logger = logging.getLogger(__name__) -class StationWeights(BaseModel): +class SpatialWeights(BaseModel): exponent: float = Field( - default=0.5, - description="Exponent of the exponential decay function. Default is 1.5.", + default=3.0, + description="Exponent of the spatial decay function. Default is 3.", ge=0.0, - le=3.0, ) radius_meters: PositiveFloat = Field( default=8000.0, - description="Radius in meters for the exponential decay function. " - "Default is 8000.", + description="Cutoff distance for the spatial decay function in meters." + " Default is 8000.", + ) + waterlevel: float = Field( + default=0.0, + ge=0.0, + le=1.0, + description="Waterlevel for the exponential decay function. Default is 0.0.", ) lut_cache_size: ByteSize = Field( default=200 * MB, - description="Size of the LRU cache in bytes. Default is 1e9.", + description="Size of the LRU cache in bytes. Default is 200 MB.", ) _node_lut: dict[bytes, np.ndarray] = PrivateAttr() @@ -47,14 +52,21 @@ def get_distances(self, nodes: Iterable[Node]) -> np.ndarray: self._station_coords_ecef - node_coords[:, np.newaxis], axis=2 ) - def calc_weights(self, distances: np.ndarray) -> np.ndarray: + def calc_weights_exp(self, distances: np.ndarray) -> np.ndarray: exp = self.exponent # radius = distances.min(axis=1)[:, np.newaxis] radius = self.radius_meters return np.exp(-(distances**exp) / (radius**exp)) + def calc_weights(self, distances: np.ndarray) -> np.ndarray: + exp = self.exponent + radius = self.radius_meters + return (1 - self.waterlevel) / ( + 1 + (distances / radius) ** exp + ) + self.waterlevel + def prepare(self, stations: Stations, octree: Octree) -> None: - logger.info("preparing station weights") + logger.info("preparing spatial weights") bytes_per_node = stations.n_stations * np.float32().itemsize lru_cache_size = int(self.lut_cache_size / bytes_per_node) diff --git a/src/qseek/tracers/cake.py b/src/qseek/tracers/cake.py index f6e0ee3f..1c7594aa 100644 --- a/src/qseek/tracers/cake.py +++ b/src/qseek/tracers/cake.py @@ -648,9 +648,10 @@ def get_travel_time_location( source: Location, receiver: Location, ) -> float: - if phase not in self.phases: - raise ValueError(f"Phase {phase} is not defined.") - tree = self._get_sptree_model(phase) + try: + tree = self._get_sptree_model(phase) + except KeyError as exc: + raise ValueError(f"Phase {phase} is not defined.") from exc return tree.get_travel_time(source, receiver) async def get_travel_times( @@ -659,9 +660,13 @@ async def get_travel_times( octree: Octree, stations: Stations, ) -> np.ndarray: - if phase not in self.phases: - raise ValueError(f"Phase {phase} is not defined.") - return await self._get_sptree_model(phase).get_travel_times(octree, stations) + try: + return await self._get_sptree_model(phase).get_travel_times( + octree, + stations, + ) + except KeyError as exc: + raise ValueError(f"Phase {phase} is not defined.") from exc def get_arrivals( self, diff --git a/src/qseek/utils.py b/src/qseek/utils.py index b5f25642..ed40411a 100644 --- a/src/qseek/utils.py +++ b/src/qseek/utils.py @@ -386,17 +386,19 @@ async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: return wrapper -def human_readable_bytes(size: int | float) -> str: +def human_readable_bytes(size: int | float, decimal: bool = False) -> str: """Convert a size in bytes to a human-readable string representation. Args: size (int | float): The size in bytes. + decimal: If True, use decimal units (e.g. 1000 bytes per KB). + If False, use binary units (e.g. 1024 bytes per KiB). Returns: str: The human-readable string representation of the size. """ - return ByteSize(size).human_readable() + return ByteSize.human_readable(size, decimal=decimal) def datetime_now() -> datetime: From d2fc94701d78c9a67a693fb20aeb8a17a8c45325 Mon Sep 17 00:00:00 2001 From: Marius Isken Date: Fri, 31 May 2024 13:41:35 +0000 Subject: [PATCH 12/26] renaming distance weights --- ...spatial_weights.py => distance_weights.py} | 18 +++++++-------- src/qseek/search.py | 15 ++++++------ src/qseek/waveforms/squirrel.py | 23 ++++++++++--------- 3 files changed, 29 insertions(+), 27 deletions(-) rename src/qseek/{spatial_weights.py => distance_weights.py} (90%) diff --git a/src/qseek/spatial_weights.py b/src/qseek/distance_weights.py similarity index 90% rename from src/qseek/spatial_weights.py rename to src/qseek/distance_weights.py index bb241230..93cfd9d3 100644 --- a/src/qseek/spatial_weights.py +++ b/src/qseek/distance_weights.py @@ -19,7 +19,7 @@ logger = logging.getLogger(__name__) -class SpatialWeights(BaseModel): +class DistanceWeights(BaseModel): exponent: float = Field( default=3.0, description="Exponent of the spatial decay function. Default is 3.", @@ -113,19 +113,19 @@ async def get_weights(self, octree: Octree, stations: Stations) -> np.ndarray: try: distances[idx] = self._node_lut[node.hash()][station_indices] except KeyError: - cache_hits, cache_misses = self._node_lut.get_stats() - total_hits = cache_hits + cache_misses - cache_hit_rate = cache_hits / (total_hits or 1) - logger.debug( - "node LUT cache fill level %.1f%%, cache hit rate %.1f%%", - self.lut_fill_level() * 100, - cache_hit_rate * 100, - ) fill_nodes.append(node) continue if fill_nodes: self.fill_lut(fill_nodes) + cache_hits, cache_misses = self._node_lut.get_stats() + total_hits = cache_hits + cache_misses + cache_hit_rate = cache_hits / (total_hits or 1) + logger.debug( + "node LUT cache fill level %.1f%%, cache hit rate %.1f%%", + self.lut_fill_level() * 100, + cache_hit_rate * 100, + ) return await self.get_weights(octree, stations) return self.calc_weights(distances) diff --git a/src/qseek/search.py b/src/qseek/search.py index 2b5e29de..386aac32 100644 --- a/src/qseek/search.py +++ b/src/qseek/search.py @@ -22,6 +22,7 @@ ) from qseek.corrections.corrections import StationCorrectionType +from qseek.distance_weights import DistanceWeights from qseek.features import FeatureExtractorType from qseek.images.images import ImageFunctions, WaveformImages from qseek.magnitudes import EventMagnitudeCalculatorType @@ -34,7 +35,6 @@ from qseek.pre_processing.frequency_filters import Bandpass from qseek.pre_processing.module import Downsample, PreProcessing from qseek.signals import Signal -from qseek.spatial_weights import SpatialWeights from qseek.stats import RuntimeStats, Stats from qseek.tracers.tracers import RayTracer, RayTracers from qseek.utils import ( @@ -238,8 +238,9 @@ class Search(BaseModel): default=RayTracers(root=[tracer() for tracer in RayTracer.get_subclasses()]), description="List of ray tracers for travel time calculation.", ) - spatial_weights: SpatialWeights | None = Field( - default=SpatialWeights(), + distance_weights: DistanceWeights | None = Field( + default=DistanceWeights(), + alias="spatial_weights", description="Spatial weights for distance weighting.", ) station_corrections: StationCorrectionType | None = Field( @@ -463,8 +464,8 @@ async def prepare(self) -> None: self.data_provider.prepare(self.stations) await self.pre_processing.prepare() - if self.spatial_weights: - self.spatial_weights.prepare(self.stations, self.octree) + if self.distance_weights: + self.distance_weights.prepare(self.stations, self.octree) if self.station_corrections: await self.station_corrections.prepare( @@ -722,8 +723,8 @@ async def calculate_semblance( weights = np.full_like(shifts, fill_value=image.weight, dtype=np.float32) weights[traveltimes_bad] = 0.0 - if parent.spatial_weights: - weights *= await parent.spatial_weights.get_weights(octree, image.stations) + if parent.distance_weights: + weights *= await parent.distance_weights.get_weights(octree, image.stations) with np.errstate(divide="ignore", invalid="ignore"): weights /= weights.sum(axis=1, keepdims=True) diff --git a/src/qseek/waveforms/squirrel.py b/src/qseek/waveforms/squirrel.py index 73ff248a..739a851c 100644 --- a/src/qseek/waveforms/squirrel.py +++ b/src/qseek/waveforms/squirrel.py @@ -14,6 +14,7 @@ PositiveInt, PrivateAttr, computed_field, + constr, model_validator, ) from pyrocko.squirrel import Squirrel @@ -53,7 +54,7 @@ def __init__( async def prefetch_worker(self) -> None: logger.info( - "start prefetching data, queue size %d", + "start prefetching waveforms, queue size %d", self.queue.maxsize, ) @@ -70,7 +71,7 @@ async def load_data() -> None | Batch: await self.queue.put(batch) await asyncio.create_task(load_data()) - logger.debug("loading waveform batches to finish") + logger.debug("done loading waveforms") class SquirrelStats(Stats): @@ -135,11 +136,12 @@ class PyrockoSquirrel(WaveformProvider): "[ISO8601](https://en.wikipedia.org/wiki/ISO_8601).", ) - channel_selector: str = Field( - default="*", - max_length=3, - description="Channel selector for waveforms, " - "use e.g. `EN?` for selection of all accelerometer data.", + channel_selector: list[constr(to_upper=True, max_length=2, min_length=2)] | None = ( + Field( + default=None, + description="Channel selector for waveforms, " + "use e.g. `EN` for selection of all accelerometer data.", + ) ) async_prefetch_batches: PositiveInt = Field( default=10, @@ -220,12 +222,11 @@ async def iter_batches( tinc=window_increment.total_seconds(), tpad=window_padding.total_seconds(), want_incomplete=False, - codes=[ - (*nsl, self.channel_selector) for nsl in self._stations.get_all_nsl() - ], + codes=[(*nsl, "*") for nsl in self._stations.get_all_nsl()], ) prefetcher = SquirrelPrefetcher( - iterator, queue_size=self.async_prefetch_batches + iterator, + queue_size=self.async_prefetch_batches, ) stats.set_queue(prefetcher.queue) From cd609d56eb2db443a6f06f2518c2cf537938fd0c Mon Sep 17 00:00:00 2001 From: Marius Isken Date: Fri, 31 May 2024 13:54:30 +0000 Subject: [PATCH 13/26] bugfixes --- src/qseek/pre_processing/downsample.py | 9 ++++++++- src/qseek/pre_processing/frequency_filters.py | 20 +++++++++++++------ 2 files changed, 22 insertions(+), 7 deletions(-) diff --git a/src/qseek/pre_processing/downsample.py b/src/qseek/pre_processing/downsample.py index 14ced59b..b1098e2f 100644 --- a/src/qseek/pre_processing/downsample.py +++ b/src/qseek/pre_processing/downsample.py @@ -1,6 +1,7 @@ from __future__ import annotations import asyncio +import logging from typing import TYPE_CHECKING, Literal from pydantic import Field, PositiveFloat @@ -10,6 +11,8 @@ if TYPE_CHECKING: from qseek.waveforms.base import WaveformBatch +logger = logging.getLogger(__name__) + class Downsample(BatchPreProcessing): """Downsample the traces to a new sampling frequency.""" @@ -26,7 +29,11 @@ async def process_batch(self, batch: WaveformBatch) -> WaveformBatch: def worker() -> None: for trace in self.select_traces(batch): if trace.deltat < desired_deltat: - trace.downsample_to(deltat=desired_deltat, allow_upsample_max=5) + try: + trace.downsample_to(deltat=desired_deltat, allow_upsample_max=5) + except Exception as e: + logger.exception("Failed to downsample trace: %s", e) + ... await asyncio.to_thread(worker) return batch diff --git a/src/qseek/pre_processing/frequency_filters.py b/src/qseek/pre_processing/frequency_filters.py index e6eb2df8..3d526184 100644 --- a/src/qseek/pre_processing/frequency_filters.py +++ b/src/qseek/pre_processing/frequency_filters.py @@ -1,6 +1,7 @@ from __future__ import annotations import asyncio +import logging from typing import TYPE_CHECKING, Literal from pydantic import Field, PositiveFloat, field_validator @@ -12,6 +13,9 @@ from qseek.waveforms.base import WaveformBatch +logger = logging.getLogger(__name__) + + class Bandpass(BatchPreProcessing): """Apply a bandpass filter to the traces.""" @@ -43,12 +47,16 @@ def _check_bandpass(cls, value) -> Range: async def process_batch(self, batch: WaveformBatch) -> WaveformBatch: def worker() -> None: for trace in self.select_traces(batch): - trace.bandpass( - order=self.corners, - corner_hp=self.bandpass[0], - corner_lp=self.bandpass[1], - demean=self.demean, - ) + try: + trace.bandpass( + order=self.corners, + corner_hp=self.bandpass[0], + corner_lp=self.bandpass[1], + demean=self.demean, + ) + except Exception as e: + logger.exception("Failed to apply bandpass filter: %s", e) + ... await asyncio.to_thread(worker) return batch From 4884298b0acca5b4a62723995ad33cd43a31f226 Mon Sep 17 00:00:00 2001 From: Marius Isken Date: Fri, 31 May 2024 13:55:18 +0000 Subject: [PATCH 14/26] bugfixes --- src/qseek/waveforms/squirrel.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/qseek/waveforms/squirrel.py b/src/qseek/waveforms/squirrel.py index 739a851c..48ee2172 100644 --- a/src/qseek/waveforms/squirrel.py +++ b/src/qseek/waveforms/squirrel.py @@ -140,7 +140,7 @@ class PyrockoSquirrel(WaveformProvider): Field( default=None, description="Channel selector for waveforms, " - "use e.g. `EN` for selection of all accelerometer data.", + "use e.g. `['EN']` for selection of all accelerometer data.", ) ) async_prefetch_batches: PositiveInt = Field( @@ -223,6 +223,7 @@ async def iter_batches( tpad=window_padding.total_seconds(), want_incomplete=False, codes=[(*nsl, "*") for nsl in self._stations.get_all_nsl()], + channel_priorities=self.channel_selector, ) prefetcher = SquirrelPrefetcher( iterator, From a138aea00b9eec4ec73419e6a2227553c9e3d3a4 Mon Sep 17 00:00:00 2001 From: Marius Isken Date: Thu, 13 Jun 2024 10:27:56 +0000 Subject: [PATCH 15/26] fixes and improvements --- src/qseek/distance_weights.py | 4 +- src/qseek/images/base.py | 9 +- src/qseek/images/phase_net.py | 31 ++++-- src/qseek/models/catalog.py | 3 + src/qseek/models/detection.py | 6 +- src/qseek/octree.py | 27 ++++- src/qseek/pre_processing/base.py | 11 ++ src/qseek/pre_processing/downsample.py | 63 +++++++++-- src/qseek/pre_processing/frequency_filters.py | 104 ++++++++++++++---- src/qseek/search.py | 35 ++++-- src/qseek/tracers/base.py | 4 + src/qseek/waveforms/squirrel.py | 14 ++- 12 files changed, 244 insertions(+), 67 deletions(-) diff --git a/src/qseek/distance_weights.py b/src/qseek/distance_weights.py index 93cfd9d3..90eba4ea 100644 --- a/src/qseek/distance_weights.py +++ b/src/qseek/distance_weights.py @@ -66,7 +66,7 @@ def calc_weights(self, distances: np.ndarray) -> np.ndarray: ) + self.waterlevel def prepare(self, stations: Stations, octree: Octree) -> None: - logger.info("preparing spatial weights") + logger.info("preparing distance weights") bytes_per_node = stations.n_stations * np.float32().itemsize lru_cache_size = int(self.lut_cache_size / bytes_per_node) @@ -80,7 +80,7 @@ def prepare(self, stations: Stations, octree: Octree) -> None: self.fill_lut(nodes=list(octree)) def fill_lut(self, nodes: Sequence[Node]) -> None: - logger.debug("filling weight lut for %d nodes", len(nodes)) + logger.debug("filling distance weight LUT for %d nodes", len(nodes)) distances = self.get_distances(nodes) for node, sta_distances in zip(nodes, distances, strict=True): sta_distances = sta_distances.astype(np.float32) diff --git a/src/qseek/images/base.py b/src/qseek/images/base.py index 6db88af7..49c1d06e 100644 --- a/src/qseek/images/base.py +++ b/src/qseek/images/base.py @@ -43,7 +43,13 @@ def get_blinding(self, sampling_rate: float) -> timedelta: """ raise NotImplementedError("must be implemented by subclass") - def get_provided_phases(self) -> tuple[PhaseDescription, ...]: ... + def get_provided_phases(self) -> tuple[PhaseDescription, ...]: + """Get the phases provided by the image function. + + Returns: + tuple[PhaseDescription, ...]: The phases provided by the image function. + """ + raise NotImplementedError("must be implemented by subclass") @dataclass @@ -52,6 +58,7 @@ class WaveformImage: phase: PhaseDescription weight: float traces: list[Trace] + detection_half_width: float stations: Stations = Field(default_factory=lambda: Stations.model_construct()) @property diff --git a/src/qseek/images/phase_net.py b/src/qseek/images/phase_net.py index c658c3b1..cdd47c32 100644 --- a/src/qseek/images/phase_net.py +++ b/src/qseek/images/phase_net.py @@ -150,7 +150,7 @@ class PhaseNet(ImageFunction): description="Window overlap in samples.", ) torch_use_cuda: bool | int = Field( - default=False, + default=True, description="Use CUDA for inference. If `True` use default device, if `int` use" " the specified device.", ) @@ -159,7 +159,7 @@ class PhaseNet(ImageFunction): description="Number of CPU threads to use if only CPU is used.", ) batch_size: int = Field( - default=64, + default=128, ge=64, description="Batch size for inference, larger values can improve performance.", ) @@ -204,10 +204,16 @@ def _prepare(self) -> None: torch.set_num_threads(self.torch_cpu_threads) self._phase_net = PhaseNetSeisBench.from_pretrained(self.model) if self.torch_use_cuda: - if isinstance(self.torch_use_cuda, bool): - self._phase_net.cuda() - else: - self._phase_net.cuda(self.torch_use_cuda) + try: + if isinstance(self.torch_use_cuda, bool): + self._phase_net.cuda() + else: + self._phase_net.cuda(self.torch_use_cuda) + except RuntimeError as exc: + logger.warning( + "failed to use CUDA for PhaseNet model, using CPU.", + exc_info=exc, + ) self._phase_net.eval() try: logger.info("compiling PhaseNet model...") @@ -224,13 +230,18 @@ def get_blinding(self, sampling_rate: float) -> timedelta: ) return timedelta(seconds=blinding_samples / sampling_rate) + def _detection_half_width(self) -> float: + """Half width of the detection window in seconds.""" + # The 0.2 seconds is the default value from SeisBench training + return 0.2 / self.rescale_input + @alog_call async def process_traces(self, traces: list[Trace]) -> list[PhaseNetImage]: stream = Stream(tr.to_obspy_trace() for tr in traces) - if self.rescale_input > 1: + if self.rescale_input != 1: scale = self.rescale_input for tr in stream: - tr.stats.sampling_rate = tr.stats.sampling_rate / scale + tr.stats.sampling_rate /= scale annotations: Stream = await asyncio.to_thread( self.phase_net.annotate, @@ -243,7 +254,7 @@ async def process_traces(self, traces: list[Trace]) -> list[PhaseNetImage]: if self.rescale_input > 1: scale = self.rescale_input for tr in annotations: - tr.stats.sampling_rate = tr.stats.sampling_rate * scale + tr.stats.sampling_rate *= scale blinding_samples = self.phase_net.default_args["blinding"][0] # 100 Hz is the native sampling rate of PhaseNet blinding_seconds = (blinding_samples / 100.0) * (1.0 - 1 / scale) @@ -259,12 +270,14 @@ async def process_traces(self, traces: list[Trace]) -> list[PhaseNetImage]: image_function=self, weight=self.weights["P"], phase=self.phase_map["P"], + detection_half_width=self._detection_half_width(), traces=[tr for tr in annotated_traces if tr.channel.endswith("P")], ) annotation_s = PhaseNetImage( image_function=self, weight=self.weights["S"], phase=self.phase_map["S"], + detection_half_width=self._detection_half_width(), traces=[tr for tr in annotated_traces if tr.channel.endswith("S")], ) diff --git a/src/qseek/models/catalog.py b/src/qseek/models/catalog.py index e25c7f64..a997b921 100644 --- a/src/qseek/models/catalog.py +++ b/src/qseek/models/catalog.py @@ -102,6 +102,9 @@ async def filter_events_by_time( start_time (datetime | None): Start time of the time range. end_time (datetime | None): End time of the time range. """ + if not self.events: + return + events = [] if start_time is not None and min(det.time for det in self.events) < start_time: logger.info("filtering detections after start time %s", start_time) diff --git a/src/qseek/models/detection.py b/src/qseek/models/detection.py index 89526cf4..a070c15b 100644 --- a/src/qseek/models/detection.py +++ b/src/qseek/models/detection.py @@ -3,6 +3,7 @@ import asyncio import logging from datetime import datetime, timedelta +from functools import cached_property from itertools import chain from pathlib import Path from random import uniform @@ -223,7 +224,8 @@ class EventReceivers(BaseModel): event_uid: UUID | None = None receivers: list[Receiver] = [] - @property + @computed_field + @cached_property def n_receivers(self) -> int: """Number of receivers in the receiver set.""" return len(self.receivers) @@ -470,7 +472,6 @@ class EventDetection(Location): default=0, description="Number of stations in the detection.", ) - distance_border: PositiveFloat = Field( ..., description="Distance to the nearest border in meters. " @@ -699,6 +700,7 @@ def get_csv_dict(self) -> dict[str, Any]: "north_shift": round(self.north_shift, 2), "distance_border": round(self.distance_border, 2), "semblance": self.semblance, + "n_stations": self.n_stations, } for magnitude in self.magnitudes: csv_line.update(magnitude.csv_row()) diff --git a/src/qseek/octree.py b/src/qseek/octree.py index 2347bcc0..70145853 100644 --- a/src/qseek/octree.py +++ b/src/qseek/octree.py @@ -116,8 +116,14 @@ def split(self) -> tuple[Node, ...]: @property def coordinates(self) -> tuple[float, float, float]: + """Returns the node coordinates (east, north, depth).""" return self.east, self.north, self.depth + @property + def radius(self) -> float: + """Returns the radius of the sphere that can fit the node inside.""" + return np.sqrt(3 * (self.size / 2) ** 2) + def get_distance_border(self, with_surface: bool = False) -> float: """Distance to the closest EW, NS or bottom border of the tree. @@ -145,19 +151,30 @@ def get_distance_border(self, with_surface: bool = False) -> float: return min(border_distance, self.depth - tree.depth_bounds[0]) return border_distance - def is_inside_border(self, with_surface: bool = False) -> bool: + def is_inside_border( + self, + with_surface: bool = False, + border_width: float | Literal["root_node_size"] = "root_node_size", + ) -> bool: """Check if the node is within the root node border. Args: with_surface (bool, optional): If True, the surface is considered as a border. Defaults to False. + border_width (float, optional): Width of the border, if 0.0, + the octree's root node size is used. Defaults to 0.0. Returns: bool: True if the node is inside the root tree's border. """ if self.tree is None: raise AttributeError("parent tree not set") - return self.get_distance_border(with_surface) <= self.tree.root_node_size + border = ( + self.tree.root_node_size + if border_width == "root_node_size" + else border_width + ) + return self.get_distance_border(with_surface) <= border def can_split(self) -> bool: """Check if the node can be split. @@ -227,7 +244,7 @@ def as_location(self) -> Location: ) return self._location - def collides(self, other: Node) -> bool: + def is_colliding(self, other: Node) -> bool: """Check if two nodes collide. Args: @@ -254,7 +271,7 @@ def get_neighbours(self) -> list[Node]: return [ node for node in self.tree.iter_nodes() - if self.collides(node) and node is not self + if self.is_colliding(node) and node is not self ] def distance_to(self, other: Node) -> float: @@ -432,7 +449,7 @@ def _clear_cache(self) -> None: del self.n_nodes def reset(self) -> Self: - """Reset the octree to its initial state.""" + """Reset the octree to its initial state and return it.""" logger.debug("resetting tree") self._clear_cache() self._root_nodes = self.get_root_nodes(self.root_node_size) diff --git a/src/qseek/pre_processing/base.py b/src/qseek/pre_processing/base.py index 9b04fb68..32444ad1 100644 --- a/src/qseek/pre_processing/base.py +++ b/src/qseek/pre_processing/base.py @@ -1,8 +1,11 @@ from __future__ import annotations +from itertools import groupby from typing import TYPE_CHECKING, Literal +import numpy as np from pydantic import BaseModel, Field, field_validator +from pyrocko.trace import Trace from qseek.utils import NSL @@ -67,3 +70,11 @@ async def process_batch(self, batch: WaveformBatch) -> WaveformBatch: list[Trace]: The processed list of traces. """ raise NotImplementedError + + +def group_traces(traces: list[Trace]) -> groupby[tuple[float, int], Trace]: + return groupby(traces, key=lambda trace: (trace.deltat, trace.ydata.size)) + + +def traces_data(traces: list[Trace], dtype=np.float64) -> np.ndarray: + return np.array([trace.ydata for trace in traces], dtype=dtype) diff --git a/src/qseek/pre_processing/downsample.py b/src/qseek/pre_processing/downsample.py index b1098e2f..c6142cfa 100644 --- a/src/qseek/pre_processing/downsample.py +++ b/src/qseek/pre_processing/downsample.py @@ -2,17 +2,56 @@ import asyncio import logging +from concurrent.futures import ThreadPoolExecutor from typing import TYPE_CHECKING, Literal +import numpy as np from pydantic import Field, PositiveFloat +from pyrocko.trace import _configure_downsampling +from pyrocko.util import decimate_coeffs +from scipy import signal -from qseek.pre_processing.base import BatchPreProcessing +from qseek.pre_processing.base import BatchPreProcessing, group_traces, traces_data if TYPE_CHECKING: + from pyrocko.trace import Trace + from qseek.waveforms.base import WaveformBatch logger = logging.getLogger(__name__) +THREAD_POOL = ThreadPoolExecutor(max_workers=4) + + +def downsample( + traces: list[Trace], + delta_t: float, + demean: bool = False, +) -> list[Trace]: + data = traces_data(traces) + trace_deltat = traces[0].deltat + + upscale_sratio, decimation_sequence = _configure_downsampling( + trace_deltat, delta_t, allow_upsample_max=5 + ) + + if demean: + data -= np.mean(data, axis=1, keepdims=True) + + if upscale_sratio > 1: + data = np.repeat(data, upscale_sratio, axis=1) + + for n_decimate in decimation_sequence: + b, a, n = decimate_coeffs(n_decimate, None, "fir-remez") + data = signal.lfilter(b, a, data, axis=1) + data = data[:, n // 2 :: n_decimate].copy() + + for trace, trace_data in zip(traces, data, strict=True): + trace.ydata = trace_data + trace.tmax = trace.tmin + (trace_data.size - 1) * delta_t + trace.deltat = delta_t + return traces + class Downsample(BatchPreProcessing): """Downsample the traces to a new sampling frequency.""" @@ -24,16 +63,22 @@ class Downsample(BatchPreProcessing): ) async def process_batch(self, batch: WaveformBatch) -> WaveformBatch: - desired_deltat = 1 / self.sampling_frequency + desired_delta_t = 1.0 / self.sampling_frequency def worker() -> None: - for trace in self.select_traces(batch): - if trace.deltat < desired_deltat: - try: - trace.downsample_to(deltat=desired_deltat, allow_upsample_max=5) - except Exception as e: - logger.exception("Failed to downsample trace: %s", e) - ... + traces = self.select_traces(batch) + trace_groups = [] + for (delta_t, _), trace_group in group_traces(traces): + if desired_delta_t <= delta_t: + logger.debug("traces sampling rate is smaller than desired") + continue + trace_groups.append(list(trace_group)) + + THREAD_POOL.map( + downsample, + trace_groups, + [desired_delta_t] * len(trace_groups), + ) await asyncio.to_thread(worker) return batch diff --git a/src/qseek/pre_processing/frequency_filters.py b/src/qseek/pre_processing/frequency_filters.py index 3d526184..5129ebc3 100644 --- a/src/qseek/pre_processing/frequency_filters.py +++ b/src/qseek/pre_processing/frequency_filters.py @@ -2,20 +2,62 @@ import asyncio import logging +from functools import lru_cache from typing import TYPE_CHECKING, Literal +import numpy as np from pydantic import Field, PositiveFloat, field_validator +from scipy import signal -from qseek.pre_processing.base import BatchPreProcessing +from qseek.pre_processing.base import BatchPreProcessing, group_traces, traces_data from qseek.utils import Range if TYPE_CHECKING: + from pyrocko.trace import Trace + from qseek.waveforms.base import WaveformBatch logger = logging.getLogger(__name__) +@lru_cache +def butter_sos( + N: int, # noqa: N803 + Wn: float | tuple[float, float], # noqa: N803 + btype: Literal["lowpass", "highpass", "bandpass"], + fs: float, + dtype: np.dtype = float, +) -> np.ndarray: + return signal.butter( + N=N, + Wn=Wn, + btype=btype, + fs=fs, + output="sos", + ).astype(dtype) + + +def _sos_filter( + traces: list[Trace], + sos: np.ndarray, + demean: bool, + zero_phase: bool, +) -> list[Trace]: + data = traces_data(traces) + if demean: + data -= np.mean(data, axis=1, keepdims=True) + + if zero_phase: + data = signal.sosfiltfilt(sos, data, axis=1) + else: + data = signal.sosfilt(sos, data, axis=1) + + for trace, ydata in zip(traces, data, strict=True): + trace.set_ydata(ydata) + return traces + + class Bandpass(BatchPreProcessing): """Apply a bandpass filter to the traces.""" @@ -46,17 +88,15 @@ def _check_bandpass(cls, value) -> Range: async def process_batch(self, batch: WaveformBatch) -> WaveformBatch: def worker() -> None: - for trace in self.select_traces(batch): - try: - trace.bandpass( - order=self.corners, - corner_hp=self.bandpass[0], - corner_lp=self.bandpass[1], - demean=self.demean, - ) - except Exception as e: - logger.exception("Failed to apply bandpass filter: %s", e) - ... + traces = self.select_traces(batch) + for (deltat, _), trace_group in group_traces(traces): + sos = butter_sos( + N=self.corners, + Wn=self.bandpass, + btype="bandpass", + fs=1.0 / deltat, + ) + _sos_filter(list(trace_group), sos, demean=self.demean, zero_phase=True) await asyncio.to_thread(worker) return batch @@ -82,12 +122,22 @@ class Highpass(BatchPreProcessing): async def process_batch(self, batch: WaveformBatch) -> WaveformBatch: def worker() -> None: - for trace in self.select_traces(batch): - trace.highpass( - order=self.corners, - corner=self.frequency, - demean=self.demean, + traces = self.select_traces(batch) + for (deltat, _), trace_group in group_traces(traces): + sampling_rate = 1.0 / deltat + if self.frequency >= sampling_rate / 2: + logger.debug( + "Highpass frequency is higher than Nyquist frequency. " + "No filtering is applied." + ) + continue + sos = butter_sos( + N=self.corners, + Wn=self.frequency, + btype="highpass", + fs=sampling_rate, ) + _sos_filter(list(trace_group), sos, demean=self.demean, zero_phase=True) await asyncio.to_thread(worker) return batch @@ -113,12 +163,22 @@ class Lowpass(BatchPreProcessing): async def process_batch(self, batch: WaveformBatch) -> WaveformBatch: def worker() -> None: - for trace in self.select_traces(batch): - trace.lowpass( - order=self.corners, - corner=self.frequency, - demean=self.demean, + traces = self.select_traces(batch) + for (deltat, _), trace_group in group_traces(traces): + sampling_rate = 1.0 / deltat + if self.frequency >= sampling_rate / 2: + logger.debug( + "Lowpass frequency is higher than Nyquist frequency. " + "No filtering is applied." + ) + continue + sos = butter_sos( + N=self.corners, + Wn=self.frequency, + btype="lowpass", + fs=sampling_rate, ) + _sos_filter(list(trace_group), sos, demean=self.demean, zero_phase=True) await asyncio.to_thread(worker) return batch diff --git a/src/qseek/search.py b/src/qseek/search.py index 386aac32..a27b142a 100644 --- a/src/qseek/search.py +++ b/src/qseek/search.py @@ -11,6 +11,7 @@ import numpy as np import psutil from pydantic import ( + AliasChoices, BaseModel, ByteSize, ConfigDict, @@ -31,7 +32,7 @@ from qseek.models.detection import EventDetection, PhaseDetection from qseek.models.detection_uncertainty import DetectionUncertainty from qseek.models.semblance import Semblance, SemblanceCache -from qseek.octree import NodeSplitError, Octree +from qseek.octree import Octree from qseek.pre_processing.frequency_filters import Bandpass from qseek.pre_processing.module import Downsample, PreProcessing from qseek.signals import Signal @@ -240,7 +241,7 @@ class Search(BaseModel): ) distance_weights: DistanceWeights | None = Field( default=DistanceWeights(), - alias="spatial_weights", + validation_alias=AliasChoices("spatial_weights", "distance_weights"), description="Spatial weights for distance weighting.", ) station_corrections: StationCorrectionType | None = Field( @@ -271,6 +272,11 @@ class Search(BaseModel): " the octree. If `with_surface`, all events inside the boundaries of the volume" " are absorbed. If `without_surface`, events at the surface are not absorbed.", ) + absorbing_boundary_width: float | Literal["root_node_size"] = Field( + default="root_node_size", + description="Width of the absorbing boundary around the octree volume. " + "If 'octree' the width is set to the root node size of the octree.", + ) node_peak_interpolation: bool = Field( default=True, description="Interpolate intranode locations for detected events using radial" @@ -461,6 +467,9 @@ async def prepare(self) -> None: None """ logger.info("preparing search components") + asyncio.get_running_loop().set_exception_handler( + lambda loop, context: logger.error(context) + ) self.data_provider.prepare(self.stations) await self.pre_processing.prepare() @@ -840,7 +849,8 @@ async def search( source_node = octree[node_idx] if parent.absorbing_boundary and source_node.is_inside_border( - with_surface=parent.absorbing_boundary == "with_surface" + with_surface=parent.absorbing_boundary == "with_surface", + border_width=parent.absorbing_boundary_width, ): continue refine_nodes.update(source_node) @@ -854,23 +864,25 @@ async def search( # refine_nodes is empty when all sources fall into smallest octree nodes if refine_nodes: + node_size_max = max(node.size for node in refine_nodes) new_level = 0 for node in refine_nodes: - try: - node.split() - new_level = max(new_level, node.level + 1) - except NodeSplitError: - continue + node.split() + new_level = max(new_level, node.level + 1) logger.info( - "detected %d energy burst%s - refined %d nodes, lowest level %d", + "detected %d energy burst%s - refined %d nodes, level %d (%.1f m)", detection_idx.size, "s" if detection_idx.size > 1 else "", len(refine_nodes), new_level, + node_size_max, ) cache = semblance.get_cache() del semblance - return await self.search(octree, semblance_cache=cache) + return await self.search( + octree, + semblance_cache=cache, + ) detections = [] for time_idx, semblance_detection in zip( @@ -883,7 +895,8 @@ async def search( octree.map_semblance(semblance_event) source_node = octree[node_idx] if parent.absorbing_boundary and source_node.is_inside_border( - with_surface=parent.absorbing_boundary == "with_surface" + with_surface=parent.absorbing_boundary == "with_surface", + border_width=parent.absorbing_boundary_width, ): continue diff --git a/src/qseek/tracers/base.py b/src/qseek/tracers/base.py index 656e73f8..ef1405a5 100644 --- a/src/qseek/tracers/base.py +++ b/src/qseek/tracers/base.py @@ -21,7 +21,11 @@ @dataclass class ModelledArrival: phase: str + "Name of the phase" + time: datetime + "Time of the arrival" + tracer: str = "" diff --git a/src/qseek/waveforms/squirrel.py b/src/qseek/waveforms/squirrel.py index 48ee2172..3fcb0d47 100644 --- a/src/qseek/waveforms/squirrel.py +++ b/src/qseek/waveforms/squirrel.py @@ -81,6 +81,7 @@ class SquirrelStats(Stats): bytes_per_seconds: float = 0.0 _queue: asyncio.Queue[Batch | None] | None = PrivateAttr(None) + _position: int = 3 def set_queue(self, queue: asyncio.Queue[Batch | None]) -> None: self._queue = queue @@ -139,14 +140,17 @@ class PyrockoSquirrel(WaveformProvider): channel_selector: list[constr(to_upper=True, max_length=2, min_length=2)] | None = ( Field( default=None, - description="Channel selector for waveforms, " - "use e.g. `['EN']` for selection of all accelerometer data.", + description="Channel selector for waveforms, " "e.g. `['HH', 'EN']`.", ) ) async_prefetch_batches: PositiveInt = Field( default=10, description="Queue size for asynchronous pre-fetcher.", ) + n_threads: PositiveInt = Field( + default=8, + description="Number of threads for loading waveforms.", + ) _squirrel: Squirrel | None = PrivateAttr(None) _stations: Stations = PrivateAttr(None) @@ -162,13 +166,11 @@ def _validate_model(self) -> Self: def get_squirrel(self) -> Squirrel: if not self._squirrel: - logger.info( - "initializing squirrel waveform provider in environment %s", - self.environment, - ) + logger.info("loading squirrel environment from %s", self.environment) squirrel = Squirrel( env=str(self.environment.expanduser()) if self.environment else None, persistent=self.persistent, + n_threads=self.n_threads, ) paths = [] for path in self.waveform_dirs: From 65e94e7fa979ad77bd86fe24898c142627cf36c2 Mon Sep 17 00:00:00 2001 From: Marius Isken Date: Fri, 5 Jul 2024 14:17:03 +0000 Subject: [PATCH 16/26] adding exporters --- pyproject.toml | 1 + src/qseek/apps/qseek.py | 80 +++++++++++++++++++++++ src/qseek/exporters/__init__.py | 2 + src/qseek/exporters/base.py | 19 ++++++ src/qseek/exporters/simple.py | 59 +++++++++++++++++ src/qseek/exporters/velest.py | 35 ++++++++++ src/qseek/images/phase_net.py | 2 +- src/qseek/models/detection.py | 9 ++- src/qseek/models/detection_uncertainty.py | 14 +++- 9 files changed, 218 insertions(+), 3 deletions(-) create mode 100644 src/qseek/exporters/__init__.py create mode 100644 src/qseek/exporters/base.py create mode 100644 src/qseek/exporters/simple.py create mode 100644 src/qseek/exporters/velest.py diff --git a/pyproject.toml b/pyproject.toml index 2a2491b0..2752ed72 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,6 +47,7 @@ dependencies = [ "pyevtk>=1.6", "psutil>=5.9", "aiofiles>=23.0", + "typer >=0.12.3", ] classifiers = [ diff --git a/src/qseek/apps/qseek.py b/src/qseek/apps/qseek.py index 55561abb..b3acbd93 100644 --- a/src/qseek/apps/qseek.py +++ b/src/qseek/apps/qseek.py @@ -170,6 +170,41 @@ ) +export = subparsers.add_parser( + "export", + help="Export detections to different output formats", + description="Export detections to different output formats." + " Get an overview with `qseek export list`", +) + +export.add_argument( + "format", + type=str, + help="Name of export module, or `list` to list available modules", +) + +export.add_argument( + "rundir", + type=Path, + help="path to existing qseek rundir", + nargs="?", +) + +export.add_argument( + "export_dir", + type=Path, + help="path to export directory", + nargs="?", +) + +export.add_argument( + "--force", + action="store_true", + default=False, + help="overwrite existing output directory", +) + + subparsers.add_parser( "clear-cache", help="clear the cach directory", @@ -374,6 +409,51 @@ async def start() -> None: logger.info("clearing cache directory %s", CACHE_DIR) shutil.rmtree(CACHE_DIR) + case "export": + from qseek.exporters.base import Exporter + + def show_table(): + table = Table(box=box.SIMPLE, header_style=None) + table.add_column("Exporter") + table.add_column("Description") + for exporter in Exporter.get_subclasses(): + table.add_row(exporter.__name__.lower(), exporter.__doc__) + console.print(table) + + if args.format == "list": + show_table() + parser.exit() + + if not args.rundir: + parser.error("rundir is required for export") + + if args.export_dir is None: + parser.error("export directory is required") + + if args.export_dir.exists(): + if not args.force: + parser.error(f"export directory {args.export_dir} already exists") + shutil.rmtree(args.export_dir) + + for exporter in Exporter.get_subclasses(): + if exporter.__name__.lower() == args.format.lower(): + exporter_instance = exporter() + asyncio.run( + exporter_instance.export( + rundir=args.rundir, + outdir=args.export_dir, + ) + ) + break + else: + available_exporters = ", ".join( + exporter.__name__ for exporter in Exporter.get_subclasses() + ) + parser.error( + f"unknown exporter: {args.format}" + f"choose fom: {available_exporters}" + ) + case "modules": from qseek.corrections.base import TravelTimeCorrections from qseek.features.base import FeatureExtractor diff --git a/src/qseek/exporters/__init__.py b/src/qseek/exporters/__init__.py new file mode 100644 index 00000000..3e842644 --- /dev/null +++ b/src/qseek/exporters/__init__.py @@ -0,0 +1,2 @@ +from qseek.exporters.simple import Simple # noqa +from qseek.exporters.velest import Velest # noqa diff --git a/src/qseek/exporters/base.py b/src/qseek/exporters/base.py new file mode 100644 index 00000000..c26e412c --- /dev/null +++ b/src/qseek/exporters/base.py @@ -0,0 +1,19 @@ +from __future__ import annotations + +from pathlib import Path + +from pydantic import BaseModel + + +class Exporter(BaseModel): + async def export(self, rundir: Path, outdir: Path) -> Path: + raise NotImplementedError + + @classmethod + def get_subclasses(cls) -> tuple[type[Exporter], ...]: + """Get the subclasses of this class. + + Returns: + list[type]: The subclasses of this class. + """ + return tuple(cls.__subclasses__()) diff --git a/src/qseek/exporters/simple.py b/src/qseek/exporters/simple.py new file mode 100644 index 00000000..63dde913 --- /dev/null +++ b/src/qseek/exporters/simple.py @@ -0,0 +1,59 @@ +from __future__ import annotations + +import logging +from pathlib import Path + +from rich import progress + +from qseek.exporters.base import Exporter +from qseek.search import Search +from qseek.utils import time_to_path + +logger = logging.getLogger(__name__) + + +class Simple(Exporter): + """Export simple travel times in CSV format (E. Biondi, 2023).""" + + async def export(self, rundir: Path, outdir: Path) -> Path: + logger.info("Export simple travel times in CSV format.") + + search = Search.load_rundir(rundir) + catalog = search.catalog + + traveltime_dir = outdir / "traveltimes" + outdir.mkdir(parents=True) + traveltime_dir.mkdir() + + event_file = outdir / "events.csv" + self.search.stations.export_csv(outdir / "stations.csv") + await catalog.export_csv(event_file) + + for ev in progress.track( + catalog, + description="Exporting travel times", + total=catalog.n_events, + ): + traveltime_file = traveltime_dir / f"{time_to_path(ev.time)}.csv" + with traveltime_file.open("w") as file: + file.write(f"# event_id: {ev.uid}\n") + file.write(f"# event_time: {ev.time}\n") + file.write(f"# event_lat: {ev.lat}\n") + file.write(f"# event_lon: {ev.lon}\n") + file.write(f"# event_depth: {ev.effective_depth}\n") + file.write(f"# event_semblance: {ev.semblance}\n") + file.write("# traveltime observations:\n") + file.write( + "lat,lon,elevation,network,station,location,phase,confidence,traveltime\n" + ) + + for receiver in ev.receivers: + for phase, arrival in receiver.phase_arrivals.items(): + if arrival.observed is None: + continue + traveltime = arrival.observed.time - ev.time + file.write( + f"{receiver.lat},{receiver.lon},{receiver.effective_elevation},{receiver.network},{receiver.station},{receiver.location},{phase},{arrival.observed.detection_value},{traveltime.total_seconds()}\n", + ) + + return outdir diff --git a/src/qseek/exporters/velest.py b/src/qseek/exporters/velest.py new file mode 100644 index 00000000..7286650a --- /dev/null +++ b/src/qseek/exporters/velest.py @@ -0,0 +1,35 @@ +from __future__ import annotations + +import logging +from pathlib import Path + +import rich +from rich.prompt import FloatPrompt + +from qseek.exporters.base import Exporter +from qseek.search import Search + +logger = logging.getLogger(__name__) + + +class Velest(Exporter): + """Crate a VELEST project folder for 1D velocity model estimation.""" + + min_pick_semblance: float = 0.3 + n_picks: dict[str, int] = {} + n_events: int = 0 + + async def export(self, rundir: Path, outdir: Path) -> Path: + rich.print("Exporting qseek search to VELEST project folder") + min_pick_semblance = FloatPrompt.ask("Minimum pick confidence", default=0.3) + + self.min_pick_semblance = min_pick_semblance + + outdir.mkdir() + search = Search.load_rundir(rundir) + catalog = search.catalog # noqa + + export_info = outdir / "export_info.json" + export_info.write_text(self.model_dump_json(indent=2)) + + return outdir diff --git a/src/qseek/images/phase_net.py b/src/qseek/images/phase_net.py index cdd47c32..a5e2bf55 100644 --- a/src/qseek/images/phase_net.py +++ b/src/qseek/images/phase_net.py @@ -137,7 +137,7 @@ class PhaseNet(ImageFunction): image: Literal["PhaseNet"] = "PhaseNet" model: ModelName = Field( - default="ethz", + default="original", description="SeisBench pre-trained PhaseNet model to use. " "Choose from `ethz`, `geofon`, `instance`, `iquique`, `lendb`, `neic`, `obs`," " `original`, `scedc`, `stead`." diff --git a/src/qseek/models/detection.py b/src/qseek/models/detection.py index a070c15b..acf0e288 100644 --- a/src/qseek/models/detection.py +++ b/src/qseek/models/detection.py @@ -695,13 +695,20 @@ def get_csv_dict(self) -> dict[str, Any]: "time": self.time, "lat": round(self.effective_lat, 6), "lon": round(self.effective_lon, 6), - "depth_ellipsoid": round(self.effective_depth, 2), + "depth": round(self.effective_depth, 2), "east_shift": round(self.east_shift, 2), "north_shift": round(self.north_shift, 2), "distance_border": round(self.distance_border, 2), "semblance": self.semblance, "n_stations": self.n_stations, } + if self.uncertainty: + csv_line.update( + { + "uncertainty_horizontal": self.uncertainty.horizontal, + "uncertainty_vertical": self.uncertainty.depth, + } + ) for magnitude in self.magnitudes: csv_line.update(magnitude.csv_row()) return csv_line diff --git a/src/qseek/models/detection_uncertainty.py b/src/qseek/models/detection_uncertainty.py index 5b2b386c..6fcd4431 100644 --- a/src/qseek/models/detection_uncertainty.py +++ b/src/qseek/models/detection_uncertainty.py @@ -3,7 +3,7 @@ from typing import TYPE_CHECKING import numpy as np -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, computed_field from typing_extensions import Self if TYPE_CHECKING: @@ -63,3 +63,15 @@ def from_event( north=(float(min_offsets[1]), float(max_offsets[1])), depth=(float(min_offsets[2]), float(max_offsets[2])), ) + + @computed_field + def total(self) -> float: + """Calculate the total uncertainty in [m].""" + return float( + np.sqrt(sum(self.east) ** 2 + sum(self.north) ** 2 + sum(self.depth) ** 2) + ) + + @computed_field + def horizontal(self) -> float: + """Calculate the horizontal uncertainty in [m].""" + return float(np.sqrt(sum(self.east) ** 2 + sum(self.north) ** 2)) From b79b62b9dfe0b75c36159d13e332da41ab83b19d Mon Sep 17 00:00:00 2001 From: Marius Isken Date: Sun, 7 Jul 2024 12:59:50 +0000 Subject: [PATCH 17/26] minor improvements --- src/qseek/images/base.py | 7 ++++--- src/qseek/images/images.py | 7 ++++--- src/qseek/models/catalog.py | 2 +- src/qseek/models/station.py | 19 +++++++++++-------- src/qseek/pre_processing/module.py | 6 ++++-- src/qseek/search.py | 12 ++++++------ src/qseek/utils.py | 1 + src/qseek/waveforms/squirrel.py | 21 +++++---------------- 8 files changed, 36 insertions(+), 39 deletions(-) diff --git a/src/qseek/images/base.py b/src/qseek/images/base.py index 49c1d06e..4d28062d 100644 --- a/src/qseek/images/base.py +++ b/src/qseek/images/base.py @@ -89,15 +89,16 @@ def resample(self, sampling_rate: float, max_normalize: bool = False) -> None: max_normalize (bool): Normalize by maximum value to keep the scale of the maximum detection. Defaults to False. """ + if self.sampling_rate == sampling_rate: + return + downsample = self.sampling_rate > sampling_rate for tr in self.traces: - if max_normalize: - # We can use maximum here since the PhaseNet output is single-sided - _, max_value = tr.max() resample(tr, sampling_rate) if max_normalize and downsample: + _, max_value = tr.max() tr.ydata /= tr.ydata.max() tr.ydata *= max_value diff --git a/src/qseek/images/images.py b/src/qseek/images/images.py index a9f3a77f..d4451004 100644 --- a/src/qseek/images/images.py +++ b/src/qseek/images/images.py @@ -12,7 +12,7 @@ from qseek.images.base import ImageFunction from qseek.images.phase_net import PhaseNet from qseek.stats import Stats -from qseek.utils import PhaseDescription, datetime_now, human_readable_bytes +from qseek.utils import QUEUE_SIZE, PhaseDescription, datetime_now, human_readable_bytes if TYPE_CHECKING: from pyrocko.trace import Trace @@ -72,7 +72,9 @@ def _populate_table(self, table: Table) -> None: class ImageFunctions(RootModel): root: list[ImageFunctionType] = [PhaseNet()] - _queue: asyncio.Queue[Tuple[WaveformImages, WaveformBatch] | None] = PrivateAttr() + _queue: asyncio.Queue[Tuple[WaveformImages, WaveformBatch] | None] = PrivateAttr( + asyncio.Queue(maxsize=QUEUE_SIZE) + ) _processed_images: int = PrivateAttr(0) _stats: ImageFunctionsStats = PrivateAttr(default_factory=ImageFunctionsStats) @@ -81,7 +83,6 @@ def model_post_init(self, __context: Any) -> None: phases = self.get_phases() if len(set(phases)) != len(phases): raise ValueError("A phase was provided twice") - self._queue = asyncio.Queue(maxsize=16) self._stats.set_queue(self._queue) async def process_traces(self, traces: list[Trace]) -> WaveformImages: diff --git a/src/qseek/models/catalog.py b/src/qseek/models/catalog.py index a997b921..630a26b4 100644 --- a/src/qseek/models/catalog.py +++ b/src/qseek/models/catalog.py @@ -62,7 +62,7 @@ def new_detection(self, detection: EventDetection): self.max_semblance = max(self.max_semblance, detection.semblance) def _populate_table(self, table: Table) -> None: - table.add_row("No. Detections", f"[bold]{self.n_detections} :dim_button:") + table.add_row("No. Detections", f"[bold]{self.n_detections} :fire:") table.add_row("Maximum semblance", f"{self.max_semblance:.4f}") diff --git a/src/qseek/models/station.py b/src/qseek/models/station.py index 2b9dc6e0..dc60f607 100644 --- a/src/qseek/models/station.py +++ b/src/qseek/models/station.py @@ -165,7 +165,6 @@ def weed_from_squirrel_waveforms(self, squirrel: Squirrel) -> None: raise ValueError("no stations available, add waveforms to start detection") def __iter__(self) -> Iterator[Station]: - # TODO: this is inefficient return (sta for sta in self.stations if sta.nsl.pretty not in self.blacklist) @property @@ -186,14 +185,18 @@ def select_from_traces(self, traces: Iterable[Trace]) -> Stations: Returns: Stations: Containing only selected stations. """ + available_stations = tuple(self) + available_nsls = tuple(sta.nsl for sta in available_stations) + selected_stations = [] - for nsl in ((tr.network, tr.station, tr.location) for tr in traces): - for sta in self: - if sta.nsl == nsl: - selected_stations.append(sta) - break - else: - raise ValueError(f"could not find a station for {'.'.join(nsl)} ") + for nsl in {(tr.network, tr.station, tr.location) for tr in traces}: + try: + sta_idx = available_nsls.index(nsl) + selected_stations.append(available_stations[sta_idx]) + except ValueError as exc: + raise ValueError( + f"could not find a station for {'.'.join(nsl)} " + ) from exc return Stations.model_construct(stations=selected_stations) def get_centroid(self) -> Location: diff --git a/src/qseek/pre_processing/module.py b/src/qseek/pre_processing/module.py index ec1f504e..5cabb57b 100644 --- a/src/qseek/pre_processing/module.py +++ b/src/qseek/pre_processing/module.py @@ -16,7 +16,7 @@ Lowpass, ) from qseek.stats import Stats -from qseek.utils import datetime_now, human_readable_bytes +from qseek.utils import QUEUE_SIZE, datetime_now, human_readable_bytes if TYPE_CHECKING: from rich.table import Table @@ -71,7 +71,9 @@ class PreProcessing(RootModel): "The first module is the first to be applied.", ) - _queue: asyncio.Queue[WaveformBatch | None] = asyncio.Queue(maxsize=12) + _queue: asyncio.Queue[WaveformBatch | None] = PrivateAttr( + asyncio.Queue(maxsize=QUEUE_SIZE) + ) _stats: PreProcessingStats = PrivateAttr(default_factory=PreProcessingStats) def model_post_init(self, __context: Any) -> None: diff --git a/src/qseek/search.py b/src/qseek/search.py index a27b142a..3b5d48fe 100644 --- a/src/qseek/search.py +++ b/src/qseek/search.py @@ -500,12 +500,6 @@ async def start(self, force_rundir: bool = False) -> None: await self.prepare() - logger.info("starting search") - stats = self._stats - stats.reset_start_time() - - processing_start = datetime_now() - if self._progress.time_progress: logger.info("continuing search from %s", self._progress.time_progress) await self._catalog.check(repair=True) @@ -513,6 +507,8 @@ async def start(self, force_rundir: bool = False) -> None: start_time=None, end_time=self._progress.time_progress, ) + else: + logger.info("starting search") batches = self.data_provider.iter_batches( window_increment=self.window_length, @@ -522,6 +518,10 @@ async def start(self, force_rundir: bool = False) -> None: ) processed_batches = self.pre_processing.iter_batches(batches) + stats = self._stats + stats.reset_start_time() + + processing_start = datetime_now() console = asyncio.create_task(RuntimeStats.live_view()) async for images, batch in self.image_functions.iter_images(processed_batches): diff --git a/src/qseek/utils.py b/src/qseek/utils.py index ed40411a..80ba1e10 100644 --- a/src/qseek/utils.py +++ b/src/qseek/utils.py @@ -41,6 +41,7 @@ PhaseDescription = Annotated[str, constr(pattern=r"[a-zA-Z]*:[a-zA-Z]*")] +QUEUE_SIZE = 16 CACHE_DIR = Path.home() / ".cache" / "qseek" if not CACHE_DIR.exists(): logger.info("creating cache dir %s", CACHE_DIR) diff --git a/src/qseek/waveforms/squirrel.py b/src/qseek/waveforms/squirrel.py index 3fcb0d47..87398f65 100644 --- a/src/qseek/waveforms/squirrel.py +++ b/src/qseek/waveforms/squirrel.py @@ -22,7 +22,7 @@ from qseek.models.station import Stations from qseek.stats import Stats -from qseek.utils import datetime_now, human_readable_bytes, to_datetime +from qseek.utils import QUEUE_SIZE, datetime_now, human_readable_bytes, to_datetime from qseek.waveforms.base import WaveformBatch, WaveformProvider if TYPE_CHECKING: @@ -40,14 +40,10 @@ class SquirrelPrefetcher: _fetched_batches: int _task: asyncio.Task[None] - def __init__( - self, - iterator: Iterator[Batch], - queue_size: int = 8, - ) -> None: + def __init__(self, iterator: Iterator[Batch]) -> None: self.iterator = iterator - self.queue = asyncio.Queue(maxsize=queue_size) - self._load_queue = asyncio.Queue(maxsize=queue_size) + self.queue = asyncio.Queue(maxsize=QUEUE_SIZE) + self._load_queue = asyncio.Queue(maxsize=QUEUE_SIZE) self._fetched_batches = 0 self._task = asyncio.create_task(self.prefetch_worker()) @@ -143,10 +139,6 @@ class PyrockoSquirrel(WaveformProvider): description="Channel selector for waveforms, " "e.g. `['HH', 'EN']`.", ) ) - async_prefetch_batches: PositiveInt = Field( - default=10, - description="Queue size for asynchronous pre-fetcher.", - ) n_threads: PositiveInt = Field( default=8, description="Number of threads for loading waveforms.", @@ -227,10 +219,7 @@ async def iter_batches( codes=[(*nsl, "*") for nsl in self._stations.get_all_nsl()], channel_priorities=self.channel_selector, ) - prefetcher = SquirrelPrefetcher( - iterator, - queue_size=self.async_prefetch_batches, - ) + prefetcher = SquirrelPrefetcher(iterator) stats.set_queue(prefetcher.queue) while True: From fbdc17f6a30eae3f82f675b479ccfa2f4d3a81a5 Mon Sep 17 00:00:00 2001 From: Marius Isken Date: Sun, 7 Jul 2024 13:02:07 +0000 Subject: [PATCH 18/26] dependencies: numpy less 2 --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 2752ed72..8c72277b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,7 +33,7 @@ keywords = [ ] dependencies = [ - "numpy>=1.17.3", + "numpy>=1.17.3, <2", "scipy>=1.8.0", "pyrocko>=2022.06.10", "seisbench>=0.5.0", From 679579aaebef935acdc95600353af898b25c0b3c Mon Sep 17 00:00:00 2001 From: Marius Isken Date: Sun, 7 Jul 2024 13:53:11 +0000 Subject: [PATCH 19/26] exporter: simple reworks --- src/qseek/exporters/simple.py | 54 ++++++++++++++++++++++++++--------- src/qseek/search.py | 10 +++---- 2 files changed, 46 insertions(+), 18 deletions(-) diff --git a/src/qseek/exporters/simple.py b/src/qseek/exporters/simple.py index 63dde913..d9715071 100644 --- a/src/qseek/exporters/simple.py +++ b/src/qseek/exporters/simple.py @@ -3,7 +3,7 @@ import logging from pathlib import Path -from rich import progress +from rich import progress, prompt from qseek.exporters.base import Exporter from qseek.search import Search @@ -15,9 +15,25 @@ class Simple(Exporter): """Export simple travel times in CSV format (E. Biondi, 2023).""" - async def export(self, rundir: Path, outdir: Path) -> Path: + min_confidence_value: float = 0.3 + min_semblance_value: float = 0.5 + + async def export( + self, + rundir: Path, + outdir: Path, + ) -> Path: logger.info("Export simple travel times in CSV format.") + self.min_semblance_value = prompt.FloatPrompt.ask( + "Minimum event semblance value", + default=self.min_semblance_value, + ) + self.min_confidence_value = prompt.FloatPrompt.ask( + "Minimum pick confidence value", + default=self.min_confidence_value, + ) + search = Search.load_rundir(rundir) catalog = search.catalog @@ -25,9 +41,9 @@ async def export(self, rundir: Path, outdir: Path) -> Path: outdir.mkdir(parents=True) traveltime_dir.mkdir() - event_file = outdir / "events.csv" - self.search.stations.export_csv(outdir / "stations.csv") - await catalog.export_csv(event_file) + search.stations.export_csv(outdir / "stations.csv") + await catalog.export_csv(outdir / "events.csv") + (outdir / "simple-export.json").write_text(self.model_dump_json(indent=2)) for ev in progress.track( catalog, @@ -35,6 +51,21 @@ async def export(self, rundir: Path, outdir: Path) -> Path: total=catalog.n_events, ): traveltime_file = traveltime_dir / f"{time_to_path(ev.time)}.csv" + + if ev.semblance < self.min_semblance_value: + continue + + observed_arrivals = [ + (receiver, phase, arrival) + for receiver in ev.receivers + for phase, arrival in receiver.phase_arrivals.items() + if arrival.observed is not None + and arrival.observed.detection_value > self.min_confidence_value + ] + + if not observed_arrivals: + continue + with traveltime_file.open("w") as file: file.write(f"# event_id: {ev.uid}\n") file.write(f"# event_time: {ev.time}\n") @@ -47,13 +78,10 @@ async def export(self, rundir: Path, outdir: Path) -> Path: "lat,lon,elevation,network,station,location,phase,confidence,traveltime\n" ) - for receiver in ev.receivers: - for phase, arrival in receiver.phase_arrivals.items(): - if arrival.observed is None: - continue - traveltime = arrival.observed.time - ev.time - file.write( - f"{receiver.lat},{receiver.lon},{receiver.effective_elevation},{receiver.network},{receiver.station},{receiver.location},{phase},{arrival.observed.detection_value},{traveltime.total_seconds()}\n", - ) + for receiver, phase, arrival in observed_arrivals: + traveltime = arrival.observed.time - ev.time + file.write( + f"{receiver.lat},{receiver.lon},{receiver.effective_elevation},{receiver.network},{receiver.station},{receiver.location},{phase},{arrival.observed.detection_value},{traveltime.total_seconds()}\n", + ) return outdir diff --git a/src/qseek/search.py b/src/qseek/search.py index 3b5d48fe..6ccf85b5 100644 --- a/src/qseek/search.py +++ b/src/qseek/search.py @@ -68,7 +68,7 @@ class SearchStats(Stats): batch_time: datetime = datetime.min batch_count: int = 0 batch_count_total: int = 0 - processed_duration: timedelta = timedelta(seconds=0.0) + processed_time: timedelta = timedelta(seconds=0.0) processed_bytes: int = 0 processing_time: timedelta = timedelta(seconds=0.0) latest_processing_rate: float = 0.0 @@ -95,8 +95,8 @@ def time_remaining(self) -> timedelta: if not remaining_batches: return timedelta() - duration = datetime_now() - self._search_start - return duration / self.batch_count * remaining_batches + elapsed_time = datetime_now() - self._search_start + return (elapsed_time / self.batch_count) * remaining_batches @computed_field @property @@ -115,7 +115,7 @@ def processing_rate(self) -> float: def processing_speed(self) -> timedelta: if not self.processing_time: return timedelta(seconds=0.0) - return self.processed_duration / self.processing_time.total_seconds() + return self.processed_time / self.processing_time.total_seconds() @computed_field @property @@ -152,7 +152,7 @@ def add_processed_batch( self.batch_count_total = batch.n_batches self.batch_time = batch.end_time self.processed_bytes += batch.cumulative_bytes - self.processed_duration += batch.duration + self.processed_time += batch.duration self.processing_time += duration self.latest_processing_rate = batch.cumulative_bytes / duration.total_seconds() self.latest_processing_speed = batch.duration / duration.total_seconds() From 6759f231a5825463a9e3bd28f329710599dfa80f Mon Sep 17 00:00:00 2001 From: Marius Isken Date: Wed, 10 Jul 2024 13:57:44 +0000 Subject: [PATCH 20/26] stats: better statistics --- src/qseek/apps/qseek.py | 5 ++- src/qseek/exporters/simple.py | 12 ++++++++ src/qseek/models/detection.py | 4 ++- src/qseek/models/detection_uncertainty.py | 5 +++ src/qseek/models/location.py | 11 +++++-- src/qseek/models/station.py | 37 ++++++++++++++--------- src/qseek/search.py | 29 ++++++++++++++---- src/qseek/stats.py | 3 +- src/qseek/waveforms/base.py | 11 +++++++ 9 files changed, 90 insertions(+), 27 deletions(-) diff --git a/src/qseek/apps/qseek.py b/src/qseek/apps/qseek.py index b3acbd93..ffeec36b 100644 --- a/src/qseek/apps/qseek.py +++ b/src/qseek/apps/qseek.py @@ -417,7 +417,10 @@ def show_table(): table.add_column("Exporter") table.add_column("Description") for exporter in Exporter.get_subclasses(): - table.add_row(exporter.__name__.lower(), exporter.__doc__) + table.add_row( + f"[bold]{exporter.__name__.lower()}", + exporter.__doc__, + ) console.print(table) if args.format == "list": diff --git a/src/qseek/exporters/simple.py b/src/qseek/exporters/simple.py index d9715071..f2100319 100644 --- a/src/qseek/exporters/simple.py +++ b/src/qseek/exporters/simple.py @@ -45,6 +45,9 @@ async def export( await catalog.export_csv(outdir / "events.csv") (outdir / "simple-export.json").write_text(self.model_dump_json(indent=2)) + n_observed_arrivals = 0 + n_events = 0 + for ev in progress.track( catalog, description="Exporting travel times", @@ -66,6 +69,9 @@ async def export( if not observed_arrivals: continue + n_observed_arrivals += len(observed_arrivals) + n_events += 1 + with traveltime_file.open("w") as file: file.write(f"# event_id: {ev.uid}\n") file.write(f"# event_time: {ev.time}\n") @@ -84,4 +90,10 @@ async def export( f"{receiver.lat},{receiver.lon},{receiver.effective_elevation},{receiver.network},{receiver.station},{receiver.location},{phase},{arrival.observed.detection_value},{traveltime.total_seconds()}\n", ) + logger.info( + "Exported %d observed arrivals from %d events.", + n_observed_arrivals, + n_events, + ) + return outdir diff --git a/src/qseek/models/detection.py b/src/qseek/models/detection.py index acf0e288..54a537f1 100644 --- a/src/qseek/models/detection.py +++ b/src/qseek/models/detection.py @@ -706,9 +706,11 @@ def get_csv_dict(self) -> dict[str, Any]: csv_line.update( { "uncertainty_horizontal": self.uncertainty.horizontal, - "uncertainty_vertical": self.uncertainty.depth, + "uncertainty_vertical": self.uncertainty.vertical, } ) + csv_line["WKT_geom"] = self.as_wkt() + for magnitude in self.magnitudes: csv_line.update(magnitude.csv_row()) return csv_line diff --git a/src/qseek/models/detection_uncertainty.py b/src/qseek/models/detection_uncertainty.py index 6fcd4431..4ad8d01b 100644 --- a/src/qseek/models/detection_uncertainty.py +++ b/src/qseek/models/detection_uncertainty.py @@ -75,3 +75,8 @@ def total(self) -> float: def horizontal(self) -> float: """Calculate the horizontal uncertainty in [m].""" return float(np.sqrt(sum(self.east) ** 2 + sum(self.north) ** 2)) + + @computed_field + def vertical(self) -> float: + """Calculate the vertical uncertainty in [m].""" + return float(sum(self.depth)) diff --git a/src/qseek/models/location.py b/src/qseek/models/location.py index 5f367df2..ea0d95b5 100644 --- a/src/qseek/models/location.py +++ b/src/qseek/models/location.py @@ -191,6 +191,13 @@ def origin(self) -> Location: """ return Location(lat=self.lat, lon=self.lon, elevation=self.effective_elevation) + def as_wkt(self) -> str: + """Return the location as WKT string.""" + return ( + f"POINT Z({self.effective_lon} {self.effective_lat}" + f" {self.effective_elevation})" + ) + def __hash__(self) -> int: return hash(self.location_hash()) @@ -210,10 +217,10 @@ def location_hash(self) -> str: def locations_to_csv(locations: Iterable[Location], filename: Path) -> Path: - lines = ["lat, lon, elevation, type"] + lines = ["lat,lon,elevation,type"] for loc in locations: lines.append( - "%.4f, %.4f, %.4f, %s" + "%.4f,%.4f,%.4f,%s" % (*loc.effective_lat_lon, loc.effective_elevation, loc.__class__.__name__) ) filename.write_text("\n".join(lines)) diff --git a/src/qseek/models/station.py b/src/qseek/models/station.py index dc60f607..f587befd 100644 --- a/src/qseek/models/station.py +++ b/src/qseek/models/station.py @@ -169,9 +169,14 @@ def __iter__(self) -> Iterator[Station]: @property def n_stations(self) -> int: - """Number of stations in the stations object.""" + """Number of stations.""" return sum(1 for _ in self) + @property + def n_networks(self) -> int: + """Number of stations.""" + return len({sta.network for sta in self}) + def get_all_nsl(self) -> list[NSL]: """Get all NSL codes from all stations.""" return [sta.nsl for sta in self] @@ -179,24 +184,23 @@ def get_all_nsl(self) -> list[NSL]: def select_from_traces(self, traces: Iterable[Trace]) -> Stations: """Select stations by NSL code. + Stations are not unique and are ordered by the input traces. + Args: traces (Iterable[Trace]): Iterable of Pyrocko Traces Returns: Stations: Containing only selected stations. """ - available_stations = tuple(self) - available_nsls = tuple(sta.nsl for sta in available_stations) - - selected_stations = [] - for nsl in {(tr.network, tr.station, tr.location) for tr in traces}: - try: - sta_idx = available_nsls.index(nsl) - selected_stations.append(available_stations[sta_idx]) - except ValueError as exc: - raise ValueError( - f"could not find a station for {'.'.join(nsl)} " - ) from exc + available_stations = {sta.nsl: sta for sta in self} + try: + selected_stations = [ + available_stations[(tr.network, tr.station, tr.location)] + for tr in traces + ] + except KeyError as exc: + raise ValueError("could not find a station") from exc + return Stations.model_construct(stations=selected_stations) def get_centroid(self) -> Location: @@ -248,11 +252,14 @@ def export_csv(self, filename: Path) -> None: filename (Path): Path to CSV file. """ with filename.open("w") as f: - f.write("network,station,location,latitude,longitude,elevation,depth\n") + f.write( + "network,station,location,latitude,longitude,elevation,depth,WKT_geom\n" + ) for sta in self: f.write( f"{sta.network},{sta.station},{sta.location}," - f"{sta.lat},{sta.lon},{sta.elevation},{sta.depth}\n" + f"{sta.effective_lat},{sta.effective_lon},{sta.elevation}," + f"{sta.depth}{sta.as_wkt()}\n" ) def export_vtk(self, reference: Location | None = None) -> None: ... diff --git a/src/qseek/search.py b/src/qseek/search.py index 6ccf85b5..70d43e8e 100644 --- a/src/qseek/search.py +++ b/src/qseek/search.py @@ -74,6 +74,12 @@ class SearchStats(Stats): latest_processing_rate: float = 0.0 latest_processing_speed: timedelta = timedelta(seconds=0.0) + current_stations: int = 0 + total_stations: int = 0 + + current_networks: int = 0 + total_networks: int = 0 + memory_total: ByteSize = Field( default_factory=lambda: ByteSize(psutil.virtual_memory().total) ) @@ -156,6 +162,9 @@ def add_processed_batch( self.processing_time += duration self.latest_processing_rate = batch.cumulative_bytes / duration.total_seconds() self.latest_processing_speed = batch.duration / duration.total_seconds() + self.current_stations = batch.n_stations + self.current_networks = batch.n_networks + self._batch_processing_times.append(duration) if show_log: self.log() @@ -182,23 +191,28 @@ def tts(duration: timedelta) -> str: "Project", f"[bold]{self.project_name}[/bold]", ) - table.add_row( - "Resources", - f"CPU {self.cpu_percent:.1f}%, " - f"RAM {human_readable_bytes(self.memory_used, decimal=True)}" - f"/{self.memory_total.human_readable(decimal=True)}", - ) table.add_row( "Progress ", f"[bold]{self.processed_percent:.1f}%[/bold]" f" ([bold]{self.batch_count}[/bold]/{self.batch_count_total or '?'}," f' {self.batch_time.strftime("%Y-%m-%d %H:%M:%S")})', ) + table.add_row( + "Stations", + f"{self.current_stations}/{self.total_stations}" + f" ({self.current_networks}/{self.total_networks} networks)", + ) table.add_row( "Processing rate", f"{human_readable_bytes(self.processing_rate)}/s" f" ({tts(self.processing_speed)} tr/s)", ) + table.add_row( + "Resources", + f"CPU {self.cpu_percent:.1f}%, " + f"RAM {human_readable_bytes(self.memory_used, decimal=True)}" + f"/{self.memory_total.human_readable(decimal=True)}", + ) table.add_row( "Remaining Time", f"{tts(self.time_remaining)}, " @@ -492,7 +506,10 @@ async def prepare(self) -> None: for magnitude in self.magnitudes: await magnitude.prepare(self.octree, self.stations) await self.init_boundaries() + self._stats.project_name = self._rundir.name + self._stats.total_stations = self.stations.n_stations + self._stats.total_networks = self.stations.n_networks async def start(self, force_rundir: bool = False) -> None: if not self.has_rundir(): diff --git a/src/qseek/stats.py b/src/qseek/stats.py index fe216acb..40734f52 100644 --- a/src/qseek/stats.py +++ b/src/qseek/stats.py @@ -7,7 +7,6 @@ from pydantic import BaseModel, PrivateAttr, create_model from pydantic.fields import ComputedFieldInfo, FieldInfo -from rich.console import Group from rich.live import Live from rich.panel import Panel from rich.progress import Progress @@ -57,7 +56,7 @@ def generate_grid() -> Table: stats._populate_table(table) grid = table.grid(expand=True) grid.add_row(PROGRESS) - grid.add_row(Group(Panel(table, title="QSeek"))) + grid.add_row(Panel(table, title="QSeek")) return grid with Live( diff --git a/src/qseek/waveforms/base.py b/src/qseek/waveforms/base.py index 4aa2bec2..e1b6b04e 100644 --- a/src/qseek/waveforms/base.py +++ b/src/qseek/waveforms/base.py @@ -46,8 +46,19 @@ def cumulative_duration(self) -> timedelta: @property def cumulative_bytes(self) -> int: + """Cumulative size of the traces in the batch in bytes.""" return sum(tr.ydata.nbytes for tr in self.traces) + @property + def n_stations(self) -> int: + """Number of unique stations in the batch.""" + return len({(tr.network, tr.station, tr.location) for tr in self.traces}) + + @property + def n_networks(self) -> int: + """Number of unique networks in the batch.""" + return len({tr.network for tr in self.traces}) + def is_empty(self) -> bool: """Check if the batch is empty. From 083d0e63220feb7713b830d5c3ad455660dca42a Mon Sep 17 00:00:00 2001 From: Marius Isken Date: Tue, 16 Jul 2024 11:02:29 +0000 Subject: [PATCH 21/26] refactoring corrections --- .pre-commit-config.yaml | 2 +- src/qseek/apps/qseek.py | 70 +++++------------------------------ src/qseek/corrections/base.py | 16 ++------ src/qseek/models/catalog.py | 15 +++++++- src/qseek/models/station.py | 2 +- src/qseek/octree.py | 7 +++- src/qseek/utils.py | 4 +- 7 files changed, 37 insertions(+), 79 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 9a64928e..87bdc9eb 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -10,7 +10,7 @@ repos: - id: mixed-line-ending - repo: https://github.com/astral-sh/ruff-pre-commit # Ruff version. - rev: v0.4.4 + rev: v0.5.2 hooks: - id: ruff - id: ruff-format diff --git a/src/qseek/apps/qseek.py b/src/qseek/apps/qseek.py index ffeec36b..c542b01b 100644 --- a/src/qseek/apps/qseek.py +++ b/src/qseek/apps/qseek.py @@ -105,22 +105,6 @@ help="show semblance trace in snuffler", ) -station_corrections = subparsers.add_parser( - "corrections", - help="analyse and extract station corrections from a run", - description="Analyze and plot station corrections from a finished run", -) -station_corrections.add_argument( - "--plot", - action="store_true", - default=False, - help="plot station correction results and save to rundir", -) -corrections_rundir = station_corrections.add_argument( - "rundir", - type=Path, - help="path of existing run", -) features_extract = subparsers.add_parser( "feature-extraction", @@ -151,11 +135,10 @@ description="Show all available modules", ) modules.add_argument( - "--json", - "-j", + "name", + nargs="?", type=str, - help="print module's JSON config", - default="", + help="Name of the module to print JSON config for.", ) serve = subparsers.add_parser( @@ -192,9 +175,9 @@ export.add_argument( "export_dir", + nargs="?", type=Path, help="path to export directory", - nargs="?", ) export.add_argument( @@ -232,7 +215,6 @@ continue_rundir.completer = DirectoriesCompleter() snuffler_rundir.completer = DirectoriesCompleter() features_rundir.completer = DirectoriesCompleter() - corrections_rundir.completer = DirectoriesCompleter() dump_dir.completer = DirectoriesCompleter() argcomplete.autocomplete(parser) @@ -246,7 +228,6 @@ def main() -> None: load_insights() from rich import box from rich.progress import Progress - from rich.prompt import IntPrompt from rich.table import Table from qseek.console import console @@ -362,40 +343,6 @@ async def worker() -> None: asyncio.run(worker(), debug=loop_debug) - case "corrections": - import json - - from qseek.corrections.base import TravelTimeCorrections - - rundir = Path(args.rundir) - - corrections_modules = TravelTimeCorrections.get_subclasses() - - console.print("[bold]Available travel time corrections modules") - for imodule, module in enumerate(corrections_modules): - console.print(f"{imodule}: {module.__name__}") - - module_choice = IntPrompt.ask( - "Choose station corrections module", - choices=[str(i) for i in range(len(corrections_modules))], - default="0", - console=console, - ) - travel_time_corrections = corrections_modules[int(module_choice)] - corrections = asyncio.run( - travel_time_corrections.setup(rundir, console), debug=loop_debug - ) - - search = json.loads((rundir / "search.json").read_text()) - search["station_corrections"] = corrections.model_dump(mode="json") - - new_config_file = rundir.parent / f"{rundir.name}-corrections.json" - console.print("writing new config file") - console.print( - f"to use this config file, run [bold]qseek search {new_config_file}" - ) - new_config_file.write_text(json.dumps(search, indent=2)) - case "serve": search = Search.load_rundir(args.rundir) webserver = WebServer(search) @@ -479,10 +426,10 @@ def show_table(): TravelTimeCorrections, ) - if args.json: + if args.name: for module in module_classes: for subclass in module.get_subclasses(): - if subclass.__name__ == args.json: + if subclass.__name__ == args.name: console.print_json(subclass().model_dump_json(indent=2)) parser.exit() else: @@ -501,9 +448,10 @@ def is_insight(module: type) -> bool: table.add_section() console.print(table) - console.print("🔑 indicates an insight module\n") + console.print("Insight module are marked by 🔑\n") console.print( - "Use `qseek modules --json ` to print the JSON schema" + "Use [bold]qseek modules [/bold] " + "to print the JSON schema" ) case "dump-schemas": diff --git a/src/qseek/corrections/base.py b/src/qseek/corrections/base.py index 7c5543e3..587d4572 100644 --- a/src/qseek/corrections/base.py +++ b/src/qseek/corrections/base.py @@ -3,13 +3,11 @@ from typing import TYPE_CHECKING, Iterable, Literal from pydantic import BaseModel -from typing_extensions import Self if TYPE_CHECKING: from pathlib import Path import numpy as np - from rich.console import Console from qseek.models.station import Stations from qseek.octree import Node, Octree @@ -75,7 +73,7 @@ async def prepare( stations: Stations, octree: Octree, phases: Iterable[PhaseDescription], - rundir: Path | None = None, + rundir: Path, ) -> None: """Prepare the station for the corrections. @@ -83,14 +81,8 @@ async def prepare( stations (Stations): The station to prepare. octree (Octree): The octree to use for the preparation. phases (Iterable[PhaseDescription]): The phases to prepare the station for. - rundir (Path | None, optional): The rundir to use for the delay. + rundir (Path): The rundir to use for the delay. Defaults to None. + import_rundir (Path): The import rundir to use + for extracting the delays. """ - ... - - @classmethod - async def setup(cls, rundir: Path, console: Console | None = None) -> Self: - """Prepare the station corrections for the console.""" - if console: - console.print("This module does not require any preparation.") - return cls() diff --git a/src/qseek/models/catalog.py b/src/qseek/models/catalog.py index 630a26b4..72707cac 100644 --- a/src/qseek/models/catalog.py +++ b/src/qseek/models/catalog.py @@ -1,6 +1,6 @@ from __future__ import annotations -from datetime import datetime +from datetime import datetime, timezone from pathlib import Path from typing import TYPE_CHECKING, Any, Iterator @@ -154,6 +154,19 @@ def save_semblance_trace(self, trace: Trace) -> None: append=True, ) + @classmethod + def last_modification(cls, rundir: Path) -> datetime: + """Last modification of the event file. + + Returns: + datetime: Last modification of the event file. + """ + detection_file = rundir / FILENAME_DETECTIONS + return datetime.fromtimestamp( + detection_file.stat().st_mtime, + tz=timezone.utc, + ) + @classmethod def load_rundir(cls, rundir: Path) -> EventCatalog: """Load detections from files in the detections directory.""" diff --git a/src/qseek/models/station.py b/src/qseek/models/station.py index f587befd..2887bf56 100644 --- a/src/qseek/models/station.py +++ b/src/qseek/models/station.py @@ -259,7 +259,7 @@ def export_csv(self, filename: Path) -> None: f.write( f"{sta.network},{sta.station},{sta.location}," f"{sta.effective_lat},{sta.effective_lon},{sta.elevation}," - f"{sta.depth}{sta.as_wkt()}\n" + f"{sta.depth},{sta.as_wkt()}\n" ) def export_vtk(self, reference: Location | None = None) -> None: ... diff --git a/src/qseek/octree.py b/src/qseek/octree.py index 70145853..0a2aaeea 100644 --- a/src/qseek/octree.py +++ b/src/qseek/octree.py @@ -317,7 +317,7 @@ def __hash__(self) -> int: return hash(self.hash()) -class Octree(BaseModel): +class Octree(BaseModel, Iterator[Node]): location: Location = Field( default=Location(lat=0.0, lon=0.0), description="The reference location of the octree.", @@ -436,6 +436,11 @@ def __iter__(self) -> Iterator[Node]: for node in self._root_nodes: yield from node + def __next__(self) -> Node: + for node in self: + return node + raise StopIteration + def __getitem__(self, idx: int) -> Node: for inode, node in enumerate(self): if inode == idx: diff --git a/src/qseek/utils.py b/src/qseek/utils.py index 80ba1e10..16977fb9 100644 --- a/src/qseek/utils.py +++ b/src/qseek/utils.py @@ -130,7 +130,7 @@ def parse(cls, nsl: str | NSL | list[str] | tuple[str, str, str]) -> NSL: ValueError: If the NSL string is empty or invalid. """ if not nsl: - raise ValueError("invalid empty NSL") + raise ValueError(f"invalid empty NSL: {nsl}") if type(nsl) is _NSL: return nsl if isinstance(nsl, (list, tuple)): @@ -146,7 +146,7 @@ def parse(cls, nsl: str | NSL | list[str] | tuple[str, str, str]) -> NSL: return cls(parts[0], parts[1], "") raise ValueError( f"invalid NSL `{nsl}`, expecting `..`, " - "e.g. `6A.STA130.00`, `6A.STA130` or `.STA130`" + "e.g. `6A.STA130.00`, `6A.`, `6A.STA130` or `.STA130`" ) def _check(self) -> None: From 766ffc7814ba84a70e29889e6c91842491cc824f Mon Sep 17 00:00:00 2001 From: Marius Isken Date: Tue, 16 Jul 2024 15:46:37 +0000 Subject: [PATCH 22/26] adding velest exporter --- src/qseek/exporters/velest.py | 276 +++++++++++++++++++++++++++++++++- 1 file changed, 268 insertions(+), 8 deletions(-) diff --git a/src/qseek/exporters/velest.py b/src/qseek/exporters/velest.py index 7286650a..a5ca4a47 100644 --- a/src/qseek/exporters/velest.py +++ b/src/qseek/exporters/velest.py @@ -2,34 +2,294 @@ import logging from pathlib import Path +from typing import NamedTuple +import numpy as np import rich -from rich.prompt import FloatPrompt +from rich.prompt import FloatPrompt, IntPrompt from qseek.exporters.base import Exporter +from qseek.models.detection import EventDetection, PhaseDetection, Receiver +from qseek.models.station import Location, Station from qseek.search import Search logger = logging.getLogger(__name__) +KM = 1000.0 + +CONFIDENCE_QUALITY_BINS = [1.0, 0.8, 0.6, 0.4, 0.0] + +CONTROL_FILE_TPL = """velest parameters must be modified according to documentation +{ref_lat} {ref_lon} 0 0.0 0 0.00 1 +{n_earthquakes} 0 0.0 +{isingle:1d} 0 +{max_distance_station} 0 {min_depth} 0.20 5.00 {use_station_correction:1d} +2 0.75 {vp_vs_ratio} 1 +0.01 0.01 0.01 {velocity_damping} {station_correction_damping} +1 0 0 {use_elevation:1d} {use_station_correction:1d} +1 1 2 0 +0 0 0 0 0 0 0 +0.001 {iteration_number} {invertratio} +{model_file} +stations_velest.sta + +regionsnamen.dat +regionskoord.dat + + +{phase_file} + +{mainout_file} +{outcheck_file} +{finalcnv_file} +{stacorrection_file} +""" + + +class VelestControlFile(NamedTuple): + ref_lat: float + ref_lon: float # should be negative for East + n_earthquakes: int + isingle: bool = True + min_depth: float = -0.2 + vp_vs_ratio: float = 1.65 + iteration_number: int = 99 + invertratio: int = 0 + model_file: str = "model.mod" + phase_file: str = "phase_velest.pha" + mainout_file: str = "main.out" + outcheck_file: str = "log.out" + finalcnv_file: str = "final.cnv" + stacorrection_file: str = "stacor.dat" + velocity_damping: float = 1.0 # Damping parameter for the velocity + station_correction_damping: float = 0.1 # Damping parameter for the station + max_distance_station: float = 200.0 + use_elevation: bool = False + use_station_correction: bool = False + allow_low_velocity: bool = False + + def write_config_file(self, file: Path): + file.write_text(CONTROL_FILE_TPL.format(**self._asdict())) class Velest(Exporter): """Crate a VELEST project folder for 1D velocity model estimation.""" - min_pick_semblance: float = 0.3 - n_picks: dict[str, int] = {} + min_event_semblance: float = 0.2 + min_receivers_number: int = 10 + min_p_phase_confidence: float = 0.3 + min_s_phase_confidence: float = 0.3 + max_traveltime_delay: float = 2.5 + distance_border: float = 500.0 + n_picks_p: int = 0 + n_picks_s: int = 0 n_events: int = 0 async def export(self, rundir: Path, outdir: Path) -> Path: rich.print("Exporting qseek search to VELEST project folder") - min_pick_semblance = FloatPrompt.ask("Minimum pick confidence", default=0.3) - - self.min_pick_semblance = min_pick_semblance + self.min_event_semblance = FloatPrompt.ask( + "Minimum event semblance", + default=self.min_event_semblance, + ) + self.min_receivers_number = IntPrompt.ask( + "Minimum number of receivers (P phase)", + default=self.min_receivers_number, + ) + self.min_distance_to_border = FloatPrompt.ask( + "Minimum distance to border(m)", + default=self.distance_border, + ) + self.min_p_phase_confidence = FloatPrompt.ask( + "Minimum pick probability for P phase", + default=self.min_p_phase_confidence, + ) + self.min_s_phase_confidence = FloatPrompt.ask( + "Minimum pick probability for S phase", + default=self.min_s_phase_confidence, + ) + self.max_traveltime_delay = FloatPrompt.ask( + "Maximum travel time delay", + default=self.max_traveltime_delay, + ) outdir.mkdir() search = Search.load_rundir(rundir) - catalog = search.catalog # noqa + phases = search.image_functions.get_phases() + for phase in phases: + if phase.endswith("P"): + phase_p = phase + if phase.endswith("S"): + phase_s = phase + + catalog = search.catalog + + # export station file + stations = search.stations.stations + station_file = outdir / "stations_velest.sta" + self.export_station(stations=stations, filename=station_file) + + # export phase file + phase_file = outdir / "phase_velest.pha" + n_earthquakes = 0 + for event in catalog: + if event.semblance < self.min_event_semblance: + continue + if event.receivers.n_observations(phase_p) < self.min_receivers_number: + continue + if event.distance_border < self.min_distance_to_border: + continue + + observed_arrivals: list[tuple[Receiver, PhaseDetection]] = [] + + for receiver in event.receivers: + for _phase, detection in receiver.phase_arrivals.items(): + if detection.observed is None: + continue + observed = detection.observed + if ( + detection.phase == phase_p + and observed.detection_value <= self.min_p_phase_confidence + ): + continue + if ( + detection.phase == phase_s + and observed.detection_value <= self.min_s_phase_confidence + ): + continue + if ( + detection.traveltime_delay.total_seconds() + > self.max_traveltime_delay + ): + continue + observed_arrivals.append((receiver, detection)) + + count_p, count_s = self.export_phases_slim( + phase_file, event, observed_arrivals + ) + self.n_picks_p += count_p + self.n_picks_s += count_s + n_earthquakes += 1 + self.n_events = n_earthquakes + + # export control file + control_file = outdir / "velest.cmn" + control_file_parameters = VelestControlFile( + ref_lat=search.octree.location.lat, + ref_lon=-search.octree.location.lon, + n_earthquakes=n_earthquakes, + ) + control_file_parameters.write_config_file(control_file) + # export velocity model file + dep = search.ray_tracers.root[0].earthmodel.layered_model.profile("z") + vp = search.ray_tracers.root[0].earthmodel.layered_model.profile("vp") + vs = search.ray_tracers.root[0].earthmodel.layered_model.profile("vs") + dep_velest = [] + vp_velest = [] + vs_velest = [] + for d, vpi, vsi in zip(dep, vp, vs, strict=True): + if float(d) / KM not in dep_velest: + dep_velest.append(float(d) / KM) + vp_velest.append(float(vpi) / KM) + vs_velest.append(float(vsi) / KM) + velmod_file = outdir / "model.mod" + make_velmod_file(velmod_file, vp_velest, vs_velest, dep_velest) export_info = outdir / "export_info.json" export_info.write_text(self.model_dump_json(indent=2)) - return outdir + + def export_phases_slim( + self, + outfile: Path, + event: EventDetection, + observed_arrivals: list[tuple[Receiver, PhaseDetection]], + ): + mag = event.magnitude.average if event.magnitude is not None else 0.0 + lat, lon = velest_location(event) + write_out = ( + f"{event.time:%y%m%d %H%M %S.%f}"[:-4] + + f" {lat} {lon} {event.depth/1000:7.2f} {mag:5.2f}\n" + ) + count_p = 0 + count_s = 0 + for rec, dectection in observed_arrivals: + quality_weight = ( + np.digitize( + dectection.observed.detection_value, + CONFIDENCE_QUALITY_BINS, + ) + - 1 + ) + if dectection.phase.endswith("P"): + phase = "P" + count_p += 1 + else: + phase = "S" + count_s += 1 + traveltime = (dectection.observed.time - event.time).total_seconds() + write_out += ( + f" {rec.station:6s} {phase:1s} " + f"{quality_weight:1d} {traveltime:7.2f}\n" + ) + write_out += "\n" + + if count_p or count_s: + with outfile.open("a") as file: + file.write(write_out) + else: + logger.warning("Event {event.time}: No phases observed") + + return count_p, count_s + + def export_station(self, stations: list[Station], filename: Path) -> None: + with filename.open("w") as fpout: + fpout.write("(a6,f7.4,a1,1x,f8.4,a1,1x,i4,1x,i1,1x,i3,1x,f5.2,2x,f5.2)\n") + station_index = 1 + for station in stations: + lat, lon = velest_location(station) + fpout.write( + f"{station.station:6s}{lat} {lon} {int(station.elevation):4d} " + f"1 {station_index:3d} 0.00 0.00\n" + ) + station_index += 1 + fpout.write("\n") + + +def velest_location(location: Location) -> tuple[str, str]: + """Return VELEST formatted latitude and longitude. + + Args: + location: Location object. + + Returns: + tuple: VELEST formatted latitude and longitude. + """ + if location.effective_lat < 0: + velest_lat = f"{location.effective_lat:7.4f}S" + else: + velest_lat = f"{location.effective_lat:7.4f}N" + if location.effective_lon < 0: + velest_lon = f"{location.effective_lon:8.4f}W" + else: + velest_lon = f"{location.effective_lon:8.4f}E" + return velest_lat, velest_lon + + +def make_velmod_file( + modname: Path, + velocity_p: list[float], + velocity_s: list[float], + depths: list[float], +) -> None: + nlayer = len(depths) + vdamp = 1.0 + with modname.open("w") as fp: + fp.write("initial 1D-model for velest\n") + # the second line - indicate the number of layers for Vp + fp.write(f"{nlayer} vel,depth,vdamp,phase (f5.2,5x,f7.2,2x,f7.3,3x,a1)\n") + + for vel, depth in zip(velocity_p, depths, strict=True): + fp.write(f"{vel:5.2f} {depth:7.2f} {vdamp:7.3f}\n") + + fp.write("%3d\n" % nlayer) + for vel, depth in zip(velocity_s, depths, strict=True): + fp.write(f"{vel:5.2f} {depth:7.2f} {vdamp:7.3f}\n") From 2a0b9786686a3928c33b6adf8bd836c2ef4cb534 Mon Sep 17 00:00:00 2001 From: Marius Isken Date: Wed, 17 Jul 2024 14:27:59 +0000 Subject: [PATCH 23/26] project update --- README.md | 7 +- pyproject.toml | 6 +- src/qseek/images/images.py | 6 +- .../images/{phase_net.py => seisbench.py} | 94 ++++++++++++++----- src/qseek/stats.py | 2 +- 5 files changed, 85 insertions(+), 30 deletions(-) rename src/qseek/images/{phase_net.py => seisbench.py} (78%) diff --git a/README.md b/README.md index d9c7ae19..0d244cab 100644 --- a/README.md +++ b/README.md @@ -13,7 +13,12 @@ Qseek is an earthquake detection and localisation framework based on stacking an Key features are of the earthquake detection and localisation framework are: -* Earthquake phase detection using machine-learning pickers from [SeisBench](https://github.com/seisbench/seisbench) +* Earthquake phase detection using machine-learning model from [SeisBench](https://github.com/seisbench/seisbench), pre-trained on different data sets. + * [PhaseNet (Zhu and Beroza, 2018](https://doi.org/10.1093/gji/ggy423) + * [EQTransformer (Mousavi et al., 2020)](https://doi.org/10.1038/s41467-020-17591-w) + * [GPD (Ross et al., 2018)](https://doi.org/10.1785/0120180080) + * [OBSTransformer (Niksejel and Zahng, 2024)](https://doi.org/10.1093/gji/ggae049) + * [LFEDetect](https://doi.org/10.1093/gji/ggae049) * Octree localisation approach for efficient and accurate search * Different velocity models: * Constant velocity diff --git a/pyproject.toml b/pyproject.toml index 8c72277b..b455f823 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -67,9 +67,9 @@ classifiers = [ dev = ["pre-commit>=3.4", "ruff>=0.3.0", "pytest>=7.4", "pytest-asyncio>=0.21"] docs = [ - "mkdocs-material>=9.5.13", - "mkdocstrings[python]>=0.23", - "markdown-exec>=1.8.0", + "mkdocs-material>=9.5", + "mkdocstrings[python]>=0.25", + "markdown-exec>=1.9", ] completion = ["argcomplete>=3.2"] diff --git a/src/qseek/images/images.py b/src/qseek/images/images.py index d4451004..fc342129 100644 --- a/src/qseek/images/images.py +++ b/src/qseek/images/images.py @@ -10,7 +10,7 @@ from pydantic import Field, PositiveInt, PrivateAttr, RootModel, computed_field from qseek.images.base import ImageFunction -from qseek.images.phase_net import PhaseNet +from qseek.images.seisbench import SeisBench from qseek.stats import Stats from qseek.utils import QUEUE_SIZE, PhaseDescription, datetime_now, human_readable_bytes @@ -27,7 +27,7 @@ ImageFunctionType = Annotated[ - Union[PhaseNet, ImageFunction], + Union[SeisBench, ImageFunction], Field(..., discriminator="image"), ] @@ -70,7 +70,7 @@ def _populate_table(self, table: Table) -> None: class ImageFunctions(RootModel): - root: list[ImageFunctionType] = [PhaseNet()] + root: list[ImageFunctionType] = [SeisBench()] _queue: asyncio.Queue[Tuple[WaveformImages, WaveformBatch] | None] = PrivateAttr( asyncio.Queue(maxsize=QUEUE_SIZE) diff --git a/src/qseek/images/phase_net.py b/src/qseek/images/seisbench.py similarity index 78% rename from src/qseek/images/phase_net.py rename to src/qseek/images/seisbench.py index a5e2bf55..46c46a1f 100644 --- a/src/qseek/images/phase_net.py +++ b/src/qseek/images/seisbench.py @@ -23,20 +23,42 @@ if TYPE_CHECKING: from pyrocko.trace import Trace - from seisbench.models import PhaseNet as PhaseNetSeisBench + from seisbench.models import WaveformModel + ModelName = Literal[ + "PhaseNet", + "EQTransformer", + "GPD", + "OBSTransformer", + "LFEDetect", +] + + +PreTrainedName = Literal[ + "cascadia", + "cms", "diting", + "dummy", "ethz", "geofon", "instance", "iquique", + "jcms", + "jcs", + "jms", "lendb", + "mexico", + "nankai", "neic", "obs", + "obst2024", "original", + "original_nonconservative", + "san_andreas", "scedc", "stead", + "volpick", ] PhaseName = Literal["P", "S"] @@ -132,13 +154,19 @@ def search_phase_arrival( ) -class PhaseNet(ImageFunction): +class SeisBench(ImageFunction): """PhaseNet image function. For more details see SeisBench documentation.""" - image: Literal["PhaseNet"] = "PhaseNet" + image: Literal["SeisBench"] = "SeisBench" + model: ModelName = Field( + default="PhaseNet", + description="The model to use for the image function. Currently only `PhaseNet`", + ) + + pretrained: PreTrainedName = Field( default="original", - description="SeisBench pre-trained PhaseNet model to use. " + description="SeisBench pre-trained model to use. " "Choose from `ethz`, `geofon`, `instance`, `iquique`, `lendb`, `neic`, `obs`," " `original`, `scedc`, `stead`." " For more details see SeisBench documentation", @@ -189,46 +217,68 @@ class PhaseNet(ImageFunction): description="Weights for each phase.", ) - _phase_net: PhaseNetSeisBench = PrivateAttr(None) + _seisbench_model: WaveformModel = PrivateAttr(None) @property - def phase_net(self) -> PhaseNetSeisBench: - if self._phase_net is None: + def seisbench_model(self) -> WaveformModel: + if self._seisbench_model is None: self._prepare() - return self._phase_net + return self._seisbench_model def _prepare(self) -> None: + import seisbench.models as sbm import torch - from seisbench.models import PhaseNet as PhaseNetSeisBench torch.set_num_threads(self.torch_cpu_threads) - self._phase_net = PhaseNetSeisBench.from_pretrained(self.model) + + match self.model: + case "PhaseNet": + model = sbm.PhaseNet + case "EQTransformer": + model = sbm.EQTransformer + case "GPD": + model = sbm.GPD + case "OBSTransformer": + model = sbm.OBSTransformer + case "LFEDetect": + model = sbm.LFEDetect + case _: + raise ValueError(f"Model `{self.model}` not available.") + + self._seisbench_model = model.from_pretrained(self.pretrained) if self.torch_use_cuda: try: if isinstance(self.torch_use_cuda, bool): - self._phase_net.cuda() + self._seisbench_model.cuda() else: - self._phase_net.cuda(self.torch_use_cuda) + self._seisbench_model.cuda(self.torch_use_cuda) except RuntimeError as exc: logger.warning( "failed to use CUDA for PhaseNet model, using CPU.", exc_info=exc, ) - self._phase_net.eval() + self._seisbench_model.eval() try: logger.info("compiling PhaseNet model...") - self._phase_net = torch.compile(self._phase_net, mode="max-autotune") + self._seisbench_model = torch.compile( + self._seisbench_model, + mode="max-autotune", + ) except RuntimeError as exc: logger.warning( - "failed to compile PhaseNet model, using uncompiled model.", + "failed to compile SeisBench model, using uncompiled model.", exc_info=exc, ) + def get_blinding_samples(self) -> tuple[int, int]: + try: + return self.seisbench_model.default_args["blinding"] + except KeyError: + return self.seisbench_model._annotate_args["blinding"][1] + def get_blinding(self, sampling_rate: float) -> timedelta: - blinding_samples = ( - max(self.phase_net.default_args["blinding"]) / self.rescale_input - ) - return timedelta(seconds=blinding_samples / sampling_rate) + scaled_blinding_samples = max(self.get_blinding_samples()) / self.rescale_input + return timedelta(seconds=scaled_blinding_samples / sampling_rate) def _detection_half_width(self) -> float: """Half width of the detection window in seconds.""" @@ -244,7 +294,7 @@ async def process_traces(self, traces: list[Trace]) -> list[PhaseNetImage]: tr.stats.sampling_rate /= scale annotations: Stream = await asyncio.to_thread( - self.phase_net.annotate, + self.seisbench_model.annotate, stream, overlap=self.window_overlap_samples, batch_size=self.batch_size, @@ -255,7 +305,7 @@ async def process_traces(self, traces: list[Trace]) -> list[PhaseNetImage]: scale = self.rescale_input for tr in annotations: tr.stats.sampling_rate *= scale - blinding_samples = self.phase_net.default_args["blinding"][0] + blinding_samples = max(self.get_blinding_samples()) # 100 Hz is the native sampling rate of PhaseNet blinding_seconds = (blinding_samples / 100.0) * (1.0 - 1 / scale) tr.stats.starttime -= blinding_seconds @@ -263,7 +313,7 @@ async def process_traces(self, traces: list[Trace]) -> list[PhaseNetImage]: annotated_traces: list[Trace] = [ tr.to_pyrocko_trace() for tr in annotations - if not tr.stats.channel.endswith("N") + if tr.stats.channel.endswith("P") or tr.stats.channel.endswith("S") ] annotation_p = PhaseNetImage( diff --git a/src/qseek/stats.py b/src/qseek/stats.py index 40734f52..4179a33e 100644 --- a/src/qseek/stats.py +++ b/src/qseek/stats.py @@ -54,7 +54,7 @@ def generate_grid() -> Table: ) table.add_section() stats._populate_table(table) - grid = table.grid(expand=True) + grid = Table.grid(expand=True) grid.add_row(PROGRESS) grid.add_row(Panel(table, title="QSeek")) return grid From fee5e362fb6ca4206c6c2c71d490382a152b119e Mon Sep 17 00:00:00 2001 From: Marius Isken Date: Wed, 17 Jul 2024 14:30:54 +0000 Subject: [PATCH 24/26] remaing interpolation function --- src/qseek/octree.py | 5 ++--- src/qseek/search.py | 2 +- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/src/qseek/octree.py b/src/qseek/octree.py index 0a2aaeea..464458c5 100644 --- a/src/qseek/octree.py +++ b/src/qseek/octree.py @@ -632,7 +632,7 @@ def total_number_nodes(self) -> int: """ return len(self._root_nodes) * (8 ** (self.n_levels - 1)) - async def interpolate_max_location( + async def interpolate_max_semblance( self, peak_node: Node, ) -> Location: @@ -664,8 +664,7 @@ async def interpolate_max_location( rbf = scipy.interpolate.RBFInterpolator( neighbor_coords[:, :3], neighbor_semblance, - kernel="thin_plate_spline", - degree=1, + kernel="cubic", ) bound = peak_node.size / 1.5 res = await asyncio.to_thread( diff --git a/src/qseek/search.py b/src/qseek/search.py index 70d43e8e..6fc38d2b 100644 --- a/src/qseek/search.py +++ b/src/qseek/search.py @@ -918,7 +918,7 @@ async def search( continue if parent.node_peak_interpolation: - source_location = await octree.interpolate_max_location(source_node) + source_location = await octree.interpolate_max_semblance(source_node) else: source_location = source_node.as_location() From f8edfc9e0dda8b2e82361af1db29137726f23187 Mon Sep 17 00:00:00 2001 From: Marius Isken Date: Wed, 17 Jul 2024 16:46:50 +0200 Subject: [PATCH 25/26] cleanup docs --- README.md | 2 +- docs/components/image_function.md | 8 +++++--- docs/components/seismic_data.md | 2 +- docs/index.md | 15 ++++++++++----- mkdocs.yml | 2 +- pyproject.toml | 1 + src/qseek/apps/qseek.py | 2 +- src/qseek/images/seisbench.py | 3 ++- 8 files changed, 22 insertions(+), 13 deletions(-) diff --git a/README.md b/README.md index 0d244cab..b6c67f3a 100644 --- a/README.md +++ b/README.md @@ -18,7 +18,7 @@ Key features are of the earthquake detection and localisation framework are: * [EQTransformer (Mousavi et al., 2020)](https://doi.org/10.1038/s41467-020-17591-w) * [GPD (Ross et al., 2018)](https://doi.org/10.1785/0120180080) * [OBSTransformer (Niksejel and Zahng, 2024)](https://doi.org/10.1093/gji/ggae049) - * [LFEDetect](https://doi.org/10.1093/gji/ggae049) + * LFEDetect * Octree localisation approach for efficient and accurate search * Different velocity models: * Constant velocity diff --git a/docs/components/image_function.md b/docs/components/image_function.md index 99e2c90c..591b009e 100644 --- a/docs/components/image_function.md +++ b/docs/components/image_function.md @@ -2,14 +2,16 @@ For image functions this version of Qseek relies heavily on machine learning pickers delivered by [SeisBench](https://github.com/seisbench/seisbench). -## PhaseNet Image Function +## SeisBench Image Function + +SeisBench offers access to a variety of machine learning phase pickers pre-trained on various data sets. !!! abstract "Citation PhaseNet" *Zhu, Weiqiang, and Gregory C. Beroza. "PhaseNet: A Deep-Neural-Network-Based Seismic Arrival Time Picking Method." arXiv preprint arXiv:1803.03211 (2018).* ```python exec='on' from qseek.utils import generate_docs -from qseek.images.phase_net import PhaseNet +from qseek.images.seisbench import SeisBench -print(generate_docs(PhaseNet())) +print(generate_docs(SeisBench())) ``` diff --git a/docs/components/seismic_data.md b/docs/components/seismic_data.md index 84aa1011..80019d0e 100644 --- a/docs/components/seismic_data.md +++ b/docs/components/seismic_data.md @@ -10,7 +10,7 @@ To prepare your data for EQ detection and localisation, **organize it in a MiniS from qseek.utils import generate_docs from qseek.waveforms.squirrel import PyrockoSquirrel -print(generate_docs(PyrockoSquirrel())) +print(generate_docs(PyrockoSquirrel(persistent="docs"))) ``` ## Meta Data diff --git a/docs/index.md b/docs/index.md index 121e5cc4..55b63f73 100644 --- a/docs/index.md +++ b/docs/index.md @@ -14,14 +14,19 @@ The detector is leveraging [Pyrocko](https://pyrocko.org) and [SeisBench](https: ## Features * [x] Earthquake phase detection using machine-learning pickers from [SeisBench](https://github.com/seisbench/seisbench) + * [x] [PhaseNet (Zhu and Beroza, 2018](https://doi.org/10.1093/gji/ggy423) + * [x] [EQTransformer (Mousavi et al., 2020)](https://doi.org/10.1038/s41467-020-17591-w) + * [x] [GPD (Ross et al., 2018)](https://doi.org/10.1785/0120180080) + * [x] [OBSTransformer (Niksejel and Zahng, 2024)](https://doi.org/10.1093/gji/ggae049) + * [x] LFEDetect * [x] Octree localisation approach for efficient and accurate search * [x] Different velocity models: - * [x] Constant velocity - * [x] 1D Layered velocity model - * [x] 3D fast-marching velocity model (NonLinLoc compatible) + * [x] Constant velocity + * [x] 1D Layered velocity model + * [x] 3D fast-marching velocity model (NonLinLoc compatible) * [x] Extraction of earthquake event features: - * [x] Local magnitudes - * [x] Ground motion attributes + * [x] Local magnitudes + * [x] Ground motion attributes * [x] Automatic extraction of modelled and picked travel times * [x] Calculation and application of station corrections / station delay times * [ ] Real-time analytics on streaming data (e.g. SeedLink) diff --git a/mkdocs.yml b/mkdocs.yml index 8dafcf9e..fa0aed37 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -77,7 +77,7 @@ nav: - Getting Started 🚀: getting_started.md - Visualising Detections: visualizing_results.md - Benchmark: benchmark.md - - Usage: + - Configuration: - The Search: components/configuration.md - Seismic Data: components/seismic_data.md - Ray Tracer: components/ray_tracer.md diff --git a/pyproject.toml b/pyproject.toml index b455f823..31ee57aa 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,6 +48,7 @@ dependencies = [ "psutil>=5.9", "aiofiles>=23.0", "typer >=0.12.3", + "scikit-fmm >= 2024.05", ] classifiers = [ diff --git a/src/qseek/apps/qseek.py b/src/qseek/apps/qseek.py index c542b01b..0031e4ff 100644 --- a/src/qseek/apps/qseek.py +++ b/src/qseek/apps/qseek.py @@ -155,7 +155,7 @@ export = subparsers.add_parser( "export", - help="Export detections to different output formats", + help="export detections to different output formats", description="Export detections to different output formats." " Get an overview with `qseek export list`", ) diff --git a/src/qseek/images/seisbench.py b/src/qseek/images/seisbench.py index 46c46a1f..51664c62 100644 --- a/src/qseek/images/seisbench.py +++ b/src/qseek/images/seisbench.py @@ -161,7 +161,8 @@ class SeisBench(ImageFunction): model: ModelName = Field( default="PhaseNet", - description="The model to use for the image function. Currently only `PhaseNet`", + description="The model to use for the image function. Currently supported " + "models are `PhaseNet`, `EQTransformer`, `GPD`, `OBSTransformer`, `LFEDetect`.", ) pretrained: PreTrainedName = Field( From cd4eaf3edcce5e56baeafce7742d9268f298e10e Mon Sep 17 00:00:00 2001 From: Marius Isken Date: Wed, 17 Jul 2024 15:00:38 +0000 Subject: [PATCH 26/26] removing GPD again --- README.md | 1 - docs/index.md | 1 - src/qseek/images/seisbench.py | 5 ++--- 3 files changed, 2 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index b6c67f3a..d1134caa 100644 --- a/README.md +++ b/README.md @@ -16,7 +16,6 @@ Key features are of the earthquake detection and localisation framework are: * Earthquake phase detection using machine-learning model from [SeisBench](https://github.com/seisbench/seisbench), pre-trained on different data sets. * [PhaseNet (Zhu and Beroza, 2018](https://doi.org/10.1093/gji/ggy423) * [EQTransformer (Mousavi et al., 2020)](https://doi.org/10.1038/s41467-020-17591-w) - * [GPD (Ross et al., 2018)](https://doi.org/10.1785/0120180080) * [OBSTransformer (Niksejel and Zahng, 2024)](https://doi.org/10.1093/gji/ggae049) * LFEDetect * Octree localisation approach for efficient and accurate search diff --git a/docs/index.md b/docs/index.md index 55b63f73..0ce562b8 100644 --- a/docs/index.md +++ b/docs/index.md @@ -16,7 +16,6 @@ The detector is leveraging [Pyrocko](https://pyrocko.org) and [SeisBench](https: * [x] Earthquake phase detection using machine-learning pickers from [SeisBench](https://github.com/seisbench/seisbench) * [x] [PhaseNet (Zhu and Beroza, 2018](https://doi.org/10.1093/gji/ggy423) * [x] [EQTransformer (Mousavi et al., 2020)](https://doi.org/10.1038/s41467-020-17591-w) - * [x] [GPD (Ross et al., 2018)](https://doi.org/10.1785/0120180080) * [x] [OBSTransformer (Niksejel and Zahng, 2024)](https://doi.org/10.1093/gji/ggae049) * [x] LFEDetect * [x] Octree localisation approach for efficient and accurate search diff --git a/src/qseek/images/seisbench.py b/src/qseek/images/seisbench.py index 51664c62..ff36cca1 100644 --- a/src/qseek/images/seisbench.py +++ b/src/qseek/images/seisbench.py @@ -29,7 +29,6 @@ ModelName = Literal[ "PhaseNet", "EQTransformer", - "GPD", "OBSTransformer", "LFEDetect", ] @@ -255,12 +254,12 @@ def _prepare(self) -> None: self._seisbench_model.cuda(self.torch_use_cuda) except RuntimeError as exc: logger.warning( - "failed to use CUDA for PhaseNet model, using CPU.", + "failed to use CUDA for SeisBench model, using CPU.", exc_info=exc, ) self._seisbench_model.eval() try: - logger.info("compiling PhaseNet model...") + logger.info("compiling SeisBench model...") self._seisbench_model = torch.compile( self._seisbench_model, mode="max-autotune",