Skip to content

Commit

Permalink
WIP changes
Browse files Browse the repository at this point in the history
  • Loading branch information
ahmeda14960 committed Jan 26, 2025
1 parent d3261bb commit fbcf9de
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 18 deletions.
5 changes: 0 additions & 5 deletions src/levanter/data/packing.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ class SequencePacker:

def __init__(self, Pos: hax.Axis, max_pack_size: int, pad_token: int):
self.Pos = Pos
logger.info(f" Pos in packer is {Pos}")
self._ids: list[int] = []
self._segment_ids: list[int] = []
self._loss_mask: list[int] = []
Expand All @@ -55,8 +54,6 @@ def add_example(self, ids: list[int], loss_mask: list[int] | np.ndarray, segment
return

if len(ids) + len(self._ids) > self.Pos.size:
logger.info(f"length of new id is ids: {len(ids)}")
logger.info(f"length of old list is ids: {len(self._ids)}")
raise ValueError("Too many tokens")

if self.num_segments >= self.max_pack_size:
Expand All @@ -67,8 +64,6 @@ def add_example(self, ids: list[int], loss_mask: list[int] | np.ndarray, segment
segment_id = self.num_segments

self.num_segments += 1
logger.info(f"segment_id is {segment_id}")
logger.info(f"ids total is now {len(self._ids)}")

self._segment_ids.extend([segment_id] * len(ids))

Expand Down
17 changes: 6 additions & 11 deletions src/levanter/data/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,6 @@ async def current_len(self) -> Optional[int]:

async def get_batch(self, indices: Sequence[int]) -> Sequence[T_co]:
token_arrays = await self._await_token_cache()
# logger.info(f"Time to get token cache: {time.time() - time_in}")
len = await self.wait_until_len_at_least(max(indices) + 1)
if len is not None and len < max(indices) + 1:
raise ValueError("Requested indices beyond the end of the dataset")
Expand Down Expand Up @@ -1459,31 +1458,27 @@ def mk_chat_sft_packed_dataset(
)

# Convert cached dictionaries to PromptCompletions and pack them
def prepare_and_pack(examples: list[dict]) -> list[LmExample]:
def prepare_and_pack(examples: list[dict]) -> list:
completions = []
logger.info(f"length of examples is {len(examples)}")
for idx, ex in enumerate(examples):
if int(ex["sources_len"]) > Pos.size - 1:
# if the prompt itself is larger than our context
# length we need to skip this example
logger.info(f"Skipping example {idx} because prompt is too long")
continue
if len(ex["input_ids"]) > Pos.size:
logger.info(f"Shortening example {idx} from {len(ex['input_ids'])} to {Pos.size}")
ex["input_ids"] = ex["input_ids"][:Pos.size]
logger.info(f"New length of example is {len(ex['input_ids'])}")
completions.append(
PromptCompletion(
ids=ex["input_ids"].tolist(),
prompt_length=int(ex["sources_len"])
)
)
logger.info(f"\n\n at example {idx}")
logger.info(f"Prompt Length: {ex['sources_len']}")
#logger.info(f"Prompt: {tokenizer.decode(ex['input_ids'])}")
logger.info(f"Total completion length {len(ex['input_ids'])}")
# logger.info(f"\n\n at example {idx}")
# logger.info(f"Prompt Length: {ex['sources_len']}")
# #logger.info(f"Prompt: {tokenizer.decode(ex['input_ids'])}")
# logger.info(f"Total completion length {len(ex['input_ids'])}")
#import os; os._exit(1)
logger.info(f"completions: {len(completions)}")
# logger.info(f"completions: {len(completions)}")
iterator = pack_prompt_completions(
Pos=Pos,
sequences=completions,
Expand Down
16 changes: 14 additions & 2 deletions src/levanter/main/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ def train(config: SFTConfig):
converter = converter.replaced(tokenizer=tokenizer)

model_config = converter.default_config
logger.info(f"New seq_len is {config.max_seq_len}")
model_config = dataclasses.replace(converter.default_config, seq_len=config.max_seq_len)
elif config.trainer.initialize_from is None:
raise ValueError("Must specify either --initialize_from_hf or --initialize_from")
Expand Down Expand Up @@ -150,12 +151,11 @@ def train(config: SFTConfig):
)

# if config.enable_packing:
logger.info('\n\npacking!!\n\n')
import haliax
train_dataset = mk_chat_sft_packed_dataset(
chat_config,
tokenizer,
haliax.Axis("position", 4096),
haliax.Axis("position", 2048),
max_segments_per_example=8
)
# else:
Expand Down Expand Up @@ -217,6 +217,18 @@ def train(config: SFTConfig):
callbacks.log_performance_stats(Pos.size, trainer.config.train_batch_size, flops_per_example), every=1
)

# reshuffle the examples before packing!

# to implement seeking
# check the step number in the trainer state if it's not zero
# then next the iterator until we get there, then continue training.
# batch size will be backed in from config

# change iterate tokenized requests to take a dict rather than a list
# of where the first element is prompt ands econd is response

# then pass into tierate tokenizer requests, go to pack requests
# and then you have the correct loader, just pass to trainer.train()
loader = trainer.data_loader(train_dataset, trainer.TrainBatch)

if config.hf_save_path is not None:
Expand Down

0 comments on commit fbcf9de

Please sign in to comment.