Skip to content

Commit

Permalink
Merge pull request #60 from nvjax-svc-0:patch/glam_without_repeat_layer
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 633222135
  • Loading branch information
pax authors committed May 13, 2024
2 parents 8053fa8 + 637a22b commit f204f0e
Showing 1 changed file with 14 additions and 8 deletions.
22 changes: 14 additions & 8 deletions praxis/layers/glam.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,7 @@ def GlamUniTransformerLmHParams(
z_loss_weight=1e-4,
combine_qkv=False,
bidirectional=False,
repeat=True,
num_pipeline_stages=1,
num_pipeline_microbatches=1,
model_type=LanguageModelType.CAUSAL,
Expand Down Expand Up @@ -322,14 +323,19 @@ def GlamUniTransformerLmHParams(
num_blocks = num_transformer_layers // 2 if moe else num_transformer_layers

if num_pipeline_stages == 1:
p.stacked_transformer_tpl = pax_fiddle.Config(
transformers.StackedTransformerRepeated,
name='decoder',
unroll_in_decode=True,
block=glam_p,
x_times=num_blocks,
checkpoint_policy=checkpoint_policy,
)
if repeat:
p.stacked_transformer_tpl = pax_fiddle.Config(
transformers.StackedTransformerRepeated,
name='decoder',
unroll_in_decode=True,
block=glam_p,
x_times=num_blocks,
checkpoint_policy=checkpoint_policy,
)
else:
glam_p.num_layers = num_transformer_layers
glam_p.moe_layers = list(range(0, glam_p.num_layers, 2))
p.stacked_transformer_tpl = glam_p
else:
assert num_blocks % num_pipeline_stages == 0
glam_p.num_layers = num_transformer_layers // num_pipeline_stages
Expand Down

0 comments on commit f204f0e

Please sign in to comment.