From 2a573955de37c027b23d51a4deaa6b77044520b2 Mon Sep 17 00:00:00 2001 From: Xilun Wu <12968408+XilunWu@users.noreply.github.com> Date: Wed, 15 Jan 2025 14:56:58 -0800 Subject: [PATCH] [do NOT land] CP+torch.compile debugging attempt [ghstack-poisoned] --- torchtitan/models/llama/__init__.py | 2 +- torchtitan/models/llama/model.py | 7 ++++++- torchtitan/parallelisms/parallelize_llama.py | 3 +++ torchtitan/utils.py | 1 + 4 files changed, 11 insertions(+), 2 deletions(-) diff --git a/torchtitan/models/llama/__init__.py b/torchtitan/models/llama/__init__.py index 3bb430d2..104a7b78 100644 --- a/torchtitan/models/llama/__init__.py +++ b/torchtitan/models/llama/__init__.py @@ -14,7 +14,7 @@ "debugmodel": ModelArgs(dim=256, n_layers=8, n_heads=16, rope_theta=500000), "8B": ModelArgs( dim=4096, - n_layers=32, + n_layers=1, n_heads=32, n_kv_heads=8, ffn_dim_multiplier=1.3, diff --git a/torchtitan/models/llama/model.py b/torchtitan/models/llama/model.py index 641ef6de..2462dcf3 100644 --- a/torchtitan/models/llama/model.py +++ b/torchtitan/models/llama/model.py @@ -169,6 +169,10 @@ def init_weights(self, init_std: float): nn.init.trunc_normal_(linear.weight, mean=0.0, std=0.02) nn.init.trunc_normal_(self.wo.weight, mean=0.0, std=init_std) + @torch.compiler.disable + def SDPA(self, *args, **kwargs): + return F.scaled_dot_product_attention(*args, **kwargs) + def forward( self, x: torch.Tensor, @@ -206,7 +210,8 @@ def forward( xv = values.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) # we use casual mask for training - output = F.scaled_dot_product_attention(xq, xk, xv, is_causal=True) + # output = F.scaled_dot_product_attention(xq, xk, xv, is_causal=True) + output = self.SDPA(xq, xk, xv, is_causal=True) output = output.transpose( 1, 2 ).contiguous() # (bs, seqlen, n_local_heads, head_dim) diff --git a/torchtitan/parallelisms/parallelize_llama.py b/torchtitan/parallelisms/parallelize_llama.py index 9728569a..3f1909ed 100644 --- a/torchtitan/parallelisms/parallelize_llama.py +++ b/torchtitan/parallelisms/parallelize_llama.py @@ -299,6 +299,9 @@ def apply_compile(model: nn.Module): Apply torch.compile to each TransformerBlock, which makes compilation efficient due to repeated structure. Alternatively one can compile the whole model (after applying DP). """ + # torch._inductor.config.force_disable_caches = True + torch._dynamo.config.suppress_errors = True + for layer_id, transformer_block in model.layers.named_children(): transformer_block = torch.compile(transformer_block, fullgraph=True) model.layers.register_module(layer_id, transformer_block) diff --git a/torchtitan/utils.py b/torchtitan/utils.py index 88663c00..51b47494 100644 --- a/torchtitan/utils.py +++ b/torchtitan/utils.py @@ -213,6 +213,7 @@ def context(cp_context: Optional[Generator[None, None, None]] = None): stack.enter_context( sdpa_kernel( [SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION] + # [SDPBackend.CUDNN_ATTENTION] ) ) stack.enter_context(cp_context)