Skip to content

Commit

Permalink
Merge branch 'dev' into kill_generation_length
Browse files Browse the repository at this point in the history
  • Loading branch information
maxisawesome authored Feb 22, 2024
2 parents c6ade3c + d3987a0 commit d66f7c2
Show file tree
Hide file tree
Showing 7 changed files with 31 additions and 12 deletions.
2 changes: 1 addition & 1 deletion composer/_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@

"""The Composer Version."""

__version__ = '0.19.1'
__version__ = '0.20.0'
1 change: 1 addition & 0 deletions composer/devices/device_tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ class DeviceTPU(Device):
More details.
"""

dist_backend = 'xla'
name = 'tpu'

def __init__(self):
Expand Down
5 changes: 5 additions & 0 deletions composer/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2567,6 +2567,11 @@ def _train_microbatch(self, use_grad_scaling: bool, current_batch_size: int,
microbatch_loss.mul_(microbatch_num_samples / current_batch_size)
microbatch_loss.backward(create_graph=self._backwards_create_graph)

if self.state.device.dist_backend == 'xla':
# For xla devices, the program between any pair of mark_steps() calls is compiled. With out this, the
# microbatching loop is unrolled, drastically increasing compile time.
xm.mark_step()

self.engine.run_event(Event.AFTER_BACKWARD)

# Use microbatch outputs to update training metrics
Expand Down
17 changes: 15 additions & 2 deletions composer/utils/dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,15 +37,20 @@
import logging
import os
import pickle
import sys
import time
from contextlib import contextmanager
from typing import TYPE_CHECKING, Any, List, Optional, Sequence, TypeVar, Union, cast

import torch
import torch.distributed as dist
import torch.utils.data
from packaging import version

from composer.utils.device import get_device, is_hpu_installed
from composer.utils.device import get_device, is_hpu_installed, is_tpu_installed

if is_tpu_installed():
import torch_xla

if TYPE_CHECKING:
from composer.devices import Device
Expand Down Expand Up @@ -534,7 +539,15 @@ def initialize_dist(device: Union[str, Device], timeout: float = 300.0):

dist_env_vars_match_defaults = all(os.environ.get(k, v) == v for (k, v) in dist_env_var_defaults.items())

if dist_env_vars_match_defaults:
if device_obj.dist_backend == 'xla':
if not 'torch_xla' in sys.modules:
raise RuntimeError('PyTorch XLA package not found. In order to use XLA based devices '
'PyTorch XLA must be installed.')
if version.parse(torch_xla.__version__) < version.parse('2.1.0'):
raise RuntimeError(f'PyTorch XLA version must be at least 2.1.0, found {torch_xla.__version__}.')
# XLA initialization requires the init_method to be set
dist.init_process_group(device_obj.dist_backend, init_method='xla://')
elif dist_env_vars_match_defaults:
# Fill in the remaining single-rank variables
os.environ.update(dist_env_var_defaults)
dist.init_process_group(device_obj.dist_backend, store=dist.HashStore(), world_size=1, rank=0)
Expand Down
4 changes: 2 additions & 2 deletions docker/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ all dependencies for both NLP and Vision models. They are built on top of the
<!-- BEGIN_COMPOSER_BUILD_MATRIX -->
| Composer Version | CUDA Support | Docker Tag |
|--------------------|----------------|----------------------------------------------------------------|
| 0.19.1 | Yes | `mosaicml/composer:latest`, `mosaicml/composer:0.19.1` |
| 0.19.1 | No | `mosaicml/composer:latest_cpu`, `mosaicml/composer:0.19.1_cpu` |
| 0.20.0 | Yes | `mosaicml/composer:latest`, `mosaicml/composer:0.20.0` |
| 0.20.0 | No | `mosaicml/composer:latest_cpu`, `mosaicml/composer:0.20.0_cpu` |
<!-- END_COMPOSER_BUILD_MATRIX -->

**Note**: For a lightweight installation, we recommended using a [MosaicML PyTorch Image](#pytorch-images) and manually
Expand Down
12 changes: 6 additions & 6 deletions docker/build_matrix.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -246,9 +246,9 @@
TORCHVISION_VERSION: 0.18.0
- AWS_OFI_NCCL_VERSION: ''
BASE_IMAGE: nvidia/cuda:12.1.0-cudnn8-devel-ubuntu20.04
COMPOSER_INSTALL_COMMAND: mosaicml[all]==0.19.1
COMPOSER_INSTALL_COMMAND: mosaicml[all]==0.20.0
CUDA_VERSION: 12.1.0
IMAGE_NAME: composer-0-19-1
IMAGE_NAME: composer-0-20-0
MOFED_VERSION: 5.5-1.0.3.2
NVIDIA_REQUIRE_CUDA_OVERRIDE: cuda>=12.1 brand=tesla,driver>=450,driver<451 brand=tesla,driver>=470,driver<471
brand=unknown,driver>=470,driver<471 brand=nvidia,driver>=470,driver<471 brand=nvidiartx,driver>=470,driver<471
Expand All @@ -269,23 +269,23 @@
PYTORCH_NIGHTLY_VERSION: ''
PYTORCH_VERSION: 2.1.2
TAGS:
- mosaicml/composer:0.19.1
- mosaicml/composer:0.20.0
- mosaicml/composer:latest
TARGET: composer_stage
TORCHVISION_VERSION: 0.16.2
- AWS_OFI_NCCL_VERSION: ''
BASE_IMAGE: ubuntu:20.04
COMPOSER_INSTALL_COMMAND: mosaicml[all]==0.19.1
COMPOSER_INSTALL_COMMAND: mosaicml[all]==0.20.0
CUDA_VERSION: ''
IMAGE_NAME: composer-0-19-1-cpu
IMAGE_NAME: composer-0-20-0-cpu
MOFED_VERSION: 5.5-1.0.3.2
NVIDIA_REQUIRE_CUDA_OVERRIDE: ''
PYTHON_VERSION: '3.10'
PYTORCH_NIGHTLY_URL: ''
PYTORCH_NIGHTLY_VERSION: ''
PYTORCH_VERSION: 2.1.2
TAGS:
- mosaicml/composer:0.19.1_cpu
- mosaicml/composer:0.20.0_cpu
- mosaicml/composer:latest_cpu
TARGET: composer_stage
TORCHVISION_VERSION: 0.16.2
2 changes: 1 addition & 1 deletion docker/generate_build_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ def _main():
composer_entries = []

# The `GIT_COMMIT` is a placeholder and Jenkins will substitute it with the actual git commit for the `composer_staging` images
composer_versions = ['0.19.1'] # Only build images for the latest composer version
composer_versions = ['0.20.0'] # Only build images for the latest composer version
composer_python_versions = [PRODUCTION_PYTHON_VERSION] # just build composer against the latest

for product in itertools.product(composer_python_versions, composer_versions, cuda_options):
Expand Down

0 comments on commit d66f7c2

Please sign in to comment.