Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Seq Packing in NeMo / Neva2 #11633

Open
wants to merge 50 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 15 commits
Commits
Show all changes
50 commits
Select commit Hold shift + click to select a range
15439bf
api updates and fixes
yaoyu-33 Dec 12, 2024
6bfd873
Apply isort and black reformatting
yaoyu-33 Dec 12, 2024
773b4c9
fix
yaoyu-33 Dec 12, 2024
3a1a017
fix arg
yaoyu-33 Dec 12, 2024
e3e87b7
update seq packing in mock ds
yaoyu-33 Dec 16, 2024
4ee633c
Merge branch 'main' into yuya/neva2_seq_packing
yaoyu-33 Dec 16, 2024
ecc813d
Merge branch 'main' into yuya/neva2_seq_packing
yaoyu-33 Dec 17, 2024
c10157c
save
yaoyu-33 Dec 17, 2024
84eb7cc
update preprocess_data
yaoyu-33 Dec 17, 2024
3bf6442
update seq packing
yaoyu-33 Dec 17, 2024
c8a26af
Apply isort and black reformatting
yaoyu-33 Dec 17, 2024
48b5261
fix sp
yaoyu-33 Dec 17, 2024
365c051
Merge remote-tracking branch 'origin/yuya/neva2_seq_packing' into yuy…
yaoyu-33 Dec 17, 2024
7da82ed
save
yaoyu-33 Dec 18, 2024
4127c40
Merge branch 'main' into yuya/neva2_seq_packing
yaoyu-33 Dec 18, 2024
c5d26c3
fix seq packing
yaoyu-33 Dec 18, 2024
ecd461f
add truncation and padding
yaoyu-33 Dec 19, 2024
5e0a168
Apply isort and black reformatting
yaoyu-33 Dec 19, 2024
9240a79
Fix issues
yaoyu-33 Dec 19, 2024
4808999
change LLaVATemplateConfig variables to class variables
yashaswikarnati Dec 19, 2024
c4d92f9
change to use field with default attributes
yashaswikarnati Dec 19, 2024
ad44132
Apply isort and black reformatting
yashaswikarnati Dec 19, 2024
7db8e52
Merge remote-tracking branch 'origin/yuya/neva2_seq_packing' into yuy…
yaoyu-33 Dec 19, 2024
e705afe
Apply isort and black reformatting
yaoyu-33 Dec 19, 2024
f0a9cb5
Merge remote-tracking branch 'origin/yash/fix_template_dataclass' int…
yaoyu-33 Dec 19, 2024
7415036
Add seq packing option in energon
yaoyu-33 Dec 31, 2024
af1f32a
Fix energon conversation
yaoyu-33 Dec 31, 2024
568f9aa
add energon option in neva training script
yaoyu-33 Dec 31, 2024
01fd6cf
Apply isort and black reformatting
yaoyu-33 Dec 31, 2024
094ef9a
add ci test for packed seq
yaoyu-33 Jan 3, 2025
626bbc3
fix mock dataset seq packing
yaoyu-33 Jan 7, 2025
18aa644
Apply isort and black reformatting
yaoyu-33 Jan 7, 2025
b4f7e8b
fix mock dataset seq packing
yaoyu-33 Jan 7, 2025
2ccea79
Merge remote-tracking branch 'origin/yuya/neva2_seq_packing' into yuy…
yaoyu-33 Jan 7, 2025
a2a4000
Apply isort and black reformatting
yaoyu-33 Jan 7, 2025
0599b5a
Merge branch 'main' into yuya/neva2_seq_packing
yaoyu-33 Jan 7, 2025
90778e1
fix lint and update seq pack func
yaoyu-33 Jan 7, 2025
38a6c49
Merge remote-tracking branch 'origin/yuya/neva2_seq_packing' into yuy…
yaoyu-33 Jan 7, 2025
f0ec5f1
fix energon module
yaoyu-33 Jan 7, 2025
ff45f7e
Apply isort and black reformatting
yaoyu-33 Jan 7, 2025
38e42a2
fix comments
yaoyu-33 Jan 8, 2025
eadc665
Apply isort and black reformatting
yaoyu-33 Jan 8, 2025
846252f
address lightning issues
yaoyu-33 Jan 8, 2025
d70e432
Merge remote-tracking branch 'origin/yuya/neva2_seq_packing' into yuy…
yaoyu-33 Jan 8, 2025
a2290de
Apply isort and black reformatting
yaoyu-33 Jan 8, 2025
3fdfe3e
Update sequence_packing.py
yaoyu-33 Jan 9, 2025
d24cd3b
Merge branch 'main' into yuya/neva2_seq_packing
yaoyu-33 Jan 13, 2025
a68e41f
update energon requirements
yaoyu-33 Jan 13, 2025
c04f1ed
Fix for energon update
yaoyu-33 Jan 13, 2025
41ab130
fix for test
yaoyu-33 Jan 14, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions nemo/collections/vlm/neva/data/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,15 @@ class DataConfig:
@dataclass
class ImageDataConfig(DataConfig):
media_type: str = "image"
media_token: MultiModalToken = ImageToken
media_token: MultiModalToken = ImageToken()
yaoyu-33 marked this conversation as resolved.
Show resolved Hide resolved
image_folder: Optional[str] = None
image_process_mode: str = 'pad'


