Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Summarization TL;DR #27

Draft
wants to merge 62 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
62 commits
Select commit Hold shift + click to select a range
845ea72
test
vwxyzjn Sep 22, 2023
bc7d54c
quick change
vwxyzjn Sep 22, 2023
d1c5f86
quick change
vwxyzjn Sep 23, 2023
760d792
quick fix
vwxyzjn Sep 23, 2023
7a67477
push changes
vwxyzjn Sep 25, 2023
f12f228
push changes
vwxyzjn Sep 25, 2023
301d008
push change
vwxyzjn Sep 25, 2023
6b2d2ec
quick change
vwxyzjn Sep 25, 2023
e843166
normalize based on reference response
vwxyzjn Sep 25, 2023
1deff4e
quick push
vwxyzjn Sep 25, 2023
8d89722
fix
vwxyzjn Sep 27, 2023
7866eb2
push changes so far
vwxyzjn Oct 3, 2023
a2187cf
update
vwxyzjn Oct 4, 2023
09a705e
test
vwxyzjn Oct 5, 2023
2bcf98d
add eos token in reference response
vwxyzjn Oct 5, 2023
da5dc33
push changes
vwxyzjn Oct 25, 2023
332da0d
actually kind of work
vwxyzjn Oct 25, 2023
b6d4984
push changes
vwxyzjn Oct 28, 2023
3b7cbb7
pre-commit
vwxyzjn Oct 28, 2023
58db7b1
pre-commit
vwxyzjn Oct 28, 2023
c918340
remove unnecessary stuff
vwxyzjn Oct 28, 2023
ca76b55
dropout proper setting
vwxyzjn Oct 29, 2023
89f208f
make it work with gpt-large
vwxyzjn Oct 29, 2023
9302a5b
push changes
vwxyzjn Nov 14, 2023
94acf38
SFT seemed to work finally https://wandb.ai/costa-huang/tldr_summariz…
vwxyzjn Nov 17, 2023
b7f6876
reducing lr helps; refactor loop logic
vwxyzjn Nov 22, 2023
5622b63
push changes
vwxyzjn Nov 23, 2023
b8c5ffc
push changes
vwxyzjn Nov 28, 2023
ef6d2b1
deal with >48 tokens in rm dataset
vwxyzjn Nov 28, 2023
46e00ca
cache a debugging 25 token generated script
vwxyzjn Nov 30, 2023
c4ebe5e
seems successful!
vwxyzjn Dec 3, 2023
11ed546
seems to work ok with 1B models
vwxyzjn Dec 5, 2023
277fd53
minor change
vwxyzjn Dec 6, 2023
37a2963
remove files
vwxyzjn Dec 8, 2023
f5dc7cd
seems successful
vwxyzjn Dec 8, 2023
a46b8a6
push changes
vwxyzjn Dec 12, 2023
601c755
update sft stuff
vwxyzjn Dec 18, 2023
0f4dc10
update dependencies
vwxyzjn Dec 18, 2023
2a12638
quick push
vwxyzjn Dec 18, 2023
7e1336f
rename
vwxyzjn Dec 18, 2023
89ea1c5
quick change
vwxyzjn Dec 19, 2023
fceceaf
precommit
vwxyzjn Dec 19, 2023
b84c237
rename
vwxyzjn Dec 19, 2023
f686c51
push
vwxyzjn Dec 20, 2023
22aa0d1
change lr scheduler stuff
vwxyzjn Dec 21, 2023
0d4ddfa
support offload / 6.9b model
vwxyzjn Dec 26, 2023
0472f4d
sft / reward without padding
vwxyzjn Dec 26, 2023
451ec85
update benchmark.py
vwxyzjn Dec 26, 2023
2cbb1f7
precommit
vwxyzjn Dec 26, 2023
82ea918
test
vwxyzjn Dec 29, 2023
f97df9f
various fix; ppo repeat shuffle
vwxyzjn Dec 29, 2023
0efacc4
push changes
vwxyzjn Jan 1, 2024
537abaf
a regression in transformers https://github.com/huggingface/transform…
vwxyzjn Jan 3, 2024
55e3a50
various adjustments
vwxyzjn Jan 3, 2024
e8105dd
work with 6.9B
vwxyzjn Jan 3, 2024
3f6d045
prettier
vwxyzjn Jan 3, 2024
d12e443
push changes
vwxyzjn Jan 6, 2024
6f6490f
handles cnndm correctly
vwxyzjn Jan 11, 2024
2166b4f
quick change
vwxyzjn Jan 11, 2024
90c75f0
small refactor; remove unused var
vwxyzjn Jan 14, 2024
1d5b7db
evaluate test splits as well
vwxyzjn Jan 14, 2024
828db07
add DPO
vwxyzjn Jan 14, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -35,5 +35,5 @@ repos:
hooks:
- id: codespell
args:
- --ignore-words-list=nd,reacher,thist,ths,magent,ba
- --ignore-words-list=nd,reacher,thist,ths,magent,ba,rouge,hist
- --skip=docs/css/termynal.css,docs/js/termynal.js
96 changes: 54 additions & 42 deletions benchmark/benchmark.py
Original file line number Diff line number Diff line change
@@ -1,66 +1,77 @@
import argparse
import math
import os
import shlex
import subprocess
import uuid
from distutils.util import strtobool
from dataclasses import dataclass
from typing import Optional

