Skip to content

Commit

Permalink
[BE] Normalized to use model_args: ModelArgs
Browse files Browse the repository at this point in the history
  • Loading branch information
awgu committed Feb 13, 2024
1 parent da50d34 commit fb5f4fc
Showing 1 changed file with 41 additions and 41 deletions.
82 changes: 41 additions & 41 deletions torchtrain/models/llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ class Attention(nn.Module):
Multi-head attention module.
Args:
args (ModelArgs): Model configuration parameters.
model_args (ModelArgs): Model configuration arguments.
Attributes:
n_kv_heads (int): Number of key and value heads.
Expand All @@ -182,18 +182,18 @@ class Attention(nn.Module):
"""

def __init__(self, args: ModelArgs):
def __init__(self, model_args: ModelArgs):

super().__init__()
self.n_heads = args.n_heads
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
self.n_heads = model_args.n_heads
self.n_kv_heads = model_args.n_heads if model_args.n_kv_heads is None else model_args.n_kv_heads
self.n_rep = self.n_heads // self.n_kv_heads
self.head_dim = args.dim // args.n_heads
self.head_dim = model_args.dim // model_args.n_heads

self.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False)
self.wk = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
self.wv = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False)
self.wq = nn.Linear(model_args.dim, model_args.n_heads * self.head_dim, bias=False)
self.wk = nn.Linear(model_args.dim, self.n_kv_heads * self.head_dim, bias=False)
self.wv = nn.Linear(model_args.dim, self.n_kv_heads * self.head_dim, bias=False)
self.wo = nn.Linear(model_args.n_heads * self.head_dim, model_args.dim, bias=False)

def forward(
self,
Expand Down Expand Up @@ -290,18 +290,18 @@ class RotaryEmbedding(nn.Module):
RotaryEmbedding Module
"""

def __init__(self, params: ModelArgs):
def __init__(self, model_args: ModelArgs):
super().__init__()
self.params = params
self.tok_embeddings = nn.Embedding(params.vocab_size, params.dim)
self.model_args = model_args
self.tok_embeddings = nn.Embedding(model_args.vocab_size, model_args.dim)

self.freqs_cis = precompute_freqs_cis(
# Note that self.params.max_seq_len is multiplied by 2 because the token limit for the Llama 2 generation
# Note that self.model_args.max_seq_len is multiplied by 2 because the token limit for the Llama 2 generation
# of models is 4096.
# Adding this multiplier instead of using 4096 directly allows for dynamism of token lengths while training
# or fine-tuning.
self.params.dim // self.params.n_heads,
self.params.max_seq_len * 2,
self.model_args.dim // self.model_args.n_heads,
self.model_args.max_seq_len * 2,
)

def forward(self, tokens: torch.Tensor):
Expand All @@ -327,7 +327,7 @@ class TransformerBlock(nn.Module):
Args:
layer_id (int): Identifier for the layer.
args (ModelArgs): Model configuration parameters.
model_args (ModelArgs): Model configuration arguments.
Attributes:
n_heads (int): Number of attention heads.
Expand All @@ -341,21 +341,21 @@ class TransformerBlock(nn.Module):
"""

def __init__(self, layer_id: int, args: ModelArgs):
def __init__(self, layer_id: int, model_args: ModelArgs):

super().__init__()
self.n_heads = args.n_heads
self.dim = args.dim
self.attention = Attention(args)
self.n_heads = model_args.n_heads
self.dim = model_args.dim
self.attention = Attention(model_args)
self.feed_forward = FeedForward(
dim=args.dim,
hidden_dim=4 * args.dim,
multiple_of=args.multiple_of,
ffn_dim_multiplier=args.ffn_dim_multiplier,
dim=model_args.dim,
hidden_dim=4 * model_args.dim,
multiple_of=model_args.multiple_of,
ffn_dim_multiplier=model_args.ffn_dim_multiplier,
)
self.layer_id = layer_id
self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
self.attention_norm = RMSNorm(model_args.dim, eps=model_args.norm_eps)
self.ffn_norm = RMSNorm(model_args.dim, eps=model_args.norm_eps)

def forward(
self,
Expand Down Expand Up @@ -383,10 +383,10 @@ class Transformer(nn.Module):
Transformer Module
Args:
params (ModelArgs): Model configuration parameters.
model_args (ModelArgs): Model configuration arguments.
Attributes:
params (ModelArgs): Model configuration parameters.
model_args (ModelArgs): Model configuration arguments.
vocab_size (int): Vocabulary size.
n_layers (int): Number of layers in the model.
tok_embeddings (ParallelEmbedding): Token embeddings.
Expand All @@ -397,21 +397,21 @@ class Transformer(nn.Module):
"""

def __init__(self, params: ModelArgs):
def __init__(self, model_args: ModelArgs):

super().__init__()
self.params = params
self.vocab_size = params.vocab_size
self.n_layers = params.n_layers
self.model_args = model_args
self.vocab_size = model_args.vocab_size
self.n_layers = model_args.n_layers

self.embeddings = RotaryEmbedding(params)
self.embeddings = RotaryEmbedding(model_args)

self.layers = torch.nn.ModuleList()
for layer_id in range(params.n_layers):
self.layers.append(TransformerBlock(layer_id, params))
for layer_id in range(model_args.n_layers):
self.layers.append(TransformerBlock(layer_id, model_args))

self.norm = RMSNorm(params.dim, eps=params.norm_eps)
self.output = nn.Linear(params.dim, params.vocab_size, bias=False)
self.norm = RMSNorm(model_args.dim, eps=model_args.norm_eps)
self.output = nn.Linear(model_args.dim, model_args.vocab_size, bias=False)

def forward(self, tokens: torch.Tensor):
"""
Expand All @@ -426,7 +426,7 @@ def forward(self, tokens: torch.Tensor):
"""
h, freqs_cis = self.embeddings(tokens)
# fold batch and sequence dimension for more efficient allgather/reduce_scatter
h = h.view(-1, self.params.dim)
h = h.view(-1, self.model_args.dim)

for layer in self.layers:
h = layer(h, freqs_cis)
Expand All @@ -435,17 +435,17 @@ def forward(self, tokens: torch.Tensor):
# unfold batch and sequence dimension
bsz = tokens.shape[0]
bs_seqlen = h.shape[0]
h = h.view(bsz, bs_seqlen // bsz, self.params.dim)
h = h.view(bsz, bs_seqlen // bsz, self.model_args.dim)
output = self.output(h).float()
return output

@classmethod
def from_model_args(cls, model_args: ModelArgs):
def from_model_args(cls, model_args: ModelArgs) -> "Transformer":
"""
Initialize a Transformer model from a ModelArgs object.
Args:
model_args (ModelArgs): Model configuration parameters.
model_args (ModelArgs): Model configuration arguments.
Returns:
Transformer: Transformer model.
Expand Down

0 comments on commit fb5f4fc

Please sign in to comment.