@dataclass
class VideoDataConfig(DataConfig):
media_type: str = "video"
media_token: MultiModalToken = VideoToken
media_token: MultiModalToken = VideoToken()
splice_single_frame: Optional[str] = None
# 'first', 'middle', 'last' will represent video as first / middle / last frame only, all other frames discarded.
num_frames: int = 8 # Selects the number of frames to use from the video
Expand Down
122 changes: 71 additions & 51 deletions nemo/collections/vlm/neva/data/lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,6 @@
data_config,
tokenizer,
image_processor,
sequence_length=None,
):
super().__init__()
if data_path is not None:
Expand All @@ -269,8 +268,6 @@
self.tokenizer = self.tokenizer.tokenizer

self.image_processor = image_processor
self.sequence_length = sequence_length

self.conv_template = data_config.conv_template
self.conv = supported_conv_templates[self.conv_template]
self.image_process_mode = data_config.image_process_mode
Expand Down Expand Up @@ -381,6 +378,7 @@
data_config,
tokenizer,
image_processor,
packed_sequence=False,
):

if data_path.endswith(".json"):
Expand Down Expand Up @@ -414,29 +412,11 @@

else:
raise ValueError(f"Formatting of {data_path} is not supported in Neva.")
self.packed_sequence = packed_sequence

