Skip to content

Commit

Permalink
unique layer init, clean up lr display
Browse files Browse the repository at this point in the history
  • Loading branch information
lessw2020 committed Feb 16, 2024
1 parent 196d56e commit 267c6b7
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 3 deletions.
8 changes: 6 additions & 2 deletions torchtrain/models/llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class ModelArgs:

max_batch_size: int = 32
max_seq_len: int = 32768

unique_layer_init: bool = True # initialization uses each unique layer_id or total model layer count

class RMSNorm(torch.nn.Module):
"""
Expand Down Expand Up @@ -392,7 +392,11 @@ def __init__(self, layer_id: int, model_args: ModelArgs):
self.num_layers = model_args.n_layers
self.attention_norm = RMSNorm(model_args.dim, eps=model_args.norm_eps)
self.ffn_norm = RMSNorm(model_args.dim, eps=model_args.norm_eps)
self.weight_init_std = 0.02 / (2 * self.num_layers) ** 0.5

if model_args.unique_layer_init:
self.weight_init_std = 0.02 /(2 * (self.layer_id + 1)) ** 0.5
else:
self.weight_init_std = 0.02 / (2 * self.num_layers) ** 0.5

def forward(
self,
Expand Down
2 changes: 1 addition & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ def main(args):
time_last_log = timer()

rank0_log(
f"step: {train_state.step}, current loss: {train_state.current_loss}, lr: {scheduler.get_last_lr()}"
f"step: {train_state.step}, current loss: {train_state.current_loss}, lr: {round(float(scheduler.get_last_lr()[0]), 8)}"
)
scheduler.step()

Expand Down

0 comments on commit 267c6b7

Please sign in to comment.