import requests


def parse_args():
# fmt: off
parser = argparse.ArgumentParser()
parser.add_argument("--command", type=str, default="",
help="the command to run")
parser.add_argument("--num-seeds", type=int, default=3,
help="the number of random seeds")
parser.add_argument("--start-seed", type=int, default=1,
help="the number of the starting seed")
parser.add_argument("--workers", type=int, default=0,
help="the number of workers to run benchmark experimenets")
parser.add_argument("--auto-tag", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
help="if toggled, the runs will be tagged with git tags, commit, and pull request number if possible")
parser.add_argument("--slurm-template-path", type=str, default=None,
help="the path to the slurm template file (see docs for more details)")
parser.add_argument("--slurm-gpus-per-task", type=int, default=1,
help="the number of gpus per task to use for slurm jobs")
parser.add_argument("--slurm-total-cpus", type=int, default=50,
help="the number of gpus per task to use for slurm jobs")
parser.add_argument("--slurm-ntasks", type=int, default=1,
help="the number of tasks to use for slurm jobs")
parser.add_argument("--slurm-nodes", type=int, default=None,
help="the number of nodes to use for slurm jobs")
args = parser.parse_args()
# fmt: on
return args
import tyro


@dataclass
class Args:
command: str
"""the command to run"""
num_seeds: int = 3
"""the number of random seeds"""
start_seed: int = 1
"""the number of the starting seed"""
workers: int = 0
"""the number of workers to run benchmark experimenets"""
auto_tag: bool = True
"""if toggled, the runs will be tagged with git tags, commit, and pull request number if possible"""
slurm_template_path: Optional[str] = None
"""the path to the slurm template file (see docs for more details)"""
slurm_gpus_per_task: Optional[int] = None
"""the number of gpus per task to use for slurm jobs"""
slurm_total_cpus: Optional[int] = None
"""the number of gpus per task to use for slurm jobs"""
slurm_ntasks: Optional[int] = None
"""the number of tasks to use for slurm jobs"""
slurm_nodes: Optional[int] = None
"""the number of nodes to use for slurm jobs"""


def run_experiment(command: str):
command_list = shlex.split(command)
print(f"running {command}")
fd = subprocess.Popen(command_list)
return_code = fd.wait()
assert return_code == 0

