Skip to content

Commit

Permalink
[wip] Update
Browse files Browse the repository at this point in the history
  • Loading branch information
ljvmiranda921 committed Sep 20, 2024
1 parent e43fa9a commit c7f40f8
Showing 1 changed file with 113 additions and 5 deletions.
118 changes: 113 additions & 5 deletions scripts/eval_helpsteer2_dev.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,14 @@
import sys
import hashlib
from pathlib import Path
from typing import Any

import torch
import transformers
from accelerate import Accelerator
from accelerate.logging import get_logger
from datasets import load_dataset
from fastchat.conversation import get_conv_template
from datasets import load_dataset, Dataset
from fastchat.conversation import get_conv_template, Conversation
from transformers import AutoModelForSequenceClassification, AutoTokenizer

from src.rm_inference import RewardBenchPipeline
Expand Down Expand Up @@ -123,6 +124,8 @@ def main():
dataset_path="nvidia/Helpsteer2",
split="validation",
weights="llama",
conv=conv,
keep_columns=["text_chosen", "text_rejected", "id"],
)

# copy id for saving, then remove
Expand Down Expand Up @@ -222,9 +225,11 @@ def check_tokenizer_chat_template(tokenizer):


def load_helpsteer2_dataset(
conv: Conversation,
dataset_path: str = "nvidia/Helpsteer2",
split: str = "validation",
weights: str = "llama",
keep_columns: list[str] = ["text_chosen", "text_rejected", "id"],
):
if weights == "llama":
wt_vec = {
Expand All @@ -246,8 +251,8 @@ def load_helpsteer2_dataset(
raise ValueError("Unknown weights. Please pass either 'llama' or 'nemotron'")

# Binarize the dataset
dataset = load_dataset(dataset_path, split=split)
df = dataset.to_pandas()
init_dataset = load_dataset(dataset_path, split=split)
df = init_dataset.to_pandas()

def _compute_rating(row, wt_vec):
return sum(row[col] * wt_vec[col] for col in wt_vec)
Expand Down Expand Up @@ -280,7 +285,110 @@ def _compute_rating(row, wt_vec):
lambda x: hashlib.md5(x.encode("utf-8")).hexdigest()
)

raw_dataset = Dataset.from_pandas(df_binary)
logging.info("*** Preparing dataset with FastChat ***")
dataset = raw_dataset.map(
prepare_dialogue,
fn_kwargs={"dialogue_template": conv},
num_proc=8,
load_from_cache_file=False,
)

all_cols = dataset.column_names
dataset = dataset.remove_columns([c for c in all_cols if c not in keep_columns])
breakpoint()


def prepare_dialogue(
example: dict[str, Any],
dialogue_template: Conversation,
ift: bool = False,
) -> dict[str, Any]:
"""Format example to single- or multi-turn dialogue."""
if all(k in example.keys() for k in ("chosen", "rejected")):
# multi turn
if isinstance(example["prompt"], list) and len(example["prompt"]) > 0:
# iterate through prompt messages, alternate user and assistant, end with example["chosen"]/rejected
dialogue_template.messages = []
for i, (line) in enumerate(example["prompt"]):
p = line["content"]
_ = line["role"]
if (i + 1) % 2 == 1:
dialogue_template.messages.append([dialogue_template.roles[0], p])
else:
dialogue_template.messages.append([dialogue_template.roles[1], p])
# assert that the last message before this is user
assert dialogue_template.messages[-1][0] == dialogue_template.roles[0]

# needed for DPO
temp_prompt = dialogue_template.get_prompt()

# end with chosen/rejected
dialogue_template.messages.append(
[dialogue_template.roles[1], example["chosen"]]
)
example["text_chosen"] = dialogue_template.get_prompt()

dialogue_template.messages[-1] = [
dialogue_template.roles[1],
example["rejected"],
]
example["text_rejected"] = dialogue_template.get_prompt()

example["prompt"] = temp_prompt

# single turn
else:
if isinstance(example["prompt"], list):
example["prompt"] = example["prompt"][0]
dialogue_template.messages = [
[dialogue_template.roles[0], example["prompt"]],
]
temp_prompt = dialogue_template.get_prompt()

dialogue_template.messages = [
[dialogue_template.roles[0], example["prompt"]],
[dialogue_template.roles[1], example["chosen"]],
]
example["text_chosen"] = dialogue_template.get_prompt()
dialogue_template.messages = [
[dialogue_template.roles[0], example["prompt"]],
[dialogue_template.roles[1], example["rejected"]],
]
example["text_rejected"] = dialogue_template.get_prompt()

example["prompt"] = temp_prompt
elif ift:
if isinstance(example["prompt"], list):
example["prompt"] = example["prompt"][0]

dialogue_template.messages = [
[dialogue_template.roles[0], example["prompt"]],
]
temp_prompt = dialogue_template.get_prompt()
dialogue_template.messages = [
[dialogue_template.roles[0], example["prompt"]],
[dialogue_template.roles[1], example["input"]],
]
example["text"] = dialogue_template.get_prompt()
example["prompt"] = temp_prompt # needed for DPO

else:
raise ValueError(
"Could not format example as dialogue for `rm` task!"
f"Require `[chosen, rejected]` keys but found {list(example.keys())}"
)
return example


if __name__ == "__main__":
# main()
load_helpsteer2_dataset()
chat_template = "tulu"
conv = get_conv_template(chat_template)
dataset = load_helpsteer2_dataset(
dataset_path="nvidia/Helpsteer2",
split="validation",
weights="llama",
conv=conv,
keep_columns=["text_chosen", "text_rejected", "id"],
)

0 comments on commit c7f40f8

Please sign in to comment.