From 637a22b2427b7e9d5dfea25304aa98ddafa8bbc2 Mon Sep 17 00:00:00 2001 From: ashors1 Date: Tue, 14 Nov 2023 12:45:03 -0800 Subject: [PATCH] add support for glam with repeated_layer=False --- praxis/layers/glam.py | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/praxis/layers/glam.py b/praxis/layers/glam.py index 0a944a63..53063c70 100644 --- a/praxis/layers/glam.py +++ b/praxis/layers/glam.py @@ -208,6 +208,7 @@ def GlamUniTransformerLmHParams( z_loss_weight=1e-4, combine_qkv=False, bidirectional=False, + repeat=True, num_pipeline_stages=1, num_pipeline_microbatches=1, model_type=LanguageModelType.CAUSAL, @@ -322,14 +323,19 @@ def GlamUniTransformerLmHParams( num_blocks = num_transformer_layers // 2 if moe else num_transformer_layers if num_pipeline_stages == 1: - p.stacked_transformer_tpl = pax_fiddle.Config( - transformers.StackedTransformerRepeated, - name='decoder', - unroll_in_decode=True, - block=glam_p, - x_times=num_blocks, - checkpoint_policy=checkpoint_policy, - ) + if repeat: + p.stacked_transformer_tpl = pax_fiddle.Config( + transformers.StackedTransformerRepeated, + name='decoder', + unroll_in_decode=True, + block=glam_p, + x_times=num_blocks, + checkpoint_policy=checkpoint_policy, + ) + else: + glam_p.num_layers = num_transformer_layers + glam_p.moe_layers = list(range(0, glam_p.num_layers, 2)) + p.stacked_transformer_tpl = glam_p else: assert num_blocks % num_pipeline_stages == 0 glam_p.num_layers = num_transformer_layers // num_pipeline_stages