Skip to content

Commit

Permalink
deps: cap transformers at 4.40.2 (#218)
Browse files Browse the repository at this point in the history
* deps: pin transformers below v4.41d

Signed-off-by: Anh-Uong <[email protected]>

* remove deprecated requirements.yaml

Signed-off-by: Anh-Uong <[email protected]>

* update unit tests with old evaluation_strategy flag

Signed-off-by: Anh-Uong <[email protected]>

* set transformers upper bound to 4.40.2

- update eval flag in docs

Signed-off-by: Anh-Uong <[email protected]>

---------

Signed-off-by: Anh-Uong <[email protected]>
  • Loading branch information
anhuong authored Jun 27, 2024
1 parent 0949699 commit 3f05c67
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 4 deletions.
2 changes: 1 addition & 1 deletion examples/prompt_tuning_twitter_complaints/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ tuning/sft_trainer.py \
--per_device_train_batch_size 1 \
--per_device_eval_batch_size 1 \
--gradient_accumulation_steps 1 \
--eval_strategy "no" \
--evaluation_strategy "no" \
--save_strategy "epoch" \
--learning_rate 1e-5 \
--weight_decay 0. \
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ classifiers=[
dependencies = [
"numpy>=1.26.4,<2.0",
"accelerate>=0.20.3,<0.40",
"transformers>=4.34.1,<5.0,!=4.38.2",
"transformers>=4.34.1,<=4.40.2,!=4.38.2",
"torch>=2.2.0,<3.0",
"sentencepiece>=0.1.99,<0.3",
"tokenizers>=0.13.3,<1.0",
Expand Down
4 changes: 2 additions & 2 deletions tests/test_sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,7 @@ def test_run_causallm_pt_with_validation():
with tempfile.TemporaryDirectory() as tempdir:
train_args = copy.deepcopy(TRAIN_ARGS)
train_args.output_dir = tempdir
train_args.eval_strategy = "epoch"
train_args.evaluation_strategy = "epoch"
data_args = copy.deepcopy(DATA_ARGS)
data_args.validation_data_path = TWITTER_COMPLAINTS_DATA

Expand All @@ -317,7 +317,7 @@ def test_run_causallm_pt_with_validation_data_formatting():
with tempfile.TemporaryDirectory() as tempdir:
train_args = copy.deepcopy(TRAIN_ARGS)
train_args.output_dir = tempdir
train_args.eval_strategy = "epoch"
train_args.evaluation_strategy = "epoch"
data_args = copy.deepcopy(DATA_ARGS)
data_args.validation_data_path = TWITTER_COMPLAINTS_DATA
data_args.dataset_text_field = None
Expand Down

0 comments on commit 3f05c67

Please sign in to comment.