Skip to content

Commit

Permalink
precommit
Browse files Browse the repository at this point in the history
  • Loading branch information
ahmeda14960 committed Oct 29, 2024
1 parent 313a3f4 commit b3718c1
Show file tree
Hide file tree
Showing 5 changed files with 13 additions and 19 deletions.
4 changes: 2 additions & 2 deletions examples/sft/alpaca-llama.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ trainer:
wandb:
project: "levanter-sft"
tags: ["llama2", "alpaca"]
num_train_steps: 1218
num_train_steps: 1218
train_batch_size: 64
# If using model parallelism
tensor_parallel_axes: ["mlp", "heads"]
Expand All @@ -29,4 +29,4 @@ model_cache_dir: null

hf_save_path: "sft_hf_ckpts"
hf_upload: false
hf_save_steps: 1000
hf_save_steps: 1000
4 changes: 2 additions & 2 deletions examples/sft/dolly-llama.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ trainer:
wandb:
project: "levanter-sft"
tags: ["llama2", "oasst"]
num_train_steps: 1218
num_train_steps: 1218
train_batch_size: 128
# If using model parallelism
tensor_parallel_axes: ["mlp", "heads"]
Expand All @@ -29,4 +29,4 @@ model_cache_dir: null

hf_save_path: "sft_hf_ckpts"
hf_upload: false
hf_save_steps: 1000
hf_save_steps: 1000
8 changes: 4 additions & 4 deletions examples/sft/oasst-llama.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@ trainer:
wandb:
project: "levanter-sft"
tags: ["llama2", "oasst"]
num_train_steps: 1218
num_train_steps: 1218
train_batch_size: 128

# If using model parallelism
tensor_parallel_axes: ["mlp", "heads"]

Expand All @@ -25,7 +25,7 @@ supervised_data:
output_field: "response" # adjust based on dataset
cache_dir: "cache/dolly"

# Model configuration
# Model configuration
max_tune_length: 2048
trust_remote_code: false
model_cache_dir: null
Expand All @@ -35,4 +35,4 @@ hf_save_path: "sft_hf_ckpts"
hf_upload: false
hf_save_steps: 1000

# python examples/sft/sft.py --config_path examples/sft/oasst-llama2.yaml
# python examples/sft/sft.py --config_path examples/sft/oasst-llama2.yaml
15 changes: 5 additions & 10 deletions examples/sft/sft.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
import json
import logging
import os
from dataclasses import dataclass
from typing import Dict, Optional, Union
from typing import Optional, Union

import fsspec
import jax
import jax.random as jrandom
import transformers

Expand All @@ -14,13 +11,11 @@
import levanter
from levanter.compat.hf_checkpoints import HFCheckpointConverter, save_hf_checkpoint_callback
from levanter.data import PermutationDataset
from levanter.data.text import EpochDataset, LMSupervisedDatasetConfig, mk_supervised_dataset
from levanter.models.lm_model import LmHeadModel, compute_next_token_loss
from levanter.optim import OptimizerConfig
from levanter.trainer import Trainer, TrainerConfig
from levanter.utils import fsspec_utils
from levanter.utils.hf_utils import num_cpus_used_by_tokenizer
from levanter.utils.py_utils import non_caching_cycle
from levanter.data.text import mk_supervised_dataset, LMSupervisedDatasetConfig, EpochDataset


logger = logging.getLogger(__name__)
Expand All @@ -38,7 +33,7 @@ class TrainArgs:
trainer: TrainerConfig

max_tune_length: int = 2048 # maximum length of the input to the model during tuning

# Supervision config
supervised_data: LMSupervisedDatasetConfig = LMSupervisedDatasetConfig()
input_field: str = "instruction" # field name for input in dataset
Expand Down Expand Up @@ -81,7 +76,7 @@ def train(config: TrainArgs):

# modify converter to use our tokenizer
converter = converter.replaced(tokenizer=tokenizer)

# Configure supervised dataset
supervised_config = config.supervised_data

Expand Down Expand Up @@ -151,4 +146,4 @@ def add_special_tokens(tokenizer, use_unk_instead_of_adding=False):


if __name__ == "__main__":
levanter.config.main(train)()
levanter.config.main(train)()
1 change: 0 additions & 1 deletion src/levanter/data/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -725,7 +725,6 @@ def mk_supervised_dataset(config: LMSupervisedDatasetConfig, tokenizer: PreTrain
input_field = config.input_field
output_field = config.output_field


output_exemplar = {"input_ids": np.zeros((0,), dtype=np.int32), "sources_len": np.zeros((0,), dtype=np.int32)}

# Use the same preprocessing as before
Expand Down

0 comments on commit b3718c1

Please sign in to comment.