Skip to content

Commit

Permalink
🐛 [Fix] momentum scaling bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
henrytsui000 committed Dec 23, 2024
1 parent b96c8ea commit c0e2436
Showing 1 changed file with 6 additions and 3 deletions.
9 changes: 6 additions & 3 deletions yolo/utils/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ def lerp(start: float, end: float, step: Union[int, float], total: int = 1):
"""
Linearly interpolates between start and end values.
start * (1 - step) + end * step
Parameters:
start (float): The starting value.
end (float): The ending value.
Expand Down Expand Up @@ -88,8 +90,8 @@ def next_epoch(self, batch_num, epoch_idx):
# 0.937: Start Momentum
# 0.8 : Normal Momemtum
# 3 : The warm up epoch num
self.min_mom = lerp(0.937, 0.8, min(epoch_idx, 3), 3)
self.max_mom = lerp(0.937, 0.8, min(epoch_idx + 1, 3), 3)
self.min_mom = lerp(0.8, 0.937, min(epoch_idx, 3), 3)
self.max_mom = lerp(0.8, 0.937, min(epoch_idx + 1, 3), 3)
self.batch_num = batch_num
self.batch_idx = 0

Expand All @@ -99,8 +101,9 @@ def next_batch(self):
for lr_idx, param_group in enumerate(self.param_groups):
min_lr, max_lr = self.min_lr[lr_idx], self.max_lr[lr_idx]
param_group["lr"] = lerp(min_lr, max_lr, self.batch_idx, self.batch_num)
# param_group["momentum"] = lerp(self.min_mom, self.max_mom, self.batch_idx, self.batch_num)
param_group["momentum"] = lerp(self.min_mom, self.max_mom, self.batch_idx, self.batch_num)
lr_dict[f"LR/{lr_idx}"] = param_group["lr"]
lr_dict[f"momentum/{lr_idx}"] = param_group["momentum"]
return lr_dict

optimizer_class.next_batch = next_batch
Expand Down

0 comments on commit c0e2436

Please sign in to comment.