# Use subprocess.PIPE to capture the output
fd = subprocess.Popen(command_list, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
output, errors = fd.communicate()

return_code = fd.returncode
assert return_code == 0, f"Command failed with error: {errors.decode('utf-8')}"

# Convert bytes to string and strip leading/trailing whitespaces
return output.decode("utf-8").strip()


def autotag() -> str:
wandb_tag = ""
print("autotag feature is enabled")
git_tag = ""
try:
git_tag = subprocess.check_output(["git", "describe", "--tags"]).decode("ascii").strip()
wandb_tag = f"{git_tag}"
print(f"identified git tag: {git_tag}")
except subprocess.CalledProcessError:
return wandb_tag
except subprocess.CalledProcessError as e:
print(e)
if len(git_tag) == 0:
try:
count = int(subprocess.check_output(["git", "rev-list", "--count", "HEAD"]).decode("ascii").strip())
hash = subprocess.check_output(["git", "rev-parse", "--short", "HEAD"]).decode("ascii").strip()
git_tag = f"no-tag-{count}-g{hash}"
print(f"identified git tag: {git_tag}")
except subprocess.CalledProcessError as e:
print(e)
wandb_tag = git_tag

git_commit = subprocess.check_output(["git", "rev-parse", "--verify", "HEAD"]).decode("ascii").strip()
try:
# try finding the pull request number on github
prs = requests.get(
f"https://api.github.com/search/issues?q=repo:vwxyzjn/lm-human-preference-details+is:pr+{git_commit}"
)
prs = requests.get(f"https://api.github.com/search/issues?q=repo:vwxyzjn/cleanrl+is:pr+{git_commit}")
if prs.status_code == 200:
prs = prs.json()
if len(prs["items"]) > 0:
Expand All @@ -75,7 +86,7 @@ def autotag() -> str:


if __name__ == "__main__":
args = parse_args()
args = tyro.cli(Args)
if args.auto_tag:
existing_wandb_tag = os.environ.get("WANDB_TAGS", "")
wandb_tag = autotag()
Expand All @@ -84,7 +95,7 @@ def autotag() -> str:
os.environ["WANDB_TAGS"] = ",".join([existing_wandb_tag, wandb_tag])
else:
os.environ["WANDB_TAGS"] = wandb_tag

print("WANDB_TAGS: ", os.environ.get("WANDB_TAGS", ""))
commands = []
for seed in range(0, args.num_seeds):
commands += [" ".join([args.command, "--seed", str(args.start_seed + seed)])]
Expand Down Expand Up @@ -133,4 +144,5 @@ def autotag() -> str:
slurm_path = os.path.join("slurm", f"{filename}.slurm")
print(f"saving command in {slurm_path}")
if args.workers > 0:
run_experiment(f"sbatch {slurm_path}")
job_id = run_experiment(f"sbatch --parsable {slurm_path}")
print(f"Job ID: {job_id}")
4 changes: 2 additions & 2 deletions benchmark/trl.slurm_template
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
#!/bin/bash
#SBATCH --partition=production-cluster
#SBATCH --partition=hopper-prod
#SBATCH --gpus-per-task={{gpus_per_task}}
#SBATCH --cpus-per-gpu={{cpus_per_gpu}}
#SBATCH --ntasks={{ntasks}}
#SBATCH --output=slurm/logs/%x_%j.out
#SBATCH --array={{array}}
#SBATCH --exclude=ip-26-0-149-199
##SBATCH --exclude=ip-26-0-149-199

{{nodes}}

Expand Down
219 changes: 219 additions & 0 deletions lm_human_preference_details/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,21 @@ def tldr_generator(mode, seed=0, shuffle=False):
yield text


# TL;DR filtered dataset, modified from
# https://github.com/openai/summarize-from-feedback/tree/700967448d10004279f138666442bf1497d0e705#reddit-tldr-dataset
def tldr_filtered_generator(split, seed=0, shuffle=False):
assert split in ["test", "train", "valid"]

data = load_dataset("vwxyzjn/summarize_from_feedback_tldr_3_filtered")[split]
if shuffle:
random.seed(seed)
dataset.shuffle(seed)

for item in data:
yield dict(reference=item["summary"], **{k: v for (k, v) in item.items() if k != "summary"})
# yield f"SUBREDDIT: r/{item['subreddit']}\n\nTITLE: {item['title']}\n\nPOST: {item['post']}\n\nTL;DR:"


# for testing only
def dummy_generator(mode, seed=0, shuffle=False):
while True:
Expand All @@ -74,5 +89,209 @@ def dummy_generator(mode, seed=0, shuffle=False):
"books": books_generator,
"cnndm": cnndm_generator,
"tldr": tldr_generator,
"tldr_3_filtered": tldr_filtered_generator,
"dummy": dummy_generator,
}


