From 4bc146dd6054c6babe5b21e8a3cd8e1a5387d6c0 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Tue, 29 Oct 2024 14:13:27 -0700 Subject: [PATCH] [not for land] torch.compile individual linears Summary: Changes the torch.compile strategy to only apply to individual `Linear` modules. This is not for land, just creating to help with reproducing an issue with torch.compile. Test Plan: ``` // run debug model with-proxy CONFIG_FILE="./train_configs/debug_model.toml" ./run_llama_train.sh --training.compile --activation_checkpoint.mode none // will see this error // https://gist.github.com/vkuzo/2caa399a3ef7df2a79b9c1788c27ac7b ``` Reviewers: Subscribers: Tasks: Tags: --- torchtitan/parallelisms/parallelize_llama.py | 21 ++++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) diff --git a/torchtitan/parallelisms/parallelize_llama.py b/torchtitan/parallelisms/parallelize_llama.py index ed23936b..ac0f6c3b 100644 --- a/torchtitan/parallelisms/parallelize_llama.py +++ b/torchtitan/parallelisms/parallelize_llama.py @@ -76,6 +76,7 @@ def parallelize_llama( "Please use rmsnorm or layernorm." ) apply_compile(model) + print('after compile', model) if ( parallel_dims.dp_shard_enabled @@ -306,11 +307,23 @@ def apply_compile(model: nn.Module): Apply torch.compile to each TransformerBlock, which makes compilation efficient due to repeated structure. Alternatively one can compile the whole model (after applying DP). """ - for layer_id, transformer_block in model.layers.named_children(): - transformer_block = torch.compile(transformer_block, fullgraph=True) - model.layers.register_module(layer_id, transformer_block) - logger.info("Compiling each TransformerBlock with torch.compile") + # each transformer block + if False: + for layer_id, transformer_block in model.layers.named_children(): + transformer_block = torch.compile(transformer_block, fullgraph=True) + model.layers.register_module(layer_id, transformer_block) + logger.info("Compiling each TransformerBlock with torch.compile") + + # individual linear + if True: + for name, child in model.named_children(): + if isinstance(child, torch.nn.Linear): + new_child = torch.compile(child) + setattr(model, name, new_child) + else: + apply_compile(child) + logger.info("Compiling each linear with torch.compile") def apply_fsdp(