Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
wconstab committed May 22, 2024
2 parents 8717192 + 7ede6a3 commit 9884f42
Show file tree
Hide file tree
Showing 6 changed files with 83 additions and 53 deletions.
51 changes: 24 additions & 27 deletions .github/workflows/integration_test_periodic.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,30 +14,27 @@ defaults:
shell: bash -l -eo pipefail {0}

jobs:
unit_tests_4gpu:
runs-on: linux.g5.12xlarge.nvidia.gpu
strategy:
matrix:
python-version: ['3.10']
steps:
- name: Check out repo
uses: actions/checkout@v3
- name: Setup conda env
uses: conda-incubator/setup-miniconda@v2
with:
auto-update-conda: true
miniconda-version: "latest"
activate-environment: test
python-version: ${{ matrix.python-version }}
- name: Update pip
run: python -m pip install --upgrade pip
- name: Install dependencies
run: |
pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121
pip install --pre torchdata --index-url https://download.pytorch.org/whl/nightly
python -m pip install -r requirements.txt
python -m pip install -r dev-requirements.txt
- name: Run test_runner.py
run: python ./test_runner.py
- name: Upload Coverage to Codecov
uses: codecov/codecov-action@v3
build-test:
uses: pytorch/test-infra/.github/workflows/linux_job.yml@main
with:
runner: linux.g5.12xlarge.nvidia.gpu
gpu-arch-type: cuda
gpu-arch-version: "12.1"
# This image is faster to clone than the default, but it lacks CC needed by triton
# (1m25s vs 2m37s).
docker-image: torchtitan-ubuntu-20.04-clang12
repository: pytorch/torchtitan
upload-artifact: outputs
script: |
set -eux
# The generic Linux job chooses to use base env, not the one setup by the image
CONDA_ENV=$(conda env list --json | jq -r ".envs | .[-1]")
conda activate "${CONDA_ENV}"
pip config --user set global.progress_bar off
python -m pip install --force-reinstall --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121
python -m pip install --pre torchdata --index-url https://download.pytorch.org/whl/nightly/
mkdir artifacts-to-be-uploaded
python ./test_runner.py artifacts-to-be-uploaded
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ We report our [Performance](docs/performance.md) verified on 64 A100 GPUs


### Coming soon