from dataclasses import dataclass, field
from typing import Dict, List, NewType, Optional, Union

import numpy as np
import torch


def first_true_indices(bools, dtype=torch.long):
"""
Takes an N-dimensional bool tensor and returns an (N-1)-dimensional tensor of integers giving
the position of the first True in each "row".

Returns the length of the rows (bools.size(-1)) if no element is True in a given row.
"""
row_len = bools.size(-1)
zero_or_index = row_len * (~bools).type(dtype) + torch.arange(row_len, dtype=dtype, device=bools.device)
return torch.min(zero_or_index, dim=-1).values


def to_numpy(x):
if isinstance(x, torch.Tensor):
return x.cpu().detach().numpy()
if isinstance(x, np.ndarray):
return x
if isinstance(x, float):
return np.array(x)
raise ValueError(f"Unexpected type {type(x)}")


# from summarize_from_feedback.utils import hyperparams
PADDING_TOKEN = -1


@dataclass
class TaskQueryHParams:
length: int = None
dataset: str = None
format_str: Optional[str] = None # if underlying dataset yields dicts, can format arbitrarily
truncate_field: Optional[str] = None
truncate_text: Optional[str] = None
padding: Optional[str] = None # defaults to repeated spaces
pad_side: Optional[str] = None


@dataclass
class TaskResponseHParams:
ref_format_str: Optional[str] = None # if underlying dataset yields dicts, can format arbitrarily
length: int = None
# Truncate response at the first occurrence of this token when sampling.
truncate_token: Optional[int] = None


@dataclass
class TaskHParams:
query: TaskQueryHParams = field(default_factory=TaskQueryHParams)
response: TaskResponseHParams = field(default_factory=TaskResponseHParams)


# Has endoftext potentially, random stuff after
SampledTokens = NewType("SampledTokens", torch.LongTensor)
SampledTokenList = NewType("SampledTokenList", List[int])
# Has only the actual sample + padding tokens
ProcessedTokens = NewType("ProcessedTokens", torch.LongTensor)
ProcessedTokenList = NewType("ProcessedTokenList", List[int])


class ResponseEncoder:
def __init__(self, H: TaskResponseHParams, encoder, padding_token=PADDING_TOKEN):
self.H = H
self.encoder = encoder
self.padding_token = padding_token

def process_responses(self, unprocessed_tokens: SampledTokens) -> ProcessedTokens:
assert unprocessed_tokens.size(-1) == self.H.length
if self.H.truncate_token is not None:
assert self.padding_token is not None
trunc_idxs = first_true_indices(unprocessed_tokens == self.H.truncate_token).unsqueeze(-1)
new_size = [1] * (len(unprocessed_tokens.size()) - 1) + [self.H.length]
idxs = torch.arange(self.H.length, device=unprocessed_tokens.device).view(*new_size)
return torch.masked_fill(unprocessed_tokens, idxs > trunc_idxs, self.padding_token)
else:
return unprocessed_tokens

def encode_response(self, text: str, allow_truncate: bool = False) -> ProcessedTokenList:
tokens = self.encoder.encode(text)
if allow_truncate:
tokens = tokens[: self.H.length - (0 if self.H.truncate_token is None else 1)]
if self.H.truncate_token is not None:
tokens = tokens + [self.H.truncate_token]
if self.padding_token is None:
assert len(tokens) == self.H.length
return tokens
assert len(tokens) <= self.H.length, f"Response too long (limit {self.H.length}): {text}"
return tokens + [self.padding_token] * (self.H.length - len(tokens))

