From 6bf3c4ed63deedfbeb0ddadfffd1ae20a815001b Mon Sep 17 00:00:00 2001 From: zhuwq Date: Thu, 17 Oct 2024 21:48:03 -0700 Subject: [PATCH] update japn --- examples/japan/run_adloc_cc.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/examples/japan/run_adloc_cc.py b/examples/japan/run_adloc_cc.py index 7a3038e..b253760 100644 --- a/examples/japan/run_adloc_cc.py +++ b/examples/japan/run_adloc_cc.py @@ -168,14 +168,14 @@ optimizer = optim.Adam( [ {"params": travel_time.event_loc.parameters(), "lr": lr}, # learning rate for event_loc - {"params": travel_time.event_time.parameters(), "lr": lr * 0.1}, # learning rate for event_time + {"params": travel_time.event_time.parameters(), "lr": lr}, # learning rate for event_time ], lr=lr, ) scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS, eta_min=lr * 0.1) - scaler = optim.lr_scheduler.ReduceLROnPlateau( - optim.SGD(params=travel_time.parameters(), lr=1.0), mode="min", factor=0.9, patience=3, threshold=0.05 - ) + # scaler = optim.lr_scheduler.ReduceLROnPlateau( + # optim.SGD(params=travel_time.parameters(), lr=1.0), mode="min", factor=0.95, patience=3, threshold=0.05 + # ) valid_index = np.ones(len(pairs), dtype=bool) if ddp: @@ -217,7 +217,7 @@ # torch.nn.utils.clip_grad_norm_(travel_time.parameters(), 1.0) optimizer.step() scheduler.step() - scaler.step(loss) + # scaler.step(loss) with torch.no_grad(): raw_travel_time.event_loc.weight.data[:, 2].clamp_( min=config["zlim_km"][0] + 0.1, max=config["zlim_km"][1] - 0.1 @@ -247,8 +247,8 @@ weight = np.concatenate(weight) # threshold_time = 6.0 * (np.cos(epoch * np.pi / EPOCHS) + 1.0) / 2.0 + 2.0 # s threshold_time = 6.0 * (EPOCHS - 1 - epoch) / EPOCHS + 2.0 # s - print(f"Scaler: {scaler.get_last_lr()[0]}") - threshold_time *= scaler.get_last_lr()[0] + # print(f"Scaler: {scaler.get_last_lr()[0]}") + # threshold_time *= scaler.get_last_lr()[0] # valid_index = np.abs(pred_time - pairs["dt"]) < np.std((pred_time - pairs["dt"])[valid_index]) * threshold_time # weighted_std = np.sqrt(np.average(((pred_time - pairs["dt"])[valid_index]) ** 2, weights=weight[valid_index])) weighted_std = np.sqrt(np.average(((pred_time - pairs["dt"])) ** 2, weights=weight)) @@ -284,7 +284,7 @@ ) # threshold_space = 9.0 * (np.cos(epoch * np.pi / EPOCHS) + 1.0) / 2.0 + 1.0 # km threshold_space = 9.0 * (EPOCHS - 1 - epoch) / EPOCHS + 1.0 # km - threshold_space *= scaler.get_last_lr()[0] + # threshold_space *= scaler.get_last_lr()[0] num_picks = len(pairs_df) pairs_df = pairs_df[pairs_df["dist_km"] < threshold_space] print(f"Filter by space: {num_picks} -> {len(pairs_df)} using threshold {threshold_space:.2f}")