def collate_fn(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
data_config = self.data_config
packed_sequence = "cu_seqlens" in instances[0]
max_len = max(instance['tokens'].shape[0] for instance in instances)
for instance in instances:
pad_len = max_len - instance['tokens'].shape[0]
instance['tokens'] = F.pad(instance['tokens'], (0, pad_len), 'constant', 0)
instance['labels'] = F.pad(instance['labels'], (0, pad_len), 'constant', IGNORE_INDEX)
if packed_sequence and instance["cu_seqlens"][-1] != max_len:
instance["cu_seqlens"] = torch.cat((instance["cu_seqlens"], torch.IntTensor([max_len])), 0)

if packed_sequence:
max_len_cu = max(instance['cu_seqlens'].shape[0] for instance in instances)
max_len_image = max(instance['image'].shape[0] for instance in instances)
for instance in instances:
pad_len_cu = max_len_cu - instance['cu_seqlens'].shape[0]
instance['cu_seqlens'] = F.pad(instance['cu_seqlens'], (0, pad_len_cu), 'constant', max_len)

x = instance['image']
num_pad = max_len_image - x.shape[0]
pad_tensor = torch.zeros(num_pad, *x.shape[1:], dtype=x.dtype, device=x.device)
instance['image'] = torch.cat((x, pad_tensor), dim=0)
packed_sequence = self.packed_sequence

media_type = data_config.media_type
if media_type == 'image':
Expand All @@ -447,33 +427,75 @@
else:
raise ValueError(f"Unsupported media type {media_type}")

batch = default_collate(instances)
tokenizer = self.tokenizer
if packed_sequence:
from megatron.core.packed_seq_params import PackedSeqParams

tokens = batch['tokens']
labels = batch['labels']
media_token_id = self.data_config.media_token.token_index

if packed_sequence:
cu_seqlens = batch["cu_seqlens"]
tokens = []
labels = []
position_ids = []
for cu_seqlen in cu_seqlens:
position_ids.append([])
for ind in range(0, len(cu_seqlen) - 1):
seqlen = cu_seqlen[ind + 1] - cu_seqlen[ind]
position_ids[-1].extend(list(range(seqlen)))
position_ids = torch.LongTensor(position_ids)
loss_mask = torch.ones(tokens.size(), dtype=torch.float, device=tokens.device)
attention_mask = torch.ones(tokens.size(), dtype=torch.long, device=tokens.device)
else:
seqlens = []
cu_seqlens = [0]
cu_seqlens_padded = [0]
for instance in instances:
# Assume 1 tile per image
num_image_embeddings_per_tile = 576
num_images = torch.sum(instance['tokens'] == media_token_id)
seqlen = len(instance['tokens']) + (num_image_embeddings_per_tile - 1) * num_images
seqlen_padded = (seqlen - 1) // 8 * 8 + 8
pad_len = seqlen_padded - seqlen
if pad_len > 0:
instance['tokens'] = F.pad(instance['tokens'], (0, pad_len), 'constant', 0)
instance['labels'] = F.pad(instance['labels'], (0, pad_len), 'constant', IGNORE_INDEX)
tokens.append(instance['tokens'])
labels.append(instance['labels'])
position_ids.append(
torch.arange(len(instance['tokens']), dtype=torch.int, device=instance['tokens'].device)
)
seqlens.append(seqlen)
cu_seqlens.append(cu_seqlens[-1] + seqlen)
cu_seqlens_padded.append(cu_seqlens_padded[-1] + seqlen_padded)
tokens = torch.cat(tokens, dim=0).unsqueeze(0)
labels = torch.cat(labels, dim=0).unsqueeze(0)
position_ids = torch.cat(position_ids, dim=0).unsqueeze(0)
loss_mask = torch.ones_like(labels, dtype=torch.float, device=labels.device)
loss_mask[labels < 0] = 0.0
attention_mask = None

cu_seqlens = torch.IntTensor(cu_seqlens)
cu_seqlens_padded = torch.IntTensor(cu_seqlens_padded)
Fixed Show fixed Hide fixed
cu_seqlens_padded = None
qkv_format = 'thd'
packed_seq_params = PackedSeqParams(
cu_seqlens_q=cu_seqlens,
cu_seqlens_kv=cu_seqlens,
cu_seqlens_q_padded=cu_seqlens_padded,
cu_seqlens_kv_padded=cu_seqlens_padded,
max_seqlen_q=max(seqlens),
max_seqlen_kv=max(seqlens),
qkv_format=qkv_format,
)
else: # regular dataset
max_len = max(instance['tokens'].shape[0] for instance in instances)
for instance in instances:
pad_len = max_len - instance['tokens'].shape[0]
instance['tokens'] = F.pad(instance['tokens'], (0, pad_len), 'constant', 0)
instance['labels'] = F.pad(instance['labels'], (0, pad_len), 'constant', IGNORE_INDEX)

batch = default_collate(instances)
tokenizer = self.tokenizer

tokens = batch['tokens']
labels = batch['labels']
attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids(
data=tokens,
eod_token=tokenizer.eos_token_id,
eod_mask_loss=data_config.eod_mask_loss,
reset_attention_mask=data_config.reset_attention_mask,
reset_position_ids=data_config.reset_position_ids,
)

loss_mask[labels < 0] = 0.0
loss_mask[labels < 0] = 0.0

batch = {
'tokens': tokens,
Expand All @@ -484,7 +506,7 @@
'media': media,
}
if packed_sequence:
batch["cu_seqlens"] = cu_seqlens
batch["packed_seq_params"] = packed_seq_params
return batch


Expand All @@ -506,7 +528,7 @@
num_workers: int = 8,
pin_memory: bool = True,
persistent_workers: bool = False,
use_packed_sequence: bool = False,
packed_sequence: bool = False,
seed: int = 1234,
) -> None:
super().__init__()
Expand Down Expand Up @@ -534,7 +556,7 @@
self.pin_memory = pin_memory
self.persistent_workers = persistent_workers
self.seed = seed
self.use_packed_sequence = use_packed_sequence
self.packed_sequence = packed_sequence
self.init_global_step = 0

if tokenizer is None or image_processor is None:
Expand All @@ -556,14 +578,12 @@

