Skip to content

Commit

Permalink
sgd implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
kothasuhas committed Jan 23, 2025
1 parent a3cdaab commit 3a5d4ff
Showing 1 changed file with 30 additions and 0 deletions.
30 changes: 30 additions & 0 deletions src/levanter/optim/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,3 +322,33 @@ def _optimizer(learning_rate):
return optimizer

return optax.inject_hyperparams(_optimizer)(learning_rate=self.lr_scheduler(num_train_steps))


@OptimizerConfig.register_subclass("sgd")
@dataclass
class SGDConfig(OptimizerConfig):
momentum: float = 0.9
nesterov: bool = False
max_grad_norm: Optional[float] = 1.0

def build(self, num_train_steps):
"""Creates the SGD optimizer"""
def _optimizer(learning_rate):
components = []

if self.max_grad_norm:
components.append(optax.clip_by_global_norm(self.max_grad_norm))

if self.momentum > 0:
components.append(optax.trace(decay=self.momentum, nesterov=self.nesterov))

if self.weight_decay > 0:
components.append(optax.add_decayed_weights(self.weight_decay, self.build_weight_decay_mask()))

# - learning rate for descent
components.append(optax.scale(-learning_rate))

optimizer = optax.chain(*components)
return optimizer

return optax.inject_hyperparams(_optimizer)(learning_rate=self.lr_scheduler(num_train_steps))

0 comments on commit 3a5d4ff

Please sign in to comment.