Skip to content

Commit

Permalink
make model_conf for model control via sh, add rdvz point
Browse files Browse the repository at this point in the history
  • Loading branch information
lessw2020 committed Feb 12, 2024
1 parent e14a350 commit c5da96d
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 30 deletions.
16 changes: 9 additions & 7 deletions run_llama_train.sh
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ set -ex
# e.g.
# LOG_RANK=0,1 NGPU=4 SP=2 ./run_llama_train.sh

MODEL_CONF=${MODEL_CONF:-"7B"}
MODEL=${MODEL:-"llama"}
MODEL_CONF=${MODEL_CONF:-"debugmodel"}
NGPU=${NGPU:-"8"}
PP=${PP:-"1"}
SP=${SP:-"1"}
Expand All @@ -22,9 +23,10 @@ CHECKPOINT_FOLDER=${CHECKPOINT_FOLDER:-""}
# Please adjust this to a longer interval period. The unit of measurement is in steps.
CHECKPOINT_INTERVAL=${CHECKPOINT_INTERVAL:-5}

torchrun --nproc_per_node=${NGPU} train.py --steps 25
# --compile --local-ranks-filter ${LOG_RANK}
# --local-ranks-filter ${LOG_RANK} --role rank --tee 3 \
\
#--pp_degree ${PP} --sp_degree ${SP} --dp_degree ${DP}
#--checkpoint-folder=${CHECKPOINT_FOLDER} --checkpoint-interval=${CHECKPOINT_INTERVAL}
torchrun --nproc_per_node=${NGPU} --rdzv_endpoint="localhost:5972" \
--local-ranks-filter ${LOG_RANK} --role rank --tee 3 \
train.py --steps 10 \
--model ${MODEL} --model_conf ${MODEL_CONF} \
--pp_degree ${PP} --sp_degree ${SP} --dp_degree ${DP} \
# --compile \
# --checkpoint-folder=${CHECKPOINT_FOLDER} --checkpoint-interval=${CHECKPOINT_INTERVAL}
41 changes: 19 additions & 22 deletions torchtrain/models/llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,6 @@ def reset_parameters(self):
torch.nn.init.ones_(self.weight)



def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
"""
Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
Expand Down Expand Up @@ -186,7 +185,6 @@ class Attention(nn.Module):
"""

def __init__(self, args: ModelArgs):

super().__init__()
self.n_heads = args.n_heads
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
Expand All @@ -203,12 +201,14 @@ def reset_parameters(self, init_std):
nn.init.trunc_normal_(
item.weight,
mean=0.0,
std=0.02,)
std=0.02,
)

nn.init.trunc_normal_(
self.wo.weight,
mean=0.0,
std=init_std,)
self.wo.weight,
mean=0.0,
std=init_std,
)

def forward(
self,
Expand Down Expand Up @@ -284,7 +284,6 @@ def __init__(
multiple_of: int,
ffn_dim_multiplier: Optional[float],
):

super().__init__()
hidden_dim = int(2 * hidden_dim / 3)
# custom dim factor multiplier
Expand All @@ -299,18 +298,19 @@ def __init__(
def forward(self, x):
return self.w2(F.silu(self.w1(x)) * self.w3(x))

def reset_parameters(self,init_std):
def reset_parameters(self, init_std):
nn.init.trunc_normal_(
self.w1.weight,
mean=0.0,
std=0.02,)
self.w1.weight,
mean=0.0,
std=0.02,
)

for item in (self.w2, self.w3):
nn.init.trunc_normal_(
item.weight,
mean=0.0,
std=init_std,)

std=init_std,
)


class RotaryEmbedding(nn.Module):
Expand Down Expand Up @@ -370,7 +370,6 @@ class TransformerBlock(nn.Module):
"""

def __init__(self, layer_id: int, args: ModelArgs):

super().__init__()
self.n_heads = args.n_heads
self.dim = args.dim
Expand All @@ -386,7 +385,7 @@ def __init__(self, layer_id: int, args: ModelArgs):
self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)

self.weight_init_std = 0.02 / (2 * self.num_layers)**0.5
self.weight_init_std = 0.02 / (2 * self.num_layers) ** 0.5

def forward(
self,
Expand All @@ -409,15 +408,14 @@ def forward(
return out

def reset_parameters(self):
""" reset params and norms for entire block """
"""reset params and norms for entire block"""
self.attention_norm.reset_parameters()
self.ffn_norm.reset_parameters()

self.attention.reset_parameters(self.weight_init_std)
self.feed_forward.reset_parameters(self.weight_init_std)



class Transformer(nn.Module):
"""
Transformer Module
Expand All @@ -438,7 +436,6 @@ class Transformer(nn.Module):
"""

def __init__(self, params: ModelArgs):

super().__init__()
self.params = params
self.vocab_size = params.vocab_size
Expand All @@ -457,7 +454,9 @@ def __init__(self, params: ModelArgs):
self.reset_parameters()
rank0_log(f"{self.params=}")

def reset_parameters(self,):
def reset_parameters(
self,
):
for layer in self.layers:
layer.reset_parameters()
self.norm.reset_parameters()
Expand All @@ -470,9 +469,7 @@ def reset_parameters(self,):
a=-cutoff_factor * final_out_std,
b=cutoff_factor * final_out_std,
)
rank0_log(f"Model params initialized via reset_params")


rank0_log("Model params initialized via reset_params")

def forward(self, tokens: torch.Tensor):
"""
Expand Down
2 changes: 1 addition & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def main(args):
parser.add_argument(
"--model_conf",
type=str,
default="40B",
default="debugmodel",
help="which model config to train",
)
parser.add_argument("--dataset", type=str, default="alpaca", help="dataset to use")
Expand Down

0 comments on commit c5da96d

Please sign in to comment.