Skip to content

Commit

Permalink
update japn
Browse files Browse the repository at this point in the history
  • Loading branch information
zhuwq0 committed Oct 18, 2024
1 parent a0f9dd0 commit 6bf3c4e
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions examples/japan/run_adloc_cc.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,14 +168,14 @@
optimizer = optim.Adam(
[
{"params": travel_time.event_loc.parameters(), "lr": lr}, # learning rate for event_loc
{"params": travel_time.event_time.parameters(), "lr": lr * 0.1}, # learning rate for event_time
{"params": travel_time.event_time.parameters(), "lr": lr}, # learning rate for event_time
],
lr=lr,
)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS, eta_min=lr * 0.1)
scaler = optim.lr_scheduler.ReduceLROnPlateau(
optim.SGD(params=travel_time.parameters(), lr=1.0), mode="min", factor=0.9, patience=3, threshold=0.05
)
# scaler = optim.lr_scheduler.ReduceLROnPlateau(
# optim.SGD(params=travel_time.parameters(), lr=1.0), mode="min", factor=0.95, patience=3, threshold=0.05
# )
valid_index = np.ones(len(pairs), dtype=bool)

if ddp:
Expand Down Expand Up @@ -217,7 +217,7 @@
# torch.nn.utils.clip_grad_norm_(travel_time.parameters(), 1.0)
optimizer.step()
scheduler.step()
scaler.step(loss)
# scaler.step(loss)
with torch.no_grad():
raw_travel_time.event_loc.weight.data[:, 2].clamp_(
min=config["zlim_km"][0] + 0.1, max=config["zlim_km"][1] - 0.1
Expand Down Expand Up @@ -247,8 +247,8 @@
weight = np.concatenate(weight)
# threshold_time = 6.0 * (np.cos(epoch * np.pi / EPOCHS) + 1.0) / 2.0 + 2.0 # s
threshold_time = 6.0 * (EPOCHS - 1 - epoch) / EPOCHS + 2.0 # s
print(f"Scaler: {scaler.get_last_lr()[0]}")
threshold_time *= scaler.get_last_lr()[0]
# print(f"Scaler: {scaler.get_last_lr()[0]}")
# threshold_time *= scaler.get_last_lr()[0]
# valid_index = np.abs(pred_time - pairs["dt"]) < np.std((pred_time - pairs["dt"])[valid_index]) * threshold_time
# weighted_std = np.sqrt(np.average(((pred_time - pairs["dt"])[valid_index]) ** 2, weights=weight[valid_index]))
weighted_std = np.sqrt(np.average(((pred_time - pairs["dt"])) ** 2, weights=weight))
Expand Down Expand Up @@ -284,7 +284,7 @@
)
# threshold_space = 9.0 * (np.cos(epoch * np.pi / EPOCHS) + 1.0) / 2.0 + 1.0 # km
threshold_space = 9.0 * (EPOCHS - 1 - epoch) / EPOCHS + 1.0 # km
threshold_space *= scaler.get_last_lr()[0]
# threshold_space *= scaler.get_last_lr()[0]
num_picks = len(pairs_df)
pairs_df = pairs_df[pairs_df["dist_km"] < threshold_space]
print(f"Filter by space: {num_picks} -> {len(pairs_df)} using threshold {threshold_space:.2f}")
Expand Down

0 comments on commit 6bf3c4e

Please sign in to comment.