diff --git a/torchtitan/parallelisms/parallelize_llama.py b/torchtitan/parallelisms/parallelize_llama.py index ed23936b..ac0f6c3b 100644 --- a/torchtitan/parallelisms/parallelize_llama.py +++ b/torchtitan/parallelisms/parallelize_llama.py @@ -76,6 +76,7 @@ def parallelize_llama( "Please use rmsnorm or layernorm." ) apply_compile(model) + print('after compile', model) if ( parallel_dims.dp_shard_enabled @@ -306,11 +307,23 @@ def apply_compile(model: nn.Module): Apply torch.compile to each TransformerBlock, which makes compilation efficient due to repeated structure. Alternatively one can compile the whole model (after applying DP). """ - for layer_id, transformer_block in model.layers.named_children(): - transformer_block = torch.compile(transformer_block, fullgraph=True) - model.layers.register_module(layer_id, transformer_block) - logger.info("Compiling each TransformerBlock with torch.compile") + # each transformer block + if False: + for layer_id, transformer_block in model.layers.named_children(): + transformer_block = torch.compile(transformer_block, fullgraph=True) + model.layers.register_module(layer_id, transformer_block) + logger.info("Compiling each TransformerBlock with torch.compile") + + # individual linear + if True: + for name, child in model.named_children(): + if isinstance(child, torch.nn.Linear): + new_child = torch.compile(child) + setattr(model, name, new_child) + else: + apply_compile(child) + logger.info("Compiling each linear with torch.compile") def apply_fsdp(