From b3718c1b7225cbc526e622146d1fc1f20ae79951 Mon Sep 17 00:00:00 2001 From: Ahmed Ahmed Date: Mon, 28 Oct 2024 18:57:59 -0700 Subject: [PATCH] precommit --- examples/sft/alpaca-llama.yaml | 4 ++-- examples/sft/dolly-llama.yaml | 4 ++-- examples/sft/oasst-llama.yaml | 8 ++++---- examples/sft/sft.py | 15 +++++---------- src/levanter/data/text.py | 1 - 5 files changed, 13 insertions(+), 19 deletions(-) diff --git a/examples/sft/alpaca-llama.yaml b/examples/sft/alpaca-llama.yaml index ac2de709d..5817bf0e3 100644 --- a/examples/sft/alpaca-llama.yaml +++ b/examples/sft/alpaca-llama.yaml @@ -6,7 +6,7 @@ trainer: wandb: project: "levanter-sft" tags: ["llama2", "alpaca"] - num_train_steps: 1218 + num_train_steps: 1218 train_batch_size: 64 # If using model parallelism tensor_parallel_axes: ["mlp", "heads"] @@ -29,4 +29,4 @@ model_cache_dir: null hf_save_path: "sft_hf_ckpts" hf_upload: false -hf_save_steps: 1000 \ No newline at end of file +hf_save_steps: 1000 diff --git a/examples/sft/dolly-llama.yaml b/examples/sft/dolly-llama.yaml index 9dd68f984..f386c32b7 100644 --- a/examples/sft/dolly-llama.yaml +++ b/examples/sft/dolly-llama.yaml @@ -6,7 +6,7 @@ trainer: wandb: project: "levanter-sft" tags: ["llama2", "oasst"] - num_train_steps: 1218 + num_train_steps: 1218 train_batch_size: 128 # If using model parallelism tensor_parallel_axes: ["mlp", "heads"] @@ -29,4 +29,4 @@ model_cache_dir: null hf_save_path: "sft_hf_ckpts" hf_upload: false -hf_save_steps: 1000 \ No newline at end of file +hf_save_steps: 1000 diff --git a/examples/sft/oasst-llama.yaml b/examples/sft/oasst-llama.yaml index 46f89f0ea..48cd6ae2b 100644 --- a/examples/sft/oasst-llama.yaml +++ b/examples/sft/oasst-llama.yaml @@ -6,9 +6,9 @@ trainer: wandb: project: "levanter-sft" tags: ["llama2", "oasst"] - num_train_steps: 1218 + num_train_steps: 1218 train_batch_size: 128 - + # If using model parallelism tensor_parallel_axes: ["mlp", "heads"] @@ -25,7 +25,7 @@ supervised_data: output_field: "response" # adjust based on dataset cache_dir: "cache/dolly" -# Model configuration +# Model configuration max_tune_length: 2048 trust_remote_code: false model_cache_dir: null @@ -35,4 +35,4 @@ hf_save_path: "sft_hf_ckpts" hf_upload: false hf_save_steps: 1000 -# python examples/sft/sft.py --config_path examples/sft/oasst-llama2.yaml \ No newline at end of file +# python examples/sft/sft.py --config_path examples/sft/oasst-llama2.yaml diff --git a/examples/sft/sft.py b/examples/sft/sft.py index 8fa81624a..90bd1ab85 100644 --- a/examples/sft/sft.py +++ b/examples/sft/sft.py @@ -1,11 +1,8 @@ -import json import logging import os from dataclasses import dataclass -from typing import Dict, Optional, Union +from typing import Optional, Union -import fsspec -import jax import jax.random as jrandom import transformers @@ -14,13 +11,11 @@ import levanter from levanter.compat.hf_checkpoints import HFCheckpointConverter, save_hf_checkpoint_callback from levanter.data import PermutationDataset +from levanter.data.text import EpochDataset, LMSupervisedDatasetConfig, mk_supervised_dataset from levanter.models.lm_model import LmHeadModel, compute_next_token_loss from levanter.optim import OptimizerConfig from levanter.trainer import Trainer, TrainerConfig -from levanter.utils import fsspec_utils -from levanter.utils.hf_utils import num_cpus_used_by_tokenizer from levanter.utils.py_utils import non_caching_cycle -from levanter.data.text import mk_supervised_dataset, LMSupervisedDatasetConfig, EpochDataset logger = logging.getLogger(__name__) @@ -38,7 +33,7 @@ class TrainArgs: trainer: TrainerConfig max_tune_length: int = 2048 # maximum length of the input to the model during tuning - + # Supervision config supervised_data: LMSupervisedDatasetConfig = LMSupervisedDatasetConfig() input_field: str = "instruction" # field name for input in dataset @@ -81,7 +76,7 @@ def train(config: TrainArgs): # modify converter to use our tokenizer converter = converter.replaced(tokenizer=tokenizer) - + # Configure supervised dataset supervised_config = config.supervised_data @@ -151,4 +146,4 @@ def add_special_tokens(tokenizer, use_unk_instead_of_adding=False): if __name__ == "__main__": - levanter.config.main(train)() \ No newline at end of file + levanter.config.main(train)() diff --git a/src/levanter/data/text.py b/src/levanter/data/text.py index b9690e8b7..0181889d9 100644 --- a/src/levanter/data/text.py +++ b/src/levanter/data/text.py @@ -725,7 +725,6 @@ def mk_supervised_dataset(config: LMSupervisedDatasetConfig, tokenizer: PreTrain input_field = config.input_field output_field = config.output_field - output_exemplar = {"input_ids": np.zeros((0,), dtype=np.int32), "sources_len": np.zeros((0,), dtype=np.int32)} # Use the same preprocessing as before