Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Seq Packing in NeMo / Neva2 #11633

Merged
merged 50 commits into from
Jan 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
50 commits
Select commit Hold shift + click to select a range
15439bf
api updates and fixes
yaoyu-33 Dec 12, 2024
6bfd873
Apply isort and black reformatting
yaoyu-33 Dec 12, 2024
773b4c9
fix
yaoyu-33 Dec 12, 2024
3a1a017
fix arg
yaoyu-33 Dec 12, 2024
e3e87b7
update seq packing in mock ds
yaoyu-33 Dec 16, 2024
4ee633c
Merge branch 'main' into yuya/neva2_seq_packing
yaoyu-33 Dec 16, 2024
ecc813d
Merge branch 'main' into yuya/neva2_seq_packing
yaoyu-33 Dec 17, 2024
c10157c
save
yaoyu-33 Dec 17, 2024
84eb7cc
update preprocess_data
yaoyu-33 Dec 17, 2024
3bf6442
update seq packing
yaoyu-33 Dec 17, 2024
c8a26af
Apply isort and black reformatting
yaoyu-33 Dec 17, 2024
48b5261
fix sp
yaoyu-33 Dec 17, 2024
365c051
Merge remote-tracking branch 'origin/yuya/neva2_seq_packing' into yuy…
yaoyu-33 Dec 17, 2024
7da82ed
save
yaoyu-33 Dec 18, 2024
4127c40
Merge branch 'main' into yuya/neva2_seq_packing
yaoyu-33 Dec 18, 2024
c5d26c3
fix seq packing
yaoyu-33 Dec 18, 2024
ecd461f
add truncation and padding
yaoyu-33 Dec 19, 2024
5e0a168
Apply isort and black reformatting
yaoyu-33 Dec 19, 2024
9240a79
Fix issues
yaoyu-33 Dec 19, 2024
4808999
change LLaVATemplateConfig variables to class variables
yashaswikarnati Dec 19, 2024
c4d92f9
change to use field with default attributes
yashaswikarnati Dec 19, 2024
ad44132
Apply isort and black reformatting
yashaswikarnati Dec 19, 2024
7db8e52
Merge remote-tracking branch 'origin/yuya/neva2_seq_packing' into yuy…
yaoyu-33 Dec 19, 2024
e705afe
Apply isort and black reformatting
yaoyu-33 Dec 19, 2024
f0a9cb5
Merge remote-tracking branch 'origin/yash/fix_template_dataclass' int…
yaoyu-33 Dec 19, 2024
7415036
Add seq packing option in energon
yaoyu-33 Dec 31, 2024
af1f32a
Fix energon conversation
yaoyu-33 Dec 31, 2024
568f9aa
add energon option in neva training script
yaoyu-33 Dec 31, 2024
01fd6cf
Apply isort and black reformatting
yaoyu-33 Dec 31, 2024
094ef9a
add ci test for packed seq
yaoyu-33 Jan 3, 2025
626bbc3
fix mock dataset seq packing
yaoyu-33 Jan 7, 2025
18aa644
Apply isort and black reformatting
yaoyu-33 Jan 7, 2025
b4f7e8b
fix mock dataset seq packing
yaoyu-33 Jan 7, 2025
2ccea79
Merge remote-tracking branch 'origin/yuya/neva2_seq_packing' into yuy…
yaoyu-33 Jan 7, 2025
a2a4000
Apply isort and black reformatting
yaoyu-33 Jan 7, 2025
0599b5a
Merge branch 'main' into yuya/neva2_seq_packing
yaoyu-33 Jan 7, 2025
90778e1
fix lint and update seq pack func
yaoyu-33 Jan 7, 2025
38a6c49
Merge remote-tracking branch 'origin/yuya/neva2_seq_packing' into yuy…
yaoyu-33 Jan 7, 2025
f0ec5f1
fix energon module
yaoyu-33 Jan 7, 2025
ff45f7e
Apply isort and black reformatting
yaoyu-33 Jan 7, 2025
38e42a2
fix comments
yaoyu-33 Jan 8, 2025
eadc665
Apply isort and black reformatting
yaoyu-33 Jan 8, 2025
846252f
address lightning issues
yaoyu-33 Jan 8, 2025
d70e432
Merge remote-tracking branch 'origin/yuya/neva2_seq_packing' into yuy…
yaoyu-33 Jan 8, 2025
a2290de
Apply isort and black reformatting
yaoyu-33 Jan 8, 2025
3fdfe3e
Update sequence_packing.py
yaoyu-33 Jan 9, 2025
d24cd3b
Merge branch 'main' into yuya/neva2_seq_packing
yaoyu-33 Jan 13, 2025
a68e41f
update energon requirements
yaoyu-33 Jan 13, 2025
c04f1ed
Fix for energon update
yaoyu-33 Jan 13, 2025
41ab130
fix for test
yaoyu-33 Jan 14, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions .github/workflows/cicd-main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4329,11 +4329,24 @@ jobs:
with:
RUNNER: self-hosted-azure
SCRIPT: |
python tests/collections/vlm/neva_train.py \
python tests/collections/vlm/test_neva_train.py \
--devices=1 \
--max-steps=5 \
--experiment-dir=/tmp/nemo2_neva_results/${{ github.run_id }}

