From fb5f4fc964e53d7c522615f48ef0440c089dac50 Mon Sep 17 00:00:00 2001 From: Andrew Gu Date: Tue, 13 Feb 2024 09:07:03 -0800 Subject: [PATCH] [BE] Normalized to use `model_args: ModelArgs` --- torchtrain/models/llama/model.py | 82 ++++++++++++++++---------------- 1 file changed, 41 insertions(+), 41 deletions(-) diff --git a/torchtrain/models/llama/model.py b/torchtrain/models/llama/model.py index 7485d32b6..484733302 100644 --- a/torchtrain/models/llama/model.py +++ b/torchtrain/models/llama/model.py @@ -165,7 +165,7 @@ class Attention(nn.Module): Multi-head attention module. Args: - args (ModelArgs): Model configuration parameters. + model_args (ModelArgs): Model configuration arguments. Attributes: n_kv_heads (int): Number of key and value heads. @@ -182,18 +182,18 @@ class Attention(nn.Module): """ - def __init__(self, args: ModelArgs): + def __init__(self, model_args: ModelArgs): super().__init__() - self.n_heads = args.n_heads - self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads + self.n_heads = model_args.n_heads + self.n_kv_heads = model_args.n_heads if model_args.n_kv_heads is None else model_args.n_kv_heads self.n_rep = self.n_heads // self.n_kv_heads - self.head_dim = args.dim // args.n_heads + self.head_dim = model_args.dim // model_args.n_heads - self.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False) - self.wk = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False) - self.wv = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False) - self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False) + self.wq = nn.Linear(model_args.dim, model_args.n_heads * self.head_dim, bias=False) + self.wk = nn.Linear(model_args.dim, self.n_kv_heads * self.head_dim, bias=False) + self.wv = nn.Linear(model_args.dim, self.n_kv_heads * self.head_dim, bias=False) + self.wo = nn.Linear(model_args.n_heads * self.head_dim, model_args.dim, bias=False) def forward( self, @@ -290,18 +290,18 @@ class RotaryEmbedding(nn.Module): RotaryEmbedding Module """ - def __init__(self, params: ModelArgs): + def __init__(self, model_args: ModelArgs): super().__init__() - self.params = params - self.tok_embeddings = nn.Embedding(params.vocab_size, params.dim) + self.model_args = model_args + self.tok_embeddings = nn.Embedding(model_args.vocab_size, model_args.dim) self.freqs_cis = precompute_freqs_cis( - # Note that self.params.max_seq_len is multiplied by 2 because the token limit for the Llama 2 generation + # Note that self.model_args.max_seq_len is multiplied by 2 because the token limit for the Llama 2 generation # of models is 4096. # Adding this multiplier instead of using 4096 directly allows for dynamism of token lengths while training # or fine-tuning. - self.params.dim // self.params.n_heads, - self.params.max_seq_len * 2, + self.model_args.dim // self.model_args.n_heads, + self.model_args.max_seq_len * 2, ) def forward(self, tokens: torch.Tensor): @@ -327,7 +327,7 @@ class TransformerBlock(nn.Module): Args: layer_id (int): Identifier for the layer. - args (ModelArgs): Model configuration parameters. + model_args (ModelArgs): Model configuration arguments. Attributes: n_heads (int): Number of attention heads. @@ -341,21 +341,21 @@ class TransformerBlock(nn.Module): """ - def __init__(self, layer_id: int, args: ModelArgs): + def __init__(self, layer_id: int, model_args: ModelArgs): super().__init__() - self.n_heads = args.n_heads - self.dim = args.dim - self.attention = Attention(args) + self.n_heads = model_args.n_heads + self.dim = model_args.dim + self.attention = Attention(model_args) self.feed_forward = FeedForward( - dim=args.dim, - hidden_dim=4 * args.dim, - multiple_of=args.multiple_of, - ffn_dim_multiplier=args.ffn_dim_multiplier, + dim=model_args.dim, + hidden_dim=4 * model_args.dim, + multiple_of=model_args.multiple_of, + ffn_dim_multiplier=model_args.ffn_dim_multiplier, ) self.layer_id = layer_id - self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps) - self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps) + self.attention_norm = RMSNorm(model_args.dim, eps=model_args.norm_eps) + self.ffn_norm = RMSNorm(model_args.dim, eps=model_args.norm_eps) def forward( self, @@ -383,10 +383,10 @@ class Transformer(nn.Module): Transformer Module Args: - params (ModelArgs): Model configuration parameters. + model_args (ModelArgs): Model configuration arguments. Attributes: - params (ModelArgs): Model configuration parameters. + model_args (ModelArgs): Model configuration arguments. vocab_size (int): Vocabulary size. n_layers (int): Number of layers in the model. tok_embeddings (ParallelEmbedding): Token embeddings. @@ -397,21 +397,21 @@ class Transformer(nn.Module): """ - def __init__(self, params: ModelArgs): + def __init__(self, model_args: ModelArgs): super().__init__() - self.params = params - self.vocab_size = params.vocab_size - self.n_layers = params.n_layers + self.model_args = model_args + self.vocab_size = model_args.vocab_size + self.n_layers = model_args.n_layers - self.embeddings = RotaryEmbedding(params) + self.embeddings = RotaryEmbedding(model_args) self.layers = torch.nn.ModuleList() - for layer_id in range(params.n_layers): - self.layers.append(TransformerBlock(layer_id, params)) + for layer_id in range(model_args.n_layers): + self.layers.append(TransformerBlock(layer_id, model_args)) - self.norm = RMSNorm(params.dim, eps=params.norm_eps) - self.output = nn.Linear(params.dim, params.vocab_size, bias=False) + self.norm = RMSNorm(model_args.dim, eps=model_args.norm_eps) + self.output = nn.Linear(model_args.dim, model_args.vocab_size, bias=False) def forward(self, tokens: torch.Tensor): """ @@ -426,7 +426,7 @@ def forward(self, tokens: torch.Tensor): """ h, freqs_cis = self.embeddings(tokens) # fold batch and sequence dimension for more efficient allgather/reduce_scatter - h = h.view(-1, self.params.dim) + h = h.view(-1, self.model_args.dim) for layer in self.layers: h = layer(h, freqs_cis) @@ -435,17 +435,17 @@ def forward(self, tokens: torch.Tensor): # unfold batch and sequence dimension bsz = tokens.shape[0] bs_seqlen = h.shape[0] - h = h.view(bsz, bs_seqlen // bsz, self.params.dim) + h = h.view(bsz, bs_seqlen // bsz, self.model_args.dim) output = self.output(h).float() return output @classmethod - def from_model_args(cls, model_args: ModelArgs): + def from_model_args(cls, model_args: ModelArgs) -> "Transformer": """ Initialize a Transformer model from a ModelArgs object. Args: - model_args (ModelArgs): Model configuration parameters. + model_args (ModelArgs): Model configuration arguments. Returns: Transformer: Transformer model.