Skip to content

Commit

Permalink
Update on "Add Sequence Parallelism to llama"
Browse files Browse the repository at this point in the history
Somehow the torch.compile not working although eager sequence
parallelism working, so currently don't turn it on by default

[ghstack-poisoned]
  • Loading branch information
wanchaol committed Feb 7, 2024
2 parents 09e7447 + 549e197 commit 52c9091
Show file tree
Hide file tree
Showing 14 changed files with 351 additions and 168 deletions.
27 changes: 27 additions & 0 deletions .flake8
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
[flake8]
# Suggested config from pytorch that we can adapt
select = B,C,E,F,N,P,T4,W,B9,TOR0,TOR1,TOR2
max-line-length = 120
# C408 ignored because we like the dict keyword argument syntax
# E501 is not flexible enough, we're using B950 instead
# N812 ignored because import torch.nn.functional as F is PyTorch convention
# N817 ignored because importing using acronyms is convention (DistributedDataParallel as DDP)
# E731 allow usage of assigning lambda expressions
ignore =
E203,E305,E402,E501,E721,E741,F405,F821,F841,F999,W503,W504,C408,E302,W291,E303,N812,N817,E731
# shebang has extra meaning in fbcode lints, so I think it's not worth trying
# to line this up with executable bit
EXE001,
# these ignores are from flake8-bugbear; please fix!
B007,B008,
optional-ascii-coding = True
exclude =
./.git,
./docs
./build
./scripts,
./venv,
*.pyi
.pre-commit-config.yaml
*.md
.flake8
39 changes: 39 additions & 0 deletions .github/workflows/lint.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
name: Lint

on:
pull_request:

concurrency:
group: lint-${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_number || github.ref }}
cancel-in-progress: true

defaults:
run:
shell: bash -l -eo pipefail {0}

jobs:
lint:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ['3.11']
steps:
- name: Check out repo
uses: actions/checkout@v3
- name: Setup python
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Update pip
run: python -m pip install --upgrade pip
- name: Install lint utilities
run: |
python -m pip install pre-commit
pre-commit install-hooks
- id: file_changes
uses: trilom/[email protected]
with:
prNumber: ${{ github.event.number }}
output: ' '
- name: Lint modified files
run: pre-commit run --files ${{ steps.file_changes.outputs.files }}
51 changes: 51 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
exclude: 'build'

default_language_version:
python: python3

repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: 6306a48f7dae5861702d573c9c247e4e9498e867
hooks:
- id: trailing-whitespace
- id: check-ast
- id: check-merge-conflict
- id: no-commit-to-branch
args: ['--branch=main']
- id: check-added-large-files
args: ['--maxkb=500']
- id: end-of-file-fixer
exclude: '^(.*\.svg)$'

- repo: https://github.com/Lucas-C/pre-commit-hooks
rev: v1.5.4
hooks:
- id: insert-license
files: \.py$
args:
- --license-filepath
- docs/license_header.txt

- repo: https://github.com/pycqa/flake8
rev: 34cbf8ef3950f43d09b85e2e45c15ae5717dc37b
hooks:
- id: flake8
additional_dependencies:
- flake8-bugbear == 22.4.25
- pep8-naming == 0.12.1
- torchfix
args: ['--config=.flake8']

- repo: https://github.com/omnilib/ufmt
rev: v2.3.0
hooks:
- id: ufmt
additional_dependencies:
- black == 22.12.0
- usort == 1.0.5

- repo: https://github.com/jsh9/pydoclint
rev: d88180a8632bb1602a4d81344085cf320f288c5a
hooks:
- id: pydoclint
args: [--config=pyproject.toml]
3 changes: 3 additions & 0 deletions dev-requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
pytest
pytest-cov
pre-commit
2 changes: 2 additions & 0 deletions docs/license_header.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Copyright (c) Meta Platforms, Inc. and affiliates.
This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
6 changes: 6 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
[tool.pydoclint]
style = 'google'
check-return-types = 'False'

[tool.pytest.ini_options]
addopts = ["--showlocals"] # show local variables in tracebacks
6 changes: 5 additions & 1 deletion torchtrain/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.

from torchtrain.datasets.alpaca import build_alpaca_data_loader
from torchtrain.datasets.tokenizer import create_tokenizer
from torchtrain.datasets.pad_batch_sequence import pad_batch_to_longest_seq
from torchtrain.datasets.tokenizer import create_tokenizer

__all__ = ["build_alpaca_data_loader", "create_tokenizer", "pad_batch_to_longest_seq"]