L2_NeMo_2_NEVA_MOCK_PACKED_TRAINING:
needs: [cicd-test-container-setup]
uses: ./.github/workflows/_test_template.yml
if: contains(fromJSON(needs.cicd-test-container-setup.outputs.test_to_run), 'L2_NeMo_2_NEVA_MOCK_PACKED_TRAINING') || needs.cicd-test-container-setup.outputs.all == 'true'
with:
RUNNER: self-hosted-azure
SCRIPT: |
python tests/collections/vlm/test_neva_train.py \
--devices=1 \
--max-steps=5 \
--experiment-dir=/tmp/nemo2_neva_results/${{ github.run_id }} \
--use_packed_sequence

L2_NeMo_2_MLLAMA_MOCK_TRAINING:
needs: [cicd-test-container-setup]
uses: ./.github/workflows/_test_template.yml
Expand All @@ -4342,7 +4355,7 @@ jobs:
RUNNER: self-hosted-azure
SCRIPT: |
TRANSFORMERS_OFFLINE=1 \
python tests/collections/vlm/mllama_train.py \
python tests/collections/vlm/test_mllama_train.py \
--devices=1 \
--max-steps=5 \
--experiment-dir=/tmp/nemo2_mllama_results/${{ github.run_id }}
Expand Down Expand Up @@ -5060,6 +5073,7 @@ jobs:
- Speech_Checkpoints_tests
- L2_Stable_Diffusion_Training
- L2_NeMo_2_NEVA_MOCK_TRAINING
- L2_NeMo_2_NEVA_MOCK_PACKED_TRAINING
- L2_NeMo_2_MLLAMA_MOCK_TRAINING
- L2_NeMo_2_GPT_Pretraining_no_transformer_engine
- L2_NeMo_2_GPT_DDP_Param_Parity_check
Expand Down
4 changes: 2 additions & 2 deletions nemo/collections/llm/peft/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@
from pathlib import Path
from typing import Tuple, Union

import pytorch_lightning as pl
import lightning.pytorch as pl
import torch
from lightning.pytorch.trainer.states import TrainerFn
from megatron.core import dist_checkpointing
from pytorch_lightning.trainer.states import TrainerFn
from rich.console import Console

from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer
Expand Down
5 changes: 5 additions & 0 deletions nemo/collections/multimodal/data/energon/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def __init__(
multimodal_sample_config: Optional[MultiModalSampleConfig] = MultiModalSampleConfig(),
task_encoder: Optional[MultiModalTaskEncoder] = None,
decoder_seq_length: Optional[int] = None,
packing_buffer_size: Optional[int] = None,
) -> None:
"""
Initialize the EnergonMultiModalDataModule.
Expand All @@ -84,6 +85,8 @@ def __init__(
Defaults to MultiModalSampleConfig().
task_encoder (MultiModalTaskEncoder, optional): Encoder responsible for encoding and batching samples.
If not provided, a default (MultimodalTaskEncoder) encoder will be created. Defaults to None.
decoder_seq_length (int, optional): The maximum sequence length for the decoder. Used in encoder-decoder models.
packing_buffer_size (int, optional): Size of the packing buffer for batched samples. Defaults to None.
"""