1. Async checkpointing
2. FP8 support
3. Context Parallel
Expand Down
11 changes: 10 additions & 1 deletion test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,12 +127,21 @@ def build_test_list(args):
OverrideDefinitions(
[
[
"--training.compile",
"--training.compile --model.norm_type=rmsnorm",
f"--job.dump_folder {args.output_dir}/1d_compile/",
],
],
"1D compile",
),
OverrideDefinitions(
[
[
"--training.compile --training.tensor_parallel_degree 2 --model.norm_type=rmsnorm",
f"--job.dump_folder {args.output_dir}/2d_compile/",
],
],
"2D compile",
),
OverrideDefinitions(
[
[
Expand Down
9 changes: 8 additions & 1 deletion torchtitan/datasets/hf_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,14 @@
import torch
from torch.distributed.checkpoint.stateful import Stateful
from torch.utils.data import IterableDataset
from torchdata.stateful_dataloader import StatefulDataLoader

try:
from torchdata.stateful_dataloader import StatefulDataLoader
except ImportError as e:
raise ImportError(
"Please install the latest torchdata nightly to use StatefulDataloader via:"
"pip3 install --pre torchdata --index-url https://download.pytorch.org/whl/nightly"
) from e

from torchtitan.datasets.tokenizer import Tokenizer
from torchtitan.logging_utils import logger
Expand Down
50 changes: 40 additions & 10 deletions torchtitan/parallelisms/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,11 @@ def pipeline_llama_manual(
batch_size = job_config.training.batch_size
local_seq_len = int(job_config.training.seq_len // parallel_dims.tp)
layers_io_shape = (batch_size, local_seq_len, model_config.dim)
output_layer_shape = (batch_size, local_seq_len, model_config.vocab_size)
output_layer_shape = (
batch_size,
job_config.training.seq_len,
model_config.vocab_size,
)
if pp_rank == 0:
# first layer
input = torch.randint(
Expand Down Expand Up @@ -318,7 +322,7 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
),
"output": col_parallel_strategy(
input_layouts=Shard(1),
output_layouts=(Shard(-1) if loss_parallel else Replicate()),
output_layouts=Shard(-1) if loss_parallel else Replicate(),
use_local_output=not loss_parallel,
),
"norm": SequenceParallel(),
Expand Down Expand Up @@ -360,20 +364,49 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):

logger.info("Applied Tensor Parallelism to the model")

# apply AC + torch.compile
ac_config = job_config.activation_checkpoint
enable_compile = job_config.training.compile
for layer_id, transformer_block in model.layers.items():
if ac_config.mode in ("full", "selective"):
transformer_block = checkpoint_wrapper(transformer_block, ac_config)
if enable_compile:
# turn on per-transformer block compile after AC wrapping and before FSDP
# TODO: dynamic shape have some issues so we turn it off for now.
# TODO: inline inbuilt nn modules does not work yet, enable it to accelarate
# compile time.
# torch._dynamo.config.inline_inbuilt_nn_modules = True
transformer_block = torch.compile(transformer_block, dynamic=False)
model.layers[layer_id] = transformer_block

if ac_config.mode in ("full", "selective"):
logger.info(f"Applied {ac_config.mode} activation checkpointing to the model")
if (
enable_compile
and ac_config.mode == "selective"
and ac_config.selective_ac_option == "op"
):
# some temp flags for torch.compile enablement + SAC
torch._dynamo.config._experimental_support_context_fn_in_torch_utils_checkpoint = (
True
)
if enable_compile:
if job_config.model.norm_type == "fused_rmsnorm":
raise NotImplementedError(
"fused_rmsnorm not yet compatible with torch.compile. Please use layernorm or rmsnorm."
)
logger.info("Compiled each TransformerBlock with torch.compile")

# apply DP (FSDP2)
if parallel_dims.dp_enabled:
dp_mesh = world_mesh["dp"] if world_mesh.ndim > 1 else world_mesh
assert dp_mesh.mesh_dim_names == ("dp",), dp_mesh.mesh_dim_names
mp_policy = MixedPrecisionPolicy(
param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param],
reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce],
)
ac_mode = job_config.activation_checkpoint.mode
fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy}
for layer_id, transformer_block in model.layers.items():
if job_config.activation_checkpoint.mode in ("full", "selective"):
transformer_block = checkpoint_wrapper(
transformer_block, job_config.activation_checkpoint
)
# As an optimization, do not reshard after forward for the last
# transformer block since FSDP would prefetch it immediately.
# When using Pipeline Parallelism, generally zero-2 is best so as to avoid repeated reshardings
Expand All @@ -387,12 +420,9 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
reshard_after_forward=reshard_after_forward,
)
model.layers[layer_id] = transformer_block

model = fully_shard(
model, **fsdp_config, reshard_after_forward=not parallel_dims.pp_enabled
)
if ac_mode in ("full", "selective"):
logger.info(f"Applied {ac_mode} activation checkpointing to the model")
logger.info("Applied FSDP to the model")

return model
14 changes: 0 additions & 14 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,20 +245,6 @@ def loss_fn(pred, labels):

metric_logger = build_metric_logger(job_config)

# torch.compile model for improved performance
if job_config.training.compile:
if (
job_config.activation_checkpoint.mode == "selective"
and job_config.activation_checkpoint.selective_ac_option == "op"
):
torch._dynamo.config._experimental_support_context_fn_in_torch_utils_checkpoint = (
True
)
logger.info("Compiling model with torch.compile")
# Dynamic shape have issues with distributed, turn dynamic off as Transformer
# training is static_shape TODO: resolve dynamic shape issue and restore defaults
model = torch.compile(model, dynamic=False)

train_state = TrainState()

# train loop
Expand Down

0 comments on commit 9884f42

Please sign in to comment.