Skip to content

Commit

Permalink
fix small bug in cut_templates.py
Browse files Browse the repository at this point in the history
  • Loading branch information
zhuwq0 committed Nov 7, 2023
1 parent eefe682 commit 061e977
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 103 deletions.
2 changes: 1 addition & 1 deletion PhaseNet
195 changes: 93 additions & 102 deletions slurm/cut_templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ def extract_template_numpy(
traveltime_fname,
traveltime_index_fname,
traveltime_type_fname,
arrivaltime_index_fname,
snr_fname,
mseed_path,
events,
Expand All @@ -56,9 +55,6 @@ def extract_template_numpy(
traveltime_type_array = np.memmap(
traveltime_type_fname, dtype=np.int32, mode="r+", shape=tuple(config["traveltime_shape"])
)
arrivaltime_index_array = np.memmap(
arrivaltime_index_fname, dtype=np.int64, mode="r+", shape=tuple(config["traveltime_shape"])
)
snr_array = np.memmap(snr_fname, dtype=np.float32, mode="r+", shape=tuple(config["snr_shape"]))

# %%
Expand Down Expand Up @@ -93,18 +89,18 @@ def extract_template_numpy(
trace.resample(config["sampling_rate"])
# trace.detrend("linear")
# trace.taper(max_percentage=0.05, type="cosine")
# trace.filter("bandpass", freqmin=1.0, freqmax=15.0, corners=2, zerophase=True)
trace.filter("bandpass", freqmin=1.0, freqmax=15.0, corners=4, zerophase=True)
waveforms_dict[f"{station_id}{c}"] = trace
except Exception as e:
print(e)
continue

# %%
picks["station_component_index"] = picks.apply(lambda x: f"{x.station_id}.{x.phase_type}", axis=1)
picks["station_phase_index"] = picks.apply(lambda x: f"{x.station_id}.{x.phase_type}", axis=1)

# %%
num_event = 0
for ii, event in tqdm(
for i, event in tqdm(
events_.iterrows(),
total=len(events_),
desc=f"Cutting event {year_jday}T{hour}",
Expand All @@ -119,19 +115,19 @@ def extract_template_numpy(
continue

picks_ = picks.loc[[event.event_index]]
picks_ = picks_.set_index("station_component_index")
picks_ = picks_.set_index("station_phase_index")

event_loc = event[["x_km", "y_km", "z_km"]].to_numpy().astype(np.float32)
event_loc = np.hstack((event_loc, [0]))[np.newaxis, :]
station_loc = stations[["x_km", "y_km", "z_km"]].to_numpy()
station_loc = stations[["x_km", "y_km", "z_km"]].to_numpy().astype(np.float32)

template_ = np.zeros((6, len(stations), config["nt"]), dtype=np.float32)
snr_ = np.zeros((6, len(stations)), dtype=np.float32)
traveltime_ = np.zeros((2, len(stations)), dtype=np.float32)
traveltime_index_ = np.zeros((2, len(stations)), dtype=np.int32)
traveltime_type_ = np.zeros((2, len(stations)), dtype=np.int32)
arrivaltime_index_ = np.zeros((2, len(stations)), dtype=np.int64)

for i, phase_type in enumerate(["P", "S"]):
for k, phase_type in enumerate(["P", "S"]):
traveltime = gamma.seismic_ops.calc_time(
event_loc,
station_loc,
Expand All @@ -140,56 +136,45 @@ def extract_template_numpy(
).squeeze()

phase_timestamp_pred = event["event_timestamp"] + traveltime
# predicted_phase_time = [events_.loc[event_index]["event_time"] + pd.Timedelta(seconds=x) for x in traveltime]

mean_shift = []
for j, station_id in enumerate(stations["station_id"]):
for j, station in stations.iterrows():
station_id = station["station_id"]
if f"{station_id}.{phase_type}" in picks_.index:
## TODO: check if multiple phases for the same station
phase_timestamp = picks_.loc[f"{station_id}.{phase_type}"]["phase_timestamp"]
phase_timestamp_pred[j] = phase_timestamp
mean_shift.append(phase_timestamp - (event["event_timestamp"] + traveltime[j]))

traveltime[j] = phase_timestamp - event["event_timestamp"]
traveltime_type_[i, j] = 1 # auto pick
arrivaltime_index_[i, j] = int(round(phase_timestamp * config["sampling_rate"]))
traveltime_type_[k, j] = 1 # auto pick
# traveltime[j] = phase_timestamp - event["event_timestamp"] # should define traveltime at the exact data point
else:
traveltime_type_[i, j] = 0 # theoretical pick
traveltime_type_[k, j] = 0 # theoretical pick

# if len(mean_shift) > 0:
# mean_shift = float(np.median(mean_shift))
# else:
# mean_shift = 0
# phase_timestamp_pred[traveltime_type_[i, :] == 0] += mean_shift
# traveltime[traveltime_type_[i, :] == 0] += mean_shift
traveltime_[i, :] = traveltime

for j, station in enumerate(stations.iloc):
for j, station in stations.iterrows():
station_id = station["station_id"]

empty_data = True
for c in station["component"]:
c_index = i * 3 + config["component_mapping"][c]
c_index = k * 3 + config["component_mapping"][c] # 012 for P, 345 for S

if f"{station_id}{c}" in waveforms_dict:
trace = waveforms_dict[f"{station_id}{c}"]

begin_time = (
phase_timestamp_pred[j]
- trace.stats.starttime.datetime.replace(tzinfo=timezone.utc).timestamp()
- config["time_before"]
trace_starttime = trace.stats.starttime.datetime.replace(tzinfo=timezone.utc).timestamp()

begin_time = phase_timestamp_pred[j] - trace_starttime - config["time_before"]
end_time = phase_timestamp_pred[j] - trace_starttime + config["time_after"]
begin_time_index = max(0, int(begin_time * trace.stats.sampling_rate))
end_time_index = max(0, int(end_time * trace.stats.sampling_rate))
traveltime_[k, j] = (
begin_time_index / trace.stats.sampling_rate
+ config["time_before"]
+ trace_starttime
- event["event_timestamp"]
) ## define traveltime at the exact data point
traveltime_index_[k, j] = begin_time_index + int(
config["time_before"] * trace.stats.sampling_rate
)
end_time = (
phase_timestamp_pred[j]
- trace.stats.starttime.datetime.replace(tzinfo=timezone.utc).timestamp()
+ config["time_after"]
)

trace_data = trace.data[
max(0, int(begin_time * trace.stats.sampling_rate)) : max(
0, int(end_time * trace.stats.sampling_rate)
)
].astype(np.float32)
trace_data = trace.data[begin_time_index:end_time_index].astype(np.float32)

if len(trace_data) < config["nt"]:
continue
Expand All @@ -198,48 +183,81 @@ def extract_template_numpy(
continue

empty_data = False
template_[c_index, j, : config["nt"]] = trace_data[: config["nt"]]
if traveltime_type_[k, j] == 1: ## only use auto picks
template_[c_index, j, : config["nt"]] = trace_data[: config["nt"]]
################## debuging ##################
# import matplotlib.pyplot as plt
# import scipy.interpolate

# if (i == 0) and (j in [3, 4, 5]):
# # template_[c_index, j, 1 : config["nt"]] = trace_data[: config["nt"] - 1]
# t = np.linspace(0, 1, (config["nt"] - 1) + 1)
# t_interp = np.linspace(0, 1, (config["nt"] - 1) * 10 + 1)
# x = trace_data[: config["nt"]]
# x_interp = scipy.interpolate.interp1d(t, x, kind="cubic")(t_interp)
# # print(x - x_interp[0::10])
# # plt.figure()
# # plt.plot(t, x)
# # plt.plot(t_interp, x_interp)
# # plt.plot(t, x_interp[0::10])
# # plt.savefig("debug.png")
# # raise
# template_[c_index, j, :] = np.roll(x_interp, -1)[::10]
################################################

s = np.std(trace_data[-int(config["time_after"] * config["sampling_rate"]) :])
n = np.std(trace_data[: int(config["time_before"] * config["sampling_rate"])])
if n == 0:
snr_[c_index, j] = 0
else:
snr_[c_index, j] = s / n

template_array[ii] += template_
traveltime_array[ii] += traveltime_
traveltime_index_array[ii] += np.round(traveltime_ * config["sampling_rate"]).astype(np.int32)
traveltime_type_array[ii] += traveltime_type_
arrivaltime_index_array[ii] += arrivaltime_index_
snr_array[ii] = +snr_
template_array[i] += template_
traveltime_array[i] += traveltime_
traveltime_index_array[i] += traveltime_index_
traveltime_type_array[i] += traveltime_type_
snr_array[i] = +snr_

with lock:
template_array.flush()
traveltime_array.flush()
traveltime_index_array.flush()
traveltime_type_array.flush()
arrivaltime_index_array.flush()
snr_array.flush()

# num_event += 1
# if num_event > 20:
# break

# %%
result_path = f"{region}/cctorch"
if not os.path.exists(f"{root_path}/{result_path}"):
os.makedirs(f"{root_path}/{result_path}")

# %%
print(json.dumps(config, indent=4, sort_keys=True))

# %%
picks = pd.read_csv(
f"{root_path}/{region}/gamma/gamma_picks.csv",
parse_dates=["phase_time"],
)
picks = picks[picks["event_index"] != -1]
picks["phase_timestamp"] = picks["phase_time"].apply(lambda x: x.timestamp())

################## debuging ##################
# tmp = picks[picks["event_index"] == 1528]
# tmp["event_index"] = 1527
# # tmp["phase_time"] = tmp["phase_time"] + pd.Timedelta(seconds=0.1)
# # tmp["phase_timestamp"] = tmp["phase_time"].apply(lambda x: x.timestamp())
# picks = pd.concat([picks[picks["event_index"] != 1527], tmp])
################################################

# %%
stations = pd.read_json(f"{root_path}/{region}/obspy/stations.json", orient="index")
stations["station_id"] = stations.index
# already set in obspy
# stations["x_km"] = stations.apply(
# lambda x: (x.longitude - config["longitude0"]) * np.cos(np.deg2rad(config["latitude0"])) * config["degree2km"],
# axis=1,
# )
# stations["y_km"] = stations.apply(lambda x: (x.latitude - config["latitude0"]) * config["degree2km"], axis=1)
# stations["z_km"] = stations.apply(lambda x: -x["elevation_m"] / 1e3, axis=1)
# %% filter stations without picks
stations = stations[stations["station_id"].isin(picks.groupby("station_id").size().index)]
stations.reset_index(drop=True, inplace=True) # index used in memmap array
stations.to_json(f"{root_path}/{result_path}/stations.json", orient="index", indent=4)
stations.to_csv(f"{root_path}/{result_path}/stations.csv", index=True)
print(stations.iloc[:5])

# %%
events = pd.read_csv(f"{root_path}/{region}/gamma/gamma_events.csv", parse_dates=["time"])
Expand All @@ -254,40 +272,24 @@ def extract_template_numpy(
events.rename(columns={"y(km)": "y_km"}, inplace=True)
if "z(km)" in events.columns:
events.rename(columns={"z(km)": "z_km"}, inplace=True)
events.reset_index(drop=True, inplace=True) # index used in memmap array
events.to_csv(f"{root_path}/{result_path}/events.csv", index=True)
print(events.iloc[:5])

################## debuging ####################
# mask = events["event_index"] == 1528
# events.loc[mask, "event_time"] = events.loc[mask, "event_time"] + pd.Timedelta(seconds=-0.001)
# events.loc[mask, "event_timestamp"] = events.loc[mask, "event_time"].apply(lambda x: x.timestamp())
# print(events[events["event_index"] == 1528], events[events["event_index"] == 1527])
################################################

# %%
if "event_index" not in events.columns:
event_index = events.index
else:
event_index = list(events["event_index"])
event_index_fname = f"{root_path}/{result_path}/event_index.txt"
with open(event_index_fname, "w") as f:
for i, idx in enumerate(event_index):
f.write(f"{i},{idx}\n")
for i, event in enumerate(events.iloc):
f.write(f"{i},{event['event_index']}\n")
config["cctorch"]["event_index_file"] = event_index_fname

# %%
picks = pd.read_csv(
f"{root_path}/{region}/gamma/gamma_picks.csv",
parse_dates=["phase_time"],
)
picks = picks[picks["event_index"] != -1]

## debuging
tmp = picks[picks["event_index"] == 1528]
tmp["event_index"] = 1527
picks = pd.concat([picks[picks["event_index"] != 1527], tmp])

picks["phase_timestamp"] = picks["phase_time"].apply(lambda x: x.timestamp())

picks_ = picks.groupby("station_id").size()
# station_id_ = picks_[picks_ > (picks_.sum() / len(picks_) * 0.1)].index
# stations = stations[stations["station_id"].isin(station_id_)]
stations = stations[stations["station_id"].isin(picks_.index)]

stations.to_json(f"{root_path}/{result_path}/stations_filtered.json", orient="index", indent=4)
stations.to_csv(f"{root_path}/{result_path}/stations_filtered.csv", index=True, index_label="station_id")

station_index_fname = f"{root_path}/{result_path}/station_index.txt"
with open(station_index_fname, "w") as f:
for i, sta in enumerate(stations.iloc):
Expand All @@ -297,19 +299,12 @@ def extract_template_numpy(
# %%
picks = picks.merge(stations, on="station_id")
picks = picks.merge(events, on="event_index", suffixes=("_station", "_event"))

# %%
# events["index"] = events["event_index"]
# events.set_index("index", inplace=True)
# picks["index"] = picks["event_index"]
# picks.set_index("index", inplace=True)
picks.set_index("event_index", inplace=True)

# %%
nt = int((config["cctorch"]["time_before"] + config["cctorch"]["time_after"]) * config["cctorch"]["sampling_rate"])
config["cctorch"]["nt"] = nt
nch = 6 ## For [P,S] phases and [E,N,Z] components
# nev = int(events.index.max()) + 1
nev = len(events)
nst = len(stations)
print(f"nev: {nev}, nch: {nch}, nst: {nst}, nt: {nt}")
Expand All @@ -324,20 +319,17 @@ def extract_template_numpy(
traveltime_fname = f"{root_path}/{result_path}/traveltime.dat"
traveltime_index_fname = f"{root_path}/{result_path}/traveltime_index.dat"
traveltime_type_fname = f"{root_path}/{result_path}/traveltime_type.dat"
arrivaltime_index_fname = f"{root_path}/{result_path}/arrivaltime_index.dat"
snr_fname = f"{root_path}/{result_path}/snr.dat"
config["cctorch"]["template_file"] = template_fname
config["cctorch"]["traveltime_file"] = traveltime_fname
config["cctorch"]["traveltime_index_file"] = traveltime_index_fname
config["cctorch"]["traveltime_type_file"] = traveltime_type_fname
config["cctorch"]["arrivaltime_index_file"] = arrivaltime_index_fname
config["cctorch"]["snr_file"] = snr_fname

template_array = np.memmap(template_fname, dtype=np.float32, mode="w+", shape=template_shape)
traveltime_array = np.memmap(traveltime_fname, dtype=np.float32, mode="w+", shape=traveltime_shape)
traveltime_index_array = np.memmap(traveltime_index_fname, dtype=np.int32, mode="w+", shape=traveltime_shape)
traveltime_type_array = np.memmap(traveltime_type_fname, dtype=np.int32, mode="w+", shape=traveltime_shape)
arrivaltime_index_array = np.memmap(arrivaltime_index_fname, dtype=np.int64, mode="w+", shape=traveltime_shape)
snr_array = np.memmap(snr_fname, dtype=np.float32, mode="w+", shape=snr_shape)

with open(f"{root_path}/{result_path}/config.json", "w") as f:
Expand All @@ -358,7 +350,6 @@ def extract_template_numpy(
traveltime_fname,
traveltime_index_fname,
traveltime_type_fname,
arrivaltime_index_fname,
snr_fname,
d,
events,
Expand Down

0 comments on commit 061e977

Please sign in to comment.