Skip to content

Commit

Permalink
[do NOT land] CP+torch.compile debugging attempt
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
XilunWu committed Jan 15, 2025
1 parent 95677cb commit 2a57395
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 2 deletions.
2 changes: 1 addition & 1 deletion torchtitan/models/llama/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
"debugmodel": ModelArgs(dim=256, n_layers=8, n_heads=16, rope_theta=500000),
"8B": ModelArgs(
dim=4096,
n_layers=32,
n_layers=1,
n_heads=32,
n_kv_heads=8,
ffn_dim_multiplier=1.3,
Expand Down
7 changes: 6 additions & 1 deletion torchtitan/models/llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,10 @@ def init_weights(self, init_std: float):
nn.init.trunc_normal_(linear.weight, mean=0.0, std=0.02)
nn.init.trunc_normal_(self.wo.weight, mean=0.0, std=init_std)

@torch.compiler.disable
def SDPA(self, *args, **kwargs):
return F.scaled_dot_product_attention(*args, **kwargs)

def forward(
self,
x: torch.Tensor,
Expand Down Expand Up @@ -206,7 +210,8 @@ def forward(
xv = values.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)

# we use casual mask for training
output = F.scaled_dot_product_attention(xq, xk, xv, is_causal=True)
# output = F.scaled_dot_product_attention(xq, xk, xv, is_causal=True)
output = self.SDPA(xq, xk, xv, is_causal=True)
output = output.transpose(
1, 2
).contiguous() # (bs, seqlen, n_local_heads, head_dim)
Expand Down
3 changes: 3 additions & 0 deletions torchtitan/parallelisms/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,9 @@ def apply_compile(model: nn.Module):
Apply torch.compile to each TransformerBlock, which makes compilation efficient due to
repeated structure. Alternatively one can compile the whole model (after applying DP).
"""
# torch._inductor.config.force_disable_caches = True
torch._dynamo.config.suppress_errors = True

for layer_id, transformer_block in model.layers.named_children():
transformer_block = torch.compile(transformer_block, fullgraph=True)
model.layers.register_module(layer_id, transformer_block)
Expand Down
1 change: 1 addition & 0 deletions torchtitan/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,7 @@ def context(cp_context: Optional[Generator[None, None, None]] = None):
stack.enter_context(
sdpa_kernel(
[SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION]
# [SDPBackend.CUDNN_ATTENTION]
)
)
stack.enter_context(cp_context)
Expand Down

0 comments on commit 2a57395

Please sign in to comment.