Skip to content

Commit

Permalink
Tensor parallel Llama3 tutorial illustrating use of torch.distributed…
Browse files Browse the repository at this point in the history
… and nccl ops
  • Loading branch information
apbose committed Jan 17, 2025
1 parent 543bc9b commit 0313372
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 3 deletions.
2 changes: 2 additions & 0 deletions docsrc/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ Tutorials
* :ref:`mutable_torchtrt_module_example`
* :ref:`weight_streaming_example`
* :ref:`pre_allocated_output_example`
* :ref:`tensor_parallel_llama`

.. toctree::
:caption: Tutorials
Expand All @@ -87,6 +88,7 @@ Tutorials
tutorials/_rendered_examples/dynamo/mutable_torchtrt_module_example
tutorials/_rendered_examples/dynamo/weight_streaming_example
tutorials/_rendered_examples/dynamo/pre_allocated_output_example
tutorials/_rendered_examples/dynamo/tensor_parallel_llama

Dynamo Frontend
----------------
Expand Down
68 changes: 65 additions & 3 deletions examples/distributed_inference/tensor_parallel_llama3.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,27 @@
# Taken and modified pytorch lightening
# https://lightning.ai/lightning-ai/studios/tensor-parallelism-supercharging-large-model-training-with-pytorch-lightning
"""
.. _tensor_parallel_llama:
Torch distributed example for llama3-7B model
======================================================
As model sizes are increasing, large models with billions of parameters are trained with many GPUs, where regular data parallel training is no longer possible. In this example, we illustrate the Llama3-7B model inference using Torch-TensorRT backend, split across multiple GPUs using a form of model parallelism called Tensor Parallelism. We make use of Pytorch Distributed Tensor Parallelism Module. Please refer to these tutorials- https://pytorch.org/tutorials/intermediate/TP_tutorial.html and https://lightning.ai/lightning-ai/studios/tensor-parallelism-supercharging-large-model-training-with-pytorch-lightning?section=featured"""

# %%
# Imports and Model Definition
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^

import logging
import os
import time

import torch

# %%
# Pytorch Tensor Parallel APIs offer set of module level primitives(ParallelStyle) to configure the sharding of tensors in each layer of the model
# ParallelTransformer creates the parallelize_plan for the FeedForward layer of the model
from llama3_model import ModelArgs, ParallelTransformer

from tensor_parallel_initialize_dist import initialize_distributed_env
from torch.distributed._composable.fsdp import MixedPrecisionPolicy
from torch.distributed._composable.fsdp.fully_shard import fully_shard
Expand All @@ -14,11 +30,24 @@
checkpoint_wrapper,
)

# %%
# Initialize the distributed environment
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^

# Depending on the inputs/outputs sharded DTensors layout specified above, proper communication operations are required to transform DTensor layouts
# eg operations: allreduce, allgather, reduce_gather
# NCCL operations enable these operations.
# The below API does the following
# Initialize the communicators and the distributed environment
# Sets the path for the TRT-LLM plugin .so path which is required for the NCCL operations in Torch-TRT backend. Please note that if you are in python3.10 environment, `import tensorrt_llm` should be enough
# Initialize the logger. eg: In case of 2 GPUs, the log files are `./tensor_parallel_llama3_0.log` and `./tensor_parallel_llama3_1.log`
device_mesh, _world_size, _rank, logger = initialize_distributed_env(
"./tensor_parallel_llama3"
)
# Import should be after initialization of the TRT-LLM plugin .so path
import tensorrt_llm

# %%
# Model initialization with torch distributed parallel plan
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

logger.info(f"Starting PyTorch TP example on rank {_rank}.")
assert (
Expand All @@ -36,7 +65,38 @@
)

with torch.no_grad():
# The plan is
#plan = {
# "attention": PrepareModuleInput(
# input_layouts=(Shard(1), None),
# desired_input_layouts=(Replicate(), None),
# ),
# "attention.wq": ColwiseParallel(),
# "attention.wk": ColwiseParallel(),
# "attention.wv": ColwiseParallel(),
# "attention.wo": RowwiseParallel(output_layouts=Shard(1)),
# "attention_norm": SequenceParallel(),
# "feed_forward": PrepareModuleInput(
# input_layouts=(Shard(1),),
# desired_input_layouts=(Replicate(),),
# ),
# "feed_forward.w1": ColwiseParallel(),
# "feed_forward.w2": RowwiseParallel(output_layouts=Shard(1)),
# "feed_forward.w3": ColwiseParallel(),
# "ffn_norm": SequenceParallel(),
#}

model = ParallelTransformer(model_args, device_mesh)

# %%
# Model inference with Torch-TensorRT backend
# -------------------------------------------
# When we compile the distributed model using Torch-TensorRT backend, pytorch distributed libraries create the sharded model
# on multiple GPUs and the communicator operations are used for proper communication. In the above,
# `ColwiseParallel` and `RowwiseParallel` shard the attention layers in the column or row fashion.
# `SequenceParallel` performs sharded computations of the normalization layer
# `PrepareModuleInput` configures the model input with proper communication operations

torch.manual_seed(0)
inp = torch.randint(32000, (8, 256), device="cuda")
python_result = model(inp)
Expand All @@ -62,9 +122,11 @@
output = model(inp)
end = time.time()
if i == 0:
# Logging the Compilation time
logger.info(f"Compilation time is {end-start}")
assert (
python_result - output
).std() < 0.01, "Compilation result is not correct."
elif _rank == 0:
# Logging the inference time
logger.info(f"Inference time is {end-start}")

0 comments on commit 0313372

Please sign in to comment.