def decode_response(self, processed_response_tokens: ProcessedTokenList) -> str:
tokens = [x for x in processed_response_tokens if x != self.padding_token]
if self.H.truncate_token is not None:
if tokens[-1] == self.H.truncate_token:
tokens = tokens[:-1]
else:
assert len(tokens) == self.H.length
return self.encoder.decode(tokens)

def decode_responses(self, processed_response_tokens: Union[ProcessedTokens, np.ndarray]): # -> array of array of ... str:
def _decode_responses_list(l):
if isinstance(l[0], (int, np.int64)):
return self.decode_response(l)
return [_decode_responses_list(ll) for ll in l]

return _decode_responses_list(to_numpy(processed_response_tokens))


def _ensure_length(toks, l, pad_sequence=None, pad_side=None, truncate_side=None):
assert pad_side in (None, "left", "right")
assert truncate_side in (None, "left", "right")
if len(toks) < l:
assert pad_sequence is not None
pad_amt = l - len(toks)
assert len(pad_sequence) >= pad_amt, f"{len(pad_sequence)} < {pad_amt}"
if pad_side is None:
assert len(toks) == l, f"Needed to pad! {len(toks)} < {l}"
return toks
elif pad_side == "left":
return pad_sequence[-pad_amt:] + toks
else:
assert pad_side == "right"
return toks + pad_sequence[:pad_amt]
if truncate_side is None:
assert len(toks) == l, f"Needed to truncate! {len(toks)} > {l}"
return toks
elif truncate_side == "left":
return toks[-l:]
else:
assert truncate_side == "right"
return toks[:l]


def _get_query_padding_for_task(encoder, hparams: TaskQueryHParams):
if hparams.padding is not None:
return encoder.encode(hparams.padding)
return encoder.encode(" ") * hparams.length


def process_query(query_info: Dict[str, str], *, encoder, hparams: TaskQueryHParams, pad_sequence=None):
if pad_sequence is None:
pad_sequence = _get_query_padding_for_task(encoder, hparams)
if isinstance(query_info, str):
query_info = dict(query=query_info)
else:
# copy to avoid mutating input
query_info = dict(**query_info)

format_str = hparams.format_str or "{query}"
# breakpoint()
query_tokens = encoder.encode(format_str.format(**query_info))
truncate_field = hparams.truncate_field or "query"

if truncate_field not in query_info:
raise ValueError(f"Could not truncate field {truncate_field}, found fields: {query_info.keys()}!")
while len(query_tokens) > hparams.length:
if not len(query_info[truncate_field]):
raise ValueError("Could not truncate enough!")

i = -1 # default to just remove one character
if hparams.truncate_text:
try:
i = query_info[truncate_field].rindex(hparams.truncate_text)
except ValueError:
pass
query_info[truncate_field] = query_info[truncate_field][:i]
query_tokens = encoder.encode(format_str.format(**query_info))

return dict(query_token=_ensure_length(query_tokens, hparams.length, pad_side=hparams.pad_side, pad_sequence=pad_sequence))


if __name__ == "__main__":
gen = tldr_filtered_generator("train")
for i in range(10):
d = next(gen)
from transformers import AutoTokenizer

encoder = AutoTokenizer.from_pretrained("gpt2")

q = process_query(
d,
encoder=encoder,
hparams=TaskQueryHParams(
length=512,
dataset="tldr_3_filtered",
format_str="SUBREDDIT: r/{subreddit}\n\nTITLE: {title}\n\nPOST: {post}\n\nTL;DR:",
truncate_field="post",
truncate_text="\n",
padding=None,
pad_side="left",
),
)
print("===start")
print(d, len(q))
print("===", encoder.decode(q["tokens"]))
print("===end")
Loading
Loading