-
Notifications
You must be signed in to change notification settings - Fork 2.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Seq Packing in NeMo / Neva2 #11633
Changes from 40 commits
15439bf
6bfd873
773b4c9
3a1a017
e3e87b7
4ee633c
ecc813d
c10157c
84eb7cc
3bf6442
c8a26af
48b5261
365c051
7da82ed
4127c40
c5d26c3
ecd461f
5e0a168
9240a79
4808999
c4d92f9
ad44132
7db8e52
e705afe
f0a9cb5
7415036
af1f32a
568f9aa
01fd6cf
094ef9a
626bbc3
18aa644
b4f7e8b
2ccea79
a2a4000
0599b5a
90778e1
38a6c49
f0ec5f1
ff45f7e
38e42a2
eadc665
846252f
d70e432
a2290de
3fdfe3e
d24cd3b
a68e41f
c04f1ed
41ab130
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -25,14 +25,21 @@ | |
batch_list, | ||
batch_pad_stack, | ||
) | ||
from megatron.energon.task_encoder.base import stateless | ||
|
||
from nemo.collections.multimodal.data.energon.config import ImageTextRawBatch, ImageTextSample | ||
from nemo.collections.multimodal.data.energon.config import ( | ||
ImageTextRawBatch, | ||
ImageTextSample, | ||
PackedImageTextRawBatch, | ||
PackedImageTextSample, | ||
) | ||
from nemo.collections.multimodal.data.energon.sample_encoder import ( | ||
InterleavedSampleEncoder, | ||
SampleEncoder, | ||
SimilarityInterleavedEncoder, | ||
VQASampleEncoder, | ||
) | ||
from nemo.utils import logging | ||
|
||
|
||
class MultiModalTaskEncoder( | ||
|
@@ -54,7 +61,15 @@ class MultiModalTaskEncoder( | |
for model input. | ||
""" | ||
|
||
def __init__(self, tokenizer, image_processor, multimodal_sample_config): | ||
def __init__( | ||
self, | ||
tokenizer, | ||
image_processor, | ||
multimodal_sample_config, | ||
packed_sequence=False, | ||
packing_seq_length=4096, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can you use consistent naming as fine_tuning.py? i.e. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i kept There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. changed name from |
||
num_image_embeddings_per_tile=576, | ||
): | ||
""" | ||
Initialize the MultiModalTaskEncoder with specific encoders for different sample types. | ||
|
||
|
@@ -64,6 +79,10 @@ def __init__(self, tokenizer, image_processor, multimodal_sample_config): | |
multimodal_sample_config (MultiModalSampleConfig): MultiModalSampleConfig object. | ||
""" | ||
self.tokenizer = tokenizer | ||
self.sample_config = multimodal_sample_config | ||
self.packed_sequence = packed_sequence | ||
self.num_image_embeddings_per_tile = num_image_embeddings_per_tile # only used with seq packing | ||
self.packing_seq_length = packing_seq_length | ||
self.encoders: Dict[str, SampleEncoder] = { | ||
VQASample.__name__: VQASampleEncoder( | ||
tokenizer=tokenizer, | ||
|
@@ -92,6 +111,7 @@ def register_encoder(self, sample_type: str, encoder: SampleEncoder) -> None: | |
""" | ||
self.encoders[sample_type] = encoder | ||
|
||
@stateless(restore_seeds=True) | ||
def encode_sample( | ||
self, sample: Union[VQASample, InterleavedSample, SimilarityInterleavedSample, CaptioningSample] | ||
) -> ImageTextSample: | ||
|
@@ -118,7 +138,7 @@ def encode_sample( | |
encoded_sample = encoder.encode(input_sample=sample, output_sample=ImageTextSample()) | ||
return encoded_sample | ||
|
||
def batch(self, samples: List[ImageTextSample]) -> ImageTextRawBatch: | ||
def batch(self, samples): | ||
yaoyu-33 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
""" | ||
Batch a list of encoded samples into a single raw batch. | ||
|
||
|
@@ -131,26 +151,40 @@ def batch(self, samples: List[ImageTextSample]) -> ImageTextRawBatch: | |
ImageTextRawBatch: The batched data, including images, tokens, labels, and loss masks. | ||
""" | ||
|
||
keys, images, tokens, labels, loss_mask = [], [], [], [], [] | ||
for sample in samples: | ||
keys.append(sample.__key__) | ||
images.append(sample.images) | ||
tokens.append(sample.tokens) | ||
labels.append(sample.labels) | ||
loss_mask.append(sample.loss_mask) | ||
|
||
batch_keys = batch_list(keys) | ||
batch_images = batch_pad_stack(images) | ||
batch_prompt_tokens = batch_pad_stack(tokens) | ||
batch_labels = batch_pad_stack(labels) | ||
batch_loss_mask = batch_pad_stack(loss_mask) | ||
return ImageTextRawBatch( | ||
__keys__=batch_keys, | ||
images=batch_images, | ||
tokens=batch_prompt_tokens, | ||
labels=batch_labels, | ||
loss_mask=batch_loss_mask, | ||
) | ||
if self.packed_sequence: | ||
assert len(samples) == 1, "Must set MBS=1 when using `packed_sequence`." | ||
yaoyu-33 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# The batching are taken care by packing. | ||
sample = samples[0] | ||
return PackedImageTextRawBatch( | ||
__keys__=sample.__key__, | ||
images=sample.images, | ||
tokens=sample.tokens, | ||
labels=sample.labels, | ||
loss_mask=sample.loss_mask, | ||
position_ids=sample.position_ids, | ||
packed_seq_params=sample.packed_seq_params, | ||
) | ||
else: | ||
keys, images, tokens, labels, loss_mask = [], [], [], [], [] | ||
for sample in samples: | ||
keys.append(sample.__key__) | ||
images.append(sample.images) | ||
tokens.append(sample.tokens) | ||
labels.append(sample.labels) | ||
loss_mask.append(sample.loss_mask) | ||
|
||
batch_keys = batch_list(keys) | ||
batch_images = batch_pad_stack(images) | ||
batch_prompt_tokens = batch_pad_stack(tokens) | ||
batch_labels = batch_pad_stack(labels) | ||
batch_loss_mask = batch_pad_stack(loss_mask) | ||
return ImageTextRawBatch( | ||
__keys__=batch_keys, | ||
images=batch_images, | ||
tokens=batch_prompt_tokens, | ||
labels=batch_labels, | ||
loss_mask=batch_loss_mask, | ||
) | ||
|
||
def encode_batch(self, batch_data: ImageTextRawBatch) -> dict: | ||
""" | ||
|
@@ -165,7 +199,7 @@ def encode_batch(self, batch_data: ImageTextRawBatch) -> dict: | |
Returns: | ||
dict: A dictionary containing the encoded batch data, ready for model input. | ||
""" | ||
batch_dict = dataclasses.asdict(batch_data) | ||
batch_dict = batch_data.__dict__ | ||
if 'images' in batch_dict: | ||
batch_dict['media'] = batch_dict['images'] | ||
del batch_dict['images'] | ||
|
@@ -177,3 +211,66 @@ def encode_batch(self, batch_data: ImageTextRawBatch) -> dict: | |
if 'attention_mask' not in batch_dict: | ||
batch_dict['attention_mask'] = None | ||
return batch_dict | ||
|
||
def select_samples_to_pack(self, samples): | ||
"""Selects which samples will be packed together. | ||
|
||
NOTE: Energon dataloader calls this method internally if packing is used. | ||
Please see https://nvidia.github.io/Megatron-Energon/packing.html | ||
""" | ||
from nemo.collections.vlm.neva.data.sequence_packing import greedy_knapsack, predict_seq_len | ||
|
||
media_token_id = self.sample_config.image_token.token_id | ||
lengths = [ | ||
predict_seq_len( | ||
sample.tokens, | ||
media_token_index=media_token_id, | ||
num_image_embeddings_per_tile=self.num_image_embeddings_per_tile, | ||
) | ||
for sample in samples | ||
] | ||
packed_samples = greedy_knapsack(lengths, samples, self.packing_seq_length) | ||
avg_samples_per_bin = round(len(lengths) / len(packed_samples)) | ||
logging.info( | ||
f"[Seq Packing Info] - Packing seq len: {self.packing_seq_length}, " | ||
f"Buffered samples: {len(lengths)}, Total number of bins: {len(packed_samples)}, " | ||
f"Average samples per bin: {avg_samples_per_bin}" | ||
) | ||
return packed_samples | ||
|
||
@stateless | ||
def pack_selected_samples(self, samples): | ||
""" | ||
Function to pack a list of ImageTaskSample into a single ImageTaskSamplePacked. | ||
|
||
NOTE: Energon dataloader calls this method internally if packing is used. | ||
Please see https://nvidia.github.io/Megatron-Energon/packing.html | ||
|
||
Args: | ||
samples: List of ImageTaskSample instances to pack into one sample. | ||
|
||
Returns: | ||
ImageTaskSamplePacked instance. | ||
""" | ||
from nemo.collections.vlm.neva.data.sequence_packing import convert_to_packed | ||
|
||
packed_images = torch.stack([sample.images for sample in samples]) | ||
media_token_id = self.sample_config.image_token.token_id | ||
packed_tokens, packed_labels, packed_position_ids, packed_loss_mask, packed_seq_params = convert_to_packed( | ||
tokens=[sample.tokens for sample in samples], | ||
labels=[sample.labels for sample in samples], | ||
num_image_embeddings_per_tile=self.num_image_embeddings_per_tile, | ||
media_token_index=media_token_id, | ||
ignore_index=self.sample_config.ignore_place_holder, | ||
) | ||
|
||
return PackedImageTextSample( | ||
__key__=",".join([s.__key__ for s in samples]), | ||
__restore_key__=(), # Will be set by energon based on `samples` | ||
tokens=packed_tokens, | ||
labels=packed_labels, | ||
images=packed_images, | ||
position_ids=packed_position_ids, | ||
loss_mask=packed_loss_mask, | ||
packed_seq_params=packed_seq_params, | ||
) |
Check notice
Code scanning / CodeQL
Unnecessary lambda Note
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you resolve this?