Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enable data loading for data parallel training #49

Merged
merged 2 commits into from
Feb 8, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 17 additions & 11 deletions torchtrain/datasets/alpaca.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ class AlpacaDataset(IterableDataset):
Args:
tokenizer (Tokenizer): Tokenizer used to encode data. Tokenize must implement an `encode` and `decode` method.
seq_len (int): max sequence length
world_size (int): number of data parallel processes participating in training
rank (int): rank of the current data parallel process

Data input format:
{
Expand All @@ -34,11 +36,20 @@ class AlpacaDataset(IterableDataset):
Batch size: 8
"""

def __init__(self, tokenizer: TokenizerIf, seq_len: int = 2048, **kwargs) -> None:
def __init__(
self,
tokenizer: TokenizerIf,
seq_len: int = 2048,
world_size: int = 1,
rank: int = 0,
**kwargs
) -> None:
self._data = load_dataset("tatsu-lab/alpaca", split="train")
self._tokenizer = tokenizer
self.data_iterator = iter(self._data)
self.seq_len = seq_len
self.world_size = world_size
self.rank = rank
self.response_tag = "\n\n### Response:\n"

def __len__(self):
Expand All @@ -48,7 +59,10 @@ def __iter__(self):
max_buffer_token_len = 1 + self.seq_len
all_tokens: List[int] = []

for sample in self.data_iterator:
for idx, sample in enumerate(self.data_iterator):
# select samples to pack in a round-robin fashion
tianyu-l marked this conversation as resolved.
Show resolved Hide resolved
if idx % self.world_size != self.rank:
continue
sample_text = sample["text"]
sample_tokens = self._tokenizer.encode(sample_text, bos=True, eos=True)
all_tokens.extend(sample_tokens)
Expand All @@ -66,14 +80,6 @@ def __iter__(self):
def build_alpaca_data_loader(
tokenizer: TokenizerIf, batch_size: int, seq_len: int, world_size, rank
):
alpaca_ds = AlpacaDataset(tokenizer=tokenizer, seq_len=seq_len)
# TOOD: sampler can't work with iterable dataset, figure out a way
# to sample in a distributed manner
# dist_sampler = DistributedSampler(
# alpaca_ds,
# world_size,
# rank,
# shuffle=True,
# )
alpaca_ds = AlpacaDataset(tokenizer, seq_len, world_size, rank)

return DataLoader(alpaca_ds, batch_size=batch_size)
1 change: 1 addition & 0 deletions torchtrain/parallelisms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def build_mesh(self, device_type):
if d > 1:
dims.append(d)
names.append(name)
names = tuple(names)
logger.info(f"Building {len(dims)}-D device mesh with {names}, {dims}")
return init_device_mesh(device_type, dims, mesh_dim_names=names)

Expand Down
2 changes: 1 addition & 1 deletion torchtrain/parallelisms/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def parallelize_llama(model, world_mesh, parallel_dims, args):

if parallel_dims.dp_enabled:
dp_mesh = world_mesh["dp"] if world_mesh.ndim > 1 else world_mesh
assert dp_mesh.mesh_dim_names == ["dp"], dp_mesh.mesh_dim_names
assert dp_mesh.mesh_dim_names == ("dp",), dp_mesh.mesh_dim_names
fsdp_config = {
"mixed_precision": MixedPrecision(
param_dtype=torch.bfloat16,
Expand Down
Loading