Skip to content

Commit

Permalink
chore: Update training scripts with configurable parser
Browse files Browse the repository at this point in the history
  • Loading branch information
upskyy committed Aug 8, 2024
1 parent 25f3706 commit b96405c
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 7 deletions.
8 changes: 5 additions & 3 deletions training_multi_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
parser.add_argument("--nli_batch_size", type=int, default=64)
parser.add_argument("--sts_batch_size", type=int, default=8)
parser.add_argument("--num_epochs", type=int, default=10)
parser.add_argument("--eval_steps", type=int, default=1000)
parser.add_argument("--learning_rate", type=float, default=2e-5)
parser.add_argument("--output_dir", type=str, default="output")
parser.add_argument("--output_prefix", type=str, default="kor_multi_")
parser.add_argument("--seed", type=int, default=42)
Expand Down Expand Up @@ -97,14 +99,14 @@
train_objectives=train_objectives,
evaluator=dev_evaluator,
epochs=args.num_epochs,
optimizer_params={"lr": 2e-5},
evaluation_steps=1000,
optimizer_params={"lr": args.learning_rate},
evaluation_steps=args.eval_steps,
warmup_steps=warmup_steps,
output_path=model_save_path,
)

# Load the stored model and evaluate its performance on STS benchmark dataset
model = SentenceTransformer(model_save_path)
model = SentenceTransformer(model_save_path, trust_remote_code=True)
logging.info("Read KorSTS benchmark test dataset")

test_file = os.path.join(sts_dataset_path, "sts-test.tsv")
Expand Down
6 changes: 4 additions & 2 deletions training_nli.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
parser.add_argument("--max_seq_length", type=int, default=128)
parser.add_argument("--batch_size", type=int, default=64)
parser.add_argument("--num_epochs", type=int, default=1)
parser.add_argument("--eval_steps", type=int, default=1000)
parser.add_argument("--learning_rate", type=float, default=2e-5)
parser.add_argument("--output_dir", type=str, default="output")
parser.add_argument("--output_prefix", type=str, default="kor_nli_")
parser.add_argument("--seed", type=int, default=777)
Expand Down Expand Up @@ -83,8 +85,8 @@
train_objectives=[(train_dataloader, train_loss)],
evaluator=dev_evaluator,
epochs=args.num_epochs,
optimizer_params={"lr": 2e-5},
evaluation_steps=1000,
optimizer_params={"lr": args.learning_rate},
evaluation_steps=args.eval_steps,
warmup_steps=warmup_steps,
output_path=model_save_path,
)
Expand Down
6 changes: 4 additions & 2 deletions training_sts.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
parser.add_argument("--max_seq_length", type=int, default=128)
parser.add_argument("--batch_size", type=int, default=8)
parser.add_argument("--num_epochs", type=int, default=5)
parser.add_argument("--eval_steps", type=int, default=1000)
parser.add_argument("--learning_rate", type=float, default=2e-5)
parser.add_argument("--output_dir", type=str, default="output")
parser.add_argument("--output_prefix", type=str, default="kor_sts_")
parser.add_argument("--seed", type=int, default=777)
Expand Down Expand Up @@ -84,8 +86,8 @@
train_objectives=[(train_dataloader, train_loss)],
evaluator=dev_evaluator,
epochs=args.num_epochs,
optimizer_params={"lr": 2e-5},
evaluation_steps=1000,
optimizer_params={"lr": args.learning_rate},
evaluation_steps=args.eval_steps,
warmup_steps=warmup_steps,
output_path=model_save_path,
)
Expand Down

0 comments on commit b96405c

Please sign in to comment.