Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BE] Normalized to use model_args: ModelArgs #58

Merged
merged 1 commit into from
Feb 13, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 49 additions & 41 deletions torchtrain/models/llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ class Attention(nn.Module):
Multi-head attention module.

Args:
args (ModelArgs): Model configuration parameters.
model_args (ModelArgs): Model configuration arguments.

Attributes:
n_kv_heads (int): Number of key and value heads.
Expand All @@ -182,18 +182,26 @@ class Attention(nn.Module):

"""

def __init__(self, args: ModelArgs):
def __init__(self, model_args: ModelArgs):

super().__init__()
self.n_heads = args.n_heads
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
self.n_heads = model_args.n_heads
self.n_kv_heads = (
model_args.n_heads
if model_args.n_kv_heads is None
else model_args.n_kv_heads
)
self.n_rep = self.n_heads // self.n_kv_heads
self.head_dim = args.dim // args.n_heads
self.head_dim = model_args.dim // model_args.n_heads

self.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False)
self.wk = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
self.wv = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False)
self.wq = nn.Linear(
model_args.dim, model_args.n_heads * self.head_dim, bias=False
)
self.wk = nn.Linear(model_args.dim, self.n_kv_heads * self.head_dim, bias=False)
self.wv = nn.Linear(model_args.dim, self.n_kv_heads * self.head_dim, bias=False)
self.wo = nn.Linear(
model_args.n_heads * self.head_dim, model_args.dim, bias=False
)

def forward(
self,
Expand Down Expand Up @@ -290,18 +298,18 @@ class RotaryEmbedding(nn.Module):
RotaryEmbedding Module
"""

def __init__(self, params: ModelArgs):
def __init__(self, model_args: ModelArgs):
super().__init__()
self.params = params
self.tok_embeddings = nn.Embedding(params.vocab_size, params.dim)
self.model_args = model_args
self.tok_embeddings = nn.Embedding(model_args.vocab_size, model_args.dim)

self.freqs_cis = precompute_freqs_cis(
# Note that self.params.max_seq_len is multiplied by 2 because the token limit for the Llama 2 generation
# Note that self.model_args.max_seq_len is multiplied by 2 because the token limit for the Llama 2 generation
# of models is 4096.
# Adding this multiplier instead of using 4096 directly allows for dynamism of token lengths while training
# or fine-tuning.
self.params.dim // self.params.n_heads,
self.params.max_seq_len * 2,
self.model_args.dim // self.model_args.n_heads,
self.model_args.max_seq_len * 2,
)

def forward(self, tokens: torch.Tensor):
Expand All @@ -327,7 +335,7 @@ class TransformerBlock(nn.Module):

Args:
layer_id (int): Identifier for the layer.
args (ModelArgs): Model configuration parameters.
model_args (ModelArgs): Model configuration arguments.

Attributes:
n_heads (int): Number of attention heads.
Expand All @@ -341,21 +349,21 @@ class TransformerBlock(nn.Module):

"""

def __init__(self, layer_id: int, args: ModelArgs):
def __init__(self, layer_id: int, model_args: ModelArgs):

super().__init__()
self.n_heads = args.n_heads
self.dim = args.dim
self.attention = Attention(args)
self.n_heads = model_args.n_heads
self.dim = model_args.dim
self.attention = Attention(model_args)
self.feed_forward = FeedForward(
dim=args.dim,
hidden_dim=4 * args.dim,
multiple_of=args.multiple_of,
ffn_dim_multiplier=args.ffn_dim_multiplier,
dim=model_args.dim,
hidden_dim=4 * model_args.dim,
multiple_of=model_args.multiple_of,
ffn_dim_multiplier=model_args.ffn_dim_multiplier,
)
self.layer_id = layer_id
self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
self.attention_norm = RMSNorm(model_args.dim, eps=model_args.norm_eps)
self.ffn_norm = RMSNorm(model_args.dim, eps=model_args.norm_eps)

def forward(
self,
Expand Down Expand Up @@ -383,10 +391,10 @@ class Transformer(nn.Module):
Transformer Module

Args:
params (ModelArgs): Model configuration parameters.
model_args (ModelArgs): Model configuration arguments.

Attributes:
params (ModelArgs): Model configuration parameters.
model_args (ModelArgs): Model configuration arguments.
vocab_size (int): Vocabulary size.
n_layers (int): Number of layers in the model.
tok_embeddings (ParallelEmbedding): Token embeddings.
Expand All @@ -397,21 +405,21 @@ class Transformer(nn.Module):

"""

def __init__(self, params: ModelArgs):
def __init__(self, model_args: ModelArgs):

super().__init__()
self.params = params
self.vocab_size = params.vocab_size
self.n_layers = params.n_layers
self.model_args = model_args
self.vocab_size = model_args.vocab_size
self.n_layers = model_args.n_layers

self.embeddings = RotaryEmbedding(params)
self.embeddings = RotaryEmbedding(model_args)

self.layers = torch.nn.ModuleList()
for layer_id in range(params.n_layers):
self.layers.append(TransformerBlock(layer_id, params))
for layer_id in range(model_args.n_layers):
self.layers.append(TransformerBlock(layer_id, model_args))

self.norm = RMSNorm(params.dim, eps=params.norm_eps)
self.output = nn.Linear(params.dim, params.vocab_size, bias=False)
self.norm = RMSNorm(model_args.dim, eps=model_args.norm_eps)
self.output = nn.Linear(model_args.dim, model_args.vocab_size, bias=False)

def forward(self, tokens: torch.Tensor):
"""
Expand All @@ -426,7 +434,7 @@ def forward(self, tokens: torch.Tensor):
"""
h, freqs_cis = self.embeddings(tokens)
# fold batch and sequence dimension for more efficient allgather/reduce_scatter
h = h.view(-1, self.params.dim)
h = h.view(-1, self.model_args.dim)

for layer in self.layers:
h = layer(h, freqs_cis)
Expand All @@ -435,17 +443,17 @@ def forward(self, tokens: torch.Tensor):
# unfold batch and sequence dimension
bsz = tokens.shape[0]
bs_seqlen = h.shape[0]
h = h.view(bsz, bs_seqlen // bsz, self.params.dim)
h = h.view(bsz, bs_seqlen // bsz, self.model_args.dim)
output = self.output(h).float()
return output

@classmethod
def from_model_args(cls, model_args: ModelArgs):
def from_model_args(cls, model_args: ModelArgs) -> "Transformer":
"""
Initialize a Transformer model from a ModelArgs object.

Args:
model_args (ModelArgs): Model configuration parameters.
model_args (ModelArgs): Model configuration arguments.

Returns:
Transformer: Transformer model.
Expand Down
Loading