Skip to content

Commit

Permalink
[BE] Normalized to use model_args: ModelArgs (#58)
Browse files Browse the repository at this point in the history
Some modules used `args: ModelArgs`, others `params: ModelArgs`, and
others `model_args: ModelArgs`. This PR normalizes everything to use
`model_args: ModelArgs` for consistency. (`params` might be confused
with `nn.Parameter`s, and `model_args` was more explicit than `args`.)

**Test Plan**
```
./run_llama_train.sh
```


<details>
<summary> Output </summary>

```
+ TRAINER_DIR=/home/andgu/local/torchtrain
+ MODEL=debugmodel
+ NGPU=8
+ PP=1
+ SP=1
+ DP=-1
+ LOG_RANK=0
+ CHECKPOINT_FOLDER=
+ CHECKPOINT_INTERVAL=5
+ torchrun --nproc_per_node=8 --local-ranks-filter 0 --role rank --tee 3 train.py --steps 10 --compile --pp_degree 1 --sp_degree 1 --dp_degree -1
[2024-02-13 09:53:31,345] torch.distributed.run: [WARNING] 
[2024-02-13 09:53:31,345] torch.distributed.run: [WARNING] *****************************************
[2024-02-13 09:53:31,345] torch.distributed.run: [WARNING] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. 
[2024-02-13 09:53:31,345] torch.distributed.run: [WARNING] *****************************************
[rank0]:2024-02-13 09:53:33,644 - torchtrain.parallelisms - INFO - Building 1-D device mesh with ('dp',), [8]
[rank0]:2024-02-13 09:53:36,955 - root - INFO - Reloaded SentencePiece model from ./torchtrain/datasets/tokenizer/tokenizer.model
[rank0]:2024-02-13 09:53:36,955 - root - INFO - #words: 32000 - BOS ID: 1 - EOS ID: 2
[rank0]:/home/andgu/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/scipy/__init__.py:146: UserWarning: A NumPy version >=1.17.3 and <1.25.0 is required for this version of SciPy (detected version 1.26.0
[rank0]:  warnings.warn(f"A NumPy version >={np_minversion} and <{np_maxversion}"
[rank0]:2024-02-13 09:53:41,571 - root - INFO - Applied FSDP to the model...
[rank0]:2024-02-13 09:53:41,572 - root - INFO - Gradient scaling not enabled.
[rank0]:2024-02-13 09:53:41,572 - root - INFO - Compiling model llama with torch.compile...
[rank0]:2024-02-13 09:53:43,892 - root - INFO - Profiling active.  Traces will be saved at ./torchtrain/outputs/profiling/traces
[rank0]:NCCL version 2.19.3+cuda12.0
[rank0]:[rank0]:[2024-02-13 09:53:43,995] [0/0] torch._dynamo.variables.torch: [WARNING] Profiler function <class 'torch.autograd.profiler.record_function'> will be ignored
[rank0]:/data/users/andgu/pytorch/torch/_inductor/lowering.py:1697: UserWarning: Torchinductor does not support code generation for complex operators. Performance may be worse than eager.
[rank0]:  warnings.warn(
[rank0]:2024-02-13 09:54:06,085 - root - INFO - step: 1, current loss: 10.54707145690918, lr: [0.0002666666666666667]
[rank0]:2024-02-13 09:54:06,153 - root - INFO - step: 2, current loss: 10.481386184692383, lr: [0.0005333333333333334]
[rank0]:2024-02-13 09:54:06,222 - root - INFO - step: 3, current loss: 10.334623336791992, lr: [0.0008]
[rank0]:2024-02-13 09:54:06,288 - root - INFO - step: 4, current loss: 10.121940612792969, lr: [0.0007]
[rank0]:2024-02-13 09:54:06,355 - root - INFO - step: 5, current loss: 9.922933578491211, lr: [0.0006000000000000001]
[rank0]:2024-02-13 09:54:06,422 - root - INFO - step: 6, current loss: 9.710294723510742, lr: [0.0005]
[rank0]:2024-02-13 09:54:06,487 - root - INFO - step: 7, current loss: 9.587849617004395, lr: [0.0004]
[rank0]:2024-02-13 09:54:06,773 - root - INFO - step: 8, current loss: 9.474313735961914, lr: [0.00030000000000000003]
[rank0]:STAGE:2024-02-13 09:54:06 3243810:3243810 ActivityProfilerController.cpp:314] Completed Stage: Warm Up
[rank0]:2024-02-13 09:54:06,845 - root - INFO - step: 9, current loss: 9.282522201538086, lr: [0.0002]
[rank0]:[rank0]:[W CPUAllocator.cpp:249] Memory block of unknown size was allocated before the profiling started, profiler results will not include the deallocation event
[rank0]:STAGE:2024-02-13 09:54:06 3243810:3243810 ActivityProfilerController.cpp:320] Completed Stage: Collection
[rank0]:STAGE:2024-02-13 09:54:06 3243810:3243810 ActivityProfilerController.cpp:324] Completed Stage: Post Processing
[rank0]:2024-02-13 09:54:06,999 - root - INFO - exporting profile traces to ./torchtrain/outputs/profiling/traces/iteration_10
[rank0]:2024-02-13 09:54:07,002 - root - INFO - step: 10, current loss: 9.34823989868164, lr: [0.0001]
```
</details>
  • Loading branch information
awgu authored Feb 13, 2024
1 parent 07fbe13 commit 08b607e
Showing 1 changed file with 49 additions and 41 deletions.
90 changes: 49 additions & 41 deletions torchtrain/models/llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ class Attention(nn.Module):
Multi-head attention module.
Args:
args (ModelArgs): Model configuration parameters.
model_args (ModelArgs): Model configuration arguments.
Attributes:
n_kv_heads (int): Number of key and value heads.
Expand All @@ -182,18 +182,26 @@ class Attention(nn.Module):
"""

def __init__(self, args: ModelArgs):
def __init__(self, model_args: ModelArgs):

super().__init__()
self.n_heads = args.n_heads
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
self.n_heads = model_args.n_heads
self.n_kv_heads = (
model_args.n_heads
if model_args.n_kv_heads is None
else model_args.n_kv_heads
)
self.n_rep = self.n_heads // self.n_kv_heads
self.head_dim = args.dim // args.n_heads
self.head_dim = model_args.dim // model_args.n_heads

self.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False)
self.wk = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
self.wv = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False)
self.wq = nn.Linear(
model_args.dim, model_args.n_heads * self.head_dim, bias=False
)
self.wk = nn.Linear(model_args.dim, self.n_kv_heads * self.head_dim, bias=False)
self.wv = nn.Linear(model_args.dim, self.n_kv_heads * self.head_dim, bias=False)
self.wo = nn.Linear(
model_args.n_heads * self.head_dim, model_args.dim, bias=False
)

def forward(
self,
Expand Down Expand Up @@ -290,18 +298,18 @@ class RotaryEmbedding(nn.Module):
RotaryEmbedding Module
"""

def __init__(self, params: ModelArgs):
def __init__(self, model_args: ModelArgs):
super().__init__()
self.params = params
self.tok_embeddings = nn.Embedding(params.vocab_size, params.dim)
self.model_args = model_args
self.tok_embeddings = nn.Embedding(model_args.vocab_size, model_args.dim)

self.freqs_cis = precompute_freqs_cis(
# Note that self.params.max_seq_len is multiplied by 2 because the token limit for the Llama 2 generation
# Note that self.model_args.max_seq_len is multiplied by 2 because the token limit for the Llama 2 generation
# of models is 4096.
# Adding this multiplier instead of using 4096 directly allows for dynamism of token lengths while training
# or fine-tuning.
self.params.dim // self.params.n_heads,
self.params.max_seq_len * 2,
self.model_args.dim // self.model_args.n_heads,
self.model_args.max_seq_len * 2,
)

def forward(self, tokens: torch.Tensor):
Expand All @@ -327,7 +335,7 @@ class TransformerBlock(nn.Module):
Args:
layer_id (int): Identifier for the layer.
args (ModelArgs): Model configuration parameters.
model_args (ModelArgs): Model configuration arguments.
Attributes:
n_heads (int): Number of attention heads.
Expand All @@ -341,21 +349,21 @@ class TransformerBlock(nn.Module):
"""

def __init__(self, layer_id: int, args: ModelArgs):
def __init__(self, layer_id: int, model_args: ModelArgs):

super().__init__()
self.n_heads = args.n_heads
self.dim = args.dim
self.attention = Attention(args)
self.n_heads = model_args.n_heads
self.dim = model_args.dim
self.attention = Attention(model_args)
self.feed_forward = FeedForward(
dim=args.dim,
hidden_dim=4 * args.dim,
multiple_of=args.multiple_of,
ffn_dim_multiplier=args.ffn_dim_multiplier,
dim=model_args.dim,
hidden_dim=4 * model_args.dim,
multiple_of=model_args.multiple_of,
ffn_dim_multiplier=model_args.ffn_dim_multiplier,
)
self.layer_id = layer_id
self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
self.attention_norm = RMSNorm(model_args.dim, eps=model_args.norm_eps)
self.ffn_norm = RMSNorm(model_args.dim, eps=model_args.norm_eps)

def forward(
self,
Expand Down Expand Up @@ -383,10 +391,10 @@ class Transformer(nn.Module):
Transformer Module
Args:
params (ModelArgs): Model configuration parameters.
model_args (ModelArgs): Model configuration arguments.
Attributes:
params (ModelArgs): Model configuration parameters.
model_args (ModelArgs): Model configuration arguments.
vocab_size (int): Vocabulary size.
n_layers (int): Number of layers in the model.
tok_embeddings (ParallelEmbedding): Token embeddings.
Expand All @@ -397,21 +405,21 @@ class Transformer(nn.Module):
"""

def __init__(self, params: ModelArgs):
def __init__(self, model_args: ModelArgs):

super().__init__()
self.params = params
self.vocab_size = params.vocab_size
self.n_layers = params.n_layers
self.model_args = model_args
self.vocab_size = model_args.vocab_size
self.n_layers = model_args.n_layers

self.embeddings = RotaryEmbedding(params)
self.embeddings = RotaryEmbedding(model_args)

self.layers = torch.nn.ModuleList()
for layer_id in range(params.n_layers):
self.layers.append(TransformerBlock(layer_id, params))
for layer_id in range(model_args.n_layers):
self.layers.append(TransformerBlock(layer_id, model_args))

self.norm = RMSNorm(params.dim, eps=params.norm_eps)
self.output = nn.Linear(params.dim, params.vocab_size, bias=False)
self.norm = RMSNorm(model_args.dim, eps=model_args.norm_eps)
self.output = nn.Linear(model_args.dim, model_args.vocab_size, bias=False)

def forward(self, tokens: torch.Tensor):
"""
Expand All @@ -426,7 +434,7 @@ def forward(self, tokens: torch.Tensor):
"""
h, freqs_cis = self.embeddings(tokens)
# fold batch and sequence dimension for more efficient allgather/reduce_scatter
h = h.view(-1, self.params.dim)
h = h.view(-1, self.model_args.dim)

for layer in self.layers:
h = layer(h, freqs_cis)
Expand All @@ -435,17 +443,17 @@ def forward(self, tokens: torch.Tensor):
# unfold batch and sequence dimension
bsz = tokens.shape[0]
bs_seqlen = h.shape[0]
h = h.view(bsz, bs_seqlen // bsz, self.params.dim)
h = h.view(bsz, bs_seqlen // bsz, self.model_args.dim)
output = self.output(h).float()
return output

@classmethod
def from_model_args(cls, model_args: ModelArgs):
def from_model_args(cls, model_args: ModelArgs) -> "Transformer":
"""
Initialize a Transformer model from a ModelArgs object.
Args:
model_args (ModelArgs): Model configuration parameters.
model_args (ModelArgs): Model configuration arguments.
Returns:
Transformer: Transformer model.
Expand Down

0 comments on commit 08b607e

Please sign in to comment.