dataloader_fn = {
"alpaca": build_alpaca_data_loader,
Expand Down
27 changes: 8 additions & 19 deletions torchtrain/datasets/alpaca.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,15 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.

from typing import List, Tuple
from typing import List

import torch

from datasets import load_dataset
from torch.utils.data import IterableDataset, DataLoader, DistributedSampler
from torch.utils.data import DataLoader, IterableDataset

from torchtrain.datasets.tokenizer import TokenizerIf

from datasets import load_dataset


class AlpacaDataset(IterableDataset):
"""PyTorch Representation of the Alpaca Dataset from Hugging Face.
Expand All @@ -37,11 +34,7 @@ class AlpacaDataset(IterableDataset):
Batch size: 8
"""

def __init__(self,
tokenizer: TokenizerIf,
seq_len: int = 2048,
**kwargs
) -> None:
def __init__(self, tokenizer: TokenizerIf, seq_len: int = 2048, **kwargs) -> None:
self._data = load_dataset("tatsu-lab/alpaca", split="train")
self._tokenizer = tokenizer
self.data_iterator = iter(self._data)
Expand All @@ -52,7 +45,7 @@ def __len__(self):
return len(self._data)

def __iter__(self):
max_buffer_token_len = (1 + self.seq_len)
max_buffer_token_len = 1 + self.seq_len
all_tokens: List[int] = []

for sample in self.data_iterator:
Expand All @@ -71,11 +64,7 @@ def __iter__(self):


def build_alpaca_data_loader(
tokenizer: TokenizerIf,
batch_size: int,
seq_len: int,
world_size,
rank
tokenizer: TokenizerIf, batch_size: int, seq_len: int, world_size, rank
):
alpaca_ds = AlpacaDataset(tokenizer=tokenizer, seq_len=seq_len)
# TOOD: sampler can't work with iterable dataset, figure out a way
Expand Down
35 changes: 27 additions & 8 deletions torchtrain/datasets/download_tokenizer.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import os
from typing import Optional

Expand All @@ -11,20 +12,38 @@

def hf_download(repo_id: Optional[str] = None, hf_token: Optional[str] = None) -> None:
from huggingface_hub import hf_hub_download

os.makedirs(f"checkpoints/{repo_id}", exist_ok=True)
try:
hf_hub_download(repo_id, "tokenizer.model", local_dir=f"torchtrain/datasets/tokenizer/", local_dir_use_symlinks=False, token=hf_token)
hf_hub_download(
repo_id,
"tokenizer.model",
local_dir="torchtrain/datasets/tokenizer/",
local_dir_use_symlinks=False,
token=hf_token,
)
except HTTPError as e:
if e.response.status_code == 401:
print("You need to pass a valid `--hf_token=...` to download private checkpoints.")
print(
"You need to pass a valid `--hf_token=...` to download private checkpoints."
)
else:
raise e

if __name__ == '__main__':

if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description='Download tokenizer from HuggingFace.')
parser.add_argument('--repo_id', type=str, default="meta-llama/llama-2-70b", help='Repository ID to download from.')
parser.add_argument('--hf_token', type=str, default=None, help='HuggingFace API token.')

parser = argparse.ArgumentParser(description="Download tokenizer from HuggingFace.")
parser.add_argument(
"--repo_id",
type=str,
default="meta-llama/llama-2-70b",
help="Repository ID to download from.",
)
parser.add_argument(
"--hf_token", type=str, default=None, help="HuggingFace API token."
)

args = parser.parse_args()
hf_download(args.repo_id, args.hf_token)
14 changes: 13 additions & 1 deletion torchtrain/models/llama/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,20 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.

from torchtrain.models.llama.model import ModelArgs, Transformer

__all__ = ["Transformer"]

llama_configs = {
"debugmodel": ModelArgs(dim=256, n_layers=1, n_heads=16),
"7B": ModelArgs(dim=4096, n_layers=32, n_heads=32),
"13B": ModelArgs(dim=5120, n_layers=40, n_heads=40),
"70B": ModelArgs(dim=8192, n_layers=80, n_heads=64, n_kv_heads=8, ffn_dim_multiplier=1.3, multiple_of=4096),
"70B": ModelArgs(
dim=8192,
n_layers=80,
n_heads=64,
n_kv_heads=8,
ffn_dim_multiplier=1.3,
multiple_of=4096,
),
}
Loading

0 comments on commit 52c9091

Please sign in to comment.