def setup(self, stage: str = "") -> None:
assert len(self.paths) == 1, "not yet support blend dataset in Neva 2.0!"
if self.use_packed_sequence:
pass # TODO
else:
# TODO:
# rng = torch.Generator().manual_seed(self.seed)
# train_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size], generator=rng)
self._train_ds = NevaDataset(self.paths[0], self.data_config, self.tokenizer, self.image_processor)
self._validation_ds = NevaDataset(self.paths[0], self.data_config, self.tokenizer, self.image_processor)
self._train_ds = NevaDataset(
self.paths[0], self.data_config, self.tokenizer, self.image_processor, packed_sequence=self.packed_sequence
)
self._validation_ds = NevaDataset(
self.paths[0], self.data_config, self.tokenizer, self.image_processor, packed_sequence=self.packed_sequence
)

def train_dataloader(self) -> TRAIN_DATALOADERS:
return self._create_dataloader(self._train_ds)
Expand Down
50 changes: 47 additions & 3 deletions nemo/collections/vlm/neva/data/mock.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def __init__(
num_workers: int = 8,
pin_memory: bool = True,
persistent_workers: bool = False,
packed_sequence: bool = False,
):
super().__init__()
self.seq_length = seq_length
Expand All @@ -54,6 +55,7 @@ def __init__(
self.num_workers = num_workers
self.pin_memory = pin_memory
self.persistent_workers = persistent_workers
self.packed_sequence = packed_sequence

if tokenizer is None or image_processor is None:
logging.warning(f"Processor or tokenizer are not provided! Fall back to `llava-hf/llava-1.5-7b-hf`.")
Expand All @@ -73,13 +75,28 @@ def __init__(

def setup(self, stage: str = "") -> None:
self._train_ds = _MockNevaDataset(
self.tokenizer, self.image_processor, "train", self.num_train_samples, self.seq_length
self.tokenizer,
self.image_processor,
"train",
self.num_train_samples,
self.seq_length,
packed_sequence=self.packed_sequence,
)
self._validation_ds = _MockNevaDataset(
self.tokenizer, self.image_processor, "valid", self.num_val_samples, self.seq_length
self.tokenizer,
self.image_processor,
"valid",
self.num_val_samples,
self.seq_length,
packed_sequence=self.packed_sequence,
)
self._test_ds = _MockNevaDataset(
self.tokenizer, self.image_processor, "test", self.num_test_samples, self.seq_length
self.tokenizer,
self.image_processor,
"test",
self.num_test_samples,
self.seq_length,
packed_sequence=self.packed_sequence,
)

def train_dataloader(self) -> TRAIN_DATALOADERS:
Expand Down Expand Up @@ -117,6 +134,7 @@ def __init__(
num_samples: int,
seq_length: int,
seed: int = 42,
packed_sequence: bool = False,
) -> None:
super().__init__()
self.name = name
Expand All @@ -129,6 +147,7 @@ def __init__(

self.length = num_samples
self.seed = seed
self.packed_sequence = packed_sequence

self.loss_mask = torch.ones(self.seq_length, dtype=torch.float)
self.position_ids = torch.arange(self.seq_length, dtype=torch.int64)
Expand Down Expand Up @@ -164,6 +183,31 @@ def _collate_fn(self, batch):
"""
collated_batch = data.dataloader.default_collate(batch)
collated_batch["attention_mask"] = None
if self.packed_sequence:
from megatron.core.packed_seq_params import PackedSeqParams

tokens = collated_batch["tokens"]
batch_size = tokens.shape[0]
valid_seqlen = self.seq_length
cu_seqlens = torch.arange(
0, (batch_size + 1) * (valid_seqlen), step=(valid_seqlen), dtype=torch.int32, device=tokens.device
)
cu_seqlens_padded = None
qkv_format = 'thd'
packed_seq_params = PackedSeqParams(
cu_seqlens_q=cu_seqlens,
cu_seqlens_kv=cu_seqlens,
cu_seqlens_q_padded=cu_seqlens_padded,
cu_seqlens_kv_padded=cu_seqlens_padded,
max_seqlen_q=valid_seqlen,
max_seqlen_kv=valid_seqlen,
qkv_format=qkv_format,
)
collated_batch["packed_seq_params"] = packed_seq_params

for key in ["tokens", "labels", "loss_mask", "position_ids"]:
collated_batch[key] = collated_batch[key].reshape(1, -1)

return collated_batch

def collate_fn(self, batch):
Expand Down
22 changes: 16 additions & 6 deletions nemo/collections/vlm/neva/model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,14 +121,19 @@ def neva_data_step(dataloader_iter) -> Dict[str, torch.Tensor]:
)
)

packed_seq_params = _batch.get("packed_seq_params", None)
_batch = {
key: val.cuda(non_blocking=True) if key in required_keys and val is not None else None
for key, val in _batch.items()
}
# slice batch along sequence dimension for context parallelism
output = get_batch_on_this_context_parallel_rank(_batch)
if packed_seq_params is not None:
for attr in ["cu_seqlens_q", "cu_seqlens_kv", "cu_seqlens_q_padded", "cu_seqlens_kv_padded"]:
value = getattr(packed_seq_params, attr, None)
if value is not None:
setattr(packed_seq_params, attr, value.cuda(non_blocking=True))
_batch["packed_seq_params"] = packed_seq_params

return output
return _batch


def neva_forward_step(model, batch) -> torch.Tensor:
Expand Down Expand Up @@ -596,6 +601,7 @@ def forward(
image_token_index,
num_image_tiles,
attention_mask,
packed_seq_params,
) # [combined_seq_len, b, h_language], [b, combined_seq_len], [b, combined_seq_len]

output = self.language_model(
Expand Down Expand Up @@ -642,6 +648,7 @@ def _preprocess_data(
image_token_index,
num_image_tiles,
attention_mask,
packed_seq_params,
):
"""Preprocess input data before input to language model.

Expand Down Expand Up @@ -698,6 +705,8 @@ def _preprocess_data(
labels.shape == loss_mask.shape
), f"mismatching labels shape {labels.shape} and loss mask shape {loss_mask.shape}"

packed_sequence = packed_seq_params is not None and packed_seq_params.qkv_format == "thd"

# Create indices for new text and label positions.
with torch.no_grad():
image_token_mask = input_ids == image_token_index
Expand Down Expand Up @@ -826,16 +835,16 @@ def _preprocess_data(
), "unexpected shapes after data preprocessing"

truncate_labels = has_labels and final_labels.shape[1] > self._language_max_sequence_length
if truncate_labels:
if truncate_labels and not packed_sequence:
final_labels = final_labels[:, : self._language_max_sequence_length]
final_loss_mask = final_loss_mask[:, : self._language_max_sequence_length]

if final_embedding is not None:
final_embedding = final_embedding.transpose(1, 0).contiguous()
# Truncate if exceeding the language model's max sequence length.
if final_embedding.shape[0] > self._language_max_sequence_length:
if final_embedding.shape[0] > self._language_max_sequence_length and not packed_sequence:
final_embedding = final_embedding[: self._language_max_sequence_length]
if self.sequence_parallel_lm:
if self.sequence_parallel_lm and not packed_sequence:
# Create an attention mask. This ensures correct computation.
# This is done even when no padding was done as we set mask_type to
# 'padding' or 'padding_causal' when using SP.
Expand All @@ -858,6 +867,7 @@ def _preprocess_data(

# Attention mask True/False meaning flipped in 1.7.0
attention_mask = attention_mask < 0.5
if self.sequence_parallel_lm:
final_embedding = tensor_parallel.scatter_to_sequence_parallel_region(final_embedding)

return final_embedding, final_labels, final_loss_mask, attention_mask
Expand Down
5 changes: 4 additions & 1 deletion nemo/lightning/megatron_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -1711,7 +1711,10 @@ def masked_token_loss(tensor: Tensor, mask: Tensor):
"""
losses = tensor.float()
loss_mask = mask.view(-1).float()
loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum() # sequence level nll
num_valid_tokens = loss_mask.sum()
if num_valid_tokens < 0.5: # no valid tokens
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you explain a bit more when this is the case? is this only valid for neva?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not only for neva, also for SFT. If the system and user prompt is very long, and predict answer only. After truncation from right, there might not be any answer/valid tokens.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Usually we truncate the input/context and keep answer intact, so that wouldn't happen

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep the logic is a bit different for vlm. We don't want to truncate from left.

num_valid_tokens += 1.0
loss = torch.sum(losses.view(-1) * loss_mask) / num_valid_tokens # sequence level nll

return loss

Expand Down
Loading
Loading