super().__init__()
Expand Down Expand Up @@ -113,6 +116,7 @@ def __init__(
)
self.train_dataloader_object = None
self.val_dataloader_object = None
self.packing_buffer_size = packing_buffer_size

def io_init(self, **kwargs) -> fdl.Config[Self]:

Expand Down Expand Up @@ -146,6 +150,7 @@ def datasets_provider(self, worker_config, split: Literal['train', 'val'] = 'val
task_encoder=self.task_encoder,
worker_config=worker_config,
max_samples_per_sequence=None,
packing_buffer_size=self.packing_buffer_size,
shuffle_buffer_size=100,
split_part=split,
)
Expand Down
24 changes: 22 additions & 2 deletions nemo/collections/multimodal/data/energon/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,11 @@
# limitations under the License.

from dataclasses import dataclass, field
from typing import List
from typing import List, Tuple, Union

import torch
from megatron.core.packed_seq_params import PackedSeqParams

from nemo.collections.multimodal.data.energon.conversation import LLaVATemplateConfig


Expand All @@ -34,7 +37,7 @@

@dataclass
class ImageTextSample:
'''Sample type for template formatted raw image text sample'''
"""Sample type for template formatted raw image text sample"""

__key__: str = ''
images: torch.Tensor = field(default_factory=lambda: torch.empty(0))
Expand All @@ -43,6 +46,15 @@
loss_mask: torch.Tensor = field(default_factory=lambda: torch.empty(0, dtype=torch.float))


@dataclass
class PackedImageTextSample(ImageTextSample):
"""Sample type for packed image text sample"""

__restore_key__: Tuple[Union[str, int, tuple], ...] = ()
position_ids: torch.Tensor = field(default_factory=lambda: torch.empty(0, dtype=torch.float))
packed_seq_params: PackedSeqParams = field(default_factory=lambda: PackedSeqParams())

Check notice

Code scanning / CodeQL

Unnecessary lambda Note

This 'lambda' is just a simple wrapper around a callable object. Use that object directly.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you resolve this?



@dataclass
class ImageTextRawBatch:
"""Sample type for image text raw batch"""
Expand All @@ -56,6 +68,14 @@
loss_mask: torch.Tensor = field(default_factory=lambda: torch.empty(0, dtype=torch.float))


@dataclass
class PackedImageTextRawBatch(ImageTextRawBatch):
"""Sample type for image text raw batch"""

position_ids: torch.Tensor = field(default_factory=lambda: torch.empty(0, dtype=torch.float))
packed_seq_params: PackedSeqParams = field(default_factory=lambda: PackedSeqParams())

Check notice

Code scanning / CodeQL

Unnecessary lambda Note

This 'lambda' is just a simple wrapper around a callable object. Use that object directly.


@dataclass
class MultiModalSampleConfig:
image_token: ImageToken = field(default_factory=ImageToken)
Expand Down
2 changes: 1 addition & 1 deletion nemo/collections/multimodal/data/energon/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class LLaVATemplateConfig(BaseConversationTemplateConfig):
"""LLava-specific template configuration which extends the base config"""

system: str = field(
default="A chat between a curious user and artificial assistant agent. "
default="A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed and polite answers to user's questions."
)
roles: List[str] = field(default_factory=lambda: ['user', 'assistant'])
Expand Down
170 changes: 143 additions & 27 deletions nemo/collections/multimodal/data/energon/task_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,21 @@
batch_list,
batch_pad_stack,
)
from megatron.energon.task_encoder.base import stateless

from nemo.collections.multimodal.data.energon.config import ImageTextRawBatch, ImageTextSample
from nemo.collections.multimodal.data.energon.config import (
ImageTextRawBatch,
ImageTextSample,
PackedImageTextRawBatch,
PackedImageTextSample,
)
from nemo.collections.multimodal.data.energon.sample_encoder import (
InterleavedSampleEncoder,
SampleEncoder,
SimilarityInterleavedEncoder,
VQASampleEncoder,
)
from nemo.utils import logging


class MultiModalTaskEncoder(
Expand All @@ -54,16 +61,34 @@ class MultiModalTaskEncoder(
for model input.
"""

def __init__(self, tokenizer, image_processor, multimodal_sample_config):
def __init__(
self,
tokenizer,
image_processor,
multimodal_sample_config,
packed_sequence=False,
packed_sequence_size=-1,
num_image_embeddings_per_tile=576,
):
"""
Initialize the MultiModalTaskEncoder with specific encoders for different sample types.

Parameters:
tokenizer (Tokenizer): The tokenizer used for processing text across different sample types.
image_processor (ImageProcessor): The image processor used for preprocessing images.
multimodal_sample_config (MultiModalSampleConfig): MultiModalSampleConfig object.
tokenizer (Tokenizer): The tokenizer used for processing textual components across sample types.
image_processor (ImageProcessor): The image processor responsible for preprocessing image data.
multimodal_sample_config (MultiModalSampleConfig): Configuration object defining properties and
requirements for multimodal samples.
packed_sequence (bool, optional): Flag indicating whether packed sequences are used. Default is False.
packed_sequence_size (int, optional): The size of packed sequences, used when `packed_sequence` is True.
Default is -1.
num_image_embeddings_per_tile (int, optional): Number of image embeddings per image tile. Determines
the granularity of image features. Default is 576.
"""
self.tokenizer = tokenizer
self.sample_config = multimodal_sample_config
self.packed_sequence = packed_sequence
self.num_image_embeddings_per_tile = num_image_embeddings_per_tile # only used with seq packing
self.packed_sequence_size = packed_sequence_size
self.encoders: Dict[str, SampleEncoder] = {
VQASample.__name__: VQASampleEncoder(
tokenizer=tokenizer,
Expand Down Expand Up @@ -92,6 +117,7 @@ def register_encoder(self, sample_type: str, encoder: SampleEncoder) -> None:
"""
self.encoders[sample_type] = encoder

@stateless
def encode_sample(
self, sample: Union[VQASample, InterleavedSample, SimilarityInterleavedSample, CaptioningSample]
) -> ImageTextSample:
Expand All @@ -118,7 +144,9 @@ def encode_sample(
encoded_sample = encoder.encode(input_sample=sample, output_sample=ImageTextSample())
return encoded_sample

def batch(self, samples: List[ImageTextSample]) -> ImageTextRawBatch:
def batch(
self, samples: List[Union[ImageTextSample, PackedImageTextSample]]
) -> Union[ImageTextRawBatch, PackedImageTextRawBatch]:
"""
Batch a list of encoded samples into a single raw batch.

Expand All @@ -131,26 +159,51 @@ def batch(self, samples: List[ImageTextSample]) -> ImageTextRawBatch:
ImageTextRawBatch: The batched data, including images, tokens, labels, and loss masks.
"""

keys, images, tokens, labels, loss_mask = [], [], [], [], []
for sample in samples:
keys.append(sample.__key__)
images.append(sample.images)
tokens.append(sample.tokens)
labels.append(sample.labels)
loss_mask.append(sample.loss_mask)

batch_keys = batch_list(keys)
batch_images = batch_pad_stack(images)
batch_prompt_tokens = batch_pad_stack(tokens)
batch_labels = batch_pad_stack(labels)
batch_loss_mask = batch_pad_stack(loss_mask)
return ImageTextRawBatch(
__keys__=batch_keys,
images=batch_images,
tokens=batch_prompt_tokens,
labels=batch_labels,
loss_mask=batch_loss_mask,
)
if self.packed_sequence:
if len(samples) > 1:
raise ValueError(
"Micro batch size should be 1 when training with packed sequence, but your micro batch size "
f"is {len(samples)}. \nThe following config is equivalent to your current setting for "
f"a packed dataset. Please update your config to the following: \n"
f"Set micro batch size to 1 (currently {len(samples)})\n"
f"Set global batch size to `global_batch_size // {len(samples)}` "
f"Set packed sequence length to `original_sample_seq_len * {len(samples)}` "
f"(currently {self.packed_sequence_size}) \n"
f"For details please visit "
f"https://docs.nvidia.com/nemo-framework/user-guide/latest/sft_peft/packed_sequence.html"
)
# The batching are taken care by packing.
sample = samples[0]
return PackedImageTextRawBatch(
__keys__=sample.__key__,
images=sample.images,
tokens=sample.tokens,
labels=sample.labels,
loss_mask=sample.loss_mask,
position_ids=sample.position_ids,
packed_seq_params=sample.packed_seq_params,
)
else:
keys, images, tokens, labels, loss_mask = [], [], [], [], []
for sample in samples:
keys.append(sample.__key__)
images.append(sample.images)
tokens.append(sample.tokens)
labels.append(sample.labels)
loss_mask.append(sample.loss_mask)

batch_keys = batch_list(keys)
batch_images = batch_pad_stack(images)
batch_prompt_tokens = batch_pad_stack(tokens)
batch_labels = batch_pad_stack(labels)
batch_loss_mask = batch_pad_stack(loss_mask)
return ImageTextRawBatch(
__keys__=batch_keys,
images=batch_images,
tokens=batch_prompt_tokens,
labels=batch_labels,
loss_mask=batch_loss_mask,
)

def encode_batch(self, batch_data: ImageTextRawBatch) -> dict:
"""
Expand All @@ -165,7 +218,7 @@ def encode_batch(self, batch_data: ImageTextRawBatch) -> dict:
Returns:
dict: A dictionary containing the encoded batch data, ready for model input.
"""
batch_dict = dataclasses.asdict(batch_data)
batch_dict = batch_data.__dict__
if 'images' in batch_dict:
batch_dict['media'] = batch_dict['images']
del batch_dict['images']
Expand All @@ -177,3 +230,66 @@ def encode_batch(self, batch_data: ImageTextRawBatch) -> dict:
if 'attention_mask' not in batch_dict:
batch_dict['attention_mask'] = None
return batch_dict

def select_samples_to_pack(self, samples):
"""Selects which samples will be packed together.

NOTE: Energon dataloader calls this method internally if packing is used.
Please see https://nvidia.github.io/Megatron-Energon/packing.html
"""
from nemo.collections.vlm.neva.data.sequence_packing import greedy_knapsack, predict_seq_len

media_token_id = self.sample_config.image_token.token_id
lengths = [
predict_seq_len(
sample.tokens,
media_token_index=media_token_id,
num_image_embeddings_per_tile=self.num_image_embeddings_per_tile,
)
for sample in samples
]
packed_samples = greedy_knapsack(lengths, samples, self.packed_sequence_size)
avg_samples_per_bin = round(len(lengths) / len(packed_samples))
logging.info(
f"[Seq Packing Info] - Packing seq len: {self.packed_sequence_size}, "
f"Buffered samples: {len(lengths)}, Total number of bins: {len(packed_samples)}, "
f"Average samples per bin: {avg_samples_per_bin}"
)
return packed_samples

@stateless
def pack_selected_samples(self, samples):
"""
Function to pack a list of ImageTaskSample into a single ImageTaskSamplePacked.

NOTE: Energon dataloader calls this method internally if packing is used.
Please see https://nvidia.github.io/Megatron-Energon/packing.html

Args:
samples: List of ImageTaskSample instances to pack into one sample.

Returns:
ImageTaskSamplePacked instance.
"""
from nemo.collections.vlm.neva.data.sequence_packing import convert_to_packed

packed_images = torch.stack([sample.images for sample in samples])
media_token_id = self.sample_config.image_token.token_id
packed_tokens, packed_labels, packed_position_ids, packed_loss_mask, packed_seq_params = convert_to_packed(
tokens=[sample.tokens for sample in samples],
labels=[sample.labels for sample in samples],
num_image_embeddings_per_tile=self.num_image_embeddings_per_tile,
media_token_index=media_token_id,
ignore_index=self.sample_config.ignore_place_holder,
)

return PackedImageTextSample(
__key__=",".join([s.__key__ for s in samples]),
__restore_key__=(), # Will be set by energon based on `samples`
tokens=packed_tokens,
labels=packed_labels,
images=packed_images,
position_ids=packed_position_ids,
loss_mask=packed_loss_mask,
packed_seq_params=packed_seq_params,
)
2 changes: 1 addition & 1 deletion nemo/collections/vlm/inference/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from typing import List, Optional, Union

import pytorch_lightning as pl
import lightning.pytorch as pl
import torch
import torch.distributed
from megatron.core.inference.common_inference_params import CommonInferenceParams
Expand Down
Loading
Loading