From 845ea721223e148191dbadb94b162ff2622acca9 Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Fri, 22 Sep 2023 10:39:44 -0400 Subject: [PATCH 01/62] test --- lm_human_preference_details/data.py | 219 +++++ .../train_policy_accelerate_summarize.py | 860 ++++++++++++++++++ .../train_reward_accelerate_summarize.py | 817 +++++++++++++++++ 3 files changed, 1896 insertions(+) create mode 100644 lm_human_preference_details/train_policy_accelerate_summarize.py create mode 100644 lm_human_preference_details/train_reward_accelerate_summarize.py diff --git a/lm_human_preference_details/data.py b/lm_human_preference_details/data.py index 5044ae3..08fc5bd 100644 --- a/lm_human_preference_details/data.py +++ b/lm_human_preference_details/data.py @@ -64,6 +64,21 @@ def tldr_generator(mode, seed=0, shuffle=False): yield text +# TL;DR filtered dataset, modified from +# https://github.com/openai/summarize-from-feedback/tree/700967448d10004279f138666442bf1497d0e705#reddit-tldr-dataset +def tldr_filtered_generator(split, seed=0, shuffle=False): + assert split in ["test", "train", "valid"] + + data = load_dataset("vwxyzjn/summarize_from_feedback_tldr_3_filtered")[split] + if shuffle: + random.seed(seed) + dataset.shuffle(seed) + + for item in data: + yield dict(reference=item["summary"], **{k: v for (k, v) in item.items() if k != "summary"}) + # yield f"SUBREDDIT: r/{item['subreddit']}\n\nTITLE: {item['title']}\n\nPOST: {item['post']}\n\nTL;DR:" + + # for testing only def dummy_generator(mode, seed=0, shuffle=False): while True: @@ -74,5 +89,209 @@ def dummy_generator(mode, seed=0, shuffle=False): "books": books_generator, "cnndm": cnndm_generator, "tldr": tldr_generator, + "tldr_3_filtered": tldr_filtered_generator, "dummy": dummy_generator, } + + +from dataclasses import dataclass, field +from typing import Dict, List, NewType, Optional, Union + +import numpy as np +import torch + + +def first_true_indices(bools, dtype=torch.long): + """ + Takes an N-dimensional bool tensor and returns an (N-1)-dimensional tensor of integers giving + the position of the first True in each "row". + + Returns the length of the rows (bools.size(-1)) if no element is True in a given row. + """ + row_len = bools.size(-1) + zero_or_index = row_len * (~bools).type(dtype) + torch.arange(row_len, dtype=dtype, device=bools.device) + return torch.min(zero_or_index, dim=-1).values + + +def to_numpy(x): + if isinstance(x, torch.Tensor): + return x.cpu().detach().numpy() + if isinstance(x, np.ndarray): + return x + if isinstance(x, float): + return np.array(x) + raise ValueError(f"Unexpected type {type(x)}") + + +# from summarize_from_feedback.utils import hyperparams +PADDING_TOKEN = -1 + + +@dataclass +class TaskQueryHParams: + length: int = None + dataset: str = None + format_str: Optional[str] = None # if underlying dataset yields dicts, can format arbitrarily + truncate_field: Optional[str] = None + truncate_text: Optional[str] = None + padding: Optional[str] = None # defaults to repeated spaces + pad_side: Optional[str] = None + + +@dataclass +class TaskResponseHParams: + ref_format_str: Optional[str] = None # if underlying dataset yields dicts, can format arbitrarily + length: int = None + # Truncate response at the first occurrence of this token when sampling. + truncate_token: Optional[int] = None + + +@dataclass +class TaskHParams: + query: TaskQueryHParams = field(default_factory=TaskQueryHParams) + response: TaskResponseHParams = field(default_factory=TaskResponseHParams) + + +# Has endoftext potentially, random stuff after +SampledTokens = NewType("SampledTokens", torch.LongTensor) +SampledTokenList = NewType("SampledTokenList", List[int]) +# Has only the actual sample + padding tokens +ProcessedTokens = NewType("ProcessedTokens", torch.LongTensor) +ProcessedTokenList = NewType("ProcessedTokenList", List[int]) + + +class ResponseEncoder: + def __init__(self, H: TaskResponseHParams, encoder, padding_token=PADDING_TOKEN): + self.H = H + self.encoder = encoder + self.padding_token = padding_token + + def process_responses(self, unprocessed_tokens: SampledTokens) -> ProcessedTokens: + assert unprocessed_tokens.size(-1) == self.H.length + if self.H.truncate_token is not None: + assert self.padding_token is not None + trunc_idxs = first_true_indices(unprocessed_tokens == self.H.truncate_token).unsqueeze(-1) + new_size = [1] * (len(unprocessed_tokens.size()) - 1) + [self.H.length] + idxs = torch.arange(self.H.length, device=unprocessed_tokens.device).view(*new_size) + return torch.masked_fill(unprocessed_tokens, idxs > trunc_idxs, self.padding_token) + else: + return unprocessed_tokens + + def encode_response(self, text: str, allow_truncate: bool = False) -> ProcessedTokenList: + tokens = self.encoder.encode(text) + if allow_truncate: + tokens = tokens[: self.H.length - (0 if self.H.truncate_token is None else 1)] + if self.H.truncate_token is not None: + tokens = tokens + [self.H.truncate_token] + if self.padding_token is None: + assert len(tokens) == self.H.length + return tokens + assert len(tokens) <= self.H.length, f"Response too long (limit {self.H.length}): {text}" + return tokens + [self.padding_token] * (self.H.length - len(tokens)) + + def decode_response(self, processed_response_tokens: ProcessedTokenList) -> str: + tokens = [x for x in processed_response_tokens if x != self.padding_token] + if self.H.truncate_token is not None: + if tokens[-1] == self.H.truncate_token: + tokens = tokens[:-1] + else: + assert len(tokens) == self.H.length + return self.encoder.decode(tokens) + + def decode_responses(self, processed_response_tokens: Union[ProcessedTokens, np.ndarray]): # -> array of array of ... str: + def _decode_responses_list(l): + if isinstance(l[0], (int, np.int64)): + return self.decode_response(l) + return [_decode_responses_list(ll) for ll in l] + + return _decode_responses_list(to_numpy(processed_response_tokens)) + + +def _ensure_length(toks, l, pad_sequence=None, pad_side=None, truncate_side=None): + assert pad_side in (None, "left", "right") + assert truncate_side in (None, "left", "right") + if len(toks) < l: + assert pad_sequence is not None + pad_amt = l - len(toks) + assert len(pad_sequence) >= pad_amt, f"{len(pad_sequence)} < {pad_amt}" + if pad_side is None: + assert len(toks) == l, f"Needed to pad! {len(toks)} < {l}" + return toks + elif pad_side == "left": + return pad_sequence[-pad_amt:] + toks + else: + assert pad_side == "right" + return toks + pad_sequence[:pad_amt] + if truncate_side is None: + assert len(toks) == l, f"Needed to truncate! {len(toks)} > {l}" + return toks + elif truncate_side == "left": + return toks[-l:] + else: + assert truncate_side == "right" + return toks[:l] + + +def _get_query_padding_for_task(encoder, hparams: TaskQueryHParams): + if hparams.padding is not None: + return encoder.encode(hparams.padding) + return encoder.encode(" ") * hparams.length + + +def process_query(query_info: Dict[str, str], *, encoder, hparams: TaskQueryHParams, pad_sequence=None): + if pad_sequence is None: + pad_sequence = _get_query_padding_for_task(encoder, hparams) + if isinstance(query_info, str): + query_info = dict(query=query_info) + else: + # copy to avoid mutating input + query_info = dict(**query_info) + + format_str = hparams.format_str or "{query}" + # breakpoint() + query_tokens = encoder.encode(format_str.format(**query_info)) + truncate_field = hparams.truncate_field or "query" + + if truncate_field not in query_info: + raise ValueError(f"Could not truncate field {truncate_field}, found fields: {query_info.keys()}!") + while len(query_tokens) > hparams.length: + if not len(query_info[truncate_field]): + raise ValueError("Could not truncate enough!") + + i = -1 # default to just remove one character + if hparams.truncate_text: + try: + i = query_info[truncate_field].rindex(hparams.truncate_text) + except ValueError: + pass + query_info[truncate_field] = query_info[truncate_field][:i] + query_tokens = encoder.encode(format_str.format(**query_info)) + + return dict(query_token=_ensure_length(query_tokens, hparams.length, pad_side=hparams.pad_side, pad_sequence=pad_sequence)) + + +if __name__ == "__main__": + gen = tldr_filtered_generator("train") + for i in range(10): + d = next(gen) + from transformers import AutoTokenizer + + encoder = AutoTokenizer.from_pretrained("gpt2") + + q = process_query( + d, + encoder=encoder, + hparams=TaskQueryHParams( + length=512, + dataset="tldr_3_filtered", + format_str="SUBREDDIT: r/{subreddit}\n\nTITLE: {title}\n\nPOST: {post}\n\nTL;DR:", + truncate_field="post", + truncate_text="\n", + padding=None, + pad_side="left", + ), + ) + print("===start") + print(d, len(q)) + print("===", encoder.decode(q["tokens"])) + print("===end") diff --git a/lm_human_preference_details/train_policy_accelerate_summarize.py b/lm_human_preference_details/train_policy_accelerate_summarize.py new file mode 100644 index 0000000..1dbd02e --- /dev/null +++ b/lm_human_preference_details/train_policy_accelerate_summarize.py @@ -0,0 +1,860 @@ +import functools +import os +import random +import time +from dataclasses import asdict, dataclass, field +from types import SimpleNamespace +from typing import List, Optional + +import numpy as np +import pandas as pd +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +import tyro +from accelerate import Accelerator +from accelerate.state import AcceleratorState +from datasets import load_dataset +from rich.console import Console +from rich.pretty import pprint +from rich.table import Table +from torch import Tensor, optim +from torch.optim.optimizer import ( + _dispatch_sqrt, + _get_value, + _use_grad_for_differentiable, +) +from torch.utils.data import DataLoader +from torch.utils.tensorboard import SummaryWriter +from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig + + +@dataclass +class AdaptiveKLParams: + target: float = 6.0 + horizon: int = 10000 # in episodes + + +@dataclass +class RewardHParams: + kl_coef: float = 0.15 + adaptive_kl: Optional[AdaptiveKLParams] = field(default_factory=AdaptiveKLParams) + trained_model: Optional[str] = "models/reward.pt" + label_dataset: tyro.conf.Suppress[Optional[str]] = None + + +@dataclass +class PpoHParams: + total_episodes: int = 1000000 + local_batch_size: int = 64 + local_mini_batch_size: tyro.conf.Suppress[int] = None + batch_size: tyro.conf.Suppress[int] = None + mini_batch_size: tyro.conf.Suppress[int] = None + gradient_accumulation_steps: int = 1 + """gradient accumulation steps""" + local_micro_batch_size: tyro.conf.Suppress[int] = None + """per rank micro batch size""" + world_size: tyro.conf.Suppress[int] = None + batch_size: tyro.conf.Suppress[int] = None + minibatch_size: tyro.conf.Suppress[int] = None + num_updates: tyro.conf.Suppress[int] = None + nminibatches: int = 1 + noptepochs: int = 4 + lr: float = 0.00001 + eps: float = 1e-5 + vf_coef: float = 0.1 + cliprange: float = 0.2 + cliprange_value: float = 0.2 + gamma: float = 1 + lam: float = 0.95 + whiten_rewards: bool = True + + +@dataclass +class TaskHParams: + # Query params + query_length: int = 512 + query_dataset: str = "tldr_3_filtered" + + query_format_str: Optional[str] = "SUBREDDIT: r/{subreddit}\n\nTITLE: {title}\n\nPOST: {post}\n\nTL;DR:" + query_truncate_field: Optional[str] = "post" + query_truncate_text: Optional[str] = "\n" + query_padding: Optional[str] = None # defaults to repeated spaces + query_pad_side: Optional[str] = "left" + + # Response params + response_length: int = 48 + + # Truncate response after the first occurrence of this token at or after index after when sampling. + truncate_token: int = 50256 # EOS token + truncate_after: int = 16 + penalty_reward_value: int = -1 + + # LM params + temperature: float = 0.7 + + +# a patch +@dataclass +class TaskQueryHParams: + length: int = None + dataset: str = None + format_str: Optional[str] = None # if underlying dataset yields dicts, can format arbitrarily + truncate_field: Optional[str] = None + truncate_text: Optional[str] = None + padding: Optional[str] = None # defaults to repeated spaces + pad_side: Optional[str] = None + + +@dataclass +class Args: + # common args + exp_name: str = os.path.basename(__file__)[: -len(".py")] + """the name of this experiment""" + seed: int = 1 + """seed of the experiment""" + track: bool = False + """if toggled, this experiment will be tracked with Weights and Biases""" + wandb_project_name: str = "cleanrl" + """the wandb's project name""" + wandb_entity: Optional[str] = None + """the entity (team) of wandb's project""" + cuda: bool = True + """Whether to use cuda if available.""" + run_name: tyro.conf.Suppress[str] = None + """TO BE FILLED: a unique name of this run""" + upload_model: bool = False + "whether to upload the saved model to huggingface" + hf_entity: str = "" + "the user or org name of the model repository from the Hugging Face Hub" + + base_model: str = "gpt2" + """the name of the pretrained model to use""" + deepspeed: bool = False + """Whether to use deepspeed to train the model""" + print_sample_output_freq: int = 10 + """How often to print sample output""" + save_path: str = "models/policy.pt" + """Where to save the model""" + use_tensorflow_adam: bool = True + """Whether to use tensorflow-style Adam optimizer instead of PyTorch's""" + task: TaskHParams = field(default_factory=TaskHParams) + rewards: RewardHParams = field(default_factory=RewardHParams) + ppo: PpoHParams = field(default_factory=PpoHParams) + + +def first_true_indices(bools, dtype=torch.long): + """ + Takes an N-dimensional bool tensor and returns an (N-1)-dimensional tensor of integers giving + the position of the first True in each "row". + + Returns the length of the rows (bools.size(-1)) if no element is True in a given row. + """ + row_len = bools.size(-1) + zero_or_index = row_len * (~bools).type(dtype) + torch.arange(row_len, dtype=dtype, device=bools.device) + return torch.min(zero_or_index, dim=-1).values + + +def print_rich_table(title: str, df: pd.DataFrame, console: Console) -> Table: + table = Table(show_lines=True) + for column in df.columns: + table.add_column(column) + for _, row in df.iterrows(): + table.add_row(*row.astype(str).tolist()) + console.rule(f"[bold red]{title}") + console.print(table) + + +def _single_tensor_adam( + params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_avg_sqs: List[Tensor], + max_exp_avg_sqs: List[Tensor], + state_steps: List[Tensor], + grad_scale: Optional[Tensor], + found_inf: Optional[Tensor], + *, + amsgrad: bool, + beta1: float, + beta2: float, + lr: float, + weight_decay: float, + eps: float, + maximize: bool, + capturable: bool, + differentiable: bool, +): + assert grad_scale is None and found_inf is None + + for i, param in enumerate(params): + grad = grads[i] if not maximize else -grads[i] + exp_avg = exp_avgs[i] + exp_avg_sq = exp_avg_sqs[i] + step_t = state_steps[i] + # update step + step_t += 1 + # Decay the first and second moment running average coefficient + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2) + step = _get_value(step_t) + + ### pytorch adam implementation: + # bias_correction1 = 1 - beta1 ** step + # bias_correction2 = 1 - beta2 ** step + # step_size = lr / bias_correction1 + # bias_correction2_sqrt = _dispatch_sqrt(bias_correction2) + # denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps) + # param.addcdiv_(exp_avg, denom, value=-step_size) + + ### tensorflow adam implementation: + lr_t = lr * _dispatch_sqrt(1 - beta2**step) / (1 - beta1**step) + denom = exp_avg_sq.sqrt().add_(eps) + param.addcdiv_(exp_avg, denom, value=-lr_t) + + +def adam( + params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_avg_sqs: List[Tensor], + max_exp_avg_sqs: List[Tensor], + state_steps: List[Tensor], + # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627 + # setting this as kwarg for now as functional API is compiled by torch/distributed/optim + foreach: Optional[bool] = None, + capturable: bool = False, + differentiable: bool = False, + fused: Optional[bool] = None, + grad_scale: Optional[Tensor] = None, + found_inf: Optional[Tensor] = None, + *, + amsgrad: bool, + beta1: float, + beta2: float, + lr: float, + weight_decay: float, + eps: float, + maximize: bool, +): + func = _single_tensor_adam + + func( + params, + grads, + exp_avgs, + exp_avg_sqs, + max_exp_avg_sqs, + state_steps, + amsgrad=amsgrad, + beta1=beta1, + beta2=beta2, + lr=lr, + weight_decay=weight_decay, + eps=eps, + maximize=maximize, + capturable=capturable, + differentiable=differentiable, + grad_scale=grad_scale, + found_inf=found_inf, + ) + + +class AdamTensorFlowStyle(optim.Adam): + @_use_grad_for_differentiable + def step(self, closure=None): + self._cuda_graph_capture_health_check() + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + params_with_grad = [] + grads = [] + exp_avgs = [] + exp_avg_sqs = [] + max_exp_avg_sqs = [] + state_steps = [] + beta1, beta2 = group["betas"] + + self._init_group( + group, + params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + max_exp_avg_sqs, + state_steps, + ) + + adam( + params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + max_exp_avg_sqs, + state_steps, + amsgrad=group["amsgrad"], + beta1=beta1, + beta2=beta2, + lr=group["lr"], + weight_decay=group["weight_decay"], + eps=group["eps"], + maximize=group["maximize"], + foreach=group["foreach"], + capturable=group["capturable"], + differentiable=group["differentiable"], + fused=group["fused"], + grad_scale=getattr(self, "grad_scale", None), + found_inf=getattr(self, "found_inf", None), + ) + + return loss + + +class AdaptiveKLController: + def __init__(self, init_kl_coef: float, hparams: AdaptiveKLParams): + self.value = init_kl_coef + self.hparams = hparams + + def update(self, current, n_steps): + target = self.hparams.target + proportional_error = np.clip(current / target - 1, -0.2, 0.2) + mult = 1 + proportional_error * n_steps / self.hparams.horizon + self.value *= mult + + +def layer_init(layer, std=np.sqrt(2), bias_const=0.0): + torch.nn.init.normal_(layer.weight, std=std) + torch.nn.init.constant_(layer.bias, val=bias_const) + return layer + + +def whiten(values, shift_mean=True): + # `unbiased=False` matches TF `tf.nn.moments`'s setting + mean, var = torch.mean(values), torch.var(values, unbiased=False) + whitened = (values - mean) * torch.rsqrt(var + 1e-8) + if not shift_mean: + whitened += mean + return whitened + + +class AutoModelForCausalLMWithScalarHead(nn.Module): + def __init__(self, lm_backbone): + super().__init__() + self.lm_backbone = lm_backbone + self.scalar_head = layer_init(nn.Linear(lm_backbone.config.hidden_size, 1), std=0) + + def forward(self, **kwargs): + output = self.lm_backbone(**kwargs) + return output, self.scalar_head(output.hidden_states[-1]) + + +class AutoModelForCausalLMWithRewardHead(nn.Module): + def __init__(self, lm_backbone): + super().__init__() + self.lm_backbone = lm_backbone + self.scalar_head = layer_init( + nn.Linear(lm_backbone.config.hidden_size, 1), + std=1 / np.sqrt(lm_backbone.config.hidden_size + 1), + ) + self.reward_gain = torch.nn.Parameter(torch.tensor(1.0), requires_grad=True) + self.reward_bias = torch.nn.Parameter(torch.tensor(0.0), requires_grad=True) + + def forward(self, **kwargs): + output = self.lm_backbone(**kwargs) + reward_latents = output.hidden_states[-1] + # shape: [batch_size, length, hidden_size] + last_reward_latents = reward_latents + # shape: [batch_size, hidden_size] + reward = self.scalar_head(last_reward_latents) + # shape: [batch_size, 1] + reward = self.reward_gain * reward + self.reward_bias + return output, reward + + +def right_padding_to_left_padding(tokens, pad_id): + """Convert from right padding to left padding.""" + assert tokens.ndim == 2 + return torch.tensor( + [[pad_id] * (row == pad_id).sum() + [x for x in row if x != pad_id] for row in tokens], + device=tokens.device, + ) + + +def ceil_div(a, b): + return (a - 1) // b + 1 + + +def exact_div(a, b): + q = a // b + if a != q * b: + raise ValueError(f"Inexact division: {a} / {b} = {a / b}") + return q + + +def generate(lm_backbone, queries, tokenizer, generation_config): + """generate in a way that does not affect padding tokens""" + context_length = queries.shape[1] + attention_mask = queries != tokenizer.pad_token_id + input_ids = queries.clone() + input_ids[~attention_mask] = 0 # set padding tokens to 0 + output = lm_backbone.generate( + input_ids=input_ids, + attention_mask=attention_mask, + # position_ids=attention_mask.cumsum(1) - attention_mask.long(), # generation collapsed if this was turned on. TODO: why does generation collapse with this? + generation_config=generation_config, + return_dict_in_generate=True, + ) + # restore padding tokens + return torch.cat((queries, output.sequences[:, context_length:]), dim=1) + + +def get_reward(reward_model, query_responses, args): + attention_mask = query_responses != args.pad_token_id + position_ids = attention_mask.cumsum(1) - attention_mask.long() # exclusive cumsum + input_ids = query_responses.clone() + input_ids[~attention_mask] = 0 + return reward_model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + return_dict=True, + output_hidden_states=True, + ) + + +def forward(policy, query_responses, tokenizer): + attention_mask = query_responses != tokenizer.pad_token_id + position_ids = attention_mask.cumsum(1) - attention_mask.long() # exclusive cumsum + input_ids = query_responses.clone() + input_ids[~attention_mask] = 0 + return policy( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + return_dict=True, + output_hidden_states=True, + ) + + +# def train(args: Args): +if __name__ == "__main__": + args = tyro.cli(Args) + + accelerator = Accelerator(gradient_accumulation_steps=args.ppo.gradient_accumulation_steps) + args.ppo.world_size = accelerator.num_processes + args.ppo.batch_size = int(args.ppo.local_batch_size * args.ppo.world_size) + args.ppo.minibatch_size = exact_div(args.ppo.batch_size, args.ppo.nminibatches) + args.ppo.local_mini_batch_size = exact_div(args.ppo.local_batch_size, args.ppo.nminibatches) + args.ppo.local_micro_batch_size = exact_div(args.ppo.local_mini_batch_size, args.ppo.gradient_accumulation_steps) + if args.ppo.whiten_rewards: + assert ( + args.ppo.local_mini_batch_size >= 8 + ), f"Per-rank minibatch size {args.ppo.local_mini_batch_size} is insufficient for whitening" + # `per_rank_rollout_batch_size` is our `args.ppo.local_batch_size` + # `per_rank_minibatch_size` is our `args.ppo.local_mini_batch_size` + args.ppo.num_updates = args.ppo.total_episodes // args.ppo.batch_size + + console = Console(force_terminal=True) + run_name = f"{args.exp_name}__{args.seed}__{int(time.time())}" + writer = SimpleNamespace() # dummy writer + writer.add_scalar = lambda x, y, z: None + writer.add_histogram = lambda x, y, z: None + if accelerator.is_main_process: + if args.track: + import wandb + + wandb.init( + project=args.wandb_project_name, + entity=args.wandb_entity, + sync_tensorboard=True, + config=asdict(args), + name=run_name, + save_code=True, + ) + wandb.run.log_code(".") + writer = SummaryWriter(f"runs/{run_name}") + writer.add_text( + "hyperparameters", + "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])), + ) + pprint(args) + device = accelerator.device + local_seed = args.seed + accelerator.process_index * 100003 # Prime + random.seed(local_seed) + np.random.seed(local_seed) + torch.manual_seed(local_seed) + torch.backends.cudnn.deterministic = True + tokenizer = AutoTokenizer.from_pretrained( + args.base_model, + padding_side="right", + trust_remote_code=True, + ) + # we use the padding token manually but do not resize the token embedding of the model + tokenizer.add_special_tokens({"pad_token": "[PAD]"}) + reward_model = AutoModelForCausalLMWithRewardHead( + AutoModelForCausalLM.from_pretrained(args.base_model, trust_remote_code=True) + ) + if args.rewards.trained_model: + reward_model.load_state_dict(torch.load(args.rewards.trained_model, map_location=device)) + print(f"loaded pretrained reward model from {args.rewards.trained_model}") + # each class should have a separate pretrained model that do not share weights + ref_policy = AutoModelForCausalLMWithScalarHead( + AutoModelForCausalLM.from_pretrained(args.base_model, trust_remote_code=True) + ) + policy = AutoModelForCausalLMWithScalarHead(AutoModelForCausalLM.from_pretrained(args.base_model, trust_remote_code=True)) + policy.lm_backbone.generation_config.eos_token_id = ( + None # disable `pad_token_id` and `eos_token_id` because we just want to + ) + policy.lm_backbone.generation_config.pad_token_id = None # generate tokens without truncation / padding + # IMPORTANT: Layer norm produces weird gradients, which affects Adam optimizer to impact all the parameters systematically + # see https://github.com/pytorch/pytorch/issues/104857 for more details + if args.use_tensorflow_adam: + optimizer = AdamTensorFlowStyle(policy.parameters(), lr=args.ppo.lr, eps=args.ppo.eps) + else: + optimizer = optim.Adam(policy.parameters(), lr=args.ppo.lr, eps=args.ppo.eps) + dataset = load_dataset("bookcorpus", split="train") + dataset = dataset.shuffle(seed=local_seed) + + def process_query_data(x, base_model: str, response_length: int): # added args so it's hashable + tokenizer = AutoTokenizer.from_pretrained(base_model) + tokenizer.add_special_tokens({"pad_token": "[PAD]"}) + return { + "query_token": tokenizer( + x["text"], padding="max_length", max_length=response_length, truncation=True, return_tensors="pt" + )["input_ids"], + } + + dataset.set_transform( + functools.partial(process_query_data, base_model=args.base_model, response_length=args.task.response_length) + ) + dataloader = DataLoader(dataset, batch_size=args.ppo.local_batch_size) + policy, optimizer, dataloader = accelerator.prepare(policy, optimizer, dataloader) + if args.deepspeed: + import deepspeed + + deepspeed_states = AcceleratorState().deepspeed_plugin + # deepspeed_states.deepspeed_config["train_micro_batch_size_per_gpu"] = args.ppo.local_micro_batch_size + # deepspeed_states.deepspeed_config["checkpoint"] = {"use_node_local_storage": True} + eval_ds_config = { + "train_micro_batch_size_per_gpu": deepspeed_states.deepspeed_config["train_micro_batch_size_per_gpu"], + # "steps_per_print": 10, + # "zero_optimization": { + # "stage": stage, + # "stage3_param_persistence_threshold": 1e4, + # "offload_param": { + # "device": off_load_device + # } + # }, + "bf16": {"enabled": True}, + "prescale_gradients": False, + "wall_clock_breakdown": False, + } + reward_model, *_ = deepspeed.initialize(model=reward_model, config=eval_ds_config) + reward_model.eval() + ref_policy, *_ = deepspeed.initialize(model=ref_policy, config=eval_ds_config) + ref_policy.eval() + else: + ref_policy = ref_policy.to(device) + reward_model = reward_model.to(device) + + def repeat_generator(): # TODO: ideally we shuffle the dataloader as well + while True: + yield from dataloader + + iter_dataloader = iter(repeat_generator()) + kl_ctl = AdaptiveKLController(args.rewards.kl_coef, hparams=args.rewards.adaptive_kl) + # WARNING: even with `max_new_tokens` and `min_new_tokens` set to the same value, the number of tokens generated + # may not be the same. TODO: investigate further, we just want to generate a fixed number of tokens + generation_config = GenerationConfig( + max_new_tokens=args.task.response_length, + min_new_tokens=args.task.response_length, + temperature=args.task.temperature, + top_k=0.0, + top_p=1.0, + do_sample=True, + ) + + print("===training policy===") + global_step = 0 + stats_shape = (args.ppo.noptepochs, args.ppo.nminibatches, args.ppo.gradient_accumulation_steps) + approxkls_stats = torch.zeros(stats_shape, device=device) + clipfracs_stats = torch.zeros(stats_shape, device=device) + pg_losses_stats = torch.zeros(stats_shape, device=device) + vf_losses_stats = torch.zeros(stats_shape, device=device) + vf_clipfrac_stats = torch.zeros(stats_shape, device=device) + entropies_stats = torch.zeros(stats_shape, device=device) + for update in range(1, args.ppo.num_updates + 1): + global_step += 1 * args.ppo.batch_size + frac = 1.0 - (update - 1.0) / args.ppo.num_updates + lrnow = frac * args.ppo.lr + optimizer.param_groups[0]["lr"] = lrnow + data = next(iter_dataloader) + with torch.no_grad(): + queries = data["query_token"].to(device) + queries = right_padding_to_left_padding(data["query_token"], tokenizer.pad_token_id).to(device) + query_responses = generate( + accelerator.unwrap_model(policy).lm_backbone, + queries, + tokenizer, + generation_config, + ) + context_length = queries.shape[1] + responses = query_responses[:, context_length:] + + output, full_values = forward(policy, query_responses, tokenizer) + values = full_values[:, context_length - 1 : -1].squeeze(-1) + logits = output.logits[:, context_length - 1 : -1] + logits /= args.task.temperature + all_logprobs = F.log_softmax(logits, dim=-1) + logprobs = torch.gather(all_logprobs, 2, responses.unsqueeze(-1)).squeeze(-1) + del output, logits, all_logprobs + torch.cuda.empty_cache() + + ref_output, _ = forward(ref_policy, query_responses, tokenizer) + ref_logits = ref_output.logits[:, context_length - 1 : -1] + ref_logits /= args.task.temperature + ref_all_logprobs = F.log_softmax(ref_logits, dim=-1) + ref_logprobs = torch.gather(ref_all_logprobs, 2, responses.unsqueeze(-1)).squeeze(-1) + del ref_output, ref_logits, ref_all_logprobs + torch.cuda.empty_cache() + + # **Response Processing** + # 1. truncate at the first occurrence of `truncate_token` that appears at or after + # position truncate_after in the responses + # https://github.com/openai/lm-human-preferences/blob/cbfd210bb8b08f6bc5c26878c10984b90f516c66/lm_human_preferences/train_policy.py#L378 + truncate_token_mask = responses == args.task.truncate_token + truncate_after_or_token_mask = torch.cat( + [ + torch.zeros_like(truncate_token_mask)[:, : args.task.truncate_after], + truncate_token_mask[:, args.task.truncate_after :], + ], + dim=1, + ) + truncate_mask = (torch.cumsum(truncate_after_or_token_mask, dim=1) - truncate_after_or_token_mask.long()).bool() + postprocessed_responses = torch.where( + truncate_mask, + torch.full_like(responses, tokenizer.pad_token_id), + responses, + ) + del truncate_token_mask, truncate_after_or_token_mask, truncate_mask + torch.cuda.empty_cache() + + # 2. run reward model on the truncated responses + postprocessed_query_responses = torch.cat((queries, postprocessed_responses), 1) + postprocessed_query_responses = right_padding_to_left_padding( + postprocessed_query_responses, tokenizer.pad_token_id + ) + scores = get_reward(reward_model, postprocessed_query_responses, tokenizer)[1] + last_response_indices = first_true_indices(query_responses == tokenizer.pad_token_id) - 1 + last_response_indices = torch.max( + last_response_indices, + torch.zeros([1], dtype=last_response_indices.dtype, device=query_responses.device), + ) + scores = scores[:, :, 0].gather(1, last_response_indices.unsqueeze(1)).view(-1) + + # # 3. filter response. Ensure that the sample contains truncate_token + # # responses not passing that filter will receive a low (fixed) score + # # only query humans on responses that pass that filter + # matches_token = postprocessed_responses[:, args.task.truncate_after :] == args.task.truncate_token + # filter_mask = torch.any(matches_token, dim=-1) + # scores = torch.where( + # filter_mask, + # scores, + # torch.full_like(scores, args.task.penalty_reward_value), + # ) + # del matches_token, filter_mask + torch.cuda.empty_cache() + + # 4. compute rewards + kl = logprobs - ref_logprobs + non_score_reward = -kl_ctl.value * kl + rewards = non_score_reward.clone() + rewards[:, -1] += scores + + # 5. whiten rewards + if args.ppo.whiten_rewards: + rewards = whiten(rewards, shift_mean=False) + + if args.print_sample_output_freq > 0 and (update - 1) % args.print_sample_output_freq == 0: + try: + all_decode_queries = tokenizer.batch_decode(queries, skip_special_tokens=True) + all_postprocessed_query_responses = tokenizer.batch_decode( + postprocessed_query_responses, skip_special_tokens=True + ) + all_postprocessed_responses = [ + x[len(y) :] for x, y in zip(all_postprocessed_query_responses, all_decode_queries) + ] + + kl_sum = kl.sum(axis=1) + all_df = pd.DataFrame( + { + "query": all_decode_queries, + "response": all_postprocessed_responses, + "score": scores.float().cpu().numpy(), + "kl": kl_sum.float().cpu().numpy(), + "reward": (scores - kl_ctl.value * kl_sum).float().cpu().numpy(), + } + ) + if accelerator.is_main_process and args.track: + wandb.log({"query_responses": wandb.Table(dataframe=all_df)}, step=update) + print_rich_table("stuff", all_df[:4], console) + except Exception as e: + print(e) + del ( + all_decode_queries, + all_postprocessed_query_responses, + all_postprocessed_responses, + kl_sum, + all_df, + ) + del postprocessed_query_responses + torch.cuda.empty_cache() + + # 6. compute advantages and returns + lastgaelam = 0 + advantages_reversed = [] + gen_length = args.task.response_length + for t in reversed(range(gen_length)): + nextvalues = values[:, t + 1] if t < gen_length - 1 else 0.0 + delta = rewards[:, t] + args.ppo.gamma * nextvalues - values[:, t] + lastgaelam = delta + args.ppo.gamma * args.ppo.lam * lastgaelam + advantages_reversed.append(lastgaelam) + advantages = torch.stack(advantages_reversed[::-1], axis=1) + returns = advantages + values + advantages = whiten(advantages) + return_mean, return_var = returns.mean(), returns.var() + value_mean, value_var = values.mean(), values.var() + + # Do multiple epochs of PPO training, with a fresh random shuffle in each epoch + for ppo_epoch_idx in range(args.ppo.noptepochs): + b_inds = np.random.permutation(args.ppo.local_batch_size) + minibatch_idx = 0 + for mini_batch_start in range(0, args.ppo.local_batch_size, args.ppo.local_mini_batch_size): + mini_batch_end = mini_batch_start + args.ppo.local_mini_batch_size + mini_batch_inds = b_inds[mini_batch_start:mini_batch_end] + gradient_accumulation_idx = 0 + for micro_batch_start in range(0, args.ppo.local_mini_batch_size, args.ppo.local_micro_batch_size): + with accelerator.accumulate(policy): + micro_batch_end = micro_batch_start + args.ppo.local_micro_batch_size + micro_batch_inds = mini_batch_inds[micro_batch_start:micro_batch_end] + mb_return = returns[micro_batch_inds] + mb_advantage = advantages[micro_batch_inds] + mb_values = values[micro_batch_inds] + mb_responses = responses[micro_batch_inds] + mb_query_responses = query_responses[micro_batch_inds] + mb_logprobs = logprobs[micro_batch_inds] + + output, vpred_temp = forward(policy, mb_query_responses, tokenizer) + logits = output.logits[:, context_length - 1 : -1] + logits /= args.task.temperature + new_all_logprobs = F.log_softmax(logits, dim=-1) + new_logprobs = torch.gather(new_all_logprobs, 2, mb_responses.unsqueeze(-1)).squeeze(-1) + vpred = vpred_temp[:, context_length - 1 : -1].squeeze(-1) + vpredclipped = torch.clamp( + vpred, + mb_values - args.ppo.cliprange_value, + mb_values + args.ppo.cliprange_value, + ) + vf_losses1 = torch.square(vpred - mb_return) + vf_losses2 = torch.square(vpredclipped - mb_return) + vf_loss = 0.5 * torch.max(vf_losses1, vf_losses2).mean() + vf_clipfrac = (vf_losses2 > vf_losses1).float().mean() + logprobs_diff = new_logprobs - mb_logprobs + ratio = torch.exp(logprobs_diff) + pg_losses = -mb_advantage * ratio + pg_losses2 = -mb_advantage * torch.clamp(ratio, 1.0 - args.ppo.cliprange, 1.0 + args.ppo.cliprange) + pg_loss = torch.max(pg_losses, pg_losses2).mean() + pg_clipfrac = (pg_losses2 > pg_losses).float().mean() + loss = pg_loss + args.ppo.vf_coef * vf_loss + accelerator.backward(loss) + optimizer.step() + optimizer.zero_grad() + prob_dist = torch.nn.functional.softmax(logits, dim=-1) + entropy = torch.logsumexp(logits, dim=-1) - torch.sum(prob_dist * logits, dim=-1) + approxkl = 0.5 * (logprobs_diff**2).mean() + with torch.no_grad(): + approxkls_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = approxkl + clipfracs_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = pg_clipfrac + pg_losses_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = pg_loss + vf_losses_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = vf_loss + vf_clipfrac_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = vf_clipfrac + entropies_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = entropy.mean() + gradient_accumulation_idx += 1 + minibatch_idx += 1 + if accelerator.is_main_process: + console.print( + f"ppo_epoch_idx", + ppo_epoch_idx, + "approxkl", + approxkl.item(), + "pg_loss", + pg_loss.item(), + "pg_clipfrac", + pg_clipfrac.item(), + "ratio", + ratio.mean().item(), + ) + + with torch.no_grad(): + if not args.deepspeed: # for some reason there is a OOM with the `writer.add_histogram` + writer.add_histogram("ppo/val/ratio_hist", ratio, update) + kl = logprobs - ref_logprobs + mean_kl = kl.sum(1).mean() + mean_entropy = (-logprobs).sum(1).mean() + mean_non_score_reward = non_score_reward.sum(1).mean() + writer.add_scalar("objective/kl_coef", kl_ctl.value, update) + writer.add_scalar("objective/kl", accelerator.gather(mean_kl).mean().item(), update) + writer.add_scalar("objective/entropy", accelerator.gather(mean_entropy).mean().item(), update) + writer.add_scalar("objective/non_score_reward", accelerator.gather(mean_non_score_reward).mean().item(), update) + writer.add_scalar( + "objective/score_total", accelerator.gather(mean_non_score_reward + scores.mean()).mean().item(), update + ) + writer.add_scalar("objective/scores", accelerator.gather(scores.mean()).mean().item(), update) + writer.add_scalar("ppo/loss/policy", accelerator.gather(pg_loss).mean().item(), update) + writer.add_scalar("ppo/loss/value", accelerator.gather(vf_loss).mean().item(), update) + writer.add_scalar("ppo/loss/total", accelerator.gather(loss).mean().item(), update) + writer.add_scalar("ppo/policy/entropy", accelerator.gather(entropy.mean()).mean().item(), update) + writer.add_scalar("ppo/policy/approxkl", accelerator.gather(approxkl).mean().item(), update) + writer.add_scalar("ppo/policy/clipfrac", accelerator.gather(pg_clipfrac).mean().item(), update) + writer.add_scalar("ppo/policy/approxkl_avg", accelerator.gather(approxkls_stats).mean().item(), update) + writer.add_scalar("ppo/policy/clipfrac_avg", accelerator.gather(clipfracs_stats).mean().item(), update) + writer.add_scalar("ppo/loss/policy_avg", accelerator.gather(pg_losses_stats).mean().item(), update) + writer.add_scalar("ppo/loss/value_avg", accelerator.gather(vf_losses_stats).mean().item(), update) + writer.add_scalar("ppo/val/clipfrac_avg", accelerator.gather(vf_clipfrac_stats).mean().item(), update) + writer.add_scalar("ppo/policy/entropy_avg", accelerator.gather(entropies_stats).mean().item(), update) + writer.add_scalar("ppo/returns/mean", accelerator.gather(return_mean).mean().item(), update) + writer.add_scalar("ppo/returns/var", accelerator.gather(return_var).mean().item(), update) + writer.add_scalar("ppo/val/vpred", accelerator.gather(vpred.mean()).mean().item(), update) + writer.add_scalar("ppo/val/error", accelerator.gather(vf_losses1.mean()).mean().item(), update) + writer.add_scalar("ppo/val/clipfrac", accelerator.gather(vf_clipfrac).mean().item(), update) + writer.add_scalar("ppo/val/mean", accelerator.gather(value_mean).mean().item(), update) + writer.add_scalar("ppo/val/var", accelerator.gather(value_var).mean().item(), update) + writer.add_scalar("ppo/val/ratio", accelerator.gather(ratio.mean()).mean().item(), update) + writer.add_scalar("ppo/val/ratio_var", accelerator.gather(ratio.mean()).var().item(), update) + writer.add_scalar("ppo/val/advantage", accelerator.gather(advantages.mean()).mean().item(), update) + writer.add_scalar("ppo/val/advantage_var", accelerator.gather(advantages.mean()).var().item(), update) + writer.add_scalar("ppo/val/num_eos_tokens", (responses == tokenizer.eos_token_id).sum().item(), update) + writer.add_scalar("ppo/lr", lrnow, update) + writer.add_scalar("ppo/episode", global_step, update) + kl_ctl.update(mean_kl.item(), args.ppo.batch_size) + del kl, mean_kl, mean_entropy, mean_non_score_reward, scores + + # save model + if accelerator.is_main_process and args.save_path: + os.makedirs(os.path.dirname(args.save_path), exist_ok=True) + torch.save(policy.state_dict(), args.save_path) + + if args.upload_model: + repo_name = f"{args.exp_name}__{args.rewards.label_dataset}__seed{args.seed}__{int(time.time())}" + repo_id = f"{args.hf_entity}/{repo_name}" if args.hf_entity else repo_name + policy.lm_backbone.save_pretrained(repo_id, safe_serialization=True, push_to_hub=True) + tokenizer.save_pretrained(repo_id, push_to_hub=True) + + +if __name__ == "__main__": + args = tyro.cli(Args) + train(args) diff --git a/lm_human_preference_details/train_reward_accelerate_summarize.py b/lm_human_preference_details/train_reward_accelerate_summarize.py new file mode 100644 index 0000000..572ea5b --- /dev/null +++ b/lm_human_preference_details/train_reward_accelerate_summarize.py @@ -0,0 +1,817 @@ +import os +import random +import time +from dataclasses import asdict, dataclass, field +from types import SimpleNamespace +from typing import List, Optional + +import numpy as np +import pandas as pd +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +import transformers +import tyro +from accelerate import Accelerator +from accelerate.state import AcceleratorState +from accelerate.utils import DistributedDataParallelKwargs, broadcast +from datasets import load_dataset +from rich.console import Console +from rich.pretty import pprint +from rich.table import Table +from torch import Tensor, optim +from torch.optim.optimizer import ( + _dispatch_sqrt, + _get_value, + _use_grad_for_differentiable, +) +from torch.utils.data import DataLoader +from torch.utils.tensorboard import SummaryWriter +from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig + +from lm_human_preference_details.data import process_query + + +@dataclass +class LabelHParams: + type: str = None + num_train: int = 64832 + num_labels: int = 2 + source: str = None + + +@dataclass +class TaskHParams: + # Query params + query_length: int = 512 + query_dataset: str = "tldr_3_filtered" + + query_format_str: Optional[str] = "SUBREDDIT: r/{subreddit}\n\nTITLE: {title}\n\nPOST: {post}\n\nTL;DR:" + query_truncate_field: Optional[str] = "post" + query_truncate_text: Optional[str] = "\n" + query_padding: Optional[str] = None # defaults to repeated spaces + query_pad_side: Optional[str] = "left" + + # Response params + response_length: int = 48 + + # LM params + temperature: float = 0.7 + + +# a patch +@dataclass +class TaskQueryHParams: + length: int = None + dataset: str = None + format_str: Optional[str] = None # if underlying dataset yields dicts, can format arbitrarily + truncate_field: Optional[str] = None + truncate_text: Optional[str] = None + padding: Optional[str] = None # defaults to repeated spaces + pad_side: Optional[str] = None + + +@dataclass +class Args: + # common args + exp_name: str = os.path.basename(__file__)[: -len(".py")] + """the name of this experiment""" + seed: int = 1 + """seed of the experiment""" + track: bool = False + """if toggled, this experiment will be tracked with Weights and Biases""" + wandb_project_name: str = "cleanrl" + """the wandb's project name""" + wandb_entity: Optional[str] = None + """the entity (team) of wandb's project""" + cuda: bool = True + """Whether to use cuda if available.""" + run_name: tyro.conf.Suppress[str] = None + """TO BE FILLED: a unique name of this run""" + + base_model: str = "gpt2" + """the name of the pretrained model to use""" + deepspeed: bool = False + """Whether to use deepspeed to train the model""" + label_dataset: str = "sentiment/offline_5k.json" + """the name of the dataset to use for labels in `https://huggingface.co/datasets/vwxyzjn/lm-human-preferences`""" + local_batch_size: int = 4 + """per rank batch size""" + gradient_accumulation_steps: int = 1 + """gradient accumulation steps""" + local_micro_batch_size: tyro.conf.Suppress[int] = None + """per rank micro batch size""" + lr: float = 0.00005 + """the learning rate""" + eps: float = 1e-5 + """the epsilon for AdamW""" + local_rollout_batch_size: int = 512 + """per rank rollout batch size""" + rollout_batch_size: tyro.conf.Suppress[int] = None + """rollout batch size""" + world_size: tyro.conf.Suppress[int] = None + """the number of processes to use""" + batch_size: tyro.conf.Suppress[int] = None + """the batch size across all ranks""" + local_normalize_samples: int = 256 + """Samples used to estimate reward mean and std""" + normalize_samples: tyro.conf.Suppress[int] = None + """Samples used to estimate reward mean and std across all ranks""" + debug_normalize: int = 0 + """Samples used to check that normalization worked""" + normalize_before: bool = True + """Whether, before training, to normalize the rewards on the policy to the scales on the training buffer. (For comparisons, just use mean 0, var 1.)""" + normalize_after: bool = True + """Whether, after training, to normalize the rewards on the ref policy to mean 0, var 1 (so the KL coefficient always has the same meaning).""" + print_sample_output_freq: int = 10 + """How often to print sample output""" + save_path: str = "models/reward.pt" + """Where to save the model""" + use_tensorflow_adam: bool = True + """Whether to use tensorflow-style Adam optimizer instead of PyTorch's""" + task: TaskHParams = field(default_factory=TaskHParams) + labels: LabelHParams = field(default_factory=LabelHParams) + + +OPENAI_PAD_TOKEN_ID = 50259 + + +def first_true_indices(bools, dtype=torch.long): + """ + Takes an N-dimensional bool tensor and returns an (N-1)-dimensional tensor of integers giving + the position of the first True in each "row". + + Returns the length of the rows (bools.size(-1)) if no element is True in a given row. + """ + row_len = bools.size(-1) + zero_or_index = row_len * (~bools).type(dtype) + torch.arange(row_len, dtype=dtype, device=bools.device) + return torch.min(zero_or_index, dim=-1).values + + +def print_rich_table(title: str, df: pd.DataFrame, console: Console) -> Table: + table = Table(show_lines=True) + for column in df.columns: + table.add_column(column) + for _, row in df.iterrows(): + table.add_row(*row.astype(str).tolist()) + console.rule(f"[bold red]{title}") + console.print(table) + + +def _single_tensor_adam( + params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_avg_sqs: List[Tensor], + max_exp_avg_sqs: List[Tensor], + state_steps: List[Tensor], + grad_scale: Optional[Tensor], + found_inf: Optional[Tensor], + *, + amsgrad: bool, + beta1: float, + beta2: float, + lr: float, + weight_decay: float, + eps: float, + maximize: bool, + capturable: bool, + differentiable: bool, +): + assert grad_scale is None and found_inf is None + + for i, param in enumerate(params): + grad = grads[i] if not maximize else -grads[i] + exp_avg = exp_avgs[i] + exp_avg_sq = exp_avg_sqs[i] + step_t = state_steps[i] + # update step + step_t += 1 + # Decay the first and second moment running average coefficient + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2) + step = _get_value(step_t) + + ### pytorch adam implementation: + # bias_correction1 = 1 - beta1 ** step + # bias_correction2 = 1 - beta2 ** step + # step_size = lr / bias_correction1 + # bias_correction2_sqrt = _dispatch_sqrt(bias_correction2) + # denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps) + # param.addcdiv_(exp_avg, denom, value=-step_size) + + ### tensorflow adam implementation: + lr_t = lr * _dispatch_sqrt(1 - beta2**step) / (1 - beta1**step) + denom = exp_avg_sq.sqrt().add_(eps) + param.addcdiv_(exp_avg, denom, value=-lr_t) + + +def adam( + params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_avg_sqs: List[Tensor], + max_exp_avg_sqs: List[Tensor], + state_steps: List[Tensor], + # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627 + # setting this as kwarg for now as functional API is compiled by torch/distributed/optim + foreach: Optional[bool] = None, + capturable: bool = False, + differentiable: bool = False, + fused: Optional[bool] = None, + grad_scale: Optional[Tensor] = None, + found_inf: Optional[Tensor] = None, + *, + amsgrad: bool, + beta1: float, + beta2: float, + lr: float, + weight_decay: float, + eps: float, + maximize: bool, +): + func = _single_tensor_adam + + func( + params, + grads, + exp_avgs, + exp_avg_sqs, + max_exp_avg_sqs, + state_steps, + amsgrad=amsgrad, + beta1=beta1, + beta2=beta2, + lr=lr, + weight_decay=weight_decay, + eps=eps, + maximize=maximize, + capturable=capturable, + differentiable=differentiable, + grad_scale=grad_scale, + found_inf=found_inf, + ) + + +class AdamTensorFlowStyle(optim.Adam): + @_use_grad_for_differentiable + def step(self, closure=None): + self._cuda_graph_capture_health_check() + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + params_with_grad = [] + grads = [] + exp_avgs = [] + exp_avg_sqs = [] + max_exp_avg_sqs = [] + state_steps = [] + beta1, beta2 = group["betas"] + + self._init_group( + group, + params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + max_exp_avg_sqs, + state_steps, + ) + + adam( + params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + max_exp_avg_sqs, + state_steps, + amsgrad=group["amsgrad"], + beta1=beta1, + beta2=beta2, + lr=group["lr"], + weight_decay=group["weight_decay"], + eps=group["eps"], + maximize=group["maximize"], + foreach=group["foreach"], + capturable=group["capturable"], + differentiable=group["differentiable"], + fused=group["fused"], + grad_scale=getattr(self, "grad_scale", None), + found_inf=getattr(self, "found_inf", None), + ) + + return loss + + +def layer_init(layer, std=np.sqrt(2), bias_const=0.0): + torch.nn.init.normal_(layer.weight, std=std) + torch.nn.init.constant_(layer.bias, val=bias_const) + return layer + + +class AutoModelForCausalLMWithRewardHead(nn.Module): + def __init__(self, lm_backbone): + super().__init__() + self.lm_backbone = lm_backbone + self.scalar_head = layer_init( + nn.Linear(lm_backbone.config.hidden_size, 1), + std=1 / np.sqrt(lm_backbone.config.hidden_size + 1), + ) + self.reward_gain = torch.nn.Parameter(torch.tensor(1.0), requires_grad=True) + self.reward_bias = torch.nn.Parameter(torch.tensor(0.0), requires_grad=True) + + def forward(self, **kwargs): + output = self.lm_backbone(**kwargs) + reward_latents = output.hidden_states[-1] + # shape: [batch_size, length, hidden_size] + last_reward_latents = reward_latents + # shape: [batch_size, hidden_size] + reward = self.scalar_head(last_reward_latents) + # shape: [batch_size, 1] + reward = self.reward_gain * reward + self.reward_bias + return output, reward + + +def right_padding_to_left_padding(tokens, pad_id): + """Convert from right padding to left padding.""" + assert tokens.ndim == 2 + return torch.tensor( + [[pad_id] * (row == pad_id).sum() + [x for x in row if x != pad_id] for row in tokens], + device=tokens.device, + ) + + +def ceil_div(a, b): + return (a - 1) // b + 1 + + +def exact_div(a, b): + q = a // b + if a != q * b: + raise ValueError(f"Inexact division: {a} / {b} = {a / b}") + return q + + +def generate(lm_backbone, queries, args, generation_config): + """generate in a way that does not affect padding tokens""" + context_length = queries.shape[1] + attention_mask = queries != args.pad_token_id + input_ids = queries.clone() + input_ids[~attention_mask] = 0 # set padding tokens to 0 + output = lm_backbone.generate( + input_ids=input_ids, + attention_mask=attention_mask, + # position_ids=attention_mask.cumsum(1) - attention_mask.long(), # generation collapsed if this was turned on. TODO: why does generation collapse with this? + generation_config=generation_config, + return_dict_in_generate=True, + ) + # restore padding tokens + return torch.cat((queries, output.sequences[:, context_length:]), dim=1) + + +def get_reward(reward_model, query_responses, args): + attention_mask = query_responses != args.pad_token_id + position_ids = attention_mask.cumsum(1) - attention_mask.long() # exclusive cumsum + input_ids = query_responses.clone() + input_ids[~attention_mask] = 0 + return reward_model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + return_dict=True, + output_hidden_states=True, + ) + + +def normalize( + args, + accelerator, + device, + lm_backbone, + reward_model, + iter_dataloader, + generation_config, +): + with torch.no_grad(): + # reset reward scales + accelerator.unwrap_model(reward_model).reward_gain.data.fill_(1.0) + accelerator.unwrap_model(reward_model).reward_bias.data.fill_(0.0) + # number of minibatches for computing the normalization statistics + n_batches = ceil_div(args.local_normalize_samples, args.local_rollout_batch_size) + sample_queries_responses = [] + for _ in range(n_batches): + data = next(iter_dataloader) + queries = data["query_token"].to(device) + queries = right_padding_to_left_padding(data["query_token"], args.pad_token_id).to(device) + query_responses = generate(lm_backbone, queries, args, generation_config) + sample_queries_responses.append(query_responses) + + # compute reward statistics + rewards = [] + for query_responses in sample_queries_responses: + rewards.append(get_reward(reward_model, query_responses, args)[1]) + rewards = torch.cat(rewards) + rewards = accelerator.gather(rewards) + mean, std = rewards.mean(), rewards.std() + print(f"mean: {mean}, std: {std}") + + # reward normalization + target_mean, target_std = torch.tensor(0.0, device=device), torch.tensor(1.0, device=device) + gain = target_std / std + bias = target_mean - gain * mean + print(f"gain: {gain}, bias: {bias}") + accelerator.unwrap_model(reward_model).reward_gain.data = gain + accelerator.unwrap_model(reward_model).reward_bias.data = bias + + # validate normalization + n_batches = ceil_div(args.local_normalize_samples, args.local_rollout_batch_size) + sample_queries_responses = [] + for _ in range(n_batches): + data = next(iter_dataloader) + queries = data["query_token"].to(device) + queries = right_padding_to_left_padding(data["query_token"], args.pad_token_id).to(device) + query_responses = generate(lm_backbone, queries, args, generation_config) + sample_queries_responses.append(query_responses) + rewards = [] + for query_responses in sample_queries_responses: + rewards.append(get_reward(reward_model, query_responses, args)[1]) + rewards = torch.cat(rewards) + rewards = accelerator.gather(rewards) + mean, std = rewards.mean(), rewards.std() + print(f"after mean: {mean}, after std: {std}") + + +# def train(args: Args): +if __name__ == "__main__": + args = tyro.cli(Args) + accelerator = Accelerator( + kwargs_handlers=[ + DistributedDataParallelKwargs( + broadcast_buffers=False, + ) + ], # this is needed to avoid https://github.com/pytorch/pytorch/issues/22095#issuecomment-505099500 + gradient_accumulation_steps=args.gradient_accumulation_steps, + ) + args.world_size = accelerator.num_processes + args.batch_size = int(args.local_batch_size * args.world_size) + args.rollout_batch_size = int(args.local_rollout_batch_size * args.world_size) + args.local_micro_batch_size = exact_div(args.local_batch_size, args.gradient_accumulation_steps) + patch_h = TaskQueryHParams( + length=args.task.query_length, + dataset=args.task.query_dataset, + format_str=args.task.query_format_str, + truncate_field=args.task.query_truncate_field, + truncate_text=args.task.query_truncate_text, + padding=args.task.query_padding, + pad_side=args.task.query_pad_side, + ) + + console = Console(force_terminal=True) + run_name = f"{args.exp_name}__{args.seed}__{int(time.time())}" + writer = SimpleNamespace() # dummy writer + writer.add_scalar = lambda x, y, z: None + if accelerator.is_main_process: + if args.track: + import wandb + + wandb.init( + project=args.wandb_project_name, + entity=args.wandb_entity, + sync_tensorboard=True, + config=asdict(args), + name=run_name, + save_code=True, + ) + wandb.run.log_code(".") + writer = SummaryWriter(f"runs/{run_name}") + writer.add_text( + "hyperparameters", + "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])), + ) + pprint(args) + device = accelerator.device + local_seed = args.seed + accelerator.process_index * 100003 # Prime + random.seed(local_seed) + np.random.seed(local_seed) + torch.manual_seed(local_seed) + torch.backends.cudnn.deterministic = True + tokenizer = AutoTokenizer.from_pretrained( + args.base_model, + padding_side="right", + trust_remote_code=True, + ) + # we use the padding token manually but do not resize the token embedding of the model + tokenizer.add_special_tokens({"pad_token": "[PAD]"}) + args.pad_token_id = tokenizer.pad_token_id + untrained_model = AutoModelForCausalLMWithRewardHead( + AutoModelForCausalLM.from_pretrained(args.base_model, trust_remote_code=True) + ) + reward_model = AutoModelForCausalLMWithRewardHead( + AutoModelForCausalLM.from_pretrained(args.base_model, trust_remote_code=True) + ) + untrained_model.lm_backbone.generation_config.eos_token_id = ( + None # disable `pad_token_id` and `eos_token_id` because we just want to + ) + untrained_model.lm_backbone.generation_config.pad_token_id = None # generate tokens without truncation / padding + reward_model.lm_backbone.generation_config.eos_token_id = ( + None # disable `pad_token_id` and `eos_token_id` because we just want to + ) + reward_model.lm_backbone.generation_config.pad_token_id = None # generate tokens without truncation / padding + # make sure the `lm_head` or `embed_out` does not require gradients, otherwise + # pytorch DDP complains; see https://gist.github.com/vwxyzjn/45fc8706dfb3cf33695f0f57cc44a533 + if isinstance(reward_model.lm_backbone, transformers.GPTNeoXForCausalLM): + reward_model.lm_backbone.embed_out.requires_grad_(False) + if args.use_tensorflow_adam: + optimizer = AdamTensorFlowStyle(reward_model.parameters(), lr=args.lr, eps=args.eps) + else: + optimizer = optim.Adam(reward_model.parameters(), lr=args.lr, eps=args.eps) + dataset = load_dataset("vwxyzjn/summarize_from_feedback_tldr_3_filtered", split="train") + + def process_query_data(x): + return { + **process_query(x, encoder=tokenizer, hparams=patch_h), + } + + dataset = dataset.map(process_query_data) + dataset = dataset.with_format("torch", columns=["query_token"]) + dataset = dataset.shuffle(seed=local_seed) + dataloader = DataLoader(dataset, batch_size=args.local_rollout_batch_size) + reward_model, optimizer, dataloader = accelerator.prepare(reward_model, optimizer, dataloader) + if args.deepspeed: + import deepspeed + + deepspeed_states = AcceleratorState().deepspeed_plugin + # deepspeed_states.deepspeed_config["train_micro_batch_size_per_gpu"] = args.ppo.local_micro_batch_size + # deepspeed_states.deepspeed_config["checkpoint"] = {"use_node_local_storage": True} + eval_ds_config = { + "train_micro_batch_size_per_gpu": deepspeed_states.deepspeed_config["train_micro_batch_size_per_gpu"], + # "steps_per_print": 10, + # "zero_optimization": { + # "stage": stage, + # "stage3_param_persistence_threshold": 1e4, + # "offload_param": { + # "device": off_load_device + # } + # }, + "bf16": {"enabled": True}, + "prescale_gradients": False, + "wall_clock_breakdown": False, + } + untrained_model, *_ = deepspeed.initialize(model=untrained_model, config=eval_ds_config) + untrained_model.eval() + else: + untrained_model = untrained_model.to(device) + + def repeat_generator(): # TODO: ideally we shuffle the dataloader as well + while True: + yield from dataloader + + iter_dataloader = iter(repeat_generator()) + generation_config = GenerationConfig( + max_new_tokens=args.task.response_length, + min_new_tokens=args.task.response_length, + temperature=args.task.temperature, + top_k=0.0, + top_p=1.0, + do_sample=True, + ) + + if args.normalize_before: + print("===Normalize reward model *before* training===") + print( + "before normalization. " + + f"Gain: {accelerator.unwrap_model(reward_model).reward_gain.data}" + + f" Bias: {accelerator.unwrap_model(reward_model).reward_bias.data}" + ) + + normalize( + args, + accelerator, + device, + untrained_model.lm_backbone, + reward_model, + iter_dataloader, + generation_config, + ) + print( + "after normalization. " + + f"Gain: {accelerator.unwrap_model(reward_model).reward_gain.data}" + + f" Bias: {accelerator.unwrap_model(reward_model).reward_bias.data}" + ) + + # `label` has keys `['sample0', 'query', 'best', 'sample3', 'sample1', 'sample2']` + label = load_dataset("openai/summarize_from_feedback", "comparisons", split="train") + test_label = load_dataset("openai/summarize_from_feedback", "comparisons", split="validation") + print("Num labels found in source:", len(label)) + print("training on", args.labels.num_train, "in batches of", args.local_batch_size) + + def process_response_data(x): + return { + **process_query(x["info"], encoder=tokenizer, hparams=patch_h), + "response0_token": tokenizer.encode( + x["summaries"][0]["text"], padding="max_length", max_length=args.task.response_length, truncation=True + ), + "response1_token": tokenizer.encode( + x["summaries"][1]["text"], padding="max_length", max_length=args.task.response_length, truncation=True + ), + } + + label = label.map(process_response_data) + test_label = test_label.map(process_response_data) + # tokenizer.encode(label[0]["summaries"][0]["text"]) + + print("===training reward model===") + all_inds = np.random.permutation(args.labels.num_train) + # ensure that all processes have the same shuffled indices + all_inds = broadcast(torch.tensor(all_inds, device=device), 0) + all_inds = all_inds.cpu().numpy() + global_step = 0 + for start in range(0, args.labels.num_train, args.batch_size): + # linear rate annealing + lr = (1 - start / args.labels.num_train) * args.lr + optimizer.param_groups[0]["lr"] = lr + + global_step += 1 + end = start + args.batch_size + b_inds_all = all_inds[start:end] + b_inds = b_inds_all[accelerator.process_index :: accelerator.num_processes] # multi-GPU slicing + losses = torch.zeros((args.gradient_accumulation_steps,), device=device) + accuracies = torch.zeros((args.gradient_accumulation_steps,), device=device) + gradient_accumulation_step = 0 + for micro_batch_start in range(0, args.local_batch_size, args.local_micro_batch_size): + with accelerator.accumulate(reward_model): + micro_batch_end = micro_batch_start + args.local_micro_batch_size + micro_batch_inds = b_inds[micro_batch_start:micro_batch_end] + mb_data = label[micro_batch_inds] + mb_query = torch.from_numpy(np.stack(mb_data["query_token"])).to(device) + mb_best = torch.from_numpy(np.stack(mb_data["choice"])).to(device) + mb_responses = [ + torch.from_numpy(np.stack(mb_data[f"response{i}_token"])).to(device) for i in range(args.labels.num_labels) + ] + predicted_rewards = [] + for i in range(args.labels.num_labels): + query_responses = torch.cat([mb_query, mb_responses[i]], dim=1) + reward = get_reward(reward_model, query_responses, args)[1] + last_response_indices = first_true_indices(query_responses == args.pad_token_id) - 1 + last_response_indices = torch.max( + last_response_indices, + torch.zeros([1], dtype=last_response_indices.dtype, device=query_responses.device), + ) + predicted_rewards.append(reward[:, :, 0].gather(1, last_response_indices.unsqueeze(1)).view(-1)) + predicted_rewards = torch.stack( + predicted_rewards, dim=1 + ) # shape (batch_size, num_labels), basically a reward prediction for each label + reward_preferred = predicted_rewards.gather(1, mb_best.view(-1, 1)).view(-1) + reward_rejected = predicted_rewards.gather(1, (1 - mb_best).view(-1, 1)).view(-1) + accuracy = (predicted_rewards.argmax(1) == mb_best).float().mean() + loss = -nn.functional.logsigmoid(reward_preferred - reward_rejected).mean() + # loss = torch.nn.functional.cross_entropy(predicted_rewards, mb_best) + accelerator.backward(loss) + optimizer.step() # accelerate handles gradient accumulation automatically + optimizer.zero_grad() + losses[gradient_accumulation_step] = loss + accuracies[gradient_accumulation_step] = accuracy + gradient_accumulation_step += 1 + + train_accuracy = accelerator.gather(accuracies).mean().item() + writer.add_scalar("train/loss", accelerator.gather(losses).mean().item(), global_step) + writer.add_scalar("train/accuracy", train_accuracy, global_step) + writer.add_scalar("train/lr", lr, global_step) + print("train/accuracy", train_accuracy) + + if args.print_sample_output_freq > 0 and global_step % args.print_sample_output_freq == 0: + with torch.no_grad(): + # eval on test_label, some duplicate code (I don't want to make the training loop into a function...) + test_accuracies = [] + len_labels = (len(test_label) // args.batch_size) * args.batch_size # in case the last batch is not full + new_all_inds = np.arange(len_labels) + for start in range(0, len_labels, args.batch_size): + end = start + args.batch_size + b_inds_all = new_all_inds[start:end] + b_inds = b_inds_all[accelerator.process_index :: accelerator.num_processes] # multi-GPU slicing + for micro_batch_start in range(0, args.local_batch_size, args.local_micro_batch_size): + micro_batch_end = micro_batch_start + args.local_micro_batch_size + micro_batch_inds = b_inds[micro_batch_start:micro_batch_end] + mb_data = label[micro_batch_inds] + mb_query = torch.from_numpy(np.stack(mb_data["query_token"])).to(device) + mb_best = torch.from_numpy(np.stack(mb_data["choice"])).to(device) + mb_responses = [ + torch.from_numpy(np.stack(mb_data[f"response{i}_token"])).to(device) + for i in range(args.labels.num_labels) + ] + predicted_rewards = [] + for i in range(args.labels.num_labels): + query_responses = torch.cat([mb_query, mb_responses[i]], dim=1) + reward = get_reward(reward_model, query_responses, args)[1] + last_response_indices = first_true_indices(query_responses == args.pad_token_id) - 1 + last_response_indices = torch.max( + last_response_indices, + torch.zeros([1], dtype=last_response_indices.dtype, device=query_responses.device), + ) + predicted_rewards.append(reward[:, :, 0].gather(1, last_response_indices.unsqueeze(1)).view(-1)) + predicted_rewards = torch.stack( + predicted_rewards, dim=1 + ) # shape (batch_size, num_labels), basically a reward prediction for each label + reward_preferred = predicted_rewards.gather(1, mb_best.view(-1, 1)).view(-1) + reward_rejected = predicted_rewards.gather(1, (1 - mb_best).view(-1, 1)).view(-1) + accuracy = (predicted_rewards.argmax(1) == mb_best).float().mean() + test_accuracies.append(accuracy) + test_accuracy = accelerator.gather(torch.stack(test_accuracies).mean()).mean().item() + writer.add_scalar("test/accuracy", test_accuracy, global_step) + if accelerator.is_main_process: + print("test/accuracy", test_accuracy, global_step) + + # the part below is testing out some generations and KLs, not presented in the original code + data = next(iter_dataloader) + queries = data["query_token"].to(device) + context_length = queries.shape[1] + queries = right_padding_to_left_padding(data["query_token"], args.pad_token_id).to(device) + query_responses = generate( + accelerator.unwrap_model(reward_model).lm_backbone, + queries, + args, + generation_config, + ) + responses = query_responses[:, context_length:] + + output, reward = get_reward(reward_model, query_responses, args) + logits = output.logits[:, context_length - 1 : -1] + logits /= args.task.temperature + all_logprobs = F.log_softmax(logits, dim=-1) + logprobs = torch.gather(all_logprobs, 2, responses.unsqueeze(-1)).squeeze(-1) + del output, logits, all_logprobs + torch.cuda.empty_cache() + + output, _ = get_reward(untrained_model, query_responses, args) + logits = output.logits[:, context_length - 1 : -1] + logits /= args.task.temperature + all_logprobs = F.log_softmax(logits, dim=-1) + ref_logprobs = torch.gather(all_logprobs, 2, responses.unsqueeze(-1)).squeeze(-1) + del output, logits, all_logprobs + torch.cuda.empty_cache() + + kl = logprobs - ref_logprobs + kl_sum = kl.sum(axis=1) + all_decode_queries = tokenizer.batch_decode(queries, skip_special_tokens=True) + all_query_responses = tokenizer.batch_decode(query_responses, skip_special_tokens=True) + all_responses = [x[len(y) :] for x, y in zip(all_query_responses, all_decode_queries)] + all_df = pd.DataFrame( + { + "query": all_decode_queries, + "response": all_responses, + "kl": kl_sum.float().cpu().numpy(), + } + ) + if accelerator.is_main_process and args.track: + wandb.log({"query_responses": wandb.Table(dataframe=all_df)}, step=global_step) + print_rich_table(f"Sample Output at Step {global_step}", all_df[:4], console) + del ( + query_responses, + all_decode_queries, + all_query_responses, + all_responses, + kl_sum, + all_df, + ) + writer.add_scalar("train/kl", kl.sum(1).mean().item(), global_step) + + torch.cuda.empty_cache() + if args.normalize_after: + print("===Normalize reward model *after* training===") + print( + "before normalization. " + + f"Gain: {accelerator.unwrap_model(reward_model).reward_gain.data}" + + f" Bias: {accelerator.unwrap_model(reward_model).reward_bias.data}" + ) + + normalize( + args, + accelerator, + device, + untrained_model.lm_backbone, + reward_model, + iter_dataloader, + generation_config, + ) + print( + "after normalization. " + + f"Gain: {accelerator.unwrap_model(reward_model).reward_gain.data}" + + f" Bias: {accelerator.unwrap_model(reward_model).reward_bias.data}" + ) + + # save model + if args.save_path: + os.makedirs(os.path.dirname(args.save_path), exist_ok=True) + torch.save(accelerator.unwrap_model(reward_model).state_dict(), args.save_path) + + if accelerator.is_main_process and args.track: + wandb.finish() + + +if __name__ == "__main__": + args = tyro.cli(Args) + train(args) From bc7d54cdad7d8c3867d81e17bd47b908d0a18d81 Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Fri, 22 Sep 2023 17:21:02 +0000 Subject: [PATCH 02/62] quick change --- .../train_both_accelerate_summarize.py | 32 +++++++++++++++++++ .../train_policy_accelerate_summarize.py | 7 ++-- .../train_reward_accelerate_summarize.py | 12 +++---- 3 files changed, 39 insertions(+), 12 deletions(-) create mode 100644 lm_human_preference_details/train_both_accelerate_summarize.py diff --git a/lm_human_preference_details/train_both_accelerate_summarize.py b/lm_human_preference_details/train_both_accelerate_summarize.py new file mode 100644 index 0000000..ff1dc4e --- /dev/null +++ b/lm_human_preference_details/train_both_accelerate_summarize.py @@ -0,0 +1,32 @@ +import os +import time +from dataclasses import dataclass, field + +import tyro +from train_policy_accelerate_summarize import Args as ArgsPolicy +from train_policy_accelerate_summarize import train as train_policy +from train_reward_accelerate_summarize import Args as ArgsReward +from train_reward_accelerate_summarize import train as train_reward + + +@dataclass +class Args: + exp_name: str = os.path.basename(__file__)[: -len(".py")] + """the name of this experiment""" + seed: int = 1 + """seed of the experiment""" + reward: ArgsReward = field(default_factory=ArgsReward) + policy: ArgsPolicy = field(default_factory=ArgsPolicy) + + +if __name__ == "__main__": + args = tyro.cli(Args) + run_name = f"{args.exp_name}__{args.seed}__{int(time.time())}" + args.reward.seed = args.seed + args.policy.seed = args.seed + args.reward.save_path = f"models/{run_name}/reward.pt" + args.policy.save_path = f"models/{run_name}/policy.pt" + args.policy.rewards.trained_model = args.reward.save_path + args.policy.rewards.label_dataset = args.reward.label_dataset + train_reward(args.reward) + train_policy(args.policy) diff --git a/lm_human_preference_details/train_policy_accelerate_summarize.py b/lm_human_preference_details/train_policy_accelerate_summarize.py index 1dbd02e..5a1d23f 100644 --- a/lm_human_preference_details/train_policy_accelerate_summarize.py +++ b/lm_human_preference_details/train_policy_accelerate_summarize.py @@ -440,10 +440,7 @@ def forward(policy, query_responses, tokenizer): ) -# def train(args: Args): -if __name__ == "__main__": - args = tyro.cli(Args) - +def train(args: Args): accelerator = Accelerator(gradient_accumulation_steps=args.ppo.gradient_accumulation_steps) args.ppo.world_size = accelerator.num_processes args.ppo.batch_size = int(args.ppo.local_batch_size * args.ppo.world_size) @@ -516,7 +513,7 @@ def forward(policy, query_responses, tokenizer): optimizer = AdamTensorFlowStyle(policy.parameters(), lr=args.ppo.lr, eps=args.ppo.eps) else: optimizer = optim.Adam(policy.parameters(), lr=args.ppo.lr, eps=args.ppo.eps) - dataset = load_dataset("bookcorpus", split="train") + dataset = load_dataset(args.task.query_dataset, split="train") dataset = dataset.shuffle(seed=local_seed) def process_query_data(x, base_model: str, response_length: int): # added args so it's hashable diff --git a/lm_human_preference_details/train_reward_accelerate_summarize.py b/lm_human_preference_details/train_reward_accelerate_summarize.py index 572ea5b..d826508 100644 --- a/lm_human_preference_details/train_reward_accelerate_summarize.py +++ b/lm_human_preference_details/train_reward_accelerate_summarize.py @@ -94,7 +94,7 @@ class Args: """the name of the pretrained model to use""" deepspeed: bool = False """Whether to use deepspeed to train the model""" - label_dataset: str = "sentiment/offline_5k.json" + label_dataset: str = "openai/summarize_from_feedback" """the name of the dataset to use for labels in `https://huggingface.co/datasets/vwxyzjn/lm-human-preferences`""" local_batch_size: int = 4 """per rank batch size""" @@ -124,7 +124,7 @@ class Args: """Whether, before training, to normalize the rewards on the policy to the scales on the training buffer. (For comparisons, just use mean 0, var 1.)""" normalize_after: bool = True """Whether, after training, to normalize the rewards on the ref policy to mean 0, var 1 (so the KL coefficient always has the same meaning).""" - print_sample_output_freq: int = 10 + print_sample_output_freq: int = 20 """How often to print sample output""" save_path: str = "models/reward.pt" """Where to save the model""" @@ -445,9 +445,7 @@ def normalize( print(f"after mean: {mean}, after std: {std}") -# def train(args: Args): -if __name__ == "__main__": - args = tyro.cli(Args) +def train(args: Args): accelerator = Accelerator( kwargs_handlers=[ DistributedDataParallelKwargs( @@ -604,8 +602,8 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well ) # `label` has keys `['sample0', 'query', 'best', 'sample3', 'sample1', 'sample2']` - label = load_dataset("openai/summarize_from_feedback", "comparisons", split="train") - test_label = load_dataset("openai/summarize_from_feedback", "comparisons", split="validation") + label = load_dataset(args.label_dataset, "comparisons", split="train") + test_label = load_dataset(args.label_dataset, "comparisons", split="validation") print("Num labels found in source:", len(label)) print("training on", args.labels.num_train, "in batches of", args.local_batch_size) From d1c5f862df90ec62e4514f1d4673ce872200acc9 Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Sat, 23 Sep 2023 14:38:46 +0000 Subject: [PATCH 03/62] quick change --- .../train_policy_accelerate_summarize.py | 49 ++++++++++--------- .../train_reward_accelerate_summarize.py | 11 +++-- 2 files changed, 35 insertions(+), 25 deletions(-) diff --git a/lm_human_preference_details/train_policy_accelerate_summarize.py b/lm_human_preference_details/train_policy_accelerate_summarize.py index 5a1d23f..cfad5f1 100644 --- a/lm_human_preference_details/train_policy_accelerate_summarize.py +++ b/lm_human_preference_details/train_policy_accelerate_summarize.py @@ -29,6 +29,7 @@ from torch.utils.tensorboard import SummaryWriter from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig +from lm_human_preference_details.data import process_query @dataclass class AdaptiveKLParams: @@ -75,7 +76,7 @@ class PpoHParams: class TaskHParams: # Query params query_length: int = 512 - query_dataset: str = "tldr_3_filtered" + query_dataset: str = "vwxyzjn/summarize_from_feedback_tldr_3_filtered" query_format_str: Optional[str] = "SUBREDDIT: r/{subreddit}\n\nTITLE: {title}\n\nPOST: {post}\n\nTL;DR:" query_truncate_field: Optional[str] = "post" @@ -447,6 +448,15 @@ def train(args: Args): args.ppo.minibatch_size = exact_div(args.ppo.batch_size, args.ppo.nminibatches) args.ppo.local_mini_batch_size = exact_div(args.ppo.local_batch_size, args.ppo.nminibatches) args.ppo.local_micro_batch_size = exact_div(args.ppo.local_mini_batch_size, args.ppo.gradient_accumulation_steps) + patch_h = TaskQueryHParams( + length=args.task.query_length, + dataset=args.task.query_dataset, + format_str=args.task.query_format_str, + truncate_field=args.task.query_truncate_field, + truncate_text=args.task.query_truncate_text, + padding=args.task.query_padding, + pad_side=args.task.query_pad_side, + ) if args.ppo.whiten_rewards: assert ( args.ppo.local_mini_batch_size >= 8 @@ -514,20 +524,15 @@ def train(args: Args): else: optimizer = optim.Adam(policy.parameters(), lr=args.ppo.lr, eps=args.ppo.eps) dataset = load_dataset(args.task.query_dataset, split="train") - dataset = dataset.shuffle(seed=local_seed) - def process_query_data(x, base_model: str, response_length: int): # added args so it's hashable - tokenizer = AutoTokenizer.from_pretrained(base_model) - tokenizer.add_special_tokens({"pad_token": "[PAD]"}) + def process_query_data(x): return { - "query_token": tokenizer( - x["text"], padding="max_length", max_length=response_length, truncation=True, return_tensors="pt" - )["input_ids"], + **process_query(x, encoder=tokenizer, hparams=patch_h), } - dataset.set_transform( - functools.partial(process_query_data, base_model=args.base_model, response_length=args.task.response_length) - ) + dataset = dataset.map(process_query_data) + dataset = dataset.with_format("torch", columns=["query_token"]) + dataset = dataset.shuffle(seed=local_seed) dataloader = DataLoader(dataset, batch_size=args.ppo.local_batch_size) policy, optimizer, dataloader = accelerator.prepare(policy, optimizer, dataloader) if args.deepspeed: @@ -653,17 +658,17 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well ) scores = scores[:, :, 0].gather(1, last_response_indices.unsqueeze(1)).view(-1) - # # 3. filter response. Ensure that the sample contains truncate_token - # # responses not passing that filter will receive a low (fixed) score - # # only query humans on responses that pass that filter - # matches_token = postprocessed_responses[:, args.task.truncate_after :] == args.task.truncate_token - # filter_mask = torch.any(matches_token, dim=-1) - # scores = torch.where( - # filter_mask, - # scores, - # torch.full_like(scores, args.task.penalty_reward_value), - # ) - # del matches_token, filter_mask + # 3. filter response. Ensure that the sample contains truncate_token + # responses not passing that filter will receive a low (fixed) score + # only query humans on responses that pass that filter + matches_token = postprocessed_responses[:, args.task.truncate_after :] == args.task.truncate_token + filter_mask = torch.any(matches_token, dim=-1) + scores = torch.where( + filter_mask, + scores, + torch.full_like(scores, args.task.penalty_reward_value), + ) + del matches_token, filter_mask torch.cuda.empty_cache() # 4. compute rewards diff --git a/lm_human_preference_details/train_reward_accelerate_summarize.py b/lm_human_preference_details/train_reward_accelerate_summarize.py index d826508..c29e562 100644 --- a/lm_human_preference_details/train_reward_accelerate_summarize.py +++ b/lm_human_preference_details/train_reward_accelerate_summarize.py @@ -124,7 +124,7 @@ class Args: """Whether, before training, to normalize the rewards on the policy to the scales on the training buffer. (For comparisons, just use mean 0, var 1.)""" normalize_after: bool = True """Whether, after training, to normalize the rewards on the ref policy to mean 0, var 1 (so the KL coefficient always has the same meaning).""" - print_sample_output_freq: int = 20 + print_sample_output_freq: int = 60 """How often to print sample output""" save_path: str = "models/reward.pt" """Where to save the model""" @@ -685,7 +685,8 @@ def process_response_data(x): with torch.no_grad(): # eval on test_label, some duplicate code (I don't want to make the training loop into a function...) test_accuracies = [] - len_labels = (len(test_label) // args.batch_size) * args.batch_size # in case the last batch is not full + eval_len = 200 # len(test_label) + len_labels = (eval_len // args.batch_size) * args.batch_size # in case the last batch is not full new_all_inds = np.arange(len_labels) for start in range(0, len_labels, args.batch_size): end = start + args.batch_size @@ -766,7 +767,11 @@ def process_response_data(x): ) if accelerator.is_main_process and args.track: wandb.log({"query_responses": wandb.Table(dataframe=all_df)}, step=global_step) - print_rich_table(f"Sample Output at Step {global_step}", all_df[:4], console) + try: + print_rich_table(f"Sample Output at Step {global_step}", all_df[:4], console) + except Exception as e: + print(e) + pass del ( query_responses, all_decode_queries, From 760d7922ed0df5d49edd9123266760fc2cd41287 Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Sat, 23 Sep 2023 10:41:34 -0400 Subject: [PATCH 04/62] quick fix --- .../train_policy_accelerate_summarize.py | 2 +- .../train_reward_accelerate_summarize.py | 7 +++---- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/lm_human_preference_details/train_policy_accelerate_summarize.py b/lm_human_preference_details/train_policy_accelerate_summarize.py index cfad5f1..2d8b456 100644 --- a/lm_human_preference_details/train_policy_accelerate_summarize.py +++ b/lm_human_preference_details/train_policy_accelerate_summarize.py @@ -1,4 +1,3 @@ -import functools import os import random import time @@ -31,6 +30,7 @@ from lm_human_preference_details.data import process_query + @dataclass class AdaptiveKLParams: target: float = 6.0 diff --git a/lm_human_preference_details/train_reward_accelerate_summarize.py b/lm_human_preference_details/train_reward_accelerate_summarize.py index c29e562..4c29fa8 100644 --- a/lm_human_preference_details/train_reward_accelerate_summarize.py +++ b/lm_human_preference_details/train_reward_accelerate_summarize.py @@ -45,7 +45,7 @@ class LabelHParams: class TaskHParams: # Query params query_length: int = 512 - query_dataset: str = "tldr_3_filtered" + query_dataset: str = "vwxyzjn/summarize_from_feedback_tldr_3_filtered" query_format_str: Optional[str] = "SUBREDDIT: r/{subreddit}\n\nTITLE: {title}\n\nPOST: {post}\n\nTL;DR:" query_truncate_field: Optional[str] = "post" @@ -527,7 +527,7 @@ def train(args: Args): optimizer = AdamTensorFlowStyle(reward_model.parameters(), lr=args.lr, eps=args.eps) else: optimizer = optim.Adam(reward_model.parameters(), lr=args.lr, eps=args.eps) - dataset = load_dataset("vwxyzjn/summarize_from_feedback_tldr_3_filtered", split="train") + dataset = load_dataset(args.task.query_dataset, split="train") def process_query_data(x): return { @@ -685,7 +685,7 @@ def process_response_data(x): with torch.no_grad(): # eval on test_label, some duplicate code (I don't want to make the training loop into a function...) test_accuracies = [] - eval_len = 200 # len(test_label) + eval_len = 200 # len(test_label) len_labels = (eval_len // args.batch_size) * args.batch_size # in case the last batch is not full new_all_inds = np.arange(len_labels) for start in range(0, len_labels, args.batch_size): @@ -771,7 +771,6 @@ def process_response_data(x): print_rich_table(f"Sample Output at Step {global_step}", all_df[:4], console) except Exception as e: print(e) - pass del ( query_responses, all_decode_queries, From 7a67477a4a167f6ad7d34465803e7a65f32762ee Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Mon, 25 Sep 2023 12:56:52 -0400 Subject: [PATCH 05/62] push changes --- .../train_sft_accelerate_summarize.py | 503 ++++++++++++++++++ 1 file changed, 503 insertions(+) create mode 100644 lm_human_preference_details/train_sft_accelerate_summarize.py diff --git a/lm_human_preference_details/train_sft_accelerate_summarize.py b/lm_human_preference_details/train_sft_accelerate_summarize.py new file mode 100644 index 0000000..bd332fe --- /dev/null +++ b/lm_human_preference_details/train_sft_accelerate_summarize.py @@ -0,0 +1,503 @@ +import os +import random +import time +from dataclasses import asdict, dataclass, field +from types import SimpleNamespace +from typing import List, Optional + +import numpy as np +import pandas as pd +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +import tyro +from accelerate import Accelerator +from accelerate.state import AcceleratorState +from datasets import load_dataset +from rich.console import Console +from rich.pretty import pprint +from rich.table import Table +from torch import Tensor, optim +from torch.optim.optimizer import ( + _dispatch_sqrt, + _get_value, + _use_grad_for_differentiable, +) +from torch.utils.data import DataLoader +from torch.utils.tensorboard import SummaryWriter +from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig + +from lm_human_preference_details.data import process_query + + +@dataclass +class AdaptiveKLParams: + target: float = 6.0 + horizon: int = 10000 # in episodes + + +@dataclass +class RewardHParams: + kl_coef: float = 0.15 + adaptive_kl: Optional[AdaptiveKLParams] = field(default_factory=AdaptiveKLParams) + trained_model: Optional[str] = "models/reward.pt" + label_dataset: tyro.conf.Suppress[Optional[str]] = None + + +@dataclass +class PpoHParams: + total_episodes: int = 1000000 + local_batch_size: int = 64 + local_mini_batch_size: tyro.conf.Suppress[int] = None + batch_size: tyro.conf.Suppress[int] = None + mini_batch_size: tyro.conf.Suppress[int] = None + gradient_accumulation_steps: int = 1 + """gradient accumulation steps""" + local_micro_batch_size: tyro.conf.Suppress[int] = None + """per rank micro batch size""" + world_size: tyro.conf.Suppress[int] = None + batch_size: tyro.conf.Suppress[int] = None + minibatch_size: tyro.conf.Suppress[int] = None + num_updates: tyro.conf.Suppress[int] = None + nminibatches: int = 1 + noptepochs: int = 4 + lr: float = 0.00001 + eps: float = 1e-5 + vf_coef: float = 0.1 + cliprange: float = 0.2 + cliprange_value: float = 0.2 + gamma: float = 1 + lam: float = 0.95 + whiten_rewards: bool = True + + +@dataclass +class TaskHParams: + # Query params + query_length: int = 512 + query_dataset: str = "vwxyzjn/summarize_from_feedback_tldr_3_filtered" + + query_format_str: Optional[str] = "SUBREDDIT: r/{subreddit}\n\nTITLE: {title}\n\nPOST: {post}\n\nTL;DR:" + query_truncate_field: Optional[str] = "post" + query_truncate_text: Optional[str] = "\n" + query_padding: Optional[str] = None # defaults to repeated spaces + query_pad_side: Optional[str] = "left" + + # Response params + response_length: int = 48 + + # Truncate response after the first occurrence of this token at or after index after when sampling. + truncate_token: int = 50256 # EOS token + truncate_after: int = 16 + penalty_reward_value: int = -1 + + # LM params + temperature: float = 0.7 + + +# a patch +@dataclass +class TaskQueryHParams: + length: int = None + dataset: str = None + format_str: Optional[str] = None # if underlying dataset yields dicts, can format arbitrarily + truncate_field: Optional[str] = None + truncate_text: Optional[str] = None + padding: Optional[str] = None # defaults to repeated spaces + pad_side: Optional[str] = None + + +@dataclass +class Args: + # common args + exp_name: str = os.path.basename(__file__)[: -len(".py")] + """the name of this experiment""" + seed: int = 1 + """seed of the experiment""" + track: bool = False + """if toggled, this experiment will be tracked with Weights and Biases""" + wandb_project_name: str = "cleanrl" + """the wandb's project name""" + wandb_entity: Optional[str] = None + """the entity (team) of wandb's project""" + cuda: bool = True + """Whether to use cuda if available.""" + run_name: tyro.conf.Suppress[str] = None + """TO BE FILLED: a unique name of this run""" + upload_model: bool = False + "whether to upload the saved model to huggingface" + hf_entity: str = "" + "the user or org name of the model repository from the Hugging Face Hub" + + base_model: str = "gpt2" + """the name of the pretrained model to use""" + deepspeed: bool = False + """Whether to use deepspeed to train the model""" + print_sample_output_freq: int = 10 + """How often to print sample output""" + save_path: str = "models/policy.pt" + """Where to save the model""" + use_tensorflow_adam: bool = True + """Whether to use tensorflow-style Adam optimizer instead of PyTorch's""" + task: TaskHParams = field(default_factory=TaskHParams) + rewards: RewardHParams = field(default_factory=RewardHParams) + ppo: PpoHParams = field(default_factory=PpoHParams) + + +def print_rich_table(title: str, df: pd.DataFrame, console: Console) -> Table: + table = Table(show_lines=True) + for column in df.columns: + table.add_column(column) + for _, row in df.iterrows(): + table.add_row(*row.astype(str).tolist()) + console.rule(f"[bold red]{title}") + console.print(table) + + +def _single_tensor_adam( + params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_avg_sqs: List[Tensor], + max_exp_avg_sqs: List[Tensor], + state_steps: List[Tensor], + grad_scale: Optional[Tensor], + found_inf: Optional[Tensor], + *, + amsgrad: bool, + beta1: float, + beta2: float, + lr: float, + weight_decay: float, + eps: float, + maximize: bool, + capturable: bool, + differentiable: bool, +): + assert grad_scale is None and found_inf is None + + for i, param in enumerate(params): + grad = grads[i] if not maximize else -grads[i] + exp_avg = exp_avgs[i] + exp_avg_sq = exp_avg_sqs[i] + step_t = state_steps[i] + # update step + step_t += 1 + # Decay the first and second moment running average coefficient + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2) + step = _get_value(step_t) + + ### pytorch adam implementation: + # bias_correction1 = 1 - beta1 ** step + # bias_correction2 = 1 - beta2 ** step + # step_size = lr / bias_correction1 + # bias_correction2_sqrt = _dispatch_sqrt(bias_correction2) + # denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps) + # param.addcdiv_(exp_avg, denom, value=-step_size) + + ### tensorflow adam implementation: + lr_t = lr * _dispatch_sqrt(1 - beta2**step) / (1 - beta1**step) + denom = exp_avg_sq.sqrt().add_(eps) + param.addcdiv_(exp_avg, denom, value=-lr_t) + + +def adam( + params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_avg_sqs: List[Tensor], + max_exp_avg_sqs: List[Tensor], + state_steps: List[Tensor], + # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627 + # setting this as kwarg for now as functional API is compiled by torch/distributed/optim + foreach: Optional[bool] = None, + capturable: bool = False, + differentiable: bool = False, + fused: Optional[bool] = None, + grad_scale: Optional[Tensor] = None, + found_inf: Optional[Tensor] = None, + *, + amsgrad: bool, + beta1: float, + beta2: float, + lr: float, + weight_decay: float, + eps: float, + maximize: bool, +): + func = _single_tensor_adam + + func( + params, + grads, + exp_avgs, + exp_avg_sqs, + max_exp_avg_sqs, + state_steps, + amsgrad=amsgrad, + beta1=beta1, + beta2=beta2, + lr=lr, + weight_decay=weight_decay, + eps=eps, + maximize=maximize, + capturable=capturable, + differentiable=differentiable, + grad_scale=grad_scale, + found_inf=found_inf, + ) + + +class AdamTensorFlowStyle(optim.Adam): + @_use_grad_for_differentiable + def step(self, closure=None): + self._cuda_graph_capture_health_check() + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + params_with_grad = [] + grads = [] + exp_avgs = [] + exp_avg_sqs = [] + max_exp_avg_sqs = [] + state_steps = [] + beta1, beta2 = group["betas"] + + self._init_group( + group, + params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + max_exp_avg_sqs, + state_steps, + ) + + adam( + params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + max_exp_avg_sqs, + state_steps, + amsgrad=group["amsgrad"], + beta1=beta1, + beta2=beta2, + lr=group["lr"], + weight_decay=group["weight_decay"], + eps=group["eps"], + maximize=group["maximize"], + foreach=group["foreach"], + capturable=group["capturable"], + differentiable=group["differentiable"], + fused=group["fused"], + grad_scale=getattr(self, "grad_scale", None), + found_inf=getattr(self, "found_inf", None), + ) + + return loss + + + + +def right_padding_to_left_padding(tokens, pad_id): + """Convert from right padding to left padding.""" + assert tokens.ndim == 2 + return torch.tensor( + [[pad_id] * (row == pad_id).sum() + [x for x in row if x != pad_id] for row in tokens], + device=tokens.device, + ) + + +def ceil_div(a, b): + return (a - 1) // b + 1 + + +def exact_div(a, b): + q = a // b + if a != q * b: + raise ValueError(f"Inexact division: {a} / {b} = {a / b}") + return q + + +def generate(lm_backbone, queries, tokenizer, generation_config): + """generate in a way that does not affect padding tokens""" + context_length = queries.shape[1] + attention_mask = queries != tokenizer.pad_token_id + input_ids = queries.clone() + input_ids[~attention_mask] = 0 # set padding tokens to 0 + output = lm_backbone.generate( + input_ids=input_ids, + attention_mask=attention_mask, + # position_ids=attention_mask.cumsum(1) - attention_mask.long(), # generation collapsed if this was turned on. TODO: why does generation collapse with this? + generation_config=generation_config, + return_dict_in_generate=True, + ) + # restore padding tokens + return torch.cat((queries, output.sequences[:, context_length:]), dim=1) + + + +def forward(policy, query_responses, tokenizer): + attention_mask = query_responses != tokenizer.pad_token_id + position_ids = attention_mask.cumsum(1) - attention_mask.long() # exclusive cumsum + input_ids = torch.masked_fill(query_responses, ~attention_mask, 0) + return policy( + labels=input_ids, + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + return_dict=True, + ) + + +# def train(args: Args): +if __name__ == "__main__": + args = tyro.cli(Args) + accelerator = Accelerator(gradient_accumulation_steps=args.ppo.gradient_accumulation_steps) + args.ppo.world_size = accelerator.num_processes + args.ppo.batch_size = int(args.ppo.local_batch_size * args.ppo.world_size) + args.ppo.minibatch_size = exact_div(args.ppo.batch_size, args.ppo.nminibatches) + args.ppo.local_mini_batch_size = exact_div(args.ppo.local_batch_size, args.ppo.nminibatches) + args.ppo.local_micro_batch_size = exact_div(args.ppo.local_mini_batch_size, args.ppo.gradient_accumulation_steps) + patch_h = TaskQueryHParams( + length=args.task.query_length, + dataset=args.task.query_dataset, + format_str=args.task.query_format_str, + truncate_field=args.task.query_truncate_field, + truncate_text=args.task.query_truncate_text, + padding=args.task.query_padding, + pad_side=args.task.query_pad_side, + ) + if args.ppo.whiten_rewards: + assert ( + args.ppo.local_mini_batch_size >= 8 + ), f"Per-rank minibatch size {args.ppo.local_mini_batch_size} is insufficient for whitening" + # `per_rank_rollout_batch_size` is our `args.ppo.local_batch_size` + # `per_rank_minibatch_size` is our `args.ppo.local_mini_batch_size` + args.ppo.num_updates = args.ppo.total_episodes // args.ppo.batch_size + + console = Console(force_terminal=True) + run_name = f"{args.exp_name}__{args.seed}__{int(time.time())}" + writer = SimpleNamespace() # dummy writer + writer.add_scalar = lambda x, y, z: None + writer.add_histogram = lambda x, y, z: None + if accelerator.is_main_process: + if args.track: + import wandb + + wandb.init( + project=args.wandb_project_name, + entity=args.wandb_entity, + sync_tensorboard=True, + config=asdict(args), + name=run_name, + save_code=True, + ) + wandb.run.log_code(".") + writer = SummaryWriter(f"runs/{run_name}") + writer.add_text( + "hyperparameters", + "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])), + ) + pprint(args) + device = accelerator.device + local_seed = args.seed + accelerator.process_index * 100003 # Prime + random.seed(local_seed) + np.random.seed(local_seed) + torch.manual_seed(local_seed) + torch.backends.cudnn.deterministic = True + tokenizer = AutoTokenizer.from_pretrained( + args.base_model, + padding_side="right", + trust_remote_code=True, + ) + # we use the padding token manually but do not resize the token embedding of the model + tokenizer.add_special_tokens({"pad_token": "[PAD]"}) + policy = AutoModelForCausalLM.from_pretrained(args.base_model, trust_remote_code=True) + policy.generation_config.eos_token_id = ( + None # disable `pad_token_id` and `eos_token_id` because we just want to + ) + policy.generation_config.pad_token_id = None # generate tokens without truncation / padding + # IMPORTANT: Layer norm produces weird gradients, which affects Adam optimizer to impact all the parameters systematically + # see https://github.com/pytorch/pytorch/issues/104857 for more details + if args.use_tensorflow_adam: + optimizer = AdamTensorFlowStyle(policy.parameters(), lr=args.ppo.lr, eps=args.ppo.eps) + else: + optimizer = optim.Adam(policy.parameters(), lr=args.ppo.lr, eps=args.ppo.eps) + dataset = load_dataset(args.task.query_dataset, split="train") + test_dataset = load_dataset(args.task.query_dataset, split="test") + + def process_query_data1(x): + return { + **process_query(x, encoder=tokenizer, hparams=patch_h), + "reference_response": tokenizer.encode( + x["summary"], padding="max_length", max_length=args.task.response_length, truncation=True + ), + } + + dataset = dataset.map(process_query_data1) + dataset = dataset.with_format("torch", columns=["query_token", "reference_response"]) + dataset = dataset.shuffle(seed=local_seed) + dataloader = DataLoader(dataset, batch_size=args.ppo.local_batch_size) + policy, optimizer, dataloader = accelerator.prepare(policy, optimizer, dataloader) + iter_dataloader = iter(dataloader) + # WARNING: even with `max_new_tokens` and `min_new_tokens` set to the same value, the number of tokens generated + # may not be the same. TODO: investigate further, we just want to generate a fixed number of tokens + generation_config = GenerationConfig( + max_new_tokens=args.task.response_length, + min_new_tokens=args.task.response_length, + temperature=args.task.temperature, + top_k=0.0, + top_p=1.0, + do_sample=True, + ) + + print("===training policy===") + global_step = 0 + stats_shape = (args.ppo.noptepochs, args.ppo.nminibatches, args.ppo.gradient_accumulation_steps) + approxkls_stats = torch.zeros(stats_shape, device=device) + clipfracs_stats = torch.zeros(stats_shape, device=device) + pg_losses_stats = torch.zeros(stats_shape, device=device) + vf_losses_stats = torch.zeros(stats_shape, device=device) + vf_clipfrac_stats = torch.zeros(stats_shape, device=device) + entropies_stats = torch.zeros(stats_shape, device=device) + for update in range(1, args.ppo.num_updates + 1): + global_step += 1 * args.ppo.batch_size + frac = 1.0 - (update - 1.0) / args.ppo.num_updates + lrnow = frac * args.ppo.lr + optimizer.param_groups[0]["lr"] = lrnow + data = next(iter_dataloader) + with torch.no_grad(): + queries = data["query_token"].to(device) + reference_responses = data["reference_response"].to(device) + query_responses = torch.cat((queries, reference_responses), dim=1) + query_responses = right_padding_to_left_padding(query_responses, tokenizer.pad_token_id).to(device) + with accelerator.accumulate(policy): + output = forward(policy, query_responses, tokenizer) + loss = output.loss + accelerator.backward(loss) + optimizer.step() + optimizer.zero_grad() + writer.add_scalar("loss", loss.item(), update) + + # save model + if accelerator.is_main_process and args.save_path: + os.makedirs(os.path.dirname(args.save_path), exist_ok=True) + torch.save(policy.state_dict(), args.save_path) + + if args.upload_model: + repo_name = f"{args.exp_name}__{args.rewards.label_dataset}__seed{args.seed}__{int(time.time())}" + repo_id = f"{args.hf_entity}/{repo_name}" if args.hf_entity else repo_name + policy.lm_backbone.save_pretrained(repo_id, safe_serialization=True, push_to_hub=True) + tokenizer.save_pretrained(repo_id, push_to_hub=True) + + +if __name__ == "__main__": + args = tyro.cli(Args) + train(args) From f12f228ca4fb46bf1e49ad8a6458644d8ade5560 Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Mon, 25 Sep 2023 12:57:15 -0400 Subject: [PATCH 06/62] push changes --- .../train_sft_accelerate_summarize.py | 76 ++++++++++++++----- 1 file changed, 55 insertions(+), 21 deletions(-) diff --git a/lm_human_preference_details/train_sft_accelerate_summarize.py b/lm_human_preference_details/train_sft_accelerate_summarize.py index bd332fe..e0c8566 100644 --- a/lm_human_preference_details/train_sft_accelerate_summarize.py +++ b/lm_human_preference_details/train_sft_accelerate_summarize.py @@ -48,11 +48,11 @@ class RewardHParams: @dataclass class PpoHParams: total_episodes: int = 1000000 - local_batch_size: int = 64 + local_batch_size: int = 8 local_mini_batch_size: tyro.conf.Suppress[int] = None batch_size: tyro.conf.Suppress[int] = None mini_batch_size: tyro.conf.Suppress[int] = None - gradient_accumulation_steps: int = 1 + gradient_accumulation_steps: int = 32 """gradient accumulation steps""" local_micro_batch_size: tyro.conf.Suppress[int] = None """per rank micro batch size""" @@ -134,7 +134,7 @@ class Args: """the name of the pretrained model to use""" deepspeed: bool = False """Whether to use deepspeed to train the model""" - print_sample_output_freq: int = 10 + print_sample_output_freq: int = 40 """How often to print sample output""" save_path: str = "models/policy.pt" """Where to save the model""" @@ -362,9 +362,9 @@ def forward(policy, query_responses, tokenizer): accelerator = Accelerator(gradient_accumulation_steps=args.ppo.gradient_accumulation_steps) args.ppo.world_size = accelerator.num_processes args.ppo.batch_size = int(args.ppo.local_batch_size * args.ppo.world_size) - args.ppo.minibatch_size = exact_div(args.ppo.batch_size, args.ppo.nminibatches) - args.ppo.local_mini_batch_size = exact_div(args.ppo.local_batch_size, args.ppo.nminibatches) - args.ppo.local_micro_batch_size = exact_div(args.ppo.local_mini_batch_size, args.ppo.gradient_accumulation_steps) + # args.ppo.minibatch_size = exact_div(args.ppo.batch_size, args.ppo.nminibatches) + # args.ppo.local_mini_batch_size = exact_div(args.ppo.local_batch_size, args.ppo.nminibatches) + # args.ppo.local_micro_batch_size = exact_div(args.ppo.local_mini_batch_size, args.ppo.gradient_accumulation_steps) patch_h = TaskQueryHParams( length=args.task.query_length, dataset=args.task.query_dataset, @@ -374,10 +374,10 @@ def forward(policy, query_responses, tokenizer): padding=args.task.query_padding, pad_side=args.task.query_pad_side, ) - if args.ppo.whiten_rewards: - assert ( - args.ppo.local_mini_batch_size >= 8 - ), f"Per-rank minibatch size {args.ppo.local_mini_batch_size} is insufficient for whitening" + # if args.ppo.whiten_rewards: + # assert ( + # args.ppo.local_mini_batch_size >= 8 + # ), f"Per-rank minibatch size {args.ppo.local_mini_batch_size} is insufficient for whitening" # `per_rank_rollout_batch_size` is our `args.ppo.local_batch_size` # `per_rank_minibatch_size` is our `args.ppo.local_mini_batch_size` args.ppo.num_updates = args.ppo.total_episodes // args.ppo.batch_size @@ -444,6 +444,9 @@ def process_query_data1(x): dataset = dataset.map(process_query_data1) dataset = dataset.with_format("torch", columns=["query_token", "reference_response"]) dataset = dataset.shuffle(seed=local_seed) + test_dataset = test_dataset.map(process_query_data1) + test_dataset = test_dataset.with_format("torch", columns=["query_token", "reference_response"]) + test_dataset = test_dataset.shuffle(seed=local_seed) dataloader = DataLoader(dataset, batch_size=args.ppo.local_batch_size) policy, optimizer, dataloader = accelerator.prepare(policy, optimizer, dataloader) iter_dataloader = iter(dataloader) @@ -467,24 +470,55 @@ def process_query_data1(x): vf_losses_stats = torch.zeros(stats_shape, device=device) vf_clipfrac_stats = torch.zeros(stats_shape, device=device) entropies_stats = torch.zeros(stats_shape, device=device) + test_data = test_dataset[0:10] + test_data = {k: v.to(device) for k, v in test_data.items()} for update in range(1, args.ppo.num_updates + 1): global_step += 1 * args.ppo.batch_size frac = 1.0 - (update - 1.0) / args.ppo.num_updates lrnow = frac * args.ppo.lr optimizer.param_groups[0]["lr"] = lrnow data = next(iter_dataloader) - with torch.no_grad(): - queries = data["query_token"].to(device) - reference_responses = data["reference_response"].to(device) - query_responses = torch.cat((queries, reference_responses), dim=1) - query_responses = right_padding_to_left_padding(query_responses, tokenizer.pad_token_id).to(device) - with accelerator.accumulate(policy): - output = forward(policy, query_responses, tokenizer) - loss = output.loss - accelerator.backward(loss) - optimizer.step() - optimizer.zero_grad() + queries = data["query_token"].to(device) + reference_responses = data["reference_response"].to(device) + query_responses = torch.cat((queries, reference_responses), dim=1) + query_responses = right_padding_to_left_padding(query_responses, tokenizer.pad_token_id).to(device) + with accelerator.accumulate(policy): + output = forward(policy, query_responses, tokenizer) + loss = output.loss + accelerator.backward(loss) + optimizer.step() + optimizer.zero_grad() + if (update - 1) % args.ppo.gradient_accumulation_steps: writer.add_scalar("loss", loss.item(), update) + if (update - 1) % args.print_sample_output_freq * args.ppo.gradient_accumulation_steps == 0: + with torch.no_grad(): + test_queries = test_data["query_token"] + test_reference_responses = test_data["reference_response"] + test_queries = right_padding_to_left_padding(test_queries, tokenizer.pad_token_id) + generated_responses = generate(policy, test_queries, tokenizer, generation_config) + + try: + all_decode_test_queries = tokenizer.batch_decode(test_queries, skip_special_tokens=True) + all_decode_test_query_responses = tokenizer.batch_decode(generated_responses, skip_special_tokens=True) + all_decode_test_reference_responses = tokenizer.batch_decode(test_reference_responses, skip_special_tokens=True) + all_decode_test_responses = [ + x[len(y) :] for x, y in zip(all_decode_test_query_responses, all_decode_test_queries) + ] + + all_df = pd.DataFrame( + { + "query": all_decode_test_queries, + "response": all_decode_test_responses, + "reference": all_decode_test_reference_responses, + } + ) + if accelerator.is_main_process and args.track: + wandb.log({"query_responses": wandb.Table(dataframe=all_df)}, step=update) + print_rich_table("stuff", all_df[:4], console) + except Exception as e: + print(e) + + # save model if accelerator.is_main_process and args.save_path: From 301d008eb3c64da6685c2a5d4d56a21ab358e6d7 Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Mon, 25 Sep 2023 12:57:34 -0400 Subject: [PATCH 07/62] push change --- .../train_sft_accelerate_summarize.py | 18 +++++------------- 1 file changed, 5 insertions(+), 13 deletions(-) diff --git a/lm_human_preference_details/train_sft_accelerate_summarize.py b/lm_human_preference_details/train_sft_accelerate_summarize.py index e0c8566..d0fdc77 100644 --- a/lm_human_preference_details/train_sft_accelerate_summarize.py +++ b/lm_human_preference_details/train_sft_accelerate_summarize.py @@ -8,12 +8,9 @@ import numpy as np import pandas as pd import torch -import torch.nn as nn -import torch.nn.functional as F import torch.optim as optim import tyro from accelerate import Accelerator -from accelerate.state import AcceleratorState from datasets import load_dataset from rich.console import Console from rich.pretty import pprint @@ -303,8 +300,6 @@ def step(self, closure=None): return loss - - def right_padding_to_left_padding(tokens, pad_id): """Convert from right padding to left padding.""" assert tokens.ndim == 2 @@ -342,7 +337,6 @@ def generate(lm_backbone, queries, tokenizer, generation_config): return torch.cat((queries, output.sequences[:, context_length:]), dim=1) - def forward(policy, query_responses, tokenizer): attention_mask = query_responses != tokenizer.pad_token_id position_ids = attention_mask.cumsum(1) - attention_mask.long() # exclusive cumsum @@ -420,9 +414,7 @@ def forward(policy, query_responses, tokenizer): # we use the padding token manually but do not resize the token embedding of the model tokenizer.add_special_tokens({"pad_token": "[PAD]"}) policy = AutoModelForCausalLM.from_pretrained(args.base_model, trust_remote_code=True) - policy.generation_config.eos_token_id = ( - None # disable `pad_token_id` and `eos_token_id` because we just want to - ) + policy.generation_config.eos_token_id = None # disable `pad_token_id` and `eos_token_id` because we just want to policy.generation_config.pad_token_id = None # generate tokens without truncation / padding # IMPORTANT: Layer norm produces weird gradients, which affects Adam optimizer to impact all the parameters systematically # see https://github.com/pytorch/pytorch/issues/104857 for more details @@ -488,7 +480,7 @@ def process_query_data1(x): accelerator.backward(loss) optimizer.step() optimizer.zero_grad() - if (update - 1) % args.ppo.gradient_accumulation_steps: + if (update - 1) % args.ppo.gradient_accumulation_steps: writer.add_scalar("loss", loss.item(), update) if (update - 1) % args.print_sample_output_freq * args.ppo.gradient_accumulation_steps == 0: with torch.no_grad(): @@ -500,7 +492,9 @@ def process_query_data1(x): try: all_decode_test_queries = tokenizer.batch_decode(test_queries, skip_special_tokens=True) all_decode_test_query_responses = tokenizer.batch_decode(generated_responses, skip_special_tokens=True) - all_decode_test_reference_responses = tokenizer.batch_decode(test_reference_responses, skip_special_tokens=True) + all_decode_test_reference_responses = tokenizer.batch_decode( + test_reference_responses, skip_special_tokens=True + ) all_decode_test_responses = [ x[len(y) :] for x, y in zip(all_decode_test_query_responses, all_decode_test_queries) ] @@ -518,8 +512,6 @@ def process_query_data1(x): except Exception as e: print(e) - - # save model if accelerator.is_main_process and args.save_path: os.makedirs(os.path.dirname(args.save_path), exist_ok=True) From 6b2d2ecbd2765777d3f4262489595ee08d36774b Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Mon, 25 Sep 2023 16:54:08 -0400 Subject: [PATCH 08/62] quick change --- .../train_sft_accelerate_summarize.py | 109 ++++++------------ 1 file changed, 36 insertions(+), 73 deletions(-) diff --git a/lm_human_preference_details/train_sft_accelerate_summarize.py b/lm_human_preference_details/train_sft_accelerate_summarize.py index d0fdc77..50a8684 100644 --- a/lm_human_preference_details/train_sft_accelerate_summarize.py +++ b/lm_human_preference_details/train_sft_accelerate_summarize.py @@ -29,44 +29,19 @@ @dataclass -class AdaptiveKLParams: - target: float = 6.0 - horizon: int = 10000 # in episodes - - -@dataclass -class RewardHParams: - kl_coef: float = 0.15 - adaptive_kl: Optional[AdaptiveKLParams] = field(default_factory=AdaptiveKLParams) - trained_model: Optional[str] = "models/reward.pt" - label_dataset: tyro.conf.Suppress[Optional[str]] = None - - -@dataclass -class PpoHParams: - total_episodes: int = 1000000 - local_batch_size: int = 8 - local_mini_batch_size: tyro.conf.Suppress[int] = None +class SFTHParams: + gradient_accumulation_steps: int = 32 + local_micro_batch_size: int = 8 + noptepochs: int = 1 + lr: float = 0.00001 + eps: float = 1e-5 + total_episodes: tyro.conf.Suppress[int] = None + local_batch_size:tyro.conf.Suppress[int] = None batch_size: tyro.conf.Suppress[int] = None mini_batch_size: tyro.conf.Suppress[int] = None - gradient_accumulation_steps: int = 32 - """gradient accumulation steps""" - local_micro_batch_size: tyro.conf.Suppress[int] = None - """per rank micro batch size""" world_size: tyro.conf.Suppress[int] = None batch_size: tyro.conf.Suppress[int] = None - minibatch_size: tyro.conf.Suppress[int] = None num_updates: tyro.conf.Suppress[int] = None - nminibatches: int = 1 - noptepochs: int = 4 - lr: float = 0.00001 - eps: float = 1e-5 - vf_coef: float = 0.1 - cliprange: float = 0.2 - cliprange_value: float = 0.2 - gamma: float = 1 - lam: float = 0.95 - whiten_rewards: bool = True @dataclass @@ -138,8 +113,7 @@ class Args: use_tensorflow_adam: bool = True """Whether to use tensorflow-style Adam optimizer instead of PyTorch's""" task: TaskHParams = field(default_factory=TaskHParams) - rewards: RewardHParams = field(default_factory=RewardHParams) - ppo: PpoHParams = field(default_factory=PpoHParams) + sft: SFTHParams = field(default_factory=SFTHParams) def print_rich_table(title: str, df: pd.DataFrame, console: Console) -> Table: @@ -350,15 +324,11 @@ def forward(policy, query_responses, tokenizer): ) -# def train(args: Args): -if __name__ == "__main__": - args = tyro.cli(Args) - accelerator = Accelerator(gradient_accumulation_steps=args.ppo.gradient_accumulation_steps) - args.ppo.world_size = accelerator.num_processes - args.ppo.batch_size = int(args.ppo.local_batch_size * args.ppo.world_size) - # args.ppo.minibatch_size = exact_div(args.ppo.batch_size, args.ppo.nminibatches) - # args.ppo.local_mini_batch_size = exact_div(args.ppo.local_batch_size, args.ppo.nminibatches) - # args.ppo.local_micro_batch_size = exact_div(args.ppo.local_mini_batch_size, args.ppo.gradient_accumulation_steps) +def train(args: Args): + accelerator = Accelerator(gradient_accumulation_steps=args.sft.gradient_accumulation_steps) + args.sft.world_size = accelerator.num_processes + args.sft.local_batch_size = args.sft.local_micro_batch_size * args.sft.gradient_accumulation_steps + args.sft.batch_size = int(args.sft.local_batch_size * args.sft.world_size) patch_h = TaskQueryHParams( length=args.task.query_length, dataset=args.task.query_dataset, @@ -368,13 +338,10 @@ def forward(policy, query_responses, tokenizer): padding=args.task.query_padding, pad_side=args.task.query_pad_side, ) - # if args.ppo.whiten_rewards: - # assert ( - # args.ppo.local_mini_batch_size >= 8 - # ), f"Per-rank minibatch size {args.ppo.local_mini_batch_size} is insufficient for whitening" - # `per_rank_rollout_batch_size` is our `args.ppo.local_batch_size` - # `per_rank_minibatch_size` is our `args.ppo.local_mini_batch_size` - args.ppo.num_updates = args.ppo.total_episodes // args.ppo.batch_size + dataset = load_dataset(args.task.query_dataset, split="train") + test_dataset = load_dataset(args.task.query_dataset, split="test") + args.sft.total_episodes = len(dataset) + args.sft.num_updates = args.sft.total_episodes // args.sft.batch_size console = Console(force_terminal=True) run_name = f"{args.exp_name}__{args.seed}__{int(time.time())}" @@ -419,11 +386,9 @@ def forward(policy, query_responses, tokenizer): # IMPORTANT: Layer norm produces weird gradients, which affects Adam optimizer to impact all the parameters systematically # see https://github.com/pytorch/pytorch/issues/104857 for more details if args.use_tensorflow_adam: - optimizer = AdamTensorFlowStyle(policy.parameters(), lr=args.ppo.lr, eps=args.ppo.eps) + optimizer = AdamTensorFlowStyle(policy.parameters(), lr=args.sft.lr, eps=args.sft.eps) else: - optimizer = optim.Adam(policy.parameters(), lr=args.ppo.lr, eps=args.ppo.eps) - dataset = load_dataset(args.task.query_dataset, split="train") - test_dataset = load_dataset(args.task.query_dataset, split="test") + optimizer = optim.Adam(policy.parameters(), lr=args.sft.lr, eps=args.sft.eps) def process_query_data1(x): return { @@ -439,7 +404,7 @@ def process_query_data1(x): test_dataset = test_dataset.map(process_query_data1) test_dataset = test_dataset.with_format("torch", columns=["query_token", "reference_response"]) test_dataset = test_dataset.shuffle(seed=local_seed) - dataloader = DataLoader(dataset, batch_size=args.ppo.local_batch_size) + dataloader = DataLoader(dataset, batch_size=args.sft.local_batch_size) policy, optimizer, dataloader = accelerator.prepare(policy, optimizer, dataloader) iter_dataloader = iter(dataloader) # WARNING: even with `max_new_tokens` and `min_new_tokens` set to the same value, the number of tokens generated @@ -455,19 +420,15 @@ def process_query_data1(x): print("===training policy===") global_step = 0 - stats_shape = (args.ppo.noptepochs, args.ppo.nminibatches, args.ppo.gradient_accumulation_steps) - approxkls_stats = torch.zeros(stats_shape, device=device) - clipfracs_stats = torch.zeros(stats_shape, device=device) - pg_losses_stats = torch.zeros(stats_shape, device=device) - vf_losses_stats = torch.zeros(stats_shape, device=device) - vf_clipfrac_stats = torch.zeros(stats_shape, device=device) - entropies_stats = torch.zeros(stats_shape, device=device) + loss_stats = torch.zeros(args.sft.gradient_accumulation_steps, device=device) test_data = test_dataset[0:10] test_data = {k: v.to(device) for k, v in test_data.items()} - for update in range(1, args.ppo.num_updates + 1): - global_step += 1 * args.ppo.batch_size - frac = 1.0 - (update - 1.0) / args.ppo.num_updates - lrnow = frac * args.ppo.lr + gradient_accumulation_idx = 0 + for update in range(1, args.sft.num_updates + 1): + print(update, global_step) + global_step += 1 * args.sft.batch_size + frac = 1.0 - (update - 1.0) / args.sft.num_updates + lrnow = frac * args.sft.lr optimizer.param_groups[0]["lr"] = lrnow data = next(iter_dataloader) queries = data["query_token"].to(device) @@ -480,9 +441,12 @@ def process_query_data1(x): accelerator.backward(loss) optimizer.step() optimizer.zero_grad() - if (update - 1) % args.ppo.gradient_accumulation_steps: - writer.add_scalar("loss", loss.item(), update) - if (update - 1) % args.print_sample_output_freq * args.ppo.gradient_accumulation_steps == 0: + loss_stats[gradient_accumulation_idx] = loss + gradient_accumulation_idx = (gradient_accumulation_idx + 1) % args.sft.gradient_accumulation_steps + if (update - 1) % args.sft.gradient_accumulation_steps: + writer.add_scalar("loss", accelerator.gather(loss_stats).mean().item(), update) + writer.add_scalar("lr", lrnow, update) + if (update - 1) % args.print_sample_output_freq * args.sft.gradient_accumulation_steps == 0: with torch.no_grad(): test_queries = test_data["query_token"] test_reference_responses = test_data["reference_response"] @@ -508,7 +472,7 @@ def process_query_data1(x): ) if accelerator.is_main_process and args.track: wandb.log({"query_responses": wandb.Table(dataframe=all_df)}, step=update) - print_rich_table("stuff", all_df[:4], console) + print_rich_table(f"Sample Output at Step {update}", all_df[:4], console) except Exception as e: print(e) @@ -520,10 +484,9 @@ def process_query_data1(x): if args.upload_model: repo_name = f"{args.exp_name}__{args.rewards.label_dataset}__seed{args.seed}__{int(time.time())}" repo_id = f"{args.hf_entity}/{repo_name}" if args.hf_entity else repo_name - policy.lm_backbone.save_pretrained(repo_id, safe_serialization=True, push_to_hub=True) + policy.save_pretrained(repo_id, safe_serialization=True, push_to_hub=True) tokenizer.save_pretrained(repo_id, push_to_hub=True) - if __name__ == "__main__": args = tyro.cli(Args) train(args) From e8431668efbcd62482c99557fd63a152fe654169 Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Mon, 25 Sep 2023 21:39:20 +0000 Subject: [PATCH 09/62] normalize based on reference response --- .../train_reward_accelerate_summarize.py | 15 ++++++++++----- .../train_sft_accelerate_summarize.py | 6 +++--- 2 files changed, 13 insertions(+), 8 deletions(-) diff --git a/lm_human_preference_details/train_reward_accelerate_summarize.py b/lm_human_preference_details/train_reward_accelerate_summarize.py index 4c29fa8..43ed638 100644 --- a/lm_human_preference_details/train_reward_accelerate_summarize.py +++ b/lm_human_preference_details/train_reward_accelerate_summarize.py @@ -406,8 +406,9 @@ def normalize( for _ in range(n_batches): data = next(iter_dataloader) queries = data["query_token"].to(device) - queries = right_padding_to_left_padding(data["query_token"], args.pad_token_id).to(device) - query_responses = generate(lm_backbone, queries, args, generation_config) + reference_response = data["reference_response"].to(device) + query_responses = torch.cat((queries, reference_response), dim=1) + query_responses = right_padding_to_left_padding(query_responses, args.pad_token_id).to(device) sample_queries_responses.append(query_responses) # compute reward statistics @@ -433,8 +434,9 @@ def normalize( for _ in range(n_batches): data = next(iter_dataloader) queries = data["query_token"].to(device) - queries = right_padding_to_left_padding(data["query_token"], args.pad_token_id).to(device) - query_responses = generate(lm_backbone, queries, args, generation_config) + reference_response = data["reference_response"].to(device) + query_responses = torch.cat((queries, reference_response), dim=1) + query_responses = right_padding_to_left_padding(query_responses, args.pad_token_id).to(device) sample_queries_responses.append(query_responses) rewards = [] for query_responses in sample_queries_responses: @@ -532,10 +534,13 @@ def train(args: Args): def process_query_data(x): return { **process_query(x, encoder=tokenizer, hparams=patch_h), + "reference_response": tokenizer.encode( + x["summary"], padding="max_length", max_length=args.task.response_length, truncation=True + ), } dataset = dataset.map(process_query_data) - dataset = dataset.with_format("torch", columns=["query_token"]) + dataset = dataset.with_format("torch", columns=["query_token", "reference_response"]) dataset = dataset.shuffle(seed=local_seed) dataloader = DataLoader(dataset, batch_size=args.local_rollout_batch_size) reward_model, optimizer, dataloader = accelerator.prepare(reward_model, optimizer, dataloader) diff --git a/lm_human_preference_details/train_sft_accelerate_summarize.py b/lm_human_preference_details/train_sft_accelerate_summarize.py index 50a8684..e56af3b 100644 --- a/lm_human_preference_details/train_sft_accelerate_summarize.py +++ b/lm_human_preference_details/train_sft_accelerate_summarize.py @@ -30,10 +30,10 @@ @dataclass class SFTHParams: - gradient_accumulation_steps: int = 32 + gradient_accumulation_steps: int = 2 local_micro_batch_size: int = 8 noptepochs: int = 1 - lr: float = 0.00001 + lr: float = 6.35e-5 eps: float = 1e-5 total_episodes: tyro.conf.Suppress[int] = None local_batch_size:tyro.conf.Suppress[int] = None @@ -106,7 +106,7 @@ class Args: """the name of the pretrained model to use""" deepspeed: bool = False """Whether to use deepspeed to train the model""" - print_sample_output_freq: int = 40 + print_sample_output_freq: int = 80 """How often to print sample output""" save_path: str = "models/policy.pt" """Where to save the model""" From 1deff4ea22f28fc4b560c860eba80640cf816125 Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Mon, 25 Sep 2023 17:40:10 -0400 Subject: [PATCH 10/62] quick push --- .../train_sft_accelerate_summarize.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/lm_human_preference_details/train_sft_accelerate_summarize.py b/lm_human_preference_details/train_sft_accelerate_summarize.py index e56af3b..2ccd6ff 100644 --- a/lm_human_preference_details/train_sft_accelerate_summarize.py +++ b/lm_human_preference_details/train_sft_accelerate_summarize.py @@ -390,7 +390,7 @@ def train(args: Args): else: optimizer = optim.Adam(policy.parameters(), lr=args.sft.lr, eps=args.sft.eps) - def process_query_data1(x): + def process_query_data(x): return { **process_query(x, encoder=tokenizer, hparams=patch_h), "reference_response": tokenizer.encode( @@ -398,10 +398,10 @@ def process_query_data1(x): ), } - dataset = dataset.map(process_query_data1) + dataset = dataset.map(process_query_data) dataset = dataset.with_format("torch", columns=["query_token", "reference_response"]) dataset = dataset.shuffle(seed=local_seed) - test_dataset = test_dataset.map(process_query_data1) + test_dataset = test_dataset.map(process_query_data) test_dataset = test_dataset.with_format("torch", columns=["query_token", "reference_response"]) test_dataset = test_dataset.shuffle(seed=local_seed) dataloader = DataLoader(dataset, batch_size=args.sft.local_batch_size) From 8d89722ab17e35aaa3426ced0d68a26729376686 Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Wed, 27 Sep 2023 13:47:03 +0000 Subject: [PATCH 11/62] fix --- .../train_policy_accelerate_summarize.py | 18 ++++-- .../train_reward_accelerate_summarize.py | 57 +++++++++---------- .../train_sft_accelerate_summarize.py | 21 +++---- 3 files changed, 52 insertions(+), 44 deletions(-) diff --git a/lm_human_preference_details/train_policy_accelerate_summarize.py b/lm_human_preference_details/train_policy_accelerate_summarize.py index 2d8b456..0ffcd00 100644 --- a/lm_human_preference_details/train_policy_accelerate_summarize.py +++ b/lm_human_preference_details/train_policy_accelerate_summarize.py @@ -40,6 +40,7 @@ class AdaptiveKLParams: @dataclass class RewardHParams: kl_coef: float = 0.15 + use_adaptive_kl: bool = True adaptive_kl: Optional[AdaptiveKLParams] = field(default_factory=AdaptiveKLParams) trained_model: Optional[str] = "models/reward.pt" label_dataset: tyro.conf.Suppress[Optional[str]] = None @@ -136,6 +137,8 @@ class Args: """Whether to use deepspeed to train the model""" print_sample_output_freq: int = 10 """How often to print sample output""" + sft_model_path: str = "models/sft_policy.pt" + """Where to load the SFT model""" save_path: str = "models/policy.pt" """Where to save the model""" use_tensorflow_adam: bool = True @@ -513,6 +516,10 @@ def train(args: Args): AutoModelForCausalLM.from_pretrained(args.base_model, trust_remote_code=True) ) policy = AutoModelForCausalLMWithScalarHead(AutoModelForCausalLM.from_pretrained(args.base_model, trust_remote_code=True)) + if args.sft_model_path: + policy.lm_backbone.load_state_dict(torch.load(args.sft_model_path, map_location=device)) + ref_policy.lm_backbone.load_state_dict(torch.load(args.sft_model_path, map_location=device)) + print(f"loaded pretrained policy from {args.sft_model_path}") policy.lm_backbone.generation_config.eos_token_id = ( None # disable `pad_token_id` and `eos_token_id` because we just want to ) @@ -574,7 +581,7 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well generation_config = GenerationConfig( max_new_tokens=args.task.response_length, min_new_tokens=args.task.response_length, - temperature=args.task.temperature, + temperature=(args.task.temperature + 1e-7), top_k=0.0, top_p=1.0, do_sample=True, @@ -610,7 +617,7 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well output, full_values = forward(policy, query_responses, tokenizer) values = full_values[:, context_length - 1 : -1].squeeze(-1) logits = output.logits[:, context_length - 1 : -1] - logits /= args.task.temperature + logits /= (args.task.temperature + 1e-7) all_logprobs = F.log_softmax(logits, dim=-1) logprobs = torch.gather(all_logprobs, 2, responses.unsqueeze(-1)).squeeze(-1) del output, logits, all_logprobs @@ -618,7 +625,7 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well ref_output, _ = forward(ref_policy, query_responses, tokenizer) ref_logits = ref_output.logits[:, context_length - 1 : -1] - ref_logits /= args.task.temperature + ref_logits /= (args.task.temperature + 1e-7) ref_all_logprobs = F.log_softmax(ref_logits, dim=-1) ref_logprobs = torch.gather(ref_all_logprobs, 2, responses.unsqueeze(-1)).squeeze(-1) del ref_output, ref_logits, ref_all_logprobs @@ -752,7 +759,7 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well output, vpred_temp = forward(policy, mb_query_responses, tokenizer) logits = output.logits[:, context_length - 1 : -1] - logits /= args.task.temperature + logits /= (args.task.temperature + 1e-7) new_all_logprobs = F.log_softmax(logits, dim=-1) new_logprobs = torch.gather(new_all_logprobs, 2, mb_responses.unsqueeze(-1)).squeeze(-1) vpred = vpred_temp[:, context_length - 1 : -1].squeeze(-1) @@ -842,7 +849,8 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well writer.add_scalar("ppo/val/num_eos_tokens", (responses == tokenizer.eos_token_id).sum().item(), update) writer.add_scalar("ppo/lr", lrnow, update) writer.add_scalar("ppo/episode", global_step, update) - kl_ctl.update(mean_kl.item(), args.ppo.batch_size) + if args.rewards.use_adaptive_kl: + kl_ctl.update(mean_kl.item(), args.ppo.batch_size) del kl, mean_kl, mean_entropy, mean_non_score_reward, scores # save model diff --git a/lm_human_preference_details/train_reward_accelerate_summarize.py b/lm_human_preference_details/train_reward_accelerate_summarize.py index 43ed638..fabae52 100644 --- a/lm_human_preference_details/train_reward_accelerate_summarize.py +++ b/lm_human_preference_details/train_reward_accelerate_summarize.py @@ -124,7 +124,7 @@ class Args: """Whether, before training, to normalize the rewards on the policy to the scales on the training buffer. (For comparisons, just use mean 0, var 1.)""" normalize_after: bool = True """Whether, after training, to normalize the rewards on the ref policy to mean 0, var 1 (so the KL coefficient always has the same meaning).""" - print_sample_output_freq: int = 60 + print_sample_output_freq: int = 120 """How often to print sample output""" save_path: str = "models/reward.pt" """Where to save the model""" @@ -393,23 +393,24 @@ def normalize( device, lm_backbone, reward_model, - iter_dataloader, - generation_config, + dataloader, + validation_dataloader, ): + idx = 0 with torch.no_grad(): # reset reward scales accelerator.unwrap_model(reward_model).reward_gain.data.fill_(1.0) accelerator.unwrap_model(reward_model).reward_bias.data.fill_(0.0) # number of minibatches for computing the normalization statistics - n_batches = ceil_div(args.local_normalize_samples, args.local_rollout_batch_size) sample_queries_responses = [] - for _ in range(n_batches): - data = next(iter_dataloader) + for data in dataloader: + idx += len(data["query_token"]) queries = data["query_token"].to(device) reference_response = data["reference_response"].to(device) query_responses = torch.cat((queries, reference_response), dim=1) query_responses = right_padding_to_left_padding(query_responses, args.pad_token_id).to(device) sample_queries_responses.append(query_responses) + accelerator.print(f"====number of samples per device: {idx}") # compute reward statistics rewards = [] @@ -429,10 +430,8 @@ def normalize( accelerator.unwrap_model(reward_model).reward_bias.data = bias # validate normalization - n_batches = ceil_div(args.local_normalize_samples, args.local_rollout_batch_size) sample_queries_responses = [] - for _ in range(n_batches): - data = next(iter_dataloader) + for data in validation_dataloader: queries = data["query_token"].to(device) reference_response = data["reference_response"].to(device) query_responses = torch.cat((queries, reference_response), dim=1) @@ -530,6 +529,7 @@ def train(args: Args): else: optimizer = optim.Adam(reward_model.parameters(), lr=args.lr, eps=args.eps) dataset = load_dataset(args.task.query_dataset, split="train") + validation_dataset = load_dataset(args.task.query_dataset, split="validation") def process_query_data(x): return { @@ -543,6 +543,10 @@ def process_query_data(x): dataset = dataset.with_format("torch", columns=["query_token", "reference_response"]) dataset = dataset.shuffle(seed=local_seed) dataloader = DataLoader(dataset, batch_size=args.local_rollout_batch_size) + validation_dataset = validation_dataset.map(process_query_data) + validation_dataset = validation_dataset.with_format("torch", columns=["query_token", "reference_response"]) + validation_dataset = validation_dataset.shuffle(seed=local_seed) + validation_dataloader = DataLoader(validation_dataset, batch_size=args.local_rollout_batch_size) reward_model, optimizer, dataloader = accelerator.prepare(reward_model, optimizer, dataloader) if args.deepspeed: import deepspeed @@ -569,11 +573,7 @@ def process_query_data(x): else: untrained_model = untrained_model.to(device) - def repeat_generator(): # TODO: ideally we shuffle the dataloader as well - while True: - yield from dataloader - - iter_dataloader = iter(repeat_generator()) + iter_dataloader = iter(dataloader) generation_config = GenerationConfig( max_new_tokens=args.task.response_length, min_new_tokens=args.task.response_length, @@ -597,8 +597,8 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well device, untrained_model.lm_backbone, reward_model, - iter_dataloader, - generation_config, + dataloader, + validation_dataloader, ) print( "after normalization. " @@ -608,9 +608,9 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well # `label` has keys `['sample0', 'query', 'best', 'sample3', 'sample1', 'sample2']` label = load_dataset(args.label_dataset, "comparisons", split="train") - test_label = load_dataset(args.label_dataset, "comparisons", split="validation") - print("Num labels found in source:", len(label)) - print("training on", args.labels.num_train, "in batches of", args.local_batch_size) + validation_label = load_dataset(args.label_dataset, "comparisons", split="validation") + accelerator.print("Num labels found in source:", len(label)) + accelerator.print("training on", args.labels.num_train, "in batches of", args.local_batch_size) def process_response_data(x): return { @@ -624,10 +624,10 @@ def process_response_data(x): } label = label.map(process_response_data) - test_label = test_label.map(process_response_data) + validation_label = validation_label.map(process_response_data) # tokenizer.encode(label[0]["summaries"][0]["text"]) - print("===training reward model===") + accelerator.print("===training reward model===") all_inds = np.random.permutation(args.labels.num_train) # ensure that all processes have the same shuffled indices all_inds = broadcast(torch.tensor(all_inds, device=device), 0) @@ -684,13 +684,13 @@ def process_response_data(x): writer.add_scalar("train/loss", accelerator.gather(losses).mean().item(), global_step) writer.add_scalar("train/accuracy", train_accuracy, global_step) writer.add_scalar("train/lr", lr, global_step) - print("train/accuracy", train_accuracy) + accelerator.print("train/accuracy", train_accuracy) if args.print_sample_output_freq > 0 and global_step % args.print_sample_output_freq == 0: with torch.no_grad(): - # eval on test_label, some duplicate code (I don't want to make the training loop into a function...) + # eval on validation_label, some duplicate code (I don't want to make the training loop into a function...) test_accuracies = [] - eval_len = 200 # len(test_label) + eval_len = len(validation_label) len_labels = (eval_len // args.batch_size) * args.batch_size # in case the last batch is not full new_all_inds = np.arange(len_labels) for start in range(0, len_labels, args.batch_size): @@ -700,7 +700,7 @@ def process_response_data(x): for micro_batch_start in range(0, args.local_batch_size, args.local_micro_batch_size): micro_batch_end = micro_batch_start + args.local_micro_batch_size micro_batch_inds = b_inds[micro_batch_start:micro_batch_end] - mb_data = label[micro_batch_inds] + mb_data = validation_label[micro_batch_inds] mb_query = torch.from_numpy(np.stack(mb_data["query_token"])).to(device) mb_best = torch.from_numpy(np.stack(mb_data["choice"])).to(device) mb_responses = [ @@ -726,8 +726,7 @@ def process_response_data(x): test_accuracies.append(accuracy) test_accuracy = accelerator.gather(torch.stack(test_accuracies).mean()).mean().item() writer.add_scalar("test/accuracy", test_accuracy, global_step) - if accelerator.is_main_process: - print("test/accuracy", test_accuracy, global_step) + accelerator.print("test/accuracy", test_accuracy, global_step) # the part below is testing out some generations and KLs, not presented in the original code data = next(iter_dataloader) @@ -801,8 +800,8 @@ def process_response_data(x): device, untrained_model.lm_backbone, reward_model, - iter_dataloader, - generation_config, + dataloader, + validation_dataloader, ) print( "after normalization. " diff --git a/lm_human_preference_details/train_sft_accelerate_summarize.py b/lm_human_preference_details/train_sft_accelerate_summarize.py index 2ccd6ff..d0abedd 100644 --- a/lm_human_preference_details/train_sft_accelerate_summarize.py +++ b/lm_human_preference_details/train_sft_accelerate_summarize.py @@ -30,8 +30,8 @@ @dataclass class SFTHParams: - gradient_accumulation_steps: int = 2 - local_micro_batch_size: int = 8 + gradient_accumulation_steps: int = 16 + local_micro_batch_size: int = 1 noptepochs: int = 1 lr: float = 6.35e-5 eps: float = 1e-5 @@ -108,7 +108,7 @@ class Args: """Whether to use deepspeed to train the model""" print_sample_output_freq: int = 80 """How often to print sample output""" - save_path: str = "models/policy.pt" + save_path: str = "models/sft_policy.pt" """Where to save the model""" use_tensorflow_adam: bool = True """Whether to use tensorflow-style Adam optimizer instead of PyTorch's""" @@ -394,7 +394,8 @@ def process_query_data(x): return { **process_query(x, encoder=tokenizer, hparams=patch_h), "reference_response": tokenizer.encode( - x["summary"], padding="max_length", max_length=args.task.response_length, truncation=True + f" {x['summary']}", padding="max_length", max_length=args.task.response_length, truncation=True, + # with an extra leading space to account for the space between the query and response ), } @@ -404,7 +405,7 @@ def process_query_data(x): test_dataset = test_dataset.map(process_query_data) test_dataset = test_dataset.with_format("torch", columns=["query_token", "reference_response"]) test_dataset = test_dataset.shuffle(seed=local_seed) - dataloader = DataLoader(dataset, batch_size=args.sft.local_batch_size) + dataloader = DataLoader(dataset, batch_size=args.sft.local_micro_batch_size) policy, optimizer, dataloader = accelerator.prepare(policy, optimizer, dataloader) iter_dataloader = iter(dataloader) # WARNING: even with `max_new_tokens` and `min_new_tokens` set to the same value, the number of tokens generated @@ -420,13 +421,13 @@ def process_query_data(x): print("===training policy===") global_step = 0 - loss_stats = torch.zeros(args.sft.gradient_accumulation_steps, device=device) test_data = test_dataset[0:10] test_data = {k: v.to(device) for k, v in test_data.items()} + loss_stats = torch.zeros(args.sft.gradient_accumulation_steps, device=device) gradient_accumulation_idx = 0 for update in range(1, args.sft.num_updates + 1): - print(update, global_step) global_step += 1 * args.sft.batch_size + accelerator.print(f"update {update}, global_step {global_step}") frac = 1.0 - (update - 1.0) / args.sft.num_updates lrnow = frac * args.sft.lr optimizer.param_groups[0]["lr"] = lrnow @@ -443,7 +444,7 @@ def process_query_data(x): optimizer.zero_grad() loss_stats[gradient_accumulation_idx] = loss gradient_accumulation_idx = (gradient_accumulation_idx + 1) % args.sft.gradient_accumulation_steps - if (update - 1) % args.sft.gradient_accumulation_steps: + if update > 1 and (update - 1) % args.sft.gradient_accumulation_steps == 0: writer.add_scalar("loss", accelerator.gather(loss_stats).mean().item(), update) writer.add_scalar("lr", lrnow, update) if (update - 1) % args.print_sample_output_freq * args.sft.gradient_accumulation_steps == 0: @@ -451,7 +452,7 @@ def process_query_data(x): test_queries = test_data["query_token"] test_reference_responses = test_data["reference_response"] test_queries = right_padding_to_left_padding(test_queries, tokenizer.pad_token_id) - generated_responses = generate(policy, test_queries, tokenizer, generation_config) + generated_responses = generate(accelerator.unwrap_model(policy), test_queries, tokenizer, generation_config) try: all_decode_test_queries = tokenizer.batch_decode(test_queries, skip_special_tokens=True) @@ -479,7 +480,7 @@ def process_query_data(x): # save model if accelerator.is_main_process and args.save_path: os.makedirs(os.path.dirname(args.save_path), exist_ok=True) - torch.save(policy.state_dict(), args.save_path) + torch.save(accelerator.unwrap_model(policy).state_dict(), args.save_path) if args.upload_model: repo_name = f"{args.exp_name}__{args.rewards.label_dataset}__seed{args.seed}__{int(time.time())}" From 7866eb232f789d0278ac3438df20a84500c31695 Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Tue, 3 Oct 2023 18:11:48 +0000 Subject: [PATCH 12/62] push changes so far --- ...in_policy_accelerate_summarize_separate.py | 1000 +++++++++++++++++ .../train_reward_accelerate_summarize.py | 76 +- .../train_sft_accelerate_summarize.py | 79 +- 3 files changed, 1088 insertions(+), 67 deletions(-) create mode 100644 lm_human_preference_details/train_policy_accelerate_summarize_separate.py diff --git a/lm_human_preference_details/train_policy_accelerate_summarize_separate.py b/lm_human_preference_details/train_policy_accelerate_summarize_separate.py new file mode 100644 index 0000000..54e8115 --- /dev/null +++ b/lm_human_preference_details/train_policy_accelerate_summarize_separate.py @@ -0,0 +1,1000 @@ +import os +import random +import time +from dataclasses import asdict, dataclass, field +from types import SimpleNamespace +from typing import List, Optional + +import numpy as np +import pandas as pd +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +import tyro +from accelerate import Accelerator +from accelerate.state import AcceleratorState +from datasets import load_dataset +from rich.console import Console +from rich.pretty import pprint +from rich.table import Table +from torch import Tensor, optim +from torch.optim.optimizer import ( + _dispatch_sqrt, + _get_value, + _use_grad_for_differentiable, +) +from torch.utils.data import DataLoader +from torch.utils.tensorboard import SummaryWriter +from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig + +from lm_human_preference_details.data import process_query + +INVALID_LOGPROB = 1.0 + +@dataclass +class AdaptiveKLParams: + target: float = 6.0 + horizon: int = 10000 # in episodes + + +@dataclass +class RewardHParams: + kl_coef: float = 0.15 + use_adaptive_kl: bool = True + adaptive_kl: Optional[AdaptiveKLParams] = field(default_factory=AdaptiveKLParams) + trained_model: Optional[str] = "models/reward.pt" + label_dataset: tyro.conf.Suppress[Optional[str]] = None + + +@dataclass +class PpoHParams: + total_episodes: int = 1000000 + local_batch_size: int = 64 + local_mini_batch_size: tyro.conf.Suppress[int] = None + batch_size: tyro.conf.Suppress[int] = None + mini_batch_size: tyro.conf.Suppress[int] = None + gradient_accumulation_steps: int = 1 + """gradient accumulation steps""" + local_micro_batch_size: tyro.conf.Suppress[int] = None + """per rank micro batch size""" + world_size: tyro.conf.Suppress[int] = None + batch_size: tyro.conf.Suppress[int] = None + minibatch_size: tyro.conf.Suppress[int] = None + num_updates: tyro.conf.Suppress[int] = None + nminibatches: int = 1 + noptepochs: int = 4 + lr: float = 0.00001 + eps: float = 1e-5 + vf_coef: float = 0.1 + cliprange: float = 0.2 + cliprange_value: float = 0.2 + gamma: float = 1 + lam: float = 0.95 + whiten_rewards: bool = True + + +@dataclass +class TaskHParams: + # Query params + query_length: int = 512 + query_dataset: str = "vwxyzjn/summarize_from_feedback_tldr_3_filtered" + + query_format_str: Optional[str] = "SUBREDDIT: r/{subreddit}\n\nTITLE: {title}\n\nPOST: {post}\n\nTL;DR:" + query_truncate_field: Optional[str] = "post" + query_truncate_text: Optional[str] = "\n" + query_padding: Optional[str] = None # defaults to repeated spaces + query_pad_side: Optional[str] = "left" + + # Response params + response_length: int = 48 + + # Truncate response after the first occurrence of this token at or after index after when sampling. + truncate_token: int = 50256 # EOS token + truncate_after: int = 16 + penalty_reward_value: int = -1 + + # LM params + temperature: float = 0.7 + + +# a patch +@dataclass +class TaskQueryHParams: + length: int = None + dataset: str = None + format_str: Optional[str] = None # if underlying dataset yields dicts, can format arbitrarily + truncate_field: Optional[str] = None + truncate_text: Optional[str] = None + padding: Optional[str] = None # defaults to repeated spaces + pad_side: Optional[str] = None + + +@dataclass +class Args: + # common args + exp_name: str = os.path.basename(__file__)[: -len(".py")] + """the name of this experiment""" + seed: int = 1 + """seed of the experiment""" + track: bool = False + """if toggled, this experiment will be tracked with Weights and Biases""" + wandb_project_name: str = "cleanrl" + """the wandb's project name""" + wandb_entity: Optional[str] = None + """the entity (team) of wandb's project""" + cuda: bool = True + """Whether to use cuda if available.""" + run_name: tyro.conf.Suppress[str] = None + """TO BE FILLED: a unique name of this run""" + upload_model: bool = False + "whether to upload the saved model to huggingface" + hf_entity: str = "" + "the user or org name of the model repository from the Hugging Face Hub" + + base_model: str = "gpt2" + """the name of the pretrained model to use""" + deepspeed: bool = False + """Whether to use deepspeed to train the model""" + print_sample_output_freq: int = 10 + """How often to print sample output""" + sft_model_path: str = "models/sft_policy.pt" + """Where to load the SFT model""" + save_path: str = "models/policy.pt" + """Where to save the model""" + use_tensorflow_adam: bool = True + """Whether to use tensorflow-style Adam optimizer instead of PyTorch's""" + task: TaskHParams = field(default_factory=TaskHParams) + rewards: RewardHParams = field(default_factory=RewardHParams) + ppo: PpoHParams = field(default_factory=PpoHParams) + + +def first_true_indices(bools, dtype=torch.long): + """ + Takes an N-dimensional bool tensor and returns an (N-1)-dimensional tensor of integers giving + the position of the first True in each "row". + + Returns the length of the rows (bools.size(-1)) if no element is True in a given row. + """ + row_len = bools.size(-1) + zero_or_index = row_len * (~bools).type(dtype) + torch.arange(row_len, dtype=dtype, device=bools.device) + return torch.min(zero_or_index, dim=-1).values + + +def print_rich_table(title: str, df: pd.DataFrame, console: Console) -> Table: + table = Table(show_lines=True) + for column in df.columns: + table.add_column(column) + for _, row in df.iterrows(): + table.add_row(*row.astype(str).tolist()) + console.rule(f"[bold red]{title}") + console.print(table) + + +def _single_tensor_adam( + params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_avg_sqs: List[Tensor], + max_exp_avg_sqs: List[Tensor], + state_steps: List[Tensor], + grad_scale: Optional[Tensor], + found_inf: Optional[Tensor], + *, + amsgrad: bool, + beta1: float, + beta2: float, + lr: float, + weight_decay: float, + eps: float, + maximize: bool, + capturable: bool, + differentiable: bool, +): + assert grad_scale is None and found_inf is None + + for i, param in enumerate(params): + grad = grads[i] if not maximize else -grads[i] + exp_avg = exp_avgs[i] + exp_avg_sq = exp_avg_sqs[i] + step_t = state_steps[i] + # update step + step_t += 1 + # Decay the first and second moment running average coefficient + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2) + step = _get_value(step_t) + + ### pytorch adam implementation: + # bias_correction1 = 1 - beta1 ** step + # bias_correction2 = 1 - beta2 ** step + # step_size = lr / bias_correction1 + # bias_correction2_sqrt = _dispatch_sqrt(bias_correction2) + # denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps) + # param.addcdiv_(exp_avg, denom, value=-step_size) + + ### tensorflow adam implementation: + lr_t = lr * _dispatch_sqrt(1 - beta2**step) / (1 - beta1**step) + denom = exp_avg_sq.sqrt().add_(eps) + param.addcdiv_(exp_avg, denom, value=-lr_t) + + +def adam( + params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_avg_sqs: List[Tensor], + max_exp_avg_sqs: List[Tensor], + state_steps: List[Tensor], + # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627 + # setting this as kwarg for now as functional API is compiled by torch/distributed/optim + foreach: Optional[bool] = None, + capturable: bool = False, + differentiable: bool = False, + fused: Optional[bool] = None, + grad_scale: Optional[Tensor] = None, + found_inf: Optional[Tensor] = None, + *, + amsgrad: bool, + beta1: float, + beta2: float, + lr: float, + weight_decay: float, + eps: float, + maximize: bool, +): + func = _single_tensor_adam + + func( + params, + grads, + exp_avgs, + exp_avg_sqs, + max_exp_avg_sqs, + state_steps, + amsgrad=amsgrad, + beta1=beta1, + beta2=beta2, + lr=lr, + weight_decay=weight_decay, + eps=eps, + maximize=maximize, + capturable=capturable, + differentiable=differentiable, + grad_scale=grad_scale, + found_inf=found_inf, + ) + + +class AdamTensorFlowStyle(optim.Adam): + @_use_grad_for_differentiable + def step(self, closure=None): + self._cuda_graph_capture_health_check() + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + params_with_grad = [] + grads = [] + exp_avgs = [] + exp_avg_sqs = [] + max_exp_avg_sqs = [] + state_steps = [] + beta1, beta2 = group["betas"] + + self._init_group( + group, + params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + max_exp_avg_sqs, + state_steps, + ) + + adam( + params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + max_exp_avg_sqs, + state_steps, + amsgrad=group["amsgrad"], + beta1=beta1, + beta2=beta2, + lr=group["lr"], + weight_decay=group["weight_decay"], + eps=group["eps"], + maximize=group["maximize"], + foreach=group["foreach"], + capturable=group["capturable"], + differentiable=group["differentiable"], + fused=group["fused"], + grad_scale=getattr(self, "grad_scale", None), + found_inf=getattr(self, "found_inf", None), + ) + + return loss + + +class AdaptiveKLController: + def __init__(self, init_kl_coef: float, hparams: AdaptiveKLParams): + self.value = init_kl_coef + self.hparams = hparams + + def update(self, current, n_steps): + target = self.hparams.target + proportional_error = np.clip(current / target - 1, -0.2, 0.2) + mult = 1 + proportional_error * n_steps / self.hparams.horizon + self.value *= mult + + +def layer_init(layer, std=np.sqrt(2), bias_const=0.0): + torch.nn.init.normal_(layer.weight, std=std) + torch.nn.init.constant_(layer.bias, val=bias_const) + return layer + + +def whiten(values, shift_mean=True): + # `unbiased=False` matches TF `tf.nn.moments`'s setting + mean, var = torch.mean(values), torch.var(values, unbiased=False) + whitened = (values - mean) * torch.rsqrt(var + 1e-8) + if not shift_mean: + whitened += mean + return whitened + + +class AutoModelForCausalLMWithRewardHead(nn.Module): + def __init__(self, lm_backbone): + super().__init__() + self.lm_backbone = lm_backbone + self.scalar_head = layer_init( + nn.Linear(lm_backbone.config.hidden_size, 1), + std=1 / np.sqrt(lm_backbone.config.hidden_size + 1), + ) + self.reward_gain = torch.nn.Parameter(torch.tensor(1.0), requires_grad=True) + self.reward_bias = torch.nn.Parameter(torch.tensor(0.0), requires_grad=True) + + def forward(self, **kwargs): + output = self.lm_backbone(**kwargs) + reward_latents = output.hidden_states[-1] + # shape: [batch_size, length, hidden_size] + last_reward_latents = reward_latents + # shape: [batch_size, hidden_size] + reward = self.scalar_head(last_reward_latents) + # shape: [batch_size, 1] + reward = self.reward_gain * reward + self.reward_bias + return output, reward + + +# taken from https://github.com/OpenLMLab/MOSS-RLHF/blob/40b91eb2f2b71b16919addede0341d2bef70825d/ppo/ppo_trainer.py#L29 +# we did this we can do a single `model = accelerator.prepare(model)` +class PolicyAndValueWrapper(nn.Module): + def __init__(self, policy, critic) -> None: + super().__init__() + self.policy = policy + self.critic = critic + + def forward(self, **kwargs): + return self.policy(**kwargs), self.critic(**kwargs) + + +def right_padding_to_left_padding(tokens, pad_id): + """Convert from right padding to left padding.""" + assert tokens.ndim == 2 + return torch.tensor( + [[pad_id] * (row == pad_id).sum() + [x for x in row if x != pad_id] for row in tokens], + device=tokens.device, + ) + + +def ceil_div(a, b): + return (a - 1) // b + 1 + + +def exact_div(a, b): + q = a // b + if a != q * b: + raise ValueError(f"Inexact division: {a} / {b} = {a / b}") + return q + + +def generate(lm_backbone, queries, tokenizer, generation_config): + """generate in a way that does not affect padding tokens""" + context_length = queries.shape[1] + attention_mask = queries != tokenizer.pad_token_id + input_ids = queries.clone() + input_ids[~attention_mask] = 0 # set padding tokens to 0 + output = lm_backbone.generate( + input_ids=input_ids, + attention_mask=attention_mask, + # position_ids=attention_mask.cumsum(1) - attention_mask.long(), # generation collapsed if this was turned on. TODO: why does generation collapse with this? + generation_config=generation_config, + return_dict_in_generate=True, + ) + # restore padding tokens + return torch.cat((queries, output.sequences[:, context_length:]), dim=1) + + +def get_reward(reward_model, query_responses, args): + attention_mask = query_responses != args.pad_token_id + position_ids = attention_mask.cumsum(1) - attention_mask.long() # exclusive cumsum + input_ids = query_responses.clone() + input_ids[~attention_mask] = 0 + return reward_model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + return_dict=True, + output_hidden_states=True, + ) + + +def get_reward_complete(reward_model, query_responses, args): + reward = get_reward(reward_model, query_responses, args)[1] + last_response_indices = first_true_indices(query_responses == args.pad_token_id) - 1 + last_response_indices = torch.max( + last_response_indices, + torch.zeros([1], dtype=last_response_indices.dtype, device=query_responses.device), + ) + return reward[:, :, 0].gather(1, last_response_indices.unsqueeze(1)).view(-1) + + +def normalize( + tokenizer, + accelerator, + device, + lm_backbone, + reward_model, + dataloader, + validation_dataloader, +): + idx = 0 + with torch.no_grad(): + # reset reward scales + # accelerator.unwrap_model(reward_model).reward_gain.data.fill_(1.0) + # accelerator.unwrap_model(reward_model).reward_bias.data.fill_(0.0) + # number of minibatches for computing the normalization statistics + rewards = [] + for data in dataloader: + idx += len(data["query_token"]) + queries = data["query_token"].to(device) + queries = right_padding_to_left_padding(queries, tokenizer.pad_token_id).to(device) + reference_response = data["reference_response"].to(device) + query_responses = torch.cat((queries, reference_response), dim=1) + score = get_reward_complete(reward_model, query_responses, tokenizer) + accelerator.print(score.shape, accelerator.gather(score).mean()) + rewards.append(score) + accelerator.print(f"====number of samples per device: {idx}") + rewards = torch.cat(rewards) + rewards = accelerator.gather(rewards) + mean, std = rewards.mean(), rewards.std() + print(f"mean: {mean}, std: {std}") + + # reward normalization + target_mean, target_std = torch.tensor(0.0, device=device), torch.tensor(1.0, device=device) + gain = target_std / std + bias = target_mean - gain * mean + print(f"gain: {gain}, bias: {bias}") + accelerator.unwrap_model(reward_model).reward_gain.data = gain + accelerator.unwrap_model(reward_model).reward_bias.data = bias + + # validate normalization + rewards = [] + for data in validation_dataloader: + queries = data["query_token"].to(device) + queries = right_padding_to_left_padding(queries, tokenizer.pad_token_id).to(device) + reference_response = data["reference_response"].to(device) + query_responses = torch.cat((queries, reference_response), dim=1) + score = get_reward_complete(reward_model, query_responses, tokenizer) + rewards.append(score) + rewards = torch.cat(rewards) + rewards = accelerator.gather(rewards) + mean, std = rewards.mean(), rewards.std() + print(f"after mean: {mean}, after std: {std}") + + +def forward(policy, query_responses, tokenizer): + attention_mask = query_responses != tokenizer.pad_token_id + position_ids = attention_mask.cumsum(1) - attention_mask.long() # exclusive cumsum + input_ids = query_responses.clone() + input_ids[~attention_mask] = 0 + return policy( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + return_dict=True, + output_hidden_states=True, + ) + + +# def train(args: Args): +if __name__ == "__main__": + args = tyro.cli(Args) + accelerator = Accelerator(gradient_accumulation_steps=args.ppo.gradient_accumulation_steps) + args.ppo.world_size = accelerator.num_processes + args.ppo.batch_size = int(args.ppo.local_batch_size * args.ppo.world_size) + args.ppo.minibatch_size = exact_div(args.ppo.batch_size, args.ppo.nminibatches) + args.ppo.local_mini_batch_size = exact_div(args.ppo.local_batch_size, args.ppo.nminibatches) + args.ppo.local_micro_batch_size = exact_div(args.ppo.local_mini_batch_size, args.ppo.gradient_accumulation_steps) + patch_h = TaskQueryHParams( + length=args.task.query_length, + dataset=args.task.query_dataset, + format_str=args.task.query_format_str, + truncate_field=args.task.query_truncate_field, + truncate_text=args.task.query_truncate_text, + padding=args.task.query_padding, + pad_side=args.task.query_pad_side, + ) + if args.ppo.whiten_rewards: + assert ( + args.ppo.local_mini_batch_size >= 8 + ), f"Per-rank minibatch size {args.ppo.local_mini_batch_size} is insufficient for whitening" + # `per_rank_rollout_batch_size` is our `args.ppo.local_batch_size` + # `per_rank_minibatch_size` is our `args.ppo.local_mini_batch_size` + args.ppo.num_updates = args.ppo.total_episodes // args.ppo.batch_size + + console = Console(force_terminal=True) + run_name = f"{args.exp_name}__{args.seed}__{int(time.time())}" + writer = SimpleNamespace() # dummy writer + writer.add_scalar = lambda x, y, z: None + writer.add_histogram = lambda x, y, z: None + if accelerator.is_main_process: + if args.track: + import wandb + + wandb.init( + project=args.wandb_project_name, + entity=args.wandb_entity, + sync_tensorboard=True, + config=asdict(args), + name=run_name, + save_code=True, + ) + wandb.run.log_code(".") + writer = SummaryWriter(f"runs/{run_name}") + writer.add_text( + "hyperparameters", + "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])), + ) + pprint(args) + device = accelerator.device + local_seed = args.seed + accelerator.process_index * 100003 # Prime + random.seed(local_seed) + np.random.seed(local_seed) + torch.manual_seed(local_seed) + torch.backends.cudnn.deterministic = True + tokenizer = AutoTokenizer.from_pretrained( + args.base_model, + padding_side="right", + trust_remote_code=True, + ) + # we use the padding token manually but do not resize the token embedding of the model + tokenizer.add_special_tokens({"pad_token": "[PAD]"}) + reward_model = AutoModelForCausalLMWithRewardHead( + AutoModelForCausalLM.from_pretrained(args.base_model, trust_remote_code=True) + ) + critic = AutoModelForCausalLMWithRewardHead( + AutoModelForCausalLM.from_pretrained(args.base_model, trust_remote_code=True) + ) + if args.rewards.trained_model: + reward_model.load_state_dict(torch.load(args.rewards.trained_model, map_location=device)) + critic.load_state_dict(torch.load(args.rewards.trained_model, map_location=device)) + print(f"loaded pretrained reward model from {args.rewards.trained_model}") + # each class should have a separate pretrained model that do not share weights + ref_policy = AutoModelForCausalLM.from_pretrained(args.base_model, trust_remote_code=True) + policy = AutoModelForCausalLM.from_pretrained(args.base_model, trust_remote_code=True) + if args.sft_model_path: + policy.load_state_dict(torch.load(args.sft_model_path, map_location=device)) + ref_policy.load_state_dict(torch.load(args.sft_model_path, map_location=device)) + print(f"loaded pretrained policy from {args.sft_model_path}") + policy.generation_config.eos_token_id = ( + None # disable `pad_token_id` and `eos_token_id` because we just want to + ) + policy.generation_config.pad_token_id = None # generate tokens without truncation / padding + model = PolicyAndValueWrapper(policy, critic) + if args.use_tensorflow_adam: + optimizer = AdamTensorFlowStyle(model.parameters(), lr=args.ppo.lr, eps=args.ppo.eps) + else: + optimizer = optim.Adam(model.parameters(), lr=args.ppo.lr, eps=args.ppo.eps) + dataset = load_dataset(args.task.query_dataset, split="train") + validation_dataset = load_dataset(args.task.query_dataset, split="validation") + + def process_query_data(x): + return { + **process_query(x, encoder=tokenizer, hparams=patch_h), + "reference_response": tokenizer.encode( + f" {x['summary']}", padding="max_length", max_length=args.task.response_length, truncation=True, + # with an extra leading space to account for the space between the query and response + ), + } + + dataset = dataset.map(process_query_data) + dataset = dataset.with_format("torch", columns=["query_token", "reference_response"]) + dataset = dataset.shuffle(seed=local_seed) + dataloader = DataLoader(dataset, batch_size=args.ppo.local_batch_size) + validation_dataset = validation_dataset.map(process_query_data) + validation_dataset = validation_dataset.with_format("torch", columns=["query_token", "reference_response"]) + validation_dataset = validation_dataset.shuffle(seed=local_seed) + validation_dataloader = DataLoader(validation_dataset, batch_size=args.ppo.local_batch_size) + model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader) + if args.deepspeed: + import deepspeed + + deepspeed_states = AcceleratorState().deepspeed_plugin + # deepspeed_states.deepspeed_config["train_micro_batch_size_per_gpu"] = args.ppo.local_micro_batch_size + # deepspeed_states.deepspeed_config["checkpoint"] = {"use_node_local_storage": True} + eval_ds_config = { + "train_micro_batch_size_per_gpu": deepspeed_states.deepspeed_config["train_micro_batch_size_per_gpu"], + # "steps_per_print": 10, + # "zero_optimization": { + # "stage": stage, + # "stage3_param_persistence_threshold": 1e4, + # "offload_param": { + # "device": off_load_device + # } + # }, + "bf16": {"enabled": True}, + "prescale_gradients": False, + "wall_clock_breakdown": False, + } + reward_model, *_ = deepspeed.initialize(model=reward_model, config=eval_ds_config) + reward_model.eval() + ref_policy, *_ = deepspeed.initialize(model=ref_policy, config=eval_ds_config) + ref_policy.eval() + else: + ref_policy = ref_policy.to(device) + reward_model = reward_model.to(device) + + def repeat_generator(): # TODO: ideally we shuffle the dataloader as well + while True: + yield from dataloader + + iter_dataloader = iter(repeat_generator()) + kl_ctl = AdaptiveKLController(args.rewards.kl_coef, hparams=args.rewards.adaptive_kl) + # WARNING: even with `max_new_tokens` and `min_new_tokens` set to the same value, the number of tokens generated + # may not be the same. TODO: investigate further, we just want to generate a fixed number of tokens + generation_config = GenerationConfig( + max_new_tokens=args.task.response_length, + min_new_tokens=args.task.response_length, + temperature=(args.task.temperature + 1e-7), + top_k=0.0, + top_p=1.0, + do_sample=True, + ) + + # print("===Normalize reward model *before* training===") + # print( + # "before normalization. " + # + f"Gain: {accelerator.unwrap_model(reward_model).reward_gain.data}" + # + f" Bias: {accelerator.unwrap_model(reward_model).reward_bias.data}" + # ) + + # normalize( + # tokenizer, + # accelerator, + # device, + # reward_model, + # reward_model, + # dataloader, + # validation_dataloader, + # ) + # print( + # "after normalization. " + # + f"Gain: {accelerator.unwrap_model(reward_model).reward_gain.data}" + # + f" Bias: {accelerator.unwrap_model(reward_model).reward_bias.data}" + # ) + # # # save model + # # if args.save_path: + # # os.makedirs(os.path.dirname("models/correct_reward.pt"), exist_ok=True) + # # torch.save(accelerator.unwrap_model(reward_model).state_dict(), "models/correct_reward.pt") + # raise + + print("===training policy===") + global_step = 0 + stats_shape = (args.ppo.noptepochs, args.ppo.nminibatches, args.ppo.gradient_accumulation_steps) + approxkls_stats = torch.zeros(stats_shape, device=device) + clipfracs_stats = torch.zeros(stats_shape, device=device) + pg_losses_stats = torch.zeros(stats_shape, device=device) + vf_losses_stats = torch.zeros(stats_shape, device=device) + vf_clipfrac_stats = torch.zeros(stats_shape, device=device) + entropies_stats = torch.zeros(stats_shape, device=device) + for update in range(1, args.ppo.num_updates + 1): + global_step += 1 * args.ppo.batch_size + frac = 1.0 - (update - 1.0) / args.ppo.num_updates + lrnow = frac * args.ppo.lr + optimizer.param_groups[0]["lr"] = lrnow + data = next(iter_dataloader) + with torch.no_grad(): + """ + let's use `P` to denote the padding token, `T` to denote the truncate token, and `X` to denote the + actual tokens. + queries: `PPXXX` + query_responses: `PPXXX,XXXXTXX` # the space separates the query and response + response: `XXXXTXX` + postprocessed_responses: `XXXXTXX` -> `XXXXTPP` + postprocessed_query_responses: `PPXXX,XXXXTPP` + scores: ↑ # corresponding to this `X` token + + """ + queries = data["query_token"].to(device) + reference_responses = data["reference_response"].to(device) + queries = right_padding_to_left_padding(data["query_token"], tokenizer.pad_token_id).to(device) + query_reference_responses = torch.cat((queries, reference_responses), dim=1) + + + + reference_scores = get_reward_complete(reward_model, query_reference_responses, tokenizer) + accelerator.print(accelerator.gather(reference_scores).mean()) + + + query_responses = generate( + accelerator.unwrap_model(model).policy, + queries, + tokenizer, + generation_config, + ) + context_length = queries.shape[1] + responses = query_responses[:, context_length:] + + output = forward(accelerator.unwrap_model(model).policy, query_responses, tokenizer) + full_values = get_reward(accelerator.unwrap_model(model).critic, query_responses, tokenizer)[1] + values = full_values[:, context_length - 1 : -1].squeeze(-1) + logits = output.logits[:, context_length - 1 : -1] + logits /= (args.task.temperature + 1e-7) + all_logprobs = F.log_softmax(logits, dim=-1) + logprobs = torch.gather(all_logprobs, 2, responses.unsqueeze(-1)).squeeze(-1) + del output, logits, all_logprobs + torch.cuda.empty_cache() + + ref_output = forward(ref_policy, query_responses, tokenizer) + ref_logits = ref_output.logits[:, context_length - 1 : -1] + ref_logits /= (args.task.temperature + 1e-7) + ref_all_logprobs = F.log_softmax(ref_logits, dim=-1) + ref_logprobs = torch.gather(ref_all_logprobs, 2, responses.unsqueeze(-1)).squeeze(-1) + del ref_output, ref_logits, ref_all_logprobs + torch.cuda.empty_cache() + + # **Response Processing** + # 1. truncate at the first occurrence of `truncate_token` that appears at or after + # position truncate_after in the responses + # https://github.com/openai/lm-human-preferences/blob/cbfd210bb8b08f6bc5c26878c10984b90f516c66/lm_human_preferences/train_policy.py#L378 + truncate_token_mask = responses == args.task.truncate_token + truncate_after_or_token_mask = torch.cat( + [ + torch.zeros_like(truncate_token_mask)[:, : args.task.truncate_after], + truncate_token_mask[:, args.task.truncate_after :], + ], + dim=1, + ) + truncate_mask = (torch.cumsum(truncate_after_or_token_mask, dim=1) - truncate_after_or_token_mask.long()).bool() + postprocessed_responses = torch.where( + truncate_mask, + torch.full_like(responses, tokenizer.pad_token_id), + responses, + ) + del truncate_token_mask, truncate_after_or_token_mask, truncate_mask + torch.cuda.empty_cache() + + # 2. run reward model on the truncated responses + postprocessed_query_responses = torch.cat((queries, postprocessed_responses), 1) + padding_mask = (postprocessed_query_responses == tokenizer.pad_token_id)[:, context_length - 1 : -1] + logprobs = torch.masked_fill(logprobs, padding_mask, INVALID_LOGPROB) + ref_logprobs = torch.masked_fill(ref_logprobs, padding_mask, INVALID_LOGPROB) + values = torch.masked_fill(values, padding_mask, 0) + + scores = get_reward_complete(reward_model, postprocessed_query_responses, tokenizer) + reference_scores = get_reward_complete(reward_model, query_reference_responses, tokenizer) + + # 3. filter response. Ensure that the sample contains truncate_token + # responses not passing that filter will receive a low (fixed) score + # only query humans on responses that pass that filter + matches_token = postprocessed_responses[:, args.task.truncate_after :] == args.task.truncate_token + filter_mask = torch.any(matches_token, dim=-1) + scores = torch.where( + filter_mask, + scores, + torch.full_like(scores, args.task.penalty_reward_value), + ) + del matches_token, filter_mask + torch.cuda.empty_cache() + + # 4. compute rewards + kl = logprobs - ref_logprobs + non_score_reward = -kl_ctl.value * kl + rewards = non_score_reward.clone() + rewards[:, -1] += scores + + # 5. whiten rewards + if args.ppo.whiten_rewards: + rewards = whiten(rewards, shift_mean=False) + + if args.print_sample_output_freq > 0 and (update - 1) % args.print_sample_output_freq == 0: + try: + all_decode_queries = tokenizer.batch_decode(queries, skip_special_tokens=True) + all_postprocessed_query_responses = tokenizer.batch_decode( + postprocessed_query_responses, skip_special_tokens=True + ) + all_postprocessed_responses = [ + x[len(y) :] for x, y in zip(all_postprocessed_query_responses, all_decode_queries) + ] + all_reference_responses = tokenizer.batch_decode(reference_responses, skip_special_tokens=True) + + kl_sum = kl.sum(axis=1) + all_df = pd.DataFrame( + { + "query": all_decode_queries, + "response": all_postprocessed_responses, + "reference_responses": all_reference_responses, + "score": scores.float().cpu().numpy(), + "reference_scores": reference_scores.float().cpu().numpy(), + "kl": kl_sum.float().cpu().numpy(), + "reward": (scores - kl_ctl.value * kl_sum).float().cpu().numpy(), + } + ) + if accelerator.is_main_process and args.track: + wandb.log({"query_responses": wandb.Table(dataframe=all_df)}, step=update) + print_rich_table("stuff", all_df[:4], console) + except Exception as e: + print(e) + del ( + all_decode_queries, + all_postprocessed_query_responses, + all_postprocessed_responses, + kl_sum, + all_df, + ) + del postprocessed_query_responses + torch.cuda.empty_cache() + + # 6. compute advantages and returns + lastgaelam = 0 + advantages_reversed = [] + gen_length = args.task.response_length + for t in reversed(range(gen_length)): + nextvalues = values[:, t + 1] if t < gen_length - 1 else 0.0 + delta = rewards[:, t] + args.ppo.gamma * nextvalues - values[:, t] + lastgaelam = delta + args.ppo.gamma * args.ppo.lam * lastgaelam + advantages_reversed.append(lastgaelam) + advantages = torch.stack(advantages_reversed[::-1], axis=1) + returns = advantages + values + advantages = whiten(advantages) + return_mean, return_var = returns.mean(), returns.var() + value_mean, value_var = values.mean(), values.var() + + # Do multiple epochs of PPO training, with a fresh random shuffle in each epoch + for ppo_epoch_idx in range(args.ppo.noptepochs): + b_inds = np.random.permutation(args.ppo.local_batch_size) + minibatch_idx = 0 + for mini_batch_start in range(0, args.ppo.local_batch_size, args.ppo.local_mini_batch_size): + mini_batch_end = mini_batch_start + args.ppo.local_mini_batch_size + mini_batch_inds = b_inds[mini_batch_start:mini_batch_end] + gradient_accumulation_idx = 0 + for micro_batch_start in range(0, args.ppo.local_mini_batch_size, args.ppo.local_micro_batch_size): + with accelerator.accumulate(policy): + micro_batch_end = micro_batch_start + args.ppo.local_micro_batch_size + micro_batch_inds = mini_batch_inds[micro_batch_start:micro_batch_end] + mb_return = returns[micro_batch_inds] + mb_advantage = advantages[micro_batch_inds] + mb_values = values[micro_batch_inds] + mb_responses = responses[micro_batch_inds] + mb_query_responses = query_responses[micro_batch_inds] + mb_logprobs = logprobs[micro_batch_inds] + + # output, vpred_temp = forward(policy, mb_query_responses, tokenizer) + output, (_, vpred_temp) = forward(model, mb_query_responses, tokenizer) + + logits = output.logits[:, context_length - 1 : -1] + logits /= (args.task.temperature + 1e-7) + new_all_logprobs = F.log_softmax(logits, dim=-1) + new_logprobs = torch.gather(new_all_logprobs, 2, mb_responses.unsqueeze(-1)).squeeze(-1) + new_logprobs = torch.masked_fill(new_logprobs, padding_mask[micro_batch_inds], INVALID_LOGPROB) + vpred = vpred_temp[:, context_length - 1 : -1].squeeze(-1) + vpred = torch.masked_fill(vpred, padding_mask[micro_batch_inds], 0) + vpredclipped = torch.clamp( + vpred, + mb_values - args.ppo.cliprange_value, + mb_values + args.ppo.cliprange_value, + ) + vf_losses1 = torch.square(vpred - mb_return) + vf_losses2 = torch.square(vpredclipped - mb_return) + vf_loss = 0.5 * torch.max(vf_losses1, vf_losses2).mean() + vf_clipfrac = (vf_losses2 > vf_losses1).float().mean() + logprobs_diff = new_logprobs - mb_logprobs + ratio = torch.exp(logprobs_diff) + pg_losses = -mb_advantage * ratio + pg_losses2 = -mb_advantage * torch.clamp(ratio, 1.0 - args.ppo.cliprange, 1.0 + args.ppo.cliprange) + pg_loss = torch.max(pg_losses, pg_losses2).mean() + pg_clipfrac = (pg_losses2 > pg_losses).float().mean() + loss = pg_loss + args.ppo.vf_coef * vf_loss + accelerator.backward(loss) + optimizer.step() + optimizer.zero_grad() + prob_dist = torch.nn.functional.softmax(logits, dim=-1) + entropy = torch.logsumexp(logits, dim=-1) - torch.sum(prob_dist * logits, dim=-1) + approxkl = 0.5 * (logprobs_diff**2).mean() + with torch.no_grad(): + approxkls_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = approxkl + clipfracs_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = pg_clipfrac + pg_losses_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = pg_loss + vf_losses_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = vf_loss + vf_clipfrac_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = vf_clipfrac + entropies_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = entropy.mean() + gradient_accumulation_idx += 1 + minibatch_idx += 1 + if accelerator.is_main_process: + console.print( + f"ppo_epoch_idx", + ppo_epoch_idx, + "approxkl", + approxkl.item(), + "pg_loss", + pg_loss.item(), + "pg_clipfrac", + pg_clipfrac.item(), + "ratio", + ratio.mean().item(), + ) + + with torch.no_grad(): + if not args.deepspeed: # for some reason there is a OOM with the `writer.add_histogram` + writer.add_histogram("ppo/val/ratio_hist", ratio, update) + kl = logprobs - ref_logprobs + mean_kl = kl.sum(1).mean() + mean_entropy = (-logprobs).sum(1).mean() + mean_non_score_reward = non_score_reward.sum(1).mean() + writer.add_scalar("objective/kl_coef", kl_ctl.value, update) + writer.add_scalar("objective/kl", accelerator.gather(mean_kl).mean().item(), update) + writer.add_scalar("objective/entropy", accelerator.gather(mean_entropy).mean().item(), update) + writer.add_scalar("objective/non_score_reward", accelerator.gather(mean_non_score_reward).mean().item(), update) + writer.add_scalar( + "objective/score_total", accelerator.gather(mean_non_score_reward + scores.mean()).mean().item(), update + ) + writer.add_scalar("objective/scores", accelerator.gather(scores.mean()).mean().item(), update) + writer.add_scalar("objective/reference_scores", accelerator.gather(reference_scores.mean()).mean().item(), update) + writer.add_scalar("ppo/loss/policy", accelerator.gather(pg_loss).mean().item(), update) + writer.add_scalar("ppo/loss/value", accelerator.gather(vf_loss).mean().item(), update) + writer.add_scalar("ppo/loss/total", accelerator.gather(loss).mean().item(), update) + writer.add_scalar("ppo/policy/entropy", accelerator.gather(entropy.mean()).mean().item(), update) + writer.add_scalar("ppo/policy/approxkl", accelerator.gather(approxkl).mean().item(), update) + writer.add_scalar("ppo/policy/clipfrac", accelerator.gather(pg_clipfrac).mean().item(), update) + writer.add_scalar("ppo/policy/approxkl_avg", accelerator.gather(approxkls_stats).mean().item(), update) + writer.add_scalar("ppo/policy/clipfrac_avg", accelerator.gather(clipfracs_stats).mean().item(), update) + writer.add_scalar("ppo/loss/policy_avg", accelerator.gather(pg_losses_stats).mean().item(), update) + writer.add_scalar("ppo/loss/value_avg", accelerator.gather(vf_losses_stats).mean().item(), update) + writer.add_scalar("ppo/val/clipfrac_avg", accelerator.gather(vf_clipfrac_stats).mean().item(), update) + writer.add_scalar("ppo/policy/entropy_avg", accelerator.gather(entropies_stats).mean().item(), update) + writer.add_scalar("ppo/returns/mean", accelerator.gather(return_mean).mean().item(), update) + writer.add_scalar("ppo/returns/var", accelerator.gather(return_var).mean().item(), update) + writer.add_scalar("ppo/val/vpred", accelerator.gather(vpred.mean()).mean().item(), update) + writer.add_scalar("ppo/val/error", accelerator.gather(vf_losses1.mean()).mean().item(), update) + writer.add_scalar("ppo/val/clipfrac", accelerator.gather(vf_clipfrac).mean().item(), update) + writer.add_scalar("ppo/val/mean", accelerator.gather(value_mean).mean().item(), update) + writer.add_scalar("ppo/val/var", accelerator.gather(value_var).mean().item(), update) + writer.add_scalar("ppo/val/ratio", accelerator.gather(ratio.mean()).mean().item(), update) + writer.add_scalar("ppo/val/ratio_var", accelerator.gather(ratio.mean()).var().item(), update) + writer.add_scalar("ppo/val/advantage", accelerator.gather(advantages.mean()).mean().item(), update) + writer.add_scalar("ppo/val/advantage_var", accelerator.gather(advantages.mean()).var().item(), update) + writer.add_scalar("ppo/val/num_eos_tokens", (responses == tokenizer.eos_token_id).sum().item(), update) + writer.add_scalar("ppo/lr", lrnow, update) + writer.add_scalar("ppo/episode", global_step, update) + if args.rewards.use_adaptive_kl: + kl_ctl.update(mean_kl.item(), args.ppo.batch_size) + del kl, mean_kl, mean_entropy, mean_non_score_reward, scores + + # save model + if accelerator.is_main_process and args.save_path: + os.makedirs(os.path.dirname(args.save_path), exist_ok=True) + torch.save(policy.state_dict(), args.save_path) + + if args.upload_model: + repo_name = f"{args.exp_name}__{args.rewards.label_dataset}__seed{args.seed}__{int(time.time())}" + repo_id = f"{args.hf_entity}/{repo_name}" if args.hf_entity else repo_name + policy.save_pretrained(repo_id, safe_serialization=True, push_to_hub=True) + tokenizer.save_pretrained(repo_id, push_to_hub=True) + +# if __name__ == "__main__": +# args = tyro.cli(Args) +# train(args) diff --git a/lm_human_preference_details/train_reward_accelerate_summarize.py b/lm_human_preference_details/train_reward_accelerate_summarize.py index fabae52..4abcd4c 100644 --- a/lm_human_preference_details/train_reward_accelerate_summarize.py +++ b/lm_human_preference_details/train_reward_accelerate_summarize.py @@ -124,7 +124,7 @@ class Args: """Whether, before training, to normalize the rewards on the policy to the scales on the training buffer. (For comparisons, just use mean 0, var 1.)""" normalize_after: bool = True """Whether, after training, to normalize the rewards on the ref policy to mean 0, var 1 (so the KL coefficient always has the same meaning).""" - print_sample_output_freq: int = 120 + print_sample_output_freq: int = 506 """How often to print sample output""" save_path: str = "models/reward.pt" """Where to save the model""" @@ -134,9 +134,6 @@ class Args: labels: LabelHParams = field(default_factory=LabelHParams) -OPENAI_PAD_TOKEN_ID = 50259 - - def first_true_indices(bools, dtype=torch.long): """ Takes an N-dimensional bool tensor and returns an (N-1)-dimensional tensor of integers giving @@ -356,10 +353,10 @@ def exact_div(a, b): return q -def generate(lm_backbone, queries, args, generation_config): +def generate(lm_backbone, queries, tokenizer, generation_config): """generate in a way that does not affect padding tokens""" context_length = queries.shape[1] - attention_mask = queries != args.pad_token_id + attention_mask = queries != tokenizer.pad_token_id input_ids = queries.clone() input_ids[~attention_mask] = 0 # set padding tokens to 0 output = lm_backbone.generate( @@ -387,8 +384,18 @@ def get_reward(reward_model, query_responses, args): ) +def get_reward_complete(reward_model, query_responses, args): + reward = get_reward(reward_model, query_responses, args)[1] + last_response_indices = first_true_indices(query_responses == args.pad_token_id) - 1 + last_response_indices = torch.max( + last_response_indices, + torch.zeros([1], dtype=last_response_indices.dtype, device=query_responses.device), + ) + return reward[:, :, 0].gather(1, last_response_indices.unsqueeze(1)).view(-1) + + def normalize( - args, + tokenizer, accelerator, device, lm_backbone, @@ -402,20 +409,16 @@ def normalize( accelerator.unwrap_model(reward_model).reward_gain.data.fill_(1.0) accelerator.unwrap_model(reward_model).reward_bias.data.fill_(0.0) # number of minibatches for computing the normalization statistics - sample_queries_responses = [] + rewards = [] for data in dataloader: idx += len(data["query_token"]) queries = data["query_token"].to(device) + queries = right_padding_to_left_padding(queries, tokenizer.pad_token_id).to(device) reference_response = data["reference_response"].to(device) query_responses = torch.cat((queries, reference_response), dim=1) - query_responses = right_padding_to_left_padding(query_responses, args.pad_token_id).to(device) - sample_queries_responses.append(query_responses) + score = get_reward_complete(reward_model, query_responses, tokenizer) + rewards.append(score) accelerator.print(f"====number of samples per device: {idx}") - - # compute reward statistics - rewards = [] - for query_responses in sample_queries_responses: - rewards.append(get_reward(reward_model, query_responses, args)[1]) rewards = torch.cat(rewards) rewards = accelerator.gather(rewards) mean, std = rewards.mean(), rewards.std() @@ -430,16 +433,14 @@ def normalize( accelerator.unwrap_model(reward_model).reward_bias.data = bias # validate normalization - sample_queries_responses = [] + rewards = [] for data in validation_dataloader: queries = data["query_token"].to(device) + queries = right_padding_to_left_padding(queries, tokenizer.pad_token_id).to(device) reference_response = data["reference_response"].to(device) query_responses = torch.cat((queries, reference_response), dim=1) - query_responses = right_padding_to_left_padding(query_responses, args.pad_token_id).to(device) - sample_queries_responses.append(query_responses) - rewards = [] - for query_responses in sample_queries_responses: - rewards.append(get_reward(reward_model, query_responses, args)[1]) + score = get_reward_complete(reward_model, query_responses, tokenizer) + rewards.append(score) rewards = torch.cat(rewards) rewards = accelerator.gather(rewards) mean, std = rewards.mean(), rewards.std() @@ -535,7 +536,8 @@ def process_query_data(x): return { **process_query(x, encoder=tokenizer, hparams=patch_h), "reference_response": tokenizer.encode( - x["summary"], padding="max_length", max_length=args.task.response_length, truncation=True + f" {x['summary']}", padding="max_length", max_length=args.task.response_length, truncation=True, + # with an extra leading space to account for the space between the query and response ), } @@ -592,7 +594,7 @@ def process_query_data(x): ) normalize( - args, + tokenizer, accelerator, device, untrained_model.lm_backbone, @@ -658,13 +660,8 @@ def process_response_data(x): predicted_rewards = [] for i in range(args.labels.num_labels): query_responses = torch.cat([mb_query, mb_responses[i]], dim=1) - reward = get_reward(reward_model, query_responses, args)[1] - last_response_indices = first_true_indices(query_responses == args.pad_token_id) - 1 - last_response_indices = torch.max( - last_response_indices, - torch.zeros([1], dtype=last_response_indices.dtype, device=query_responses.device), - ) - predicted_rewards.append(reward[:, :, 0].gather(1, last_response_indices.unsqueeze(1)).view(-1)) + score = get_reward_complete(reward_model, query_responses, args) + predicted_rewards.append(score) predicted_rewards = torch.stack( predicted_rewards, dim=1 ) # shape (batch_size, num_labels), basically a reward prediction for each label @@ -702,6 +699,7 @@ def process_response_data(x): micro_batch_inds = b_inds[micro_batch_start:micro_batch_end] mb_data = validation_label[micro_batch_inds] mb_query = torch.from_numpy(np.stack(mb_data["query_token"])).to(device) + mb_query = right_padding_to_left_padding(mb_query, args.pad_token_id).to(device) mb_best = torch.from_numpy(np.stack(mb_data["choice"])).to(device) mb_responses = [ torch.from_numpy(np.stack(mb_data[f"response{i}_token"])).to(device) @@ -710,13 +708,8 @@ def process_response_data(x): predicted_rewards = [] for i in range(args.labels.num_labels): query_responses = torch.cat([mb_query, mb_responses[i]], dim=1) - reward = get_reward(reward_model, query_responses, args)[1] - last_response_indices = first_true_indices(query_responses == args.pad_token_id) - 1 - last_response_indices = torch.max( - last_response_indices, - torch.zeros([1], dtype=last_response_indices.dtype, device=query_responses.device), - ) - predicted_rewards.append(reward[:, :, 0].gather(1, last_response_indices.unsqueeze(1)).view(-1)) + score = get_reward_complete(reward_model, query_responses, args) + predicted_rewards.append(score) predicted_rewards = torch.stack( predicted_rewards, dim=1 ) # shape (batch_size, num_labels), basically a reward prediction for each label @@ -736,12 +729,12 @@ def process_response_data(x): query_responses = generate( accelerator.unwrap_model(reward_model).lm_backbone, queries, - args, + tokenizer, generation_config, ) responses = query_responses[:, context_length:] - output, reward = get_reward(reward_model, query_responses, args) + output, _ = get_reward(reward_model, query_responses, args) logits = output.logits[:, context_length - 1 : -1] logits /= args.task.temperature all_logprobs = F.log_softmax(logits, dim=-1) @@ -795,7 +788,7 @@ def process_response_data(x): ) normalize( - args, + tokenizer, accelerator, device, untrained_model.lm_backbone, @@ -812,7 +805,8 @@ def process_response_data(x): # save model if args.save_path: os.makedirs(os.path.dirname(args.save_path), exist_ok=True) - torch.save(accelerator.unwrap_model(reward_model).state_dict(), args.save_path) + # torch.save(accelerator.unwrap_model(reward_model).state_dict(), args.save_path) + accelerator.save_model(reward_model, args.save_path) if accelerator.is_main_process and args.track: wandb.finish() diff --git a/lm_human_preference_details/train_sft_accelerate_summarize.py b/lm_human_preference_details/train_sft_accelerate_summarize.py index d0abedd..b9c7b66 100644 --- a/lm_human_preference_details/train_sft_accelerate_summarize.py +++ b/lm_human_preference_details/train_sft_accelerate_summarize.py @@ -1,3 +1,4 @@ +import collections import os import random import time @@ -10,6 +11,7 @@ import torch import torch.optim as optim import tyro +import evaluate from accelerate import Accelerator from datasets import load_dataset from rich.console import Console @@ -65,7 +67,7 @@ class TaskHParams: penalty_reward_value: int = -1 # LM params - temperature: float = 0.7 + temperature: float = 0.01 # a patch @@ -106,7 +108,7 @@ class Args: """the name of the pretrained model to use""" deepspeed: bool = False """Whether to use deepspeed to train the model""" - print_sample_output_freq: int = 80 + print_sample_output_freq: int = 220 """How often to print sample output""" save_path: str = "models/sft_policy.pt" """Where to save the model""" @@ -340,6 +342,8 @@ def train(args: Args): ) dataset = load_dataset(args.task.query_dataset, split="train") test_dataset = load_dataset(args.task.query_dataset, split="test") + accelerator.print("The number of samples in dataset", len(dataset)) + accelerator.print("The number of samples in test_dataset", len(test_dataset)) args.sft.total_episodes = len(dataset) args.sft.num_updates = args.sft.total_episodes // args.sft.batch_size @@ -406,7 +410,8 @@ def process_query_data(x): test_dataset = test_dataset.with_format("torch", columns=["query_token", "reference_response"]) test_dataset = test_dataset.shuffle(seed=local_seed) dataloader = DataLoader(dataset, batch_size=args.sft.local_micro_batch_size) - policy, optimizer, dataloader = accelerator.prepare(policy, optimizer, dataloader) + test_dataloader = DataLoader(test_dataset, batch_size=args.sft.local_micro_batch_size) + policy, optimizer, dataloader, test_dataloader = accelerator.prepare(policy, optimizer, dataloader, test_dataloader) iter_dataloader = iter(dataloader) # WARNING: even with `max_new_tokens` and `min_new_tokens` set to the same value, the number of tokens generated # may not be the same. TODO: investigate further, we just want to generate a fixed number of tokens @@ -418,6 +423,7 @@ def process_query_data(x): top_p=1.0, do_sample=True, ) + rouge = evaluate.load("rouge") print("===training policy===") global_step = 0 @@ -425,11 +431,18 @@ def process_query_data(x): test_data = {k: v.to(device) for k, v in test_data.items()} loss_stats = torch.zeros(args.sft.gradient_accumulation_steps, device=device) gradient_accumulation_idx = 0 + + # Given parameters + eta_min = 0 + eta_max = 6.35e-5 + T_max = args.sft.num_updates + for update in range(1, args.sft.num_updates + 1): global_step += 1 * args.sft.batch_size accelerator.print(f"update {update}, global_step {global_step}") - frac = 1.0 - (update - 1.0) / args.sft.num_updates - lrnow = frac * args.sft.lr + # frac = 1.0 - (update - 1.0) / args.sft.num_updates + # lrnow = frac * args.sft.lr + lrnow = eta_min + 0.5 * (eta_max - eta_min) * (1 + np.cos(np.pi * (update - 1) / T_max)) optimizer.param_groups[0]["lr"] = lrnow data = next(iter_dataloader) queries = data["query_token"].to(device) @@ -448,13 +461,15 @@ def process_query_data(x): writer.add_scalar("loss", accelerator.gather(loss_stats).mean().item(), update) writer.add_scalar("lr", lrnow, update) if (update - 1) % args.print_sample_output_freq * args.sft.gradient_accumulation_steps == 0: - with torch.no_grad(): - test_queries = test_data["query_token"] - test_reference_responses = test_data["reference_response"] - test_queries = right_padding_to_left_padding(test_queries, tokenizer.pad_token_id) - generated_responses = generate(accelerator.unwrap_model(policy), test_queries, tokenizer, generation_config) - - try: + rouge_scores = collections.defaultdict(list) + for test_idx, test_data in enumerate(test_dataloader): + with torch.no_grad(): + test_queries = test_data["query_token"].to(device) + test_reference_responses = test_data["reference_response"].to(device) + test_queries = right_padding_to_left_padding(test_queries, tokenizer.pad_token_id) + generated_responses = generate(accelerator.unwrap_model(policy), test_queries, tokenizer, generation_config) + accelerator.print(update, test_idx) + all_decode_test_queries = tokenizer.batch_decode(test_queries, skip_special_tokens=True) all_decode_test_query_responses = tokenizer.batch_decode(generated_responses, skip_special_tokens=True) all_decode_test_reference_responses = tokenizer.batch_decode( @@ -463,24 +478,36 @@ def process_query_data(x): all_decode_test_responses = [ x[len(y) :] for x, y in zip(all_decode_test_query_responses, all_decode_test_queries) ] - - all_df = pd.DataFrame( - { - "query": all_decode_test_queries, - "response": all_decode_test_responses, - "reference": all_decode_test_reference_responses, - } - ) - if accelerator.is_main_process and args.track: - wandb.log({"query_responses": wandb.Table(dataframe=all_df)}, step=update) - print_rich_table(f"Sample Output at Step {update}", all_df[:4], console) - except Exception as e: - print(e) + rouge_score = rouge.compute(predictions=all_decode_test_responses, references=all_decode_test_reference_responses) + rouge_scores["rouge1"].append(rouge_score["rouge1"]) + rouge_scores["rouge2"].append(rouge_score["rouge2"]) + rouge_scores["rougeL"].append(rouge_score["rougeL"]) + + if test_idx == 0: + try: + all_df = pd.DataFrame( + { + "query": all_decode_test_queries, + "response": all_decode_test_responses, + "reference": all_decode_test_reference_responses, + } + ) + if accelerator.is_main_process and args.track: + wandb.log({"samples/query_responses": wandb.Table(dataframe=all_df)}, step=update) + print_rich_table(f"Sample Output at Step {update}", all_df[:4], console) + except Exception as e: + print(e) + + for k, v in rouge_scores.items(): + rouge_metric = torch.tensor(v, device=device) + rouge_metric = accelerator.gather(rouge_metric) + writer.add_scalar(f"rouge/{k}", rouge_metric.mean().item(), update) + accelerator.print(f"rouge/{k}: {rouge_metric.mean().item()} {rouge_metric.shape} {rouge_metric}") # save model if accelerator.is_main_process and args.save_path: os.makedirs(os.path.dirname(args.save_path), exist_ok=True) - torch.save(accelerator.unwrap_model(policy).state_dict(), args.save_path) + accelerator.save_model(policy, args.save_path) if args.upload_model: repo_name = f"{args.exp_name}__{args.rewards.label_dataset}__seed{args.seed}__{int(time.time())}" From a2187cfacb93c1d409da9a52ca4a9cad993eeb6f Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Wed, 4 Oct 2023 13:10:54 +0000 Subject: [PATCH 13/62] update --- benchmark/summarize.slurm_template | 18 +++ ...in_policy_accelerate_summarize_separate.py | 108 ++++++++++++------ .../train_sft_accelerate_summarize.py | 2 +- .../train_summarize.sh | 40 +++++++ 4 files changed, 134 insertions(+), 34 deletions(-) create mode 100644 benchmark/summarize.slurm_template create mode 100644 lm_human_preference_details/train_summarize.sh diff --git a/benchmark/summarize.slurm_template b/benchmark/summarize.slurm_template new file mode 100644 index 0000000..035feb7 --- /dev/null +++ b/benchmark/summarize.slurm_template @@ -0,0 +1,18 @@ +#!/bin/bash +#SBATCH --job-name=lm_human_preference_details +#SBATCH --partition=production-cluster +#SBATCH --gpus-per-task={{gpus_per_task}} +#SBATCH --cpus-per-gpu={{cpus_per_gpu}} +#SBATCH --ntasks={{ntasks}} +#SBATCH --output=slurm/logs/%x_%j.out +#SBATCH --array={{array}} +#SBATCH --exclude=ip-26-0-149-199 +#SBATCH --exclusive + +{{nodes}} + +seeds={{seeds}} +seed=${seeds[$SLURM_ARRAY_TASK_ID % {{len_seeds}}]} + +echo "Running task $SLURM_ARRAY_TASK_ID with seed: $seed" +SEED=$seed srun {{command}} diff --git a/lm_human_preference_details/train_policy_accelerate_summarize_separate.py b/lm_human_preference_details/train_policy_accelerate_summarize_separate.py index 54e8115..525c426 100644 --- a/lm_human_preference_details/train_policy_accelerate_summarize_separate.py +++ b/lm_human_preference_details/train_policy_accelerate_summarize_separate.py @@ -617,9 +617,9 @@ def process_query_data(x): dataloader = DataLoader(dataset, batch_size=args.ppo.local_batch_size) validation_dataset = validation_dataset.map(process_query_data) validation_dataset = validation_dataset.with_format("torch", columns=["query_token", "reference_response"]) - validation_dataset = validation_dataset.shuffle(seed=local_seed) validation_dataloader = DataLoader(validation_dataset, batch_size=args.ppo.local_batch_size) model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader) + validation_dataloader = accelerator.prepare(validation_dataloader) if args.deepspeed: import deepspeed @@ -652,6 +652,19 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well while True: yield from dataloader + sample_validation_inds = np.arange(args.ppo.batch_size) + local_sample_validation_inds = sample_validation_inds[accelerator.process_index :: accelerator.num_processes] + sample_validation = validation_dataset[local_sample_validation_inds] + sample_validation = {k: v.to(device) for k, v in sample_validation.items()} + sample_validation_queries = sample_validation["query_token"] + with torch.no_grad(): + print(sample_validation_queries.shape) + sample_validation_queries = right_padding_to_left_padding(sample_validation_queries, tokenizer.pad_token_id) + sample_validation_reference_response = sample_validation["reference_response"] + sample_validation_query_reference_responses = torch.cat((sample_validation_queries, sample_validation_reference_response), dim=1) + sample_validation_reference_scores = get_reward_complete(reward_model, sample_validation_query_reference_responses, tokenizer) + + iter_dataloader = iter(repeat_generator()) kl_ctl = AdaptiveKLController(args.rewards.kl_coef, hparams=args.rewards.adaptive_kl) # WARNING: even with `max_new_tokens` and `min_new_tokens` set to the same value, the number of tokens generated @@ -723,13 +736,6 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well reference_responses = data["reference_response"].to(device) queries = right_padding_to_left_padding(data["query_token"], tokenizer.pad_token_id).to(device) query_reference_responses = torch.cat((queries, reference_responses), dim=1) - - - - reference_scores = get_reward_complete(reward_model, query_reference_responses, tokenizer) - accelerator.print(accelerator.gather(reference_scores).mean()) - - query_responses = generate( accelerator.unwrap_model(model).policy, queries, @@ -739,6 +745,15 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well context_length = queries.shape[1] responses = query_responses[:, context_length:] + # validation + sample_validation_query_responses = generate( + accelerator.unwrap_model(model).policy, + sample_validation_queries, + tokenizer, + generation_config, + ) + + output = forward(accelerator.unwrap_model(model).policy, query_responses, tokenizer) full_values = get_reward(accelerator.unwrap_model(model).critic, query_responses, tokenizer)[1] values = full_values[:, context_length - 1 : -1].squeeze(-1) @@ -787,6 +802,9 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well scores = get_reward_complete(reward_model, postprocessed_query_responses, tokenizer) reference_scores = get_reward_complete(reward_model, query_reference_responses, tokenizer) + # note that we do not truncate the validation responses + validation_score = get_reward_complete(reward_model, sample_validation_query_responses, tokenizer) + accelerator.print(accelerator.gather(reference_scores).mean()) # 3. filter response. Ensure that the sample contains truncate_token # responses not passing that filter will receive a low (fixed) score @@ -813,39 +831,62 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well if args.print_sample_output_freq > 0 and (update - 1) % args.print_sample_output_freq == 0: try: - all_decode_queries = tokenizer.batch_decode(queries, skip_special_tokens=True) - all_postprocessed_query_responses = tokenizer.batch_decode( - postprocessed_query_responses, skip_special_tokens=True + # all_decode_queries = tokenizer.batch_decode(queries, skip_special_tokens=True) + # all_postprocessed_query_responses = tokenizer.batch_decode( + # postprocessed_query_responses, skip_special_tokens=True + # ) + # all_postprocessed_responses = [ + # x[len(y) :] for x, y in zip(all_postprocessed_query_responses, all_decode_queries) + # ] + # all_reference_responses = tokenizer.batch_decode(reference_responses, skip_special_tokens=True) + + # kl_sum = kl.sum(axis=1) + # all_df = pd.DataFrame( + # { + # "query": all_decode_queries, + # "response": all_postprocessed_responses, + # "reference_responses": all_reference_responses, + # "score": scores.float().cpu().numpy(), + # "reference_scores": reference_scores.float().cpu().numpy(), + # "kl": kl_sum.float().cpu().numpy(), + # "reward": (scores - kl_ctl.value * kl_sum).float().cpu().numpy(), + # } + # ) + # if accelerator.is_main_process and args.track: + # wandb.log({"query_responses": wandb.Table(dataframe=all_df)}, step=update) + # print_rich_table("stuff", all_df[:4], console) + all_decode_validation_queries = tokenizer.batch_decode(sample_validation_queries, skip_special_tokens=True) + all_sample_validation_query_responses = tokenizer.batch_decode( + sample_validation_query_responses, skip_special_tokens=True ) - all_postprocessed_responses = [ - x[len(y) :] for x, y in zip(all_postprocessed_query_responses, all_decode_queries) + all_sample_validation_responses = [ + x[len(y) :] for x, y in zip(all_sample_validation_query_responses, all_decode_validation_queries) ] - all_reference_responses = tokenizer.batch_decode(reference_responses, skip_special_tokens=True) - - kl_sum = kl.sum(axis=1) - all_df = pd.DataFrame( + all_sample_validation_reference_responses = tokenizer.batch_decode( + sample_validation_reference_response, skip_special_tokens=True + ) + all_sample_validation_df = pd.DataFrame( { - "query": all_decode_queries, - "response": all_postprocessed_responses, - "reference_responses": all_reference_responses, - "score": scores.float().cpu().numpy(), - "reference_scores": reference_scores.float().cpu().numpy(), - "kl": kl_sum.float().cpu().numpy(), - "reward": (scores - kl_ctl.value * kl_sum).float().cpu().numpy(), + "query": all_decode_validation_queries, + "response": all_sample_validation_responses, + "reference_responses": all_sample_validation_reference_responses, + "scores": validation_score.float().cpu().numpy(), + "reference_scores": sample_validation_reference_scores.float().cpu().numpy(), } ) if accelerator.is_main_process and args.track: - wandb.log({"query_responses": wandb.Table(dataframe=all_df)}, step=update) - print_rich_table("stuff", all_df[:4], console) + wandb.log({"samples/query_responses": wandb.Table(dataframe=all_sample_validation_df)}, step=update) + print_rich_table("stuff", all_sample_validation_df[:4], console) + except Exception as e: print(e) - del ( - all_decode_queries, - all_postprocessed_query_responses, - all_postprocessed_responses, - kl_sum, - all_df, - ) + # del ( + # all_decode_queries, + # all_postprocessed_query_responses, + # all_postprocessed_responses, + # kl_sum, + # all_df, + # ) del postprocessed_query_responses torch.cuda.empty_cache() @@ -954,6 +995,7 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well ) writer.add_scalar("objective/scores", accelerator.gather(scores.mean()).mean().item(), update) writer.add_scalar("objective/reference_scores", accelerator.gather(reference_scores.mean()).mean().item(), update) + writer.add_scalar("objective/validation_score", accelerator.gather(validation_score.mean()).mean().item(), update) writer.add_scalar("ppo/loss/policy", accelerator.gather(pg_loss).mean().item(), update) writer.add_scalar("ppo/loss/value", accelerator.gather(vf_loss).mean().item(), update) writer.add_scalar("ppo/loss/total", accelerator.gather(loss).mean().item(), update) diff --git a/lm_human_preference_details/train_sft_accelerate_summarize.py b/lm_human_preference_details/train_sft_accelerate_summarize.py index b9c7b66..40c512c 100644 --- a/lm_human_preference_details/train_sft_accelerate_summarize.py +++ b/lm_human_preference_details/train_sft_accelerate_summarize.py @@ -505,7 +505,7 @@ def process_query_data(x): accelerator.print(f"rouge/{k}: {rouge_metric.mean().item()} {rouge_metric.shape} {rouge_metric}") # save model - if accelerator.is_main_process and args.save_path: + if args.save_path: os.makedirs(os.path.dirname(args.save_path), exist_ok=True) accelerator.save_model(policy, args.save_path) diff --git a/lm_human_preference_details/train_summarize.sh b/lm_human_preference_details/train_summarize.sh new file mode 100644 index 0000000..095f5a3 --- /dev/null +++ b/lm_human_preference_details/train_summarize.sh @@ -0,0 +1,40 @@ +# generate random seed and model paths +# set seed if not found in env +if [ -z "$SEED" ]; then + SEED=$RANDOM +fi + +REWARD_MODEL_PATH=models/reward_model_$SEED +SFT_MODEL_PATH=models/sft_model_$SEED +poetry run accelerate launch --config_file deepspeed.yaml \ + lm_human_preference_details/train_reward_accelerate_summarize.py \ + --base_model=gpt2-xl \ + --local_rollout_batch_size=4 \ + --gradient_accumulation_steps=4 \ + --save_path=$REWARD_MODEL_PATH \ + --labels.num_train=92832 \ + --seed=$SEED \ + --deepspeed \ + --track + +poetry run accelerate launch --config_file deepspeed.yaml \ + --num_processes 8 lm_human_preference_details/train_sft_accelerate_summarize.py \ + --base_model=gpt2-xl \ + --save_path=$SFT_MODEL_PATH \ + --seed=$SEED \ + --deepspeed \ + --track + +poetry run accelerate launch --config_file deepspeed.yaml \ + lm_human_preference_details/train_policy_accelerate_summarize_separate.py \ + --base_model=gpt2-xl \ + --sft_model_path=$SFT_MODEL_PATH/pytorch_model.bin \ + --ppo.gradient_accumulation_steps=64 \ + --ppo.lr=1.5e-5 \ + --rewards.kl_coef=0.05 \ + --rewards.no_use_adaptive_kl \ + --rewards.trained_model=$REWARD_MODEL_PATH/pytorch_model.bin \ + --task.temperature=1.0 \ + --seed=$SEED \ + --deepspeed \ + --track \ \ No newline at end of file From 09a705eecff03fb1f830fe9e6447b55e5a0e3e8e Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Thu, 5 Oct 2023 16:01:52 +0000 Subject: [PATCH 14/62] test --- .../train_policy_accelerate_summarize_separate.py | 3 ++- .../train_reward_accelerate_summarize.py | 3 ++- lm_human_preference_details/train_sft_accelerate_summarize.py | 3 ++- 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/lm_human_preference_details/train_policy_accelerate_summarize_separate.py b/lm_human_preference_details/train_policy_accelerate_summarize_separate.py index 525c426..fc4bbcc 100644 --- a/lm_human_preference_details/train_policy_accelerate_summarize_separate.py +++ b/lm_human_preference_details/train_policy_accelerate_summarize_separate.py @@ -603,10 +603,11 @@ def forward(policy, query_responses, tokenizer): validation_dataset = load_dataset(args.task.query_dataset, split="validation") def process_query_data(x): + pad_summary_w_leading_space = " " + x['summary'] return { **process_query(x, encoder=tokenizer, hparams=patch_h), "reference_response": tokenizer.encode( - f" {x['summary']}", padding="max_length", max_length=args.task.response_length, truncation=True, + pad_summary_w_leading_space, padding="max_length", max_length=args.task.response_length, truncation=True, # with an extra leading space to account for the space between the query and response ), } diff --git a/lm_human_preference_details/train_reward_accelerate_summarize.py b/lm_human_preference_details/train_reward_accelerate_summarize.py index 4abcd4c..b5abc11 100644 --- a/lm_human_preference_details/train_reward_accelerate_summarize.py +++ b/lm_human_preference_details/train_reward_accelerate_summarize.py @@ -533,10 +533,11 @@ def train(args: Args): validation_dataset = load_dataset(args.task.query_dataset, split="validation") def process_query_data(x): + pad_summary_w_leading_space = " " + x['summary'] return { **process_query(x, encoder=tokenizer, hparams=patch_h), "reference_response": tokenizer.encode( - f" {x['summary']}", padding="max_length", max_length=args.task.response_length, truncation=True, + pad_summary_w_leading_space, padding="max_length", max_length=args.task.response_length, truncation=True, # with an extra leading space to account for the space between the query and response ), } diff --git a/lm_human_preference_details/train_sft_accelerate_summarize.py b/lm_human_preference_details/train_sft_accelerate_summarize.py index 40c512c..8882136 100644 --- a/lm_human_preference_details/train_sft_accelerate_summarize.py +++ b/lm_human_preference_details/train_sft_accelerate_summarize.py @@ -395,10 +395,11 @@ def train(args: Args): optimizer = optim.Adam(policy.parameters(), lr=args.sft.lr, eps=args.sft.eps) def process_query_data(x): + pad_summary_w_leading_space = " " + x['summary'] return { **process_query(x, encoder=tokenizer, hparams=patch_h), "reference_response": tokenizer.encode( - f" {x['summary']}", padding="max_length", max_length=args.task.response_length, truncation=True, + pad_summary_w_leading_space, padding="max_length", max_length=args.task.response_length, truncation=True, # with an extra leading space to account for the space between the query and response ), } From 2bcf98d49be0f08f6117e4ad9c80b8878daed716 Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Thu, 5 Oct 2023 16:03:06 +0000 Subject: [PATCH 15/62] add eos token in reference response --- .../train_policy_accelerate_summarize_separate.py | 3 +-- .../train_reward_accelerate_summarize.py | 7 +++---- .../train_sft_accelerate_summarize.py | 3 +-- 3 files changed, 5 insertions(+), 8 deletions(-) diff --git a/lm_human_preference_details/train_policy_accelerate_summarize_separate.py b/lm_human_preference_details/train_policy_accelerate_summarize_separate.py index fc4bbcc..964fdd3 100644 --- a/lm_human_preference_details/train_policy_accelerate_summarize_separate.py +++ b/lm_human_preference_details/train_policy_accelerate_summarize_separate.py @@ -603,11 +603,10 @@ def forward(policy, query_responses, tokenizer): validation_dataset = load_dataset(args.task.query_dataset, split="validation") def process_query_data(x): - pad_summary_w_leading_space = " " + x['summary'] return { **process_query(x, encoder=tokenizer, hparams=patch_h), "reference_response": tokenizer.encode( - pad_summary_w_leading_space, padding="max_length", max_length=args.task.response_length, truncation=True, + f" {x['summary']}<|endoftext|>", padding="max_length", max_length=args.task.response_length, truncation=True, # with an extra leading space to account for the space between the query and response ), } diff --git a/lm_human_preference_details/train_reward_accelerate_summarize.py b/lm_human_preference_details/train_reward_accelerate_summarize.py index b5abc11..b7a5426 100644 --- a/lm_human_preference_details/train_reward_accelerate_summarize.py +++ b/lm_human_preference_details/train_reward_accelerate_summarize.py @@ -533,11 +533,10 @@ def train(args: Args): validation_dataset = load_dataset(args.task.query_dataset, split="validation") def process_query_data(x): - pad_summary_w_leading_space = " " + x['summary'] return { **process_query(x, encoder=tokenizer, hparams=patch_h), "reference_response": tokenizer.encode( - pad_summary_w_leading_space, padding="max_length", max_length=args.task.response_length, truncation=True, + f" {x['summary']}<|endoftext|>", padding="max_length", max_length=args.task.response_length, truncation=True, # with an extra leading space to account for the space between the query and response ), } @@ -619,10 +618,10 @@ def process_response_data(x): return { **process_query(x["info"], encoder=tokenizer, hparams=patch_h), "response0_token": tokenizer.encode( - x["summaries"][0]["text"], padding="max_length", max_length=args.task.response_length, truncation=True + f" {x['summaries'][0]['text']}<|endoftext|>", padding="max_length", max_length=args.task.response_length, truncation=True ), "response1_token": tokenizer.encode( - x["summaries"][1]["text"], padding="max_length", max_length=args.task.response_length, truncation=True + f" {x['summaries'][1]['text']}<|endoftext|>", padding="max_length", max_length=args.task.response_length, truncation=True ), } diff --git a/lm_human_preference_details/train_sft_accelerate_summarize.py b/lm_human_preference_details/train_sft_accelerate_summarize.py index 8882136..aa9b33b 100644 --- a/lm_human_preference_details/train_sft_accelerate_summarize.py +++ b/lm_human_preference_details/train_sft_accelerate_summarize.py @@ -395,11 +395,10 @@ def train(args: Args): optimizer = optim.Adam(policy.parameters(), lr=args.sft.lr, eps=args.sft.eps) def process_query_data(x): - pad_summary_w_leading_space = " " + x['summary'] return { **process_query(x, encoder=tokenizer, hparams=patch_h), "reference_response": tokenizer.encode( - pad_summary_w_leading_space, padding="max_length", max_length=args.task.response_length, truncation=True, + f" {x['summary']}<|endoftext|>", padding="max_length", max_length=args.task.response_length, truncation=True, # with an extra leading space to account for the space between the query and response ), } From da5dc3350d53ca7ed82934d20f717e04542ce2e4 Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Wed, 25 Oct 2023 13:53:20 +0000 Subject: [PATCH 16/62] push changes --- ...in_policy_accelerate_summarize_separate.py | 245 +++++---- .../train_reward_accelerate_summarize.py | 465 +++++++++--------- .../train_sft_accelerate_summarize.py | 172 ++++--- 3 files changed, 452 insertions(+), 430 deletions(-) diff --git a/lm_human_preference_details/train_policy_accelerate_summarize_separate.py b/lm_human_preference_details/train_policy_accelerate_summarize_separate.py index 964fdd3..90df20d 100644 --- a/lm_human_preference_details/train_policy_accelerate_summarize_separate.py +++ b/lm_human_preference_details/train_policy_accelerate_summarize_separate.py @@ -43,7 +43,7 @@ class RewardHParams: kl_coef: float = 0.15 use_adaptive_kl: bool = True adaptive_kl: Optional[AdaptiveKLParams] = field(default_factory=AdaptiveKLParams) - trained_model: Optional[str] = "models/reward.pt" + trained_model: Optional[str] = "models/reward" label_dataset: tyro.conf.Suppress[Optional[str]] = None @@ -127,6 +127,8 @@ class Args: """Whether to use cuda if available.""" run_name: tyro.conf.Suppress[str] = None """TO BE FILLED: a unique name of this run""" + load_from_cache_file: bool = False + """Whether to load data from the local cache file in `dataset.map`""" upload_model: bool = False "whether to upload the saved model to huggingface" hf_entity: str = "" @@ -136,9 +138,9 @@ class Args: """the name of the pretrained model to use""" deepspeed: bool = False """Whether to use deepspeed to train the model""" - print_sample_output_freq: int = 10 + print_sample_output_freq: int = 1 """How often to print sample output""" - sft_model_path: str = "models/sft_policy.pt" + sft_model_path: str = "models/sft_policy" """Where to load the SFT model""" save_path: str = "models/policy.pt" """Where to save the model""" @@ -354,18 +356,18 @@ def __init__(self, lm_backbone): nn.Linear(lm_backbone.config.hidden_size, 1), std=1 / np.sqrt(lm_backbone.config.hidden_size + 1), ) - self.reward_gain = torch.nn.Parameter(torch.tensor(1.0), requires_grad=True) - self.reward_bias = torch.nn.Parameter(torch.tensor(0.0), requires_grad=True) + # self.reward_gain = torch.nn.Parameter(torch.tensor(1.0), requires_grad=True) + # self.reward_bias = torch.nn.Parameter(torch.tensor(0.0), requires_grad=True) def forward(self, **kwargs): output = self.lm_backbone(**kwargs) - reward_latents = output.hidden_states[-1] + last_reward_latents = output.hidden_states[-1] # shape: [batch_size, length, hidden_size] - last_reward_latents = reward_latents + # last_reward_latents = reward_latents # shape: [batch_size, hidden_size] reward = self.scalar_head(last_reward_latents) - # shape: [batch_size, 1] - reward = self.reward_gain * reward + self.reward_bias + # # shape: [batch_size, 1] + # reward = self.reward_gain * reward + self.reward_bias return output, reward @@ -418,11 +420,10 @@ def generate(lm_backbone, queries, tokenizer, generation_config): return torch.cat((queries, output.sequences[:, context_length:]), dim=1) -def get_reward(reward_model, query_responses, args): - attention_mask = query_responses != args.pad_token_id +def get_reward(reward_model, query_responses, tokenizer): + attention_mask = query_responses != tokenizer.pad_token_id position_ids = attention_mask.cumsum(1) - attention_mask.long() # exclusive cumsum - input_ids = query_responses.clone() - input_ids[~attention_mask] = 0 + input_ids = torch.masked_fill(query_responses, ~attention_mask, 0) return reward_model( input_ids=input_ids, attention_mask=attention_mask, @@ -432,9 +433,9 @@ def get_reward(reward_model, query_responses, args): ) -def get_reward_complete(reward_model, query_responses, args): - reward = get_reward(reward_model, query_responses, args)[1] - last_response_indices = first_true_indices(query_responses == args.pad_token_id) - 1 +def get_reward_complete(reward_model, query_responses, tokenizer): + reward = get_reward(reward_model, query_responses, tokenizer)[1] + last_response_indices = first_true_indices(query_responses == tokenizer.pad_token_id) - 1 last_response_indices = torch.max( last_response_indices, torch.zeros([1], dtype=last_response_indices.dtype, device=query_responses.device), @@ -442,60 +443,6 @@ def get_reward_complete(reward_model, query_responses, args): return reward[:, :, 0].gather(1, last_response_indices.unsqueeze(1)).view(-1) -def normalize( - tokenizer, - accelerator, - device, - lm_backbone, - reward_model, - dataloader, - validation_dataloader, -): - idx = 0 - with torch.no_grad(): - # reset reward scales - # accelerator.unwrap_model(reward_model).reward_gain.data.fill_(1.0) - # accelerator.unwrap_model(reward_model).reward_bias.data.fill_(0.0) - # number of minibatches for computing the normalization statistics - rewards = [] - for data in dataloader: - idx += len(data["query_token"]) - queries = data["query_token"].to(device) - queries = right_padding_to_left_padding(queries, tokenizer.pad_token_id).to(device) - reference_response = data["reference_response"].to(device) - query_responses = torch.cat((queries, reference_response), dim=1) - score = get_reward_complete(reward_model, query_responses, tokenizer) - accelerator.print(score.shape, accelerator.gather(score).mean()) - rewards.append(score) - accelerator.print(f"====number of samples per device: {idx}") - rewards = torch.cat(rewards) - rewards = accelerator.gather(rewards) - mean, std = rewards.mean(), rewards.std() - print(f"mean: {mean}, std: {std}") - - # reward normalization - target_mean, target_std = torch.tensor(0.0, device=device), torch.tensor(1.0, device=device) - gain = target_std / std - bias = target_mean - gain * mean - print(f"gain: {gain}, bias: {bias}") - accelerator.unwrap_model(reward_model).reward_gain.data = gain - accelerator.unwrap_model(reward_model).reward_bias.data = bias - - # validate normalization - rewards = [] - for data in validation_dataloader: - queries = data["query_token"].to(device) - queries = right_padding_to_left_padding(queries, tokenizer.pad_token_id).to(device) - reference_response = data["reference_response"].to(device) - query_responses = torch.cat((queries, reference_response), dim=1) - score = get_reward_complete(reward_model, query_responses, tokenizer) - rewards.append(score) - rewards = torch.cat(rewards) - rewards = accelerator.gather(rewards) - mean, std = rewards.mean(), rewards.std() - print(f"after mean: {mean}, after std: {std}") - - def forward(policy, query_responses, tokenizer): attention_mask = query_responses != tokenizer.pad_token_id position_ids = attention_mask.cumsum(1) - attention_mask.long() # exclusive cumsum @@ -611,11 +558,11 @@ def process_query_data(x): ), } - dataset = dataset.map(process_query_data) + dataset = dataset.map(process_query_data, load_from_cache_file=args.load_from_cache_file) dataset = dataset.with_format("torch", columns=["query_token", "reference_response"]) dataset = dataset.shuffle(seed=local_seed) dataloader = DataLoader(dataset, batch_size=args.ppo.local_batch_size) - validation_dataset = validation_dataset.map(process_query_data) + validation_dataset = validation_dataset.map(process_query_data, load_from_cache_file=args.load_from_cache_file) validation_dataset = validation_dataset.with_format("torch", columns=["query_token", "reference_response"]) validation_dataloader = DataLoader(validation_dataset, batch_size=args.ppo.local_batch_size) model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader) @@ -663,7 +610,7 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well sample_validation_reference_response = sample_validation["reference_response"] sample_validation_query_reference_responses = torch.cat((sample_validation_queries, sample_validation_reference_response), dim=1) sample_validation_reference_scores = get_reward_complete(reward_model, sample_validation_query_reference_responses, tokenizer) - + # breakpoint() iter_dataloader = iter(repeat_generator()) kl_ctl = AdaptiveKLController(args.rewards.kl_coef, hparams=args.rewards.adaptive_kl) @@ -752,6 +699,25 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well tokenizer, generation_config, ) + sample_validation_responses = sample_validation_query_responses[:, context_length:] + truncate_token_mask = sample_validation_responses == args.task.truncate_token + truncate_after_or_token_mask = torch.cat( + [ + torch.zeros_like(truncate_token_mask)[:, : args.task.truncate_after], + truncate_token_mask[:, args.task.truncate_after :], + ], + dim=1, + ) + truncate_mask = (torch.cumsum(truncate_after_or_token_mask, dim=1) - truncate_after_or_token_mask.long()).bool() + postprocessed_sample_validation_responses = torch.where( + truncate_mask, + torch.full_like(sample_validation_responses, tokenizer.pad_token_id), + sample_validation_responses, + ) + postprocessed_sample_validation_query_responses = torch.cat((sample_validation_queries, postprocessed_sample_validation_responses), 1) + del truncate_token_mask, truncate_after_or_token_mask, truncate_mask + torch.cuda.empty_cache() + output = forward(accelerator.unwrap_model(model).policy, query_responses, tokenizer) @@ -776,47 +742,67 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well # 1. truncate at the first occurrence of `truncate_token` that appears at or after # position truncate_after in the responses # https://github.com/openai/lm-human-preferences/blob/cbfd210bb8b08f6bc5c26878c10984b90f516c66/lm_human_preferences/train_policy.py#L378 - truncate_token_mask = responses == args.task.truncate_token - truncate_after_or_token_mask = torch.cat( - [ - torch.zeros_like(truncate_token_mask)[:, : args.task.truncate_after], - truncate_token_mask[:, args.task.truncate_after :], - ], - dim=1, - ) - truncate_mask = (torch.cumsum(truncate_after_or_token_mask, dim=1) - truncate_after_or_token_mask.long()).bool() - postprocessed_responses = torch.where( - truncate_mask, - torch.full_like(responses, tokenizer.pad_token_id), - responses, - ) - del truncate_token_mask, truncate_after_or_token_mask, truncate_mask + # truncate_token_mask = responses == args.task.truncate_token + # truncate_after_or_token_mask = torch.cat( + # [ + # torch.zeros_like(truncate_token_mask)[:, : args.task.truncate_after], + # truncate_token_mask[:, args.task.truncate_after :], + # ], + # dim=1, + # ) + # truncate_mask = (torch.cumsum(truncate_after_or_token_mask, dim=1) - truncate_after_or_token_mask.long()).bool() + # postprocessed_responses = torch.where( + # truncate_mask, + # torch.full_like(responses, tokenizer.pad_token_id), + # responses, + # ) + # del truncate_token_mask, truncate_after_or_token_mask, truncate_mask + + trunc_idxs = first_true_indices(responses == args.task.truncate_token).unsqueeze(-1) + new_size = [1] * (len(responses.size()) - 1) + [args.task.response_length] + idxs = torch.arange(args.task.response_length, device=responses.device).view(*new_size) + postprocessed_responses = torch.masked_fill(responses, idxs > trunc_idxs, tokenizer.pad_token_id) torch.cuda.empty_cache() # 2. run reward model on the truncated responses postprocessed_query_responses = torch.cat((queries, postprocessed_responses), 1) - padding_mask = (postprocessed_query_responses == tokenizer.pad_token_id)[:, context_length - 1 : -1] + padding_mask = postprocessed_responses == tokenizer.pad_token_id logprobs = torch.masked_fill(logprobs, padding_mask, INVALID_LOGPROB) ref_logprobs = torch.masked_fill(ref_logprobs, padding_mask, INVALID_LOGPROB) values = torch.masked_fill(values, padding_mask, 0) scores = get_reward_complete(reward_model, postprocessed_query_responses, tokenizer) + rew = get_reward(reward_model, postprocessed_query_responses, tokenizer)[1] + + qr = postprocessed_query_responses + attention_mask = qr != tokenizer.pad_token_id + position_ids = attention_mask.cumsum(1) - attention_mask.long() # exclusive cumsum + input_ids = torch.masked_fill(qr, ~attention_mask, 0) + output = reward_model.lm_backbone(input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, return_dict=True, output_hidden_states=True) + last_reward_latents = output.hidden_states[-1] # TODO: investigate whether it should be output.hidden_states[0] or output.hidden_states[-1] + reward = reward_model.scalar_head(last_reward_latents) + + print(postprocessed_query_responses[0:5,537:]) + print(rew.squeeze(-1)[0:5,537:]) + print(scores) + breakpoint() + + reference_scores = get_reward_complete(reward_model, query_reference_responses, tokenizer) # note that we do not truncate the validation responses - validation_score = get_reward_complete(reward_model, sample_validation_query_responses, tokenizer) - accelerator.print(accelerator.gather(reference_scores).mean()) + validation_score = get_reward_complete(reward_model, postprocessed_sample_validation_query_responses, tokenizer) + + # carperAI-style score normaliation + accelerator.print("before score", scores, scores.mean()) + accelerator.print("reference_scores", reference_scores, reference_scores.mean()) + scores = scores - reference_scores + accelerator.print("after score", scores, scores.mean()) # 3. filter response. Ensure that the sample contains truncate_token # responses not passing that filter will receive a low (fixed) score # only query humans on responses that pass that filter - matches_token = postprocessed_responses[:, args.task.truncate_after :] == args.task.truncate_token - filter_mask = torch.any(matches_token, dim=-1) - scores = torch.where( - filter_mask, - scores, - torch.full_like(scores, args.task.penalty_reward_value), - ) - del matches_token, filter_mask + contain_pad_token = torch.any(postprocessed_responses == tokenizer.pad_token_id, dim=-1) + scores = torch.where(contain_pad_token, scores, torch.full_like(scores, args.task.penalty_reward_value)) torch.cuda.empty_cache() # 4. compute rewards @@ -831,44 +817,27 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well if args.print_sample_output_freq > 0 and (update - 1) % args.print_sample_output_freq == 0: try: - # all_decode_queries = tokenizer.batch_decode(queries, skip_special_tokens=True) - # all_postprocessed_query_responses = tokenizer.batch_decode( - # postprocessed_query_responses, skip_special_tokens=True - # ) - # all_postprocessed_responses = [ - # x[len(y) :] for x, y in zip(all_postprocessed_query_responses, all_decode_queries) - # ] - # all_reference_responses = tokenizer.batch_decode(reference_responses, skip_special_tokens=True) - - # kl_sum = kl.sum(axis=1) - # all_df = pd.DataFrame( - # { - # "query": all_decode_queries, - # "response": all_postprocessed_responses, - # "reference_responses": all_reference_responses, - # "score": scores.float().cpu().numpy(), - # "reference_scores": reference_scores.float().cpu().numpy(), - # "kl": kl_sum.float().cpu().numpy(), - # "reward": (scores - kl_ctl.value * kl_sum).float().cpu().numpy(), - # } - # ) - # if accelerator.is_main_process and args.track: - # wandb.log({"query_responses": wandb.Table(dataframe=all_df)}, step=update) - # print_rich_table("stuff", all_df[:4], console) - all_decode_validation_queries = tokenizer.batch_decode(sample_validation_queries, skip_special_tokens=True) + all_decode_validation_queries = tokenizer.batch_decode(sample_validation_queries) all_sample_validation_query_responses = tokenizer.batch_decode( - sample_validation_query_responses, skip_special_tokens=True + sample_validation_query_responses + ) + all_sample_validation_query_responses_postprocessed = tokenizer.batch_decode( + postprocessed_sample_validation_query_responses ) all_sample_validation_responses = [ x[len(y) :] for x, y in zip(all_sample_validation_query_responses, all_decode_validation_queries) ] + all_sample_validation_postprocessed_responses = [ + x[len(y) :] for x, y in zip(all_sample_validation_query_responses_postprocessed, all_decode_validation_queries) + ] all_sample_validation_reference_responses = tokenizer.batch_decode( - sample_validation_reference_response, skip_special_tokens=True + sample_validation_reference_response ) all_sample_validation_df = pd.DataFrame( { "query": all_decode_validation_queries, "response": all_sample_validation_responses, + "postprocessed_response": all_sample_validation_postprocessed_responses, "reference_responses": all_sample_validation_reference_responses, "scores": validation_score.float().cpu().numpy(), "reference_scores": sample_validation_reference_scores.float().cpu().numpy(), @@ -880,13 +849,13 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well except Exception as e: print(e) - # del ( - # all_decode_queries, - # all_postprocessed_query_responses, - # all_postprocessed_responses, - # kl_sum, - # all_df, - # ) + del ( + all_decode_validation_queries, + all_sample_validation_query_responses, + all_sample_validation_responses, + all_sample_validation_reference_responses, + all_sample_validation_df, + ) del postprocessed_query_responses torch.cuda.empty_cache() @@ -926,6 +895,15 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well # output, vpred_temp = forward(policy, mb_query_responses, tokenizer) output, (_, vpred_temp) = forward(model, mb_query_responses, tokenizer) + # TODO: value also use the EOS token index!!!!!!!!!!!!!!!!!!! + # TODO: value also use the EOS token index!!!!!!!!!!!!!!!!!!! + # TODO: value also use the EOS token index!!!!!!!!!!!!!!!!!!! + # TODO: value also use the EOS token index!!!!!!!!!!!!!!!!!!! + # TODO: value also use the EOS token index!!!!!!!!!!!!!!!!!!! + # TODO: value also use the EOS token index!!!!!!!!!!!!!!!!!!! + # TODO: value also use the EOS token index!!!!!!!!!!!!!!!!!!! + # TODO: value also use the EOS token index!!!!!!!!!!!!!!!!!!! + # TODO: value also use the EOS token index!!!!!!!!!!!!!!!!!!! logits = output.logits[:, context_length - 1 : -1] logits /= (args.task.temperature + 1e-7) @@ -951,6 +929,7 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well pg_clipfrac = (pg_losses2 > pg_losses).float().mean() loss = pg_loss + args.ppo.vf_coef * vf_loss accelerator.backward(loss) + breakpoint() optimizer.step() optimizer.zero_grad() prob_dist = torch.nn.functional.softmax(logits, dim=-1) diff --git a/lm_human_preference_details/train_reward_accelerate_summarize.py b/lm_human_preference_details/train_reward_accelerate_summarize.py index b7a5426..443cc3b 100644 --- a/lm_human_preference_details/train_reward_accelerate_summarize.py +++ b/lm_human_preference_details/train_reward_accelerate_summarize.py @@ -3,7 +3,7 @@ import time from dataclasses import asdict, dataclass, field from types import SimpleNamespace -from typing import List, Optional +from typing import List, Literal, Optional import numpy as np import pandas as pd @@ -28,7 +28,7 @@ ) from torch.utils.data import DataLoader from torch.utils.tensorboard import SummaryWriter -from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig +from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig, get_scheduler from lm_human_preference_details.data import process_query @@ -36,7 +36,7 @@ @dataclass class LabelHParams: type: str = None - num_train: int = 64832 + num_train: int = 92832 num_labels: int = 2 source: str = None @@ -89,6 +89,8 @@ class Args: """Whether to use cuda if available.""" run_name: tyro.conf.Suppress[str] = None """TO BE FILLED: a unique name of this run""" + load_from_cache_file: bool = False + """Whether to load data from the local cache file in `dataset.map`""" base_model: str = "gpt2" """the name of the pretrained model to use""" @@ -124,12 +126,26 @@ class Args: """Whether, before training, to normalize the rewards on the policy to the scales on the training buffer. (For comparisons, just use mean 0, var 1.)""" normalize_after: bool = True """Whether, after training, to normalize the rewards on the ref policy to mean 0, var 1 (so the KL coefficient always has the same meaning).""" - print_sample_output_freq: int = 506 + print_sample_output_freq: int = 300 """How often to print sample output""" - save_path: str = "models/reward.pt" + sft_model_path: str = "models/sft_policy" + """Where to load the SFT model""" + logsigmoid: bool = True + """Whether to use log-sigmoid loss instead of cross-entropy loss""" + trainable_param_percentage: float = 1.0 + """Percentage of parameters to train""" + num_epochs: int = 1 + """Number of epochs to train""" + num_updates: tyro.conf.Suppress[int] = None + """Number of updates to train""" + save_path: str = "models/reward" """Where to save the model""" - use_tensorflow_adam: bool = True - """Whether to use tensorflow-style Adam optimizer instead of PyTorch's""" + optimizer: Literal["tf_adam", "adam", "adamw"] = "adamw" + """Which optimizer to use""" + scheduler: str = "constant_with_warmup" + """Which scheduler to use""" + warm_up_steps: int = 100 + """Number of warm up steps for the scheduler""" task: TaskHParams = field(default_factory=TaskHParams) labels: LabelHParams = field(default_factory=LabelHParams) @@ -318,18 +334,14 @@ def __init__(self, lm_backbone): nn.Linear(lm_backbone.config.hidden_size, 1), std=1 / np.sqrt(lm_backbone.config.hidden_size + 1), ) - self.reward_gain = torch.nn.Parameter(torch.tensor(1.0), requires_grad=True) - self.reward_bias = torch.nn.Parameter(torch.tensor(0.0), requires_grad=True) + # self.reward_gain = torch.nn.Parameter(torch.tensor(1.0), requires_grad=True) + # self.reward_bias = torch.nn.Parameter(torch.tensor(0.0), requires_grad=True) def forward(self, **kwargs): output = self.lm_backbone(**kwargs) - reward_latents = output.hidden_states[-1] - # shape: [batch_size, length, hidden_size] - last_reward_latents = reward_latents + last_reward_latents = output.hidden_states[-1] # shape: [batch_size, hidden_size] reward = self.scalar_head(last_reward_latents) - # shape: [batch_size, 1] - reward = self.reward_gain * reward + self.reward_bias return output, reward @@ -370,11 +382,10 @@ def generate(lm_backbone, queries, tokenizer, generation_config): return torch.cat((queries, output.sequences[:, context_length:]), dim=1) -def get_reward(reward_model, query_responses, args): - attention_mask = query_responses != args.pad_token_id +def get_reward(reward_model, query_responses, tokenizer): + attention_mask = query_responses != tokenizer.pad_token_id position_ids = attention_mask.cumsum(1) - attention_mask.long() # exclusive cumsum - input_ids = query_responses.clone() - input_ids[~attention_mask] = 0 + input_ids = torch.masked_fill(query_responses, ~attention_mask, 0) return reward_model( input_ids=input_ids, attention_mask=attention_mask, @@ -384,14 +395,14 @@ def get_reward(reward_model, query_responses, args): ) -def get_reward_complete(reward_model, query_responses, args): - reward = get_reward(reward_model, query_responses, args)[1] - last_response_indices = first_true_indices(query_responses == args.pad_token_id) - 1 +def get_reward_complete(reward_model, query_responses, tokenizer): + reward = get_reward(reward_model, query_responses, tokenizer)[1] + last_response_indices = first_true_indices(query_responses == tokenizer.pad_token_id) - 1 last_response_indices = torch.max( last_response_indices, torch.zeros([1], dtype=last_response_indices.dtype, device=query_responses.device), ) - return reward[:, :, 0].gather(1, last_response_indices.unsqueeze(1)).view(-1) + return reward[:, :, 0].gather(1, last_response_indices.unsqueeze(1)).view(-1), reward def normalize( @@ -447,11 +458,50 @@ def normalize( print(f"after mean: {mean}, after std: {std}") +def evaluate(args, accelerator, device, reward_model, validation_label): + # reward_model.eval() + with torch.no_grad(): + # eval on validation_label, some duplicate code (I don't want to make the training loop into a function...) + test_accuracies = [] + eval_len = len(validation_label) + len_labels = (eval_len // args.batch_size) * args.batch_size # in case the last batch is not full + new_all_inds = np.arange(len_labels) + for start in range(0, len_labels, args.batch_size): + end = start + args.batch_size + b_inds_all = new_all_inds[start:end] + b_inds = b_inds_all[accelerator.process_index :: accelerator.num_processes] # multi-GPU slicing + for micro_batch_start in range(0, args.local_batch_size, args.local_micro_batch_size): + micro_batch_end = micro_batch_start + args.local_micro_batch_size + micro_batch_inds = b_inds[micro_batch_start:micro_batch_end] + mb_data = validation_label[micro_batch_inds] + mb_query = torch.from_numpy(np.stack(mb_data["query_token"])).to(device) + mb_query = right_padding_to_left_padding(mb_query, args.pad_token_id).to(device) + mb_best = torch.from_numpy(np.stack(mb_data["choice"])).to(device) + mb_responses = [ + torch.from_numpy(np.stack(mb_data[f"response{i}_token"])).to(device) + for i in range(args.labels.num_labels) + ] + predicted_reward = [] + for i in range(args.labels.num_labels): + query_responses = torch.cat([mb_query, mb_responses[i]], dim=1) + score, _ = get_reward_complete(reward_model, query_responses, args) + predicted_reward.append(score) + predicted_reward = torch.stack( + predicted_reward, dim=1 + ) # shape (batch_size, num_labels), basically a reward prediction for each label + accuracy = (predicted_reward.argmax(1) == mb_best).float().mean() + test_accuracies.append(accuracy) + test_accuracy = accelerator.gather(torch.stack(test_accuracies).mean()).mean().item() + # reward_model.train() + return test_accuracy + + def train(args: Args): accelerator = Accelerator( kwargs_handlers=[ DistributedDataParallelKwargs( broadcast_buffers=False, + # find_unused_parameters=True, ) ], # this is needed to avoid https://github.com/pytorch/pytorch/issues/22095#issuecomment-505099500 gradient_accumulation_steps=args.gradient_accumulation_steps, @@ -460,6 +510,7 @@ def train(args: Args): args.batch_size = int(args.local_batch_size * args.world_size) args.rollout_batch_size = int(args.local_rollout_batch_size * args.world_size) args.local_micro_batch_size = exact_div(args.local_batch_size, args.gradient_accumulation_steps) + args.num_updates = args.labels.num_train // args.batch_size patch_h = TaskQueryHParams( length=args.task.query_length, dataset=args.task.query_dataset, @@ -507,16 +558,21 @@ def train(args: Args): # we use the padding token manually but do not resize the token embedding of the model tokenizer.add_special_tokens({"pad_token": "[PAD]"}) args.pad_token_id = tokenizer.pad_token_id - untrained_model = AutoModelForCausalLMWithRewardHead( - AutoModelForCausalLM.from_pretrained(args.base_model, trust_remote_code=True) - ) reward_model = AutoModelForCausalLMWithRewardHead( AutoModelForCausalLM.from_pretrained(args.base_model, trust_remote_code=True) ) - untrained_model.lm_backbone.generation_config.eos_token_id = ( - None # disable `pad_token_id` and `eos_token_id` because we just want to - ) - untrained_model.lm_backbone.generation_config.pad_token_id = None # generate tokens without truncation / padding + + # freeze the first 70% of layers + if args.trainable_param_percentage < 1.0: + layers = reward_model.lm_backbone.transformer.h + num_layers = len(layers) + num_unfrozen = int(args.trainable_param_percentage * num_layers) + for layer in layers[:-num_unfrozen]: + layer.requires_grad_(False) + + if args.sft_model_path: + reward_model.lm_backbone.load_state_dict(torch.load(args.sft_model_path, map_location=device)) + print(f"loaded SFT model from {args.sft_model_path}") reward_model.lm_backbone.generation_config.eos_token_id = ( None # disable `pad_token_id` and `eos_token_id` because we just want to ) @@ -525,67 +581,51 @@ def train(args: Args): # pytorch DDP complains; see https://gist.github.com/vwxyzjn/45fc8706dfb3cf33695f0f57cc44a533 if isinstance(reward_model.lm_backbone, transformers.GPTNeoXForCausalLM): reward_model.lm_backbone.embed_out.requires_grad_(False) - if args.use_tensorflow_adam: + if args.optimizer == "tf_adam": optimizer = AdamTensorFlowStyle(reward_model.parameters(), lr=args.lr, eps=args.eps) - else: + elif args.optimizer == "adam": optimizer = optim.Adam(reward_model.parameters(), lr=args.lr, eps=args.eps) - dataset = load_dataset(args.task.query_dataset, split="train") - validation_dataset = load_dataset(args.task.query_dataset, split="validation") - - def process_query_data(x): - return { - **process_query(x, encoder=tokenizer, hparams=patch_h), - "reference_response": tokenizer.encode( - f" {x['summary']}<|endoftext|>", padding="max_length", max_length=args.task.response_length, truncation=True, - # with an extra leading space to account for the space between the query and response - ), - } + elif args.optimizer == "adamw": + optimizer = optim.AdamW(reward_model.parameters(), lr=args.lr, eps=args.eps) + # TODO: use AdamW + scheduler = get_scheduler( + args.scheduler, + optimizer=optimizer, + num_warmup_steps=args.warm_up_steps, + num_training_steps=args.num_updates * args.num_epochs, + ) - dataset = dataset.map(process_query_data) - dataset = dataset.with_format("torch", columns=["query_token", "reference_response"]) - dataset = dataset.shuffle(seed=local_seed) - dataloader = DataLoader(dataset, batch_size=args.local_rollout_batch_size) - validation_dataset = validation_dataset.map(process_query_data) - validation_dataset = validation_dataset.with_format("torch", columns=["query_token", "reference_response"]) - validation_dataset = validation_dataset.shuffle(seed=local_seed) - validation_dataloader = DataLoader(validation_dataset, batch_size=args.local_rollout_batch_size) - reward_model, optimizer, dataloader = accelerator.prepare(reward_model, optimizer, dataloader) if args.deepspeed: import deepspeed deepspeed_states = AcceleratorState().deepspeed_plugin - # deepspeed_states.deepspeed_config["train_micro_batch_size_per_gpu"] = args.ppo.local_micro_batch_size - # deepspeed_states.deepspeed_config["checkpoint"] = {"use_node_local_storage": True} - eval_ds_config = { - "train_micro_batch_size_per_gpu": deepspeed_states.deepspeed_config["train_micro_batch_size_per_gpu"], - # "steps_per_print": 10, - # "zero_optimization": { - # "stage": stage, - # "stage3_param_persistence_threshold": 1e4, - # "offload_param": { - # "device": off_load_device - # } - # }, - "bf16": {"enabled": True}, - "prescale_gradients": False, - "wall_clock_breakdown": False, - } - untrained_model, *_ = deepspeed.initialize(model=untrained_model, config=eval_ds_config) - untrained_model.eval() - else: - untrained_model = untrained_model.to(device) - - iter_dataloader = iter(dataloader) - generation_config = GenerationConfig( - max_new_tokens=args.task.response_length, - min_new_tokens=args.task.response_length, - temperature=args.task.temperature, - top_k=0.0, - top_p=1.0, - do_sample=True, - ) + deepspeed_states.deepspeed_config["train_micro_batch_size_per_gpu"] = args.local_micro_batch_size + + reward_model, optimizer, scheduler = accelerator.prepare(reward_model, optimizer, scheduler) if args.normalize_before: + dataset = load_dataset(args.task.query_dataset, split="train") + validation_dataset = load_dataset(args.task.query_dataset, split="validation") + + def process_query_data(x): + return { + **process_query(x, encoder=tokenizer, hparams=patch_h), + "reference_response": tokenizer.encode( + f" {x['summary']}<|endoftext|>", padding="max_length", max_length=args.task.response_length, truncation=True, + # with an extra leading space to account for the space between the query and response + ), + } + + dataset = dataset.map(process_query_data, load_from_cache_file=args.load_from_cache_file) + dataset = dataset.with_format("torch", columns=["query_token", "reference_response"]) + dataset = dataset.shuffle(seed=local_seed) + dataloader = DataLoader(dataset, batch_size=args.local_rollout_batch_size) + validation_dataset = validation_dataset.map(process_query_data, load_from_cache_file=args.load_from_cache_file) + validation_dataset = validation_dataset.with_format("torch", columns=["query_token", "reference_response"]) + validation_dataset = validation_dataset.shuffle(seed=local_seed) + validation_dataloader = DataLoader(validation_dataset, batch_size=args.local_rollout_batch_size) + dataloader = accelerator.prepare(dataloader) + iter_dataloader = iter(dataloader) print("===Normalize reward model *before* training===") print( "before normalization. " @@ -597,7 +637,7 @@ def process_query_data(x): tokenizer, accelerator, device, - untrained_model.lm_backbone, + reward_model, reward_model, dataloader, validation_dataloader, @@ -611,6 +651,8 @@ def process_query_data(x): # `label` has keys `['sample0', 'query', 'best', 'sample3', 'sample1', 'sample2']` label = load_dataset(args.label_dataset, "comparisons", split="train") validation_label = load_dataset(args.label_dataset, "comparisons", split="validation") + dev_validation_label = validation_label.filter(lambda x: x["split"] == "valid1") + eval_validation_label = validation_label.filter(lambda x: x["split"] == "valid2") accelerator.print("Num labels found in source:", len(label)) accelerator.print("training on", args.labels.num_train, "in batches of", args.local_batch_size) @@ -625,158 +667,113 @@ def process_response_data(x): ), } - label = label.map(process_response_data) - validation_label = validation_label.map(process_response_data) - # tokenizer.encode(label[0]["summaries"][0]["text"]) - + label = label.map(process_response_data, load_from_cache_file=args.load_from_cache_file) + dev_validation_label = dev_validation_label.map(process_response_data, load_from_cache_file=args.load_from_cache_file) + eval_validation_label = eval_validation_label.map(process_response_data, load_from_cache_file=args.load_from_cache_file) + # TODO: check if all labels have eos token accelerator.print("===training reward model===") - all_inds = np.random.permutation(args.labels.num_train) - # ensure that all processes have the same shuffled indices - all_inds = broadcast(torch.tensor(all_inds, device=device), 0) - all_inds = all_inds.cpu().numpy() - global_step = 0 - for start in range(0, args.labels.num_train, args.batch_size): - # linear rate annealing - lr = (1 - start / args.labels.num_train) * args.lr - optimizer.param_groups[0]["lr"] = lr - - global_step += 1 - end = start + args.batch_size - b_inds_all = all_inds[start:end] - b_inds = b_inds_all[accelerator.process_index :: accelerator.num_processes] # multi-GPU slicing - losses = torch.zeros((args.gradient_accumulation_steps,), device=device) - accuracies = torch.zeros((args.gradient_accumulation_steps,), device=device) - gradient_accumulation_step = 0 - for micro_batch_start in range(0, args.local_batch_size, args.local_micro_batch_size): - with accelerator.accumulate(reward_model): - micro_batch_end = micro_batch_start + args.local_micro_batch_size - micro_batch_inds = b_inds[micro_batch_start:micro_batch_end] - mb_data = label[micro_batch_inds] - mb_query = torch.from_numpy(np.stack(mb_data["query_token"])).to(device) - mb_best = torch.from_numpy(np.stack(mb_data["choice"])).to(device) - mb_responses = [ - torch.from_numpy(np.stack(mb_data[f"response{i}_token"])).to(device) for i in range(args.labels.num_labels) - ] - predicted_rewards = [] - for i in range(args.labels.num_labels): - query_responses = torch.cat([mb_query, mb_responses[i]], dim=1) - score = get_reward_complete(reward_model, query_responses, args) - predicted_rewards.append(score) - predicted_rewards = torch.stack( - predicted_rewards, dim=1 - ) # shape (batch_size, num_labels), basically a reward prediction for each label - reward_preferred = predicted_rewards.gather(1, mb_best.view(-1, 1)).view(-1) - reward_rejected = predicted_rewards.gather(1, (1 - mb_best).view(-1, 1)).view(-1) - accuracy = (predicted_rewards.argmax(1) == mb_best).float().mean() - loss = -nn.functional.logsigmoid(reward_preferred - reward_rejected).mean() - # loss = torch.nn.functional.cross_entropy(predicted_rewards, mb_best) - accelerator.backward(loss) - optimizer.step() # accelerate handles gradient accumulation automatically - optimizer.zero_grad() - losses[gradient_accumulation_step] = loss - accuracies[gradient_accumulation_step] = accuracy - gradient_accumulation_step += 1 - - train_accuracy = accelerator.gather(accuracies).mean().item() - writer.add_scalar("train/loss", accelerator.gather(losses).mean().item(), global_step) - writer.add_scalar("train/accuracy", train_accuracy, global_step) - writer.add_scalar("train/lr", lr, global_step) - accelerator.print("train/accuracy", train_accuracy) - - if args.print_sample_output_freq > 0 and global_step % args.print_sample_output_freq == 0: - with torch.no_grad(): - # eval on validation_label, some duplicate code (I don't want to make the training loop into a function...) - test_accuracies = [] - eval_len = len(validation_label) - len_labels = (eval_len // args.batch_size) * args.batch_size # in case the last batch is not full - new_all_inds = np.arange(len_labels) - for start in range(0, len_labels, args.batch_size): - end = start + args.batch_size - b_inds_all = new_all_inds[start:end] - b_inds = b_inds_all[accelerator.process_index :: accelerator.num_processes] # multi-GPU slicing - for micro_batch_start in range(0, args.local_batch_size, args.local_micro_batch_size): - micro_batch_end = micro_batch_start + args.local_micro_batch_size - micro_batch_inds = b_inds[micro_batch_start:micro_batch_end] - mb_data = validation_label[micro_batch_inds] - mb_query = torch.from_numpy(np.stack(mb_data["query_token"])).to(device) - mb_query = right_padding_to_left_padding(mb_query, args.pad_token_id).to(device) - mb_best = torch.from_numpy(np.stack(mb_data["choice"])).to(device) - mb_responses = [ - torch.from_numpy(np.stack(mb_data[f"response{i}_token"])).to(device) - for i in range(args.labels.num_labels) - ] - predicted_rewards = [] - for i in range(args.labels.num_labels): - query_responses = torch.cat([mb_query, mb_responses[i]], dim=1) - score = get_reward_complete(reward_model, query_responses, args) - predicted_rewards.append(score) - predicted_rewards = torch.stack( - predicted_rewards, dim=1 - ) # shape (batch_size, num_labels), basically a reward prediction for each label - reward_preferred = predicted_rewards.gather(1, mb_best.view(-1, 1)).view(-1) - reward_rejected = predicted_rewards.gather(1, (1 - mb_best).view(-1, 1)).view(-1) - accuracy = (predicted_rewards.argmax(1) == mb_best).float().mean() - test_accuracies.append(accuracy) - test_accuracy = accelerator.gather(torch.stack(test_accuracies).mean()).mean().item() - writer.add_scalar("test/accuracy", test_accuracy, global_step) - accelerator.print("test/accuracy", test_accuracy, global_step) - - # the part below is testing out some generations and KLs, not presented in the original code - data = next(iter_dataloader) - queries = data["query_token"].to(device) - context_length = queries.shape[1] - queries = right_padding_to_left_padding(data["query_token"], args.pad_token_id).to(device) - query_responses = generate( - accelerator.unwrap_model(reward_model).lm_backbone, - queries, - tokenizer, - generation_config, - ) - responses = query_responses[:, context_length:] - - output, _ = get_reward(reward_model, query_responses, args) - logits = output.logits[:, context_length - 1 : -1] - logits /= args.task.temperature - all_logprobs = F.log_softmax(logits, dim=-1) - logprobs = torch.gather(all_logprobs, 2, responses.unsqueeze(-1)).squeeze(-1) - del output, logits, all_logprobs - torch.cuda.empty_cache() - - output, _ = get_reward(untrained_model, query_responses, args) - logits = output.logits[:, context_length - 1 : -1] - logits /= args.task.temperature - all_logprobs = F.log_softmax(logits, dim=-1) - ref_logprobs = torch.gather(all_logprobs, 2, responses.unsqueeze(-1)).squeeze(-1) - del output, logits, all_logprobs - torch.cuda.empty_cache() - - kl = logprobs - ref_logprobs - kl_sum = kl.sum(axis=1) - all_decode_queries = tokenizer.batch_decode(queries, skip_special_tokens=True) - all_query_responses = tokenizer.batch_decode(query_responses, skip_special_tokens=True) - all_responses = [x[len(y) :] for x, y in zip(all_query_responses, all_decode_queries)] - all_df = pd.DataFrame( - { - "query": all_decode_queries, - "response": all_responses, - "kl": kl_sum.float().cpu().numpy(), - } - ) - if accelerator.is_main_process and args.track: - wandb.log({"query_responses": wandb.Table(dataframe=all_df)}, step=global_step) - try: - print_rich_table(f"Sample Output at Step {global_step}", all_df[:4], console) - except Exception as e: - print(e) - del ( - query_responses, - all_decode_queries, - all_query_responses, - all_responses, - kl_sum, - all_df, - ) - writer.add_scalar("train/kl", kl.sum(1).mean().item(), global_step) + num_train = (args.labels.num_train // args.batch_size) * args.batch_size + for epoch in range(args.num_epochs): + all_inds = np.random.permutation(args.labels.num_train) + # ensure that all processes have the same shuffled indices + all_inds = broadcast(torch.tensor(all_inds, device=device), 0) + all_inds = all_inds.cpu().numpy() + accelerator.print(f"epoch: {epoch}") + for (epoch_global_step, start) in enumerate(range(0, num_train, args.batch_size)): + global_step = epoch * args.num_updates + epoch_global_step + end = start + args.batch_size + b_inds_all = all_inds[start:end] + b_inds = b_inds_all[accelerator.process_index :: accelerator.num_processes] # multi-GPU slicing + # accelerator.print(f"global_step: {global_step}, start: {start}, end: {end}, b_inds: {b_inds}") + if accelerator.is_main_process: pprint( + { + "global_step": global_step, + "start:end": f"{start}:{end}", + "b_inds_all": b_inds_all, + "b_inds": b_inds, + } + ) + losses = torch.zeros((args.gradient_accumulation_steps,), device=device) + accuracies = torch.zeros((args.gradient_accumulation_steps,), device=device) + reward_preferreds = torch.zeros((args.gradient_accumulation_steps,), device=device) + reward_rejecteds = torch.zeros((args.gradient_accumulation_steps,), device=device) + gradient_accumulation_step = 0 + # reward_model.train() + for micro_batch_start in range(0, args.local_batch_size, args.local_micro_batch_size): + with accelerator.accumulate(reward_model): + micro_batch_end = micro_batch_start + args.local_micro_batch_size + micro_batch_inds = b_inds[micro_batch_start:micro_batch_end] + mb_data = label[micro_batch_inds] + # pprint({ + # "micro_batch_start:micro_batch_end": f"{micro_batch_start}:{micro_batch_end}", + # "micro_batch_inds": micro_batch_inds, + # }) + mb_query = torch.from_numpy(np.stack(mb_data["query_token"])).to(device) + mb_best = torch.from_numpy(np.stack(mb_data["choice"])).to(device) + mb_responses = [ + torch.from_numpy(np.stack(mb_data[f"response{i}_token"])).to(device) for i in range(args.labels.num_labels) + ] + mb_query_tiled = mb_query.unsqueeze(1).repeat(1, len(mb_responses), 1) + query_responses = torch.cat([mb_query_tiled, torch.stack(mb_responses).transpose(0,1)], dim=2).flatten(0, 1) + predicted_reward, reward = get_reward_complete(reward_model, query_responses, tokenizer) + predicted_reward = predicted_reward.view(-1, len(mb_responses)) # TODO check shape for no gradienta ccumulation steps + + # print(tokenizer.decode(mb_query[0])) + # print(tokenizer.decode(mb_responses[0][0])) + # print(tokenizer.decode(mb_responses[1][0])) + # predicted_reward = [] + # rewards = [] + # for i in range(args.labels.num_labels): + # query_responses = torch.cat([mb_query, mb_responses[i]], dim=1) + # score, reward = get_reward_complete(reward_model, query_responses, tokenizer) + # rewards.append(reward.squeeze(-1)) + # predicted_reward.append(score) + # # shape (batch_size, num_labels), basically a reward prediction for each label + # predicted_reward = torch.stack(predicted_reward, dim=1) + # breakpoint() + accuracy = (predicted_reward.argmax(1) == mb_best).float().mean() + reward_preferred = predicted_reward.gather(1, mb_best.view(-1, 1)).view(-1) + reward_rejected = predicted_reward.gather(1, (1 - mb_best).view(-1, 1)).view(-1) + if args.logsigmoid: + loss = -F.logsigmoid(reward_preferred - reward_rejected).mean() + else: + loss = F.cross_entropy(predicted_reward, mb_best) + accelerator.backward(loss) + + # for k, v in reward_model.named_parameters(): + # if v.requires_grad: + # if v.grad is None: + # print(f"found unused param: {k}") + + optimizer.step() # accelerate handles gradient accumulation automatically + optimizer.zero_grad() + scheduler.step() + losses[gradient_accumulation_step] = loss + accuracies[gradient_accumulation_step] = accuracy + reward_preferreds[gradient_accumulation_step] = reward_preferred.mean() + reward_rejecteds[gradient_accumulation_step] = reward_rejected.mean() + gradient_accumulation_step += 1 + + train_accuracy = accelerator.gather(accuracies).mean().item() + writer.add_scalar("train/loss", accelerator.gather(losses).mean().item(), global_step) + writer.add_scalar("train/accuracy", train_accuracy, global_step) + writer.add_scalar("train/reward_preferred", accelerator.gather(reward_preferreds).mean().item(), global_step) + writer.add_scalar("train/reward_rejected", accelerator.gather(reward_rejecteds).mean().item(), global_step) + lr = scheduler.get_last_lr() + writer.add_scalar("train/lr", np.array(lr).mean().item(), global_step) + accelerator.print("train/accuracy", train_accuracy) + + # if args.print_sample_output_freq > 0 and global_step % args.print_sample_output_freq == 0: + if global_step == args.num_updates - 1: # first and last update + dev_validation_accuracy = evaluate(args, accelerator, device, reward_model, dev_validation_label) + writer.add_scalar("dev_validation/accuracy", dev_validation_accuracy, global_step) + accelerator.print("dev_validation/accuracy", dev_validation_accuracy, global_step) + eval_validation_accuracy = evaluate(args, accelerator, device, reward_model, eval_validation_label) + writer.add_scalar("eval_validation/accuracy", eval_validation_accuracy, global_step) + accelerator.print("eval_validation/accuracy", eval_validation_accuracy, global_step) + eval_validation_accuracy = evaluate(args, accelerator, device, reward_model, label) + writer.add_scalar("train_full/accuracy", eval_validation_accuracy, global_step) + accelerator.print("train_full/accuracy", eval_validation_accuracy, global_step) torch.cuda.empty_cache() if args.normalize_after: @@ -791,7 +788,7 @@ def process_response_data(x): tokenizer, accelerator, device, - untrained_model.lm_backbone, + reward_model, reward_model, dataloader, validation_dataloader, diff --git a/lm_human_preference_details/train_sft_accelerate_summarize.py b/lm_human_preference_details/train_sft_accelerate_summarize.py index aa9b33b..0d6988f 100644 --- a/lm_human_preference_details/train_sft_accelerate_summarize.py +++ b/lm_human_preference_details/train_sft_accelerate_summarize.py @@ -4,12 +4,13 @@ import time from dataclasses import asdict, dataclass, field from types import SimpleNamespace -from typing import List, Optional +from typing import List, Literal, Optional import numpy as np import pandas as pd import torch import torch.optim as optim +from torch.nn import functional as F import tyro import evaluate from accelerate import Accelerator @@ -25,7 +26,7 @@ ) from torch.utils.data import DataLoader from torch.utils.tensorboard import SummaryWriter -from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig +from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig, get_scheduler from lm_human_preference_details.data import process_query @@ -37,6 +38,7 @@ class SFTHParams: noptepochs: int = 1 lr: float = 6.35e-5 eps: float = 1e-5 + lm_loss_on_response_only: bool = False total_episodes: tyro.conf.Suppress[int] = None local_batch_size:tyro.conf.Suppress[int] = None batch_size: tyro.conf.Suppress[int] = None @@ -99,6 +101,8 @@ class Args: """Whether to use cuda if available.""" run_name: tyro.conf.Suppress[str] = None """TO BE FILLED: a unique name of this run""" + load_from_cache_file: bool = False + """Whether to load data from the local cache file in `dataset.map`""" upload_model: bool = False "whether to upload the saved model to huggingface" hf_entity: str = "" @@ -110,10 +114,14 @@ class Args: """Whether to use deepspeed to train the model""" print_sample_output_freq: int = 220 """How often to print sample output""" - save_path: str = "models/sft_policy.pt" + save_path: str = "models/sft_policy" """Where to save the model""" - use_tensorflow_adam: bool = True - """Whether to use tensorflow-style Adam optimizer instead of PyTorch's""" + optimizer: Literal["tf_adam", "adam", "adamw"] = "adamw" + """Which optimizer to use""" + scheduler: str = "cosine" + """Which scheduler to use""" + warm_up_steps: int = 0 + """Number of warm up steps for the scheduler""" task: TaskHParams = field(default_factory=TaskHParams) sft: SFTHParams = field(default_factory=SFTHParams) @@ -318,7 +326,6 @@ def forward(policy, query_responses, tokenizer): position_ids = attention_mask.cumsum(1) - attention_mask.long() # exclusive cumsum input_ids = torch.masked_fill(query_responses, ~attention_mask, 0) return policy( - labels=input_ids, input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, @@ -326,7 +333,9 @@ def forward(policy, query_responses, tokenizer): ) -def train(args: Args): +# def train(args: Args): +if __name__ == "__main__": + args = tyro.cli(Args) accelerator = Accelerator(gradient_accumulation_steps=args.sft.gradient_accumulation_steps) args.sft.world_size = accelerator.num_processes args.sft.local_batch_size = args.sft.local_micro_batch_size * args.sft.gradient_accumulation_steps @@ -351,7 +360,6 @@ def train(args: Args): run_name = f"{args.exp_name}__{args.seed}__{int(time.time())}" writer = SimpleNamespace() # dummy writer writer.add_scalar = lambda x, y, z: None - writer.add_histogram = lambda x, y, z: None if accelerator.is_main_process: if args.track: import wandb @@ -389,10 +397,19 @@ def train(args: Args): policy.generation_config.pad_token_id = None # generate tokens without truncation / padding # IMPORTANT: Layer norm produces weird gradients, which affects Adam optimizer to impact all the parameters systematically # see https://github.com/pytorch/pytorch/issues/104857 for more details - if args.use_tensorflow_adam: + if args.optimizer == "tf_adam": optimizer = AdamTensorFlowStyle(policy.parameters(), lr=args.sft.lr, eps=args.sft.eps) - else: + elif args.optimizer == "adam": optimizer = optim.Adam(policy.parameters(), lr=args.sft.lr, eps=args.sft.eps) + elif args.optimizer == "adamw": + optimizer = optim.AdamW(policy.parameters(), lr=args.sft.lr, eps=args.sft.eps) + # TODO: use AdamW + scheduler = get_scheduler( + args.scheduler, + optimizer=optimizer, + num_warmup_steps=args.warm_up_steps, + num_training_steps=args.sft.num_updates // args.sft.gradient_accumulation_steps, + ) def process_query_data(x): return { @@ -403,15 +420,15 @@ def process_query_data(x): ), } - dataset = dataset.map(process_query_data) + dataset = dataset.map(process_query_data, load_from_cache_file=args.load_from_cache_file) dataset = dataset.with_format("torch", columns=["query_token", "reference_response"]) dataset = dataset.shuffle(seed=local_seed) - test_dataset = test_dataset.map(process_query_data) + test_dataset = test_dataset.map(process_query_data, load_from_cache_file=args.load_from_cache_file) test_dataset = test_dataset.with_format("torch", columns=["query_token", "reference_response"]) test_dataset = test_dataset.shuffle(seed=local_seed) dataloader = DataLoader(dataset, batch_size=args.sft.local_micro_batch_size) test_dataloader = DataLoader(test_dataset, batch_size=args.sft.local_micro_batch_size) - policy, optimizer, dataloader, test_dataloader = accelerator.prepare(policy, optimizer, dataloader, test_dataloader) + policy, optimizer, dataloader, test_dataloader, scheduler = accelerator.prepare(policy, optimizer, dataloader, test_dataloader, scheduler) iter_dataloader = iter(dataloader) # WARNING: even with `max_new_tokens` and `min_new_tokens` set to the same value, the number of tokens generated # may not be the same. TODO: investigate further, we just want to generate a fixed number of tokens @@ -427,94 +444,123 @@ def process_query_data(x): print("===training policy===") global_step = 0 - test_data = test_dataset[0:10] - test_data = {k: v.to(device) for k, v in test_data.items()} loss_stats = torch.zeros(args.sft.gradient_accumulation_steps, device=device) gradient_accumulation_idx = 0 - - # Given parameters - eta_min = 0 - eta_max = 6.35e-5 - T_max = args.sft.num_updates - + policy.train() for update in range(1, args.sft.num_updates + 1): - global_step += 1 * args.sft.batch_size + global_step += args.sft.batch_size accelerator.print(f"update {update}, global_step {global_step}") - # frac = 1.0 - (update - 1.0) / args.sft.num_updates - # lrnow = frac * args.sft.lr - lrnow = eta_min + 0.5 * (eta_max - eta_min) * (1 + np.cos(np.pi * (update - 1) / T_max)) - optimizer.param_groups[0]["lr"] = lrnow data = next(iter_dataloader) - queries = data["query_token"].to(device) - reference_responses = data["reference_response"].to(device) + reference_responses = data["reference_response"].to(device, non_blocking=True) + queries = data["query_token"].to(device, non_blocking=True) + queries = right_padding_to_left_padding(queries, tokenizer.pad_token_id).to(device) query_responses = torch.cat((queries, reference_responses), dim=1) - query_responses = right_padding_to_left_padding(query_responses, tokenizer.pad_token_id).to(device) with accelerator.accumulate(policy): output = forward(policy, query_responses, tokenizer) - loss = output.loss + # mask out gradient effects on response padding tokens + labels = query_responses.masked_fill(query_responses == tokenizer.pad_token_id, -1) + if args.sft.lm_loss_on_response_only: + # mask out gradient effects on query tokens + labels[:, :queries.shape[1]] = -1 + lm_logits = output.logits + # hand-rolled transformer loss: Shift so that tokens < n predict n + # but unlike `transformers` we mask the padding tokens via `ignore_index=-1` + shift_logits = lm_logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + loss = F.cross_entropy(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1), ignore_index=-1) accelerator.backward(loss) optimizer.step() optimizer.zero_grad() + scheduler.step() loss_stats[gradient_accumulation_idx] = loss gradient_accumulation_idx = (gradient_accumulation_idx + 1) % args.sft.gradient_accumulation_steps if update > 1 and (update - 1) % args.sft.gradient_accumulation_steps == 0: writer.add_scalar("loss", accelerator.gather(loss_stats).mean().item(), update) - writer.add_scalar("lr", lrnow, update) - if (update - 1) % args.print_sample_output_freq * args.sft.gradient_accumulation_steps == 0: + writer.add_scalar("lr", optimizer.param_groups[0]["lr"], update) + if update == 1 or update == args.sft.num_updates - 1: + policy.eval() rouge_scores = collections.defaultdict(list) + all_decode_test_queries = [] + all_decode_test_query_responses = [] + all_decode_test_responses = [] + all_decode_test_reference_responses = [] + all_test_losses = [] for test_idx, test_data in enumerate(test_dataloader): with torch.no_grad(): - test_queries = test_data["query_token"].to(device) - test_reference_responses = test_data["reference_response"].to(device) + test_reference_responses = test_data["reference_response"].to(device, non_blocking=True) + test_queries = test_data["query_token"].to(device, non_blocking=True) test_queries = right_padding_to_left_padding(test_queries, tokenizer.pad_token_id) + test_query_reference_responses = torch.cat((test_queries, test_reference_responses), dim=1) + + test_output = forward(policy, test_query_reference_responses, tokenizer) + test_labels = test_query_reference_responses.masked_fill(test_query_reference_responses == tokenizer.pad_token_id, -1) + if args.sft.lm_loss_on_response_only: + test_labels[:, :queries.shape[1]] = -1 + test_lm_logits = test_output.logits + # hand-rolled transformer loss: Shift so that tokens < n predict n + # but unlike `transformers` we mask the padding tokens via `ignore_index=-1` + test_shift_logits = test_lm_logits[..., :-1, :].contiguous() + test_shift_labels = test_labels[..., 1:].contiguous() + test_loss = F.cross_entropy(test_shift_logits.view(-1, test_shift_logits.size(-1)), test_shift_labels.view(-1), ignore_index=-1) + test_loss = accelerator.gather(test_loss) + all_test_losses.append(test_loss) + generated_responses = generate(accelerator.unwrap_model(policy), test_queries, tokenizer, generation_config) - accelerator.print(update, test_idx) - - all_decode_test_queries = tokenizer.batch_decode(test_queries, skip_special_tokens=True) - all_decode_test_query_responses = tokenizer.batch_decode(generated_responses, skip_special_tokens=True) - all_decode_test_reference_responses = tokenizer.batch_decode( - test_reference_responses, skip_special_tokens=True + decode_test_queries = tokenizer.batch_decode(accelerator.gather(test_queries)) + decode_test_query_responses = tokenizer.batch_decode(accelerator.gather(generated_responses)) + decode_test_reference_responses = tokenizer.batch_decode( + accelerator.gather(test_reference_responses) ) - all_decode_test_responses = [ - x[len(y) :] for x, y in zip(all_decode_test_query_responses, all_decode_test_queries) + decode_test_responses = [ + x[len(y) :] for x, y in zip(decode_test_query_responses, decode_test_queries) ] - rouge_score = rouge.compute(predictions=all_decode_test_responses, references=all_decode_test_reference_responses) + rouge_score = rouge.compute(predictions=decode_test_responses, references=decode_test_reference_responses) rouge_scores["rouge1"].append(rouge_score["rouge1"]) rouge_scores["rouge2"].append(rouge_score["rouge2"]) rouge_scores["rougeL"].append(rouge_score["rougeL"]) - if test_idx == 0: - try: - all_df = pd.DataFrame( - { - "query": all_decode_test_queries, - "response": all_decode_test_responses, - "reference": all_decode_test_reference_responses, - } - ) - if accelerator.is_main_process and args.track: - wandb.log({"samples/query_responses": wandb.Table(dataframe=all_df)}, step=update) - print_rich_table(f"Sample Output at Step {update}", all_df[:4], console) - except Exception as e: - print(e) + all_decode_test_queries.extend(decode_test_queries) + accelerator.print("len(all_decode_test_queries)", len(all_decode_test_queries), decode_test_responses) + all_decode_test_query_responses.extend(decode_test_query_responses) + all_decode_test_responses.extend(decode_test_responses) + all_decode_test_reference_responses.extend(decode_test_reference_responses) + if test_idx == 10: + break + + try: + all_df = pd.DataFrame( + { + "query": all_decode_test_queries, + "response": all_decode_test_responses, + "reference": all_decode_test_reference_responses, + } + ) + accelerator.print(all_df) + if accelerator.is_main_process and args.track: + wandb.log({"samples/query_responses": wandb.Table(dataframe=all_df)}, step=update) + print_rich_table(f"Sample Output at Step {update}", all_df[:4], console) + except Exception as e: + print(e) for k, v in rouge_scores.items(): rouge_metric = torch.tensor(v, device=device) rouge_metric = accelerator.gather(rouge_metric) writer.add_scalar(f"rouge/{k}", rouge_metric.mean().item(), update) accelerator.print(f"rouge/{k}: {rouge_metric.mean().item()} {rouge_metric.shape} {rouge_metric}") + writer.add_scalar("test_loss", torch.stack(all_test_losses).mean().item(), update) + policy.train() # save model if args.save_path: os.makedirs(os.path.dirname(args.save_path), exist_ok=True) - accelerator.save_model(policy, args.save_path) + accelerator.save_model(policy, args.save_path, max_shard_size="1000GB") - if args.upload_model: - repo_name = f"{args.exp_name}__{args.rewards.label_dataset}__seed{args.seed}__{int(time.time())}" + if args.upload_model and accelerator.is_main_process: + repo_name = f"{args.exp_name}__tldr__seed{args.seed}__{int(time.time())}" repo_id = f"{args.hf_entity}/{repo_name}" if args.hf_entity else repo_name policy.save_pretrained(repo_id, safe_serialization=True, push_to_hub=True) tokenizer.save_pretrained(repo_id, push_to_hub=True) -if __name__ == "__main__": - args = tyro.cli(Args) - train(args) +# if __name__ == "__main__": +# args = tyro.cli(Args) +# train(args) From 332da0d6b04c6dfacd8e1862cfb45bd6b888893b Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Wed, 25 Oct 2023 20:36:00 +0000 Subject: [PATCH 17/62] actually kind of work --- .../summarization/minimal_rm copy.py | 25 + .../summarization/minimal_rm.py | 45 + .../summarization/minisft.py | 295 +++++ .../train_policy_accelerate copy 2.py | 836 ++++++++++++++ .../train_policy_accelerate copy.py | 947 +++++++++++++++ .../train_policy_accelerate_new.py | 952 +++++++++++++++ .../train_policy_accelerate_old.py | 922 +++++++++++++++ ...in_policy_accelerate_summarize_ref_diff.py | 889 ++++++++++++++ .../train_reward_accelerate copy.py | 732 ++++++++++++ .../train_reward_accelerate_debug copy.py | 526 +++++++++ .../train_reward_accelerate_debug.py | 528 +++++++++ ...train_reward_accelerate_summarize_debug.py | 977 ++++++++++++++++ .../train_reward_accelerate_summarized.py | 778 +++++++++++++ .../train_reward_accelerate_summarizew.py | 824 +++++++++++++ .../train_sft_accelerate_summarize copy.py | 521 +++++++++ ...train_sft_accelerate_summarize_executor.py | 540 +++++++++ .../train_policy_accelerate_summarize.py | 0 ...in_policy_accelerate_summarize_separate.py | 1021 +++++++++++++++++ .../train_reward_accelerate_summarize.py | 814 +++++++++++++ lm_human_preference_details/tldr_dataset.py | 128 +++ ...in_policy_accelerate_summarize_separate.py | 231 +--- .../train_reward_accelerate_summarize.py | 251 +--- 22 files changed, 12383 insertions(+), 399 deletions(-) create mode 100644 lm_human_preference_details/summarization/minimal_rm copy.py create mode 100644 lm_human_preference_details/summarization/minimal_rm.py create mode 100644 lm_human_preference_details/summarization/minisft.py create mode 100644 lm_human_preference_details/summarization/train_policy_accelerate copy 2.py create mode 100644 lm_human_preference_details/summarization/train_policy_accelerate copy.py create mode 100644 lm_human_preference_details/summarization/train_policy_accelerate_new.py create mode 100644 lm_human_preference_details/summarization/train_policy_accelerate_old.py create mode 100644 lm_human_preference_details/summarization/train_policy_accelerate_summarize_ref_diff.py create mode 100644 lm_human_preference_details/summarization/train_reward_accelerate copy.py create mode 100644 lm_human_preference_details/summarization/train_reward_accelerate_debug copy.py create mode 100644 lm_human_preference_details/summarization/train_reward_accelerate_debug.py create mode 100644 lm_human_preference_details/summarization/train_reward_accelerate_summarize_debug.py create mode 100644 lm_human_preference_details/summarization/train_reward_accelerate_summarized.py create mode 100644 lm_human_preference_details/summarization/train_reward_accelerate_summarizew.py create mode 100644 lm_human_preference_details/summarization/train_sft_accelerate_summarize copy.py create mode 100644 lm_human_preference_details/summarization/train_sft_accelerate_summarize_executor.py rename lm_human_preference_details/{ => summarize_old}/train_policy_accelerate_summarize.py (100%) create mode 100644 lm_human_preference_details/summarize_old/train_policy_accelerate_summarize_separate.py create mode 100644 lm_human_preference_details/summarize_old/train_reward_accelerate_summarize.py create mode 100644 lm_human_preference_details/tldr_dataset.py diff --git a/lm_human_preference_details/summarization/minimal_rm copy.py b/lm_human_preference_details/summarization/minimal_rm copy.py new file mode 100644 index 0000000..7049f0f --- /dev/null +++ b/lm_human_preference_details/summarization/minimal_rm copy.py @@ -0,0 +1,25 @@ +import numpy as np +import torch +import torch.nn as nn +from transformers import AutoModelForCausalLM, AutoTokenizer +class AutoModelForCausalLMWithRewardHead(nn.Module): + def __init__(self, lm_backbone): + super().__init__() + self.lm_backbone = lm_backbone + self.scalar_head = nn.Linear(lm_backbone.config.hidden_size, 1) + + def forward(self, **kwargs): + output = self.lm_backbone(**kwargs) + last_reward_latents = output.hidden_states[-1] + # shape: [batch_size, hidden_size] + reward = self.scalar_head(last_reward_latents) + return output, reward +base_model = "gpt2" +tokenizer = AutoTokenizer.from_pretrained(base_model, padding_side="left") +reward_model = AutoModelForCausalLMWithRewardHead(AutoModelForCausalLM.from_pretrained(base_model)) +mb_query = torch.randint(0, len(tokenizer), (1, 512)) +mb_responses = torch.randint(0, len(tokenizer), (1, 2, 80)) +mb_query_tiled = mb_query.unsqueeze(1).repeat(1, mb_responses.shape[1], 1) +query_responses = torch.cat([mb_query_tiled, mb_responses], dim=2).flatten(0, 1) +_, score = reward_model(input_ids=query_responses, return_dict=True, output_hidden_states=True) +print(score.squeeze(2)) \ No newline at end of file diff --git a/lm_human_preference_details/summarization/minimal_rm.py b/lm_human_preference_details/summarization/minimal_rm.py new file mode 100644 index 0000000..0cb4179 --- /dev/null +++ b/lm_human_preference_details/summarization/minimal_rm.py @@ -0,0 +1,45 @@ +import numpy as np +import torch +import torch.nn as nn +from transformers import AutoModelForCausalLM, AutoTokenizer + + +class AutoModelForCausalLMWithRewardHead(nn.Module): + def __init__(self, lm_backbone): + super().__init__() + self.lm_backbone = lm_backbone + self.scalar_head = nn.Linear(lm_backbone.config.hidden_size, 1) + + def forward(self, **kwargs): + output = self.lm_backbone(**kwargs) + last_reward_latents = output.hidden_states[-1] + # shape: [batch_size, hidden_size] + reward = self.scalar_head(last_reward_latents) + return output, reward + + +def get_reward(reward_model, query_responses, tokenizer): + attention_mask = query_responses != tokenizer.pad_token_id + position_ids = attention_mask.cumsum(1) - attention_mask.long() # exclusive cumsum + input_ids = query_responses.clone() + input_ids[~attention_mask] = 0 + return reward_model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + return_dict=True, + output_hidden_states=True, + ) + +base_model = "gpt2" +tokenizer = AutoTokenizer.from_pretrained(base_model, padding_side="left") +tokenizer.add_special_tokens({"pad_token": "[PAD]"}) +reward_model = AutoModelForCausalLMWithRewardHead(AutoModelForCausalLM.from_pretrained(base_model)) +reward_model.train() +mb_query = torch.randint(0, len(tokenizer), (1, 10)) +mb_query[:,0:4] = tokenizer.pad_token_id +mb_responses = torch.randint(0, len(tokenizer), (1, 2, 10)) +mb_query_tiled = mb_query.unsqueeze(1).repeat(1, mb_responses.shape[1], 1) +query_responses = torch.cat([mb_query_tiled, mb_responses], dim=2).flatten(0, 1) +_, score_all = get_reward(reward_model, query_responses, tokenizer) +print(score_all.squeeze(2)) \ No newline at end of file diff --git a/lm_human_preference_details/summarization/minisft.py b/lm_human_preference_details/summarization/minisft.py new file mode 100644 index 0000000..fede737 --- /dev/null +++ b/lm_human_preference_details/summarization/minisft.py @@ -0,0 +1,295 @@ +import collections +import os +import random +import time +from dataclasses import asdict, dataclass, field +from types import SimpleNamespace +from typing import List, Optional + +import numpy as np +import pandas as pd +import torch +import torch.optim as optim +from torch.nn import functional as F +import tyro +import evaluate +from accelerate import Accelerator +from datasets import load_dataset +from rich.console import Console +from rich.pretty import pprint +from rich.table import Table +from torch import Tensor, optim +from torch.optim.optimizer import ( + _dispatch_sqrt, + _get_value, + _use_grad_for_differentiable, +) +from torch.utils.data import DataLoader +from torch.utils.tensorboard import SummaryWriter +from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig + +from lm_human_preference_details.data import process_query + + +@dataclass +class SFTHParams: + gradient_accumulation_steps: int = 1 + local_micro_batch_size: int = 16 + noptepochs: int = 1 + lr: float = 6.35e-5 + eps: float = 1e-5 + lm_loss_on_response_only: bool = False + total_episodes: tyro.conf.Suppress[int] = None + local_batch_size:tyro.conf.Suppress[int] = None + batch_size: tyro.conf.Suppress[int] = None + mini_batch_size: tyro.conf.Suppress[int] = None + world_size: tyro.conf.Suppress[int] = None + batch_size: tyro.conf.Suppress[int] = None + num_updates: tyro.conf.Suppress[int] = None + + +@dataclass +class TaskHParams: + # Query params + query_length: int = 512 + query_dataset: str = "vwxyzjn/summarize_from_feedback_tldr_3_filtered" + + query_format_str: Optional[str] = "SUBREDDIT: r/{subreddit}\n\nTITLE: {title}\n\nPOST: {post}\n\nTL;DR:" + query_truncate_field: Optional[str] = "post" + query_truncate_text: Optional[str] = "\n" + query_padding: Optional[str] = None # defaults to repeated spaces + query_pad_side: Optional[str] = "left" + + # Response params + response_length: int = 48 + + # Truncate response after the first occurrence of this token at or after index after when sampling. + truncate_token: int = 50256 # EOS token + truncate_after: int = 16 + penalty_reward_value: int = -1 + + # LM params + temperature: float = 0.01 + + +# a patch +@dataclass +class TaskQueryHParams: + length: int = None + dataset: str = None + format_str: Optional[str] = None # if underlying dataset yields dicts, can format arbitrarily + truncate_field: Optional[str] = None + truncate_text: Optional[str] = None + padding: Optional[str] = None # defaults to repeated spaces + pad_side: Optional[str] = None + + +@dataclass +class Args: + # common args + exp_name: str = os.path.basename(__file__)[: -len(".py")] + """the name of this experiment""" + seed: int = 1 + """seed of the experiment""" + track: bool = False + """if toggled, this experiment will be tracked with Weights and Biases""" + wandb_project_name: str = "cleanrl" + """the wandb's project name""" + wandb_entity: Optional[str] = None + """the entity (team) of wandb's project""" + cuda: bool = True + """Whether to use cuda if available.""" + run_name: tyro.conf.Suppress[str] = None + """TO BE FILLED: a unique name of this run""" + upload_model: bool = False + "whether to upload the saved model to huggingface" + hf_entity: str = "" + "the user or org name of the model repository from the Hugging Face Hub" + + base_model: str = "gpt2" + """the name of the pretrained model to use""" + deepspeed: bool = False + """Whether to use deepspeed to train the model""" + print_sample_output_freq: int = 220 + """How often to print sample output""" + save_path: str = "models/sft_policy.pt" + """Where to save the model""" + use_tensorflow_adam: bool = True + """Whether to use tensorflow-style Adam optimizer instead of PyTorch's""" + task: TaskHParams = field(default_factory=TaskHParams) + sft: SFTHParams = field(default_factory=SFTHParams) + + +def right_padding_to_left_padding(tokens, pad_id): + """Convert from right padding to left padding.""" + assert tokens.ndim == 2 + return torch.tensor( + [[pad_id] * (row == pad_id).sum() + [x for x in row if x != pad_id] for row in tokens], + device=tokens.device, + ) + + +def generate(lm_backbone, queries, tokenizer, generation_config): + """generate in a way that does not affect padding tokens""" + context_length = queries.shape[1] + attention_mask = queries != tokenizer.pad_token_id + input_ids = queries.clone() + input_ids[~attention_mask] = 0 # set padding tokens to 0 + output = lm_backbone.generate( + input_ids=input_ids, + attention_mask=attention_mask, + # position_ids=attention_mask.cumsum(1) - attention_mask.long(), # generation collapsed if this was turned on. TODO: why does generation collapse with this? + generation_config=generation_config, + return_dict_in_generate=True, + ) + # restore padding tokens + return torch.cat((queries, output.sequences[:, context_length:]), dim=1) + + +def forward(policy, query_responses, tokenizer): + attention_mask = query_responses != tokenizer.pad_token_id + position_ids = attention_mask.cumsum(1) - attention_mask.long() # exclusive cumsum + input_ids = torch.masked_fill(query_responses, ~attention_mask, 0) + return policy( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + return_dict=True, + ) + + +if __name__ == "__main__": + args = tyro.cli(Args) + accelerator = Accelerator(gradient_accumulation_steps=args.sft.gradient_accumulation_steps) + args.sft.world_size = accelerator.num_processes + args.sft.local_batch_size = args.sft.local_micro_batch_size * args.sft.gradient_accumulation_steps + args.sft.batch_size = int(args.sft.local_batch_size * args.sft.world_size) + patch_h = TaskQueryHParams( + length=args.task.query_length, + dataset=args.task.query_dataset, + format_str=args.task.query_format_str, + truncate_field=args.task.query_truncate_field, + truncate_text=args.task.query_truncate_text, + padding=args.task.query_padding, + pad_side=args.task.query_pad_side, + ) + dataset = load_dataset(args.task.query_dataset, split="train") + test_dataset = load_dataset(args.task.query_dataset, split="test") + accelerator.print("The number of samples in dataset", len(dataset)) + accelerator.print("The number of samples in test_dataset", len(test_dataset)) + args.sft.total_episodes = len(dataset) + args.sft.num_updates = args.sft.total_episodes // args.sft.batch_size + + console = Console(force_terminal=True) + run_name = f"{args.exp_name}__{args.seed}__{int(time.time())}" + writer = SimpleNamespace() # dummy writer + writer.add_scalar = lambda x, y, z: None + writer.add_histogram = lambda x, y, z: None + if accelerator.is_main_process: + if args.track: + import wandb + + wandb.init( + project=args.wandb_project_name, + entity=args.wandb_entity, + sync_tensorboard=True, + config=asdict(args), + name=run_name, + save_code=True, + ) + wandb.run.log_code(".") + writer = SummaryWriter(f"runs/{run_name}") + writer.add_text( + "hyperparameters", + "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])), + ) + pprint(args) + device = accelerator.device + local_seed = args.seed + accelerator.process_index * 100003 # Prime + random.seed(local_seed) + np.random.seed(local_seed) + torch.manual_seed(local_seed) + torch.backends.cudnn.deterministic = True + tokenizer = AutoTokenizer.from_pretrained( + args.base_model, + padding_side="right", + trust_remote_code=True, + ) + # we use the padding token manually but do not resize the token embedding of the model + tokenizer.add_special_tokens({"pad_token": "[PAD]"}) + policy = AutoModelForCausalLM.from_pretrained(args.base_model, trust_remote_code=True) + policy.generation_config.eos_token_id = None # disable `pad_token_id` and `eos_token_id` because we just want to + policy.generation_config.pad_token_id = None # generate tokens without truncation / padding + # IMPORTANT: Layer norm produces weird gradients, which affects Adam optimizer to impact all the parameters systematically + # see https://github.com/pytorch/pytorch/issues/104857 for more details + optimizer = optim.Adam(policy.parameters(), lr=args.sft.lr, eps=args.sft.eps) + + def process_query_data(x): + return { + **process_query(x, encoder=tokenizer, hparams=patch_h), + "reference_response": tokenizer.encode( + f" {x['summary']}<|endoftext|>", padding="max_length", max_length=args.task.response_length, truncation=True, + # with an extra leading space to account for the space between the query and response + ), + } + + dataset = dataset.map(process_query_data) + dataset = dataset.with_format("torch", columns=["query_token", "reference_response"]) + dataset = dataset.shuffle(seed=local_seed) + test_dataset = test_dataset.map(process_query_data) + test_dataset = test_dataset.with_format("torch", columns=["query_token", "reference_response"]) + test_dataset = test_dataset.shuffle(seed=local_seed) + dataloader = DataLoader(dataset, batch_size=args.sft.local_micro_batch_size) + test_dataloader = DataLoader(test_dataset, batch_size=args.sft.local_micro_batch_size) + policy, optimizer, dataloader, test_dataloader = accelerator.prepare(policy, optimizer, dataloader, test_dataloader) + iter_dataloader = iter(dataloader) + # WARNING: even with `max_new_tokens` and `min_new_tokens` set to the same value, the number of tokens generated + # may not be the same. TODO: investigate further, we just want to generate a fixed number of tokens + # generation_config = GenerationConfig( + # max_new_tokens=args.task.response_length, + # min_new_tokens=args.task.response_length, + # temperature=args.task.temperature, + # top_k=0.0, + # top_p=1.0, + # do_sample=True, + # ) + + print("===training policy===") + global_step = 0 + test_data = test_dataset[0:10] + test_data = {k: v.to(device) for k, v in test_data.items()} + + # Given parameters + eta_min = 0 + eta_max = 6.35e-5 + T_max = args.sft.num_updates + + for update in range(1, args.sft.num_updates + 1): + global_step += 1 * args.sft.batch_size + accelerator.print(f"update {update}, global_step {global_step}") + # frac = 1.0 - (update - 1.0) / args.sft.num_updates + # lrnow = frac * args.sft.lr + lrnow = eta_min + 0.5 * (eta_max - eta_min) * (1 + np.cos(np.pi * (update - 1) / T_max)) + optimizer.param_groups[0]["lr"] = lrnow + data = next(iter_dataloader) + queries = data["query_token"].to(device) + reference_responses = data["reference_response"].to(device) + queries = right_padding_to_left_padding(queries, tokenizer.pad_token_id).to(device) + query_responses = torch.cat((queries, reference_responses), dim=1) + with accelerator.accumulate(policy): + output = forward(policy, query_responses, tokenizer) + # mask out gradient effects on response padding tokens + labels = query_responses.masked_fill(query_responses == tokenizer.pad_token_id, -1) + if args.sft.lm_loss_on_response_only: + # mask out gradient effects on query tokens + labels[:, :queries.shape[1]] = -1 + lm_logits = output.logits + # hand-rolled transformer loss: Shift so that tokens < n predict n + # but unlike `transformers` we mask the padding tokens via `ignore_index=-1` + shift_logits = lm_logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + loss = F.cross_entropy(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1), ignore_index=-1) + raise + accelerator.backward(loss) + optimizer.step() + optimizer.zero_grad() diff --git a/lm_human_preference_details/summarization/train_policy_accelerate copy 2.py b/lm_human_preference_details/summarization/train_policy_accelerate copy 2.py new file mode 100644 index 0000000..b77f275 --- /dev/null +++ b/lm_human_preference_details/summarization/train_policy_accelerate copy 2.py @@ -0,0 +1,836 @@ +import os +import random +import time +from dataclasses import asdict, dataclass, field +from types import SimpleNamespace +from typing import List, Optional + +import numpy as np +import pandas as pd +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +import tyro +from accelerate import Accelerator +from accelerate.state import AcceleratorState +from rich.console import Console +from rich.pretty import pprint +from rich.table import Table +from torch import Tensor, optim +from torch.optim.optimizer import ( + _dispatch_sqrt, + _get_value, + _use_grad_for_differentiable, +) +from torch.utils.data import DataLoader, IterableDataset +from torch.utils.tensorboard import SummaryWriter +from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig + +from lm_human_preference_details.data import DATASET + + +@dataclass +class AdaptiveKLParams: + target: float = 6.0 + horizon: int = 10000 # in episodes + + +@dataclass +class RewardHParams: + kl_coef: float = 0.15 + adaptive_kl: Optional[AdaptiveKLParams] = field(default_factory=AdaptiveKLParams) + trained_model: Optional[str] = "models/reward.pt" + label_dataset: tyro.conf.Suppress[Optional[str]] = None + + +@dataclass +class PpoHParams: + total_episodes: int = 1000000 + local_batch_size: int = 64 + local_mini_batch_size: tyro.conf.Suppress[int] = None + batch_size: tyro.conf.Suppress[int] = None + mini_batch_size: tyro.conf.Suppress[int] = None + gradient_accumulation_steps: int = 1 + """gradient accumulation steps""" + local_micro_batch_size: tyro.conf.Suppress[int] = None + """per rank micro batch size""" + world_size: tyro.conf.Suppress[int] = None + batch_size: tyro.conf.Suppress[int] = None + minibatch_size: tyro.conf.Suppress[int] = None + num_updates: tyro.conf.Suppress[int] = None + nminibatches: int = 1 + noptepochs: int = 4 + lr: float = 0.00001 + eps: float = 1e-5 + vf_coef: float = 0.1 + cliprange: float = 0.2 + cliprange_value: float = 0.2 + gamma: float = 1 + lam: float = 0.95 + whiten_rewards: bool = True + + +@dataclass +class TaskHParams: + # Query params + query_length: int = 64 + query_dataset: str = "books" + query_prefix: str = "" + query_suffix: str = "" + start_text: Optional[str] = None + end_text: Optional[str] = None + + # Response params + response_length: int = 24 + + # Truncate response after the first occurrence of this token at or after index after when sampling. + truncate_token: int = 13 + truncate_after: int = 16 + penalty_reward_value: int = -1 + + # LM params + temperature: float = 0.7 + + +@dataclass +class Args: + # common args + exp_name: str = os.path.basename(__file__)[: -len(".py")] + """the name of this experiment""" + seed: int = 1 + """seed of the experiment""" + track: bool = False + """if toggled, this experiment will be tracked with Weights and Biases""" + wandb_project_name: str = "cleanrl" + """the wandb's project name""" + wandb_entity: Optional[str] = None + """the entity (team) of wandb's project""" + cuda: bool = True + """Whether to use cuda if available.""" + run_name: tyro.conf.Suppress[str] = None + """TO BE FILLED: a unique name of this run""" + + base_model: str = "gpt2" + """the name of the pretrained model to use""" + deepspeed: bool = False + """Whether to use deepspeed to train the model""" + print_sample_output_freq: int = 10 + """How often to print sample output""" + save_path: str = "models/policy.pt" + """Where to save the model""" + use_tensorflow_adam: bool = True + """Whether to use tensorflow-style Adam optimizer instead of PyTorch's""" + task: TaskHParams = field(default_factory=TaskHParams) + rewards: RewardHParams = field(default_factory=RewardHParams) + ppo: PpoHParams = field(default_factory=PpoHParams) + + +def print_rich_table(title: str, df: pd.DataFrame, console: Console) -> Table: + table = Table(show_lines=True) + for column in df.columns: + table.add_column(column) + for _, row in df.iterrows(): + table.add_row(*row.astype(str).tolist()) + console.rule(f"[bold red]{title}") + console.print(table) + + +def _single_tensor_adam( + params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_avg_sqs: List[Tensor], + max_exp_avg_sqs: List[Tensor], + state_steps: List[Tensor], + grad_scale: Optional[Tensor], + found_inf: Optional[Tensor], + *, + amsgrad: bool, + beta1: float, + beta2: float, + lr: float, + weight_decay: float, + eps: float, + maximize: bool, + capturable: bool, + differentiable: bool, +): + assert grad_scale is None and found_inf is None + + for i, param in enumerate(params): + grad = grads[i] if not maximize else -grads[i] + exp_avg = exp_avgs[i] + exp_avg_sq = exp_avg_sqs[i] + step_t = state_steps[i] + # update step + step_t += 1 + # Decay the first and second moment running average coefficient + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2) + step = _get_value(step_t) + + ### pytorch adam implementation: + # bias_correction1 = 1 - beta1 ** step + # bias_correction2 = 1 - beta2 ** step + # step_size = lr / bias_correction1 + # bias_correction2_sqrt = _dispatch_sqrt(bias_correction2) + # denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps) + # param.addcdiv_(exp_avg, denom, value=-step_size) + + ### tensorflow adam implementation: + lr_t = lr * _dispatch_sqrt(1 - beta2**step) / (1 - beta1**step) + denom = exp_avg_sq.sqrt().add_(eps) + param.addcdiv_(exp_avg, denom, value=-lr_t) + + +def adam( + params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_avg_sqs: List[Tensor], + max_exp_avg_sqs: List[Tensor], + state_steps: List[Tensor], + # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627 + # setting this as kwarg for now as functional API is compiled by torch/distributed/optim + foreach: Optional[bool] = None, + capturable: bool = False, + differentiable: bool = False, + fused: Optional[bool] = None, + grad_scale: Optional[Tensor] = None, + found_inf: Optional[Tensor] = None, + *, + amsgrad: bool, + beta1: float, + beta2: float, + lr: float, + weight_decay: float, + eps: float, + maximize: bool, +): + func = _single_tensor_adam + + func( + params, + grads, + exp_avgs, + exp_avg_sqs, + max_exp_avg_sqs, + state_steps, + amsgrad=amsgrad, + beta1=beta1, + beta2=beta2, + lr=lr, + weight_decay=weight_decay, + eps=eps, + maximize=maximize, + capturable=capturable, + differentiable=differentiable, + grad_scale=grad_scale, + found_inf=found_inf, + ) + + +class AdamTensorFlowStyle(optim.Adam): + @_use_grad_for_differentiable + def step(self, closure=None): + self._cuda_graph_capture_health_check() + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + params_with_grad = [] + grads = [] + exp_avgs = [] + exp_avg_sqs = [] + max_exp_avg_sqs = [] + state_steps = [] + beta1, beta2 = group["betas"] + + self._init_group( + group, + params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + max_exp_avg_sqs, + state_steps, + ) + + adam( + params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + max_exp_avg_sqs, + state_steps, + amsgrad=group["amsgrad"], + beta1=beta1, + beta2=beta2, + lr=group["lr"], + weight_decay=group["weight_decay"], + eps=group["eps"], + maximize=group["maximize"], + foreach=group["foreach"], + capturable=group["capturable"], + differentiable=group["differentiable"], + fused=group["fused"], + grad_scale=getattr(self, "grad_scale", None), + found_inf=getattr(self, "found_inf", None), + ) + + return loss + + +class AdaptiveKLController: + def __init__(self, init_kl_coef: float, hparams: AdaptiveKLParams): + self.value = init_kl_coef + self.hparams = hparams + + def update(self, current, n_steps): + target = self.hparams.target + proportional_error = np.clip(current / target - 1, -0.2, 0.2) + mult = 1 + proportional_error * n_steps / self.hparams.horizon + self.value *= mult + + +def layer_init(layer, std=np.sqrt(2), bias_const=0.0): + torch.nn.init.normal_(layer.weight, std=std) + torch.nn.init.constant_(layer.bias, val=bias_const) + return layer + + +def whiten(values, shift_mean=True): + # `unbiased=False` matches TF `tf.nn.moments`'s setting + mean, var = torch.mean(values), torch.var(values, unbiased=False) + whitened = (values - mean) * torch.rsqrt(var + 1e-8) + if not shift_mean: + whitened += mean + return whitened + + +class AutoModelForCausalLMWithScalarHead(nn.Module): + def __init__(self, lm_backbone): + super().__init__() + self.lm_backbone = lm_backbone + self.scalar_head = layer_init(nn.Linear(lm_backbone.config.hidden_size, 1), std=0) + + def forward(self, **kwargs): + output = self.lm_backbone(**kwargs) + return output, self.scalar_head(output.hidden_states[-1]) + + +class AutoModelForCausalLMWithRewardHead(nn.Module): + def __init__(self, lm_backbone): + super().__init__() + self.lm_backbone = lm_backbone + self.scalar_head = layer_init( + nn.Linear(lm_backbone.config.hidden_size, 1), + std=1 / np.sqrt(lm_backbone.config.hidden_size + 1), + ) + self.reward_gain = torch.nn.Parameter(torch.tensor(1.0), requires_grad=True) + self.reward_bias = torch.nn.Parameter(torch.tensor(0.0), requires_grad=True) + + +# a pytorch dataset +class MyDataset(IterableDataset): + def __init__(self, generator, tokenizer, query_length, seed, start_text=None, end_text=None): + self.generator = generator + self.tokenizer = tokenizer + self.query_length = query_length + self.start_text = start_text + self.end_text = end_text + self.seed = seed + token_to_index = tokenizer.get_vocab() + self.start_token = token_to_index[start_text] if self.start_text else None + self.end_token = token_to_index[end_text] if self.end_text else None + + def __iter__(self): + for text in self.generator("train", self.seed, shuffle=True): + tokens = self.tokenizer.encode(text) + if self.start_token is not None: + try: + first_index = tokens.index(self.start_token) + 1 + if first_index < len(tokens): + tokens = tokens[first_index:] + except: + continue + tokens = tokens[: self.query_length] + if self.end_token is not None: + try: + last_index = len(tokens) - tokens[::-1].index(self.end_token) + tokens = tokens[:last_index] + except: + continue + output = self.tokenizer.pad( + {"input_ids": tokens}, + padding="max_length", + max_length=self.query_length, + return_tensors="pt", + return_attention_mask=True, + ) + yield output + + +def right_padding_to_left_padding(query, pad_id): + # Convert from right padding to left padding. + return torch.tensor( + [[pad_id] * (row == pad_id).sum() + [x for x in row if x != pad_id] for row in query], + device=query.device, + ) + + +def ceil_div(a, b): + return (a - 1) // b + 1 + + +def exact_div(a, b): + q = a // b + if a != q * b: + raise ValueError(f"Inexact division: {a} / {b} = {a / b}") + return q + + +def generate(lm_backbone, queries, tokenizer, generation_config): + """generate in a way that does not affect padding tokens""" + context_length = queries.shape[1] + attention_mask = queries != tokenizer.pad_token_id + input_ids = queries.clone() + input_ids[~attention_mask] = 0 # set padding tokens to 0 + output = lm_backbone.generate( + input_ids=input_ids, + attention_mask=attention_mask, + # position_ids=attention_mask.cumsum(1) - attention_mask.long(), # generation collapsed if this was turned on. TODO: why does generation collapse with this? + generation_config=generation_config, + return_dict_in_generate=True, + ) + # restore padding tokens + return torch.cat((queries, output.sequences[:, context_length:]), dim=1) + + +def get_reward(reward_model, query_responses, tokenizer): + attention_mask = query_responses != tokenizer.pad_token_id + position_ids = attention_mask.cumsum(1) - attention_mask.long() # exclusive cumsum + input_ids = query_responses.clone() + input_ids[~attention_mask] = 0 + output = reward_model.lm_backbone( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + return_dict=True, + output_hidden_states=True, + ) + reward = reward_model.scalar_head(output.hidden_states[-1]) + reward = reward_model.reward_gain * reward + reward_model.reward_bias + # but we only care about the reward of the last token + reward = reward[:, -1] + return reward + + +def forward(policy, query_responses, tokenizer): + attention_mask = query_responses != tokenizer.pad_token_id + position_ids = attention_mask.cumsum(1) - attention_mask.long() # exclusive cumsum + input_ids = query_responses.clone() + input_ids[~attention_mask] = 0 + return policy( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + return_dict=True, + output_hidden_states=True, + ) + + +def train(args: Args): + accelerator = Accelerator(gradient_accumulation_steps=args.ppo.gradient_accumulation_steps) + args.ppo.world_size = accelerator.num_processes + args.ppo.batch_size = int(args.ppo.local_batch_size * args.ppo.world_size) + args.ppo.minibatch_size = exact_div(args.ppo.batch_size, args.ppo.nminibatches) + args.ppo.local_mini_batch_size = exact_div(args.ppo.local_batch_size, args.ppo.nminibatches) + args.ppo.local_micro_batch_size = exact_div(args.ppo.local_mini_batch_size, args.ppo.gradient_accumulation_steps) + if args.ppo.whiten_rewards: + assert ( + args.ppo.local_mini_batch_size >= 8 + ), f"Per-rank minibatch size {args.ppo.local_mini_batch_size} is insufficient for whitening" + # `per_rank_rollout_batch_size` is our `args.ppo.local_batch_size` + # `per_rank_minibatch_size` is our `args.ppo.local_mini_batch_size` + args.ppo.num_updates = args.ppo.total_episodes // args.ppo.batch_size + + console = Console(force_terminal=True) + run_name = f"{args.exp_name}__{args.seed}__{int(time.time())}" + writer = SimpleNamespace() # dummy writer + writer.add_scalar = lambda x, y, z: None + writer.add_histogram = lambda x, y, z: None + if accelerator.is_main_process: + if args.track: + import wandb + + wandb.init( + project=args.wandb_project_name, + entity=args.wandb_entity, + sync_tensorboard=True, + config=asdict(args), + name=run_name, + save_code=True, + ) + wandb.run.log_code(".") + writer = SummaryWriter(f"runs/{run_name}") + writer.add_text( + "hyperparameters", + "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])), + ) + pprint(args) + device = accelerator.device + local_seed = args.seed + accelerator.process_index * 100003 # Prime + random.seed(local_seed) + np.random.seed(local_seed) + torch.manual_seed(local_seed) + torch.backends.cudnn.deterministic = True + tokenizer = AutoTokenizer.from_pretrained( + args.base_model, + padding_side="right", + trust_remote_code=True, + ) + # we use the padding token manually but do not resize the token embedding of the model + tokenizer.add_special_tokens({"pad_token": "[PAD]"}) + reward_model = AutoModelForCausalLMWithRewardHead( + AutoModelForCausalLM.from_pretrained(args.base_model, trust_remote_code=True) + ) + if args.rewards.trained_model: + reward_model.load_state_dict(torch.load(args.rewards.trained_model, map_location=device)) + print(f"loaded pretrained reward model from {args.rewards.trained_model}") + # each class should have a separate pretrained model that do not share weights + ref_policy = AutoModelForCausalLMWithScalarHead( + AutoModelForCausalLM.from_pretrained(args.base_model, trust_remote_code=True) + ) + policy = AutoModelForCausalLMWithScalarHead(AutoModelForCausalLM.from_pretrained(args.base_model, trust_remote_code=True)) + policy.lm_backbone.generation_config.eos_token_id = ( + None # disable `pad_token_id` and `eos_token_id` because we just want to + ) + policy.lm_backbone.generation_config.pad_token_id = None # generate tokens without truncation / padding + # IMPORTANT: Layer norm produces weird gradients, which affects Adam optimizer to impact all the parameters systematically + # see https://github.com/pytorch/pytorch/issues/104857 for more details + if args.use_tensorflow_adam: + optimizer = AdamTensorFlowStyle(policy.parameters(), lr=args.ppo.lr, eps=args.ppo.eps) + else: + optimizer = optim.Adam(policy.parameters(), lr=args.ppo.lr, eps=args.ppo.eps) + dataset = MyDataset( + DATASET[args.task.query_dataset], + tokenizer, + args.task.query_length, + seed=local_seed, + start_text=args.task.start_text, + end_text=args.task.end_text, + ) + dataloader = DataLoader(dataset, batch_size=args.ppo.local_batch_size) + policy, optimizer, dataloader = accelerator.prepare(policy, optimizer, dataloader) + if args.deepspeed: + import deepspeed + + deepspeed_states = AcceleratorState().deepspeed_plugin + # deepspeed_states.deepspeed_config["train_micro_batch_size_per_gpu"] = args.ppo.local_micro_batch_size + # deepspeed_states.deepspeed_config["checkpoint"] = {"use_node_local_storage": True} + eval_ds_config = { + "train_micro_batch_size_per_gpu": deepspeed_states.deepspeed_config["train_micro_batch_size_per_gpu"], + # "steps_per_print": 10, + # "zero_optimization": { + # "stage": stage, + # "stage3_param_persistence_threshold": 1e4, + # "offload_param": { + # "device": off_load_device + # } + # }, + "bf16": {"enabled": True}, + "prescale_gradients": False, + "wall_clock_breakdown": False, + } + reward_model, *_ = deepspeed.initialize(model=reward_model, config=eval_ds_config) + reward_model.eval() + ref_policy, *_ = deepspeed.initialize(model=ref_policy, config=eval_ds_config) + ref_policy.eval() + else: + ref_policy = ref_policy.to(device) + reward_model = reward_model.to(device) + iter_dataloader = iter(dataloader) + kl_ctl = AdaptiveKLController(args.rewards.kl_coef, hparams=args.rewards.adaptive_kl) + # WARNING: even with `max_new_tokens` and `min_new_tokens` set to the same value, the number of tokens generated + # may not be the same. TODO: investigate further, we just want to generate a fixed number of tokens + generation_config = GenerationConfig( + max_new_tokens=args.task.response_length, + min_new_tokens=args.task.response_length, + temperature=args.task.temperature, + top_k=0.0, + top_p=1.0, + do_sample=True, + ) + + print("===training policy===") + global_step = 0 + stats_shape = (args.ppo.noptepochs, args.ppo.nminibatches, args.ppo.gradient_accumulation_steps) + approxkls_stats = torch.zeros(stats_shape, device=device) + clipfracs_stats = torch.zeros(stats_shape, device=device) + pg_losses_stats = torch.zeros(stats_shape, device=device) + vf_losses_stats = torch.zeros(stats_shape, device=device) + vf_clipfrac_stats = torch.zeros(stats_shape, device=device) + entropies_stats = torch.zeros(stats_shape, device=device) + for update in range(1, args.ppo.num_updates + 1): + global_step += 1 * args.ppo.batch_size + frac = 1.0 - (update - 1.0) / args.ppo.num_updates + lrnow = frac * args.ppo.lr + optimizer.param_groups[0]["lr"] = lrnow + data = next(iter_dataloader) + with torch.no_grad(): + queries = data["input_ids"].to(device) + queries = right_padding_to_left_padding(data["input_ids"], tokenizer.pad_token_id).to(device) + query_responses = generate( + accelerator.unwrap_model(policy).lm_backbone, + queries, + tokenizer, + generation_config, + ) + context_length = queries.shape[1] + responses = query_responses[:, context_length:] + + output, full_values = forward(policy, query_responses, tokenizer) + values = full_values[:, context_length - 1 : -1].squeeze(-1) + logits = output.logits[:, context_length - 1 : -1] + logits /= args.task.temperature + all_logprobs = F.log_softmax(logits, dim=-1) + logprobs = torch.gather(all_logprobs, 2, responses.unsqueeze(-1)).squeeze(-1) + del output, logits, all_logprobs + torch.cuda.empty_cache() + + ref_output, _ = forward(ref_policy, query_responses, tokenizer) + ref_logits = ref_output.logits[:, context_length - 1 : -1] + ref_logits /= args.task.temperature + ref_all_logprobs = F.log_softmax(ref_logits, dim=-1) + ref_logprobs = torch.gather(ref_all_logprobs, 2, responses.unsqueeze(-1)).squeeze(-1) + del ref_output, ref_logits, ref_all_logprobs + torch.cuda.empty_cache() + + # **Response Processing** + # 1. truncate at the first occurrence of `truncate_token` that appears at or after + # position truncate_after in the responses + # https://github.com/openai/lm-human-preferences/blob/cbfd210bb8b08f6bc5c26878c10984b90f516c66/lm_human_preferences/train_policy.py#L378 + truncate_token_mask = responses == args.task.truncate_token + truncate_after_or_token_mask = torch.cat( + [ + torch.zeros_like(truncate_token_mask)[:, : args.task.truncate_after], + truncate_token_mask[:, args.task.truncate_after :], + ], + dim=1, + ) + truncate_mask = (torch.cumsum(truncate_after_or_token_mask, dim=1) - truncate_after_or_token_mask.long()).bool() + postprocessed_responses = torch.where( + truncate_mask, + torch.full_like(responses, tokenizer.pad_token_id), + responses, + ) + del truncate_token_mask, truncate_after_or_token_mask, truncate_mask + torch.cuda.empty_cache() + + # 2. run reward model on the truncated responses + postprocessed_query_responses = torch.cat((queries, postprocessed_responses), 1) + postprocessed_query_responses = right_padding_to_left_padding( + postprocessed_query_responses, tokenizer.pad_token_id + ) + scores = get_reward(reward_model, postprocessed_query_responses, tokenizer).flatten() + + # 3. filter response. Ensure that the sample contains truncate_token + # responses not passing that filter will receive a low (fixed) score + # only query humans on responses that pass that filter + matches_token = postprocessed_responses[:, args.task.truncate_after :] == args.task.truncate_token + filter_mask = torch.any(matches_token, dim=-1) + scores = torch.where( + filter_mask, + scores, + torch.full_like(scores, args.task.penalty_reward_value), + ) + del matches_token, filter_mask + torch.cuda.empty_cache() + + # 4. compute rewards + kl = logprobs - ref_logprobs + non_score_reward = -kl_ctl.value * kl + rewards = non_score_reward.clone() + rewards[:, -1] += scores + + # 5. whiten rewards + if args.ppo.whiten_rewards: + rewards = whiten(rewards, shift_mean=False) + + if args.print_sample_output_freq > 0 and (update - 1) % args.print_sample_output_freq == 0: + try: + all_decode_queries = tokenizer.batch_decode(queries, skip_special_tokens=True) + all_postprocessed_query_responses = tokenizer.batch_decode( + postprocessed_query_responses, skip_special_tokens=True + ) + all_postprocessed_responses = [ + x[len(y) :] for x, y in zip(all_postprocessed_query_responses, all_decode_queries) + ] + + kl_sum = kl.sum(axis=1) + all_df = pd.DataFrame( + { + "query": all_decode_queries, + "response": all_postprocessed_responses, + "score": scores.float().cpu().numpy(), + "kl": kl_sum.float().cpu().numpy(), + "reward": (scores - kl_ctl.value * kl_sum).float().cpu().numpy(), + } + ) + if accelerator.is_main_process and args.track: + wandb.log({"query_responses": wandb.Table(dataframe=all_df)}, step=update) + print_rich_table("stuff", all_df[:4], console) + except Exception as e: + print(e) + del ( + all_decode_queries, + all_postprocessed_query_responses, + all_postprocessed_responses, + kl_sum, + all_df, + ) + del postprocessed_query_responses + torch.cuda.empty_cache() + + # 6. compute advantages and returns + lastgaelam = 0 + advantages_reversed = [] + gen_length = args.task.response_length + for t in reversed(range(gen_length)): + nextvalues = values[:, t + 1] if t < gen_length - 1 else 0.0 + delta = rewards[:, t] + args.ppo.gamma * nextvalues - values[:, t] + lastgaelam = delta + args.ppo.gamma * args.ppo.lam * lastgaelam + advantages_reversed.append(lastgaelam) + advantages = torch.stack(advantages_reversed[::-1], axis=1) + returns = advantages + values + advantages = whiten(advantages) + return_mean, return_var = returns.mean(), returns.var() + value_mean, value_var = values.mean(), values.var() + + # Do multiple epochs of PPO training, with a fresh random shuffle in each epoch + for ppo_epoch_idx in range(args.ppo.noptepochs): + b_inds = np.random.permutation(args.ppo.local_batch_size) + minibatch_idx = 0 + for mini_batch_start in range(0, args.ppo.local_batch_size, args.ppo.local_mini_batch_size): + mini_batch_end = mini_batch_start + args.ppo.local_mini_batch_size + mini_batch_inds = b_inds[mini_batch_start:mini_batch_end] + gradient_accumulation_idx = 0 + for micro_batch_start in range(0, args.ppo.local_mini_batch_size, args.ppo.local_micro_batch_size): + with accelerator.accumulate(policy): + micro_batch_end = micro_batch_start + args.ppo.local_micro_batch_size + micro_batch_inds = mini_batch_inds[micro_batch_start:micro_batch_end] + mb_return = returns[micro_batch_inds] + mb_advantage = advantages[micro_batch_inds] + mb_values = values[micro_batch_inds] + mb_responses = responses[micro_batch_inds] + mb_query_responses = query_responses[micro_batch_inds] + mb_logprobs = logprobs[micro_batch_inds] + + output, vpred_temp = forward(policy, mb_query_responses, tokenizer) + logits = output.logits[:, context_length - 1 : -1] + logits /= args.task.temperature + new_all_logprobs = F.log_softmax(logits, dim=-1) + new_logprobs = torch.gather(new_all_logprobs, 2, mb_responses.unsqueeze(-1)).squeeze(-1) + vpred = vpred_temp[:, context_length - 1 : -1].squeeze(-1) + vpredclipped = torch.clamp( + vpred, + mb_values - args.ppo.cliprange_value, + mb_values + args.ppo.cliprange_value, + ) + vf_losses1 = torch.square(vpred - mb_return) + vf_losses2 = torch.square(vpredclipped - mb_return) + vf_loss = 0.5 * torch.max(vf_losses1, vf_losses2).mean() + vf_clipfrac = (vf_losses2 > vf_losses1).float().mean() + logprobs_diff = new_logprobs - mb_logprobs + ratio = torch.exp(logprobs_diff) + pg_losses = -mb_advantage * ratio + pg_losses2 = -mb_advantage * torch.clamp(ratio, 1.0 - args.ppo.cliprange, 1.0 + args.ppo.cliprange) + pg_loss = torch.max(pg_losses, pg_losses2).mean() + pg_clipfrac = (pg_losses2 > pg_losses).float().mean() + loss = pg_loss + args.ppo.vf_coef * vf_loss + accelerator.backward(loss) + optimizer.step() + optimizer.zero_grad() + prob_dist = torch.nn.functional.softmax(logits, dim=-1) + entropy = torch.logsumexp(logits, dim=-1) - torch.sum(prob_dist * logits, dim=-1) + approxkl = 0.5 * (logprobs_diff**2).mean() + with torch.no_grad(): + approxkls_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = approxkl + clipfracs_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = pg_clipfrac + pg_losses_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = pg_loss + vf_losses_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = vf_loss + vf_clipfrac_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = vf_clipfrac + entropies_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = entropy.mean() + gradient_accumulation_idx += 1 + minibatch_idx += 1 + if accelerator.is_main_process: + console.print( + f"ppo_epoch_idx", + ppo_epoch_idx, + "approxkl", + approxkl.item(), + "pg_loss", + pg_loss.item(), + "pg_clipfrac", + pg_clipfrac.item(), + "ratio", + ratio.mean().item(), + ) + + with torch.no_grad(): + if not args.deepspeed: # for some reason there is a OOM with the `writer.add_histogram` + writer.add_histogram("ppo/val/ratio_hist", ratio, update) + kl = logprobs - ref_logprobs + mean_kl = kl.sum(1).mean() + mean_entropy = (-logprobs).sum(1).mean() + mean_non_score_reward = non_score_reward.sum(1).mean() + writer.add_scalar("objective/kl_coef", kl_ctl.value, update) + writer.add_scalar("objective/kl", accelerator.gather(mean_kl).mean().item(), update) + writer.add_scalar("objective/entropy", accelerator.gather(mean_entropy).mean().item(), update) + writer.add_scalar("objective/non_score_reward", accelerator.gather(mean_non_score_reward).mean().item(), update) + writer.add_scalar( + "objective/score_total", accelerator.gather(mean_non_score_reward + scores.mean()).mean().item(), update + ) + writer.add_scalar("objective/scores", accelerator.gather(scores.mean()).mean().item(), update) + writer.add_scalar("ppo/loss/policy", accelerator.gather(pg_loss).mean().item(), update) + writer.add_scalar("ppo/loss/value", accelerator.gather(vf_loss).mean().item(), update) + writer.add_scalar("ppo/loss/total", accelerator.gather(loss).mean().item(), update) + writer.add_scalar("ppo/policy/entropy", accelerator.gather(entropy.mean()).mean().item(), update) + writer.add_scalar("ppo/policy/approxkl", accelerator.gather(approxkl).mean().item(), update) + writer.add_scalar("ppo/policy/clipfrac", accelerator.gather(pg_clipfrac).mean().item(), update) + writer.add_scalar("ppo/policy/approxkl_avg", accelerator.gather(approxkls_stats).mean().item(), update) + writer.add_scalar("ppo/policy/clipfrac_avg", accelerator.gather(clipfracs_stats).mean().item(), update) + writer.add_scalar("ppo/loss/policy_avg", accelerator.gather(pg_losses_stats).mean().item(), update) + writer.add_scalar("ppo/loss/value_avg", accelerator.gather(vf_losses_stats).mean().item(), update) + writer.add_scalar("ppo/val/clipfrac_avg", accelerator.gather(vf_clipfrac_stats).mean().item(), update) + writer.add_scalar("ppo/policy/entropy_avg", accelerator.gather(entropies_stats).mean().item(), update) + writer.add_scalar("ppo/returns/mean", accelerator.gather(return_mean).mean().item(), update) + writer.add_scalar("ppo/returns/var", accelerator.gather(return_var).mean().item(), update) + writer.add_scalar("ppo/val/vpred", accelerator.gather(vpred.mean()).mean().item(), update) + writer.add_scalar("ppo/val/error", accelerator.gather(vf_losses1.mean()).mean().item(), update) + writer.add_scalar("ppo/val/clipfrac", accelerator.gather(vf_clipfrac).mean().item(), update) + writer.add_scalar("ppo/val/mean", accelerator.gather(value_mean).mean().item(), update) + writer.add_scalar("ppo/val/var", accelerator.gather(value_var).mean().item(), update) + writer.add_scalar("ppo/val/ratio", accelerator.gather(ratio.mean()).mean().item(), update) + writer.add_scalar("ppo/val/ratio_var", accelerator.gather(ratio.mean()).var().item(), update) + writer.add_scalar("ppo/val/advantage", accelerator.gather(advantages.mean()).mean().item(), update) + writer.add_scalar("ppo/val/advantage_var", accelerator.gather(advantages.mean()).var().item(), update) + writer.add_scalar("ppo/val/num_eos_tokens", (responses == tokenizer.eos_token_id).sum().item(), update) + writer.add_scalar("ppo/lr", lrnow, update) + writer.add_scalar("ppo/episode", global_step, update) + kl_ctl.update(mean_kl.item(), args.ppo.batch_size) + del kl, mean_kl, mean_entropy, mean_non_score_reward, scores + + # save model + if args.save_path: + os.makedirs(os.path.dirname(args.save_path), exist_ok=True) + torch.save(reward_model.state_dict(), args.save_path) + + +if __name__ == "__main__": + args = tyro.cli(Args) + train(args) diff --git a/lm_human_preference_details/summarization/train_policy_accelerate copy.py b/lm_human_preference_details/summarization/train_policy_accelerate copy.py new file mode 100644 index 0000000..a975666 --- /dev/null +++ b/lm_human_preference_details/summarization/train_policy_accelerate copy.py @@ -0,0 +1,947 @@ +import os +import random +import time +from dataclasses import asdict, dataclass, field +from types import SimpleNamespace +from typing import Optional + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +import tyro +from accelerate import Accelerator +from accelerate.state import AcceleratorState +from rich.console import Console +from rich.pretty import pprint +from torch.utils.data import DataLoader, IterableDataset +from torch.utils.tensorboard import SummaryWriter +from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig + +from lm_human_preference_details.data import DATASET + + +@dataclass +class AdaptiveKLParams: + target: float = 6.0 + horizon: int = 10000 # in episodes + + +@dataclass +class RewardHParams: + kl_coef: float = 0.15 + adaptive_kl: Optional[AdaptiveKLParams] = field(default_factory=AdaptiveKLParams) + trained_model: Optional[str] = "models/reward.pt" + label_dataset: tyro.conf.Suppress[Optional[str]] = None + + +@dataclass +class PpoHParams: + total_episodes: int = 1000000 + local_batch_size: int = 64 + local_mini_batch_size: tyro.conf.Suppress[int] = None + batch_size: tyro.conf.Suppress[int] = None + mini_batch_size: tyro.conf.Suppress[int] = None + gradient_accumulation_steps: int = 1 + """gradient accumulation steps""" + local_micro_batch_size: tyro.conf.Suppress[int] = None + """per rank micro batch size""" + world_size: tyro.conf.Suppress[int] = None + batch_size: tyro.conf.Suppress[int] = None + minibatch_size: tyro.conf.Suppress[int] = None + num_updates: tyro.conf.Suppress[int] = None + nminibatches: int = 1 + noptepochs: int = 4 + lr: float = 0.00001 + eps: float = 1e-5 + vf_coef: float = 0.1 + cliprange: float = 0.2 + cliprange_value: float = 0.2 + gamma: float = 1 + lam: float = 0.95 + whiten_rewards: bool = True + + +@dataclass +class TaskHParams: + # Query params + query_length: int = 64 + query_dataset: str = "books" + query_prefix: str = "" + query_suffix: str = "" + start_text: Optional[str] = None + end_text: Optional[str] = None + + # Response params + response_length: int = 24 + + # Truncate response after the first occurrence of this token at or after index after when sampling. + truncate_token: int = 13 + truncate_after: int = 16 + penalty_reward_value: int = -1 + + # LM params + temperature: float = 0.7 + + +@dataclass +class Args: + # common args + exp_name: str = os.path.basename(__file__)[: -len(".py")] + """the name of this experiment""" + seed: int = 1 + """seed of the experiment""" + track: bool = False + """if toggled, this experiment will be tracked with Weights and Biases""" + wandb_project_name: str = "cleanrl" + """the wandb's project name""" + wandb_entity: Optional[str] = None + """the entity (team) of wandb's project""" + cuda: bool = True + """Whether to use cuda if available.""" + run_name: tyro.conf.Suppress[str] = None + """TO BE FILLED: a unique name of this run""" + + base_model: str = "gpt2" + """the name of the pretrained model to use""" + deepspeed: bool = False + """Whether to use deepspeed to train the model""" + print_sample_output_freq: int = 0 + """How often to print sample output""" + save_path: str = "models/policy.pt" + """Where to save the model""" + use_tensorflow_adam: bool = True + """Whether to use tensorflow-style Adam optimizer instead of PyTorch's""" + task: TaskHParams = field(default_factory=TaskHParams) + rewards: RewardHParams = field(default_factory=RewardHParams) + ppo: PpoHParams = field(default_factory=PpoHParams) + + +from typing import List, Optional + +from torch import Tensor, optim +from torch.optim.optimizer import ( + _dispatch_sqrt, + _get_value, + _use_grad_for_differentiable, +) + + +def _single_tensor_adam( + params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_avg_sqs: List[Tensor], + max_exp_avg_sqs: List[Tensor], + state_steps: List[Tensor], + grad_scale: Optional[Tensor], + found_inf: Optional[Tensor], + *, + amsgrad: bool, + beta1: float, + beta2: float, + lr: float, + weight_decay: float, + eps: float, + maximize: bool, + capturable: bool, + differentiable: bool, +): + assert grad_scale is None and found_inf is None + + for i, param in enumerate(params): + grad = grads[i] if not maximize else -grads[i] + exp_avg = exp_avgs[i] + exp_avg_sq = exp_avg_sqs[i] + step_t = state_steps[i] + # update step + step_t += 1 + # Decay the first and second moment running average coefficient + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2) + step = _get_value(step_t) + + ### pytorch adam implementation: + # bias_correction1 = 1 - beta1 ** step + # bias_correction2 = 1 - beta2 ** step + # step_size = lr / bias_correction1 + # bias_correction2_sqrt = _dispatch_sqrt(bias_correction2) + # denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps) + # param.addcdiv_(exp_avg, denom, value=-step_size) + + ### tensorflow adam implementation: + lr_t = lr * _dispatch_sqrt(1 - beta2**step) / (1 - beta1**step) + denom = exp_avg_sq.sqrt().add_(eps) + param.addcdiv_(exp_avg, denom, value=-lr_t) + + +def adam( + params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_avg_sqs: List[Tensor], + max_exp_avg_sqs: List[Tensor], + state_steps: List[Tensor], + # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627 + # setting this as kwarg for now as functional API is compiled by torch/distributed/optim + foreach: Optional[bool] = None, + capturable: bool = False, + differentiable: bool = False, + fused: Optional[bool] = None, + grad_scale: Optional[Tensor] = None, + found_inf: Optional[Tensor] = None, + *, + amsgrad: bool, + beta1: float, + beta2: float, + lr: float, + weight_decay: float, + eps: float, + maximize: bool, +): + func = _single_tensor_adam + + func( + params, + grads, + exp_avgs, + exp_avg_sqs, + max_exp_avg_sqs, + state_steps, + amsgrad=amsgrad, + beta1=beta1, + beta2=beta2, + lr=lr, + weight_decay=weight_decay, + eps=eps, + maximize=maximize, + capturable=capturable, + differentiable=differentiable, + grad_scale=grad_scale, + found_inf=found_inf, + ) + + +class AdamTensorFlowStyle(optim.Adam): + @_use_grad_for_differentiable + def step(self, closure=None): + self._cuda_graph_capture_health_check() + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + params_with_grad = [] + grads = [] + exp_avgs = [] + exp_avg_sqs = [] + max_exp_avg_sqs = [] + state_steps = [] + beta1, beta2 = group["betas"] + + self._init_group( + group, + params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + max_exp_avg_sqs, + state_steps, + ) + + adam( + params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + max_exp_avg_sqs, + state_steps, + amsgrad=group["amsgrad"], + beta1=beta1, + beta2=beta2, + lr=group["lr"], + weight_decay=group["weight_decay"], + eps=group["eps"], + maximize=group["maximize"], + foreach=group["foreach"], + capturable=group["capturable"], + differentiable=group["differentiable"], + fused=group["fused"], + grad_scale=getattr(self, "grad_scale", None), + found_inf=getattr(self, "found_inf", None), + ) + + return loss + + +class AdaptiveKLController: + def __init__(self, init_kl_coef: float, hparams: AdaptiveKLParams): + self.value = init_kl_coef + self.hparams = hparams + + def update(self, current, n_steps): + target = self.hparams.target + proportional_error = np.clip(current / target - 1, -0.2, 0.2) + mult = 1 + proportional_error * n_steps / self.hparams.horizon + self.value *= mult + + +def layer_init(layer, std=np.sqrt(2), bias_const=0.0): + torch.nn.init.normal_(layer.weight, std=std) + torch.nn.init.constant_(layer.bias, val=bias_const) + return layer + + +def whiten(values, shift_mean=True): + # `unbiased=False` matches TF `tf.nn.moments`'s setting + mean, var = torch.mean(values), torch.var(values, unbiased=False) + whitened = (values - mean) * torch.rsqrt(var + 1e-8) + if not shift_mean: + whitened += mean + return whitened + + +class AutoModelForCausalLMWithScalarHead(nn.Module): + def __init__(self, lm_backbone): + super().__init__() + self.lm_backbone = lm_backbone + self.scalar_head = layer_init(nn.Linear(lm_backbone.config.hidden_size, 1), std=0) + + def forward(self, **kwargs): + output = self.lm_backbone(**kwargs) + return output, self.scalar_head(output.hidden_states[-1]) + + +class AutoModelForCausalLMWithRewardHead(nn.Module): + def __init__(self, lm_backbone): + super().__init__() + self.lm_backbone = lm_backbone + self.scalar_head = layer_init( + nn.Linear(lm_backbone.config.hidden_size, 1), + std=1 / np.sqrt(lm_backbone.config.hidden_size + 1), + ) + self.reward_gain = torch.nn.Parameter(torch.tensor(1.0), requires_grad=True) + self.reward_bias = torch.nn.Parameter(torch.tensor(0.0), requires_grad=True) + + +# a pytorch dataset +class MyDataset(IterableDataset): + def __init__(self, generator, tokenizer, query_length, seed, start_text=None, end_text=None): + self.generator = generator + self.tokenizer = tokenizer + self.query_length = query_length + self.start_text = start_text + self.end_text = end_text + self.seed = seed + token_to_index = tokenizer.get_vocab() + self.start_token = token_to_index[start_text] if self.start_text else None + self.end_token = token_to_index[end_text] if self.end_text else None + + def __iter__(self): + for text in self.generator("train", self.seed, shuffle=True): + tokens = self.tokenizer.encode(text) + if self.start_token is not None: + try: + first_index = tokens.index(self.start_token) + 1 + if first_index < len(tokens): + tokens = tokens[first_index:] + except: + continue + tokens = tokens[: self.query_length] + if self.end_token is not None: + try: + last_index = len(tokens) - tokens[::-1].index(self.end_token) + tokens = tokens[:last_index] + except: + continue + output = self.tokenizer.pad( + {"input_ids": tokens}, + padding="max_length", + max_length=self.query_length, + return_tensors="pt", + return_attention_mask=True, + ) + yield output + + +def right_padding_to_left_padding(query, pad_id): + # Convert from right padding to left padding. + return torch.tensor( + [[pad_id] * (row == pad_id).sum() + [x for x in row if x != pad_id] for row in query], + device=query.device, + ) + + +def ceil_div(a, b): + return (a - 1) // b + 1 + + +def exact_div(a, b): + q = a // b + if a != q * b: + raise ValueError(f"Inexact division: {a} / {b} = {a / b}") + return q + + +def generate(lm_backbone, queries, tokenizer, generation_config): + """generate in a way that does not affect padding tokens""" + context_length = queries.shape[1] + attention_mask = queries != tokenizer.pad_token_id + input_ids = queries.clone() + input_ids[~attention_mask] = 0 # set padding tokens to 0 + output = lm_backbone.generate( + input_ids=input_ids, + attention_mask=attention_mask, + # position_ids=attention_mask.cumsum(1) - attention_mask.long(), # generation collapsed if this was turned on. TODO: why does generation collapse with this? + generation_config=generation_config, + return_dict_in_generate=True, + ) + # restore padding tokens + return torch.cat((queries, output.sequences[:, context_length:]), dim=1) + + +def get_reward(reward_model, query_responses, tokenizer): + attention_mask = query_responses != tokenizer.pad_token_id + position_ids = attention_mask.cumsum(1) - attention_mask.long() # exclusive cumsum + input_ids = query_responses.clone() + input_ids[~attention_mask] = 0 + output = reward_model.lm_backbone( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + return_dict=True, + output_hidden_states=True, + ) + reward = reward_model.scalar_head(output.hidden_states[-1]) + reward = reward_model.reward_gain * reward + reward_model.reward_bias + # but we only care about the reward of the last token + reward = reward[:, -1] + return reward + + +def forward(policy, query_responses, tokenizer): + attention_mask = query_responses != tokenizer.pad_token_id + position_ids = attention_mask.cumsum(1) - attention_mask.long() # exclusive cumsum + input_ids = query_responses.clone() + input_ids[~attention_mask] = 0 + return policy( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + return_dict=True, + output_hidden_states=True, + ) + + +def train(args: Args): + accelerator = Accelerator(gradient_accumulation_steps=args.ppo.gradient_accumulation_steps) + args.ppo.world_size = accelerator.num_processes + args.ppo.batch_size = int(args.ppo.local_batch_size * args.ppo.world_size) + args.ppo.minibatch_size = exact_div(args.ppo.batch_size, args.ppo.nminibatches) + args.ppo.local_mini_batch_size = exact_div(args.ppo.local_batch_size, args.ppo.nminibatches) + args.ppo.local_micro_batch_size = exact_div(args.ppo.local_mini_batch_size, args.ppo.gradient_accumulation_steps) + if args.ppo.whiten_rewards: + assert ( + args.ppo.local_mini_batch_size >= 8 + ), f"Per-rank minibatch size {args.ppo.local_mini_batch_size} is insufficient for whitening" + # `per_rank_rollout_batch_size` is our `args.ppo.local_batch_size` + # `per_rank_minibatch_size` is our `args.ppo.local_mini_batch_size` + args.ppo.num_updates = args.ppo.total_episodes // args.ppo.batch_size + + console = Console(force_terminal=True) + run_name = f"{args.exp_name}__{args.seed}__{int(time.time())}" + writer = SimpleNamespace() # dummy writer + writer.add_scalar = lambda x, y, z: None + writer.add_histogram = lambda x, y, z: None + if accelerator.is_main_process: + if args.track: + import wandb + + wandb.init( + project=args.wandb_project_name, + entity=args.wandb_entity, + sync_tensorboard=True, + config=asdict(args), + name=run_name, + save_code=True, + ) + wandb.run.log_code(".") + writer = SummaryWriter(f"runs/{run_name}") + writer.add_text( + "hyperparameters", + "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])), + ) + pprint(args) + device = accelerator.device + local_seed = args.seed + accelerator.process_index * 100003 # Prime + random.seed(local_seed) + np.random.seed(local_seed) + torch.manual_seed(local_seed) + torch.backends.cudnn.deterministic = True + tokenizer = AutoTokenizer.from_pretrained( + args.base_model, + padding_side="right", + trust_remote_code=True, + ) + # we use the padding token manually but do not resize the token embedding of the model + tokenizer.add_special_tokens({"pad_token": "[PAD]"}) + reward_model = AutoModelForCausalLMWithRewardHead(AutoModelForCausalLM.from_pretrained(args.base_model, trust_remote_code=True)) + if args.rewards.trained_model: + reward_model.load_state_dict(torch.load(args.rewards.trained_model, map_location=device)) + print(f"loaded pretrained reward model from {args.rewards.trained_model}") + # each class should have a separate pretrained model that do not share weights + ref_policy = AutoModelForCausalLMWithScalarHead(AutoModelForCausalLM.from_pretrained(args.base_model, trust_remote_code=True)) + policy = AutoModelForCausalLMWithScalarHead(AutoModelForCausalLM.from_pretrained(args.base_model, trust_remote_code=True)) + policy.lm_backbone.generation_config.eos_token_id = ( + None # disable `pad_token_id` and `eos_token_id` because we just want to + ) + policy.lm_backbone.generation_config.pad_token_id = None # generate tokens without truncation / padding + # IMPORTANT: Layer norm produces weird gradients, which affects Adam optimizer to impact all the parameters systematically + # see https://github.com/pytorch/pytorch/issues/104857 for more details + if args.use_tensorflow_adam: + optimizer = AdamTensorFlowStyle(policy.parameters(), lr=args.ppo.lr, eps=args.ppo.eps) + else: + optimizer = optim.Adam(policy.parameters(), lr=args.ppo.lr, eps=args.ppo.eps) + dataset = MyDataset( + DATASET[args.task.query_dataset], + tokenizer, + args.task.query_length, + seed=local_seed, + start_text=args.task.start_text, + end_text=args.task.end_text, + ) + dataloader = DataLoader(dataset, batch_size=args.ppo.local_batch_size) + policy, optimizer, dataloader = accelerator.prepare(policy, optimizer, dataloader) + if args.deepspeed: + import deepspeed + + deepspeed_states = AcceleratorState().deepspeed_plugin + deepspeed_states.deepspeed_config['train_micro_batch_size_per_gpu'] = args.ppo.local_micro_batch_size + deepspeed_states.deepspeed_config['checkpoint'] = {'use_node_local_storage': True} + off_load_device = "cpu" + stage = 3 + eval_ds_config = { + "train_micro_batch_size_per_gpu": deepspeed_states.deepspeed_config['train_micro_batch_size_per_gpu'], + "steps_per_print": 10, + # "zero_optimization": { + # "stage": stage, + # "stage3_param_persistence_threshold": 1e4, + # "offload_param": { + # "device": off_load_device + # } + # }, + "bf16": { + "enabled": True + }, + "prescale_gradients": False, + "wall_clock_breakdown": False + } + reward_model, *_ = deepspeed.initialize(model=reward_model, config=eval_ds_config) + reward_model.eval() + ref_policy, *_ = deepspeed.initialize(model=ref_policy, config=eval_ds_config) + ref_policy.eval() + else: + ref_policy = ref_policy.to(device) + reward_model = reward_model.to(device) + iter_dataloader = iter(dataloader) + kl_ctl = AdaptiveKLController(args.rewards.kl_coef, hparams=args.rewards.adaptive_kl) + # WARNING: even with `max_new_tokens` and `min_new_tokens` set to the same value, the number of tokens generated + # may not be the same. TODO: investigate further, we just want to generate a fixed number of tokens + generation_config = GenerationConfig( + max_new_tokens=args.task.response_length, + min_new_tokens=args.task.response_length, + temperature=args.task.temperature, + top_k=0.0, + top_p=1.0, + do_sample=True, + ) + + print("===training policy===") + global_step = 0 + approxkls_stats = torch.zeros( + ( + args.ppo.noptepochs, + args.ppo.nminibatches, + args.ppo.gradient_accumulation_steps, + ), + device=device, + ) + clipfracs_stats = torch.zeros( + ( + args.ppo.noptepochs, + args.ppo.nminibatches, + args.ppo.gradient_accumulation_steps, + ), + device=device, + ) + pg_losses_stats = torch.zeros( + ( + args.ppo.noptepochs, + args.ppo.nminibatches, + args.ppo.gradient_accumulation_steps, + ), + device=device, + ) + vf_losses_stats = torch.zeros( + ( + args.ppo.noptepochs, + args.ppo.nminibatches, + args.ppo.gradient_accumulation_steps, + ), + device=device, + ) + vf_clipfrac_stats = torch.zeros( + ( + args.ppo.noptepochs, + args.ppo.nminibatches, + args.ppo.gradient_accumulation_steps, + ), + device=device, + ) + entropies_stats = torch.zeros( + ( + args.ppo.noptepochs, + args.ppo.nminibatches, + args.ppo.gradient_accumulation_steps, + ), + device=device, + ) + for update in range(1, args.ppo.num_updates + 1): + global_step += 1 * args.ppo.batch_size + frac = 1.0 - (update - 1.0) / args.ppo.num_updates + lrnow = frac * args.ppo.lr + optimizer.param_groups[0]["lr"] = lrnow + data = next(iter_dataloader) + with torch.no_grad(): + queries = data["input_ids"].to(device) + queries = right_padding_to_left_padding(data["input_ids"], tokenizer.pad_token_id).to(device) + query_responses = generate( + accelerator.unwrap_model(policy).lm_backbone, + queries, + tokenizer, + generation_config, + ) + context_length = queries.shape[1] + responses = query_responses[:, context_length:] + + output, full_values = forward(policy, query_responses, tokenizer) + values = full_values[:, context_length - 1 : -1].squeeze(-1) + logits = output.logits[:, context_length - 1 : -1] + logits /= args.task.temperature + all_logprobs = F.log_softmax(logits, dim=-1) + logprobs = torch.gather(all_logprobs, 2, responses.unsqueeze(-1)).squeeze(-1) + + output4, _ = forward(policy, query_responses, tokenizer) + logits4 = output4.logits[:, context_length - 1 : -1] + logits4 /= args.task.temperature + all_logprobs4 = F.log_softmax(logits4, dim=-1) + logprobs4 = torch.gather(all_logprobs4, 2, responses.unsqueeze(-1)).squeeze(-1) + del output, logits, all_logprobs + torch.cuda.empty_cache() + + ref_output, _ = forward(ref_policy, query_responses, tokenizer) + ref_logits = ref_output.logits[:, context_length - 1 : -1] + ref_logits /= args.task.temperature + ref_all_logprobs = F.log_softmax(ref_logits, dim=-1) + ref_logprobs = torch.gather(ref_all_logprobs, 2, responses.unsqueeze(-1)).squeeze(-1) + del ref_output, ref_logits, ref_all_logprobs + torch.cuda.empty_cache() + + # **Response Processing** + # 1. truncate at the first occurrence of `truncate_token` that appears at or after + # position truncate_after in the responses + # https://github.com/openai/lm-human-preferences/blob/cbfd210bb8b08f6bc5c26878c10984b90f516c66/lm_human_preferences/train_policy.py#L378 + truncate_token_mask = responses == args.task.truncate_token + truncate_after_or_token_mask = torch.cat( + [ + torch.zeros_like(truncate_token_mask)[:, : args.task.truncate_after], + truncate_token_mask[:, args.task.truncate_after :], + ], + dim=1, + ) + truncate_mask = (torch.cumsum(truncate_after_or_token_mask, dim=1) - truncate_after_or_token_mask.long()).bool() + postprocessed_responses = torch.where( + truncate_mask, + torch.full_like(responses, tokenizer.pad_token_id), + responses, + ) + del truncate_token_mask, truncate_after_or_token_mask, truncate_mask + torch.cuda.empty_cache() + + # 2. run reward model on the truncated responses + postprocessed_query_responses = torch.cat((queries, postprocessed_responses), 1) + postprocessed_query_responses = right_padding_to_left_padding( + postprocessed_query_responses, tokenizer.pad_token_id + ) + scores = get_reward(reward_model, postprocessed_query_responses, tokenizer).flatten() + + # 3. filter response. Ensure that the sample contains truncate_token + # responses not passing that filter will receive a low (fixed) score + # only query humans on responses that pass that filter + matches_token = postprocessed_responses[:, args.task.truncate_after :] == args.task.truncate_token + filter_mask = torch.any(matches_token, dim=-1) + scores = torch.where( + filter_mask, + scores, + torch.full_like(scores, args.task.penalty_reward_value), + ) + del matches_token, filter_mask + torch.cuda.empty_cache() + + # 4. compute rewards + kl = logprobs - ref_logprobs + non_score_reward = -kl_ctl.value * kl + rewards = non_score_reward.clone() + rewards[:, -1] += scores + + # 5. whiten rewards + if args.ppo.whiten_rewards: + rewards = whiten(rewards, shift_mean=False) + try: + sample_kl = kl[0].sum().item() + postprocessed_responses = postprocessed_query_responses[:, context_length:] + console.print( + f"[green]{tokenizer.decode(queries[0], skip_special_tokens=True)}[/]\n[yellow]{tokenizer.decode(postprocessed_responses[0], skip_special_tokens=True)}[/]\n[blue](NO POST-PROCESSING){tokenizer.decode(responses[0], skip_special_tokens=True)}[/]\n[red]score: {scores[0]}, kl: {kl[0].sum().item()}, total reward: {scores[0] - kl_ctl.value * sample_kl} [/]" + ) + except Exception as e: + print(e) + del postprocessed_query_responses + torch.cuda.empty_cache() + + # 6. compute advantages and returns + lastgaelam = 0 + advantages_reversed = [] + gen_length = args.task.response_length + for t in reversed(range(gen_length)): + nextvalues = values[:, t + 1] if t < gen_length - 1 else 0.0 + delta = rewards[:, t] + args.ppo.gamma * nextvalues - values[:, t] + lastgaelam = delta + args.ppo.gamma * args.ppo.lam * lastgaelam + advantages_reversed.append(lastgaelam) + advantages = torch.stack(advantages_reversed[::-1], axis=1) + returns = advantages + values + advantages = whiten(advantages) + return_mean, return_var = returns.mean(), returns.var() + value_mean, value_var = values.mean(), values.var() + + # Do multiple epochs of PPO training, with a fresh random shuffle in each epoch + for ppo_epoch_idx in range(args.ppo.noptepochs): + b_inds = np.random.permutation(args.ppo.local_batch_size) + minibatch_idx = 0 + for mini_batch_start in range(0, args.ppo.local_batch_size, args.ppo.local_mini_batch_size): + mini_batch_end = mini_batch_start + args.ppo.local_mini_batch_size + mini_batch_inds = b_inds[mini_batch_start:mini_batch_end] + gradient_accumulation_idx = 0 + for micro_batch_start in range(0, args.ppo.local_mini_batch_size, args.ppo.local_micro_batch_size): + micro_batch_end = micro_batch_start + args.ppo.local_micro_batch_size + micro_batch_inds = mini_batch_inds[micro_batch_start:micro_batch_end] + mb_return = returns[micro_batch_inds] + mb_advantage = advantages[micro_batch_inds] + mb_values = values[micro_batch_inds] + mb_responses = responses[micro_batch_inds] + mb_query_responses = query_responses[micro_batch_inds] + mb_logprobs = logprobs[micro_batch_inds] + output2, vpred_temp = forward(policy, mb_query_responses, tokenizer) + logits2 = output2.logits[:, context_length - 1 : -1] + logits2 /= args.task.temperature + new_all_logprobs2 = F.log_softmax(logits2, dim=-1) + new_logprobs2 = torch.gather(new_all_logprobs2, 2, mb_responses.unsqueeze(-1)).squeeze(-1) + + with accelerator.accumulate(policy): + + + + output, vpred_temp = forward(policy, mb_query_responses, tokenizer) + logits = output.logits[:, context_length - 1 : -1] + logits /= args.task.temperature + new_all_logprobs = F.log_softmax(logits, dim=-1) + new_logprobs = torch.gather(new_all_logprobs, 2, mb_responses.unsqueeze(-1)).squeeze(-1) + + + + vpred = vpred_temp[:, context_length - 1 : -1].squeeze(-1) + vpredclipped = torch.clamp( + vpred, + mb_values - args.ppo.cliprange_value, + mb_values + args.ppo.cliprange_value, + ) + vf_losses1 = torch.square(vpred - mb_return) + vf_losses2 = torch.square(vpredclipped - mb_return) + vf_loss = 0.5 * torch.max(vf_losses1, vf_losses2).mean() + vf_clipfrac = (vf_losses2 > vf_losses1).float().mean() + logprobs_diff = new_logprobs - mb_logprobs + pprint({ + "new_logprobs": new_logprobs, + "new_logprobs2": new_logprobs2, + "mb_logprobs": mb_logprobs, + "mb_logprobs2": logprobs4[micro_batch_inds], + }) + ratio = torch.exp(logprobs_diff) + print(ratio.mean()) + breakpoint() + pg_losses = -mb_advantage * ratio + pg_losses2 = -mb_advantage * torch.clamp(ratio, 1.0 - args.ppo.cliprange, 1.0 + args.ppo.cliprange) + pg_loss = torch.max(pg_losses, pg_losses2).mean() + pg_clipfrac = (pg_losses2 > pg_losses).float().mean() + loss = pg_loss + args.ppo.vf_coef * vf_loss + accelerator.backward(loss) + optimizer.step() + optimizer.zero_grad() + pd = torch.nn.functional.softmax(logits, dim=-1) + entropy = torch.logsumexp(logits, dim=-1) - torch.sum(pd * logits, dim=-1) + approxkl = 0.5 * (logprobs_diff**2).mean() + with torch.no_grad(): + approxkls_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = approxkl + clipfracs_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = pg_clipfrac + pg_losses_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = pg_loss + vf_losses_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = vf_loss + vf_clipfrac_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = vf_clipfrac + entropies_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = entropy.mean() + gradient_accumulation_idx += 1 + minibatch_idx += 1 + if accelerator.is_main_process: + console.print( + f"ppo_epoch_idx", + ppo_epoch_idx, + "approxkl", + approxkl.item(), + "pg_loss", + pg_loss.item(), + "pg_clipfrac", + pg_clipfrac.item(), + "ratio", + ratio.mean().item(), + ) + + with torch.no_grad(): + if not args.deepspeed: # for some reason there is a OOM with the `writer.add_histogram` + writer.add_histogram("ppo/val/ratio_hist", ratio, update) + kl = logprobs - ref_logprobs + mean_kl = kl.sum(1).mean() + mean_entropy = (-logprobs).sum(1).mean() + mean_non_score_reward = non_score_reward.sum(1).mean() + writer.add_scalar("objective/kl_coef", kl_ctl.value, update) + writer.add_scalar("objective/kl", accelerator.gather(mean_kl).mean().item(), update) + writer.add_scalar( + "objective/entropy", + accelerator.gather(mean_entropy).mean().item(), + update, + ) + writer.add_scalar( + "objective/non_score_reward", + accelerator.gather(mean_non_score_reward).mean().item(), + update, + ) + writer.add_scalar( + "objective/score_total", + accelerator.gather(mean_non_score_reward + scores.mean()).mean().item(), + update, + ) + writer.add_scalar( + "objective/scores", + accelerator.gather(scores.mean()).mean().item(), + update, + ) + writer.add_scalar("ppo/loss/policy", accelerator.gather(pg_loss).mean().item(), update) + writer.add_scalar("ppo/loss/value", accelerator.gather(vf_loss).mean().item(), update) + writer.add_scalar("ppo/loss/total", accelerator.gather(loss).mean().item(), update) + writer.add_scalar( + "ppo/policy/entropy", + accelerator.gather(entropy.mean()).mean().item(), + update, + ) + writer.add_scalar( + "ppo/policy/approxkl", + accelerator.gather(approxkl).mean().item(), + update, + ) + writer.add_scalar( + "ppo/policy/clipfrac", + accelerator.gather(pg_clipfrac).mean().item(), + update, + ) + writer.add_scalar( + "ppo/policy/approxkl_avg", + accelerator.gather(approxkls_stats).mean().item(), + update, + ) + writer.add_scalar( + "ppo/policy/clipfrac_avg", + accelerator.gather(clipfracs_stats).mean().item(), + update, + ) + writer.add_scalar( + "ppo/loss/policy_avg", + accelerator.gather(pg_losses_stats).mean().item(), + update, + ) + writer.add_scalar( + "ppo/loss/value_avg", + accelerator.gather(vf_losses_stats).mean().item(), + update, + ) + writer.add_scalar( + "ppo/val/clipfrac_avg", + accelerator.gather(vf_clipfrac_stats).mean().item(), + update, + ) + writer.add_scalar( + "ppo/policy/entropy_avg", + accelerator.gather(entropies_stats).mean().item(), + update, + ) + writer.add_scalar( + "ppo/returns/mean", + accelerator.gather(return_mean).mean().item(), + update, + ) + writer.add_scalar("ppo/returns/var", accelerator.gather(return_var).mean().item(), update) + writer.add_scalar("ppo/val/vpred", accelerator.gather(vpred.mean()).mean().item(), update) + writer.add_scalar( + "ppo/val/error", + accelerator.gather(vf_losses1.mean()).mean().item(), + update, + ) + writer.add_scalar( + "ppo/val/clipfrac", + accelerator.gather(vf_clipfrac).mean().item(), + update, + ) + writer.add_scalar("ppo/val/mean", accelerator.gather(value_mean).mean().item(), update) + writer.add_scalar("ppo/val/var", accelerator.gather(value_var).mean().item(), update) + writer.add_scalar("ppo/val/ratio", accelerator.gather(ratio.mean()).mean().item(), update) + writer.add_scalar( + "ppo/val/ratio_var", + accelerator.gather(ratio.mean()).var().item(), + update, + ) + writer.add_scalar( + "ppo/val/advantage", + accelerator.gather(advantages.mean()).mean().item(), + update, + ) + writer.add_scalar( + "ppo/val/advantage_var", + accelerator.gather(advantages.mean()).var().item(), + update, + ) + writer.add_scalar( + "ppo/val/num_eos_tokens", + (responses == tokenizer.eos_token_id).sum().item(), + update, + ) + writer.add_scalar("ppo/lr", lrnow, update) + writer.add_scalar("ppo/episode", global_step, update) + kl_ctl.update(mean_kl.item(), args.ppo.batch_size) + del kl, mean_kl, mean_entropy, mean_non_score_reward, scores + + # save model + if args.save_path: + os.makedirs(os.path.dirname(args.save_path), exist_ok=True) + torch.save(reward_model.state_dict(), args.save_path) + + +if __name__ == "__main__": + args = tyro.cli(Args) + train(args) diff --git a/lm_human_preference_details/summarization/train_policy_accelerate_new.py b/lm_human_preference_details/summarization/train_policy_accelerate_new.py new file mode 100644 index 0000000..e9f296d --- /dev/null +++ b/lm_human_preference_details/summarization/train_policy_accelerate_new.py @@ -0,0 +1,952 @@ +import os +import random +import time +from dataclasses import asdict, dataclass, field +from types import SimpleNamespace +from typing import Optional + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +import tyro +from accelerate import Accelerator +from accelerate.state import AcceleratorState +from rich.console import Console +from rich.pretty import pprint +from torch.utils.data import DataLoader, IterableDataset +from torch.utils.tensorboard import SummaryWriter +from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig + +from lm_human_preference_details.data import DATASET + + +@dataclass +class AdaptiveKLParams: + target: float = 6.0 + horizon: int = 10000 # in episodes + + +@dataclass +class RewardHParams: + kl_coef: float = 0.15 + adaptive_kl: Optional[AdaptiveKLParams] = field(default_factory=AdaptiveKLParams) + trained_model: Optional[str] = "models/reward.pt" + label_dataset: tyro.conf.Suppress[Optional[str]] = None + + +@dataclass +class PpoHParams: + total_episodes: int = 1000000 + local_batch_size: int = 64 + local_mini_batch_size: tyro.conf.Suppress[int] = None + batch_size: tyro.conf.Suppress[int] = None + mini_batch_size: tyro.conf.Suppress[int] = None + gradient_accumulation_steps: int = 1 + """gradient accumulation steps""" + local_micro_batch_size: tyro.conf.Suppress[int] = None + """per rank micro batch size""" + world_size: tyro.conf.Suppress[int] = None + batch_size: tyro.conf.Suppress[int] = None + minibatch_size: tyro.conf.Suppress[int] = None + num_updates: tyro.conf.Suppress[int] = None + nminibatches: int = 1 + noptepochs: int = 4 + lr: float = 0.00001 + eps: float = 1e-5 + vf_coef: float = 0.1 + cliprange: float = 0.2 + cliprange_value: float = 0.2 + gamma: float = 1 + lam: float = 0.95 + whiten_rewards: bool = True + + +@dataclass +class TaskHParams: + # Query params + query_length: int = 64 + query_dataset: str = "books" + query_prefix: str = "" + query_suffix: str = "" + start_text: Optional[str] = None + end_text: Optional[str] = None + + # Response params + response_length: int = 24 + + # Truncate response after the first occurrence of this token at or after index after when sampling. + truncate_token: int = 13 + truncate_after: int = 16 + penalty_reward_value: int = -1 + + # LM params + temperature: float = 0.7 + + +@dataclass +class Args: + # common args + exp_name: str = os.path.basename(__file__)[: -len(".py")] + """the name of this experiment""" + seed: int = 1 + """seed of the experiment""" + track: bool = False + """if toggled, this experiment will be tracked with Weights and Biases""" + wandb_project_name: str = "cleanrl" + """the wandb's project name""" + wandb_entity: Optional[str] = None + """the entity (team) of wandb's project""" + cuda: bool = True + """Whether to use cuda if available.""" + run_name: tyro.conf.Suppress[str] = None + """TO BE FILLED: a unique name of this run""" + + base_model: str = "gpt2" + """the name of the pretrained model to use""" + deepspeed: bool = False + """Whether to use deepspeed to train the model""" + print_sample_output_freq: int = 0 + """How often to print sample output""" + save_path: str = "models/policy.pt" + """Where to save the model""" + use_tensorflow_adam: bool = True + """Whether to use tensorflow-style Adam optimizer instead of PyTorch's""" + task: TaskHParams = field(default_factory=TaskHParams) + rewards: RewardHParams = field(default_factory=RewardHParams) + ppo: PpoHParams = field(default_factory=PpoHParams) + + +from typing import List, Optional + +from torch import Tensor, optim +from torch.optim.optimizer import ( + _dispatch_sqrt, + _get_value, + _use_grad_for_differentiable, +) + + +def _single_tensor_adam( + params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_avg_sqs: List[Tensor], + max_exp_avg_sqs: List[Tensor], + state_steps: List[Tensor], + grad_scale: Optional[Tensor], + found_inf: Optional[Tensor], + *, + amsgrad: bool, + beta1: float, + beta2: float, + lr: float, + weight_decay: float, + eps: float, + maximize: bool, + capturable: bool, + differentiable: bool, +): + assert grad_scale is None and found_inf is None + + for i, param in enumerate(params): + grad = grads[i] if not maximize else -grads[i] + exp_avg = exp_avgs[i] + exp_avg_sq = exp_avg_sqs[i] + step_t = state_steps[i] + # update step + step_t += 1 + # Decay the first and second moment running average coefficient + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2) + step = _get_value(step_t) + + ### pytorch adam implementation: + # bias_correction1 = 1 - beta1 ** step + # bias_correction2 = 1 - beta2 ** step + # step_size = lr / bias_correction1 + # bias_correction2_sqrt = _dispatch_sqrt(bias_correction2) + # denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps) + # param.addcdiv_(exp_avg, denom, value=-step_size) + + ### tensorflow adam implementation: + lr_t = lr * _dispatch_sqrt(1 - beta2**step) / (1 - beta1**step) + denom = exp_avg_sq.sqrt().add_(eps) + param.addcdiv_(exp_avg, denom, value=-lr_t) + + +def adam( + params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_avg_sqs: List[Tensor], + max_exp_avg_sqs: List[Tensor], + state_steps: List[Tensor], + # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627 + # setting this as kwarg for now as functional API is compiled by torch/distributed/optim + foreach: Optional[bool] = None, + capturable: bool = False, + differentiable: bool = False, + fused: Optional[bool] = None, + grad_scale: Optional[Tensor] = None, + found_inf: Optional[Tensor] = None, + *, + amsgrad: bool, + beta1: float, + beta2: float, + lr: float, + weight_decay: float, + eps: float, + maximize: bool, +): + func = _single_tensor_adam + + func( + params, + grads, + exp_avgs, + exp_avg_sqs, + max_exp_avg_sqs, + state_steps, + amsgrad=amsgrad, + beta1=beta1, + beta2=beta2, + lr=lr, + weight_decay=weight_decay, + eps=eps, + maximize=maximize, + capturable=capturable, + differentiable=differentiable, + grad_scale=grad_scale, + found_inf=found_inf, + ) + + +class AdamTensorFlowStyle(optim.Adam): + @_use_grad_for_differentiable + def step(self, closure=None): + self._cuda_graph_capture_health_check() + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + params_with_grad = [] + grads = [] + exp_avgs = [] + exp_avg_sqs = [] + max_exp_avg_sqs = [] + state_steps = [] + beta1, beta2 = group["betas"] + + self._init_group( + group, + params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + max_exp_avg_sqs, + state_steps, + ) + + adam( + params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + max_exp_avg_sqs, + state_steps, + amsgrad=group["amsgrad"], + beta1=beta1, + beta2=beta2, + lr=group["lr"], + weight_decay=group["weight_decay"], + eps=group["eps"], + maximize=group["maximize"], + foreach=group["foreach"], + capturable=group["capturable"], + differentiable=group["differentiable"], + fused=group["fused"], + grad_scale=getattr(self, "grad_scale", None), + found_inf=getattr(self, "found_inf", None), + ) + + return loss + + +class AdaptiveKLController: + def __init__(self, init_kl_coef: float, hparams: AdaptiveKLParams): + self.value = init_kl_coef + self.hparams = hparams + + def update(self, current, n_steps): + target = self.hparams.target + proportional_error = np.clip(current / target - 1, -0.2, 0.2) + mult = 1 + proportional_error * n_steps / self.hparams.horizon + self.value *= mult + + +def layer_init(layer, std=np.sqrt(2), bias_const=0.0): + torch.nn.init.normal_(layer.weight, std=std) + torch.nn.init.constant_(layer.bias, val=bias_const) + return layer + + +def whiten(values, shift_mean=True): + # `unbiased=False` matches TF `tf.nn.moments`'s setting + mean, var = torch.mean(values), torch.var(values, unbiased=False) + whitened = (values - mean) * torch.rsqrt(var + 1e-8) + if not shift_mean: + whitened += mean + return whitened + + +class AutoModelForCausalLMWithScalarHead(nn.Module): + def __init__(self, lm_backbone): + super().__init__() + self.lm_backbone = lm_backbone + self.scalar_head = layer_init(nn.Linear(lm_backbone.config.hidden_size, 1), std=0) + + def forward(self, **kwargs): + output = self.lm_backbone(**kwargs) + return output, self.scalar_head(output.hidden_states[-1]) + + +class AutoModelForCausalLMWithRewardHead(nn.Module): + def __init__(self, lm_backbone): + super().__init__() + self.lm_backbone = lm_backbone + self.scalar_head = layer_init( + nn.Linear(lm_backbone.config.hidden_size, 1), + std=1 / np.sqrt(lm_backbone.config.hidden_size + 1), + ) + self.reward_gain = torch.nn.Parameter(torch.tensor(1.0), requires_grad=True) + self.reward_bias = torch.nn.Parameter(torch.tensor(0.0), requires_grad=True) + + +# a pytorch dataset +class MyDataset(IterableDataset): + def __init__(self, generator, tokenizer, query_length, seed, start_text=None, end_text=None): + self.generator = generator + self.tokenizer = tokenizer + self.query_length = query_length + self.start_text = start_text + self.end_text = end_text + self.seed = seed + token_to_index = tokenizer.get_vocab() + self.start_token = token_to_index[start_text] if self.start_text else None + self.end_token = token_to_index[end_text] if self.end_text else None + + def __iter__(self): + for text in self.generator("train", self.seed, shuffle=True): + tokens = self.tokenizer.encode(text) + if self.start_token is not None: + try: + first_index = tokens.index(self.start_token) + 1 + if first_index < len(tokens): + tokens = tokens[first_index:] + except: + continue + tokens = tokens[: self.query_length] + if self.end_token is not None: + try: + last_index = len(tokens) - tokens[::-1].index(self.end_token) + tokens = tokens[:last_index] + except: + continue + output = self.tokenizer.pad( + {"input_ids": tokens}, + padding="max_length", + max_length=self.query_length, + return_tensors="pt", + return_attention_mask=True, + ) + yield output + + +def right_padding_to_left_padding(query, pad_id): + # Convert from right padding to left padding. + return torch.tensor( + [[pad_id] * (row == pad_id).sum() + [x for x in row if x != pad_id] for row in query], + device=query.device, + ) + + +def ceil_div(a, b): + return (a - 1) // b + 1 + + +def exact_div(a, b): + q = a // b + if a != q * b: + raise ValueError(f"Inexact division: {a} / {b} = {a / b}") + return q + + +def generate(lm_backbone, queries, tokenizer, generation_config): + """generate in a way that does not affect padding tokens""" + context_length = queries.shape[1] + attention_mask = queries != tokenizer.pad_token_id + input_ids = queries.clone() + input_ids[~attention_mask] = 0 # set padding tokens to 0 + output = lm_backbone.generate( + input_ids=input_ids, + attention_mask=attention_mask, + # position_ids=attention_mask.cumsum(1) - attention_mask.long(), # generation collapsed if this was turned on. TODO: why does generation collapse with this? + generation_config=generation_config, + return_dict_in_generate=True, + ) + # restore padding tokens + return torch.cat((queries, output.sequences[:, context_length:]), dim=1) + + +def get_reward(reward_model, query_responses, tokenizer): + attention_mask = query_responses != tokenizer.pad_token_id + position_ids = attention_mask.cumsum(1) - attention_mask.long() # exclusive cumsum + input_ids = query_responses.clone() + input_ids[~attention_mask] = 0 + output = reward_model.lm_backbone( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + return_dict=True, + output_hidden_states=True, + ) + reward = reward_model.scalar_head(output.hidden_states[-1]) + reward = reward_model.reward_gain * reward + reward_model.reward_bias + # but we only care about the reward of the last token + reward = reward[:, -1] + return reward + + +def forward(policy, query_responses, tokenizer): + attention_mask = query_responses != tokenizer.pad_token_id + position_ids = attention_mask.cumsum(1) - attention_mask.long() # exclusive cumsum + input_ids = query_responses.clone() + input_ids[~attention_mask] = 0 + return policy( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + return_dict=True, + output_hidden_states=True, + ) + + +def train(args: Args): + accelerator = Accelerator(gradient_accumulation_steps=args.ppo.gradient_accumulation_steps) + args.ppo.world_size = accelerator.num_processes + args.ppo.batch_size = int(args.ppo.local_batch_size * args.ppo.world_size) + args.ppo.minibatch_size = exact_div(args.ppo.batch_size, args.ppo.nminibatches) + args.ppo.local_mini_batch_size = exact_div(args.ppo.local_batch_size, args.ppo.nminibatches) + args.ppo.local_micro_batch_size = exact_div(args.ppo.local_mini_batch_size, args.ppo.gradient_accumulation_steps) + if args.ppo.whiten_rewards: + assert ( + args.ppo.local_mini_batch_size >= 8 + ), f"Per-rank minibatch size {args.ppo.local_mini_batch_size} is insufficient for whitening" + # `per_rank_rollout_batch_size` is our `args.ppo.local_batch_size` + # `per_rank_minibatch_size` is our `args.ppo.local_mini_batch_size` + args.ppo.num_updates = args.ppo.total_episodes // args.ppo.batch_size + + console = Console(force_terminal=True) + run_name = f"{args.exp_name}__{args.seed}__{int(time.time())}" + writer = SimpleNamespace() # dummy writer + writer.add_scalar = lambda x, y, z: None + writer.add_histogram = lambda x, y, z: None + if accelerator.is_main_process: + if args.track: + import wandb + + wandb.init( + project=args.wandb_project_name, + entity=args.wandb_entity, + sync_tensorboard=True, + config=asdict(args), + name=run_name, + save_code=True, + ) + wandb.run.log_code(".") + writer = SummaryWriter(f"runs/{run_name}") + writer.add_text( + "hyperparameters", + "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])), + ) + pprint(args) + device = accelerator.device + local_seed = args.seed + accelerator.process_index * 100003 # Prime + random.seed(local_seed) + np.random.seed(local_seed) + torch.manual_seed(local_seed) + torch.backends.cudnn.deterministic = True + tokenizer = AutoTokenizer.from_pretrained( + args.base_model, + padding_side="right", + trust_remote_code=True, + ) + # we use the padding token manually but do not resize the token embedding of the model + tokenizer.add_special_tokens({"pad_token": "[PAD]"}) + reward_model = AutoModelForCausalLMWithRewardHead(AutoModelForCausalLM.from_pretrained(args.base_model, trust_remote_code=True)) + if args.rewards.trained_model: + reward_model.load_state_dict(torch.load(args.rewards.trained_model, map_location=device)) + print(f"loaded pretrained reward model from {args.rewards.trained_model}") + # each class should have a separate pretrained model that do not share weights + ref_policy = AutoModelForCausalLMWithScalarHead(AutoModelForCausalLM.from_pretrained(args.base_model, trust_remote_code=True)) + policy = AutoModelForCausalLMWithScalarHead(AutoModelForCausalLM.from_pretrained(args.base_model, trust_remote_code=True)) + policy.lm_backbone.generation_config.eos_token_id = ( + None # disable `pad_token_id` and `eos_token_id` because we just want to + ) + policy.lm_backbone.generation_config.pad_token_id = None # generate tokens without truncation / padding + # IMPORTANT: Layer norm produces weird gradients, which affects Adam optimizer to impact all the parameters systematically + # see https://github.com/pytorch/pytorch/issues/104857 for more details + if args.use_tensorflow_adam: + optimizer = AdamTensorFlowStyle(policy.parameters(), lr=args.ppo.lr, eps=args.ppo.eps) + else: + optimizer = optim.Adam(policy.parameters(), lr=args.ppo.lr, eps=args.ppo.eps) + dataset = MyDataset( + DATASET[args.task.query_dataset], + tokenizer, + args.task.query_length, + seed=local_seed, + start_text=args.task.start_text, + end_text=args.task.end_text, + ) + dataloader = DataLoader(dataset, batch_size=args.ppo.local_batch_size) + policy, optimizer, dataloader = accelerator.prepare(policy, optimizer, dataloader) + if args.deepspeed: + import deepspeed + + deepspeed_states = AcceleratorState().deepspeed_plugin + deepspeed_states.deepspeed_config['train_micro_batch_size_per_gpu'] = args.ppo.local_micro_batch_size + deepspeed_states.deepspeed_config['checkpoint'] = {'use_node_local_storage': True} + off_load_device = "cpu" + stage = 3 + eval_ds_config = { + "train_micro_batch_size_per_gpu": deepspeed_states.deepspeed_config['train_micro_batch_size_per_gpu'], + "steps_per_print": 10, + # "zero_optimization": { + # "stage": stage, + # "stage3_param_persistence_threshold": 1e4, + # "offload_param": { + # "device": off_load_device + # } + # }, + "bf16": { + "enabled": True + }, + "prescale_gradients": False, + "wall_clock_breakdown": False + } + reward_model, *_ = deepspeed.initialize(model=reward_model, config=eval_ds_config) + reward_model.eval() + ref_policy, *_ = deepspeed.initialize(model=ref_policy, config=eval_ds_config) + ref_policy.eval() + else: + ref_policy = ref_policy.to(device) + reward_model = reward_model.to(device) + iter_dataloader = iter(dataloader) + kl_ctl = AdaptiveKLController(args.rewards.kl_coef, hparams=args.rewards.adaptive_kl) + # WARNING: even with `max_new_tokens` and `min_new_tokens` set to the same value, the number of tokens generated + # may not be the same. TODO: investigate further, we just want to generate a fixed number of tokens + generation_config = GenerationConfig( + max_new_tokens=args.task.response_length, + min_new_tokens=args.task.response_length, + temperature=args.task.temperature, + top_k=0.0, + top_p=1.0, + do_sample=True, + ) + + print("===training policy===") + global_step = 0 + approxkls_stats = torch.zeros( + ( + args.ppo.noptepochs, + args.ppo.nminibatches, + args.ppo.gradient_accumulation_steps, + ), + device=device, + ) + clipfracs_stats = torch.zeros( + ( + args.ppo.noptepochs, + args.ppo.nminibatches, + args.ppo.gradient_accumulation_steps, + ), + device=device, + ) + pg_losses_stats = torch.zeros( + ( + args.ppo.noptepochs, + args.ppo.nminibatches, + args.ppo.gradient_accumulation_steps, + ), + device=device, + ) + vf_losses_stats = torch.zeros( + ( + args.ppo.noptepochs, + args.ppo.nminibatches, + args.ppo.gradient_accumulation_steps, + ), + device=device, + ) + vf_clipfrac_stats = torch.zeros( + ( + args.ppo.noptepochs, + args.ppo.nminibatches, + args.ppo.gradient_accumulation_steps, + ), + device=device, + ) + entropies_stats = torch.zeros( + ( + args.ppo.noptepochs, + args.ppo.nminibatches, + args.ppo.gradient_accumulation_steps, + ), + device=device, + ) + ratio_stats = torch.zeros( + ( + args.ppo.noptepochs, + args.ppo.nminibatches, + args.ppo.gradient_accumulation_steps, + ), + device=device, + ) + for update in range(1, args.ppo.num_updates + 1): + global_step += 1 * args.ppo.batch_size + frac = 1.0 - (update - 1.0) / args.ppo.num_updates + lrnow = frac * args.ppo.lr + optimizer.param_groups[0]["lr"] = lrnow + data = next(iter_dataloader) + with torch.no_grad(): + queries = data["input_ids"].to(device) + queries = right_padding_to_left_padding(data["input_ids"], tokenizer.pad_token_id).to(device) + query_responses = generate( + accelerator.unwrap_model(policy).lm_backbone, + queries, + tokenizer, + generation_config, + ) + context_length = queries.shape[1] + responses = query_responses[:, context_length:] + + output, full_values = forward(policy, query_responses, tokenizer) + values = full_values[:, context_length - 1 : -1].squeeze(-1) + logits = output.logits[:, context_length - 1 : -1] + logits /= args.task.temperature + all_logprobs = F.log_softmax(logits, dim=-1) + logprobs = torch.gather(all_logprobs, 2, responses.unsqueeze(-1)).squeeze(-1) + del output, logits, all_logprobs + torch.cuda.empty_cache() + + ref_output, _ = forward(ref_policy, query_responses, tokenizer) + ref_logits = ref_output.logits[:, context_length - 1 : -1] + ref_logits /= args.task.temperature + ref_all_logprobs = F.log_softmax(ref_logits, dim=-1) + ref_logprobs = torch.gather(ref_all_logprobs, 2, responses.unsqueeze(-1)).squeeze(-1) + del ref_output, ref_logits, ref_all_logprobs + torch.cuda.empty_cache() + + # **Response Processing** + # 1. truncate at the first occurrence of `truncate_token` that appears at or after + # position truncate_after in the responses + # https://github.com/openai/lm-human-preferences/blob/cbfd210bb8b08f6bc5c26878c10984b90f516c66/lm_human_preferences/train_policy.py#L378 + truncate_token_mask = responses == args.task.truncate_token + truncate_after_or_token_mask = torch.cat( + [ + torch.zeros_like(truncate_token_mask)[:, : args.task.truncate_after], + truncate_token_mask[:, args.task.truncate_after :], + ], + dim=1, + ) + truncate_mask = (torch.cumsum(truncate_after_or_token_mask, dim=1) - truncate_after_or_token_mask.long()).bool() + postprocessed_responses = torch.where( + truncate_mask, + torch.full_like(responses, tokenizer.pad_token_id), + responses, + ) + del truncate_token_mask, truncate_after_or_token_mask, truncate_mask + torch.cuda.empty_cache() + + # 2. run reward model on the truncated responses + postprocessed_query_responses = torch.cat((queries, postprocessed_responses), 1) + postprocessed_query_responses = right_padding_to_left_padding( + postprocessed_query_responses, tokenizer.pad_token_id + ) + scores = get_reward(reward_model, postprocessed_query_responses, tokenizer).flatten() + + # 3. filter response. Ensure that the sample contains truncate_token + # responses not passing that filter will receive a low (fixed) score + # only query humans on responses that pass that filter + matches_token = postprocessed_responses[:, args.task.truncate_after :] == args.task.truncate_token + filter_mask = torch.any(matches_token, dim=-1) + scores = torch.where( + filter_mask, + scores, + torch.full_like(scores, args.task.penalty_reward_value), + ) + del matches_token, filter_mask + torch.cuda.empty_cache() + + # 4. compute rewards + kl = logprobs - ref_logprobs + non_score_reward = -kl_ctl.value * kl + rewards = non_score_reward.clone() + rewards[:, -1] += scores + + # 5. whiten rewards + if args.ppo.whiten_rewards: + rewards = whiten(rewards, shift_mean=False) + try: + sample_kl = kl[0].sum().item() + postprocessed_responses = postprocessed_query_responses[:, context_length:] + console.print( + f"[green]{tokenizer.decode(queries[0], skip_special_tokens=True)}[/]\n[yellow]{tokenizer.decode(postprocessed_responses[0], skip_special_tokens=True)}[/]\n[blue](NO POST-PROCESSING){tokenizer.decode(responses[0], skip_special_tokens=True)}[/]\n[red]score: {scores[0]}, kl: {kl[0].sum().item()}, total reward: {scores[0] - kl_ctl.value * sample_kl} [/]" + ) + except Exception as e: + print(e) + del postprocessed_query_responses + torch.cuda.empty_cache() + + # 6. compute advantages and returns + lastgaelam = 0 + advantages_reversed = [] + gen_length = args.task.response_length + for t in reversed(range(gen_length)): + nextvalues = values[:, t + 1] if t < gen_length - 1 else 0.0 + delta = rewards[:, t] + args.ppo.gamma * nextvalues - values[:, t] + lastgaelam = delta + args.ppo.gamma * args.ppo.lam * lastgaelam + advantages_reversed.append(lastgaelam) + advantages = torch.stack(advantages_reversed[::-1], axis=1) + returns = advantages + values + advantages = whiten(advantages) + return_mean, return_var = returns.mean(), returns.var() + value_mean, value_var = values.mean(), values.var() + + # Do multiple epochs of PPO training, with a fresh random shuffle in each epoch + re_calculated_logprobs = torch.zeros_like(logprobs) + re_calculated_values = torch.zeros_like(values) + for ppo_epoch_idx in range(args.ppo.noptepochs): + b_inds = np.random.permutation(args.ppo.local_batch_size) + minibatch_idx = 0 + for mini_batch_start in range(0, args.ppo.local_batch_size, args.ppo.local_mini_batch_size): + mini_batch_end = mini_batch_start + args.ppo.local_mini_batch_size + mini_batch_inds = b_inds[mini_batch_start:mini_batch_end] + gradient_accumulation_idx = 0 + for micro_batch_start in range(0, args.ppo.local_mini_batch_size, args.ppo.local_micro_batch_size): + micro_batch_end = micro_batch_start + args.ppo.local_micro_batch_size + micro_batch_inds = mini_batch_inds[micro_batch_start:micro_batch_end] + mb_return = returns[micro_batch_inds] + mb_advantage = advantages[micro_batch_inds] + mb_responses = responses[micro_batch_inds] + mb_query_responses = query_responses[micro_batch_inds] + + # re-calculate logprobs and values for the first epoch, otherwise `bf16` will cause the logprobs to + # be much different because the logprobs are with a batch size of `local_batch_size` but the + # `new_logprobs` are with a batch size of `local_micro_batch_size` + if ppo_epoch_idx == 0: + with torch.no_grad(): + output, vpred_temp = forward(policy, mb_query_responses, tokenizer) + logits = output.logits[:, context_length - 1 : -1] + logits /= args.task.temperature + new_all_logprobs = F.log_softmax(logits, dim=-1) + new_logprobs = torch.gather(new_all_logprobs, 2, mb_responses.unsqueeze(-1)).squeeze(-1) + vpred = vpred_temp[:, context_length - 1 : -1].squeeze(-1) + re_calculated_logprobs[micro_batch_inds] = new_logprobs + re_calculated_values[micro_batch_inds] = vpred + del output, logits, new_all_logprobs + mb_values = re_calculated_values[micro_batch_inds] + mb_logprobs = re_calculated_logprobs[micro_batch_inds] + + with accelerator.accumulate(policy): + output, vpred_temp = forward(policy, mb_query_responses, tokenizer) + logits = output.logits[:, context_length - 1 : -1] + logits /= args.task.temperature + new_all_logprobs = F.log_softmax(logits, dim=-1) + new_logprobs = torch.gather(new_all_logprobs, 2, mb_responses.unsqueeze(-1)).squeeze(-1) + vpred = vpred_temp[:, context_length - 1 : -1].squeeze(-1) + vpredclipped = torch.clamp( + vpred, + mb_values - args.ppo.cliprange_value, + mb_values + args.ppo.cliprange_value, + ) + vf_losses1 = torch.square(vpred - mb_return) + vf_losses2 = torch.square(vpredclipped - mb_return) + vf_loss = 0.5 * torch.max(vf_losses1, vf_losses2).mean() + vf_clipfrac = (vf_losses2 > vf_losses1).float().mean() + logprobs_diff = new_logprobs - mb_logprobs + ratio = torch.exp(logprobs_diff) + pg_losses = -mb_advantage * ratio + pg_losses2 = -mb_advantage * torch.clamp(ratio, 1.0 - args.ppo.cliprange, 1.0 + args.ppo.cliprange) + pg_loss = torch.max(pg_losses, pg_losses2).mean() + pg_clipfrac = (pg_losses2 > pg_losses).float().mean() + loss = pg_loss + args.ppo.vf_coef * vf_loss + accelerator.backward(loss) + optimizer.step() + optimizer.zero_grad() + pd = torch.nn.functional.softmax(logits, dim=-1) + entropy = torch.logsumexp(logits, dim=-1) - torch.sum(pd * logits, dim=-1) + approxkl = 0.5 * (logprobs_diff**2).mean() + with torch.no_grad(): + approxkls_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = approxkl + clipfracs_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = pg_clipfrac + pg_losses_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = pg_loss + vf_losses_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = vf_loss + vf_clipfrac_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = vf_clipfrac + entropies_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = entropy.mean() + ratio_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = ratio.mean() + gradient_accumulation_idx += 1 + minibatch_idx += 1 + if accelerator.is_main_process: + console.print( + f"ppo_epoch_idx", + ppo_epoch_idx, + "approxkl", + approxkl.item(), + "pg_loss", + pg_loss.item(), + "pg_clipfrac", + pg_clipfrac.item(), + "ratio", + ratio.mean().item(), + ) + breakpoint() + if accelerator.is_main_process: + console.print("ratio_stats", ratio_stats.mean()) + breakpoint() + + with torch.no_grad(): + if not args.deepspeed: # for some reason there is a OOM with the `writer.add_histogram` + writer.add_histogram("ppo/val/ratio_hist", ratio, update) + kl = logprobs - ref_logprobs + mean_kl = kl.sum(1).mean() + mean_entropy = (-logprobs).sum(1).mean() + mean_non_score_reward = non_score_reward.sum(1).mean() + writer.add_scalar("objective/kl_coef", kl_ctl.value, update) + writer.add_scalar("objective/kl", accelerator.gather(mean_kl).mean().item(), update) + writer.add_scalar( + "objective/entropy", + accelerator.gather(mean_entropy).mean().item(), + update, + ) + writer.add_scalar( + "objective/non_score_reward", + accelerator.gather(mean_non_score_reward).mean().item(), + update, + ) + writer.add_scalar( + "objective/score_total", + accelerator.gather(mean_non_score_reward + scores.mean()).mean().item(), + update, + ) + writer.add_scalar( + "objective/scores", + accelerator.gather(scores.mean()).mean().item(), + update, + ) + writer.add_scalar("ppo/loss/policy", accelerator.gather(pg_loss).mean().item(), update) + writer.add_scalar("ppo/loss/value", accelerator.gather(vf_loss).mean().item(), update) + writer.add_scalar("ppo/loss/total", accelerator.gather(loss).mean().item(), update) + writer.add_scalar( + "ppo/policy/entropy", + accelerator.gather(entropy.mean()).mean().item(), + update, + ) + writer.add_scalar( + "ppo/policy/approxkl", + accelerator.gather(approxkl).mean().item(), + update, + ) + writer.add_scalar( + "ppo/policy/clipfrac", + accelerator.gather(pg_clipfrac).mean().item(), + update, + ) + writer.add_scalar( + "ppo/policy/approxkl_avg", + accelerator.gather(approxkls_stats).mean().item(), + update, + ) + writer.add_scalar( + "ppo/policy/clipfrac_avg", + accelerator.gather(clipfracs_stats).mean().item(), + update, + ) + writer.add_scalar( + "ppo/loss/policy_avg", + accelerator.gather(pg_losses_stats).mean().item(), + update, + ) + writer.add_scalar( + "ppo/loss/value_avg", + accelerator.gather(vf_losses_stats).mean().item(), + update, + ) + writer.add_scalar( + "ppo/val/clipfrac_avg", + accelerator.gather(vf_clipfrac_stats).mean().item(), + update, + ) + writer.add_scalar( + "ppo/policy/entropy_avg", + accelerator.gather(entropies_stats).mean().item(), + update, + ) + writer.add_scalar( + "ppo/returns/mean", + accelerator.gather(return_mean).mean().item(), + update, + ) + writer.add_scalar("ppo/returns/var", accelerator.gather(return_var).mean().item(), update) + writer.add_scalar("ppo/val/vpred", accelerator.gather(vpred.mean()).mean().item(), update) + writer.add_scalar( + "ppo/val/error", + accelerator.gather(vf_losses1.mean()).mean().item(), + update, + ) + writer.add_scalar( + "ppo/val/clipfrac", + accelerator.gather(vf_clipfrac).mean().item(), + update, + ) + writer.add_scalar("ppo/val/mean", accelerator.gather(value_mean).mean().item(), update) + writer.add_scalar("ppo/val/var", accelerator.gather(value_var).mean().item(), update) + writer.add_scalar("ppo/val/ratio", accelerator.gather(ratio.mean()).mean().item(), update) + writer.add_scalar( + "ppo/val/ratio_var", + accelerator.gather(ratio.mean()).var().item(), + update, + ) + writer.add_scalar( + "ppo/val/advantage", + accelerator.gather(advantages.mean()).mean().item(), + update, + ) + writer.add_scalar( + "ppo/val/advantage_var", + accelerator.gather(advantages.mean()).var().item(), + update, + ) + writer.add_scalar( + "ppo/val/num_eos_tokens", + (responses == tokenizer.eos_token_id).sum().item(), + update, + ) + writer.add_scalar("ppo/lr", lrnow, update) + writer.add_scalar("ppo/episode", global_step, update) + kl_ctl.update(mean_kl.item(), args.ppo.batch_size) + del kl, mean_kl, mean_entropy, mean_non_score_reward, scores + + # save model + if args.save_path: + os.makedirs(os.path.dirname(args.save_path), exist_ok=True) + torch.save(reward_model.state_dict(), args.save_path) + + +if __name__ == "__main__": + args = tyro.cli(Args) + train(args) diff --git a/lm_human_preference_details/summarization/train_policy_accelerate_old.py b/lm_human_preference_details/summarization/train_policy_accelerate_old.py new file mode 100644 index 0000000..de27920 --- /dev/null +++ b/lm_human_preference_details/summarization/train_policy_accelerate_old.py @@ -0,0 +1,922 @@ +import os +import random +import time +from dataclasses import asdict, dataclass, field +from types import SimpleNamespace +from typing import Optional + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +import tyro +from accelerate import Accelerator +from accelerate.state import AcceleratorState +from rich.console import Console +from rich.pretty import pprint +from torch.utils.data import DataLoader, IterableDataset +from torch.utils.tensorboard import SummaryWriter +from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig + +from lm_human_preference_details.data import DATASET + + +@dataclass +class AdaptiveKLParams: + target: float = 6.0 + horizon: int = 10000 # in episodes + + +@dataclass +class RewardHParams: + kl_coef: float = 0.15 + adaptive_kl: Optional[AdaptiveKLParams] = field(default_factory=AdaptiveKLParams) + trained_model: Optional[str] = "models/reward.pt" + label_dataset: tyro.conf.Suppress[Optional[str]] = None + + +@dataclass +class PpoHParams: + total_episodes: int = 1000000 + local_batch_size: int = 64 + local_mini_batch_size: tyro.conf.Suppress[int] = None + batch_size: tyro.conf.Suppress[int] = None + mini_batch_size: tyro.conf.Suppress[int] = None + gradient_accumulation_steps: int = 1 + """gradient accumulation steps""" + local_micro_batch_size: tyro.conf.Suppress[int] = None + """per rank micro batch size""" + world_size: tyro.conf.Suppress[int] = None + batch_size: tyro.conf.Suppress[int] = None + minibatch_size: tyro.conf.Suppress[int] = None + num_updates: tyro.conf.Suppress[int] = None + nminibatches: int = 1 + noptepochs: int = 4 + lr: float = 0.00001 + eps: float = 1e-5 + vf_coef: float = 0.1 + cliprange: float = 0.2 + cliprange_value: float = 0.2 + gamma: float = 1 + lam: float = 0.95 + whiten_rewards: bool = True + + +@dataclass +class TaskHParams: + # Query params + query_length: int = 64 + query_dataset: str = "books" + query_prefix: str = "" + query_suffix: str = "" + start_text: Optional[str] = None + end_text: Optional[str] = None + + # Response params + response_length: int = 24 + + # Truncate response after the first occurrence of this token at or after index after when sampling. + truncate_token: int = 13 + truncate_after: int = 16 + penalty_reward_value: int = -1 + + # LM params + temperature: float = 0.7 + + +@dataclass +class Args: + # common args + exp_name: str = os.path.basename(__file__)[: -len(".py")] + """the name of this experiment""" + seed: int = 1 + """seed of the experiment""" + track: bool = False + """if toggled, this experiment will be tracked with Weights and Biases""" + wandb_project_name: str = "cleanrl" + """the wandb's project name""" + wandb_entity: Optional[str] = None + """the entity (team) of wandb's project""" + cuda: bool = True + """Whether to use cuda if available.""" + run_name: tyro.conf.Suppress[str] = None + """TO BE FILLED: a unique name of this run""" + + base_model: str = "gpt2" + """the name of the pretrained model to use""" + deepspeed: bool = False + """Whether to use deepspeed to train the model""" + print_sample_output_freq: int = 0 + """How often to print sample output""" + save_path: str = "models/policy.pt" + """Where to save the model""" + use_tensorflow_adam: bool = True + """Whether to use tensorflow-style Adam optimizer instead of PyTorch's""" + task: TaskHParams = field(default_factory=TaskHParams) + rewards: RewardHParams = field(default_factory=RewardHParams) + ppo: PpoHParams = field(default_factory=PpoHParams) + + +from typing import List, Optional + +from torch import Tensor, optim +from torch.optim.optimizer import ( + _dispatch_sqrt, + _get_value, + _use_grad_for_differentiable, +) + + +def _single_tensor_adam( + params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_avg_sqs: List[Tensor], + max_exp_avg_sqs: List[Tensor], + state_steps: List[Tensor], + grad_scale: Optional[Tensor], + found_inf: Optional[Tensor], + *, + amsgrad: bool, + beta1: float, + beta2: float, + lr: float, + weight_decay: float, + eps: float, + maximize: bool, + capturable: bool, + differentiable: bool, +): + assert grad_scale is None and found_inf is None + + for i, param in enumerate(params): + grad = grads[i] if not maximize else -grads[i] + exp_avg = exp_avgs[i] + exp_avg_sq = exp_avg_sqs[i] + step_t = state_steps[i] + # update step + step_t += 1 + # Decay the first and second moment running average coefficient + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2) + step = _get_value(step_t) + + ### pytorch adam implementation: + # bias_correction1 = 1 - beta1 ** step + # bias_correction2 = 1 - beta2 ** step + # step_size = lr / bias_correction1 + # bias_correction2_sqrt = _dispatch_sqrt(bias_correction2) + # denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps) + # param.addcdiv_(exp_avg, denom, value=-step_size) + + ### tensorflow adam implementation: + lr_t = lr * _dispatch_sqrt(1 - beta2**step) / (1 - beta1**step) + denom = exp_avg_sq.sqrt().add_(eps) + param.addcdiv_(exp_avg, denom, value=-lr_t) + + +def adam( + params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_avg_sqs: List[Tensor], + max_exp_avg_sqs: List[Tensor], + state_steps: List[Tensor], + # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627 + # setting this as kwarg for now as functional API is compiled by torch/distributed/optim + foreach: Optional[bool] = None, + capturable: bool = False, + differentiable: bool = False, + fused: Optional[bool] = None, + grad_scale: Optional[Tensor] = None, + found_inf: Optional[Tensor] = None, + *, + amsgrad: bool, + beta1: float, + beta2: float, + lr: float, + weight_decay: float, + eps: float, + maximize: bool, +): + func = _single_tensor_adam + + func( + params, + grads, + exp_avgs, + exp_avg_sqs, + max_exp_avg_sqs, + state_steps, + amsgrad=amsgrad, + beta1=beta1, + beta2=beta2, + lr=lr, + weight_decay=weight_decay, + eps=eps, + maximize=maximize, + capturable=capturable, + differentiable=differentiable, + grad_scale=grad_scale, + found_inf=found_inf, + ) + + +class AdamTensorFlowStyle(optim.Adam): + @_use_grad_for_differentiable + def step(self, closure=None): + self._cuda_graph_capture_health_check() + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + params_with_grad = [] + grads = [] + exp_avgs = [] + exp_avg_sqs = [] + max_exp_avg_sqs = [] + state_steps = [] + beta1, beta2 = group["betas"] + + self._init_group( + group, + params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + max_exp_avg_sqs, + state_steps, + ) + + adam( + params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + max_exp_avg_sqs, + state_steps, + amsgrad=group["amsgrad"], + beta1=beta1, + beta2=beta2, + lr=group["lr"], + weight_decay=group["weight_decay"], + eps=group["eps"], + maximize=group["maximize"], + foreach=group["foreach"], + capturable=group["capturable"], + differentiable=group["differentiable"], + fused=group["fused"], + grad_scale=getattr(self, "grad_scale", None), + found_inf=getattr(self, "found_inf", None), + ) + + return loss + + +class AdaptiveKLController: + def __init__(self, init_kl_coef: float, hparams: AdaptiveKLParams): + self.value = init_kl_coef + self.hparams = hparams + + def update(self, current, n_steps): + target = self.hparams.target + proportional_error = np.clip(current / target - 1, -0.2, 0.2) + mult = 1 + proportional_error * n_steps / self.hparams.horizon + self.value *= mult + + +def layer_init(layer, std=np.sqrt(2), bias_const=0.0): + torch.nn.init.normal_(layer.weight, std=std) + torch.nn.init.constant_(layer.bias, val=bias_const) + return layer + + +def whiten(values, shift_mean=True): + # `unbiased=False` matches TF `tf.nn.moments`'s setting + mean, var = torch.mean(values), torch.var(values, unbiased=False) + whitened = (values - mean) * torch.rsqrt(var + 1e-8) + if not shift_mean: + whitened += mean + return whitened + + +class AutoModelForCausalLMWithScalarHead(nn.Module): + def __init__(self, lm_backbone): + super().__init__() + self.lm_backbone = lm_backbone + self.scalar_head = layer_init(nn.Linear(lm_backbone.config.hidden_size, 1), std=0) + + def forward(self, **kwargs): + output = self.lm_backbone(**kwargs) + return output, self.scalar_head(output.hidden_states[-1]) + + +class AutoModelForCausalLMWithRewardHead(nn.Module): + def __init__(self, lm_backbone): + super().__init__() + self.lm_backbone = lm_backbone + self.scalar_head = layer_init( + nn.Linear(lm_backbone.config.hidden_size, 1), + std=1 / np.sqrt(lm_backbone.config.hidden_size + 1), + ) + self.reward_gain = torch.nn.Parameter(torch.tensor(1.0), requires_grad=True) + self.reward_bias = torch.nn.Parameter(torch.tensor(0.0), requires_grad=True) + + +# a pytorch dataset +class MyDataset(IterableDataset): + def __init__(self, generator, tokenizer, query_length, seed, start_text=None, end_text=None): + self.generator = generator + self.tokenizer = tokenizer + self.query_length = query_length + self.start_text = start_text + self.end_text = end_text + self.seed = seed + token_to_index = tokenizer.get_vocab() + self.start_token = token_to_index[start_text] if self.start_text else None + self.end_token = token_to_index[end_text] if self.end_text else None + + def __iter__(self): + for text in self.generator("train", self.seed, shuffle=True): + tokens = self.tokenizer.encode(text) + if self.start_token is not None: + try: + first_index = tokens.index(self.start_token) + 1 + if first_index < len(tokens): + tokens = tokens[first_index:] + except: + continue + tokens = tokens[: self.query_length] + if self.end_token is not None: + try: + last_index = len(tokens) - tokens[::-1].index(self.end_token) + tokens = tokens[:last_index] + except: + continue + output = self.tokenizer.pad( + {"input_ids": tokens}, + padding="max_length", + max_length=self.query_length, + return_tensors="pt", + return_attention_mask=True, + ) + yield output + + +def right_padding_to_left_padding(query, pad_id): + # Convert from right padding to left padding. + return torch.tensor( + [[pad_id] * (row == pad_id).sum() + [x for x in row if x != pad_id] for row in query], + device=query.device, + ) + + +def ceil_div(a, b): + return (a - 1) // b + 1 + + +def exact_div(a, b): + q = a // b + if a != q * b: + raise ValueError(f"Inexact division: {a} / {b} = {a / b}") + return q + + +def generate(lm_backbone, queries, tokenizer, generation_config): + """generate in a way that does not affect padding tokens""" + context_length = queries.shape[1] + attention_mask = queries != tokenizer.pad_token_id + input_ids = queries.clone() + input_ids[~attention_mask] = 0 # set padding tokens to 0 + output = lm_backbone.generate( + input_ids=input_ids, + attention_mask=attention_mask, + # position_ids=attention_mask.cumsum(1) - attention_mask.long(), # generation collapsed if this was turned on. TODO: why does generation collapse with this? + generation_config=generation_config, + return_dict_in_generate=True, + ) + # restore padding tokens + return torch.cat((queries, output.sequences[:, context_length:]), dim=1) + + +def get_reward(reward_model, query_responses, tokenizer): + attention_mask = query_responses != tokenizer.pad_token_id + position_ids = attention_mask.cumsum(1) - attention_mask.long() # exclusive cumsum + input_ids = query_responses.clone() + input_ids[~attention_mask] = 0 + output = reward_model.lm_backbone( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + return_dict=True, + output_hidden_states=True, + ) + reward = reward_model.scalar_head(output.hidden_states[-1]) + reward = reward_model.reward_gain * reward + reward_model.reward_bias + # but we only care about the reward of the last token + reward = reward[:, -1] + return reward + + +def forward(policy, query_responses, tokenizer): + attention_mask = query_responses != tokenizer.pad_token_id + position_ids = attention_mask.cumsum(1) - attention_mask.long() # exclusive cumsum + input_ids = query_responses.clone() + input_ids[~attention_mask] = 0 + return policy( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + return_dict=True, + output_hidden_states=True, + ) + + +def train(args: Args): + accelerator = Accelerator(gradient_accumulation_steps=args.ppo.gradient_accumulation_steps) + args.ppo.world_size = accelerator.num_processes + args.ppo.batch_size = int(args.ppo.local_batch_size * args.ppo.world_size) + args.ppo.minibatch_size = exact_div(args.ppo.batch_size, args.ppo.nminibatches) + args.ppo.local_mini_batch_size = exact_div(args.ppo.local_batch_size, args.ppo.nminibatches) + args.ppo.local_micro_batch_size = exact_div(args.ppo.local_mini_batch_size, args.ppo.gradient_accumulation_steps) + if args.ppo.whiten_rewards: + assert ( + args.ppo.local_mini_batch_size >= 8 + ), f"Per-rank minibatch size {args.ppo.local_mini_batch_size} is insufficient for whitening" + # `per_rank_rollout_batch_size` is our `args.ppo.local_batch_size` + # `per_rank_minibatch_size` is our `args.ppo.local_mini_batch_size` + args.ppo.num_updates = args.ppo.total_episodes // args.ppo.batch_size + + console = Console(force_terminal=True) + run_name = f"{args.exp_name}__{args.seed}__{int(time.time())}" + writer = SimpleNamespace() # dummy writer + writer.add_scalar = lambda x, y, z: None + writer.add_histogram = lambda x, y, z: None + if accelerator.is_main_process: + if args.track: + import wandb + + wandb.init( + project=args.wandb_project_name, + entity=args.wandb_entity, + sync_tensorboard=True, + config=asdict(args), + name=run_name, + save_code=True, + ) + wandb.run.log_code(".") + writer = SummaryWriter(f"runs/{run_name}") + writer.add_text( + "hyperparameters", + "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])), + ) + pprint(args) + device = accelerator.device + local_seed = args.seed + accelerator.process_index * 100003 # Prime + random.seed(local_seed) + np.random.seed(local_seed) + torch.manual_seed(local_seed) + torch.backends.cudnn.deterministic = True + tokenizer = AutoTokenizer.from_pretrained( + args.base_model, + padding_side="right", + trust_remote_code=True, + ) + # we use the padding token manually but do not resize the token embedding of the model + tokenizer.add_special_tokens({"pad_token": "[PAD]"}) + reward_model = AutoModelForCausalLMWithRewardHead(AutoModelForCausalLM.from_pretrained(args.base_model, trust_remote_code=True)) + if args.rewards.trained_model: + reward_model.load_state_dict(torch.load(args.rewards.trained_model, map_location=device)) + print(f"loaded pretrained reward model from {args.rewards.trained_model}") + # each class should have a separate pretrained model that do not share weights + ref_policy = AutoModelForCausalLMWithScalarHead(AutoModelForCausalLM.from_pretrained(args.base_model, trust_remote_code=True)) + policy = AutoModelForCausalLMWithScalarHead(AutoModelForCausalLM.from_pretrained(args.base_model, trust_remote_code=True)) + policy.lm_backbone.generation_config.eos_token_id = ( + None # disable `pad_token_id` and `eos_token_id` because we just want to + ) + policy.lm_backbone.generation_config.pad_token_id = None # generate tokens without truncation / padding + # IMPORTANT: Layer norm produces weird gradients, which affects Adam optimizer to impact all the parameters systematically + # see https://github.com/pytorch/pytorch/issues/104857 for more details + if args.use_tensorflow_adam: + optimizer = AdamTensorFlowStyle(policy.parameters(), lr=args.ppo.lr, eps=args.ppo.eps) + else: + optimizer = optim.Adam(policy.parameters(), lr=args.ppo.lr, eps=args.ppo.eps) + dataset = MyDataset( + DATASET[args.task.query_dataset], + tokenizer, + args.task.query_length, + seed=local_seed, + start_text=args.task.start_text, + end_text=args.task.end_text, + ) + dataloader = DataLoader(dataset, batch_size=args.ppo.local_batch_size) + policy, optimizer, dataloader = accelerator.prepare(policy, optimizer, dataloader) + if args.deepspeed: + import deepspeed + + deepspeed_states = AcceleratorState().deepspeed_plugin + deepspeed_states.deepspeed_config['train_micro_batch_size_per_gpu'] = args.ppo.local_micro_batch_size + deepspeed_states.deepspeed_config['checkpoint'] = {'use_node_local_storage': True} + off_load_device = "cpu" + stage = 3 + eval_ds_config = { + "train_micro_batch_size_per_gpu": deepspeed_states.deepspeed_config['train_micro_batch_size_per_gpu'], + "steps_per_print": 10, + # "zero_optimization": { + # "stage": stage, + # "stage3_param_persistence_threshold": 1e4, + # "offload_param": { + # "device": off_load_device + # } + # }, + "bf16": { + "enabled": True + }, + "prescale_gradients": False, + "wall_clock_breakdown": False + } + reward_model, *_ = deepspeed.initialize(model=reward_model, config=eval_ds_config) + reward_model.eval() + ref_policy, *_ = deepspeed.initialize(model=ref_policy, config=eval_ds_config) + ref_policy.eval() + else: + ref_policy = ref_policy.to(device) + reward_model = reward_model.to(device) + iter_dataloader = iter(dataloader) + kl_ctl = AdaptiveKLController(args.rewards.kl_coef, hparams=args.rewards.adaptive_kl) + # WARNING: even with `max_new_tokens` and `min_new_tokens` set to the same value, the number of tokens generated + # may not be the same. TODO: investigate further, we just want to generate a fixed number of tokens + generation_config = GenerationConfig( + max_new_tokens=args.task.response_length, + min_new_tokens=args.task.response_length, + temperature=args.task.temperature, + top_k=0.0, + top_p=1.0, + do_sample=True, + ) + + print("===training policy===") + global_step = 0 + approxkls_stats = torch.zeros( + ( + args.ppo.noptepochs, + args.ppo.nminibatches, + args.ppo.gradient_accumulation_steps, + ), + device=device, + ) + clipfracs_stats = torch.zeros( + ( + args.ppo.noptepochs, + args.ppo.nminibatches, + args.ppo.gradient_accumulation_steps, + ), + device=device, + ) + pg_losses_stats = torch.zeros( + ( + args.ppo.noptepochs, + args.ppo.nminibatches, + args.ppo.gradient_accumulation_steps, + ), + device=device, + ) + vf_losses_stats = torch.zeros( + ( + args.ppo.noptepochs, + args.ppo.nminibatches, + args.ppo.gradient_accumulation_steps, + ), + device=device, + ) + vf_clipfrac_stats = torch.zeros( + ( + args.ppo.noptepochs, + args.ppo.nminibatches, + args.ppo.gradient_accumulation_steps, + ), + device=device, + ) + entropies_stats = torch.zeros( + ( + args.ppo.noptepochs, + args.ppo.nminibatches, + args.ppo.gradient_accumulation_steps, + ), + device=device, + ) + for update in range(1, args.ppo.num_updates + 1): + global_step += 1 * args.ppo.batch_size + frac = 1.0 - (update - 1.0) / args.ppo.num_updates + lrnow = frac * args.ppo.lr + optimizer.param_groups[0]["lr"] = lrnow + data = next(iter_dataloader) + with torch.no_grad(): + queries = data["input_ids"].to(device) + queries = right_padding_to_left_padding(data["input_ids"], tokenizer.pad_token_id).to(device) + query_responses = generate( + accelerator.unwrap_model(policy).lm_backbone, + queries, + tokenizer, + generation_config, + ) + context_length = queries.shape[1] + responses = query_responses[:, context_length:] + + output, full_values = forward(policy, query_responses, tokenizer) + values = full_values[:, context_length - 1 : -1].squeeze(-1) + logits = output.logits[:, context_length - 1 : -1] + logits /= args.task.temperature + all_logprobs = F.log_softmax(logits, dim=-1) + logprobs = torch.gather(all_logprobs, 2, responses.unsqueeze(-1)).squeeze(-1) + del output, logits, all_logprobs + torch.cuda.empty_cache() + + ref_output, _ = forward(ref_policy, query_responses, tokenizer) + ref_logits = ref_output.logits[:, context_length - 1 : -1] + ref_logits /= args.task.temperature + ref_all_logprobs = F.log_softmax(ref_logits, dim=-1) + ref_logprobs = torch.gather(ref_all_logprobs, 2, responses.unsqueeze(-1)).squeeze(-1) + del ref_output, ref_logits, ref_all_logprobs + torch.cuda.empty_cache() + + # **Response Processing** + # 1. truncate at the first occurrence of `truncate_token` that appears at or after + # position truncate_after in the responses + # https://github.com/openai/lm-human-preferences/blob/cbfd210bb8b08f6bc5c26878c10984b90f516c66/lm_human_preferences/train_policy.py#L378 + truncate_token_mask = responses == args.task.truncate_token + truncate_after_or_token_mask = torch.cat( + [ + torch.zeros_like(truncate_token_mask)[:, : args.task.truncate_after], + truncate_token_mask[:, args.task.truncate_after :], + ], + dim=1, + ) + truncate_mask = (torch.cumsum(truncate_after_or_token_mask, dim=1) - truncate_after_or_token_mask.long()).bool() + postprocessed_responses = torch.where( + truncate_mask, + torch.full_like(responses, tokenizer.pad_token_id), + responses, + ) + del truncate_token_mask, truncate_after_or_token_mask, truncate_mask + torch.cuda.empty_cache() + + # 2. run reward model on the truncated responses + postprocessed_query_responses = torch.cat((queries, postprocessed_responses), 1) + postprocessed_query_responses = right_padding_to_left_padding( + postprocessed_query_responses, tokenizer.pad_token_id + ) + scores = get_reward(reward_model, postprocessed_query_responses, tokenizer).flatten() + + # 3. filter response. Ensure that the sample contains truncate_token + # responses not passing that filter will receive a low (fixed) score + # only query humans on responses that pass that filter + matches_token = postprocessed_responses[:, args.task.truncate_after :] == args.task.truncate_token + filter_mask = torch.any(matches_token, dim=-1) + scores = torch.where( + filter_mask, + scores, + torch.full_like(scores, args.task.penalty_reward_value), + ) + del matches_token, filter_mask + torch.cuda.empty_cache() + + # 4. compute rewards + kl = logprobs - ref_logprobs + non_score_reward = -kl_ctl.value * kl + rewards = non_score_reward.clone() + rewards[:, -1] += scores + + # 5. whiten rewards + if args.ppo.whiten_rewards: + rewards = whiten(rewards, shift_mean=False) + try: + sample_kl = kl[0].sum().item() + postprocessed_responses = postprocessed_query_responses[:, context_length:] + console.print( + f"[green]{tokenizer.decode(queries[0], skip_special_tokens=True)}[/]\n[yellow]{tokenizer.decode(postprocessed_responses[0], skip_special_tokens=True)}[/]\n[blue](NO POST-PROCESSING){tokenizer.decode(responses[0], skip_special_tokens=True)}[/]\n[red]score: {scores[0]}, kl: {kl[0].sum().item()}, total reward: {scores[0] - kl_ctl.value * sample_kl} [/]" + ) + except Exception as e: + print(e) + del postprocessed_query_responses + torch.cuda.empty_cache() + + # 6. compute advantages and returns + lastgaelam = 0 + advantages_reversed = [] + gen_length = args.task.response_length + for t in reversed(range(gen_length)): + nextvalues = values[:, t + 1] if t < gen_length - 1 else 0.0 + delta = rewards[:, t] + args.ppo.gamma * nextvalues - values[:, t] + lastgaelam = delta + args.ppo.gamma * args.ppo.lam * lastgaelam + advantages_reversed.append(lastgaelam) + advantages = torch.stack(advantages_reversed[::-1], axis=1) + returns = advantages + values + advantages = whiten(advantages) + return_mean, return_var = returns.mean(), returns.var() + value_mean, value_var = values.mean(), values.var() + + # Do multiple epochs of PPO training, with a fresh random shuffle in each epoch + for ppo_epoch_idx in range(args.ppo.noptepochs): + b_inds = np.random.permutation(args.ppo.local_batch_size) + minibatch_idx = 0 + for mini_batch_start in range(0, args.ppo.local_batch_size, args.ppo.local_mini_batch_size): + mini_batch_end = mini_batch_start + args.ppo.local_mini_batch_size + mini_batch_inds = b_inds[mini_batch_start:mini_batch_end] + gradient_accumulation_idx = 0 + for micro_batch_start in range(0, args.ppo.local_mini_batch_size, args.ppo.local_micro_batch_size): + with accelerator.accumulate(policy): + micro_batch_end = micro_batch_start + args.ppo.local_micro_batch_size + micro_batch_inds = mini_batch_inds[micro_batch_start:micro_batch_end] + mb_return = returns[micro_batch_inds] + mb_advantage = advantages[micro_batch_inds] + mb_values = values[micro_batch_inds] + mb_responses = responses[micro_batch_inds] + mb_query_responses = query_responses[micro_batch_inds] + mb_logprobs = logprobs[micro_batch_inds] + + output, vpred_temp = forward(policy, mb_query_responses, tokenizer) + logits = output.logits[:, context_length - 1 : -1] + logits /= args.task.temperature + new_all_logprobs = F.log_softmax(logits, dim=-1) + new_logprobs = torch.gather(new_all_logprobs, 2, mb_responses.unsqueeze(-1)).squeeze(-1) + vpred = vpred_temp[:, context_length - 1 : -1].squeeze(-1) + vpredclipped = torch.clamp( + vpred, + mb_values - args.ppo.cliprange_value, + mb_values + args.ppo.cliprange_value, + ) + vf_losses1 = torch.square(vpred - mb_return) + vf_losses2 = torch.square(vpredclipped - mb_return) + vf_loss = 0.5 * torch.max(vf_losses1, vf_losses2).mean() + vf_clipfrac = (vf_losses2 > vf_losses1).float().mean() + logprobs_diff = new_logprobs - mb_logprobs + ratio = torch.exp(logprobs_diff) + pg_losses = -mb_advantage * ratio + pg_losses2 = -mb_advantage * torch.clamp(ratio, 1.0 - args.ppo.cliprange, 1.0 + args.ppo.cliprange) + pg_loss = torch.max(pg_losses, pg_losses2).mean() + pg_clipfrac = (pg_losses2 > pg_losses).float().mean() + loss = pg_loss + args.ppo.vf_coef * vf_loss + accelerator.backward(loss) + optimizer.step() + optimizer.zero_grad() + pd = torch.nn.functional.softmax(logits, dim=-1) + entropy = torch.logsumexp(logits, dim=-1) - torch.sum(pd * logits, dim=-1) + approxkl = 0.5 * (logprobs_diff**2).mean() + with torch.no_grad(): + approxkls_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = approxkl + clipfracs_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = pg_clipfrac + pg_losses_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = pg_loss + vf_losses_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = vf_loss + vf_clipfrac_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = vf_clipfrac + entropies_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = entropy.mean() + gradient_accumulation_idx += 1 + minibatch_idx += 1 + if accelerator.is_main_process: + console.print( + f"ppo_epoch_idx", + ppo_epoch_idx, + "approxkl", + approxkl.item(), + "pg_loss", + pg_loss.item(), + "pg_clipfrac", + pg_clipfrac.item(), + "ratio", + ratio.mean().item(), + ) + + with torch.no_grad(): + if not args.deepspeed: # for some reason there is a OOM with the `writer.add_histogram` + writer.add_histogram("ppo/val/ratio_hist", ratio, update) + kl = logprobs - ref_logprobs + mean_kl = kl.sum(1).mean() + mean_entropy = (-logprobs).sum(1).mean() + mean_non_score_reward = non_score_reward.sum(1).mean() + writer.add_scalar("objective/kl_coef", kl_ctl.value, update) + writer.add_scalar("objective/kl", accelerator.gather(mean_kl).mean().item(), update) + writer.add_scalar( + "objective/entropy", + accelerator.gather(mean_entropy).mean().item(), + update, + ) + writer.add_scalar( + "objective/non_score_reward", + accelerator.gather(mean_non_score_reward).mean().item(), + update, + ) + writer.add_scalar( + "objective/score_total", + accelerator.gather(mean_non_score_reward + scores.mean()).mean().item(), + update, + ) + writer.add_scalar( + "objective/scores", + accelerator.gather(scores.mean()).mean().item(), + update, + ) + writer.add_scalar("ppo/loss/policy", accelerator.gather(pg_loss).mean().item(), update) + writer.add_scalar("ppo/loss/value", accelerator.gather(vf_loss).mean().item(), update) + writer.add_scalar("ppo/loss/total", accelerator.gather(loss).mean().item(), update) + writer.add_scalar( + "ppo/policy/entropy", + accelerator.gather(entropy.mean()).mean().item(), + update, + ) + writer.add_scalar( + "ppo/policy/approxkl", + accelerator.gather(approxkl).mean().item(), + update, + ) + writer.add_scalar( + "ppo/policy/clipfrac", + accelerator.gather(pg_clipfrac).mean().item(), + update, + ) + writer.add_scalar( + "ppo/policy/approxkl_avg", + accelerator.gather(approxkls_stats).mean().item(), + update, + ) + writer.add_scalar( + "ppo/policy/clipfrac_avg", + accelerator.gather(clipfracs_stats).mean().item(), + update, + ) + writer.add_scalar( + "ppo/loss/policy_avg", + accelerator.gather(pg_losses_stats).mean().item(), + update, + ) + writer.add_scalar( + "ppo/loss/value_avg", + accelerator.gather(vf_losses_stats).mean().item(), + update, + ) + writer.add_scalar( + "ppo/val/clipfrac_avg", + accelerator.gather(vf_clipfrac_stats).mean().item(), + update, + ) + writer.add_scalar( + "ppo/policy/entropy_avg", + accelerator.gather(entropies_stats).mean().item(), + update, + ) + writer.add_scalar( + "ppo/returns/mean", + accelerator.gather(return_mean).mean().item(), + update, + ) + writer.add_scalar("ppo/returns/var", accelerator.gather(return_var).mean().item(), update) + writer.add_scalar("ppo/val/vpred", accelerator.gather(vpred.mean()).mean().item(), update) + writer.add_scalar( + "ppo/val/error", + accelerator.gather(vf_losses1.mean()).mean().item(), + update, + ) + writer.add_scalar( + "ppo/val/clipfrac", + accelerator.gather(vf_clipfrac).mean().item(), + update, + ) + writer.add_scalar("ppo/val/mean", accelerator.gather(value_mean).mean().item(), update) + writer.add_scalar("ppo/val/var", accelerator.gather(value_var).mean().item(), update) + writer.add_scalar("ppo/val/ratio", accelerator.gather(ratio.mean()).mean().item(), update) + writer.add_scalar( + "ppo/val/ratio_var", + accelerator.gather(ratio.mean()).var().item(), + update, + ) + writer.add_scalar( + "ppo/val/advantage", + accelerator.gather(advantages.mean()).mean().item(), + update, + ) + writer.add_scalar( + "ppo/val/advantage_var", + accelerator.gather(advantages.mean()).var().item(), + update, + ) + writer.add_scalar( + "ppo/val/num_eos_tokens", + (responses == tokenizer.eos_token_id).sum().item(), + update, + ) + writer.add_scalar("ppo/lr", lrnow, update) + writer.add_scalar("ppo/episode", global_step, update) + kl_ctl.update(mean_kl.item(), args.ppo.batch_size) + del kl, mean_kl, mean_entropy, mean_non_score_reward, scores + + # save model + if args.save_path: + os.makedirs(os.path.dirname(args.save_path), exist_ok=True) + torch.save(reward_model.state_dict(), args.save_path) + + +if __name__ == "__main__": + args = tyro.cli(Args) + train(args) diff --git a/lm_human_preference_details/summarization/train_policy_accelerate_summarize_ref_diff.py b/lm_human_preference_details/summarization/train_policy_accelerate_summarize_ref_diff.py new file mode 100644 index 0000000..ee56755 --- /dev/null +++ b/lm_human_preference_details/summarization/train_policy_accelerate_summarize_ref_diff.py @@ -0,0 +1,889 @@ +import os +import random +import time +from dataclasses import asdict, dataclass, field +from types import SimpleNamespace +from typing import List, Optional + +import numpy as np +import pandas as pd +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +import tyro +from accelerate import Accelerator +from accelerate.state import AcceleratorState +from datasets import load_dataset +from rich.console import Console +from rich.pretty import pprint +from rich.table import Table +from torch import Tensor, optim +from torch.optim.optimizer import ( + _dispatch_sqrt, + _get_value, + _use_grad_for_differentiable, +) +from torch.utils.data import DataLoader +from torch.utils.tensorboard import SummaryWriter +from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig + +from lm_human_preference_details.data import process_query + + +@dataclass +class AdaptiveKLParams: + target: float = 6.0 + horizon: int = 10000 # in episodes + + +@dataclass +class RewardHParams: + kl_coef: float = 0.15 + use_adaptive_kl: bool = True + adaptive_kl: Optional[AdaptiveKLParams] = field(default_factory=AdaptiveKLParams) + trained_model: Optional[str] = "models/reward.pt" + label_dataset: tyro.conf.Suppress[Optional[str]] = None + + +@dataclass +class PpoHParams: + total_episodes: int = 1000000 + local_batch_size: int = 64 + local_mini_batch_size: tyro.conf.Suppress[int] = None + batch_size: tyro.conf.Suppress[int] = None + mini_batch_size: tyro.conf.Suppress[int] = None + gradient_accumulation_steps: int = 1 + """gradient accumulation steps""" + local_micro_batch_size: tyro.conf.Suppress[int] = None + """per rank micro batch size""" + world_size: tyro.conf.Suppress[int] = None + batch_size: tyro.conf.Suppress[int] = None + minibatch_size: tyro.conf.Suppress[int] = None + num_updates: tyro.conf.Suppress[int] = None + nminibatches: int = 1 + noptepochs: int = 4 + lr: float = 0.00001 + eps: float = 1e-5 + vf_coef: float = 0.1 + cliprange: float = 0.2 + cliprange_value: float = 0.2 + gamma: float = 1 + lam: float = 0.95 + whiten_rewards: bool = True + + +@dataclass +class TaskHParams: + # Query params + query_length: int = 512 + query_dataset: str = "vwxyzjn/summarize_from_feedback_tldr_3_filtered" + + query_format_str: Optional[str] = "SUBREDDIT: r/{subreddit}\n\nTITLE: {title}\n\nPOST: {post}\n\nTL;DR:" + query_truncate_field: Optional[str] = "post" + query_truncate_text: Optional[str] = "\n" + query_padding: Optional[str] = None # defaults to repeated spaces + query_pad_side: Optional[str] = "left" + + # Response params + response_length: int = 48 + + # Truncate response after the first occurrence of this token at or after index after when sampling. + truncate_token: int = 50256 # EOS token + truncate_after: int = 16 + penalty_reward_value: int = -1 + + # LM params + temperature: float = 0.7 + + +# a patch +@dataclass +class TaskQueryHParams: + length: int = None + dataset: str = None + format_str: Optional[str] = None # if underlying dataset yields dicts, can format arbitrarily + truncate_field: Optional[str] = None + truncate_text: Optional[str] = None + padding: Optional[str] = None # defaults to repeated spaces + pad_side: Optional[str] = None + + +@dataclass +class Args: + # common args + exp_name: str = os.path.basename(__file__)[: -len(".py")] + """the name of this experiment""" + seed: int = 1 + """seed of the experiment""" + track: bool = False + """if toggled, this experiment will be tracked with Weights and Biases""" + wandb_project_name: str = "cleanrl" + """the wandb's project name""" + wandb_entity: Optional[str] = None + """the entity (team) of wandb's project""" + cuda: bool = True + """Whether to use cuda if available.""" + run_name: tyro.conf.Suppress[str] = None + """TO BE FILLED: a unique name of this run""" + upload_model: bool = False + "whether to upload the saved model to huggingface" + hf_entity: str = "" + "the user or org name of the model repository from the Hugging Face Hub" + + base_model: str = "gpt2" + """the name of the pretrained model to use""" + deepspeed: bool = False + """Whether to use deepspeed to train the model""" + print_sample_output_freq: int = 10 + """How often to print sample output""" + sft_model_path: str = "models/sft_policy.pt" + """Where to load the SFT model""" + save_path: str = "models/policy.pt" + """Where to save the model""" + use_tensorflow_adam: bool = True + """Whether to use tensorflow-style Adam optimizer instead of PyTorch's""" + task: TaskHParams = field(default_factory=TaskHParams) + rewards: RewardHParams = field(default_factory=RewardHParams) + ppo: PpoHParams = field(default_factory=PpoHParams) + + +def first_true_indices(bools, dtype=torch.long): + """ + Takes an N-dimensional bool tensor and returns an (N-1)-dimensional tensor of integers giving + the position of the first True in each "row". + + Returns the length of the rows (bools.size(-1)) if no element is True in a given row. + """ + row_len = bools.size(-1) + zero_or_index = row_len * (~bools).type(dtype) + torch.arange(row_len, dtype=dtype, device=bools.device) + return torch.min(zero_or_index, dim=-1).values + + +def print_rich_table(title: str, df: pd.DataFrame, console: Console) -> Table: + table = Table(show_lines=True) + for column in df.columns: + table.add_column(column) + for _, row in df.iterrows(): + table.add_row(*row.astype(str).tolist()) + console.rule(f"[bold red]{title}") + console.print(table) + + +def _single_tensor_adam( + params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_avg_sqs: List[Tensor], + max_exp_avg_sqs: List[Tensor], + state_steps: List[Tensor], + grad_scale: Optional[Tensor], + found_inf: Optional[Tensor], + *, + amsgrad: bool, + beta1: float, + beta2: float, + lr: float, + weight_decay: float, + eps: float, + maximize: bool, + capturable: bool, + differentiable: bool, +): + assert grad_scale is None and found_inf is None + + for i, param in enumerate(params): + grad = grads[i] if not maximize else -grads[i] + exp_avg = exp_avgs[i] + exp_avg_sq = exp_avg_sqs[i] + step_t = state_steps[i] + # update step + step_t += 1 + # Decay the first and second moment running average coefficient + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2) + step = _get_value(step_t) + + ### pytorch adam implementation: + # bias_correction1 = 1 - beta1 ** step + # bias_correction2 = 1 - beta2 ** step + # step_size = lr / bias_correction1 + # bias_correction2_sqrt = _dispatch_sqrt(bias_correction2) + # denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps) + # param.addcdiv_(exp_avg, denom, value=-step_size) + + ### tensorflow adam implementation: + lr_t = lr * _dispatch_sqrt(1 - beta2**step) / (1 - beta1**step) + denom = exp_avg_sq.sqrt().add_(eps) + param.addcdiv_(exp_avg, denom, value=-lr_t) + + +def adam( + params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_avg_sqs: List[Tensor], + max_exp_avg_sqs: List[Tensor], + state_steps: List[Tensor], + # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627 + # setting this as kwarg for now as functional API is compiled by torch/distributed/optim + foreach: Optional[bool] = None, + capturable: bool = False, + differentiable: bool = False, + fused: Optional[bool] = None, + grad_scale: Optional[Tensor] = None, + found_inf: Optional[Tensor] = None, + *, + amsgrad: bool, + beta1: float, + beta2: float, + lr: float, + weight_decay: float, + eps: float, + maximize: bool, +): + func = _single_tensor_adam + + func( + params, + grads, + exp_avgs, + exp_avg_sqs, + max_exp_avg_sqs, + state_steps, + amsgrad=amsgrad, + beta1=beta1, + beta2=beta2, + lr=lr, + weight_decay=weight_decay, + eps=eps, + maximize=maximize, + capturable=capturable, + differentiable=differentiable, + grad_scale=grad_scale, + found_inf=found_inf, + ) + + +class AdamTensorFlowStyle(optim.Adam): + @_use_grad_for_differentiable + def step(self, closure=None): + self._cuda_graph_capture_health_check() + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + params_with_grad = [] + grads = [] + exp_avgs = [] + exp_avg_sqs = [] + max_exp_avg_sqs = [] + state_steps = [] + beta1, beta2 = group["betas"] + + self._init_group( + group, + params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + max_exp_avg_sqs, + state_steps, + ) + + adam( + params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + max_exp_avg_sqs, + state_steps, + amsgrad=group["amsgrad"], + beta1=beta1, + beta2=beta2, + lr=group["lr"], + weight_decay=group["weight_decay"], + eps=group["eps"], + maximize=group["maximize"], + foreach=group["foreach"], + capturable=group["capturable"], + differentiable=group["differentiable"], + fused=group["fused"], + grad_scale=getattr(self, "grad_scale", None), + found_inf=getattr(self, "found_inf", None), + ) + + return loss + + +class AdaptiveKLController: + def __init__(self, init_kl_coef: float, hparams: AdaptiveKLParams): + self.value = init_kl_coef + self.hparams = hparams + + def update(self, current, n_steps): + target = self.hparams.target + proportional_error = np.clip(current / target - 1, -0.2, 0.2) + mult = 1 + proportional_error * n_steps / self.hparams.horizon + self.value *= mult + + +def layer_init(layer, std=np.sqrt(2), bias_const=0.0): + torch.nn.init.normal_(layer.weight, std=std) + torch.nn.init.constant_(layer.bias, val=bias_const) + return layer + + +def whiten(values, shift_mean=True): + # `unbiased=False` matches TF `tf.nn.moments`'s setting + mean, var = torch.mean(values), torch.var(values, unbiased=False) + whitened = (values - mean) * torch.rsqrt(var + 1e-8) + if not shift_mean: + whitened += mean + return whitened + + +class AutoModelForCausalLMWithScalarHead(nn.Module): + def __init__(self, lm_backbone): + super().__init__() + self.lm_backbone = lm_backbone + self.scalar_head = layer_init(nn.Linear(lm_backbone.config.hidden_size, 1), std=0) + + def forward(self, **kwargs): + output = self.lm_backbone(**kwargs) + return output, self.scalar_head(output.hidden_states[-1]) + + +class AutoModelForCausalLMWithRewardHead(nn.Module): + def __init__(self, lm_backbone): + super().__init__() + self.lm_backbone = lm_backbone + self.scalar_head = layer_init( + nn.Linear(lm_backbone.config.hidden_size, 1), + std=1 / np.sqrt(lm_backbone.config.hidden_size + 1), + ) + self.reward_gain = torch.nn.Parameter(torch.tensor(1.0), requires_grad=True) + self.reward_bias = torch.nn.Parameter(torch.tensor(0.0), requires_grad=True) + + def forward(self, **kwargs): + output = self.lm_backbone(**kwargs) + reward_latents = output.hidden_states[-1] + # shape: [batch_size, length, hidden_size] + last_reward_latents = reward_latents + # shape: [batch_size, hidden_size] + reward = self.scalar_head(last_reward_latents) + # shape: [batch_size, 1] + reward = self.reward_gain * reward + self.reward_bias + return output, reward + + +def right_padding_to_left_padding(tokens, pad_id): + """Convert from right padding to left padding.""" + assert tokens.ndim == 2 + return torch.tensor( + [[pad_id] * (row == pad_id).sum() + [x for x in row if x != pad_id] for row in tokens], + device=tokens.device, + ) + + +def ceil_div(a, b): + return (a - 1) // b + 1 + + +def exact_div(a, b): + q = a // b + if a != q * b: + raise ValueError(f"Inexact division: {a} / {b} = {a / b}") + return q + + +def generate(lm_backbone, queries, tokenizer, generation_config): + """generate in a way that does not affect padding tokens""" + context_length = queries.shape[1] + attention_mask = queries != tokenizer.pad_token_id + input_ids = queries.clone() + input_ids[~attention_mask] = 0 # set padding tokens to 0 + output = lm_backbone.generate( + input_ids=input_ids, + attention_mask=attention_mask, + # position_ids=attention_mask.cumsum(1) - attention_mask.long(), # generation collapsed if this was turned on. TODO: why does generation collapse with this? + generation_config=generation_config, + return_dict_in_generate=True, + ) + # restore padding tokens + return torch.cat((queries, output.sequences[:, context_length:]), dim=1) + + +def get_reward(reward_model, query_responses, args): + attention_mask = query_responses != args.pad_token_id + position_ids = attention_mask.cumsum(1) - attention_mask.long() # exclusive cumsum + input_ids = query_responses.clone() + input_ids[~attention_mask] = 0 + return reward_model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + return_dict=True, + output_hidden_states=True, + ) + + +def forward(policy, query_responses, tokenizer): + attention_mask = query_responses != tokenizer.pad_token_id + position_ids = attention_mask.cumsum(1) - attention_mask.long() # exclusive cumsum + input_ids = query_responses.clone() + input_ids[~attention_mask] = 0 + return policy( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + return_dict=True, + output_hidden_states=True, + ) + + +def train(args: Args): + accelerator = Accelerator(gradient_accumulation_steps=args.ppo.gradient_accumulation_steps) + args.ppo.world_size = accelerator.num_processes + args.ppo.batch_size = int(args.ppo.local_batch_size * args.ppo.world_size) + args.ppo.minibatch_size = exact_div(args.ppo.batch_size, args.ppo.nminibatches) + args.ppo.local_mini_batch_size = exact_div(args.ppo.local_batch_size, args.ppo.nminibatches) + args.ppo.local_micro_batch_size = exact_div(args.ppo.local_mini_batch_size, args.ppo.gradient_accumulation_steps) + patch_h = TaskQueryHParams( + length=args.task.query_length, + dataset=args.task.query_dataset, + format_str=args.task.query_format_str, + truncate_field=args.task.query_truncate_field, + truncate_text=args.task.query_truncate_text, + padding=args.task.query_padding, + pad_side=args.task.query_pad_side, + ) + if args.ppo.whiten_rewards: + assert ( + args.ppo.local_mini_batch_size >= 8 + ), f"Per-rank minibatch size {args.ppo.local_mini_batch_size} is insufficient for whitening" + # `per_rank_rollout_batch_size` is our `args.ppo.local_batch_size` + # `per_rank_minibatch_size` is our `args.ppo.local_mini_batch_size` + args.ppo.num_updates = args.ppo.total_episodes // args.ppo.batch_size + + console = Console(force_terminal=True) + run_name = f"{args.exp_name}__{args.seed}__{int(time.time())}" + writer = SimpleNamespace() # dummy writer + writer.add_scalar = lambda x, y, z: None + writer.add_histogram = lambda x, y, z: None + if accelerator.is_main_process: + if args.track: + import wandb + + wandb.init( + project=args.wandb_project_name, + entity=args.wandb_entity, + sync_tensorboard=True, + config=asdict(args), + name=run_name, + save_code=True, + ) + wandb.run.log_code(".") + writer = SummaryWriter(f"runs/{run_name}") + writer.add_text( + "hyperparameters", + "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])), + ) + pprint(args) + device = accelerator.device + local_seed = args.seed + accelerator.process_index * 100003 # Prime + random.seed(local_seed) + np.random.seed(local_seed) + torch.manual_seed(local_seed) + torch.backends.cudnn.deterministic = True + tokenizer = AutoTokenizer.from_pretrained( + args.base_model, + padding_side="right", + trust_remote_code=True, + ) + # we use the padding token manually but do not resize the token embedding of the model + tokenizer.add_special_tokens({"pad_token": "[PAD]"}) + reward_model = AutoModelForCausalLMWithRewardHead( + AutoModelForCausalLM.from_pretrained(args.base_model, trust_remote_code=True) + ) + if args.rewards.trained_model: + reward_model.load_state_dict(torch.load(args.rewards.trained_model, map_location=device)) + print(f"loaded pretrained reward model from {args.rewards.trained_model}") + # each class should have a separate pretrained model that do not share weights + ref_policy = AutoModelForCausalLMWithScalarHead( + AutoModelForCausalLM.from_pretrained(args.base_model, trust_remote_code=True) + ) + policy = AutoModelForCausalLMWithScalarHead(AutoModelForCausalLM.from_pretrained(args.base_model, trust_remote_code=True)) + if args.sft_model_path: + policy.lm_backbone.load_state_dict(torch.load(args.sft_model_path, map_location=device)) + ref_policy.lm_backbone.load_state_dict(torch.load(args.sft_model_path, map_location=device)) + print(f"loaded pretrained policy from {args.sft_model_path}") + policy.lm_backbone.generation_config.eos_token_id = ( + None # disable `pad_token_id` and `eos_token_id` because we just want to + ) + policy.lm_backbone.generation_config.pad_token_id = None # generate tokens without truncation / padding + # IMPORTANT: Layer norm produces weird gradients, which affects Adam optimizer to impact all the parameters systematically + # see https://github.com/pytorch/pytorch/issues/104857 for more details + if args.use_tensorflow_adam: + optimizer = AdamTensorFlowStyle(policy.parameters(), lr=args.ppo.lr, eps=args.ppo.eps) + else: + optimizer = optim.Adam(policy.parameters(), lr=args.ppo.lr, eps=args.ppo.eps) + dataset = load_dataset(args.task.query_dataset, split="train") + + def process_query_data(x): + return { + **process_query(x, encoder=tokenizer, hparams=patch_h), + "reference_response": tokenizer.encode( + f" {x['summary']}", padding="max_length", max_length=args.task.response_length, truncation=True, + # with an extra leading space to account for the space between the query and response + ), + } + + dataset = dataset.map(process_query_data) + dataset = dataset.with_format("torch", columns=["query_token", "reference_response"]) + dataset = dataset.shuffle(seed=local_seed) + dataloader = DataLoader(dataset, batch_size=args.ppo.local_batch_size) + policy, optimizer, dataloader = accelerator.prepare(policy, optimizer, dataloader) + if args.deepspeed: + import deepspeed + + deepspeed_states = AcceleratorState().deepspeed_plugin + # deepspeed_states.deepspeed_config["train_micro_batch_size_per_gpu"] = args.ppo.local_micro_batch_size + # deepspeed_states.deepspeed_config["checkpoint"] = {"use_node_local_storage": True} + eval_ds_config = { + "train_micro_batch_size_per_gpu": deepspeed_states.deepspeed_config["train_micro_batch_size_per_gpu"], + # "steps_per_print": 10, + # "zero_optimization": { + # "stage": stage, + # "stage3_param_persistence_threshold": 1e4, + # "offload_param": { + # "device": off_load_device + # } + # }, + "bf16": {"enabled": True}, + "prescale_gradients": False, + "wall_clock_breakdown": False, + } + reward_model, *_ = deepspeed.initialize(model=reward_model, config=eval_ds_config) + reward_model.eval() + ref_policy, *_ = deepspeed.initialize(model=ref_policy, config=eval_ds_config) + ref_policy.eval() + else: + ref_policy = ref_policy.to(device) + reward_model = reward_model.to(device) + + def repeat_generator(): # TODO: ideally we shuffle the dataloader as well + while True: + yield from dataloader + + iter_dataloader = iter(repeat_generator()) + kl_ctl = AdaptiveKLController(args.rewards.kl_coef, hparams=args.rewards.adaptive_kl) + # WARNING: even with `max_new_tokens` and `min_new_tokens` set to the same value, the number of tokens generated + # may not be the same. TODO: investigate further, we just want to generate a fixed number of tokens + generation_config = GenerationConfig( + max_new_tokens=args.task.response_length, + min_new_tokens=args.task.response_length, + temperature=(args.task.temperature + 1e-7), + top_k=0.0, + top_p=1.0, + do_sample=True, + ) + + print("===training policy===") + global_step = 0 + stats_shape = (args.ppo.noptepochs, args.ppo.nminibatches, args.ppo.gradient_accumulation_steps) + approxkls_stats = torch.zeros(stats_shape, device=device) + clipfracs_stats = torch.zeros(stats_shape, device=device) + pg_losses_stats = torch.zeros(stats_shape, device=device) + vf_losses_stats = torch.zeros(stats_shape, device=device) + vf_clipfrac_stats = torch.zeros(stats_shape, device=device) + entropies_stats = torch.zeros(stats_shape, device=device) + for update in range(1, args.ppo.num_updates + 1): + global_step += 1 * args.ppo.batch_size + frac = 1.0 - (update - 1.0) / args.ppo.num_updates + lrnow = frac * args.ppo.lr + optimizer.param_groups[0]["lr"] = lrnow + data = next(iter_dataloader) + with torch.no_grad(): + queries = data["query_token"].to(device) + reference_responses = data["reference_response"].to(device) + query_reference_responses = torch.cat((queries, reference_responses), dim=1) + queries = right_padding_to_left_padding(data["query_token"], tokenizer.pad_token_id).to(device) + query_reference_responses = right_padding_to_left_padding(query_reference_responses, tokenizer.pad_token_id).to(device) + query_responses = generate( + accelerator.unwrap_model(policy).lm_backbone, + queries, + tokenizer, + generation_config, + ) + context_length = queries.shape[1] + responses = query_responses[:, context_length:] + + output, full_values = forward(policy, query_responses, tokenizer) + values = full_values[:, context_length - 1 : -1].squeeze(-1) + logits = output.logits[:, context_length - 1 : -1] + logits /= (args.task.temperature + 1e-7) + all_logprobs = F.log_softmax(logits, dim=-1) + logprobs = torch.gather(all_logprobs, 2, responses.unsqueeze(-1)).squeeze(-1) + del output, logits, all_logprobs + torch.cuda.empty_cache() + + ref_output, _ = forward(ref_policy, query_responses, tokenizer) + ref_logits = ref_output.logits[:, context_length - 1 : -1] + ref_logits /= (args.task.temperature + 1e-7) + ref_all_logprobs = F.log_softmax(ref_logits, dim=-1) + ref_logprobs = torch.gather(ref_all_logprobs, 2, responses.unsqueeze(-1)).squeeze(-1) + del ref_output, ref_logits, ref_all_logprobs + torch.cuda.empty_cache() + + # **Response Processing** + # 1. truncate at the first occurrence of `truncate_token` that appears at or after + # position truncate_after in the responses + # https://github.com/openai/lm-human-preferences/blob/cbfd210bb8b08f6bc5c26878c10984b90f516c66/lm_human_preferences/train_policy.py#L378 + truncate_token_mask = responses == args.task.truncate_token + truncate_after_or_token_mask = torch.cat( + [ + torch.zeros_like(truncate_token_mask)[:, : args.task.truncate_after], + truncate_token_mask[:, args.task.truncate_after :], + ], + dim=1, + ) + truncate_mask = (torch.cumsum(truncate_after_or_token_mask, dim=1) - truncate_after_or_token_mask.long()).bool() + postprocessed_responses = torch.where( + truncate_mask, + torch.full_like(responses, tokenizer.pad_token_id), + responses, + ) + del truncate_token_mask, truncate_after_or_token_mask, truncate_mask + torch.cuda.empty_cache() + + # 2. run reward model on the truncated responses + postprocessed_query_responses = torch.cat((queries, postprocessed_responses), 1) + postprocessed_query_responses = right_padding_to_left_padding( + postprocessed_query_responses, tokenizer.pad_token_id + ) + scores = get_reward(reward_model, postprocessed_query_responses, tokenizer)[1] + last_response_indices = first_true_indices(query_responses == tokenizer.pad_token_id) - 1 + last_response_indices = torch.max( + last_response_indices, + torch.zeros([1], dtype=last_response_indices.dtype, device=query_responses.device), + ) + scores = scores[:, :, 0].gather(1, last_response_indices.unsqueeze(1)).view(-1) + + reference_scores = get_reward(reward_model, query_reference_responses, tokenizer)[1] + last_reference_response_indices = first_true_indices(query_reference_responses == tokenizer.pad_token_id) - 1 + last_reference_response_indices = torch.max( + last_reference_response_indices, + torch.zeros([1], dtype=last_reference_response_indices.dtype, device=query_reference_responses.device), + ) + reference_scores = reference_scores[:, :, 0].gather(1, last_reference_response_indices.unsqueeze(1)).view(-1) + + # 3. filter response. Ensure that the sample contains truncate_token + # responses not passing that filter will receive a low (fixed) score + # only query humans on responses that pass that filter + matches_token = postprocessed_responses[:, args.task.truncate_after :] == args.task.truncate_token + filter_mask = torch.any(matches_token, dim=-1) + scores = torch.where( + filter_mask, + scores, + torch.full_like(scores, args.task.penalty_reward_value), + ) + del matches_token, filter_mask + torch.cuda.empty_cache() + + # 4. compute rewards + kl = logprobs - ref_logprobs + non_score_reward = -kl_ctl.value * kl + rewards = non_score_reward.clone() + rewards[:, -1] += scores + + # 5. whiten rewards + if args.ppo.whiten_rewards: + rewards = whiten(rewards, shift_mean=False) + + if args.print_sample_output_freq > 0 and (update - 1) % args.print_sample_output_freq == 0: + try: + all_decode_queries = tokenizer.batch_decode(queries, skip_special_tokens=True) + all_postprocessed_query_responses = tokenizer.batch_decode( + postprocessed_query_responses, skip_special_tokens=True + ) + all_postprocessed_responses = [ + x[len(y) :] for x, y in zip(all_postprocessed_query_responses, all_decode_queries) + ] + all_reference_responses = tokenizer.batch_decode(reference_responses, skip_special_tokens=True) + + kl_sum = kl.sum(axis=1) + all_df = pd.DataFrame( + { + "query": all_decode_queries, + "response": all_postprocessed_responses, + "reference_responses": all_reference_responses, + "score": scores.float().cpu().numpy(), + "reference_scores": reference_scores.float().cpu().numpy(), + "kl": kl_sum.float().cpu().numpy(), + "reward": (scores - kl_ctl.value * kl_sum).float().cpu().numpy(), + } + ) + if accelerator.is_main_process and args.track: + wandb.log({"query_responses": wandb.Table(dataframe=all_df)}, step=update) + print_rich_table("stuff", all_df[:4], console) + except Exception as e: + print(e) + del ( + all_decode_queries, + all_postprocessed_query_responses, + all_postprocessed_responses, + kl_sum, + all_df, + ) + del postprocessed_query_responses + torch.cuda.empty_cache() + + # 6. compute advantages and returns + lastgaelam = 0 + advantages_reversed = [] + gen_length = args.task.response_length + for t in reversed(range(gen_length)): + nextvalues = values[:, t + 1] if t < gen_length - 1 else 0.0 + delta = rewards[:, t] + args.ppo.gamma * nextvalues - values[:, t] + lastgaelam = delta + args.ppo.gamma * args.ppo.lam * lastgaelam + advantages_reversed.append(lastgaelam) + advantages = torch.stack(advantages_reversed[::-1], axis=1) + returns = advantages + values + advantages = whiten(advantages) + return_mean, return_var = returns.mean(), returns.var() + value_mean, value_var = values.mean(), values.var() + + # Do multiple epochs of PPO training, with a fresh random shuffle in each epoch + for ppo_epoch_idx in range(args.ppo.noptepochs): + b_inds = np.random.permutation(args.ppo.local_batch_size) + minibatch_idx = 0 + for mini_batch_start in range(0, args.ppo.local_batch_size, args.ppo.local_mini_batch_size): + mini_batch_end = mini_batch_start + args.ppo.local_mini_batch_size + mini_batch_inds = b_inds[mini_batch_start:mini_batch_end] + gradient_accumulation_idx = 0 + for micro_batch_start in range(0, args.ppo.local_mini_batch_size, args.ppo.local_micro_batch_size): + with accelerator.accumulate(policy): + micro_batch_end = micro_batch_start + args.ppo.local_micro_batch_size + micro_batch_inds = mini_batch_inds[micro_batch_start:micro_batch_end] + mb_return = returns[micro_batch_inds] + mb_advantage = advantages[micro_batch_inds] + mb_values = values[micro_batch_inds] + mb_responses = responses[micro_batch_inds] + mb_query_responses = query_responses[micro_batch_inds] + mb_logprobs = logprobs[micro_batch_inds] + + output, vpred_temp = forward(policy, mb_query_responses, tokenizer) + logits = output.logits[:, context_length - 1 : -1] + logits /= (args.task.temperature + 1e-7) + new_all_logprobs = F.log_softmax(logits, dim=-1) + new_logprobs = torch.gather(new_all_logprobs, 2, mb_responses.unsqueeze(-1)).squeeze(-1) + vpred = vpred_temp[:, context_length - 1 : -1].squeeze(-1) + vpredclipped = torch.clamp( + vpred, + mb_values - args.ppo.cliprange_value, + mb_values + args.ppo.cliprange_value, + ) + vf_losses1 = torch.square(vpred - mb_return) + vf_losses2 = torch.square(vpredclipped - mb_return) + vf_loss = 0.5 * torch.max(vf_losses1, vf_losses2).mean() + vf_clipfrac = (vf_losses2 > vf_losses1).float().mean() + logprobs_diff = new_logprobs - mb_logprobs + ratio = torch.exp(logprobs_diff) + pg_losses = -mb_advantage * ratio + pg_losses2 = -mb_advantage * torch.clamp(ratio, 1.0 - args.ppo.cliprange, 1.0 + args.ppo.cliprange) + pg_loss = torch.max(pg_losses, pg_losses2).mean() + pg_clipfrac = (pg_losses2 > pg_losses).float().mean() + loss = pg_loss + args.ppo.vf_coef * vf_loss + accelerator.backward(loss) + optimizer.step() + optimizer.zero_grad() + prob_dist = torch.nn.functional.softmax(logits, dim=-1) + entropy = torch.logsumexp(logits, dim=-1) - torch.sum(prob_dist * logits, dim=-1) + approxkl = 0.5 * (logprobs_diff**2).mean() + with torch.no_grad(): + approxkls_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = approxkl + clipfracs_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = pg_clipfrac + pg_losses_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = pg_loss + vf_losses_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = vf_loss + vf_clipfrac_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = vf_clipfrac + entropies_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = entropy.mean() + gradient_accumulation_idx += 1 + minibatch_idx += 1 + if accelerator.is_main_process: + console.print( + f"ppo_epoch_idx", + ppo_epoch_idx, + "approxkl", + approxkl.item(), + "pg_loss", + pg_loss.item(), + "pg_clipfrac", + pg_clipfrac.item(), + "ratio", + ratio.mean().item(), + ) + + with torch.no_grad(): + if not args.deepspeed: # for some reason there is a OOM with the `writer.add_histogram` + writer.add_histogram("ppo/val/ratio_hist", ratio, update) + kl = logprobs - ref_logprobs + mean_kl = kl.sum(1).mean() + mean_entropy = (-logprobs).sum(1).mean() + mean_non_score_reward = non_score_reward.sum(1).mean() + writer.add_scalar("objective/kl_coef", kl_ctl.value, update) + writer.add_scalar("objective/kl", accelerator.gather(mean_kl).mean().item(), update) + writer.add_scalar("objective/entropy", accelerator.gather(mean_entropy).mean().item(), update) + writer.add_scalar("objective/non_score_reward", accelerator.gather(mean_non_score_reward).mean().item(), update) + writer.add_scalar( + "objective/score_total", accelerator.gather(mean_non_score_reward + scores.mean()).mean().item(), update + ) + writer.add_scalar("objective/scores", accelerator.gather(scores.mean()).mean().item(), update) + writer.add_scalar("objective/reference_scores", accelerator.gather(reference_scores.mean()).mean().item(), update) + writer.add_scalar("ppo/loss/policy", accelerator.gather(pg_loss).mean().item(), update) + writer.add_scalar("ppo/loss/value", accelerator.gather(vf_loss).mean().item(), update) + writer.add_scalar("ppo/loss/total", accelerator.gather(loss).mean().item(), update) + writer.add_scalar("ppo/policy/entropy", accelerator.gather(entropy.mean()).mean().item(), update) + writer.add_scalar("ppo/policy/approxkl", accelerator.gather(approxkl).mean().item(), update) + writer.add_scalar("ppo/policy/clipfrac", accelerator.gather(pg_clipfrac).mean().item(), update) + writer.add_scalar("ppo/policy/approxkl_avg", accelerator.gather(approxkls_stats).mean().item(), update) + writer.add_scalar("ppo/policy/clipfrac_avg", accelerator.gather(clipfracs_stats).mean().item(), update) + writer.add_scalar("ppo/loss/policy_avg", accelerator.gather(pg_losses_stats).mean().item(), update) + writer.add_scalar("ppo/loss/value_avg", accelerator.gather(vf_losses_stats).mean().item(), update) + writer.add_scalar("ppo/val/clipfrac_avg", accelerator.gather(vf_clipfrac_stats).mean().item(), update) + writer.add_scalar("ppo/policy/entropy_avg", accelerator.gather(entropies_stats).mean().item(), update) + writer.add_scalar("ppo/returns/mean", accelerator.gather(return_mean).mean().item(), update) + writer.add_scalar("ppo/returns/var", accelerator.gather(return_var).mean().item(), update) + writer.add_scalar("ppo/val/vpred", accelerator.gather(vpred.mean()).mean().item(), update) + writer.add_scalar("ppo/val/error", accelerator.gather(vf_losses1.mean()).mean().item(), update) + writer.add_scalar("ppo/val/clipfrac", accelerator.gather(vf_clipfrac).mean().item(), update) + writer.add_scalar("ppo/val/mean", accelerator.gather(value_mean).mean().item(), update) + writer.add_scalar("ppo/val/var", accelerator.gather(value_var).mean().item(), update) + writer.add_scalar("ppo/val/ratio", accelerator.gather(ratio.mean()).mean().item(), update) + writer.add_scalar("ppo/val/ratio_var", accelerator.gather(ratio.mean()).var().item(), update) + writer.add_scalar("ppo/val/advantage", accelerator.gather(advantages.mean()).mean().item(), update) + writer.add_scalar("ppo/val/advantage_var", accelerator.gather(advantages.mean()).var().item(), update) + writer.add_scalar("ppo/val/num_eos_tokens", (responses == tokenizer.eos_token_id).sum().item(), update) + writer.add_scalar("ppo/lr", lrnow, update) + writer.add_scalar("ppo/episode", global_step, update) + if args.rewards.use_adaptive_kl: + kl_ctl.update(mean_kl.item(), args.ppo.batch_size) + del kl, mean_kl, mean_entropy, mean_non_score_reward, scores + + # save model + if accelerator.is_main_process and args.save_path: + os.makedirs(os.path.dirname(args.save_path), exist_ok=True) + torch.save(policy.state_dict(), args.save_path) + + if args.upload_model: + repo_name = f"{args.exp_name}__{args.rewards.label_dataset}__seed{args.seed}__{int(time.time())}" + repo_id = f"{args.hf_entity}/{repo_name}" if args.hf_entity else repo_name + policy.lm_backbone.save_pretrained(repo_id, safe_serialization=True, push_to_hub=True) + tokenizer.save_pretrained(repo_id, push_to_hub=True) + + +if __name__ == "__main__": + args = tyro.cli(Args) + train(args) diff --git a/lm_human_preference_details/summarization/train_reward_accelerate copy.py b/lm_human_preference_details/summarization/train_reward_accelerate copy.py new file mode 100644 index 0000000..11e26d0 --- /dev/null +++ b/lm_human_preference_details/summarization/train_reward_accelerate copy.py @@ -0,0 +1,732 @@ +import os +import random +import time +from dataclasses import asdict, dataclass, field +from types import SimpleNamespace +from typing import List, Optional + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +import tyro +from accelerate import Accelerator +from accelerate.utils import DistributedDataParallelKwargs, broadcast +from datasets import load_dataset +from rich.console import Console +from rich.pretty import pprint +from torch import Tensor, optim +from torch.optim.optimizer import ( + _dispatch_sqrt, + _get_value, + _use_grad_for_differentiable, +) +from torch.utils.data import DataLoader, IterableDataset +from torch.utils.tensorboard import SummaryWriter +from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig + +from lm_human_preference_details.data import DATASET + + +@dataclass +class LabelHParams: + type: str = None + num_train: int = 4992 + num_labels: int = 4 + source: str = None + + +@dataclass +class TaskHParams: + # Query params + query_length: int = 64 + query_dataset: str = "books" + query_prefix: str = "" + query_suffix: str = "" + start_text: Optional[str] = None + end_text: Optional[str] = None + + # Response params + response_length: int = 24 + + # LM params + temperature: float = 0.7 + + +@dataclass +class Args: + # common args + exp_name: str = os.path.basename(__file__)[: -len(".py")] + """the name of this experiment""" + seed: int = 1 + """seed of the experiment""" + track: bool = False + """if toggled, this experiment will be tracked with Weights and Biases""" + wandb_project_name: str = "cleanrl" + """the wandb's project name""" + wandb_entity: Optional[str] = None + """the entity (team) of wandb's project""" + cuda: bool = True + """Whether to use cuda if available.""" + run_name: tyro.conf.Suppress[str] = None + """TO BE FILLED: a unique name of this run""" + + base_model: str = "gpt2" + """the name of the pretrained model to use""" + label_dataset: str = "sentiment/offline_5k.json" + """the name of the dataset to use for labels in `https://huggingface.co/datasets/vwxyzjn/lm-human-preferences`""" + local_batch_size: int = 4 + """per rank batch size""" + gradient_accumulation_steps: int = 1 + """gradient accumulation steps""" + local_micro_batch_size: tyro.conf.Suppress[int] = None + """per rank micro batch size""" + lr: float = 0.00005 + """the learning rate""" + eps: float = 1e-5 + """the epsilon for AdamW""" + rollout_batch_size: int = 512 + """rollout batch size""" + world_size: tyro.conf.Suppress[int] = None + """the number of processes to use""" + batch_size: tyro.conf.Suppress[int] = None + """the batch size across all ranks""" + local_normalize_samples: int = 256 + """Samples used to estimate reward mean and std""" + normalize_samples: tyro.conf.Suppress[int] = None + """Samples used to estimate reward mean and std across all ranks""" + debug_normalize: int = 0 + """Samples used to check that normalization worked""" + normalize_before: bool = True + """Whether, before training, to normalize the rewards on the policy to the scales on the training buffer. (For comparisons, just use mean 0, var 1.)""" + normalize_after: bool = True + """Whether, after training, to normalize the rewards on the ref policy to mean 0, var 1 (so the KL coefficient always has the same meaning).""" + print_sample_output_freq: int = 10 + """How often to print sample output""" + save_path: str = "models/reward.pt" + """Where to save the model""" + use_tensorflow_adam: bool = True + """Whether to use tensorflow-style Adam optimizer instead of PyTorch's""" + task: TaskHParams = field(default_factory=TaskHParams) + labels: LabelHParams = field(default_factory=LabelHParams) + + +OPENAI_PAD_TOKEN_ID = 50259 + + +def _single_tensor_adam( + params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_avg_sqs: List[Tensor], + max_exp_avg_sqs: List[Tensor], + state_steps: List[Tensor], + grad_scale: Optional[Tensor], + found_inf: Optional[Tensor], + *, + amsgrad: bool, + beta1: float, + beta2: float, + lr: float, + weight_decay: float, + eps: float, + maximize: bool, + capturable: bool, + differentiable: bool, +): + assert grad_scale is None and found_inf is None + + for i, param in enumerate(params): + grad = grads[i] if not maximize else -grads[i] + exp_avg = exp_avgs[i] + exp_avg_sq = exp_avg_sqs[i] + step_t = state_steps[i] + # update step + step_t += 1 + # Decay the first and second moment running average coefficient + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2) + step = _get_value(step_t) + + ### pytorch adam implementation: + # bias_correction1 = 1 - beta1 ** step + # bias_correction2 = 1 - beta2 ** step + # step_size = lr / bias_correction1 + # bias_correction2_sqrt = _dispatch_sqrt(bias_correction2) + # denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps) + # param.addcdiv_(exp_avg, denom, value=-step_size) + + ### tensorflow adam implementation: + lr_t = lr * _dispatch_sqrt(1 - beta2**step) / (1 - beta1**step) + denom = exp_avg_sq.sqrt().add_(eps) + param.addcdiv_(exp_avg, denom, value=-lr_t) + + +def adam( + params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_avg_sqs: List[Tensor], + max_exp_avg_sqs: List[Tensor], + state_steps: List[Tensor], + # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627 + # setting this as kwarg for now as functional API is compiled by torch/distributed/optim + foreach: Optional[bool] = None, + capturable: bool = False, + differentiable: bool = False, + fused: Optional[bool] = None, + grad_scale: Optional[Tensor] = None, + found_inf: Optional[Tensor] = None, + *, + amsgrad: bool, + beta1: float, + beta2: float, + lr: float, + weight_decay: float, + eps: float, + maximize: bool, +): + func = _single_tensor_adam + + func( + params, + grads, + exp_avgs, + exp_avg_sqs, + max_exp_avg_sqs, + state_steps, + amsgrad=amsgrad, + beta1=beta1, + beta2=beta2, + lr=lr, + weight_decay=weight_decay, + eps=eps, + maximize=maximize, + capturable=capturable, + differentiable=differentiable, + grad_scale=grad_scale, + found_inf=found_inf, + ) + + +class AdamTensorFlowStyle(optim.Adam): + @_use_grad_for_differentiable + def step(self, closure=None): + self._cuda_graph_capture_health_check() + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + params_with_grad = [] + grads = [] + exp_avgs = [] + exp_avg_sqs = [] + max_exp_avg_sqs = [] + state_steps = [] + beta1, beta2 = group["betas"] + + self._init_group( + group, + params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + max_exp_avg_sqs, + state_steps, + ) + + adam( + params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + max_exp_avg_sqs, + state_steps, + amsgrad=group["amsgrad"], + beta1=beta1, + beta2=beta2, + lr=group["lr"], + weight_decay=group["weight_decay"], + eps=group["eps"], + maximize=group["maximize"], + foreach=group["foreach"], + capturable=group["capturable"], + differentiable=group["differentiable"], + fused=group["fused"], + grad_scale=getattr(self, "grad_scale", None), + found_inf=getattr(self, "found_inf", None), + ) + + return loss + + +def layer_init(layer, std=np.sqrt(2), bias_const=0.0): + torch.nn.init.normal_(layer.weight, std=std) + torch.nn.init.constant_(layer.bias, val=bias_const) + return layer + + +class AutoModelForCausalLMWithRewardHead(nn.Module): + def __init__(self, lm_backbone): + super().__init__() + self.lm_backbone = lm_backbone + self.scalar_head = layer_init( + nn.Linear(lm_backbone.config.hidden_size, 1), + std=1 / np.sqrt(lm_backbone.config.hidden_size + 1), + ) + self.reward_gain = torch.nn.Parameter(torch.tensor(1.0), requires_grad=True) + self.reward_bias = torch.nn.Parameter(torch.tensor(0.0), requires_grad=True) + + def forward(self, **kwargs): + output = self.lm_backbone(**kwargs) + reward_latents = output.hidden_states[-1] + # shape: [batch_size, length, hidden_size] + last_reward_latents = reward_latents[:, -1, :] + # shape: [batch_size, hidden_size] + reward = self.scalar_head(last_reward_latents) + # shape: [batch_size, 1] + reward = self.reward_gain * reward + self.reward_bias + return output, reward + + +# Dataset for reward-model normalization +class NormalizationDataset(IterableDataset): + """A dataset for reward model normalization.""" + + def __init__(self, generator, tokenizer, query_length, seed, start_text=None, end_text=None): + self.generator = generator + self.tokenizer = tokenizer + self.query_length = query_length + self.start_text = start_text + self.end_text = end_text + self.seed = seed + token_to_index = tokenizer.get_vocab() + self.start_token = token_to_index[start_text] if self.start_text else None + self.end_token = token_to_index[end_text] if self.end_text else None + + def __iter__(self): + for text in self.generator("train", self.seed, shuffle=True): + tokens = self.tokenizer.encode(text) + if self.start_token is not None: + try: + first_index = tokens.index(self.start_token) + 1 + if first_index < len(tokens): + tokens = tokens[first_index:] + except: + continue + tokens = tokens[: self.query_length] + if self.end_token is not None: + try: + last_index = len(tokens) - tokens[::-1].index(self.end_token) + tokens = tokens[:last_index] + except: + continue + output = self.tokenizer.pad( + {"input_ids": tokens}, + padding="max_length", + max_length=self.query_length, + return_tensors="pt", + return_attention_mask=True, + ) + yield output + + +def right_padding_to_left_padding(tokens, pad_id): + """Convert from right padding to left padding.""" + assert tokens.ndim == 2 + return torch.tensor( + [[pad_id] * (row == pad_id).sum() + [x for x in row if x != pad_id] for row in tokens], + device=tokens.device, + ) + + +def ceil_div(a, b): + return (a - 1) // b + 1 + + +def exact_div(a, b): + q = a // b + if a != q * b: + raise ValueError(f"Inexact division: {a} / {b} = {a / b}") + return q + + +def generate(lm_backbone, queries, args, generation_config): + """generate in a way that does not affect padding tokens""" + context_length = queries.shape[1] + attention_mask = queries != args.pad_token_id + input_ids = queries.clone() + input_ids[~attention_mask] = 0 # set padding tokens to 0 + output = lm_backbone.generate( + input_ids=input_ids, + attention_mask=attention_mask, + # position_ids=attention_mask.cumsum(1) - attention_mask.long(), # generation collapsed if this was turned on. TODO: why does generation collapse with this? + generation_config=generation_config, + return_dict_in_generate=True, + ) + # restore padding tokens + return torch.cat((queries, output.sequences[:, context_length:]), dim=1) + + +def get_reward(reward_model, query_responses, args): + attention_mask = query_responses != args.pad_token_id + position_ids = attention_mask.cumsum(1) - attention_mask.long() # exclusive cumsum + input_ids = query_responses.clone() + input_ids[~attention_mask] = 0 + return reward_model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + return_dict=True, + output_hidden_states=True, + ) + + +def normalize( + args, + accelerator, + device, + lm_backbone, + reward_model, + iter_dataloader, + generation_config, +): + with torch.no_grad(): + # reset reward scales + accelerator.unwrap_model(reward_model).reward_gain.data.fill_(1.0) + accelerator.unwrap_model(reward_model).reward_bias.data.fill_(0.0) + + # sample queries and responses + n_batches = ceil_div(args.local_normalize_samples, args.rollout_batch_size) + sample_queries_responses = [] + for _ in range(n_batches): + data = next(iter_dataloader) + queries = data["input_ids"].to(device) + queries = right_padding_to_left_padding(data["input_ids"], args.pad_token_id).to(device) + query_responses = generate(lm_backbone, queries, args, generation_config) + sample_queries_responses.append(query_responses) + + # compute reward statistics + rewards = [] + for query_responses in sample_queries_responses: + rewards.append(get_reward(reward_model, query_responses, args)[1]) + rewards = torch.cat(rewards) + rewards = accelerator.gather(rewards) + # shape: [args.local_normalize_samples, 1] + mean, std = rewards.mean(), rewards.std() + print(f"mean: {mean}, std: {std}") + + # reward normalization + target_mean, target_std = torch.tensor(0.0, device=device), torch.tensor(1.0, device=device) + gain = target_std / std + bias = target_mean - gain * mean + print(f"gain: {gain}, bias: {bias}") + accelerator.unwrap_model(reward_model).reward_gain.data = gain + accelerator.unwrap_model(reward_model).reward_bias.data = bias + + # validate normalization + n_batches = ceil_div(args.local_normalize_samples, args.rollout_batch_size) + sample_queries_responses = [] + for _ in range(n_batches): + data = next(iter_dataloader) + queries = data["input_ids"].to(device) + queries = right_padding_to_left_padding(data["input_ids"], args.pad_token_id).to(device) + query_responses = generate(lm_backbone, queries, args, generation_config) + sample_queries_responses.append(query_responses) + rewards = [] + for query_responses in sample_queries_responses: + rewards.append(get_reward(reward_model, query_responses, args)[1]) + rewards = torch.cat(rewards) + rewards = accelerator.gather(rewards) + mean, std = rewards.mean(), rewards.std() + print(f"after mean: {mean}, after std: {std}") + + +def train(args: Args): + accelerator = Accelerator( + kwargs_handlers=[ + DistributedDataParallelKwargs(broadcast_buffers=False) + ], # this is needed to avoid https://github.com/pytorch/pytorch/issues/22095#issuecomment-505099500 + gradient_accumulation_steps=args.gradient_accumulation_steps, + ) + args.world_size = accelerator.num_processes + args.batch_size = int(args.local_batch_size * args.world_size) + args.local_micro_batch_size = exact_div(args.local_batch_size, args.gradient_accumulation_steps) + + console = Console(force_terminal=True) + run_name = f"{args.exp_name}__{args.seed}__{int(time.time())}" + writer = SimpleNamespace() # dummy writer + writer.add_scalar = lambda x, y, z: None + if accelerator.is_main_process: + if args.track: + import wandb + + wandb.init( + project=args.wandb_project_name, + entity=args.wandb_entity, + sync_tensorboard=True, + config=asdict(args), + name=run_name, + save_code=True, + ) + wandb.run.log_code(".") + writer = SummaryWriter(f"runs/{run_name}") + writer.add_text( + "hyperparameters", + "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])), + ) + pprint(args) + local_seed = args.seed + accelerator.process_index * 100003 # Prime + device = accelerator.device + random.seed(local_seed) + np.random.seed(local_seed) + torch.manual_seed(local_seed) + torch.backends.cudnn.deterministic = True + tokenizer = AutoTokenizer.from_pretrained( + args.base_model, + padding_side="right", + trust_remote_code=True, + ) + # we use the padding token manually but do not resize the token embedding of the model + tokenizer.add_special_tokens({"pad_token": "[PAD]"}) + args.pad_token_id = tokenizer.pad_token_id + untrained_model = AutoModelForCausalLMWithRewardHead(AutoModelForCausalLM.from_pretrained(args.base_model, trust_remote_code=True)).to(device) + reward_model = AutoModelForCausalLMWithRewardHead(AutoModelForCausalLM.from_pretrained(args.base_model, trust_remote_code=True)).to(device) + untrained_model.lm_backbone.generation_config.eos_token_id = ( + None # disable `pad_token_id` and `eos_token_id` because we just want to + ) + untrained_model.lm_backbone.generation_config.pad_token_id = None # generate tokens without truncation / padding + reward_model.lm_backbone.generation_config.eos_token_id = ( + None # disable `pad_token_id` and `eos_token_id` because we just want to + ) + reward_model.lm_backbone.generation_config.pad_token_id = None # generate tokens without truncation / padding + if args.use_tensorflow_adam: + optimizer = AdamTensorFlowStyle(reward_model.parameters(), lr=args.lr, eps=args.eps) + else: + optimizer = optim.Adam(reward_model.parameters(), lr=args.lr, eps=args.eps) + normalization_dataset = NormalizationDataset( + DATASET[args.task.query_dataset], + tokenizer, + args.task.query_length, + seed=local_seed, + start_text=args.task.start_text, + end_text=args.task.end_text, + ) + normalization_dataloader = DataLoader(normalization_dataset, batch_size=args.rollout_batch_size) + reward_model.lm_backbone._set_gradient_checkpointing(True) + reward_model, optimizer, normalization_dataloader = accelerator.prepare(reward_model, optimizer, normalization_dataloader) + iter_normalization_dataloader = iter(normalization_dataloader) + + generation_config = GenerationConfig( + max_new_tokens=args.task.response_length, + min_new_tokens=args.task.response_length, + temperature=args.task.temperature, + top_k=0.0, + top_p=1.0, + do_sample=True, + ) + + if args.normalize_before: + print("===Normalize reward model *before* training===") + print( + "before normalization. " + + f"Gain: {accelerator.unwrap_model(reward_model).reward_gain.data}" + + f" Bias: {accelerator.unwrap_model(reward_model).reward_bias.data}" + ) + + normalize( + args, + accelerator, + device, + untrained_model.lm_backbone, + reward_model, + iter_normalization_dataloader, + generation_config, + ) + print( + "after normalization. " + + f"Gain: {accelerator.unwrap_model(reward_model).reward_gain.data}" + + f" Bias: {accelerator.unwrap_model(reward_model).reward_bias.data}" + ) + + # `label` has keys `['sample0', 'query', 'best', 'sample3', 'sample1', 'sample2']` + label = load_dataset( + "vwxyzjn/lm-human-preferences", + data_files=[args.label_dataset], + )["train"] + print("Num labels found in source:", len(label)) + print("training on", args.labels.num_train, "in batches of", args.local_batch_size) + + print("===training reward model===") + all_inds = np.random.permutation(args.labels.num_train) + # ensure that all processes have the same shuffled indices + all_inds = broadcast(torch.tensor(all_inds, device=device), 0) + all_inds = all_inds.cpu().numpy() + global_step = 0 + for start in range(0, args.labels.num_train, args.batch_size): + # linear rate annealing + lr = (1 - start / args.labels.num_train) * args.lr + optimizer.param_groups[0]["lr"] = lr + + global_step += 1 + end = start + args.batch_size + b_inds_all = all_inds[start:end] + b_inds = b_inds_all[accelerator.process_index :: accelerator.num_processes] # multi-GPU slicing + losses = torch.zeros((args.gradient_accumulation_steps,), device=device) + accuracies = torch.zeros((args.gradient_accumulation_steps,), device=device) + gradient_accumulation_step = 0 + for micro_batch_start in range(0, args.local_batch_size, args.local_micro_batch_size): + with accelerator.accumulate(reward_model): + micro_batch_end = micro_batch_start + args.local_micro_batch_size + micro_batch_inds = b_inds[micro_batch_start:micro_batch_end] + mb_data = label[micro_batch_inds] + mb_query = torch.from_numpy(np.stack(mb_data["query"])).to(device) + mb_best = torch.from_numpy(np.stack(mb_data["best"])).to(device) + mb_responses = [ + torch.from_numpy(np.stack(mb_data[f"sample{i}"])).to(device) for i in range(args.labels.num_labels) + ] + # hack: deal with openai's padding token + mb_query[mb_query == OPENAI_PAD_TOKEN_ID] = args.pad_token_id + for item in mb_responses: + item[item == OPENAI_PAD_TOKEN_ID] = args.pad_token_id + + predicted_rewards = [] + for i in range(args.labels.num_labels): + query_responses = torch.cat([mb_query, mb_responses[i]], dim=1) + query_responses = right_padding_to_left_padding(query_responses, args.pad_token_id) + reward = get_reward(reward_model, query_responses, args)[1] + predicted_rewards.append(reward.view(-1)) + predicted_rewards = torch.stack( + predicted_rewards, dim=1 + ) # shape (batch_size, num_labels), basically a reward prediction for each label + accuracy = (predicted_rewards.argmax(1) == mb_best).float().mean() + loss = torch.nn.functional.cross_entropy(predicted_rewards, mb_best) + accelerator.backward(loss) + optimizer.step() # accelerate handles gradient accumulation automatically + optimizer.zero_grad() + losses[gradient_accumulation_step] = loss + accuracies[gradient_accumulation_step] = accuracy + gradient_accumulation_step += 1 + + writer.add_scalar("train/loss", accelerator.gather(losses).mean().item(), global_step) + writer.add_scalar("train/accuracy", accelerator.gather(accuracies).mean().item(), global_step) + writer.add_scalar("train/lr", lr, global_step) + + if args.print_sample_output_freq > 0 and global_step % args.print_sample_output_freq == 0: + with torch.no_grad(): + # eval on test_label, some duplicate code (I don't want to make the training loop into a function...) + test_accuracies = [] + new_all_inds = np.arange(len(label)) + for start in range(args.labels.num_train, len(label), args.batch_size): + end = start + args.batch_size + b_inds_all = new_all_inds[start:end] + b_inds = b_inds_all[accelerator.process_index :: accelerator.num_processes] # multi-GPU slicing + for micro_batch_start in range(0, args.local_batch_size, args.local_micro_batch_size): + micro_batch_end = micro_batch_start + args.local_micro_batch_size + micro_batch_inds = b_inds[micro_batch_start:micro_batch_end] + mb_data = label[micro_batch_inds] + mb_query = torch.from_numpy(np.stack(mb_data["query"])) + mb_query = right_padding_to_left_padding(mb_query, args.pad_token_id).to(device) + mb_best = torch.from_numpy(np.stack(mb_data["best"])).to(device) + mb_responses = [ + torch.from_numpy(np.stack(mb_data[f"sample{i}"])).to(device) for i in range(args.labels.num_labels) + ] + # hack: deal with openai's padding token + mb_query[mb_query == OPENAI_PAD_TOKEN_ID] = args.pad_token_id + for item in mb_responses: + item[item == OPENAI_PAD_TOKEN_ID] = args.pad_token_id + predicted_rewards = [] + for i in range(args.labels.num_labels): + query_responses = torch.cat([mb_query, mb_responses[i]], dim=1) + query_responses = right_padding_to_left_padding(query_responses, args.pad_token_id) + reward = get_reward(reward_model, query_responses, args)[1] + predicted_rewards.append(reward.view(-1)) + predicted_rewards = torch.stack( + predicted_rewards, dim=1 + ) # shape (batch_size, num_labels), basically a reward prediction for each label + accuracy = (predicted_rewards.argmax(1) == mb_best).float().mean() + test_accuracies.append(accuracy) + test_accuracy = accelerator.gather(torch.stack(test_accuracies).mean()).mean().item() + writer.add_scalar("test/accuracy", test_accuracy, global_step) + if accelerator.is_main_process: + print("test/accuracy", test_accuracy, global_step) + + # the part below is testing out some generations and KLs, not presented in the original code + data = next(iter_normalization_dataloader) + queries = data["input_ids"].to(device) + context_length = queries.shape[1] + queries = right_padding_to_left_padding(data["input_ids"], args.pad_token_id).to(device) + query_responses = generate( + accelerator.unwrap_model(reward_model).lm_backbone, + queries, + args, + generation_config, + ) + responses = query_responses[:, context_length:] + + output, reward = get_reward(reward_model, query_responses, args) + logits = output.logits[:, context_length - 1 : -1] + logits /= args.task.temperature + all_logprobs = F.log_softmax(logits, dim=-1) + logprobs = torch.gather(all_logprobs, 2, responses.unsqueeze(-1)).squeeze(-1) + del output, logits, all_logprobs + torch.cuda.empty_cache() + + output, _ = get_reward(untrained_model, query_responses, args) + logits = output.logits[:, context_length - 1 : -1] + logits /= args.task.temperature + all_logprobs = F.log_softmax(logits, dim=-1) + ref_logprobs = torch.gather(all_logprobs, 2, responses.unsqueeze(-1)).squeeze(-1) + del output, logits, all_logprobs + torch.cuda.empty_cache() + + print(f"global_step {global_step}:") + kl = logprobs - ref_logprobs + console.print( + f"[green]{tokenizer.decode(queries[0], skip_special_tokens=True)}[/]" + f"\n[blue]{tokenizer.decode(responses[0], skip_special_tokens=True)}[/]" + f"\n[red]reward: {reward[0].item()}[/]" + f"\n[red]kl: {kl[0].sum().item()}[/]" + f"\n[red]average kl: {kl.sum(1).mean().item()}[/]" + ) + writer.add_scalar("train/kl", kl.sum(1).mean().item(), global_step) + + torch.cuda.empty_cache() + if args.normalize_after: + print("===Normalize reward model *after* training===") + print( + "before normalization. " + + f"Gain: {accelerator.unwrap_model(reward_model).reward_gain.data}" + + f" Bias: {accelerator.unwrap_model(reward_model).reward_bias.data}" + ) + + normalize( + args, + accelerator, + device, + untrained_model.lm_backbone, + reward_model, + iter_normalization_dataloader, + generation_config, + ) + print( + "after normalization. " + + f"Gain: {accelerator.unwrap_model(reward_model).reward_gain.data}" + + f" Bias: {accelerator.unwrap_model(reward_model).reward_bias.data}" + ) + + # save model + if args.save_path: + os.makedirs(os.path.dirname(args.save_path), exist_ok=True) + torch.save(accelerator.unwrap_model(reward_model).state_dict(), args.save_path) + + if accelerator.is_main_process and args.track: + wandb.finish() + + +if __name__ == "__main__": + args = tyro.cli(Args) + train(args) diff --git a/lm_human_preference_details/summarization/train_reward_accelerate_debug copy.py b/lm_human_preference_details/summarization/train_reward_accelerate_debug copy.py new file mode 100644 index 0000000..5113045 --- /dev/null +++ b/lm_human_preference_details/summarization/train_reward_accelerate_debug copy.py @@ -0,0 +1,526 @@ +import os +import random +import time +from dataclasses import asdict, dataclass, field +from types import SimpleNamespace +from typing import Optional + +from accelerate import Accelerator +from accelerate.utils import DistributedDataParallelKwargs, broadcast +import numpy as np +import torch +import torch.nn as nn +import torch.optim as optim +import torch.nn.functional as F +import tyro +from rich.console import Console +from datasets import load_dataset +from rich.pretty import pprint +from torch.utils.data import DataLoader, IterableDataset +from torch.utils.tensorboard import SummaryWriter +from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig + +from lm_human_preference_details.datamod import DATASET + +@dataclass +class LabelHParams: + type: str = None + num_train: int = 4992 + num_labels: int = 4 + source: str = None + + +@dataclass +class TaskHParams: + # Query params + query_length: int = 64 + query_dataset: str = "books" + query_prefix: str = "" + query_suffix: str = "" + start_text: Optional[str] = None + end_text: Optional[str] = None + + # Response params + response_length: int = 24 + + # LM params + temperature: float = 0.7 + + +@dataclass +class Args: + # common args + exp_name: str = os.path.basename(__file__)[:-len(".py")] + """the name of this experiment""" + seed: int = 1 + """seed of the experiment""" + track: bool = False + """if toggled, this experiment will be tracked with Weights and Biases""" + wandb_project_name: str = "cleanrl" + """the wandb's project name""" + wandb_entity: Optional[str] = None + """the entity (team) of wandb's project""" + cuda: bool = True + """Whether to use cuda if available.""" + run_name: tyro.conf.Suppress[str] = None + """TO BE FILLED: a unique name of this run""" + + base_model: str = "gpt2" + """the name of the pretrained model to use""" + label_dataset: str = "sentiment/offline_5k.json" + """the name of the dataset to use for labels in `https://huggingface.co/datasets/vwxyzjn/lm-human-preferences`""" + local_batch_size: int = 4 + """per rank batch size""" + lr: float = 0.00005 + """the learning rate""" + eps: float = 1e-5 + """the epsilon for AdamW""" + local_rollout_batch_size: int = 512 + """per rank rollot batch size""" + world_size: tyro.conf.Suppress[int] = None + """the number of processes to use""" + batch_size: tyro.conf.Suppress[int] = None + """the batch size across all ranks""" + normalize_samples: int = 256 + """Samples used to estimate reward mean and std""" + debug_normalize: int = 0 + """Samples used to check that normalization worked""" + normalize_before: bool = True + """Whether, before training, to normalize the rewards on the policy to the scales on the training buffer. (For comparisons, just use mean 0, var 1.)""" + normalize_after: bool = True + """Whether, after training, to normalize the rewards on the ref policy to mean 0, var 1 (so the KL coefficient always has the same meaning).""" + print_sample_output_freq: int = 10 + """How often to print sample output""" + save_path: str = "models/reward.pt" + """Where to save the model""" + task: TaskHParams = field(default_factory=TaskHParams) + labels: LabelHParams = field(default_factory=LabelHParams) + + +def layer_init(layer, std=np.sqrt(2), bias_const=0.0): + torch.nn.init.normal_(layer.weight, std=std) + torch.nn.init.constant_(layer.bias, val=bias_const) + return layer + + +OPENAI_PAD_TOKEN_ID = 50259 + + +class ScalarHead(nn.Module): + def __init__(self, config, scale=None, **kwargs): + super().__init__() + if not hasattr(config, "summary_dropout_prob"): + summary_dropout_prob = kwargs.pop("summary_dropout_prob", 0.1) + else: + summary_dropout_prob = config.summary_dropout_prob + self.dropout = nn.Dropout(summary_dropout_prob) if summary_dropout_prob else nn.Identity() + # some models such as OPT have a projection layer before the word embeddings - e.g. OPT-350m + if hasattr(config, "word_embed_proj_dim"): + hidden_size = config.word_embed_proj_dim + else: + hidden_size = config.hidden_size + if scale is None: + scale = 1 / np.sqrt(hidden_size + 1) + self.summary = layer_init(nn.Linear(hidden_size, 1), std=scale) + self.flatten = nn.Flatten() + + def forward(self, hidden_states): + output = self.dropout(hidden_states) + output = self.summary(output) + return output + + +class AutoModelForCausalLMWithScalarHead(nn.Module): + def __init__(self, pretrained_model): + super().__init__() + self.pretrained_model = pretrained_model + self.scalar_head = ScalarHead(self.pretrained_model.config, scale=0.0) + + def forward(self, **kwargs): + output = self.pretrained_model(**kwargs) + return output, self.scalar_head(output.hidden_states[-1]) + + +class AutoModelForCausalLMWithRewardHead(nn.Module): + def __init__(self, pretrained_model): + super().__init__() + self.pretrained_model = pretrained_model + self.scalar_head = ScalarHead(self.pretrained_model.config) + self.reward_gain = torch.nn.Parameter(torch.tensor(1.0), requires_grad=True) + self.reward_bias = torch.nn.Parameter(torch.tensor(0.0), requires_grad=True) + + def forward(self, **kwargs): + output = self.pretrained_model(**kwargs) + reward = self.scalar_head(output.hidden_states[-1]) + reward = self.reward_gain * reward + self.reward_bias + # but we only care about the reward of the last token + reward = reward[:, -1] + return output, reward + + +# a pytorch dataset +class MyDataset(IterableDataset): + def __init__(self, generator, tokenizer, query_length, start_text=None, end_text=None, query_prefix="", query_suffix="", seed=None): + self.generator = generator + self.tokenizer = tokenizer + self.query_length = query_length + self.start_text = start_text + self.end_text = end_text + self.seed = seed + token_to_index = tokenizer.get_vocab() + self.start_token = token_to_index[start_text] if self.start_text else None + self.end_token = token_to_index[end_text] if self.end_text else None + self.query_prefix = query_prefix + self.query_suffix = query_suffix + self.query_prefix_tokens = torch.LongTensor(tokenizer.encode(query_prefix)) + self.query_suffix_tokens = torch.LongTensor(tokenizer.encode(query_suffix)) + + + def __iter__(self): + for text in self.generator("train", self.seed, shuffle=True): + tokens = self.tokenizer.encode(text) + if self.start_token is not None: + try: + first_index = tokens.index(self.start_token) + 1 + if first_index < len(tokens): + tokens = tokens[first_index:] + except: + continue + tokens = tokens[: self.query_length] + if self.end_token is not None: + try: + last_index = len(tokens) - tokens[::-1].index(self.end_token) + tokens = tokens[:last_index] + except: + continue + output = self.tokenizer.pad( + {"input_ids": tokens}, + padding="max_length", + max_length=self.query_length, + return_tensors="pt", + ) + output["input_ids"] = torch.cat((self.query_prefix_tokens, output["input_ids"], self.query_suffix_tokens)) + yield output + + +def left_padding_to_right_padding(query, pad_id): + # got to convert to right padding, otherwise `transformers` has weird issues + # even with `position_ids` + return torch.tensor([ + [pad_id]*(row==pad_id).sum() + [x for x in row if x != pad_id] + for row in query + ]) + + +def ceil_div(a, b): + return (a - 1) // b + 1 + + +def generate(pretrained_model, queries, tokenizer, generation_config): + """generate in a way that does not affect padding tokens""" + context_length = queries.shape[1] + attention_mask = queries != tokenizer.pad_token_id + input_ids = queries.clone() + input_ids[~attention_mask] = 0 # set padding tokens to 0 + output = pretrained_model.generate( + input_ids=input_ids, + attention_mask=attention_mask, + # position_ids=attention_mask.cumsum(1) - attention_mask.long(), # generation collapsed if this was turned on. TODO: why does generation collapse with this? + generation_config=generation_config, + return_dict_in_generate=True, + ) + # restore padding tokens + return torch.cat((queries, output.sequences[:, context_length:]), dim=1) + + +def get_reward(reward_model, query_responses, tokenizer): + attention_mask = query_responses != tokenizer.pad_token_id + position_ids = attention_mask.cumsum(1) - attention_mask.long() # exclusive cumsum + input_ids = query_responses.clone() + input_ids[~attention_mask] = 0 + return reward_model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + return_dict=True, + output_hidden_states=True, + ) + +def normalize(args, accelerator, device, tokenizer, pretrained_model, reward_model, iter_dataloader, generation_config): + with torch.no_grad(): + # reset reward scales + reward_model.module.reward_gain.data.fill_(1.0) + reward_model.module.reward_bias.data.fill_(0.0) + + # sample queries and responses + n_batches = ceil_div(args.normalize_samples, args.local_rollout_batch_size) + sample_queries_responses = [] + for _ in range(n_batches): + data = next(iter_dataloader) + queries = data["input_ids"].to(device) + queries = left_padding_to_right_padding(data["input_ids"], tokenizer.pad_token_id).to(device) + query_responses = generate(pretrained_model, queries, tokenizer, generation_config) + sample_queries_responses.append(query_responses) + + # compute reward statistics + rewards = [] + for query_responses in sample_queries_responses: + rewards.append(get_reward(reward_model, query_responses, tokenizer)[1]) + rewards = torch.cat(rewards) + rewards= accelerator.gather(rewards) + mean, std = rewards.mean(), rewards.std() + print(f"mean: {mean}, std: {std}") + + # reward normalization + target_mean, target_std = torch.tensor(0.0, device=device), torch.tensor(1.0, device=device) + gain = target_std / std + bias = target_mean - gain * mean + print(f"gain: {gain}, bias: {bias}") + reward_model.module.reward_gain.data = gain + reward_model.module.reward_bias.data = bias + + # after normalization statistics + n_batches = ceil_div(args.normalize_samples, args.local_rollout_batch_size) + sample_queries_responses = [] + for _ in range(n_batches): + data = next(iter_dataloader) + queries = data["input_ids"].to(device) + queries = left_padding_to_right_padding(data["input_ids"], tokenizer.pad_token_id).to(device) + query_responses = generate(pretrained_model, queries, tokenizer, generation_config) + sample_queries_responses.append(query_responses) + rewards = [] + for query_responses in sample_queries_responses: + rewards.append(get_reward(reward_model, query_responses, tokenizer)[1]) + rewards = torch.cat(rewards) + rewards= accelerator.gather(rewards) + mean, std = rewards.mean(), rewards.std() + print(f"after mean: {mean}, after std: {std}") + + +def train(args: Args): + args.task.query_prefix = args.task.query_prefix.replace("\\n", "\n") + args.task.query_suffix = args.task.query_suffix.replace("\\n", "\n") + accelerator = Accelerator( + kwargs_handlers=[DistributedDataParallelKwargs(broadcast_buffers=False)] # this is needed to avoid https://github.com/pytorch/pytorch/issues/22095#issuecomment-505099500 + ) + args.world_size = accelerator.num_processes + args.batch_size = int(args.local_batch_size * args.world_size) + + console = Console(force_terminal=True) + run_name = f"{args.exp_name}__{args.seed}__{int(time.time())}" + writer = SimpleNamespace() # dummy writer + writer.add_scalar = lambda x, y, z: None + if accelerator.is_main_process: + if args.track: + import wandb + + wandb.init( + project=args.wandb_project_name, + entity=args.wandb_entity, + sync_tensorboard=True, + config=asdict(args), + name=run_name, + save_code=True, + ) + wandb.run.log_code(".") + writer = SummaryWriter(f"runs/{run_name}") + writer.add_text( + "hyperparameters", + "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])), + ) + pprint(args) + device = accelerator.device + args.seed += accelerator.process_index + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.backends.cudnn.deterministic = True + tokenizer = AutoTokenizer.from_pretrained( + args.base_model, + padding_side="right", + ) + # we use the padding token manually but do not resize the token embedding of the model + tokenizer.add_special_tokens({"pad_token": "[PAD]"}) + untrained_model = AutoModelForCausalLMWithRewardHead(AutoModelForCausalLM.from_pretrained(args.base_model)).to(device) + reward_model = AutoModelForCausalLMWithRewardHead(AutoModelForCausalLM.from_pretrained(args.base_model)).to(device) + reward_model.pretrained_model.generation_config.eos_token_id = None # disable `pad_token_id` and `eos_token_id` because we just want to + reward_model.pretrained_model.generation_config.pad_token_id = None # generate tokens without truncation / padding + optimizer = optim.Adam(reward_model.parameters(), lr=args.lr, eps=args.eps) + dataset = MyDataset( + DATASET[args.task.query_dataset], + tokenizer, + args.task.query_length, + start_text=args.task.start_text, + end_text=args.task.end_text, + query_prefix=args.task.query_prefix, + query_suffix=args.task.query_suffix, + ) + dataloader = DataLoader(dataset, batch_size=args.local_rollout_batch_size) + reward_model, optimizer, dataloader = accelerator.prepare(reward_model, optimizer, dataloader) + print(reward_model) + iter_dataloader = iter(dataloader) + + generation_config = GenerationConfig( + max_new_tokens=args.task.response_length, + min_new_tokens=args.task.response_length, + temperature=args.task.temperature, + top_k=0.0, + top_p=1.0, + do_sample=True, + ) + + # `label` has keys `['sample0', 'query', 'best', 'sample3', 'sample1', 'sample2']` + label = load_dataset( + "vwxyzjn/lm-human-preferences", + data_files=[args.label_dataset], + )["train"] + print("Num labels found in source:", len(label)) + print("training on", args.labels.num_train, "in batches of", args.local_batch_size) + + print("before====", reward_model.module.reward_gain.data) + if args.normalize_before: + normalize(args, accelerator, device, tokenizer, accelerator.unwrap_model(reward_model).pretrained_model, reward_model, iter_dataloader, generation_config) + print("after====", reward_model.module.reward_gain.data) + + print("===training reward model===") + all_inds = np.arange(args.labels.num_train) + np.random.shuffle(all_inds) + # ensure that all processes have the same shuffled indices + all_inds = broadcast(torch.tensor(all_inds, device=device), 0) + all_inds = all_inds.cpu().numpy() + global_step = 0 + for start in range(0, args.labels.num_train, args.batch_size): + global_step += 1 + end = start + args.batch_size + b_inds_all = all_inds[start:end] + b_inds = b_inds_all[accelerator.process_index::accelerator.num_processes] # multi-GPU slicing + lr = (1 - start / args.labels.num_train) * args.lr + optimizer.param_groups[0]["lr"] = lr + mb_data = label[b_inds] + # print("accelerator.process_index", accelerator.process_index, b_inds, b_inds_all) + mb_query = torch.from_numpy(np.stack(mb_data["query"])) + print("mb_query.shape", mb_query.shape) + mb_query = left_padding_to_right_padding(mb_query, tokenizer.pad_token_id).to(device) + mb_best = torch.from_numpy(np.stack(mb_data["best"])).to(device) + mb_responses = [ + torch.from_numpy(np.stack(mb_data[f"sample{i}"])).to(device) + for i in range(args.labels.num_labels) + ] + # hack: deal with openai's padding token + # assert (mb_query == tokenizer.pad_token_id).sum() == 0 + mb_query[mb_query == OPENAI_PAD_TOKEN_ID] = tokenizer.pad_token_id + for item in mb_responses: + # assert (item == tokenizer.pad_token_id).sum() == 0 + item[item == OPENAI_PAD_TOKEN_ID] = tokenizer.pad_token_id + + predicted_rewards = [] + for i in range(args.labels.num_labels): + query_responses = torch.cat([mb_query, mb_responses[i]], dim=1) + reward = get_reward(reward_model, query_responses, tokenizer)[1] + predicted_rewards.append( + reward.squeeze() + ) + predicted_rewards = torch.stack( + predicted_rewards, dim=1 + ) # shape (batch_size, num_labels), basically a reward prediction for each label + accuracy = (predicted_rewards.argmax(1) == mb_best).float().mean() + loss = torch.nn.functional.cross_entropy(predicted_rewards, mb_best) + optimizer.zero_grad() + accelerator.backward(loss) + optimizer.step() + writer.add_scalar("train/loss", accelerator.gather(loss).mean().item(), global_step) + writer.add_scalar("train/accuracy", accelerator.gather(accuracy).mean().item(), global_step) + writer.add_scalar("train/lr", lr, global_step) + + if args.print_sample_output_freq > 0 and global_step % args.print_sample_output_freq == 0: + with torch.no_grad(): + data = next(iter_dataloader) + queries = data["input_ids"].to(device) + context_length = queries.shape[1] + queries = left_padding_to_right_padding(data["input_ids"], tokenizer.pad_token_id).to(device) + query_responses = generate(accelerator.unwrap_model(reward_model).pretrained_model, queries, tokenizer, generation_config) + responses = query_responses[:, context_length:] + + output, reward = get_reward(reward_model, query_responses, tokenizer) + logits = output.logits[:,context_length-1:-1] + logits /= args.task.temperature + all_logprobs = F.log_softmax(logits, dim=-1) + logprobs = torch.gather(all_logprobs, 2, responses.unsqueeze(-1)).squeeze(-1) + + output, _ = get_reward(untrained_model, query_responses, tokenizer) + logits = output.logits[:,context_length-1:-1] + logits /= args.task.temperature + all_logprobs = F.log_softmax(logits, dim=-1) + ref_logprobs = torch.gather(all_logprobs, 2, responses.unsqueeze(-1)).squeeze(-1) + + print(f"global_step {global_step}:") + kl = logprobs - ref_logprobs + console.print( + f"[green]{tokenizer.decode(queries[0], skip_special_tokens=True)}[/]" + f"\n[blue]{tokenizer.decode(responses[0], skip_special_tokens=True)}[/]" + f"\n[red]reward: {reward[0].item()}[/]" + f"\n[red]kl: {kl[0].sum().item()}[/]" + f"\n[red]average kl: {kl.sum(1).mean().item()}[/]" + ) + writer.add_scalar("train/kl", kl.sum(1).mean().item(), global_step) + + # eval on test_label + test_accuracies = [] + all_inds = np.arange(len(label)) + for start in range(args.labels.num_train, len(label), args.batch_size): + end = start + args.batch_size + b_inds_all = all_inds[start:end] + b_inds = b_inds_all[accelerator.process_index::accelerator.num_processes] # multi-GPU slicing + mb_data = label[b_inds] + # print("accelerator.process_index", accelerator.process_index, b_inds, b_inds_all) + mb_query = torch.from_numpy(np.stack(mb_data["query"])) + mb_query = left_padding_to_right_padding(mb_query, tokenizer.pad_token_id).to(device) + mb_best = torch.from_numpy(np.stack(mb_data["best"])).to(device) + mb_responses = [ + torch.from_numpy(np.stack(mb_data[f"sample{i}"])).to(device) + for i in range(args.labels.num_labels) + ] + # hack: deal with openai's padding token + # assert (mb_query == tokenizer.pad_token_id).sum() == 0 + mb_query[mb_query == OPENAI_PAD_TOKEN_ID] = tokenizer.pad_token_id + for item in mb_responses: + # assert (item == tokenizer.pad_token_id).sum() == 0 + item[item == OPENAI_PAD_TOKEN_ID] = tokenizer.pad_token_id + + predicted_rewards = [] + for i in range(args.labels.num_labels): + query_responses = torch.cat([mb_query, mb_responses[i]], dim=1) + if i == 0: + print(tokenizer.decode(query_responses[0], skip_special_tokens=True)) + print(tokenizer.decode(mb_responses[i], skip_special_tokens=True)) + breakpoint() + reward = get_reward(reward_model, query_responses, tokenizer)[1] + predicted_rewards.append( + reward.squeeze() + ) + predicted_rewards = torch.stack( + predicted_rewards, dim=1 + ) # shape (batch_size, num_labels), basically a reward prediction for each label + accuracy = (predicted_rewards.argmax(1) == mb_best).float().mean() + test_accuracies.append(accuracy) + test_accuracy = accelerator.gather(torch.stack(test_accuracies).mean()).mean().item() + writer.add_scalar("test/accuracy", test_accuracy, global_step) + if accelerator.is_main_process: + print("test/accuracy", test_accuracy, global_step) + + torch.cuda.empty_cache() + if args.normalize_after: + normalize(args, accelerator, device, tokenizer, accelerator.unwrap_model(reward_model).pretrained_model, reward_model, iter_dataloader, generation_config) + + # save model + if args.save_path: + os.makedirs(os.path.dirname(args.save_path), exist_ok=True) + torch.save(accelerator.unwrap_model(reward_model).state_dict(), args.save_path) + + if accelerator.is_main_process and args.track: + wandb.finish() + + +if __name__ == "__main__": + args = tyro.cli(Args) + train(args) diff --git a/lm_human_preference_details/summarization/train_reward_accelerate_debug.py b/lm_human_preference_details/summarization/train_reward_accelerate_debug.py new file mode 100644 index 0000000..e4811b1 --- /dev/null +++ b/lm_human_preference_details/summarization/train_reward_accelerate_debug.py @@ -0,0 +1,528 @@ +import os +import random +import time +from dataclasses import asdict, dataclass, field +from types import SimpleNamespace +from typing import Optional + +from accelerate import Accelerator +from accelerate.utils import DistributedDataParallelKwargs, broadcast +import numpy as np +import torch +import torch.nn as nn +import torch.optim as optim +import torch.nn.functional as F +import tyro +from rich.console import Console +from datasets import load_dataset +from rich.pretty import pprint +from torch.utils.data import DataLoader, IterableDataset +from torch.utils.tensorboard import SummaryWriter +from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig + +from lm_human_preference_details.datamod import DATASET + +@dataclass +class LabelHParams: + type: str = None + num_train: int = 4992 + num_labels: int = 4 + source: str = None + + +@dataclass +class TaskHParams: + # Query params + query_length: int = 64 + query_dataset: str = "books" + query_prefix: str = "" + query_suffix: str = "" + start_text: Optional[str] = None + end_text: Optional[str] = None + + # Response params + response_length: int = 24 + + # LM params + temperature: float = 0.7 + + +@dataclass +class Args: + # common args + exp_name: str = os.path.basename(__file__)[:-len(".py")] + """the name of this experiment""" + seed: int = 1 + """seed of the experiment""" + track: bool = False + """if toggled, this experiment will be tracked with Weights and Biases""" + wandb_project_name: str = "cleanrl" + """the wandb's project name""" + wandb_entity: Optional[str] = None + """the entity (team) of wandb's project""" + cuda: bool = True + """Whether to use cuda if available.""" + run_name: tyro.conf.Suppress[str] = None + """TO BE FILLED: a unique name of this run""" + + base_model: str = "gpt2" + """the name of the pretrained model to use""" + label_dataset: str = "sentiment/offline_5k.json" + """the name of the dataset to use for labels in `https://huggingface.co/datasets/vwxyzjn/lm-human-preferences`""" + local_batch_size: int = 4 + """per rank batch size""" + lr: float = 0.00005 + """the learning rate""" + eps: float = 1e-5 + """the epsilon for AdamW""" + local_rollout_batch_size: int = 512 + """per rank rollot batch size""" + world_size: tyro.conf.Suppress[int] = None + """the number of processes to use""" + batch_size: tyro.conf.Suppress[int] = None + """the batch size across all ranks""" + normalize_samples: int = 256 + """Samples used to estimate reward mean and std""" + debug_normalize: int = 0 + """Samples used to check that normalization worked""" + normalize_before: bool = True + """Whether, before training, to normalize the rewards on the policy to the scales on the training buffer. (For comparisons, just use mean 0, var 1.)""" + normalize_after: bool = True + """Whether, after training, to normalize the rewards on the ref policy to mean 0, var 1 (so the KL coefficient always has the same meaning).""" + print_sample_output_freq: int = 10 + """How often to print sample output""" + save_path: str = "models/reward.pt" + """Where to save the model""" + task: TaskHParams = field(default_factory=TaskHParams) + labels: LabelHParams = field(default_factory=LabelHParams) + + +def layer_init(layer, std=np.sqrt(2), bias_const=0.0): + torch.nn.init.normal_(layer.weight, std=std) + torch.nn.init.constant_(layer.bias, val=bias_const) + return layer + + +OPENAI_PAD_TOKEN_ID = 50259 + + +class ScalarHead(nn.Module): + def __init__(self, config, scale=None, **kwargs): + super().__init__() + if not hasattr(config, "summary_dropout_prob"): + summary_dropout_prob = kwargs.pop("summary_dropout_prob", 0.1) + else: + summary_dropout_prob = config.summary_dropout_prob + self.dropout = nn.Dropout(summary_dropout_prob) if summary_dropout_prob else nn.Identity() + # some models such as OPT have a projection layer before the word embeddings - e.g. OPT-350m + if hasattr(config, "word_embed_proj_dim"): + hidden_size = config.word_embed_proj_dim + else: + hidden_size = config.hidden_size + if scale is None: + scale = 1 / np.sqrt(hidden_size + 1) + self.summary = layer_init(nn.Linear(hidden_size, 1), std=scale) + self.flatten = nn.Flatten() + + def forward(self, hidden_states): + output = self.dropout(hidden_states) + output = self.summary(output) + return output + + +class AutoModelForCausalLMWithScalarHead(nn.Module): + def __init__(self, pretrained_model): + super().__init__() + self.pretrained_model = pretrained_model + self.scalar_head = ScalarHead(self.pretrained_model.config, scale=0.0) + + def forward(self, **kwargs): + output = self.pretrained_model(**kwargs) + return output, self.scalar_head(output.hidden_states[-1]) + + +class AutoModelForCausalLMWithRewardHead(nn.Module): + def __init__(self, pretrained_model): + super().__init__() + self.pretrained_model = pretrained_model + self.scalar_head = ScalarHead(self.pretrained_model.config) + self.reward_gain = torch.nn.Parameter(torch.tensor(1.0), requires_grad=True) + self.reward_bias = torch.nn.Parameter(torch.tensor(0.0), requires_grad=True) + + def forward(self, **kwargs): + output = self.pretrained_model(**kwargs) + reward = self.scalar_head(output.hidden_states[-1]) + reward = self.reward_gain * reward + self.reward_bias + # but we only care about the reward of the last token + reward = reward[:, -1] + return output, reward + + +# a pytorch dataset +class MyDataset(IterableDataset): + def __init__(self, generator, tokenizer, query_length, start_text=None, end_text=None, seed=None): + self.generator = generator + self.tokenizer = tokenizer + self.query_length = query_length + self.start_text = start_text + self.end_text = end_text + self.seed = seed + token_to_index = tokenizer.get_vocab() + self.start_token = token_to_index[start_text] if self.start_text else None + self.end_token = token_to_index[end_text] if self.end_text else None + + def __iter__(self): + for text in self.generator("train", self.seed, shuffle=True): + tokens = self.tokenizer.encode(text) + if self.start_token is not None: + try: + first_index = tokens.index(self.start_token) + 1 + if first_index < len(tokens): + tokens = tokens[first_index:] + except: + continue + tokens = tokens[: self.query_length] + if self.end_token is not None: + try: + last_index = len(tokens) - tokens[::-1].index(self.end_token) + tokens = tokens[:last_index] + except: + continue + output = self.tokenizer.pad( + {"input_ids": tokens}, + padding="max_length", + max_length=self.query_length, + return_tensors="pt", + return_attention_mask=True, + ) + yield output + + +def left_padding_to_right_padding(query, pad_id): + # got to convert to right padding, otherwise `transformers` has weird issues + # even with `position_ids` + return torch.tensor([ + [pad_id]*(row==pad_id).sum() + [x for x in row if x != pad_id] + for row in query + ]) + + +def ceil_div(a, b): + return (a - 1) // b + 1 + + +def generate(pretrained_model, queries, tokenizer, generation_config): + """generate in a way that does not affect padding tokens""" + context_length = queries.shape[1] + attention_mask = queries != tokenizer.pad_token_id + input_ids = queries.clone() + input_ids[~attention_mask] = 0 # set padding tokens to 0 + output = pretrained_model.generate( + input_ids=input_ids, + attention_mask=attention_mask, + # position_ids=attention_mask.cumsum(1) - attention_mask.long(), # generation collapsed if this was turned on. TODO: why does generation collapse with this? + generation_config=generation_config, + return_dict_in_generate=True, + ) + # restore padding tokens + return torch.cat((queries, output.sequences[:, context_length:]), dim=1) + + +def get_reward(reward_model, query_responses, tokenizer): + attention_mask = query_responses != tokenizer.pad_token_id + position_ids = attention_mask.cumsum(1) - attention_mask.long() # exclusive cumsum + input_ids = query_responses.clone() + input_ids[~attention_mask] = 0 + return reward_model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + return_dict=True, + output_hidden_states=True, + ) + +def normalize(args, accelerator, device, tokenizer, pretrained_model, reward_model, iter_dataloader, generation_config, query_prefix_tokens, query_suffix_tokens): + with torch.no_grad(): + # reset reward scales + reward_model.module.reward_gain.data.fill_(1.0) + reward_model.module.reward_bias.data.fill_(0.0) + + # sample queries and responses + n_batches = ceil_div(args.normalize_samples, args.local_rollout_batch_size) + sample_queries_responses = [] + for _ in range(n_batches): + data = next(iter_dataloader) + queries = data["input_ids"].to(device) + queries = format_query(query_prefix_tokens, queries, query_suffix_tokens) + queries = left_padding_to_right_padding(data["input_ids"], tokenizer.pad_token_id).to(device) + query_responses = generate(pretrained_model, queries, tokenizer, generation_config) + sample_queries_responses.append(query_responses) + + # compute reward statistics + rewards = [] + for query_responses in sample_queries_responses: + rewards.append(get_reward(reward_model, query_responses, tokenizer)[1]) + rewards = torch.cat(rewards) + rewards= accelerator.gather(rewards) + mean, std = rewards.mean(), rewards.std() + print(f"mean: {mean}, std: {std}") + + # reward normalization + target_mean, target_std = torch.tensor(0.0, device=device), torch.tensor(1.0, device=device) + gain = target_std / std + bias = target_mean - gain * mean + print(f"gain: {gain}, bias: {bias}") + reward_model.module.reward_gain.data = gain + reward_model.module.reward_bias.data = bias + + # after normalization statistics + n_batches = ceil_div(args.normalize_samples, args.local_rollout_batch_size) + sample_queries_responses = [] + for _ in range(n_batches): + data = next(iter_dataloader) + queries = data["input_ids"].to(device) + queries = format_query(query_prefix_tokens, queries, query_suffix_tokens) + queries = left_padding_to_right_padding(queries, tokenizer.pad_token_id).to(device) + query_responses = generate(pretrained_model, queries, tokenizer, generation_config) + sample_queries_responses.append(query_responses) + rewards = [] + for query_responses in sample_queries_responses: + rewards.append(get_reward(reward_model, query_responses, tokenizer)[1]) + rewards = torch.cat(rewards) + rewards= accelerator.gather(rewards) + mean, std = rewards.mean(), rewards.std() + print(f"after mean: {mean}, after std: {std}") + + +def format_query(query_prefix_tokens, query, query_suffix_tokens): + query_prefix_tokens_tiled = query_prefix_tokens.unsqueeze(0).repeat(query.shape[0], 1).to(query.device) + query_suffix_tokens_tiled = query_suffix_tokens.unsqueeze(0).repeat(query.shape[0], 1).to(query.device) + return torch.cat((query_prefix_tokens_tiled, query, query_suffix_tokens_tiled), dim=1) + + +def train(args: Args): + args.task.query_prefix = args.task.query_prefix.replace("\\n", "\n") + args.task.query_suffix = args.task.query_suffix.replace("\\n", "\n") + accelerator = Accelerator( + kwargs_handlers=[DistributedDataParallelKwargs(broadcast_buffers=False)] # this is needed to avoid https://github.com/pytorch/pytorch/issues/22095#issuecomment-505099500 + ) + args.world_size = accelerator.num_processes + args.batch_size = int(args.local_batch_size * args.world_size) + + console = Console(force_terminal=True) + run_name = f"{args.exp_name}__{args.seed}__{int(time.time())}" + writer = SimpleNamespace() # dummy writer + writer.add_scalar = lambda x, y, z: None + if accelerator.is_main_process: + if args.track: + import wandb + + wandb.init( + project=args.wandb_project_name, + entity=args.wandb_entity, + sync_tensorboard=True, + config=asdict(args), + name=run_name, + save_code=True, + ) + wandb.run.log_code(".") + writer = SummaryWriter(f"runs/{run_name}") + writer.add_text( + "hyperparameters", + "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])), + ) + pprint(args) + device = accelerator.device + args.seed += accelerator.process_index + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.backends.cudnn.deterministic = True + tokenizer = AutoTokenizer.from_pretrained( + args.base_model, + padding_side="right", + use_auth_token=True, + ) + # we use the padding token manually but do not resize the token embedding of the model + tokenizer.add_special_tokens({"pad_token": "[PAD]"}) + query_prefix_tokens = torch.LongTensor(tokenizer.encode(args.task.query_prefix)) + query_suffix_tokens = torch.LongTensor(tokenizer.encode(args.task.query_suffix)) + untrained_model = AutoModelForCausalLMWithRewardHead(AutoModelForCausalLM.from_pretrained(args.base_model, use_auth_token=True)).to(device) + reward_model = AutoModelForCausalLMWithRewardHead(AutoModelForCausalLM.from_pretrained(args.base_model, use_auth_token=True)).to(device) + reward_model.pretrained_model.generation_config.eos_token_id = None # disable `pad_token_id` and `eos_token_id` because we just want to + reward_model.pretrained_model.generation_config.pad_token_id = None # generate tokens without truncation / padding + optimizer = optim.Adam(reward_model.parameters(), lr=args.lr, eps=args.eps) + dataset = MyDataset( + DATASET[args.task.query_dataset], + tokenizer, + args.task.query_length, + start_text=args.task.start_text, + end_text=args.task.end_text, + ) + dataloader = DataLoader(dataset, batch_size=args.local_rollout_batch_size) + reward_model, optimizer, dataloader = accelerator.prepare(reward_model, optimizer, dataloader) + iter_dataloader = iter(dataloader) + + generation_config = GenerationConfig( + max_new_tokens=args.task.response_length, + min_new_tokens=args.task.response_length, + temperature=args.task.temperature, + top_k=0.0, + top_p=1.0, + do_sample=True, + ) + + # `label` has keys `['sample0', 'query', 'best', 'sample3', 'sample1', 'sample2']` + label = load_dataset( + "vwxyzjn/lm-human-preferences", + data_files=[args.label_dataset], + )["train"] + print("Num labels found in source:", len(label)) + print("training on", args.labels.num_train, "in batches of", args.local_batch_size) + + print("before====", reward_model.module.reward_gain.data) + if args.normalize_before: + normalize(args, accelerator, device, tokenizer, accelerator.unwrap_model(reward_model).pretrained_model, reward_model, iter_dataloader, generation_config, query_prefix_tokens, query_suffix_tokens) + print("after====", reward_model.module.reward_gain.data) + + print("===training reward model===") + all_inds = np.arange(args.labels.num_train) + np.random.shuffle(all_inds) + # ensure that all processes have the same shuffled indices + all_inds = broadcast(torch.tensor(all_inds, device=device), 0) + all_inds = all_inds.cpu().numpy() + global_step = 0 + for start in range(0, args.labels.num_train, args.batch_size): + global_step += 1 + end = start + args.batch_size + b_inds_all = all_inds[start:end] + b_inds = b_inds_all[accelerator.process_index::accelerator.num_processes] # multi-GPU slicing + lr = (1 - start / args.labels.num_train) * args.lr + optimizer.param_groups[0]["lr"] = lr + mb_data = label[b_inds] + # print("accelerator.process_index", accelerator.process_index, b_inds, b_inds_all) + mb_query = torch.from_numpy(np.stack(mb_data["query"])) + mb_query = format_query(query_prefix_tokens, mb_query, query_suffix_tokens) + mb_query = left_padding_to_right_padding(mb_query, tokenizer.pad_token_id).to(device) + mb_best = torch.from_numpy(np.stack(mb_data["best"])).to(device) + mb_responses = [ + torch.from_numpy(np.stack(mb_data[f"sample{i}"])).to(device) + for i in range(args.labels.num_labels) + ] + # hack: deal with openai's padding token + # assert (mb_query == tokenizer.pad_token_id).sum() == 0 + mb_query[mb_query == OPENAI_PAD_TOKEN_ID] = tokenizer.pad_token_id + for item in mb_responses: + # assert (item == tokenizer.pad_token_id).sum() == 0 + item[item == OPENAI_PAD_TOKEN_ID] = tokenizer.pad_token_id + + predicted_rewards = [] + for i in range(args.labels.num_labels): + query_responses = torch.cat([mb_query, mb_responses[i]], dim=1) + reward = get_reward(reward_model, query_responses, tokenizer)[1] + predicted_rewards.append( + reward.squeeze() + ) + predicted_rewards = torch.stack( + predicted_rewards, dim=1 + ) # shape (batch_size, num_labels), basically a reward prediction for each label + accuracy = (predicted_rewards.argmax(1) == mb_best).float().mean() + loss = torch.nn.functional.cross_entropy(predicted_rewards, mb_best) + optimizer.zero_grad() + accelerator.backward(loss) + optimizer.step() + writer.add_scalar("train/loss", accelerator.gather(loss).mean().item(), global_step) + writer.add_scalar("train/accuracy", accelerator.gather(accuracy).mean().item(), global_step) + writer.add_scalar("train/lr", lr, global_step) + + if args.print_sample_output_freq > 0 and global_step % args.print_sample_output_freq == 0: + with torch.no_grad(): + data = next(iter_dataloader) + queries = data["input_ids"].to(device) + queries = format_query(query_prefix_tokens, queries, query_suffix_tokens) + context_length = queries.shape[1] + queries = left_padding_to_right_padding(queries, tokenizer.pad_token_id).to(device) + query_responses = generate(accelerator.unwrap_model(reward_model).pretrained_model, queries, tokenizer, generation_config) + responses = query_responses[:, context_length:] + + + output, reward = get_reward(reward_model, query_responses, tokenizer) + logits = output.logits[:,context_length-1:-1] + logits /= args.task.temperature + all_logprobs = F.log_softmax(logits, dim=-1) + logprobs = torch.gather(all_logprobs, 2, responses.unsqueeze(-1)).squeeze(-1) + + output, _ = get_reward(untrained_model, query_responses, tokenizer) + logits = output.logits[:,context_length-1:-1] + logits /= args.task.temperature + all_logprobs = F.log_softmax(logits, dim=-1) + ref_logprobs = torch.gather(all_logprobs, 2, responses.unsqueeze(-1)).squeeze(-1) + + print(f"global_step {global_step}:") + kl = logprobs - ref_logprobs + console.print( + f"[green]{tokenizer.decode(queries[0], skip_special_tokens=True)}[/]" + f"\n[blue]{tokenizer.decode(responses[0], skip_special_tokens=True)}[/]" + f"\n[red]reward: {reward[0].item()}[/]" + f"\n[red]kl: {kl[0].sum().item()}[/]" + f"\n[red]average kl: {kl.sum(1).mean().item()}[/]" + ) + writer.add_scalar("train/kl", kl.sum(1).mean().item(), global_step) + + # eval on test_label + test_accuracies = [] + all_inds = np.arange(len(label)) + for start in range(args.labels.num_train, len(label), args.batch_size): + end = start + args.batch_size + b_inds_all = all_inds[start:end] + b_inds = b_inds_all[accelerator.process_index::accelerator.num_processes] # multi-GPU slicing + mb_data = label[b_inds] + # print("accelerator.process_index", accelerator.process_index, b_inds, b_inds_all) + mb_query = torch.from_numpy(np.stack(mb_data["query"])) + mb_query = format_query(query_prefix_tokens, mb_query, query_suffix_tokens) + mb_query = left_padding_to_right_padding(mb_query, tokenizer.pad_token_id).to(device) + mb_best = torch.from_numpy(np.stack(mb_data["best"])).to(device) + mb_responses = [ + torch.from_numpy(np.stack(mb_data[f"sample{i}"])).to(device) + for i in range(args.labels.num_labels) + ] + # hack: deal with openai's padding token + # assert (mb_query == tokenizer.pad_token_id).sum() == 0 + mb_query[mb_query == OPENAI_PAD_TOKEN_ID] = tokenizer.pad_token_id + for item in mb_responses: + # assert (item == tokenizer.pad_token_id).sum() == 0 + item[item == OPENAI_PAD_TOKEN_ID] = tokenizer.pad_token_id + + predicted_rewards = [] + for i in range(args.labels.num_labels): + query_responses = torch.cat([mb_query, mb_responses[i]], dim=1) + reward = get_reward(reward_model, query_responses, tokenizer)[1] + predicted_rewards.append( + reward.squeeze() + ) + predicted_rewards = torch.stack( + predicted_rewards, dim=1 + ) # shape (batch_size, num_labels), basically a reward prediction for each label + accuracy = (predicted_rewards.argmax(1) == mb_best).float().mean() + test_accuracies.append(accuracy) + test_accuracy = accelerator.gather(torch.stack(test_accuracies).mean()).mean().item() + writer.add_scalar("test/accuracy", test_accuracy, global_step) + if accelerator.is_main_process: + print("test/accuracy", test_accuracy, global_step) + + torch.cuda.empty_cache() + if args.normalize_after: + normalize(args, accelerator, device, tokenizer, accelerator.unwrap_model(reward_model).pretrained_model, reward_model, iter_dataloader, generation_config, query_prefix_tokens, query_suffix_tokens) + + # save model + if args.save_path: + os.makedirs(os.path.dirname(args.save_path), exist_ok=True) + torch.save(accelerator.unwrap_model(reward_model).state_dict(), args.save_path) + + if accelerator.is_main_process and args.track: + wandb.finish() + + +if __name__ == "__main__": + args = tyro.cli(Args) + train(args) diff --git a/lm_human_preference_details/summarization/train_reward_accelerate_summarize_debug.py b/lm_human_preference_details/summarization/train_reward_accelerate_summarize_debug.py new file mode 100644 index 0000000..e52f3ee --- /dev/null +++ b/lm_human_preference_details/summarization/train_reward_accelerate_summarize_debug.py @@ -0,0 +1,977 @@ +import os +import random +import time +from dataclasses import asdict, dataclass, field +from types import SimpleNamespace +from typing import List, Optional + +import numpy as np +import pandas as pd +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +import tyro +from accelerate import Accelerator +from accelerate.state import AcceleratorState +from datasets import load_dataset +from rich.console import Console +from rich.pretty import pprint +from rich.table import Table +from torch import Tensor, optim +from torch.optim.optimizer import ( + _dispatch_sqrt, + _get_value, + _use_grad_for_differentiable, +) +from torch.utils.data import DataLoader +from torch.utils.tensorboard import SummaryWriter +from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig + +from lm_human_preference_details.data import process_query + + +@dataclass +class AdaptiveKLParams: + target: float = 6.0 + horizon: int = 10000 # in episodes + + +@dataclass +class RewardHParams: + kl_coef: float = 0.15 + use_adaptive_kl: bool = True + adaptive_kl: Optional[AdaptiveKLParams] = field(default_factory=AdaptiveKLParams) + trained_model: Optional[str] = "models/reward.pt" + label_dataset: tyro.conf.Suppress[Optional[str]] = None + + +@dataclass +class PpoHParams: + total_episodes: int = 1000000 + local_batch_size: int = 64 + local_mini_batch_size: tyro.conf.Suppress[int] = None + batch_size: tyro.conf.Suppress[int] = None + mini_batch_size: tyro.conf.Suppress[int] = None + gradient_accumulation_steps: int = 1 + """gradient accumulation steps""" + local_micro_batch_size: tyro.conf.Suppress[int] = None + """per rank micro batch size""" + world_size: tyro.conf.Suppress[int] = None + batch_size: tyro.conf.Suppress[int] = None + minibatch_size: tyro.conf.Suppress[int] = None + num_updates: tyro.conf.Suppress[int] = None + nminibatches: int = 1 + noptepochs: int = 4 + lr: float = 0.00001 + eps: float = 1e-5 + vf_coef: float = 0.1 + cliprange: float = 0.2 + cliprange_value: float = 0.2 + gamma: float = 1 + lam: float = 0.95 + whiten_rewards: bool = True + + +@dataclass +class TaskHParams: + # Query params + query_length: int = 512 + query_dataset: str = "vwxyzjn/summarize_from_feedback_tldr_3_filtered" + + query_format_str: Optional[str] = "SUBREDDIT: r/{subreddit}\n\nTITLE: {title}\n\nPOST: {post}\n\nTL;DR:" + query_truncate_field: Optional[str] = "post" + query_truncate_text: Optional[str] = "\n" + query_padding: Optional[str] = None # defaults to repeated spaces + query_pad_side: Optional[str] = "left" + + # Response params + response_length: int = 48 + + # Truncate response after the first occurrence of this token at or after index after when sampling. + truncate_token: int = 50256 # EOS token + truncate_after: int = 16 + penalty_reward_value: int = -1 + + # LM params + temperature: float = 0.7 + + +# a patch +@dataclass +class TaskQueryHParams: + length: int = None + dataset: str = None + format_str: Optional[str] = None # if underlying dataset yields dicts, can format arbitrarily + truncate_field: Optional[str] = None + truncate_text: Optional[str] = None + padding: Optional[str] = None # defaults to repeated spaces + pad_side: Optional[str] = None + + +@dataclass +class Args: + # common args + exp_name: str = os.path.basename(__file__)[: -len(".py")] + """the name of this experiment""" + seed: int = 1 + """seed of the experiment""" + track: bool = False + """if toggled, this experiment will be tracked with Weights and Biases""" + wandb_project_name: str = "cleanrl" + """the wandb's project name""" + wandb_entity: Optional[str] = None + """the entity (team) of wandb's project""" + cuda: bool = True + """Whether to use cuda if available.""" + run_name: tyro.conf.Suppress[str] = None + """TO BE FILLED: a unique name of this run""" + upload_model: bool = False + "whether to upload the saved model to huggingface" + hf_entity: str = "" + "the user or org name of the model repository from the Hugging Face Hub" + + base_model: str = "gpt2" + """the name of the pretrained model to use""" + deepspeed: bool = False + """Whether to use deepspeed to train the model""" + print_sample_output_freq: int = 10 + """How often to print sample output""" + sft_model_path: str = "models/sft_policy.pt" + """Where to load the SFT model""" + save_path: str = "models/policy.pt" + """Where to save the model""" + use_tensorflow_adam: bool = True + """Whether to use tensorflow-style Adam optimizer instead of PyTorch's""" + task: TaskHParams = field(default_factory=TaskHParams) + rewards: RewardHParams = field(default_factory=RewardHParams) + ppo: PpoHParams = field(default_factory=PpoHParams) + + +def first_true_indices(bools, dtype=torch.long): + """ + Takes an N-dimensional bool tensor and returns an (N-1)-dimensional tensor of integers giving + the position of the first True in each "row". + + Returns the length of the rows (bools.size(-1)) if no element is True in a given row. + """ + row_len = bools.size(-1) + zero_or_index = row_len * (~bools).type(dtype) + torch.arange(row_len, dtype=dtype, device=bools.device) + return torch.min(zero_or_index, dim=-1).values + + +def print_rich_table(title: str, df: pd.DataFrame, console: Console) -> Table: + table = Table(show_lines=True) + for column in df.columns: + table.add_column(column) + for _, row in df.iterrows(): + table.add_row(*row.astype(str).tolist()) + console.rule(f"[bold red]{title}") + console.print(table) + + +def _single_tensor_adam( + params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_avg_sqs: List[Tensor], + max_exp_avg_sqs: List[Tensor], + state_steps: List[Tensor], + grad_scale: Optional[Tensor], + found_inf: Optional[Tensor], + *, + amsgrad: bool, + beta1: float, + beta2: float, + lr: float, + weight_decay: float, + eps: float, + maximize: bool, + capturable: bool, + differentiable: bool, +): + assert grad_scale is None and found_inf is None + + for i, param in enumerate(params): + grad = grads[i] if not maximize else -grads[i] + exp_avg = exp_avgs[i] + exp_avg_sq = exp_avg_sqs[i] + step_t = state_steps[i] + # update step + step_t += 1 + # Decay the first and second moment running average coefficient + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2) + step = _get_value(step_t) + + ### pytorch adam implementation: + # bias_correction1 = 1 - beta1 ** step + # bias_correction2 = 1 - beta2 ** step + # step_size = lr / bias_correction1 + # bias_correction2_sqrt = _dispatch_sqrt(bias_correction2) + # denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps) + # param.addcdiv_(exp_avg, denom, value=-step_size) + + ### tensorflow adam implementation: + lr_t = lr * _dispatch_sqrt(1 - beta2**step) / (1 - beta1**step) + denom = exp_avg_sq.sqrt().add_(eps) + param.addcdiv_(exp_avg, denom, value=-lr_t) + + +def adam( + params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_avg_sqs: List[Tensor], + max_exp_avg_sqs: List[Tensor], + state_steps: List[Tensor], + # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627 + # setting this as kwarg for now as functional API is compiled by torch/distributed/optim + foreach: Optional[bool] = None, + capturable: bool = False, + differentiable: bool = False, + fused: Optional[bool] = None, + grad_scale: Optional[Tensor] = None, + found_inf: Optional[Tensor] = None, + *, + amsgrad: bool, + beta1: float, + beta2: float, + lr: float, + weight_decay: float, + eps: float, + maximize: bool, +): + func = _single_tensor_adam + + func( + params, + grads, + exp_avgs, + exp_avg_sqs, + max_exp_avg_sqs, + state_steps, + amsgrad=amsgrad, + beta1=beta1, + beta2=beta2, + lr=lr, + weight_decay=weight_decay, + eps=eps, + maximize=maximize, + capturable=capturable, + differentiable=differentiable, + grad_scale=grad_scale, + found_inf=found_inf, + ) + + +class AdamTensorFlowStyle(optim.Adam): + @_use_grad_for_differentiable + def step(self, closure=None): + self._cuda_graph_capture_health_check() + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + params_with_grad = [] + grads = [] + exp_avgs = [] + exp_avg_sqs = [] + max_exp_avg_sqs = [] + state_steps = [] + beta1, beta2 = group["betas"] + + self._init_group( + group, + params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + max_exp_avg_sqs, + state_steps, + ) + + adam( + params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + max_exp_avg_sqs, + state_steps, + amsgrad=group["amsgrad"], + beta1=beta1, + beta2=beta2, + lr=group["lr"], + weight_decay=group["weight_decay"], + eps=group["eps"], + maximize=group["maximize"], + foreach=group["foreach"], + capturable=group["capturable"], + differentiable=group["differentiable"], + fused=group["fused"], + grad_scale=getattr(self, "grad_scale", None), + found_inf=getattr(self, "found_inf", None), + ) + + return loss + + +class AdaptiveKLController: + def __init__(self, init_kl_coef: float, hparams: AdaptiveKLParams): + self.value = init_kl_coef + self.hparams = hparams + + def update(self, current, n_steps): + target = self.hparams.target + proportional_error = np.clip(current / target - 1, -0.2, 0.2) + mult = 1 + proportional_error * n_steps / self.hparams.horizon + self.value *= mult + + +def layer_init(layer, std=np.sqrt(2), bias_const=0.0): + torch.nn.init.normal_(layer.weight, std=std) + torch.nn.init.constant_(layer.bias, val=bias_const) + return layer + + +def whiten(values, shift_mean=True): + # `unbiased=False` matches TF `tf.nn.moments`'s setting + mean, var = torch.mean(values), torch.var(values, unbiased=False) + whitened = (values - mean) * torch.rsqrt(var + 1e-8) + if not shift_mean: + whitened += mean + return whitened + + +class AutoModelForCausalLMWithRewardHead(nn.Module): + def __init__(self, lm_backbone): + super().__init__() + self.lm_backbone = lm_backbone + self.scalar_head = layer_init( + nn.Linear(lm_backbone.config.hidden_size, 1), + std=1 / np.sqrt(lm_backbone.config.hidden_size + 1), + ) + self.reward_gain = torch.nn.Parameter(torch.tensor(1.0), requires_grad=True) + self.reward_bias = torch.nn.Parameter(torch.tensor(0.0), requires_grad=True) + + def forward(self, **kwargs): + output = self.lm_backbone(**kwargs) + reward_latents = output.hidden_states[-1] + # shape: [batch_size, length, hidden_size] + last_reward_latents = reward_latents + # shape: [batch_size, hidden_size] + reward = self.scalar_head(last_reward_latents) + # shape: [batch_size, 1] + reward = self.reward_gain * reward + self.reward_bias + return output, reward + + +def right_padding_to_left_padding(tokens, pad_id): + """Convert from right padding to left padding.""" + assert tokens.ndim == 2 + return torch.tensor( + [[pad_id] * (row == pad_id).sum() + [x for x in row if x != pad_id] for row in tokens], + device=tokens.device, + ) + + +def ceil_div(a, b): + return (a - 1) // b + 1 + + +def exact_div(a, b): + q = a // b + if a != q * b: + raise ValueError(f"Inexact division: {a} / {b} = {a / b}") + return q + + +def generate(lm_backbone, queries, tokenizer, generation_config): + """generate in a way that does not affect padding tokens""" + context_length = queries.shape[1] + attention_mask = queries != tokenizer.pad_token_id + input_ids = queries.clone() + input_ids[~attention_mask] = 0 # set padding tokens to 0 + output = lm_backbone.generate( + input_ids=input_ids, + attention_mask=attention_mask, + # position_ids=attention_mask.cumsum(1) - attention_mask.long(), # generation collapsed if this was turned on. TODO: why does generation collapse with this? + generation_config=generation_config, + return_dict_in_generate=True, + ) + # restore padding tokens + return torch.cat((queries, output.sequences[:, context_length:]), dim=1) + + +def get_reward(reward_model, query_responses, args): + attention_mask = query_responses != args.pad_token_id + position_ids = attention_mask.cumsum(1) - attention_mask.long() # exclusive cumsum + input_ids = query_responses.clone() + input_ids[~attention_mask] = 0 + return reward_model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + return_dict=True, + output_hidden_states=True, + ) + + +def get_reward_complete(reward_model, query_responses, args): + reward = get_reward(reward_model, query_responses, args)[1] + last_response_indices = first_true_indices(query_responses == args.pad_token_id) - 1 + last_response_indices = torch.max( + last_response_indices, + torch.zeros([1], dtype=last_response_indices.dtype, device=query_responses.device), + ) + return reward[:, :, 0].gather(1, last_response_indices.unsqueeze(1)).view(-1) + + +def normalize( + tokenizer, + accelerator, + device, + lm_backbone, + reward_model, + dataloader, + validation_dataloader, +): + idx = 0 + with torch.no_grad(): + # reset reward scales + # accelerator.unwrap_model(reward_model).reward_gain.data.fill_(1.0) + # accelerator.unwrap_model(reward_model).reward_bias.data.fill_(0.0) + # number of minibatches for computing the normalization statistics + rewards = [] + for data in dataloader: + idx += len(data["query_token"]) + queries = data["query_token"].to(device) + queries = right_padding_to_left_padding(queries, tokenizer.pad_token_id).to(device) + reference_response = data["reference_response"].to(device) + query_responses = torch.cat((queries, reference_response), dim=1) + score = get_reward_complete(reward_model, query_responses, tokenizer) + accelerator.print(score.shape, accelerator.gather(score).mean()) + rewards.append(score) + accelerator.print(f"====number of samples per device: {idx}") + rewards = torch.cat(rewards) + rewards = accelerator.gather(rewards) + mean, std = rewards.mean(), rewards.std() + print(f"mean: {mean}, std: {std}") + + # reward normalization + target_mean, target_std = torch.tensor(0.0, device=device), torch.tensor(1.0, device=device) + gain = target_std / std + bias = target_mean - gain * mean + print(f"gain: {gain}, bias: {bias}") + accelerator.unwrap_model(reward_model).reward_gain.data = gain + accelerator.unwrap_model(reward_model).reward_bias.data = bias + + # validate normalization + rewards = [] + for data in validation_dataloader: + queries = data["query_token"].to(device) + queries = right_padding_to_left_padding(queries, tokenizer.pad_token_id).to(device) + reference_response = data["reference_response"].to(device) + query_responses = torch.cat((queries, reference_response), dim=1) + score = get_reward_complete(reward_model, query_responses, tokenizer) + rewards.append(score) + rewards = torch.cat(rewards) + rewards = accelerator.gather(rewards) + mean, std = rewards.mean(), rewards.std() + print(f"after mean: {mean}, after std: {std}") + + +def forward(policy, query_responses, tokenizer): + attention_mask = query_responses != tokenizer.pad_token_id + position_ids = attention_mask.cumsum(1) - attention_mask.long() # exclusive cumsum + input_ids = query_responses.clone() + input_ids[~attention_mask] = 0 + return policy( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + return_dict=True, + output_hidden_states=True, + ) + + +if __name__ == "__main__": + args = tyro.cli(Args) + accelerator = Accelerator(gradient_accumulation_steps=args.ppo.gradient_accumulation_steps) + args.ppo.world_size = accelerator.num_processes + args.ppo.batch_size = int(args.ppo.local_batch_size * args.ppo.world_size) + args.ppo.minibatch_size = exact_div(args.ppo.batch_size, args.ppo.nminibatches) + args.ppo.local_mini_batch_size = exact_div(args.ppo.local_batch_size, args.ppo.nminibatches) + args.ppo.local_micro_batch_size = exact_div(args.ppo.local_mini_batch_size, args.ppo.gradient_accumulation_steps) + patch_h = TaskQueryHParams( + length=args.task.query_length, + dataset=args.task.query_dataset, + format_str=args.task.query_format_str, + truncate_field=args.task.query_truncate_field, + truncate_text=args.task.query_truncate_text, + padding=args.task.query_padding, + pad_side=args.task.query_pad_side, + ) + if args.ppo.whiten_rewards: + assert ( + args.ppo.local_mini_batch_size >= 8 + ), f"Per-rank minibatch size {args.ppo.local_mini_batch_size} is insufficient for whitening" + # `per_rank_rollout_batch_size` is our `args.ppo.local_batch_size` + # `per_rank_minibatch_size` is our `args.ppo.local_mini_batch_size` + args.ppo.num_updates = args.ppo.total_episodes // args.ppo.batch_size + + console = Console(force_terminal=True) + run_name = f"{args.exp_name}__{args.seed}__{int(time.time())}" + writer = SimpleNamespace() # dummy writer + writer.add_scalar = lambda x, y, z: None + writer.add_histogram = lambda x, y, z: None + if accelerator.is_main_process: + if args.track: + import wandb + + wandb.init( + project=args.wandb_project_name, + entity=args.wandb_entity, + sync_tensorboard=True, + config=asdict(args), + name=run_name, + save_code=True, + ) + wandb.run.log_code(".") + writer = SummaryWriter(f"runs/{run_name}") + writer.add_text( + "hyperparameters", + "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])), + ) + pprint(args) + device = accelerator.device + local_seed = args.seed + accelerator.process_index * 100003 # Prime + random.seed(local_seed) + np.random.seed(local_seed) + torch.manual_seed(local_seed) + torch.backends.cudnn.deterministic = True + tokenizer = AutoTokenizer.from_pretrained( + args.base_model, + padding_side="right", + trust_remote_code=True, + ) + # we use the padding token manually but do not resize the token embedding of the model + tokenizer.add_special_tokens({"pad_token": "[PAD]"}) + reward_model = AutoModelForCausalLMWithRewardHead( + AutoModelForCausalLM.from_pretrained(args.base_model, trust_remote_code=True) + ) + if args.rewards.trained_model: + reward_model.load_state_dict(torch.load(args.rewards.trained_model, map_location=device)) + print(f"loaded pretrained reward model from {args.rewards.trained_model}") + # # each class should have a separate pretrained model that do not share weights + # ref_policy = AutoModelForCausalLMWithScalarHead( + # AutoModelForCausalLM.from_pretrained(args.base_model, trust_remote_code=True) + # ) + # policy = AutoModelForCausalLMWithScalarHead(AutoModelForCausalLM.from_pretrained(args.base_model, trust_remote_code=True)) + # if args.sft_model_path: + # policy.lm_backbone.load_state_dict(torch.load(args.sft_model_path, map_location=device)) + # ref_policy.lm_backbone.load_state_dict(torch.load(args.sft_model_path, map_location=device)) + # print(f"loaded pretrained policy from {args.sft_model_path}") + # policy.lm_backbone.generation_config.eos_token_id = ( + # None # disable `pad_token_id` and `eos_token_id` because we just want to + # ) + # policy.lm_backbone.generation_config.pad_token_id = None # generate tokens without truncation / padding + # # IMPORTANT: Layer norm produces weird gradients, which affects Adam optimizer to impact all the parameters systematically + # # see https://github.com/pytorch/pytorch/issues/104857 for more details + # if args.use_tensorflow_adam: + # optimizer = AdamTensorFlowStyle(policy.parameters(), lr=args.ppo.lr, eps=args.ppo.eps) + # else: + # optimizer = optim.Adam(policy.parameters(), lr=args.ppo.lr, eps=args.ppo.eps) + dataset = load_dataset(args.task.query_dataset, split="train") + validation_dataset = load_dataset(args.task.query_dataset, split="validation") + + def process_query_data(x): + return { + **process_query(x, encoder=tokenizer, hparams=patch_h), + "reference_response": tokenizer.encode( + f" {x['summary']}", padding="max_length", max_length=args.task.response_length, truncation=True, + # with an extra leading space to account for the space between the query and response + ), + } + + dataset = dataset.map(process_query_data) + dataset = dataset.with_format("torch", columns=["query_token", "reference_response"]) + dataset = dataset.shuffle(seed=local_seed) + dataloader = DataLoader(dataset, batch_size=args.ppo.local_batch_size) + validation_dataset = validation_dataset.map(process_query_data) + validation_dataset = validation_dataset.with_format("torch", columns=["query_token", "reference_response"]) + validation_dataset = validation_dataset.shuffle(seed=local_seed) + validation_dataloader = DataLoader(validation_dataset, batch_size=args.ppo.local_batch_size) + dataloader = accelerator.prepare(dataloader) + if args.deepspeed: + import deepspeed + + deepspeed_states = AcceleratorState().deepspeed_plugin + # deepspeed_states.deepspeed_config["train_micro_batch_size_per_gpu"] = args.ppo.local_micro_batch_size + # deepspeed_states.deepspeed_config["checkpoint"] = {"use_node_local_storage": True} + eval_ds_config = { + "train_micro_batch_size_per_gpu": deepspeed_states.deepspeed_config["train_micro_batch_size_per_gpu"], + # "steps_per_print": 10, + # "zero_optimization": { + # "stage": stage, + # "stage3_param_persistence_threshold": 1e4, + # "offload_param": { + # "device": off_load_device + # } + # }, + "bf16": {"enabled": True}, + "prescale_gradients": False, + "wall_clock_breakdown": False, + } + reward_model, *_ = deepspeed.initialize(model=reward_model, config=eval_ds_config) + reward_model.eval() + else: + reward_model = reward_model.to(device) + + def repeat_generator(): # TODO: ideally we shuffle the dataloader as well + while True: + yield from dataloader + + iter_dataloader = iter(repeat_generator()) + kl_ctl = AdaptiveKLController(args.rewards.kl_coef, hparams=args.rewards.adaptive_kl) + # WARNING: even with `max_new_tokens` and `min_new_tokens` set to the same value, the number of tokens generated + # may not be the same. TODO: investigate further, we just want to generate a fixed number of tokens + generation_config = GenerationConfig( + max_new_tokens=args.task.response_length, + min_new_tokens=args.task.response_length, + temperature=(args.task.temperature + 1e-7), + top_k=0.0, + top_p=1.0, + do_sample=True, + ) + + print("===Normalize reward model *before* training===") + print( + "before normalization. " + + f"Gain: {accelerator.unwrap_model(reward_model).reward_gain.data}" + + f" Bias: {accelerator.unwrap_model(reward_model).reward_bias.data}" + ) + + normalize( + tokenizer, + accelerator, + device, + reward_model, + reward_model, + dataloader, + validation_dataloader, + ) + print( + "after normalization. " + + f"Gain: {accelerator.unwrap_model(reward_model).reward_gain.data}" + + f" Bias: {accelerator.unwrap_model(reward_model).reward_bias.data}" + ) + # # save model + # if args.save_path: + # os.makedirs(os.path.dirname("models/correct_reward.pt"), exist_ok=True) + # torch.save(accelerator.unwrap_model(reward_model).state_dict(), "models/correct_reward.pt") + raise + + + print("===training policy===") + global_step = 0 + stats_shape = (args.ppo.noptepochs, args.ppo.nminibatches, args.ppo.gradient_accumulation_steps) + approxkls_stats = torch.zeros(stats_shape, device=device) + clipfracs_stats = torch.zeros(stats_shape, device=device) + pg_losses_stats = torch.zeros(stats_shape, device=device) + vf_losses_stats = torch.zeros(stats_shape, device=device) + vf_clipfrac_stats = torch.zeros(stats_shape, device=device) + entropies_stats = torch.zeros(stats_shape, device=device) + for update in range(1, args.ppo.num_updates + 1): + global_step += 1 * args.ppo.batch_size + frac = 1.0 - (update - 1.0) / args.ppo.num_updates + lrnow = frac * args.ppo.lr + optimizer.param_groups[0]["lr"] = lrnow + data = next(iter_dataloader) + with torch.no_grad(): + queries = data["query_token"].to(device) + reference_responses = data["reference_response"].to(device) + query_reference_responses = torch.cat((queries, reference_responses), dim=1) + queries = right_padding_to_left_padding(data["query_token"], tokenizer.pad_token_id).to(device) + query_reference_responses = right_padding_to_left_padding(query_reference_responses, tokenizer.pad_token_id).to(device) + query_responses = generate( + accelerator.unwrap_model(policy).lm_backbone, + queries, + tokenizer, + generation_config, + ) + context_length = queries.shape[1] + responses = query_responses[:, context_length:] + + output, full_values = forward(policy, query_responses, tokenizer) + values = full_values[:, context_length - 1 : -1].squeeze(-1) + logits = output.logits[:, context_length - 1 : -1] + logits /= (args.task.temperature + 1e-7) + all_logprobs = F.log_softmax(logits, dim=-1) + logprobs = torch.gather(all_logprobs, 2, responses.unsqueeze(-1)).squeeze(-1) + del output, logits, all_logprobs + torch.cuda.empty_cache() + + ref_output, _ = forward(ref_policy, query_responses, tokenizer) + ref_logits = ref_output.logits[:, context_length - 1 : -1] + ref_logits /= (args.task.temperature + 1e-7) + ref_all_logprobs = F.log_softmax(ref_logits, dim=-1) + ref_logprobs = torch.gather(ref_all_logprobs, 2, responses.unsqueeze(-1)).squeeze(-1) + del ref_output, ref_logits, ref_all_logprobs + torch.cuda.empty_cache() + + # **Response Processing** + # 1. truncate at the first occurrence of `truncate_token` that appears at or after + # position truncate_after in the responses + # https://github.com/openai/lm-human-preferences/blob/cbfd210bb8b08f6bc5c26878c10984b90f516c66/lm_human_preferences/train_policy.py#L378 + truncate_token_mask = responses == args.task.truncate_token + truncate_after_or_token_mask = torch.cat( + [ + torch.zeros_like(truncate_token_mask)[:, : args.task.truncate_after], + truncate_token_mask[:, args.task.truncate_after :], + ], + dim=1, + ) + truncate_mask = (torch.cumsum(truncate_after_or_token_mask, dim=1) - truncate_after_or_token_mask.long()).bool() + postprocessed_responses = torch.where( + truncate_mask, + torch.full_like(responses, tokenizer.pad_token_id), + responses, + ) + del truncate_token_mask, truncate_after_or_token_mask, truncate_mask + torch.cuda.empty_cache() + + # 2. run reward model on the truncated responses + postprocessed_query_responses = torch.cat((queries, postprocessed_responses), 1) + postprocessed_query_responses = right_padding_to_left_padding( + postprocessed_query_responses, tokenizer.pad_token_id + ) + scores = get_reward(reward_model, postprocessed_query_responses, tokenizer)[1] + last_response_indices = first_true_indices(query_responses == tokenizer.pad_token_id) - 1 + last_response_indices = torch.max( + last_response_indices, + torch.zeros([1], dtype=last_response_indices.dtype, device=query_responses.device), + ) + scores = scores[:, :, 0].gather(1, last_response_indices.unsqueeze(1)).view(-1) + + reference_scores = get_reward(reward_model, query_reference_responses, tokenizer)[1] + last_reference_response_indices = first_true_indices(query_reference_responses == tokenizer.pad_token_id) - 1 + last_reference_response_indices = torch.max( + last_reference_response_indices, + torch.zeros([1], dtype=last_reference_response_indices.dtype, device=query_reference_responses.device), + ) + reference_scores = reference_scores[:, :, 0].gather(1, last_reference_response_indices.unsqueeze(1)).view(-1) + + print(reference_scores.mean()) + # normalization again + scores = scores - reference_scores + + # 3. filter response. Ensure that the sample contains truncate_token + # responses not passing that filter will receive a low (fixed) score + # only query humans on responses that pass that filter + matches_token = postprocessed_responses[:, args.task.truncate_after :] == args.task.truncate_token + filter_mask = torch.any(matches_token, dim=-1) + scores = torch.where( + filter_mask, + scores, + torch.full_like(scores, args.task.penalty_reward_value), + ) + del matches_token, filter_mask + torch.cuda.empty_cache() + + # 4. compute rewards + kl = logprobs - ref_logprobs + non_score_reward = -kl_ctl.value * kl + rewards = non_score_reward.clone() + rewards[:, -1] += scores + + # 5. whiten rewards + if args.ppo.whiten_rewards: + rewards = whiten(rewards, shift_mean=False) + + if args.print_sample_output_freq > 0 and (update - 1) % args.print_sample_output_freq == 0: + try: + all_decode_queries = tokenizer.batch_decode(queries, skip_special_tokens=True) + all_postprocessed_query_responses = tokenizer.batch_decode( + postprocessed_query_responses, skip_special_tokens=True + ) + all_postprocessed_responses = [ + x[len(y) :] for x, y in zip(all_postprocessed_query_responses, all_decode_queries) + ] + all_reference_responses = tokenizer.batch_decode(reference_responses, skip_special_tokens=True) + + kl_sum = kl.sum(axis=1) + all_df = pd.DataFrame( + { + "query": all_decode_queries, + "response": all_postprocessed_responses, + "reference_responses": all_reference_responses, + "score": scores.float().cpu().numpy(), + "reference_scores": reference_scores.float().cpu().numpy(), + "kl": kl_sum.float().cpu().numpy(), + "reward": (scores - kl_ctl.value * kl_sum).float().cpu().numpy(), + } + ) + if accelerator.is_main_process and args.track: + wandb.log({"query_responses": wandb.Table(dataframe=all_df)}, step=update) + print_rich_table("stuff", all_df[:4], console) + except Exception as e: + print(e) + del ( + all_decode_queries, + all_postprocessed_query_responses, + all_postprocessed_responses, + kl_sum, + all_df, + ) + del postprocessed_query_responses + torch.cuda.empty_cache() + +# # 6. compute advantages and returns +# lastgaelam = 0 +# advantages_reversed = [] +# gen_length = args.task.response_length +# for t in reversed(range(gen_length)): +# nextvalues = values[:, t + 1] if t < gen_length - 1 else 0.0 +# delta = rewards[:, t] + args.ppo.gamma * nextvalues - values[:, t] +# lastgaelam = delta + args.ppo.gamma * args.ppo.lam * lastgaelam +# advantages_reversed.append(lastgaelam) +# advantages = torch.stack(advantages_reversed[::-1], axis=1) +# returns = advantages + values +# advantages = whiten(advantages) +# return_mean, return_var = returns.mean(), returns.var() +# value_mean, value_var = values.mean(), values.var() + +# # Do multiple epochs of PPO training, with a fresh random shuffle in each epoch +# for ppo_epoch_idx in range(args.ppo.noptepochs): +# b_inds = np.random.permutation(args.ppo.local_batch_size) +# minibatch_idx = 0 +# for mini_batch_start in range(0, args.ppo.local_batch_size, args.ppo.local_mini_batch_size): +# mini_batch_end = mini_batch_start + args.ppo.local_mini_batch_size +# mini_batch_inds = b_inds[mini_batch_start:mini_batch_end] +# gradient_accumulation_idx = 0 +# for micro_batch_start in range(0, args.ppo.local_mini_batch_size, args.ppo.local_micro_batch_size): +# with accelerator.accumulate(policy): +# micro_batch_end = micro_batch_start + args.ppo.local_micro_batch_size +# micro_batch_inds = mini_batch_inds[micro_batch_start:micro_batch_end] +# mb_return = returns[micro_batch_inds] +# mb_advantage = advantages[micro_batch_inds] +# mb_values = values[micro_batch_inds] +# mb_responses = responses[micro_batch_inds] +# mb_query_responses = query_responses[micro_batch_inds] +# mb_logprobs = logprobs[micro_batch_inds] + +# output, vpred_temp = forward(policy, mb_query_responses, tokenizer) +# logits = output.logits[:, context_length - 1 : -1] +# logits /= (args.task.temperature + 1e-7) +# new_all_logprobs = F.log_softmax(logits, dim=-1) +# new_logprobs = torch.gather(new_all_logprobs, 2, mb_responses.unsqueeze(-1)).squeeze(-1) +# vpred = vpred_temp[:, context_length - 1 : -1].squeeze(-1) +# vpredclipped = torch.clamp( +# vpred, +# mb_values - args.ppo.cliprange_value, +# mb_values + args.ppo.cliprange_value, +# ) +# vf_losses1 = torch.square(vpred - mb_return) +# vf_losses2 = torch.square(vpredclipped - mb_return) +# vf_loss = 0.5 * torch.max(vf_losses1, vf_losses2).mean() +# vf_clipfrac = (vf_losses2 > vf_losses1).float().mean() +# logprobs_diff = new_logprobs - mb_logprobs +# ratio = torch.exp(logprobs_diff) +# pg_losses = -mb_advantage * ratio +# pg_losses2 = -mb_advantage * torch.clamp(ratio, 1.0 - args.ppo.cliprange, 1.0 + args.ppo.cliprange) +# pg_loss = torch.max(pg_losses, pg_losses2).mean() +# pg_clipfrac = (pg_losses2 > pg_losses).float().mean() +# loss = pg_loss + args.ppo.vf_coef * vf_loss +# accelerator.backward(loss) +# optimizer.step() +# optimizer.zero_grad() +# prob_dist = torch.nn.functional.softmax(logits, dim=-1) +# entropy = torch.logsumexp(logits, dim=-1) - torch.sum(prob_dist * logits, dim=-1) +# approxkl = 0.5 * (logprobs_diff**2).mean() +# with torch.no_grad(): +# approxkls_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = approxkl +# clipfracs_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = pg_clipfrac +# pg_losses_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = pg_loss +# vf_losses_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = vf_loss +# vf_clipfrac_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = vf_clipfrac +# entropies_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = entropy.mean() +# gradient_accumulation_idx += 1 +# minibatch_idx += 1 +# if accelerator.is_main_process: +# console.print( +# f"ppo_epoch_idx", +# ppo_epoch_idx, +# "approxkl", +# approxkl.item(), +# "pg_loss", +# pg_loss.item(), +# "pg_clipfrac", +# pg_clipfrac.item(), +# "ratio", +# ratio.mean().item(), +# ) + +# with torch.no_grad(): +# if not args.deepspeed: # for some reason there is a OOM with the `writer.add_histogram` +# writer.add_histogram("ppo/val/ratio_hist", ratio, update) +# kl = logprobs - ref_logprobs +# mean_kl = kl.sum(1).mean() +# mean_entropy = (-logprobs).sum(1).mean() +# mean_non_score_reward = non_score_reward.sum(1).mean() +# writer.add_scalar("objective/kl_coef", kl_ctl.value, update) +# writer.add_scalar("objective/kl", accelerator.gather(mean_kl).mean().item(), update) +# writer.add_scalar("objective/entropy", accelerator.gather(mean_entropy).mean().item(), update) +# writer.add_scalar("objective/non_score_reward", accelerator.gather(mean_non_score_reward).mean().item(), update) +# writer.add_scalar( +# "objective/score_total", accelerator.gather(mean_non_score_reward + scores.mean()).mean().item(), update +# ) +# writer.add_scalar("objective/scores", accelerator.gather(scores.mean()).mean().item(), update) +# writer.add_scalar("objective/reference_scores", accelerator.gather(reference_scores.mean()).mean().item(), update) +# writer.add_scalar("ppo/loss/policy", accelerator.gather(pg_loss).mean().item(), update) +# writer.add_scalar("ppo/loss/value", accelerator.gather(vf_loss).mean().item(), update) +# writer.add_scalar("ppo/loss/total", accelerator.gather(loss).mean().item(), update) +# writer.add_scalar("ppo/policy/entropy", accelerator.gather(entropy.mean()).mean().item(), update) +# writer.add_scalar("ppo/policy/approxkl", accelerator.gather(approxkl).mean().item(), update) +# writer.add_scalar("ppo/policy/clipfrac", accelerator.gather(pg_clipfrac).mean().item(), update) +# writer.add_scalar("ppo/policy/approxkl_avg", accelerator.gather(approxkls_stats).mean().item(), update) +# writer.add_scalar("ppo/policy/clipfrac_avg", accelerator.gather(clipfracs_stats).mean().item(), update) +# writer.add_scalar("ppo/loss/policy_avg", accelerator.gather(pg_losses_stats).mean().item(), update) +# writer.add_scalar("ppo/loss/value_avg", accelerator.gather(vf_losses_stats).mean().item(), update) +# writer.add_scalar("ppo/val/clipfrac_avg", accelerator.gather(vf_clipfrac_stats).mean().item(), update) +# writer.add_scalar("ppo/policy/entropy_avg", accelerator.gather(entropies_stats).mean().item(), update) +# writer.add_scalar("ppo/returns/mean", accelerator.gather(return_mean).mean().item(), update) +# writer.add_scalar("ppo/returns/var", accelerator.gather(return_var).mean().item(), update) +# writer.add_scalar("ppo/val/vpred", accelerator.gather(vpred.mean()).mean().item(), update) +# writer.add_scalar("ppo/val/error", accelerator.gather(vf_losses1.mean()).mean().item(), update) +# writer.add_scalar("ppo/val/clipfrac", accelerator.gather(vf_clipfrac).mean().item(), update) +# writer.add_scalar("ppo/val/mean", accelerator.gather(value_mean).mean().item(), update) +# writer.add_scalar("ppo/val/var", accelerator.gather(value_var).mean().item(), update) +# writer.add_scalar("ppo/val/ratio", accelerator.gather(ratio.mean()).mean().item(), update) +# writer.add_scalar("ppo/val/ratio_var", accelerator.gather(ratio.mean()).var().item(), update) +# writer.add_scalar("ppo/val/advantage", accelerator.gather(advantages.mean()).mean().item(), update) +# writer.add_scalar("ppo/val/advantage_var", accelerator.gather(advantages.mean()).var().item(), update) +# writer.add_scalar("ppo/val/num_eos_tokens", (responses == tokenizer.eos_token_id).sum().item(), update) +# writer.add_scalar("ppo/lr", lrnow, update) +# writer.add_scalar("ppo/episode", global_step, update) +# if args.rewards.use_adaptive_kl: +# kl_ctl.update(mean_kl.item(), args.ppo.batch_size) +# del kl, mean_kl, mean_entropy, mean_non_score_reward, scores + +# # save model +# if accelerator.is_main_process and args.save_path: +# os.makedirs(os.path.dirname(args.save_path), exist_ok=True) +# torch.save(policy.state_dict(), args.save_path) + +# if args.upload_model: +# repo_name = f"{args.exp_name}__{args.rewards.label_dataset}__seed{args.seed}__{int(time.time())}" +# repo_id = f"{args.hf_entity}/{repo_name}" if args.hf_entity else repo_name +# policy.lm_backbone.save_pretrained(repo_id, safe_serialization=True, push_to_hub=True) +# tokenizer.save_pretrained(repo_id, push_to_hub=True) + + +# if __name__ == "__main__": +# args = tyro.cli(Args) +# train(args) diff --git a/lm_human_preference_details/summarization/train_reward_accelerate_summarized.py b/lm_human_preference_details/summarization/train_reward_accelerate_summarized.py new file mode 100644 index 0000000..e9c49f8 --- /dev/null +++ b/lm_human_preference_details/summarization/train_reward_accelerate_summarized.py @@ -0,0 +1,778 @@ +import os +import random +import time +from dataclasses import asdict, dataclass, field +from types import SimpleNamespace +from typing import List, Optional + +import numpy as np +import pandas as pd +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +import transformers +import tyro +from accelerate import Accelerator +from accelerate.state import AcceleratorState +from accelerate.utils import DistributedDataParallelKwargs, broadcast +from datasets import load_dataset +from rich.console import Console +from rich.pretty import pprint +from rich.table import Table +from torch import Tensor, optim +from torch.optim.optimizer import ( + _dispatch_sqrt, + _get_value, + _use_grad_for_differentiable, +) +from torch.utils.data import DataLoader +from torch.utils.tensorboard import SummaryWriter +from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig + +from lm_human_preference_details.data import process_query + + +@dataclass +class LabelHParams: + type: str = None + num_train: int = 64832 + num_labels: int = 2 + source: str = None + + +@dataclass +class TaskHParams: + # Query params + query_length: int = 512 + query_dataset: str = "vwxyzjn/summarize_from_feedback_tldr_3_filtered" + + query_format_str: Optional[str] = "SUBREDDIT: r/{subreddit}\n\nTITLE: {title}\n\nPOST: {post}\n\nTL;DR:" + query_truncate_field: Optional[str] = "post" + query_truncate_text: Optional[str] = "\n" + query_padding: Optional[str] = None # defaults to repeated spaces + query_pad_side: Optional[str] = "left" + + # Response params + response_length: int = 48 + + # LM params + temperature: float = 0.7 + + +# a patch +@dataclass +class TaskQueryHParams: + length: int = None + dataset: str = None + format_str: Optional[str] = None # if underlying dataset yields dicts, can format arbitrarily + truncate_field: Optional[str] = None + truncate_text: Optional[str] = None + padding: Optional[str] = None # defaults to repeated spaces + pad_side: Optional[str] = None + + +@dataclass +class Args: + # common args + exp_name: str = os.path.basename(__file__)[: -len(".py")] + """the name of this experiment""" + seed: int = 1 + """seed of the experiment""" + track: bool = False + """if toggled, this experiment will be tracked with Weights and Biases""" + wandb_project_name: str = "cleanrl" + """the wandb's project name""" + wandb_entity: Optional[str] = None + """the entity (team) of wandb's project""" + cuda: bool = True + """Whether to use cuda if available.""" + run_name: tyro.conf.Suppress[str] = None + """TO BE FILLED: a unique name of this run""" + load_from_cache_file: bool = True + """Whether to load data from the local cache file in `dataset.map`""" + + base_model: str = "gpt2" + """the name of the pretrained model to use""" + deepspeed: bool = False + """Whether to use deepspeed to train the model""" + label_dataset: str = "openai/summarize_from_feedback" + """the name of the dataset to use for labels in `https://huggingface.co/datasets/vwxyzjn/lm-human-preferences`""" + local_batch_size: int = 4 + """per rank batch size""" + gradient_accumulation_steps: int = 1 + """gradient accumulation steps""" + local_micro_batch_size: tyro.conf.Suppress[int] = None + """per rank micro batch size""" + lr: float = 0.00005 + """the learning rate""" + eps: float = 1e-5 + """the epsilon for AdamW""" + local_rollout_batch_size: int = 512 + """per rank rollout batch size""" + rollout_batch_size: tyro.conf.Suppress[int] = None + """rollout batch size""" + world_size: tyro.conf.Suppress[int] = None + """the number of processes to use""" + batch_size: tyro.conf.Suppress[int] = None + """the batch size across all ranks""" + local_normalize_samples: int = 256 + """Samples used to estimate reward mean and std""" + normalize_samples: tyro.conf.Suppress[int] = None + """Samples used to estimate reward mean and std across all ranks""" + debug_normalize: int = 0 + """Samples used to check that normalization worked""" + normalize_before: bool = True + """Whether, before training, to normalize the rewards on the policy to the scales on the training buffer. (For comparisons, just use mean 0, var 1.)""" + normalize_after: bool = True + """Whether, after training, to normalize the rewards on the ref policy to mean 0, var 1 (so the KL coefficient always has the same meaning).""" + print_sample_output_freq: int = 506 + """How often to print sample output""" + sft_model_path: str = "models/sft_policy.pt" + """Where to load the SFT model""" + logsigmoid: bool = True + """Whether to use log-sigmoid loss instead of cross-entropy loss""" + trainable_param_percentage: float = 1.0 + """Percentage of parameters to train""" + save_path: str = "models/reward.pt" + """Where to save the model""" + use_tensorflow_adam: bool = True + """Whether to use tensorflow-style Adam optimizer instead of PyTorch's""" + task: TaskHParams = field(default_factory=TaskHParams) + labels: LabelHParams = field(default_factory=LabelHParams) + + +def first_true_indices(bools, dtype=torch.long): + """ + Takes an N-dimensional bool tensor and returns an (N-1)-dimensional tensor of integers giving + the position of the first True in each "row". + + Returns the length of the rows (bools.size(-1)) if no element is True in a given row. + """ + row_len = bools.size(-1) + zero_or_index = row_len * (~bools).type(dtype) + torch.arange(row_len, dtype=dtype, device=bools.device) + return torch.min(zero_or_index, dim=-1).values + + +def print_rich_table(title: str, df: pd.DataFrame, console: Console) -> Table: + table = Table(show_lines=True) + for column in df.columns: + table.add_column(column) + for _, row in df.iterrows(): + table.add_row(*row.astype(str).tolist()) + console.rule(f"[bold red]{title}") + console.print(table) + + +def _single_tensor_adam( + params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_avg_sqs: List[Tensor], + max_exp_avg_sqs: List[Tensor], + state_steps: List[Tensor], + grad_scale: Optional[Tensor], + found_inf: Optional[Tensor], + *, + amsgrad: bool, + beta1: float, + beta2: float, + lr: float, + weight_decay: float, + eps: float, + maximize: bool, + capturable: bool, + differentiable: bool, +): + assert grad_scale is None and found_inf is None + + for i, param in enumerate(params): + grad = grads[i] if not maximize else -grads[i] + exp_avg = exp_avgs[i] + exp_avg_sq = exp_avg_sqs[i] + step_t = state_steps[i] + # update step + step_t += 1 + # Decay the first and second moment running average coefficient + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2) + step = _get_value(step_t) + + ### pytorch adam implementation: + # bias_correction1 = 1 - beta1 ** step + # bias_correction2 = 1 - beta2 ** step + # step_size = lr / bias_correction1 + # bias_correction2_sqrt = _dispatch_sqrt(bias_correction2) + # denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps) + # param.addcdiv_(exp_avg, denom, value=-step_size) + + ### tensorflow adam implementation: + lr_t = lr * _dispatch_sqrt(1 - beta2**step) / (1 - beta1**step) + denom = exp_avg_sq.sqrt().add_(eps) + param.addcdiv_(exp_avg, denom, value=-lr_t) + + +def adam( + params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_avg_sqs: List[Tensor], + max_exp_avg_sqs: List[Tensor], + state_steps: List[Tensor], + # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627 + # setting this as kwarg for now as functional API is compiled by torch/distributed/optim + foreach: Optional[bool] = None, + capturable: bool = False, + differentiable: bool = False, + fused: Optional[bool] = None, + grad_scale: Optional[Tensor] = None, + found_inf: Optional[Tensor] = None, + *, + amsgrad: bool, + beta1: float, + beta2: float, + lr: float, + weight_decay: float, + eps: float, + maximize: bool, +): + func = _single_tensor_adam + + func( + params, + grads, + exp_avgs, + exp_avg_sqs, + max_exp_avg_sqs, + state_steps, + amsgrad=amsgrad, + beta1=beta1, + beta2=beta2, + lr=lr, + weight_decay=weight_decay, + eps=eps, + maximize=maximize, + capturable=capturable, + differentiable=differentiable, + grad_scale=grad_scale, + found_inf=found_inf, + ) + + +class AdamTensorFlowStyle(optim.Adam): + @_use_grad_for_differentiable + def step(self, closure=None): + self._cuda_graph_capture_health_check() + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + params_with_grad = [] + grads = [] + exp_avgs = [] + exp_avg_sqs = [] + max_exp_avg_sqs = [] + state_steps = [] + beta1, beta2 = group["betas"] + + self._init_group( + group, + params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + max_exp_avg_sqs, + state_steps, + ) + + adam( + params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + max_exp_avg_sqs, + state_steps, + amsgrad=group["amsgrad"], + beta1=beta1, + beta2=beta2, + lr=group["lr"], + weight_decay=group["weight_decay"], + eps=group["eps"], + maximize=group["maximize"], + foreach=group["foreach"], + capturable=group["capturable"], + differentiable=group["differentiable"], + fused=group["fused"], + grad_scale=getattr(self, "grad_scale", None), + found_inf=getattr(self, "found_inf", None), + ) + + return loss + + +def layer_init(layer, std=np.sqrt(2), bias_const=0.0): + torch.nn.init.normal_(layer.weight, std=std) + torch.nn.init.constant_(layer.bias, val=bias_const) + return layer + + +class AutoModelForCausalLMWithRewardHead(nn.Module): + def __init__(self, lm_backbone): + super().__init__() + self.lm_backbone = lm_backbone + self.scalar_head = layer_init( + nn.Linear(lm_backbone.config.hidden_size, 1), + std=1 / np.sqrt(lm_backbone.config.hidden_size + 1), + ) + # self.reward_gain = torch.nn.Parameter(torch.tensor(1.0), requires_grad=True) + # self.reward_bias = torch.nn.Parameter(torch.tensor(0.0), requires_grad=True) + + def forward(self, **kwargs): + output = self.lm_backbone(**kwargs) + last_reward_latents = output.hidden_states[-1] + # shape: [batch_size, hidden_size] + reward = self.scalar_head(last_reward_latents) + return output, reward + + +def right_padding_to_left_padding(tokens, pad_id): + """Convert from right padding to left padding.""" + assert tokens.ndim == 2 + return torch.tensor( + [[pad_id] * (row == pad_id).sum() + [x for x in row if x != pad_id] for row in tokens], + device=tokens.device, + ) + + +def ceil_div(a, b): + return (a - 1) // b + 1 + + +def exact_div(a, b): + q = a // b + if a != q * b: + raise ValueError(f"Inexact division: {a} / {b} = {a / b}") + return q + + +def generate(lm_backbone, queries, tokenizer, generation_config): + """generate in a way that does not affect padding tokens""" + context_length = queries.shape[1] + attention_mask = queries != tokenizer.pad_token_id + input_ids = queries.clone() + input_ids[~attention_mask] = 0 # set padding tokens to 0 + output = lm_backbone.generate( + input_ids=input_ids, + attention_mask=attention_mask, + # position_ids=attention_mask.cumsum(1) - attention_mask.long(), # generation collapsed if this was turned on. TODO: why does generation collapse with this? + generation_config=generation_config, + return_dict_in_generate=True, + ) + # restore padding tokens + return torch.cat((queries, output.sequences[:, context_length:]), dim=1) + + +def get_reward(reward_model, query_responses, tokenizer): + attention_mask = query_responses != tokenizer.pad_token_id + position_ids = attention_mask.cumsum(1) - attention_mask.long() # exclusive cumsum + input_ids = torch.masked_fill(query_responses, ~attention_mask, 0) + return reward_model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + return_dict=True, + output_hidden_states=True, + ) + + +def get_reward_complete(reward_model, query_responses, tokenizer): + reward = get_reward(reward_model, query_responses, tokenizer)[1] + last_response_indices = first_true_indices(query_responses == tokenizer.pad_token_id) - 1 + last_response_indices = torch.max( + last_response_indices, + torch.zeros([1], dtype=last_response_indices.dtype, device=query_responses.device), + ) + return reward[:, :, 0].gather(1, last_response_indices.unsqueeze(1)).view(-1), reward + + +def normalize( + tokenizer, + accelerator, + device, + lm_backbone, + reward_model, + dataloader, + validation_dataloader, +): + idx = 0 + with torch.no_grad(): + # reset reward scales + accelerator.unwrap_model(reward_model).reward_gain.data.fill_(1.0) + accelerator.unwrap_model(reward_model).reward_bias.data.fill_(0.0) + # number of minibatches for computing the normalization statistics + rewards = [] + for data in dataloader: + idx += len(data["query_token"]) + queries = data["query_token"].to(device) + queries = right_padding_to_left_padding(queries, tokenizer.pad_token_id).to(device) + reference_response = data["reference_response"].to(device) + query_responses = torch.cat((queries, reference_response), dim=1) + score = get_reward_complete(reward_model, query_responses, tokenizer) + rewards.append(score) + accelerator.print(f"====number of samples per device: {idx}") + rewards = torch.cat(rewards) + rewards = accelerator.gather(rewards) + mean, std = rewards.mean(), rewards.std() + print(f"mean: {mean}, std: {std}") + + # reward normalization + target_mean, target_std = torch.tensor(0.0, device=device), torch.tensor(1.0, device=device) + gain = target_std / std + bias = target_mean - gain * mean + print(f"gain: {gain}, bias: {bias}") + accelerator.unwrap_model(reward_model).reward_gain.data = gain + accelerator.unwrap_model(reward_model).reward_bias.data = bias + + # validate normalization + rewards = [] + for data in validation_dataloader: + queries = data["query_token"].to(device) + queries = right_padding_to_left_padding(queries, tokenizer.pad_token_id).to(device) + reference_response = data["reference_response"].to(device) + query_responses = torch.cat((queries, reference_response), dim=1) + score = get_reward_complete(reward_model, query_responses, tokenizer) + rewards.append(score) + rewards = torch.cat(rewards) + rewards = accelerator.gather(rewards) + mean, std = rewards.mean(), rewards.std() + print(f"after mean: {mean}, after std: {std}") + + +def evaluate(args, accelerator, device, reward_model, validation_label): + reward_model.eval() + with torch.no_grad(): + # eval on validation_label, some duplicate code (I don't want to make the training loop into a function...) + test_accuracies = [] + eval_len = len(validation_label) + len_labels = (eval_len // args.batch_size) * args.batch_size # in case the last batch is not full + new_all_inds = np.arange(len_labels) + for start in range(0, len_labels, args.batch_size): + end = start + args.batch_size + b_inds_all = new_all_inds[start:end] + b_inds = b_inds_all[accelerator.process_index :: accelerator.num_processes] # multi-GPU slicing + for micro_batch_start in range(0, args.local_batch_size, args.local_micro_batch_size): + micro_batch_end = micro_batch_start + args.local_micro_batch_size + micro_batch_inds = b_inds[micro_batch_start:micro_batch_end] + mb_data = validation_label[micro_batch_inds] + mb_query = torch.from_numpy(np.stack(mb_data["query_token"])).to(device) + mb_query = right_padding_to_left_padding(mb_query, args.pad_token_id).to(device) + mb_best = torch.from_numpy(np.stack(mb_data["choice"])).to(device) + mb_responses = [ + torch.from_numpy(np.stack(mb_data[f"response{i}_token"])).to(device) + for i in range(args.labels.num_labels) + ] + predicted_rewards = [] + for i in range(args.labels.num_labels): + query_responses = torch.cat([mb_query, mb_responses[i]], dim=1) + score, _ = get_reward_complete(reward_model, query_responses, args) + predicted_rewards.append(score) + predicted_rewards = torch.stack( + predicted_rewards, dim=1 + ) # shape (batch_size, num_labels), basically a reward prediction for each label + accuracy = (predicted_rewards.argmax(1) == mb_best).float().mean() + test_accuracies.append(accuracy) + test_accuracy = accelerator.gather(torch.stack(test_accuracies).mean()).mean().item() + reward_model.train() + return test_accuracy + + +def train(args: Args): + accelerator = Accelerator( + kwargs_handlers=[ + DistributedDataParallelKwargs( + broadcast_buffers=False, + ) + ], # this is needed to avoid https://github.com/pytorch/pytorch/issues/22095#issuecomment-505099500 + gradient_accumulation_steps=args.gradient_accumulation_steps, + ) + args.world_size = accelerator.num_processes + args.batch_size = int(args.local_batch_size * args.world_size) + args.rollout_batch_size = int(args.local_rollout_batch_size * args.world_size) + args.local_micro_batch_size = exact_div(args.local_batch_size, args.gradient_accumulation_steps) + num_updates = args.labels.num_train // args.batch_size + patch_h = TaskQueryHParams( + length=args.task.query_length, + dataset=args.task.query_dataset, + format_str=args.task.query_format_str, + truncate_field=args.task.query_truncate_field, + truncate_text=args.task.query_truncate_text, + padding=args.task.query_padding, + pad_side=args.task.query_pad_side, + ) + pprint(patch_h) + + console = Console(force_terminal=True) + run_name = f"{args.exp_name}__{args.seed}__{int(time.time())}" + writer = SimpleNamespace() # dummy writer + writer.add_scalar = lambda x, y, z: None + if accelerator.is_main_process: + if args.track: + import wandb + + wandb.init( + project=args.wandb_project_name, + entity=args.wandb_entity, + sync_tensorboard=True, + config=asdict(args), + name=run_name, + save_code=True, + ) + wandb.run.log_code(".") + writer = SummaryWriter(f"runs/{run_name}") + writer.add_text( + "hyperparameters", + "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])), + ) + pprint(args) + device = accelerator.device + local_seed = args.seed + accelerator.process_index * 100003 # Prime + random.seed(local_seed) + np.random.seed(local_seed) + torch.manual_seed(local_seed) + torch.backends.cudnn.deterministic = True + tokenizer = AutoTokenizer.from_pretrained( + args.base_model, + padding_side="right", + trust_remote_code=True, + ) + # we use the padding token manually but do not resize the token embedding of the model + tokenizer.add_special_tokens({"pad_token": "[PAD]"}) + args.pad_token_id = tokenizer.pad_token_id + reward_model = AutoModelForCausalLMWithRewardHead( + AutoModelForCausalLM.from_pretrained(args.base_model, trust_remote_code=True) + ) + + # freeze the first 70% of layers + if args.trainable_param_percentage < 1.0: + layers = reward_model.lm_backbone.transformer.h + num_layers = len(layers) + num_unfrozen = int(args.trainable_param_percentage * num_layers) + for layer in layers[:-num_unfrozen]: + layer.requires_grad_(False) + + if args.sft_model_path: + reward_model.lm_backbone.load_state_dict(torch.load(args.sft_model_path, map_location=device)) + print(f"loaded SFT model from {args.sft_model_path}") + reward_model.lm_backbone.generation_config.eos_token_id = ( + None # disable `pad_token_id` and `eos_token_id` because we just want to + ) + reward_model.lm_backbone.generation_config.pad_token_id = None # generate tokens without truncation / padding + # make sure the `lm_head` or `embed_out` does not require gradients, otherwise + # pytorch DDP complains; see https://gist.github.com/vwxyzjn/45fc8706dfb3cf33695f0f57cc44a533 + reward_model.lm_backbone.gradient_checkpointing_enable() + + if isinstance(reward_model.lm_backbone, transformers.GPTNeoXForCausalLM): + reward_model.lm_backbone.embed_out.requires_grad_(False) + if args.use_tensorflow_adam: + optimizer = AdamTensorFlowStyle(reward_model.parameters(), lr=args.lr, eps=args.eps) + else: + optimizer = optim.Adam(reward_model.parameters(), lr=args.lr, eps=args.eps) + scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_updates) + dataset = load_dataset(args.task.query_dataset, split="train") + validation_dataset = load_dataset(args.task.query_dataset, split="validation") + + def process_query_data(x): + return { + **process_query(x, encoder=tokenizer, hparams=patch_h), + "reference_response": tokenizer.encode( + f" {x['summary']}<|endoftext|>", padding="max_length", max_length=args.task.response_length, truncation=True, + # with an extra leading space to account for the space between the query and response + ), + } + + # pprint(process_query_data(dataset[0])) + dataset = dataset.map(process_query_data, load_from_cache_file=args.load_from_cache_file) + dataset = dataset.with_format("torch", columns=["query_token", "reference_response"]) + dataset = dataset.shuffle(seed=local_seed) + dataloader = DataLoader(dataset, batch_size=args.local_rollout_batch_size) + validation_dataset = validation_dataset.map(process_query_data, load_from_cache_file=args.load_from_cache_file) + validation_dataset = validation_dataset.with_format("torch", columns=["query_token", "reference_response"]) + validation_dataset = validation_dataset.shuffle(seed=local_seed) + validation_dataloader = DataLoader(validation_dataset, batch_size=args.local_rollout_batch_size) + reward_model, optimizer, dataloader, scheduler = accelerator.prepare(reward_model, optimizer, dataloader, scheduler) + + iter_dataloader = iter(dataloader) + generation_config = GenerationConfig( + max_new_tokens=args.task.response_length, + min_new_tokens=args.task.response_length, + temperature=args.task.temperature, + top_k=0.0, + top_p=1.0, + do_sample=True, + ) + + if args.normalize_before: + print("===Normalize reward model *before* training===") + print( + "before normalization. " + + f"Gain: {accelerator.unwrap_model(reward_model).reward_gain.data}" + + f" Bias: {accelerator.unwrap_model(reward_model).reward_bias.data}" + ) + + normalize( + tokenizer, + accelerator, + device, + reward_model, + reward_model, + dataloader, + validation_dataloader, + ) + print( + "after normalization. " + + f"Gain: {accelerator.unwrap_model(reward_model).reward_gain.data}" + + f" Bias: {accelerator.unwrap_model(reward_model).reward_bias.data}" + ) + + # `label` has keys `['sample0', 'query', 'best', 'sample3', 'sample1', 'sample2']` + label = load_dataset(args.label_dataset, "comparisons", split="train") + validation_label = load_dataset(args.label_dataset, "comparisons", split="validation") + dev_validation_label = validation_label.filter(lambda x: x["split"] == "valid1") + eval_validation_label = validation_label.filter(lambda x: x["split"] == "valid2") + accelerator.print("Num labels found in source:", len(label)) + accelerator.print("training on", args.labels.num_train, "in batches of", args.local_batch_size) + + def process_response_data(x): + return { + **process_query(x["info"], encoder=tokenizer, hparams=patch_h), + "response0_token": tokenizer.encode( + f" {x['summaries'][0]['text']}<|endoftext|>", padding="max_length", max_length=args.task.response_length, truncation=True + ), + "response1_token": tokenizer.encode( + f" {x['summaries'][1]['text']}<|endoftext|>", padding="max_length", max_length=args.task.response_length, truncation=True + ), + } + + label = label.map(process_response_data) + dev_validation_label = dev_validation_label.map(process_response_data) + eval_validation_label = eval_validation_label.map(process_response_data) + # tokenizer.encode(label[0]["summaries"][0]["text"]) + + accelerator.print("===training reward model===") + all_inds = np.random.permutation(args.labels.num_train) + # ensure that all processes have the same shuffled indices + all_inds = broadcast(torch.tensor(all_inds, device=device), 0) + all_inds = all_inds.cpu().numpy() + + for (global_step, start) in enumerate(range(0, args.labels.num_train, args.batch_size)): + # # linear rate annealing + # lr = (1 - start / args.labels.num_train) * args.lr + # optimizer.param_groups[0]["lr"] = lr + + end = start + args.batch_size + b_inds_all = all_inds[start:end] + b_inds = b_inds_all[accelerator.process_index :: accelerator.num_processes] # multi-GPU slicing + # accelerator.print(f"global_step: {global_step}, start: {start}, end: {end}, b_inds: {b_inds}") + if accelerator.is_main_process: pprint( + { + "global_step": global_step, + "start:end": f"{start}:{end}", + "b_inds_all": b_inds_all, + "b_inds": b_inds, + } + ) + losses = torch.zeros((args.gradient_accumulation_steps,), device=device) + accuracies = torch.zeros((args.gradient_accumulation_steps,), device=device) + gradient_accumulation_step = 0 + # reward_model.train() + for micro_batch_start in range(0, args.local_batch_size, args.local_micro_batch_size): + with accelerator.accumulate(reward_model): + micro_batch_end = micro_batch_start + args.local_micro_batch_size + micro_batch_inds = b_inds[micro_batch_start:micro_batch_end] + mb_data = label[micro_batch_inds] + # pprint({ + # "micro_batch_start:micro_batch_end": f"{micro_batch_start}:{micro_batch_end}", + # "micro_batch_inds": micro_batch_inds, + # }) + mb_query = torch.from_numpy(np.stack(mb_data["query_token"])).to(device) + mb_best = torch.from_numpy(np.stack(mb_data["choice"])).to(device) + mb_responses = [ + torch.from_numpy(np.stack(mb_data[f"response{i}_token"])).to(device) for i in range(args.labels.num_labels) + ] + mb_query_tiled = mb_query.unsqueeze(1).repeat(1, len(mb_responses), 1) + query_responses = torch.cat([mb_query_tiled, torch.stack(mb_responses).transpose(0,1)], dim=2).flatten(0, 1) + predicted_rewards, score_all = get_reward_complete(reward_model, query_responses, tokenizer) + breakpoint() + + predicted_rewards = predicted_rewards.view(len(mb_responses), -1) + reward_preferred = predicted_rewards.gather(1, mb_best.view(-1, 1)).view(-1) + reward_rejected = predicted_rewards.gather(1, (1 - mb_best).view(-1, 1)).view(-1) + accuracy = (predicted_rewards.argmax(1) == mb_best).float().mean() + if args.logsigmoid: + loss = -F.logsigmoid(reward_preferred - reward_rejected).mean() + else: + loss = F.cross_entropy(predicted_rewards, mb_best) + accelerator.backward(loss) + optimizer.step() # accelerate handles gradient accumulation automatically + optimizer.zero_grad() + scheduler.step() + losses[gradient_accumulation_step] = loss + accuracies[gradient_accumulation_step] = accuracy + gradient_accumulation_step += 1 + + train_accuracy = accelerator.gather(accuracies).mean().item() + writer.add_scalar("train/loss", accelerator.gather(losses).mean().item(), global_step) + writer.add_scalar("train/accuracy", train_accuracy, global_step) + lr = scheduler.get_last_lr() + writer.add_scalar("train/lr", np.array(lr).mean().item(), global_step) + accelerator.print("train/accuracy", train_accuracy) + + # if args.print_sample_output_freq > 0 and global_step % args.print_sample_output_freq == 0: + if global_step == num_updates - 1: # first and last update + dev_validation_accuracy = evaluate(args, accelerator, device, reward_model, dev_validation_label) + writer.add_scalar("dev_validation/accuracy", dev_validation_accuracy, global_step) + accelerator.print("dev_validation/accuracy", dev_validation_accuracy, global_step) + eval_validation_accuracy = evaluate(args, accelerator, device, reward_model, eval_validation_label) + writer.add_scalar("eval_validation/accuracy", eval_validation_accuracy, global_step) + accelerator.print("eval_validation/accuracy", eval_validation_accuracy, global_step) + + torch.cuda.empty_cache() + if args.normalize_after: + print("===Normalize reward model *after* training===") + print( + "before normalization. " + + f"Gain: {accelerator.unwrap_model(reward_model).reward_gain.data}" + + f" Bias: {accelerator.unwrap_model(reward_model).reward_bias.data}" + ) + + normalize( + tokenizer, + accelerator, + device, + reward_model, + reward_model, + dataloader, + validation_dataloader, + ) + print( + "after normalization. " + + f"Gain: {accelerator.unwrap_model(reward_model).reward_gain.data}" + + f" Bias: {accelerator.unwrap_model(reward_model).reward_bias.data}" + ) + + # save model + if args.save_path: + os.makedirs(os.path.dirname(args.save_path), exist_ok=True) + # torch.save(accelerator.unwrap_model(reward_model).state_dict(), args.save_path) + accelerator.save_model(reward_model, args.save_path) + + if accelerator.is_main_process and args.track: + wandb.finish() + + + +if __name__ == "__main__": + args = tyro.cli(Args) + train(args) diff --git a/lm_human_preference_details/summarization/train_reward_accelerate_summarizew.py b/lm_human_preference_details/summarization/train_reward_accelerate_summarizew.py new file mode 100644 index 0000000..cfbd58a --- /dev/null +++ b/lm_human_preference_details/summarization/train_reward_accelerate_summarizew.py @@ -0,0 +1,824 @@ +import os +import random +import time +from dataclasses import asdict, dataclass, field +from types import SimpleNamespace +from typing import List, Literal, Optional + +import numpy as np +import pandas as pd +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +import transformers +import tyro +from accelerate import Accelerator +from accelerate.state import AcceleratorState +from accelerate.utils import DistributedDataParallelKwargs, broadcast +from datasets import load_dataset +from rich.console import Console +from rich.pretty import pprint +from rich.table import Table +from torch import Tensor, optim +from torch.optim.optimizer import ( + _dispatch_sqrt, + _get_value, + _use_grad_for_differentiable, +) +from torch.utils.data import DataLoader +from torch.utils.tensorboard import SummaryWriter +from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig, get_scheduler + +from lm_human_preference_details.data import process_query + + +@dataclass +class LabelHParams: + type: str = None + num_train: int = 92832 + num_labels: int = 2 + source: str = None + + +@dataclass +class TaskHParams: + # Query params + query_length: int = 512 + query_dataset: str = "vwxyzjn/summarize_from_feedback_tldr_3_filtered" + + query_format_str: Optional[str] = "SUBREDDIT: r/{subreddit}\n\nTITLE: {title}\n\nPOST: {post}\n\nTL;DR:" + query_truncate_field: Optional[str] = "post" + query_truncate_text: Optional[str] = "\n" + query_padding: Optional[str] = None # defaults to repeated spaces + query_pad_side: Optional[str] = "left" + + # Response params + response_length: int = 48 + + # LM params + temperature: float = 0.7 + + +# a patch +@dataclass +class TaskQueryHParams: + length: int = None + dataset: str = None + format_str: Optional[str] = None # if underlying dataset yields dicts, can format arbitrarily + truncate_field: Optional[str] = None + truncate_text: Optional[str] = None + padding: Optional[str] = None # defaults to repeated spaces + pad_side: Optional[str] = None + + +@dataclass +class Args: + # common args + exp_name: str = os.path.basename(__file__)[: -len(".py")] + """the name of this experiment""" + seed: int = 1 + """seed of the experiment""" + track: bool = False + """if toggled, this experiment will be tracked with Weights and Biases""" + wandb_project_name: str = "cleanrl" + """the wandb's project name""" + wandb_entity: Optional[str] = None + """the entity (team) of wandb's project""" + cuda: bool = True + """Whether to use cuda if available.""" + run_name: tyro.conf.Suppress[str] = None + """TO BE FILLED: a unique name of this run""" + load_from_cache_file: bool = False + """Whether to load data from the local cache file in `dataset.map`""" + + base_model: str = "gpt2" + """the name of the pretrained model to use""" + deepspeed: bool = False + """Whether to use deepspeed to train the model""" + label_dataset: str = "openai/summarize_from_feedback" + """the name of the dataset to use for labels in `https://huggingface.co/datasets/vwxyzjn/lm-human-preferences`""" + local_batch_size: int = 4 + """per rank batch size""" + gradient_accumulation_steps: int = 1 + """gradient accumulation steps""" + local_micro_batch_size: tyro.conf.Suppress[int] = None + """per rank micro batch size""" + lr: float = 0.00005 + """the learning rate""" + eps: float = 1e-5 + """the epsilon for AdamW""" + local_rollout_batch_size: int = 512 + """per rank rollout batch size""" + rollout_batch_size: tyro.conf.Suppress[int] = None + """rollout batch size""" + world_size: tyro.conf.Suppress[int] = None + """the number of processes to use""" + batch_size: tyro.conf.Suppress[int] = None + """the batch size across all ranks""" + local_normalize_samples: int = 256 + """Samples used to estimate reward mean and std""" + normalize_samples: tyro.conf.Suppress[int] = None + """Samples used to estimate reward mean and std across all ranks""" + debug_normalize: int = 0 + """Samples used to check that normalization worked""" + normalize_before: bool = True + """Whether, before training, to normalize the rewards on the policy to the scales on the training buffer. (For comparisons, just use mean 0, var 1.)""" + normalize_after: bool = True + """Whether, after training, to normalize the rewards on the ref policy to mean 0, var 1 (so the KL coefficient always has the same meaning).""" + print_sample_output_freq: int = 300 + """How often to print sample output""" + sft_model_path: str = "models/sft_policy" + """Where to load the SFT model""" + logsigmoid: bool = True + """Whether to use log-sigmoid loss instead of cross-entropy loss""" + trainable_param_percentage: float = 1.0 + """Percentage of parameters to train""" + num_epochs: int = 1 + """Number of epochs to train""" + num_updates: tyro.conf.Suppress[int] = None + """Number of updates to train""" + save_path: str = "models/reward" + """Where to save the model""" + optimizer: Literal["tf_adam", "adam", "adamw"] = "adamw" + """Which optimizer to use""" + scheduler: str = "constant_with_warmup" + """Which scheduler to use""" + warm_up_steps: int = 100 + """Number of warm up steps for the scheduler""" + task: TaskHParams = field(default_factory=TaskHParams) + labels: LabelHParams = field(default_factory=LabelHParams) + + +def first_true_indices(bools, dtype=torch.long): + """ + Takes an N-dimensional bool tensor and returns an (N-1)-dimensional tensor of integers giving + the position of the first True in each "row". + + Returns the length of the rows (bools.size(-1)) if no element is True in a given row. + """ + row_len = bools.size(-1) + zero_or_index = row_len * (~bools).type(dtype) + torch.arange(row_len, dtype=dtype, device=bools.device) + return torch.min(zero_or_index, dim=-1).values + + +def print_rich_table(title: str, df: pd.DataFrame, console: Console) -> Table: + table = Table(show_lines=True) + for column in df.columns: + table.add_column(column) + for _, row in df.iterrows(): + table.add_row(*row.astype(str).tolist()) + console.rule(f"[bold red]{title}") + console.print(table) + + +def _single_tensor_adam( + params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_avg_sqs: List[Tensor], + max_exp_avg_sqs: List[Tensor], + state_steps: List[Tensor], + grad_scale: Optional[Tensor], + found_inf: Optional[Tensor], + *, + amsgrad: bool, + beta1: float, + beta2: float, + lr: float, + weight_decay: float, + eps: float, + maximize: bool, + capturable: bool, + differentiable: bool, +): + assert grad_scale is None and found_inf is None + + for i, param in enumerate(params): + grad = grads[i] if not maximize else -grads[i] + exp_avg = exp_avgs[i] + exp_avg_sq = exp_avg_sqs[i] + step_t = state_steps[i] + # update step + step_t += 1 + # Decay the first and second moment running average coefficient + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2) + step = _get_value(step_t) + + ### pytorch adam implementation: + # bias_correction1 = 1 - beta1 ** step + # bias_correction2 = 1 - beta2 ** step + # step_size = lr / bias_correction1 + # bias_correction2_sqrt = _dispatch_sqrt(bias_correction2) + # denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps) + # param.addcdiv_(exp_avg, denom, value=-step_size) + + ### tensorflow adam implementation: + lr_t = lr * _dispatch_sqrt(1 - beta2**step) / (1 - beta1**step) + denom = exp_avg_sq.sqrt().add_(eps) + param.addcdiv_(exp_avg, denom, value=-lr_t) + + +def adam( + params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_avg_sqs: List[Tensor], + max_exp_avg_sqs: List[Tensor], + state_steps: List[Tensor], + # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627 + # setting this as kwarg for now as functional API is compiled by torch/distributed/optim + foreach: Optional[bool] = None, + capturable: bool = False, + differentiable: bool = False, + fused: Optional[bool] = None, + grad_scale: Optional[Tensor] = None, + found_inf: Optional[Tensor] = None, + *, + amsgrad: bool, + beta1: float, + beta2: float, + lr: float, + weight_decay: float, + eps: float, + maximize: bool, +): + func = _single_tensor_adam + + func( + params, + grads, + exp_avgs, + exp_avg_sqs, + max_exp_avg_sqs, + state_steps, + amsgrad=amsgrad, + beta1=beta1, + beta2=beta2, + lr=lr, + weight_decay=weight_decay, + eps=eps, + maximize=maximize, + capturable=capturable, + differentiable=differentiable, + grad_scale=grad_scale, + found_inf=found_inf, + ) + + +class AdamTensorFlowStyle(optim.Adam): + @_use_grad_for_differentiable + def step(self, closure=None): + self._cuda_graph_capture_health_check() + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + params_with_grad = [] + grads = [] + exp_avgs = [] + exp_avg_sqs = [] + max_exp_avg_sqs = [] + state_steps = [] + beta1, beta2 = group["betas"] + + self._init_group( + group, + params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + max_exp_avg_sqs, + state_steps, + ) + + adam( + params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + max_exp_avg_sqs, + state_steps, + amsgrad=group["amsgrad"], + beta1=beta1, + beta2=beta2, + lr=group["lr"], + weight_decay=group["weight_decay"], + eps=group["eps"], + maximize=group["maximize"], + foreach=group["foreach"], + capturable=group["capturable"], + differentiable=group["differentiable"], + fused=group["fused"], + grad_scale=getattr(self, "grad_scale", None), + found_inf=getattr(self, "found_inf", None), + ) + + return loss + + +def layer_init(layer, std=np.sqrt(2), bias_const=0.0): + torch.nn.init.normal_(layer.weight, std=std) + torch.nn.init.constant_(layer.bias, val=bias_const) + return layer + + +class AutoModelForCausalLMWithRewardHead(nn.Module): + def __init__(self, lm_backbone): + super().__init__() + self.lm_backbone = lm_backbone + self.scalar_head = layer_init( + nn.Linear(lm_backbone.config.hidden_size, 1), + std=1 / np.sqrt(lm_backbone.config.hidden_size + 1), + ) + # self.reward_gain = torch.nn.Parameter(torch.tensor(1.0), requires_grad=True) + # self.reward_bias = torch.nn.Parameter(torch.tensor(0.0), requires_grad=True) + + def forward(self, **kwargs): + output = self.lm_backbone(**kwargs) + last_reward_latents = output.hidden_states[-1] + # shape: [batch_size, hidden_size] + reward = self.scalar_head(last_reward_latents) + return output, reward + + +def right_padding_to_left_padding(tokens, pad_id): + """Convert from right padding to left padding.""" + assert tokens.ndim == 2 + return torch.tensor( + [[pad_id] * (row == pad_id).sum() + [x for x in row if x != pad_id] for row in tokens], + device=tokens.device, + ) + + +def ceil_div(a, b): + return (a - 1) // b + 1 + + +def exact_div(a, b): + q = a // b + if a != q * b: + raise ValueError(f"Inexact division: {a} / {b} = {a / b}") + return q + + +def generate(lm_backbone, queries, tokenizer, generation_config): + """generate in a way that does not affect padding tokens""" + context_length = queries.shape[1] + attention_mask = queries != tokenizer.pad_token_id + input_ids = queries.clone() + input_ids[~attention_mask] = 0 # set padding tokens to 0 + output = lm_backbone.generate( + input_ids=input_ids, + attention_mask=attention_mask, + # position_ids=attention_mask.cumsum(1) - attention_mask.long(), # generation collapsed if this was turned on. TODO: why does generation collapse with this? + generation_config=generation_config, + return_dict_in_generate=True, + ) + # restore padding tokens + return torch.cat((queries, output.sequences[:, context_length:]), dim=1) + + +def get_reward(reward_model, query_responses, tokenizer): + attention_mask = query_responses != tokenizer.pad_token_id + position_ids = attention_mask.cumsum(1) - attention_mask.long() # exclusive cumsum + input_ids = torch.masked_fill(query_responses, ~attention_mask, 0) + return reward_model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + return_dict=True, + output_hidden_states=True, + ) + + +def get_reward_complete(reward_model, query_responses, tokenizer): + reward = get_reward(reward_model, query_responses, tokenizer)[1] + last_response_indices = first_true_indices(query_responses == tokenizer.pad_token_id) - 1 + last_response_indices = torch.max( + last_response_indices, + torch.zeros([1], dtype=last_response_indices.dtype, device=query_responses.device), + ) + return reward[:, :, 0].gather(1, last_response_indices.unsqueeze(1)).view(-1), reward + + +def normalize( + tokenizer, + accelerator, + device, + lm_backbone, + reward_model, + dataloader, + validation_dataloader, +): + idx = 0 + with torch.no_grad(): + # reset reward scales + accelerator.unwrap_model(reward_model).reward_gain.data.fill_(1.0) + accelerator.unwrap_model(reward_model).reward_bias.data.fill_(0.0) + # number of minibatches for computing the normalization statistics + rewards = [] + for data in dataloader: + idx += len(data["query_token"]) + queries = data["query_token"].to(device) + queries = right_padding_to_left_padding(queries, tokenizer.pad_token_id).to(device) + reference_response = data["reference_response"].to(device) + query_responses = torch.cat((queries, reference_response), dim=1) + score = get_reward_complete(reward_model, query_responses, tokenizer) + rewards.append(score) + accelerator.print(f"====number of samples per device: {idx}") + rewards = torch.cat(rewards) + rewards = accelerator.gather(rewards) + mean, std = rewards.mean(), rewards.std() + print(f"mean: {mean}, std: {std}") + + # reward normalization + target_mean, target_std = torch.tensor(0.0, device=device), torch.tensor(1.0, device=device) + gain = target_std / std + bias = target_mean - gain * mean + print(f"gain: {gain}, bias: {bias}") + accelerator.unwrap_model(reward_model).reward_gain.data = gain + accelerator.unwrap_model(reward_model).reward_bias.data = bias + + # validate normalization + rewards = [] + for data in validation_dataloader: + queries = data["query_token"].to(device) + queries = right_padding_to_left_padding(queries, tokenizer.pad_token_id).to(device) + reference_response = data["reference_response"].to(device) + query_responses = torch.cat((queries, reference_response), dim=1) + score = get_reward_complete(reward_model, query_responses, tokenizer) + rewards.append(score) + rewards = torch.cat(rewards) + rewards = accelerator.gather(rewards) + mean, std = rewards.mean(), rewards.std() + print(f"after mean: {mean}, after std: {std}") + + +def evaluate(args, accelerator, device, reward_model, validation_label): + # reward_model.eval() + with torch.no_grad(): + # eval on validation_label, some duplicate code (I don't want to make the training loop into a function...) + test_accuracies = [] + eval_len = len(validation_label) + len_labels = (eval_len // args.batch_size) * args.batch_size # in case the last batch is not full + new_all_inds = np.arange(len_labels) + for start in range(0, len_labels, args.batch_size): + end = start + args.batch_size + b_inds_all = new_all_inds[start:end] + b_inds = b_inds_all[accelerator.process_index :: accelerator.num_processes] # multi-GPU slicing + for micro_batch_start in range(0, args.local_batch_size, args.local_micro_batch_size): + micro_batch_end = micro_batch_start + args.local_micro_batch_size + micro_batch_inds = b_inds[micro_batch_start:micro_batch_end] + mb_data = validation_label[micro_batch_inds] + mb_query = torch.from_numpy(np.stack(mb_data["query_token"])).to(device) + mb_query = right_padding_to_left_padding(mb_query, args.pad_token_id).to(device) + mb_best = torch.from_numpy(np.stack(mb_data["choice"])).to(device) + mb_responses = [ + torch.from_numpy(np.stack(mb_data[f"response{i}_token"])).to(device) + for i in range(args.labels.num_labels) + ] + predicted_reward = [] + rewards = [] + for i in range(args.labels.num_labels): + query_responses = torch.cat([mb_query, mb_responses[i]], dim=1) + score, reward = get_reward_complete(reward_model, query_responses, args) + rewards.append(reward) + predicted_reward.append(score) + predicted_reward = torch.stack( + predicted_reward, dim=1 + ) # shape (batch_size, num_labels), basically a reward prediction for each label + accuracy = (predicted_reward.argmax(1) == mb_best).float().mean() + test_accuracies.append(accuracy) + test_accuracy = accelerator.gather(torch.stack(test_accuracies).mean()).mean().item() + # reward_model.train() + return test_accuracy + + +def train(args: Args): + accelerator = Accelerator( + kwargs_handlers=[ + DistributedDataParallelKwargs( + broadcast_buffers=False, + # find_unused_parameters=True, + ) + ], # this is needed to avoid https://github.com/pytorch/pytorch/issues/22095#issuecomment-505099500 + gradient_accumulation_steps=args.gradient_accumulation_steps, + ) + args.world_size = accelerator.num_processes + args.batch_size = int(args.local_batch_size * args.world_size) + args.rollout_batch_size = int(args.local_rollout_batch_size * args.world_size) + args.local_micro_batch_size = exact_div(args.local_batch_size, args.gradient_accumulation_steps) + args.num_updates = args.labels.num_train // args.batch_size + patch_h = TaskQueryHParams( + length=args.task.query_length, + dataset=args.task.query_dataset, + format_str=args.task.query_format_str, + truncate_field=args.task.query_truncate_field, + truncate_text=args.task.query_truncate_text, + padding=args.task.query_padding, + pad_side=args.task.query_pad_side, + ) + + console = Console(force_terminal=True) + run_name = f"{args.exp_name}__{args.seed}__{int(time.time())}" + writer = SimpleNamespace() # dummy writer + writer.add_scalar = lambda x, y, z: None + if accelerator.is_main_process: + if args.track: + import wandb + + wandb.init( + project=args.wandb_project_name, + entity=args.wandb_entity, + sync_tensorboard=True, + config=asdict(args), + name=run_name, + save_code=True, + ) + wandb.run.log_code(".") + writer = SummaryWriter(f"runs/{run_name}") + writer.add_text( + "hyperparameters", + "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])), + ) + pprint(args) + device = accelerator.device + local_seed = args.seed + accelerator.process_index * 100003 # Prime + random.seed(local_seed) + np.random.seed(local_seed) + torch.manual_seed(local_seed) + torch.backends.cudnn.deterministic = True + tokenizer = AutoTokenizer.from_pretrained( + args.base_model, + padding_side="right", + trust_remote_code=True, + ) + # we use the padding token manually but do not resize the token embedding of the model + tokenizer.add_special_tokens({"pad_token": "[PAD]"}) + args.pad_token_id = tokenizer.pad_token_id + reward_model = AutoModelForCausalLMWithRewardHead( + AutoModelForCausalLM.from_pretrained(args.base_model, trust_remote_code=True) + ) + + # freeze the first 70% of layers + if args.trainable_param_percentage < 1.0: + layers = reward_model.lm_backbone.transformer.h + num_layers = len(layers) + num_unfrozen = int(args.trainable_param_percentage * num_layers) + for layer in layers[:-num_unfrozen]: + layer.requires_grad_(False) + + if args.sft_model_path: + reward_model.lm_backbone.load_state_dict(torch.load(args.sft_model_path, map_location=device)) + print(f"loaded SFT model from {args.sft_model_path}") + reward_model.lm_backbone.generation_config.eos_token_id = ( + None # disable `pad_token_id` and `eos_token_id` because we just want to + ) + reward_model.lm_backbone.generation_config.pad_token_id = None # generate tokens without truncation / padding + # make sure the `lm_head` or `embed_out` does not require gradients, otherwise + # pytorch DDP complains; see https://gist.github.com/vwxyzjn/45fc8706dfb3cf33695f0f57cc44a533 + reward_model.load_state_dict(torch.load("models/gpt2-medium-rm/pytorch_model.bin", map_location=device)) + print("loaded reward model") + if isinstance(reward_model.lm_backbone, transformers.GPTNeoXForCausalLM): + reward_model.lm_backbone.embed_out.requires_grad_(False) + if args.optimizer == "tf_adam": + optimizer = AdamTensorFlowStyle(reward_model.parameters(), lr=args.lr, eps=args.eps) + elif args.optimizer == "adam": + optimizer = optim.Adam(reward_model.parameters(), lr=args.lr, eps=args.eps) + elif args.optimizer == "adamw": + optimizer = optim.AdamW(reward_model.parameters(), lr=args.lr, eps=args.eps) + # TODO: use AdamW + scheduler = get_scheduler( + args.scheduler, + optimizer=optimizer, + num_warmup_steps=args.warm_up_steps, + num_training_steps=args.num_updates * args.num_epochs, + ) + + if args.deepspeed: + import deepspeed + + deepspeed_states = AcceleratorState().deepspeed_plugin + deepspeed_states.deepspeed_config["train_micro_batch_size_per_gpu"] = args.local_micro_batch_size + + + reward_model, optimizer, scheduler = accelerator.prepare(reward_model, optimizer, scheduler) + if args.normalize_before: + dataset = load_dataset(args.task.query_dataset, split="train") + validation_dataset = load_dataset(args.task.query_dataset, split="validation") + + def process_query_data(x): + return { + **process_query(x, encoder=tokenizer, hparams=patch_h), + "reference_response": tokenizer.encode( + f" {x['summary']}<|endoftext|>", padding="max_length", max_length=args.task.response_length, truncation=True, + # with an extra leading space to account for the space between the query and response + ), + } + + dataset = dataset.map(process_query_data, load_from_cache_file=args.load_from_cache_file) + dataset = dataset.with_format("torch", columns=["query_token", "reference_response"]) + dataset = dataset.shuffle(seed=local_seed) + dataloader = DataLoader(dataset, batch_size=args.local_rollout_batch_size) + validation_dataset = validation_dataset.map(process_query_data, load_from_cache_file=args.load_from_cache_file) + validation_dataset = validation_dataset.with_format("torch", columns=["query_token", "reference_response"]) + validation_dataset = validation_dataset.shuffle(seed=local_seed) + validation_dataloader = DataLoader(validation_dataset, batch_size=args.local_rollout_batch_size) + dataloader = accelerator.prepare(dataloader) + iter_dataloader = iter(dataloader) + print("===Normalize reward model *before* training===") + print( + "before normalization. " + + f"Gain: {accelerator.unwrap_model(reward_model).reward_gain.data}" + + f" Bias: {accelerator.unwrap_model(reward_model).reward_bias.data}" + ) + + normalize( + tokenizer, + accelerator, + device, + reward_model, + reward_model, + dataloader, + validation_dataloader, + ) + print( + "after normalization. " + + f"Gain: {accelerator.unwrap_model(reward_model).reward_gain.data}" + + f" Bias: {accelerator.unwrap_model(reward_model).reward_bias.data}" + ) + + # `label` has keys `['sample0', 'query', 'best', 'sample3', 'sample1', 'sample2']` + label = load_dataset(args.label_dataset, "comparisons", split="train") + validation_label = load_dataset(args.label_dataset, "comparisons", split="validation") + dev_validation_label = validation_label.filter(lambda x: x["split"] == "valid1") + eval_validation_label = validation_label.filter(lambda x: x["split"] == "valid2") + accelerator.print("Num labels found in source:", len(label)) + accelerator.print("training on", args.labels.num_train, "in batches of", args.local_batch_size) + + def process_response_data(x): + return { + **process_query(x["info"], encoder=tokenizer, hparams=patch_h), + "response0_token": tokenizer.encode( + f" {x['summaries'][0]['text']}<|endoftext|>", padding="max_length", max_length=args.task.response_length, truncation=True + ), + "response1_token": tokenizer.encode( + f" {x['summaries'][1]['text']}<|endoftext|>", padding="max_length", max_length=args.task.response_length, truncation=True + ), + } + + label = label.map(process_response_data, load_from_cache_file=args.load_from_cache_file) + dev_validation_label = dev_validation_label.map(process_response_data, load_from_cache_file=args.load_from_cache_file) + eval_validation_label = eval_validation_label.map(process_response_data, load_from_cache_file=args.load_from_cache_file) + # TODO: check if all labels have eos token + accelerator.print("===training reward model===") + num_train = (args.labels.num_train // args.batch_size) * args.batch_size + for epoch in range(args.num_epochs): + all_inds = np.random.permutation(args.labels.num_train) + # ensure that all processes have the same shuffled indices + all_inds = broadcast(torch.tensor(all_inds, device=device), 0) + all_inds = all_inds.cpu().numpy() + accelerator.print(f"epoch: {epoch}") + for (epoch_global_step, start) in enumerate(range(0, num_train, args.batch_size)): + global_step = epoch * args.num_updates + epoch_global_step + end = start + args.batch_size + b_inds_all = all_inds[start:end] + b_inds = b_inds_all[accelerator.process_index :: accelerator.num_processes] # multi-GPU slicing + # accelerator.print(f"global_step: {global_step}, start: {start}, end: {end}, b_inds: {b_inds}") + if accelerator.is_main_process: pprint( + { + "global_step": global_step, + "start:end": f"{start}:{end}", + "b_inds_all": b_inds_all, + "b_inds": b_inds, + } + ) + losses = torch.zeros((args.gradient_accumulation_steps,), device=device) + accuracies = torch.zeros((args.gradient_accumulation_steps,), device=device) + reward_preferreds = torch.zeros((args.gradient_accumulation_steps,), device=device) + reward_rejecteds = torch.zeros((args.gradient_accumulation_steps,), device=device) + gradient_accumulation_step = 0 + # reward_model.train() + for micro_batch_start in range(0, args.local_batch_size, args.local_micro_batch_size): + with accelerator.accumulate(reward_model): + micro_batch_end = micro_batch_start + args.local_micro_batch_size + micro_batch_inds = b_inds[micro_batch_start:micro_batch_end] + mb_data = label[micro_batch_inds] + # pprint({ + # "micro_batch_start:micro_batch_end": f"{micro_batch_start}:{micro_batch_end}", + # "micro_batch_inds": micro_batch_inds, + # }) + mb_query = torch.from_numpy(np.stack(mb_data["query_token"])).to(device) + mb_best = torch.from_numpy(np.stack(mb_data["choice"])).to(device) + mb_responses = [ + torch.from_numpy(np.stack(mb_data[f"response{i}_token"])).to(device) for i in range(args.labels.num_labels) + ] + mb_query_tiled = mb_query.unsqueeze(1).repeat(1, len(mb_responses), 1) + query_responses = torch.cat([mb_query_tiled, torch.stack(mb_responses).transpose(0,1)], dim=2).flatten(0, 1) + predicted_reward, reward = get_reward_complete(reward_model, query_responses, tokenizer) + predicted_reward = predicted_reward.view(-1, len(mb_responses)) # TODO check shape for no gradienta ccumulation steps + + # print(tokenizer.decode(mb_query[0])) + # print(tokenizer.decode(mb_responses[0][0])) + # print(tokenizer.decode(mb_responses[1][0])) + # predicted_reward = [] + # rewards = [] + # for i in range(args.labels.num_labels): + # query_responses = torch.cat([mb_query, mb_responses[i]], dim=1) + # score, reward = get_reward_complete(reward_model, query_responses, tokenizer) + # rewards.append(reward.squeeze(-1)) + # predicted_reward.append(score) + # # shape (batch_size, num_labels), basically a reward prediction for each label + # predicted_reward = torch.stack(predicted_reward, dim=1) + # breakpoint() + accuracy = (predicted_reward.argmax(1) == mb_best).float().mean() + reward_preferred = predicted_reward.gather(1, mb_best.view(-1, 1)).view(-1) + reward_rejected = predicted_reward.gather(1, (1 - mb_best).view(-1, 1)).view(-1) + # if args.logsigmoid: + # reward_preferred = predicted_reward.gather(1, mb_best.view(-1, 1)).view(-1) + # reward_rejected = predicted_reward.gather(1, (1 - mb_best).view(-1, 1)).view(-1) + # loss = -F.logsigmoid(reward_preferred - reward_rejected).mean() + # else: + # loss = F.cross_entropy(predicted_reward, mb_best) + # accelerator.backward(loss) + + # # for k, v in reward_model.named_parameters(): + # # if v.requires_grad: + # # if v.grad is None: + # # print(f"found unused param: {k}") + + # optimizer.step() # accelerate handles gradient accumulation automatically + # optimizer.zero_grad() + # scheduler.step() + # losses[gradient_accumulation_step] = loss + accuracies[gradient_accumulation_step] = accuracy + reward_preferreds[gradient_accumulation_step] = reward_preferred.mean() + reward_rejecteds[gradient_accumulation_step] = reward_rejected.mean() + gradient_accumulation_step += 1 + + train_accuracy = accelerator.gather(accuracies).mean().item() + print("train/accuracy", train_accuracy) + print("train/reward_preferred", accelerator.gather(reward_preferreds)) + print("train/reward_rejected", accelerator.gather(reward_rejecteds)) + breakpoint() + writer.add_scalar("train/loss", accelerator.gather(losses).mean().item(), global_step) + writer.add_scalar("train/accuracy", train_accuracy, global_step) + writer.add_scalar("train/reward_preferred", accelerator.gather(reward_preferreds).mean().item(), global_step) + writer.add_scalar("train/reward_rejected", accelerator.gather(reward_rejecteds).mean().item(), global_step) + lr = scheduler.get_last_lr() + writer.add_scalar("train/lr", np.array(lr).mean().item(), global_step) + accelerator.print("train/accuracy", train_accuracy) + + # if args.print_sample_output_freq > 0 and global_step % args.print_sample_output_freq == 0: + if global_step == args.num_updates - 1: # first and last update + dev_validation_accuracy = evaluate(args, accelerator, device, reward_model, dev_validation_label) + writer.add_scalar("dev_validation/accuracy", dev_validation_accuracy, global_step) + accelerator.print("dev_validation/accuracy", dev_validation_accuracy, global_step) + eval_validation_accuracy = evaluate(args, accelerator, device, reward_model, eval_validation_label) + writer.add_scalar("eval_validation/accuracy", eval_validation_accuracy, global_step) + accelerator.print("eval_validation/accuracy", eval_validation_accuracy, global_step) + eval_validation_accuracy = evaluate(args, accelerator, device, reward_model, label) + writer.add_scalar("train_full/accuracy", eval_validation_accuracy, global_step) + accelerator.print("train_full/accuracy", eval_validation_accuracy, global_step) + + torch.cuda.empty_cache() + if args.normalize_after: + print("===Normalize reward model *after* training===") + print( + "before normalization. " + + f"Gain: {accelerator.unwrap_model(reward_model).reward_gain.data}" + + f" Bias: {accelerator.unwrap_model(reward_model).reward_bias.data}" + ) + + normalize( + tokenizer, + accelerator, + device, + reward_model, + reward_model, + dataloader, + validation_dataloader, + ) + print( + "after normalization. " + + f"Gain: {accelerator.unwrap_model(reward_model).reward_gain.data}" + + f" Bias: {accelerator.unwrap_model(reward_model).reward_bias.data}" + ) + + # save model + if args.save_path: + os.makedirs(os.path.dirname(args.save_path), exist_ok=True) + # torch.save(accelerator.unwrap_model(reward_model).state_dict(), args.save_path) + accelerator.save_model(reward_model, args.save_path) + + if accelerator.is_main_process and args.track: + wandb.finish() + + +if __name__ == "__main__": + args = tyro.cli(Args) + train(args) diff --git a/lm_human_preference_details/summarization/train_sft_accelerate_summarize copy.py b/lm_human_preference_details/summarization/train_sft_accelerate_summarize copy.py new file mode 100644 index 0000000..31fb58b --- /dev/null +++ b/lm_human_preference_details/summarization/train_sft_accelerate_summarize copy.py @@ -0,0 +1,521 @@ +import collections +import os +import random +import time +from dataclasses import asdict, dataclass, field +from types import SimpleNamespace +from typing import List, Optional + +import numpy as np +import pandas as pd +import torch +import torch.optim as optim +import tyro +import evaluate +from accelerate import Accelerator +from datasets import load_dataset +from rich.console import Console +from rich.pretty import pprint +from rich.table import Table +from torch import Tensor, optim +from torch.optim.optimizer import ( + _dispatch_sqrt, + _get_value, + _use_grad_for_differentiable, +) +from torch.utils.data import DataLoader +from torch.utils.tensorboard import SummaryWriter +from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig + +from lm_human_preference_details.data import process_query + + +@dataclass +class SFTHParams: + gradient_accumulation_steps: int = 16 + local_micro_batch_size: int = 1 + noptepochs: int = 1 + lr: float = 6.35e-5 + eps: float = 1e-5 + total_episodes: tyro.conf.Suppress[int] = None + local_batch_size:tyro.conf.Suppress[int] = None + batch_size: tyro.conf.Suppress[int] = None + mini_batch_size: tyro.conf.Suppress[int] = None + world_size: tyro.conf.Suppress[int] = None + batch_size: tyro.conf.Suppress[int] = None + num_updates: tyro.conf.Suppress[int] = None + + +@dataclass +class TaskHParams: + # Query params + query_length: int = 512 + query_dataset: str = "vwxyzjn/summarize_from_feedback_tldr_3_filtered" + + query_format_str: Optional[str] = "SUBREDDIT: r/{subreddit}\n\nTITLE: {title}\n\nPOST: {post}\n\nTL;DR:" + query_truncate_field: Optional[str] = "post" + query_truncate_text: Optional[str] = "\n" + query_padding: Optional[str] = None # defaults to repeated spaces + query_pad_side: Optional[str] = "left" + + # Response params + response_length: int = 48 + + # Truncate response after the first occurrence of this token at or after index after when sampling. + truncate_token: int = 50256 # EOS token + truncate_after: int = 16 + penalty_reward_value: int = -1 + + # LM params + temperature: float = 0.01 + + +# a patch +@dataclass +class TaskQueryHParams: + length: int = None + dataset: str = None + format_str: Optional[str] = None # if underlying dataset yields dicts, can format arbitrarily + truncate_field: Optional[str] = None + truncate_text: Optional[str] = None + padding: Optional[str] = None # defaults to repeated spaces + pad_side: Optional[str] = None + + +@dataclass +class Args: + # common args + exp_name: str = os.path.basename(__file__)[: -len(".py")] + """the name of this experiment""" + seed: int = 1 + """seed of the experiment""" + track: bool = False + """if toggled, this experiment will be tracked with Weights and Biases""" + wandb_project_name: str = "cleanrl" + """the wandb's project name""" + wandb_entity: Optional[str] = None + """the entity (team) of wandb's project""" + cuda: bool = True + """Whether to use cuda if available.""" + run_name: tyro.conf.Suppress[str] = None + """TO BE FILLED: a unique name of this run""" + upload_model: bool = False + "whether to upload the saved model to huggingface" + hf_entity: str = "" + "the user or org name of the model repository from the Hugging Face Hub" + + base_model: str = "gpt2" + """the name of the pretrained model to use""" + deepspeed: bool = False + """Whether to use deepspeed to train the model""" + print_sample_output_freq: int = 220 + """How often to print sample output""" + save_path: str = "models/sft_policy.pt" + """Where to save the model""" + use_tensorflow_adam: bool = True + """Whether to use tensorflow-style Adam optimizer instead of PyTorch's""" + task: TaskHParams = field(default_factory=TaskHParams) + sft: SFTHParams = field(default_factory=SFTHParams) + + +def print_rich_table(title: str, df: pd.DataFrame, console: Console) -> Table: + table = Table(show_lines=True) + for column in df.columns: + table.add_column(column) + for _, row in df.iterrows(): + table.add_row(*row.astype(str).tolist()) + console.rule(f"[bold red]{title}") + console.print(table) + + +def _single_tensor_adam( + params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_avg_sqs: List[Tensor], + max_exp_avg_sqs: List[Tensor], + state_steps: List[Tensor], + grad_scale: Optional[Tensor], + found_inf: Optional[Tensor], + *, + amsgrad: bool, + beta1: float, + beta2: float, + lr: float, + weight_decay: float, + eps: float, + maximize: bool, + capturable: bool, + differentiable: bool, +): + assert grad_scale is None and found_inf is None + + for i, param in enumerate(params): + grad = grads[i] if not maximize else -grads[i] + exp_avg = exp_avgs[i] + exp_avg_sq = exp_avg_sqs[i] + step_t = state_steps[i] + # update step + step_t += 1 + # Decay the first and second moment running average coefficient + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2) + step = _get_value(step_t) + + ### pytorch adam implementation: + # bias_correction1 = 1 - beta1 ** step + # bias_correction2 = 1 - beta2 ** step + # step_size = lr / bias_correction1 + # bias_correction2_sqrt = _dispatch_sqrt(bias_correction2) + # denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps) + # param.addcdiv_(exp_avg, denom, value=-step_size) + + ### tensorflow adam implementation: + lr_t = lr * _dispatch_sqrt(1 - beta2**step) / (1 - beta1**step) + denom = exp_avg_sq.sqrt().add_(eps) + param.addcdiv_(exp_avg, denom, value=-lr_t) + + +def adam( + params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_avg_sqs: List[Tensor], + max_exp_avg_sqs: List[Tensor], + state_steps: List[Tensor], + # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627 + # setting this as kwarg for now as functional API is compiled by torch/distributed/optim + foreach: Optional[bool] = None, + capturable: bool = False, + differentiable: bool = False, + fused: Optional[bool] = None, + grad_scale: Optional[Tensor] = None, + found_inf: Optional[Tensor] = None, + *, + amsgrad: bool, + beta1: float, + beta2: float, + lr: float, + weight_decay: float, + eps: float, + maximize: bool, +): + func = _single_tensor_adam + + func( + params, + grads, + exp_avgs, + exp_avg_sqs, + max_exp_avg_sqs, + state_steps, + amsgrad=amsgrad, + beta1=beta1, + beta2=beta2, + lr=lr, + weight_decay=weight_decay, + eps=eps, + maximize=maximize, + capturable=capturable, + differentiable=differentiable, + grad_scale=grad_scale, + found_inf=found_inf, + ) + + +class AdamTensorFlowStyle(optim.Adam): + @_use_grad_for_differentiable + def step(self, closure=None): + self._cuda_graph_capture_health_check() + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + params_with_grad = [] + grads = [] + exp_avgs = [] + exp_avg_sqs = [] + max_exp_avg_sqs = [] + state_steps = [] + beta1, beta2 = group["betas"] + + self._init_group( + group, + params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + max_exp_avg_sqs, + state_steps, + ) + + adam( + params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + max_exp_avg_sqs, + state_steps, + amsgrad=group["amsgrad"], + beta1=beta1, + beta2=beta2, + lr=group["lr"], + weight_decay=group["weight_decay"], + eps=group["eps"], + maximize=group["maximize"], + foreach=group["foreach"], + capturable=group["capturable"], + differentiable=group["differentiable"], + fused=group["fused"], + grad_scale=getattr(self, "grad_scale", None), + found_inf=getattr(self, "found_inf", None), + ) + + return loss + + +def right_padding_to_left_padding(tokens, pad_id): + """Convert from right padding to left padding.""" + assert tokens.ndim == 2 + return torch.tensor( + [[pad_id] * (row == pad_id).sum() + [x for x in row if x != pad_id] for row in tokens], + device=tokens.device, + ) + + +def ceil_div(a, b): + return (a - 1) // b + 1 + + +def exact_div(a, b): + q = a // b + if a != q * b: + raise ValueError(f"Inexact division: {a} / {b} = {a / b}") + return q + + +def generate(lm_backbone, queries, tokenizer, generation_config): + """generate in a way that does not affect padding tokens""" + context_length = queries.shape[1] + attention_mask = queries != tokenizer.pad_token_id + input_ids = queries.clone() + input_ids[~attention_mask] = 0 # set padding tokens to 0 + output = lm_backbone.generate( + input_ids=input_ids, + attention_mask=attention_mask, + # position_ids=attention_mask.cumsum(1) - attention_mask.long(), # generation collapsed if this was turned on. TODO: why does generation collapse with this? + generation_config=generation_config, + return_dict_in_generate=True, + ) + # restore padding tokens + return torch.cat((queries, output.sequences[:, context_length:]), dim=1) + + +def forward(policy, query_responses, tokenizer): + attention_mask = query_responses != tokenizer.pad_token_id + position_ids = attention_mask.cumsum(1) - attention_mask.long() # exclusive cumsum + input_ids = torch.masked_fill(query_responses, ~attention_mask, 0) + return policy( + labels=input_ids, + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + return_dict=True, + ) + + +def train(args: Args): + accelerator = Accelerator(gradient_accumulation_steps=args.sft.gradient_accumulation_steps) + args.sft.world_size = accelerator.num_processes + args.sft.local_batch_size = args.sft.local_micro_batch_size * args.sft.gradient_accumulation_steps + args.sft.batch_size = int(args.sft.local_batch_size * args.sft.world_size) + patch_h = TaskQueryHParams( + length=args.task.query_length, + dataset=args.task.query_dataset, + format_str=args.task.query_format_str, + truncate_field=args.task.query_truncate_field, + truncate_text=args.task.query_truncate_text, + padding=args.task.query_padding, + pad_side=args.task.query_pad_side, + ) + dataset = load_dataset(args.task.query_dataset, split="train") + test_dataset = load_dataset(args.task.query_dataset, split="test") + accelerator.print("The number of samples in dataset", len(dataset)) + accelerator.print("The number of samples in test_dataset", len(test_dataset)) + args.sft.total_episodes = len(dataset) + args.sft.num_updates = args.sft.total_episodes // args.sft.batch_size + + console = Console(force_terminal=True) + run_name = f"{args.exp_name}__{args.seed}__{int(time.time())}" + writer = SimpleNamespace() # dummy writer + writer.add_scalar = lambda x, y, z: None + writer.add_histogram = lambda x, y, z: None + if accelerator.is_main_process: + if args.track: + import wandb + + wandb.init( + project=args.wandb_project_name, + entity=args.wandb_entity, + sync_tensorboard=True, + config=asdict(args), + name=run_name, + save_code=True, + ) + wandb.run.log_code(".") + writer = SummaryWriter(f"runs/{run_name}") + writer.add_text( + "hyperparameters", + "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])), + ) + pprint(args) + device = accelerator.device + local_seed = args.seed + accelerator.process_index * 100003 # Prime + random.seed(local_seed) + np.random.seed(local_seed) + torch.manual_seed(local_seed) + torch.backends.cudnn.deterministic = True + tokenizer = AutoTokenizer.from_pretrained( + args.base_model, + padding_side="right", + trust_remote_code=True, + ) + # we use the padding token manually but do not resize the token embedding of the model + tokenizer.add_special_tokens({"pad_token": "[PAD]"}) + policy = AutoModelForCausalLM.from_pretrained(args.base_model, trust_remote_code=True) + policy.generation_config.eos_token_id = None # disable `pad_token_id` and `eos_token_id` because we just want to + policy.generation_config.pad_token_id = None # generate tokens without truncation / padding + # IMPORTANT: Layer norm produces weird gradients, which affects Adam optimizer to impact all the parameters systematically + # see https://github.com/pytorch/pytorch/issues/104857 for more details + if args.use_tensorflow_adam: + optimizer = AdamTensorFlowStyle(policy.parameters(), lr=args.sft.lr, eps=args.sft.eps) + else: + optimizer = optim.Adam(policy.parameters(), lr=args.sft.lr, eps=args.sft.eps) + + def process_query_data(x): + pad_summary_w_leading_space = " " + x['summary'] + return { + **process_query(x, encoder=tokenizer, hparams=patch_h), + "reference_response": tokenizer.encode( + pad_summary_w_leading_space, padding="max_length", max_length=args.task.response_length, truncation=True, + # with an extra leading space to account for the space between the query and response + ), + } + + dataset = dataset.map(process_query_data) + dataset = dataset.with_format("torch", columns=["query_token", "reference_response"]) + dataset = dataset.shuffle(seed=local_seed) + test_dataset = test_dataset.map(process_query_data) + test_dataset = test_dataset.with_format("torch", columns=["query_token", "reference_response"]) + test_dataset = test_dataset.shuffle(seed=local_seed) + dataloader = DataLoader(dataset, batch_size=args.sft.local_micro_batch_size) + test_dataloader = DataLoader(test_dataset, batch_size=args.sft.local_micro_batch_size) + policy, optimizer, dataloader, test_dataloader = accelerator.prepare(policy, optimizer, dataloader, test_dataloader) + iter_dataloader = iter(dataloader) + # WARNING: even with `max_new_tokens` and `min_new_tokens` set to the same value, the number of tokens generated + # may not be the same. TODO: investigate further, we just want to generate a fixed number of tokens + generation_config = GenerationConfig( + max_new_tokens=args.task.response_length, + min_new_tokens=args.task.response_length, + temperature=args.task.temperature, + top_k=0.0, + top_p=1.0, + do_sample=True, + ) + rouge = evaluate.load("rouge") + + print("===training policy===") + global_step = 0 + test_data = test_dataset[0:10] + test_data = {k: v.to(device) for k, v in test_data.items()} + loss_stats = torch.zeros(args.sft.gradient_accumulation_steps, device=device) + gradient_accumulation_idx = 0 + + # Given parameters + eta_min = 0 + eta_max = 6.35e-5 + T_max = args.sft.num_updates + + for update in range(1, args.sft.num_updates + 1): + global_step += 1 * args.sft.batch_size + accelerator.print(f"update {update}, global_step {global_step}") + # frac = 1.0 - (update - 1.0) / args.sft.num_updates + # lrnow = frac * args.sft.lr + lrnow = eta_min + 0.5 * (eta_max - eta_min) * (1 + np.cos(np.pi * (update - 1) / T_max)) + optimizer.param_groups[0]["lr"] = lrnow + data = next(iter_dataloader) + queries = data["query_token"].to(device) + reference_responses = data["reference_response"].to(device) + query_responses = torch.cat((queries, reference_responses), dim=1) + query_responses = right_padding_to_left_padding(query_responses, tokenizer.pad_token_id).to(device) + with accelerator.accumulate(policy): + output = forward(policy, query_responses, tokenizer) + loss = output.loss + accelerator.backward(loss) + optimizer.step() + optimizer.zero_grad() + loss_stats[gradient_accumulation_idx] = loss + gradient_accumulation_idx = (gradient_accumulation_idx + 1) % args.sft.gradient_accumulation_steps + if update > 1 and (update - 1) % args.sft.gradient_accumulation_steps == 0: + writer.add_scalar("loss", accelerator.gather(loss_stats).mean().item(), update) + writer.add_scalar("lr", lrnow, update) + if (update - 1) % args.print_sample_output_freq * args.sft.gradient_accumulation_steps == 0: + rouge_scores = collections.defaultdict(list) + for test_idx, test_data in enumerate(test_dataloader): + with torch.no_grad(): + test_queries = test_data["query_token"].to(device) + test_reference_responses = test_data["reference_response"].to(device) + test_queries = right_padding_to_left_padding(test_queries, tokenizer.pad_token_id) + generated_responses = generate(accelerator.unwrap_model(policy), test_queries, tokenizer, generation_config) + accelerator.print(update, test_idx) + + all_decode_test_queries = tokenizer.batch_decode(test_queries, skip_special_tokens=True) + all_decode_test_query_responses = tokenizer.batch_decode(generated_responses, skip_special_tokens=True) + all_decode_test_reference_responses = tokenizer.batch_decode( + test_reference_responses, skip_special_tokens=True + ) + all_decode_test_responses = [ + x[len(y) :] for x, y in zip(all_decode_test_query_responses, all_decode_test_queries) + ] + rouge_score = rouge.compute(predictions=all_decode_test_responses, references=all_decode_test_reference_responses) + rouge_scores["rouge1"].append(rouge_score["rouge1"]) + rouge_scores["rouge2"].append(rouge_score["rouge2"]) + rouge_scores["rougeL"].append(rouge_score["rougeL"]) + + if test_idx == 0: + try: + all_df = pd.DataFrame( + { + "query": all_decode_test_queries, + "response": all_decode_test_responses, + "reference": all_decode_test_reference_responses, + } + ) + if accelerator.is_main_process and args.track: + wandb.log({"samples/query_responses": wandb.Table(dataframe=all_df)}, step=update) + print_rich_table(f"Sample Output at Step {update}", all_df[:4], console) + except Exception as e: + print(e) + + for k, v in rouge_scores.items(): + rouge_metric = torch.tensor(v, device=device) + rouge_metric = accelerator.gather(rouge_metric) + writer.add_scalar(f"rouge/{k}", rouge_metric.mean().item(), update) + accelerator.print(f"rouge/{k}: {rouge_metric.mean().item()} {rouge_metric.shape} {rouge_metric}") + + # save model + if args.save_path: + os.makedirs(os.path.dirname(args.save_path), exist_ok=True) + accelerator.save_model(policy, args.save_path) + + if args.upload_model: + repo_name = f"{args.exp_name}__{args.rewards.label_dataset}__seed{args.seed}__{int(time.time())}" + repo_id = f"{args.hf_entity}/{repo_name}" if args.hf_entity else repo_name + policy.save_pretrained(repo_id, safe_serialization=True, push_to_hub=True) + tokenizer.save_pretrained(repo_id, push_to_hub=True) + +if __name__ == "__main__": + args = tyro.cli(Args) + train(args) \ No newline at end of file diff --git a/lm_human_preference_details/summarization/train_sft_accelerate_summarize_executor.py b/lm_human_preference_details/summarization/train_sft_accelerate_summarize_executor.py new file mode 100644 index 0000000..618a1e9 --- /dev/null +++ b/lm_human_preference_details/summarization/train_sft_accelerate_summarize_executor.py @@ -0,0 +1,540 @@ +import collections +import os +import random +import time +from dataclasses import asdict, dataclass, field +from types import SimpleNamespace +from typing import List, Optional + +import numpy as np +import pandas as pd +import torch +import torch.optim as optim +import tyro +import evaluate +from accelerate import Accelerator +from datasets import load_dataset +from rich.console import Console +from rich.pretty import pprint +from rich.table import Table +from torch import Tensor, optim +from torch.optim.optimizer import ( + _dispatch_sqrt, + _get_value, + _use_grad_for_differentiable, +) +from torch.utils.data import DataLoader +from torch.utils.tensorboard import SummaryWriter +from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig + +from lm_human_preference_details.data import process_query +from concurrent.futures import ProcessPoolExecutor + + +@dataclass +class SFTHParams: + gradient_accumulation_steps: int = 16 + local_micro_batch_size: int = 1 + noptepochs: int = 1 + lr: float = 6.35e-5 + eps: float = 1e-5 + total_episodes: tyro.conf.Suppress[int] = None + local_batch_size:tyro.conf.Suppress[int] = None + batch_size: tyro.conf.Suppress[int] = None + mini_batch_size: tyro.conf.Suppress[int] = None + world_size: tyro.conf.Suppress[int] = None + batch_size: tyro.conf.Suppress[int] = None + num_updates: tyro.conf.Suppress[int] = None + + +@dataclass +class TaskHParams: + # Query params + query_length: int = 512 + query_dataset: str = "vwxyzjn/summarize_from_feedback_tldr_3_filtered" + + query_format_str: Optional[str] = "SUBREDDIT: r/{subreddit}\n\nTITLE: {title}\n\nPOST: {post}\n\nTL;DR:" + query_truncate_field: Optional[str] = "post" + query_truncate_text: Optional[str] = "\n" + query_padding: Optional[str] = None # defaults to repeated spaces + query_pad_side: Optional[str] = "left" + + # Response params + response_length: int = 48 + + # Truncate response after the first occurrence of this token at or after index after when sampling. + truncate_token: int = 50256 # EOS token + truncate_after: int = 16 + penalty_reward_value: int = -1 + + # LM params + temperature: float = 0.01 + + +# a patch +@dataclass +class TaskQueryHParams: + length: int = None + dataset: str = None + format_str: Optional[str] = None # if underlying dataset yields dicts, can format arbitrarily + truncate_field: Optional[str] = None + truncate_text: Optional[str] = None + padding: Optional[str] = None # defaults to repeated spaces + pad_side: Optional[str] = None + + +@dataclass +class Args: + # common args + exp_name: str = os.path.basename(__file__)[: -len(".py")] + """the name of this experiment""" + seed: int = 1 + """seed of the experiment""" + track: bool = False + """if toggled, this experiment will be tracked with Weights and Biases""" + wandb_project_name: str = "cleanrl" + """the wandb's project name""" + wandb_entity: Optional[str] = None + """the entity (team) of wandb's project""" + cuda: bool = True + """Whether to use cuda if available.""" + run_name: tyro.conf.Suppress[str] = None + """TO BE FILLED: a unique name of this run""" + upload_model: bool = False + "whether to upload the saved model to huggingface" + hf_entity: str = "" + "the user or org name of the model repository from the Hugging Face Hub" + + base_model: str = "gpt2" + """the name of the pretrained model to use""" + deepspeed: bool = False + """Whether to use deepspeed to train the model""" + print_sample_output_freq: int = 180 + """How often to print sample output""" + save_path: str = "models/sft_policy.pt" + """Where to save the model""" + use_tensorflow_adam: bool = True + """Whether to use tensorflow-style Adam optimizer instead of PyTorch's""" + task: TaskHParams = field(default_factory=TaskHParams) + sft: SFTHParams = field(default_factory=SFTHParams) + + +def print_rich_table(title: str, df: pd.DataFrame, console: Console) -> Table: + table = Table(show_lines=True) + for column in df.columns: + table.add_column(column) + for _, row in df.iterrows(): + table.add_row(*row.astype(str).tolist()) + console.rule(f"[bold red]{title}") + console.print(table) + + +def calculate_rouge( + base_model: str, + test_queries: List[List[str]], + generated_responses: List[List[str]], + test_reference_responses: List[List[str]], +): + tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True) + all_decode_test_queries = tokenizer.batch_decode(test_queries, skip_special_tokens=True) + all_decode_test_query_responses = tokenizer.batch_decode(generated_responses, skip_special_tokens=True) + all_decode_test_reference_responses = tokenizer.batch_decode( + test_reference_responses, skip_special_tokens=True + ) + all_decode_test_responses = [ + x[len(y) :] for x, y in zip(all_decode_test_query_responses, all_decode_test_queries) + ] + rouge = evaluate.load("rouge") + return rouge.compute(predictions=predictions, references=references) + + +def _single_tensor_adam( + params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_avg_sqs: List[Tensor], + max_exp_avg_sqs: List[Tensor], + state_steps: List[Tensor], + grad_scale: Optional[Tensor], + found_inf: Optional[Tensor], + *, + amsgrad: bool, + beta1: float, + beta2: float, + lr: float, + weight_decay: float, + eps: float, + maximize: bool, + capturable: bool, + differentiable: bool, +): + assert grad_scale is None and found_inf is None + + for i, param in enumerate(params): + grad = grads[i] if not maximize else -grads[i] + exp_avg = exp_avgs[i] + exp_avg_sq = exp_avg_sqs[i] + step_t = state_steps[i] + # update step + step_t += 1 + # Decay the first and second moment running average coefficient + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2) + step = _get_value(step_t) + + ### pytorch adam implementation: + # bias_correction1 = 1 - beta1 ** step + # bias_correction2 = 1 - beta2 ** step + # step_size = lr / bias_correction1 + # bias_correction2_sqrt = _dispatch_sqrt(bias_correction2) + # denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps) + # param.addcdiv_(exp_avg, denom, value=-step_size) + + ### tensorflow adam implementation: + lr_t = lr * _dispatch_sqrt(1 - beta2**step) / (1 - beta1**step) + denom = exp_avg_sq.sqrt().add_(eps) + param.addcdiv_(exp_avg, denom, value=-lr_t) + + +def adam( + params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_avg_sqs: List[Tensor], + max_exp_avg_sqs: List[Tensor], + state_steps: List[Tensor], + # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627 + # setting this as kwarg for now as functional API is compiled by torch/distributed/optim + foreach: Optional[bool] = None, + capturable: bool = False, + differentiable: bool = False, + fused: Optional[bool] = None, + grad_scale: Optional[Tensor] = None, + found_inf: Optional[Tensor] = None, + *, + amsgrad: bool, + beta1: float, + beta2: float, + lr: float, + weight_decay: float, + eps: float, + maximize: bool, +): + func = _single_tensor_adam + + func( + params, + grads, + exp_avgs, + exp_avg_sqs, + max_exp_avg_sqs, + state_steps, + amsgrad=amsgrad, + beta1=beta1, + beta2=beta2, + lr=lr, + weight_decay=weight_decay, + eps=eps, + maximize=maximize, + capturable=capturable, + differentiable=differentiable, + grad_scale=grad_scale, + found_inf=found_inf, + ) + + +class AdamTensorFlowStyle(optim.Adam): + @_use_grad_for_differentiable + def step(self, closure=None): + self._cuda_graph_capture_health_check() + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + params_with_grad = [] + grads = [] + exp_avgs = [] + exp_avg_sqs = [] + max_exp_avg_sqs = [] + state_steps = [] + beta1, beta2 = group["betas"] + + self._init_group( + group, + params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + max_exp_avg_sqs, + state_steps, + ) + + adam( + params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + max_exp_avg_sqs, + state_steps, + amsgrad=group["amsgrad"], + beta1=beta1, + beta2=beta2, + lr=group["lr"], + weight_decay=group["weight_decay"], + eps=group["eps"], + maximize=group["maximize"], + foreach=group["foreach"], + capturable=group["capturable"], + differentiable=group["differentiable"], + fused=group["fused"], + grad_scale=getattr(self, "grad_scale", None), + found_inf=getattr(self, "found_inf", None), + ) + + return loss + + +def right_padding_to_left_padding(tokens, pad_id): + """Convert from right padding to left padding.""" + assert tokens.ndim == 2 + return torch.tensor( + [[pad_id] * (row == pad_id).sum() + [x for x in row if x != pad_id] for row in tokens], + device=tokens.device, + ) + + +def ceil_div(a, b): + return (a - 1) // b + 1 + + +def exact_div(a, b): + q = a // b + if a != q * b: + raise ValueError(f"Inexact division: {a} / {b} = {a / b}") + return q + + +def generate(lm_backbone, queries, tokenizer, generation_config): + """generate in a way that does not affect padding tokens""" + context_length = queries.shape[1] + attention_mask = queries != tokenizer.pad_token_id + input_ids = torch.masked_fill(queries, ~attention_mask, 0) + output = lm_backbone.generate( + input_ids=input_ids, + attention_mask=attention_mask, + # position_ids=attention_mask.cumsum(1) - attention_mask.long(), # generation collapsed if this was turned on. TODO: why does generation collapse with this? + generation_config=generation_config, + return_dict_in_generate=True, + ) + # restore padding tokens + return torch.cat((queries, output.sequences[:, context_length:]), dim=1) + + +def forward(policy, query_responses, tokenizer): + attention_mask = query_responses != tokenizer.pad_token_id + position_ids = attention_mask.cumsum(1) - attention_mask.long() # exclusive cumsum + input_ids = torch.masked_fill(query_responses, ~attention_mask, 0) + return policy( + labels=input_ids, + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + return_dict=True, + ) + + +def train(args: Args): + accelerator = Accelerator(gradient_accumulation_steps=args.sft.gradient_accumulation_steps) + args.sft.world_size = accelerator.num_processes + args.sft.local_batch_size = args.sft.local_micro_batch_size * args.sft.gradient_accumulation_steps + args.sft.batch_size = int(args.sft.local_batch_size * args.sft.world_size) + patch_h = TaskQueryHParams( + length=args.task.query_length, + dataset=args.task.query_dataset, + format_str=args.task.query_format_str, + truncate_field=args.task.query_truncate_field, + truncate_text=args.task.query_truncate_text, + padding=args.task.query_padding, + pad_side=args.task.query_pad_side, + ) + dataset = load_dataset(args.task.query_dataset, split="train") + test_dataset = load_dataset(args.task.query_dataset, split="test") + accelerator.print("The number of samples in dataset", len(dataset)) + accelerator.print("The number of samples in test_dataset", len(test_dataset)) + args.sft.total_episodes = len(dataset) + args.sft.num_updates = args.sft.total_episodes // args.sft.batch_size + + console = Console(force_terminal=True) + run_name = f"{args.exp_name}__{args.seed}__{int(time.time())}" + writer = SimpleNamespace() # dummy writer + writer.add_scalar = lambda x, y, z: None + writer.add_histogram = lambda x, y, z: None + if accelerator.is_main_process: + if args.track: + import wandb + + wandb.init( + project=args.wandb_project_name, + entity=args.wandb_entity, + sync_tensorboard=True, + config=asdict(args), + name=run_name, + save_code=True, + ) + wandb.run.log_code(".") + writer = SummaryWriter(f"runs/{run_name}") + writer.add_text( + "hyperparameters", + "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])), + ) + pprint(args) + device = accelerator.device + local_seed = args.seed + accelerator.process_index * 100003 # Prime + random.seed(local_seed) + np.random.seed(local_seed) + torch.manual_seed(local_seed) + torch.backends.cudnn.deterministic = True + tokenizer = AutoTokenizer.from_pretrained( + args.base_model, + padding_side="right", + trust_remote_code=True, + ) + # we use the padding token manually but do not resize the token embedding of the model + tokenizer.add_special_tokens({"pad_token": "[PAD]"}) + policy = AutoModelForCausalLM.from_pretrained(args.base_model, trust_remote_code=True) + policy.generation_config.eos_token_id = None # disable `pad_token_id` and `eos_token_id` because we just want to + policy.generation_config.pad_token_id = None # generate tokens without truncation / padding + # IMPORTANT: Layer norm produces weird gradients, which affects Adam optimizer to impact all the parameters systematically + # see https://github.com/pytorch/pytorch/issues/104857 for more details + if args.use_tensorflow_adam: + optimizer = AdamTensorFlowStyle(policy.parameters(), lr=args.sft.lr, eps=args.sft.eps) + else: + optimizer = optim.Adam(policy.parameters(), lr=args.sft.lr, eps=args.sft.eps) + + def process_query_data(x): + return { + **process_query(x, encoder=tokenizer, hparams=patch_h), + "reference_response": tokenizer.encode( + f" {x['summary']}", padding="max_length", max_length=args.task.response_length, truncation=True, + # with an extra leading space to account for the space between the query and response + ), + } + + dataset = dataset.map(process_query_data) + dataset = dataset.with_format("torch", columns=["query_token", "reference_response"]) + dataset = dataset.shuffle(seed=local_seed) + test_dataset = test_dataset.map(process_query_data) + test_dataset = test_dataset.with_format("torch", columns=["query_token", "reference_response"]) + test_dataset = test_dataset.shuffle(seed=local_seed) + dataloader = DataLoader(dataset, batch_size=args.sft.local_micro_batch_size) + test_dataloader = DataLoader(test_dataset, batch_size=args.sft.local_micro_batch_size) + policy, optimizer, dataloader, test_dataloader = accelerator.prepare(policy, optimizer, dataloader, test_dataloader) + iter_dataloader = iter(dataloader) + # WARNING: even with `max_new_tokens` and `min_new_tokens` set to the same value, the number of tokens generated + # may not be the same. TODO: investigate further, we just want to generate a fixed number of tokens + generation_config = GenerationConfig( + max_new_tokens=args.task.response_length, + min_new_tokens=args.task.response_length, + temperature=args.task.temperature, + top_k=0.0, + top_p=1.0, + do_sample=True, + ) + executor = ProcessPoolExecutor() + # rouge = evaluate.load("rouge") + + print("===training policy===") + global_step = 0 + test_data = test_dataset[0:10] + test_data = {k: v.to(device) for k, v in test_data.items()} + loss_stats = torch.zeros(args.sft.gradient_accumulation_steps, device=device) + gradient_accumulation_idx = 0 + + # Given parameters + eta_min = 0 + eta_max = 6.35e-5 + T_max = args.sft.num_updates + + for update in range(1, args.sft.num_updates + 1): + global_step += 1 * args.sft.batch_size + accelerator.print(f"update {update}, global_step {global_step}") + # frac = 1.0 - (update - 1.0) / args.sft.num_updates + # lrnow = frac * args.sft.lr + lrnow = eta_min + 0.5 * (eta_max - eta_min) * (1 + np.cos(np.pi * (update - 1) / T_max)) + optimizer.param_groups[0]["lr"] = lrnow + data = next(iter_dataloader) + queries = data["query_token"].to(device) + reference_responses = data["reference_response"].to(device) + query_responses = torch.cat((queries, reference_responses), dim=1) + query_responses = right_padding_to_left_padding(query_responses, tokenizer.pad_token_id).to(device) + with accelerator.accumulate(policy): + output = forward(policy, query_responses, tokenizer) + loss = output.loss + accelerator.backward(loss) + optimizer.step() + optimizer.zero_grad() + loss_stats[gradient_accumulation_idx] = loss + gradient_accumulation_idx = (gradient_accumulation_idx + 1) % args.sft.gradient_accumulation_steps + if update > 1 and (update - 1) % args.sft.gradient_accumulation_steps == 0: + writer.add_scalar("loss", accelerator.gather(loss_stats).mean().item(), update) + writer.add_scalar("lr", lrnow, update) + if (update - 1) % args.print_sample_output_freq * args.sft.gradient_accumulation_steps == 0: + rouge_scores = collections.defaultdict(list) + futures = [] + for test_idx, test_data in enumerate(test_dataloader): + with torch.no_grad(): + test_queries = test_data["query_token"].to(device) + test_reference_responses = test_data["reference_response"] + # test_queries = right_padding_to_left_padding(test_queries, tokenizer.pad_token_id) + generated_responses = generate(accelerator.unwrap_model(policy), test_queries, tokenizer, generation_config) + accelerator.print(update, test_idx) + + + # futures.append( + # executor.submit( + # calculate_rouge, + # args.base_model, + # test_queries.cpu(), + # generated_responses.cpu(), + # test_reference_responses.cpu(), + # ) + # ) + # if test_idx == 0: + # try: + # all_df = pd.DataFrame( + # { + # "query": all_decode_test_queries, + # "response": all_decode_test_responses, + # "reference": all_decode_test_reference_responses, + # } + # ) + # if accelerator.is_main_process and args.track: + # wandb.log({"samples/query_responses": wandb.Table(dataframe=all_df)}, step=update) + # print_rich_table(f"Sample Output at Step {update}", all_df[:4], console) + # except Exception as e: + # print(e) + + rouge_scores = [f.result() for f in futures] # list of dicts + rouge_scores = {k: np.mean([x[k] for x in rouge_scores]) for k in rouge_scores[0].keys()} + for k, v in rouge_scores.items(): + rouge_metric = torch.tensor(v, device=device) + rouge_metric = accelerator.gather(rouge_metric) + writer.add_scalar(f"rouge/{k}", rouge_metric.mean().item(), update) + accelerator.print(f"rouge/{k}: {rouge_metric.mean().item()} {rouge_metric.shape} {rouge_metric}") + + # save model + if accelerator.is_main_process and args.save_path: + os.makedirs(os.path.dirname(args.save_path), exist_ok=True) + torch.save(accelerator.unwrap_model(policy).state_dict(), args.save_path) + + if args.upload_model: + repo_name = f"{args.exp_name}__{args.rewards.label_dataset}__seed{args.seed}__{int(time.time())}" + repo_id = f"{args.hf_entity}/{repo_name}" if args.hf_entity else repo_name + policy.save_pretrained(repo_id, safe_serialization=True, push_to_hub=True) + tokenizer.save_pretrained(repo_id, push_to_hub=True) + +if __name__ == "__main__": + args = tyro.cli(Args) + train(args) diff --git a/lm_human_preference_details/train_policy_accelerate_summarize.py b/lm_human_preference_details/summarize_old/train_policy_accelerate_summarize.py similarity index 100% rename from lm_human_preference_details/train_policy_accelerate_summarize.py rename to lm_human_preference_details/summarize_old/train_policy_accelerate_summarize.py diff --git a/lm_human_preference_details/summarize_old/train_policy_accelerate_summarize_separate.py b/lm_human_preference_details/summarize_old/train_policy_accelerate_summarize_separate.py new file mode 100644 index 0000000..90df20d --- /dev/null +++ b/lm_human_preference_details/summarize_old/train_policy_accelerate_summarize_separate.py @@ -0,0 +1,1021 @@ +import os +import random +import time +from dataclasses import asdict, dataclass, field +from types import SimpleNamespace +from typing import List, Optional + +import numpy as np +import pandas as pd +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +import tyro +from accelerate import Accelerator +from accelerate.state import AcceleratorState +from datasets import load_dataset +from rich.console import Console +from rich.pretty import pprint +from rich.table import Table +from torch import Tensor, optim +from torch.optim.optimizer import ( + _dispatch_sqrt, + _get_value, + _use_grad_for_differentiable, +) +from torch.utils.data import DataLoader +from torch.utils.tensorboard import SummaryWriter +from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig + +from lm_human_preference_details.data import process_query + +INVALID_LOGPROB = 1.0 + +@dataclass +class AdaptiveKLParams: + target: float = 6.0 + horizon: int = 10000 # in episodes + + +@dataclass +class RewardHParams: + kl_coef: float = 0.15 + use_adaptive_kl: bool = True + adaptive_kl: Optional[AdaptiveKLParams] = field(default_factory=AdaptiveKLParams) + trained_model: Optional[str] = "models/reward" + label_dataset: tyro.conf.Suppress[Optional[str]] = None + + +@dataclass +class PpoHParams: + total_episodes: int = 1000000 + local_batch_size: int = 64 + local_mini_batch_size: tyro.conf.Suppress[int] = None + batch_size: tyro.conf.Suppress[int] = None + mini_batch_size: tyro.conf.Suppress[int] = None + gradient_accumulation_steps: int = 1 + """gradient accumulation steps""" + local_micro_batch_size: tyro.conf.Suppress[int] = None + """per rank micro batch size""" + world_size: tyro.conf.Suppress[int] = None + batch_size: tyro.conf.Suppress[int] = None + minibatch_size: tyro.conf.Suppress[int] = None + num_updates: tyro.conf.Suppress[int] = None + nminibatches: int = 1 + noptepochs: int = 4 + lr: float = 0.00001 + eps: float = 1e-5 + vf_coef: float = 0.1 + cliprange: float = 0.2 + cliprange_value: float = 0.2 + gamma: float = 1 + lam: float = 0.95 + whiten_rewards: bool = True + + +@dataclass +class TaskHParams: + # Query params + query_length: int = 512 + query_dataset: str = "vwxyzjn/summarize_from_feedback_tldr_3_filtered" + + query_format_str: Optional[str] = "SUBREDDIT: r/{subreddit}\n\nTITLE: {title}\n\nPOST: {post}\n\nTL;DR:" + query_truncate_field: Optional[str] = "post" + query_truncate_text: Optional[str] = "\n" + query_padding: Optional[str] = None # defaults to repeated spaces + query_pad_side: Optional[str] = "left" + + # Response params + response_length: int = 48 + + # Truncate response after the first occurrence of this token at or after index after when sampling. + truncate_token: int = 50256 # EOS token + truncate_after: int = 16 + penalty_reward_value: int = -1 + + # LM params + temperature: float = 0.7 + + +# a patch +@dataclass +class TaskQueryHParams: + length: int = None + dataset: str = None + format_str: Optional[str] = None # if underlying dataset yields dicts, can format arbitrarily + truncate_field: Optional[str] = None + truncate_text: Optional[str] = None + padding: Optional[str] = None # defaults to repeated spaces + pad_side: Optional[str] = None + + +@dataclass +class Args: + # common args + exp_name: str = os.path.basename(__file__)[: -len(".py")] + """the name of this experiment""" + seed: int = 1 + """seed of the experiment""" + track: bool = False + """if toggled, this experiment will be tracked with Weights and Biases""" + wandb_project_name: str = "cleanrl" + """the wandb's project name""" + wandb_entity: Optional[str] = None + """the entity (team) of wandb's project""" + cuda: bool = True + """Whether to use cuda if available.""" + run_name: tyro.conf.Suppress[str] = None + """TO BE FILLED: a unique name of this run""" + load_from_cache_file: bool = False + """Whether to load data from the local cache file in `dataset.map`""" + upload_model: bool = False + "whether to upload the saved model to huggingface" + hf_entity: str = "" + "the user or org name of the model repository from the Hugging Face Hub" + + base_model: str = "gpt2" + """the name of the pretrained model to use""" + deepspeed: bool = False + """Whether to use deepspeed to train the model""" + print_sample_output_freq: int = 1 + """How often to print sample output""" + sft_model_path: str = "models/sft_policy" + """Where to load the SFT model""" + save_path: str = "models/policy.pt" + """Where to save the model""" + use_tensorflow_adam: bool = True + """Whether to use tensorflow-style Adam optimizer instead of PyTorch's""" + task: TaskHParams = field(default_factory=TaskHParams) + rewards: RewardHParams = field(default_factory=RewardHParams) + ppo: PpoHParams = field(default_factory=PpoHParams) + + +def first_true_indices(bools, dtype=torch.long): + """ + Takes an N-dimensional bool tensor and returns an (N-1)-dimensional tensor of integers giving + the position of the first True in each "row". + + Returns the length of the rows (bools.size(-1)) if no element is True in a given row. + """ + row_len = bools.size(-1) + zero_or_index = row_len * (~bools).type(dtype) + torch.arange(row_len, dtype=dtype, device=bools.device) + return torch.min(zero_or_index, dim=-1).values + + +def print_rich_table(title: str, df: pd.DataFrame, console: Console) -> Table: + table = Table(show_lines=True) + for column in df.columns: + table.add_column(column) + for _, row in df.iterrows(): + table.add_row(*row.astype(str).tolist()) + console.rule(f"[bold red]{title}") + console.print(table) + + +def _single_tensor_adam( + params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_avg_sqs: List[Tensor], + max_exp_avg_sqs: List[Tensor], + state_steps: List[Tensor], + grad_scale: Optional[Tensor], + found_inf: Optional[Tensor], + *, + amsgrad: bool, + beta1: float, + beta2: float, + lr: float, + weight_decay: float, + eps: float, + maximize: bool, + capturable: bool, + differentiable: bool, +): + assert grad_scale is None and found_inf is None + + for i, param in enumerate(params): + grad = grads[i] if not maximize else -grads[i] + exp_avg = exp_avgs[i] + exp_avg_sq = exp_avg_sqs[i] + step_t = state_steps[i] + # update step + step_t += 1 + # Decay the first and second moment running average coefficient + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2) + step = _get_value(step_t) + + ### pytorch adam implementation: + # bias_correction1 = 1 - beta1 ** step + # bias_correction2 = 1 - beta2 ** step + # step_size = lr / bias_correction1 + # bias_correction2_sqrt = _dispatch_sqrt(bias_correction2) + # denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps) + # param.addcdiv_(exp_avg, denom, value=-step_size) + + ### tensorflow adam implementation: + lr_t = lr * _dispatch_sqrt(1 - beta2**step) / (1 - beta1**step) + denom = exp_avg_sq.sqrt().add_(eps) + param.addcdiv_(exp_avg, denom, value=-lr_t) + + +def adam( + params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_avg_sqs: List[Tensor], + max_exp_avg_sqs: List[Tensor], + state_steps: List[Tensor], + # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627 + # setting this as kwarg for now as functional API is compiled by torch/distributed/optim + foreach: Optional[bool] = None, + capturable: bool = False, + differentiable: bool = False, + fused: Optional[bool] = None, + grad_scale: Optional[Tensor] = None, + found_inf: Optional[Tensor] = None, + *, + amsgrad: bool, + beta1: float, + beta2: float, + lr: float, + weight_decay: float, + eps: float, + maximize: bool, +): + func = _single_tensor_adam + + func( + params, + grads, + exp_avgs, + exp_avg_sqs, + max_exp_avg_sqs, + state_steps, + amsgrad=amsgrad, + beta1=beta1, + beta2=beta2, + lr=lr, + weight_decay=weight_decay, + eps=eps, + maximize=maximize, + capturable=capturable, + differentiable=differentiable, + grad_scale=grad_scale, + found_inf=found_inf, + ) + + +class AdamTensorFlowStyle(optim.Adam): + @_use_grad_for_differentiable + def step(self, closure=None): + self._cuda_graph_capture_health_check() + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + params_with_grad = [] + grads = [] + exp_avgs = [] + exp_avg_sqs = [] + max_exp_avg_sqs = [] + state_steps = [] + beta1, beta2 = group["betas"] + + self._init_group( + group, + params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + max_exp_avg_sqs, + state_steps, + ) + + adam( + params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + max_exp_avg_sqs, + state_steps, + amsgrad=group["amsgrad"], + beta1=beta1, + beta2=beta2, + lr=group["lr"], + weight_decay=group["weight_decay"], + eps=group["eps"], + maximize=group["maximize"], + foreach=group["foreach"], + capturable=group["capturable"], + differentiable=group["differentiable"], + fused=group["fused"], + grad_scale=getattr(self, "grad_scale", None), + found_inf=getattr(self, "found_inf", None), + ) + + return loss + + +class AdaptiveKLController: + def __init__(self, init_kl_coef: float, hparams: AdaptiveKLParams): + self.value = init_kl_coef + self.hparams = hparams + + def update(self, current, n_steps): + target = self.hparams.target + proportional_error = np.clip(current / target - 1, -0.2, 0.2) + mult = 1 + proportional_error * n_steps / self.hparams.horizon + self.value *= mult + + +def layer_init(layer, std=np.sqrt(2), bias_const=0.0): + torch.nn.init.normal_(layer.weight, std=std) + torch.nn.init.constant_(layer.bias, val=bias_const) + return layer + + +def whiten(values, shift_mean=True): + # `unbiased=False` matches TF `tf.nn.moments`'s setting + mean, var = torch.mean(values), torch.var(values, unbiased=False) + whitened = (values - mean) * torch.rsqrt(var + 1e-8) + if not shift_mean: + whitened += mean + return whitened + + +class AutoModelForCausalLMWithRewardHead(nn.Module): + def __init__(self, lm_backbone): + super().__init__() + self.lm_backbone = lm_backbone + self.scalar_head = layer_init( + nn.Linear(lm_backbone.config.hidden_size, 1), + std=1 / np.sqrt(lm_backbone.config.hidden_size + 1), + ) + # self.reward_gain = torch.nn.Parameter(torch.tensor(1.0), requires_grad=True) + # self.reward_bias = torch.nn.Parameter(torch.tensor(0.0), requires_grad=True) + + def forward(self, **kwargs): + output = self.lm_backbone(**kwargs) + last_reward_latents = output.hidden_states[-1] + # shape: [batch_size, length, hidden_size] + # last_reward_latents = reward_latents + # shape: [batch_size, hidden_size] + reward = self.scalar_head(last_reward_latents) + # # shape: [batch_size, 1] + # reward = self.reward_gain * reward + self.reward_bias + return output, reward + + +# taken from https://github.com/OpenLMLab/MOSS-RLHF/blob/40b91eb2f2b71b16919addede0341d2bef70825d/ppo/ppo_trainer.py#L29 +# we did this we can do a single `model = accelerator.prepare(model)` +class PolicyAndValueWrapper(nn.Module): + def __init__(self, policy, critic) -> None: + super().__init__() + self.policy = policy + self.critic = critic + + def forward(self, **kwargs): + return self.policy(**kwargs), self.critic(**kwargs) + + +def right_padding_to_left_padding(tokens, pad_id): + """Convert from right padding to left padding.""" + assert tokens.ndim == 2 + return torch.tensor( + [[pad_id] * (row == pad_id).sum() + [x for x in row if x != pad_id] for row in tokens], + device=tokens.device, + ) + + +def ceil_div(a, b): + return (a - 1) // b + 1 + + +def exact_div(a, b): + q = a // b + if a != q * b: + raise ValueError(f"Inexact division: {a} / {b} = {a / b}") + return q + + +def generate(lm_backbone, queries, tokenizer, generation_config): + """generate in a way that does not affect padding tokens""" + context_length = queries.shape[1] + attention_mask = queries != tokenizer.pad_token_id + input_ids = queries.clone() + input_ids[~attention_mask] = 0 # set padding tokens to 0 + output = lm_backbone.generate( + input_ids=input_ids, + attention_mask=attention_mask, + # position_ids=attention_mask.cumsum(1) - attention_mask.long(), # generation collapsed if this was turned on. TODO: why does generation collapse with this? + generation_config=generation_config, + return_dict_in_generate=True, + ) + # restore padding tokens + return torch.cat((queries, output.sequences[:, context_length:]), dim=1) + + +def get_reward(reward_model, query_responses, tokenizer): + attention_mask = query_responses != tokenizer.pad_token_id + position_ids = attention_mask.cumsum(1) - attention_mask.long() # exclusive cumsum + input_ids = torch.masked_fill(query_responses, ~attention_mask, 0) + return reward_model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + return_dict=True, + output_hidden_states=True, + ) + + +def get_reward_complete(reward_model, query_responses, tokenizer): + reward = get_reward(reward_model, query_responses, tokenizer)[1] + last_response_indices = first_true_indices(query_responses == tokenizer.pad_token_id) - 1 + last_response_indices = torch.max( + last_response_indices, + torch.zeros([1], dtype=last_response_indices.dtype, device=query_responses.device), + ) + return reward[:, :, 0].gather(1, last_response_indices.unsqueeze(1)).view(-1) + + +def forward(policy, query_responses, tokenizer): + attention_mask = query_responses != tokenizer.pad_token_id + position_ids = attention_mask.cumsum(1) - attention_mask.long() # exclusive cumsum + input_ids = query_responses.clone() + input_ids[~attention_mask] = 0 + return policy( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + return_dict=True, + output_hidden_states=True, + ) + + +# def train(args: Args): +if __name__ == "__main__": + args = tyro.cli(Args) + accelerator = Accelerator(gradient_accumulation_steps=args.ppo.gradient_accumulation_steps) + args.ppo.world_size = accelerator.num_processes + args.ppo.batch_size = int(args.ppo.local_batch_size * args.ppo.world_size) + args.ppo.minibatch_size = exact_div(args.ppo.batch_size, args.ppo.nminibatches) + args.ppo.local_mini_batch_size = exact_div(args.ppo.local_batch_size, args.ppo.nminibatches) + args.ppo.local_micro_batch_size = exact_div(args.ppo.local_mini_batch_size, args.ppo.gradient_accumulation_steps) + patch_h = TaskQueryHParams( + length=args.task.query_length, + dataset=args.task.query_dataset, + format_str=args.task.query_format_str, + truncate_field=args.task.query_truncate_field, + truncate_text=args.task.query_truncate_text, + padding=args.task.query_padding, + pad_side=args.task.query_pad_side, + ) + if args.ppo.whiten_rewards: + assert ( + args.ppo.local_mini_batch_size >= 8 + ), f"Per-rank minibatch size {args.ppo.local_mini_batch_size} is insufficient for whitening" + # `per_rank_rollout_batch_size` is our `args.ppo.local_batch_size` + # `per_rank_minibatch_size` is our `args.ppo.local_mini_batch_size` + args.ppo.num_updates = args.ppo.total_episodes // args.ppo.batch_size + + console = Console(force_terminal=True) + run_name = f"{args.exp_name}__{args.seed}__{int(time.time())}" + writer = SimpleNamespace() # dummy writer + writer.add_scalar = lambda x, y, z: None + writer.add_histogram = lambda x, y, z: None + if accelerator.is_main_process: + if args.track: + import wandb + + wandb.init( + project=args.wandb_project_name, + entity=args.wandb_entity, + sync_tensorboard=True, + config=asdict(args), + name=run_name, + save_code=True, + ) + wandb.run.log_code(".") + writer = SummaryWriter(f"runs/{run_name}") + writer.add_text( + "hyperparameters", + "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])), + ) + pprint(args) + device = accelerator.device + local_seed = args.seed + accelerator.process_index * 100003 # Prime + random.seed(local_seed) + np.random.seed(local_seed) + torch.manual_seed(local_seed) + torch.backends.cudnn.deterministic = True + tokenizer = AutoTokenizer.from_pretrained( + args.base_model, + padding_side="right", + trust_remote_code=True, + ) + # we use the padding token manually but do not resize the token embedding of the model + tokenizer.add_special_tokens({"pad_token": "[PAD]"}) + reward_model = AutoModelForCausalLMWithRewardHead( + AutoModelForCausalLM.from_pretrained(args.base_model, trust_remote_code=True) + ) + critic = AutoModelForCausalLMWithRewardHead( + AutoModelForCausalLM.from_pretrained(args.base_model, trust_remote_code=True) + ) + if args.rewards.trained_model: + reward_model.load_state_dict(torch.load(args.rewards.trained_model, map_location=device)) + critic.load_state_dict(torch.load(args.rewards.trained_model, map_location=device)) + print(f"loaded pretrained reward model from {args.rewards.trained_model}") + # each class should have a separate pretrained model that do not share weights + ref_policy = AutoModelForCausalLM.from_pretrained(args.base_model, trust_remote_code=True) + policy = AutoModelForCausalLM.from_pretrained(args.base_model, trust_remote_code=True) + if args.sft_model_path: + policy.load_state_dict(torch.load(args.sft_model_path, map_location=device)) + ref_policy.load_state_dict(torch.load(args.sft_model_path, map_location=device)) + print(f"loaded pretrained policy from {args.sft_model_path}") + policy.generation_config.eos_token_id = ( + None # disable `pad_token_id` and `eos_token_id` because we just want to + ) + policy.generation_config.pad_token_id = None # generate tokens without truncation / padding + model = PolicyAndValueWrapper(policy, critic) + if args.use_tensorflow_adam: + optimizer = AdamTensorFlowStyle(model.parameters(), lr=args.ppo.lr, eps=args.ppo.eps) + else: + optimizer = optim.Adam(model.parameters(), lr=args.ppo.lr, eps=args.ppo.eps) + dataset = load_dataset(args.task.query_dataset, split="train") + validation_dataset = load_dataset(args.task.query_dataset, split="validation") + + def process_query_data(x): + return { + **process_query(x, encoder=tokenizer, hparams=patch_h), + "reference_response": tokenizer.encode( + f" {x['summary']}<|endoftext|>", padding="max_length", max_length=args.task.response_length, truncation=True, + # with an extra leading space to account for the space between the query and response + ), + } + + dataset = dataset.map(process_query_data, load_from_cache_file=args.load_from_cache_file) + dataset = dataset.with_format("torch", columns=["query_token", "reference_response"]) + dataset = dataset.shuffle(seed=local_seed) + dataloader = DataLoader(dataset, batch_size=args.ppo.local_batch_size) + validation_dataset = validation_dataset.map(process_query_data, load_from_cache_file=args.load_from_cache_file) + validation_dataset = validation_dataset.with_format("torch", columns=["query_token", "reference_response"]) + validation_dataloader = DataLoader(validation_dataset, batch_size=args.ppo.local_batch_size) + model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader) + validation_dataloader = accelerator.prepare(validation_dataloader) + if args.deepspeed: + import deepspeed + + deepspeed_states = AcceleratorState().deepspeed_plugin + # deepspeed_states.deepspeed_config["train_micro_batch_size_per_gpu"] = args.ppo.local_micro_batch_size + # deepspeed_states.deepspeed_config["checkpoint"] = {"use_node_local_storage": True} + eval_ds_config = { + "train_micro_batch_size_per_gpu": deepspeed_states.deepspeed_config["train_micro_batch_size_per_gpu"], + # "steps_per_print": 10, + # "zero_optimization": { + # "stage": stage, + # "stage3_param_persistence_threshold": 1e4, + # "offload_param": { + # "device": off_load_device + # } + # }, + "bf16": {"enabled": True}, + "prescale_gradients": False, + "wall_clock_breakdown": False, + } + reward_model, *_ = deepspeed.initialize(model=reward_model, config=eval_ds_config) + reward_model.eval() + ref_policy, *_ = deepspeed.initialize(model=ref_policy, config=eval_ds_config) + ref_policy.eval() + else: + ref_policy = ref_policy.to(device) + reward_model = reward_model.to(device) + + def repeat_generator(): # TODO: ideally we shuffle the dataloader as well + while True: + yield from dataloader + + sample_validation_inds = np.arange(args.ppo.batch_size) + local_sample_validation_inds = sample_validation_inds[accelerator.process_index :: accelerator.num_processes] + sample_validation = validation_dataset[local_sample_validation_inds] + sample_validation = {k: v.to(device) for k, v in sample_validation.items()} + sample_validation_queries = sample_validation["query_token"] + with torch.no_grad(): + print(sample_validation_queries.shape) + sample_validation_queries = right_padding_to_left_padding(sample_validation_queries, tokenizer.pad_token_id) + sample_validation_reference_response = sample_validation["reference_response"] + sample_validation_query_reference_responses = torch.cat((sample_validation_queries, sample_validation_reference_response), dim=1) + sample_validation_reference_scores = get_reward_complete(reward_model, sample_validation_query_reference_responses, tokenizer) + # breakpoint() + + iter_dataloader = iter(repeat_generator()) + kl_ctl = AdaptiveKLController(args.rewards.kl_coef, hparams=args.rewards.adaptive_kl) + # WARNING: even with `max_new_tokens` and `min_new_tokens` set to the same value, the number of tokens generated + # may not be the same. TODO: investigate further, we just want to generate a fixed number of tokens + generation_config = GenerationConfig( + max_new_tokens=args.task.response_length, + min_new_tokens=args.task.response_length, + temperature=(args.task.temperature + 1e-7), + top_k=0.0, + top_p=1.0, + do_sample=True, + ) + + # print("===Normalize reward model *before* training===") + # print( + # "before normalization. " + # + f"Gain: {accelerator.unwrap_model(reward_model).reward_gain.data}" + # + f" Bias: {accelerator.unwrap_model(reward_model).reward_bias.data}" + # ) + + # normalize( + # tokenizer, + # accelerator, + # device, + # reward_model, + # reward_model, + # dataloader, + # validation_dataloader, + # ) + # print( + # "after normalization. " + # + f"Gain: {accelerator.unwrap_model(reward_model).reward_gain.data}" + # + f" Bias: {accelerator.unwrap_model(reward_model).reward_bias.data}" + # ) + # # # save model + # # if args.save_path: + # # os.makedirs(os.path.dirname("models/correct_reward.pt"), exist_ok=True) + # # torch.save(accelerator.unwrap_model(reward_model).state_dict(), "models/correct_reward.pt") + # raise + + print("===training policy===") + global_step = 0 + stats_shape = (args.ppo.noptepochs, args.ppo.nminibatches, args.ppo.gradient_accumulation_steps) + approxkls_stats = torch.zeros(stats_shape, device=device) + clipfracs_stats = torch.zeros(stats_shape, device=device) + pg_losses_stats = torch.zeros(stats_shape, device=device) + vf_losses_stats = torch.zeros(stats_shape, device=device) + vf_clipfrac_stats = torch.zeros(stats_shape, device=device) + entropies_stats = torch.zeros(stats_shape, device=device) + for update in range(1, args.ppo.num_updates + 1): + global_step += 1 * args.ppo.batch_size + frac = 1.0 - (update - 1.0) / args.ppo.num_updates + lrnow = frac * args.ppo.lr + optimizer.param_groups[0]["lr"] = lrnow + data = next(iter_dataloader) + with torch.no_grad(): + """ + let's use `P` to denote the padding token, `T` to denote the truncate token, and `X` to denote the + actual tokens. + queries: `PPXXX` + query_responses: `PPXXX,XXXXTXX` # the space separates the query and response + response: `XXXXTXX` + postprocessed_responses: `XXXXTXX` -> `XXXXTPP` + postprocessed_query_responses: `PPXXX,XXXXTPP` + scores: ↑ # corresponding to this `X` token + + """ + queries = data["query_token"].to(device) + reference_responses = data["reference_response"].to(device) + queries = right_padding_to_left_padding(data["query_token"], tokenizer.pad_token_id).to(device) + query_reference_responses = torch.cat((queries, reference_responses), dim=1) + query_responses = generate( + accelerator.unwrap_model(model).policy, + queries, + tokenizer, + generation_config, + ) + context_length = queries.shape[1] + responses = query_responses[:, context_length:] + + # validation + sample_validation_query_responses = generate( + accelerator.unwrap_model(model).policy, + sample_validation_queries, + tokenizer, + generation_config, + ) + sample_validation_responses = sample_validation_query_responses[:, context_length:] + truncate_token_mask = sample_validation_responses == args.task.truncate_token + truncate_after_or_token_mask = torch.cat( + [ + torch.zeros_like(truncate_token_mask)[:, : args.task.truncate_after], + truncate_token_mask[:, args.task.truncate_after :], + ], + dim=1, + ) + truncate_mask = (torch.cumsum(truncate_after_or_token_mask, dim=1) - truncate_after_or_token_mask.long()).bool() + postprocessed_sample_validation_responses = torch.where( + truncate_mask, + torch.full_like(sample_validation_responses, tokenizer.pad_token_id), + sample_validation_responses, + ) + postprocessed_sample_validation_query_responses = torch.cat((sample_validation_queries, postprocessed_sample_validation_responses), 1) + del truncate_token_mask, truncate_after_or_token_mask, truncate_mask + torch.cuda.empty_cache() + + + + output = forward(accelerator.unwrap_model(model).policy, query_responses, tokenizer) + full_values = get_reward(accelerator.unwrap_model(model).critic, query_responses, tokenizer)[1] + values = full_values[:, context_length - 1 : -1].squeeze(-1) + logits = output.logits[:, context_length - 1 : -1] + logits /= (args.task.temperature + 1e-7) + all_logprobs = F.log_softmax(logits, dim=-1) + logprobs = torch.gather(all_logprobs, 2, responses.unsqueeze(-1)).squeeze(-1) + del output, logits, all_logprobs + torch.cuda.empty_cache() + + ref_output = forward(ref_policy, query_responses, tokenizer) + ref_logits = ref_output.logits[:, context_length - 1 : -1] + ref_logits /= (args.task.temperature + 1e-7) + ref_all_logprobs = F.log_softmax(ref_logits, dim=-1) + ref_logprobs = torch.gather(ref_all_logprobs, 2, responses.unsqueeze(-1)).squeeze(-1) + del ref_output, ref_logits, ref_all_logprobs + torch.cuda.empty_cache() + + # **Response Processing** + # 1. truncate at the first occurrence of `truncate_token` that appears at or after + # position truncate_after in the responses + # https://github.com/openai/lm-human-preferences/blob/cbfd210bb8b08f6bc5c26878c10984b90f516c66/lm_human_preferences/train_policy.py#L378 + # truncate_token_mask = responses == args.task.truncate_token + # truncate_after_or_token_mask = torch.cat( + # [ + # torch.zeros_like(truncate_token_mask)[:, : args.task.truncate_after], + # truncate_token_mask[:, args.task.truncate_after :], + # ], + # dim=1, + # ) + # truncate_mask = (torch.cumsum(truncate_after_or_token_mask, dim=1) - truncate_after_or_token_mask.long()).bool() + # postprocessed_responses = torch.where( + # truncate_mask, + # torch.full_like(responses, tokenizer.pad_token_id), + # responses, + # ) + # del truncate_token_mask, truncate_after_or_token_mask, truncate_mask + + trunc_idxs = first_true_indices(responses == args.task.truncate_token).unsqueeze(-1) + new_size = [1] * (len(responses.size()) - 1) + [args.task.response_length] + idxs = torch.arange(args.task.response_length, device=responses.device).view(*new_size) + postprocessed_responses = torch.masked_fill(responses, idxs > trunc_idxs, tokenizer.pad_token_id) + torch.cuda.empty_cache() + + # 2. run reward model on the truncated responses + postprocessed_query_responses = torch.cat((queries, postprocessed_responses), 1) + padding_mask = postprocessed_responses == tokenizer.pad_token_id + logprobs = torch.masked_fill(logprobs, padding_mask, INVALID_LOGPROB) + ref_logprobs = torch.masked_fill(ref_logprobs, padding_mask, INVALID_LOGPROB) + values = torch.masked_fill(values, padding_mask, 0) + + scores = get_reward_complete(reward_model, postprocessed_query_responses, tokenizer) + rew = get_reward(reward_model, postprocessed_query_responses, tokenizer)[1] + + qr = postprocessed_query_responses + attention_mask = qr != tokenizer.pad_token_id + position_ids = attention_mask.cumsum(1) - attention_mask.long() # exclusive cumsum + input_ids = torch.masked_fill(qr, ~attention_mask, 0) + output = reward_model.lm_backbone(input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, return_dict=True, output_hidden_states=True) + last_reward_latents = output.hidden_states[-1] # TODO: investigate whether it should be output.hidden_states[0] or output.hidden_states[-1] + reward = reward_model.scalar_head(last_reward_latents) + + print(postprocessed_query_responses[0:5,537:]) + print(rew.squeeze(-1)[0:5,537:]) + print(scores) + breakpoint() + + + reference_scores = get_reward_complete(reward_model, query_reference_responses, tokenizer) + # note that we do not truncate the validation responses + validation_score = get_reward_complete(reward_model, postprocessed_sample_validation_query_responses, tokenizer) + + # carperAI-style score normaliation + accelerator.print("before score", scores, scores.mean()) + accelerator.print("reference_scores", reference_scores, reference_scores.mean()) + scores = scores - reference_scores + accelerator.print("after score", scores, scores.mean()) + + # 3. filter response. Ensure that the sample contains truncate_token + # responses not passing that filter will receive a low (fixed) score + # only query humans on responses that pass that filter + contain_pad_token = torch.any(postprocessed_responses == tokenizer.pad_token_id, dim=-1) + scores = torch.where(contain_pad_token, scores, torch.full_like(scores, args.task.penalty_reward_value)) + torch.cuda.empty_cache() + + # 4. compute rewards + kl = logprobs - ref_logprobs + non_score_reward = -kl_ctl.value * kl + rewards = non_score_reward.clone() + rewards[:, -1] += scores + + # 5. whiten rewards + if args.ppo.whiten_rewards: + rewards = whiten(rewards, shift_mean=False) + + if args.print_sample_output_freq > 0 and (update - 1) % args.print_sample_output_freq == 0: + try: + all_decode_validation_queries = tokenizer.batch_decode(sample_validation_queries) + all_sample_validation_query_responses = tokenizer.batch_decode( + sample_validation_query_responses + ) + all_sample_validation_query_responses_postprocessed = tokenizer.batch_decode( + postprocessed_sample_validation_query_responses + ) + all_sample_validation_responses = [ + x[len(y) :] for x, y in zip(all_sample_validation_query_responses, all_decode_validation_queries) + ] + all_sample_validation_postprocessed_responses = [ + x[len(y) :] for x, y in zip(all_sample_validation_query_responses_postprocessed, all_decode_validation_queries) + ] + all_sample_validation_reference_responses = tokenizer.batch_decode( + sample_validation_reference_response + ) + all_sample_validation_df = pd.DataFrame( + { + "query": all_decode_validation_queries, + "response": all_sample_validation_responses, + "postprocessed_response": all_sample_validation_postprocessed_responses, + "reference_responses": all_sample_validation_reference_responses, + "scores": validation_score.float().cpu().numpy(), + "reference_scores": sample_validation_reference_scores.float().cpu().numpy(), + } + ) + if accelerator.is_main_process and args.track: + wandb.log({"samples/query_responses": wandb.Table(dataframe=all_sample_validation_df)}, step=update) + print_rich_table("stuff", all_sample_validation_df[:4], console) + + except Exception as e: + print(e) + del ( + all_decode_validation_queries, + all_sample_validation_query_responses, + all_sample_validation_responses, + all_sample_validation_reference_responses, + all_sample_validation_df, + ) + del postprocessed_query_responses + torch.cuda.empty_cache() + + # 6. compute advantages and returns + lastgaelam = 0 + advantages_reversed = [] + gen_length = args.task.response_length + for t in reversed(range(gen_length)): + nextvalues = values[:, t + 1] if t < gen_length - 1 else 0.0 + delta = rewards[:, t] + args.ppo.gamma * nextvalues - values[:, t] + lastgaelam = delta + args.ppo.gamma * args.ppo.lam * lastgaelam + advantages_reversed.append(lastgaelam) + advantages = torch.stack(advantages_reversed[::-1], axis=1) + returns = advantages + values + advantages = whiten(advantages) + return_mean, return_var = returns.mean(), returns.var() + value_mean, value_var = values.mean(), values.var() + + # Do multiple epochs of PPO training, with a fresh random shuffle in each epoch + for ppo_epoch_idx in range(args.ppo.noptepochs): + b_inds = np.random.permutation(args.ppo.local_batch_size) + minibatch_idx = 0 + for mini_batch_start in range(0, args.ppo.local_batch_size, args.ppo.local_mini_batch_size): + mini_batch_end = mini_batch_start + args.ppo.local_mini_batch_size + mini_batch_inds = b_inds[mini_batch_start:mini_batch_end] + gradient_accumulation_idx = 0 + for micro_batch_start in range(0, args.ppo.local_mini_batch_size, args.ppo.local_micro_batch_size): + with accelerator.accumulate(policy): + micro_batch_end = micro_batch_start + args.ppo.local_micro_batch_size + micro_batch_inds = mini_batch_inds[micro_batch_start:micro_batch_end] + mb_return = returns[micro_batch_inds] + mb_advantage = advantages[micro_batch_inds] + mb_values = values[micro_batch_inds] + mb_responses = responses[micro_batch_inds] + mb_query_responses = query_responses[micro_batch_inds] + mb_logprobs = logprobs[micro_batch_inds] + + # output, vpred_temp = forward(policy, mb_query_responses, tokenizer) + output, (_, vpred_temp) = forward(model, mb_query_responses, tokenizer) + # TODO: value also use the EOS token index!!!!!!!!!!!!!!!!!!! + # TODO: value also use the EOS token index!!!!!!!!!!!!!!!!!!! + # TODO: value also use the EOS token index!!!!!!!!!!!!!!!!!!! + # TODO: value also use the EOS token index!!!!!!!!!!!!!!!!!!! + # TODO: value also use the EOS token index!!!!!!!!!!!!!!!!!!! + # TODO: value also use the EOS token index!!!!!!!!!!!!!!!!!!! + # TODO: value also use the EOS token index!!!!!!!!!!!!!!!!!!! + # TODO: value also use the EOS token index!!!!!!!!!!!!!!!!!!! + # TODO: value also use the EOS token index!!!!!!!!!!!!!!!!!!! + + logits = output.logits[:, context_length - 1 : -1] + logits /= (args.task.temperature + 1e-7) + new_all_logprobs = F.log_softmax(logits, dim=-1) + new_logprobs = torch.gather(new_all_logprobs, 2, mb_responses.unsqueeze(-1)).squeeze(-1) + new_logprobs = torch.masked_fill(new_logprobs, padding_mask[micro_batch_inds], INVALID_LOGPROB) + vpred = vpred_temp[:, context_length - 1 : -1].squeeze(-1) + vpred = torch.masked_fill(vpred, padding_mask[micro_batch_inds], 0) + vpredclipped = torch.clamp( + vpred, + mb_values - args.ppo.cliprange_value, + mb_values + args.ppo.cliprange_value, + ) + vf_losses1 = torch.square(vpred - mb_return) + vf_losses2 = torch.square(vpredclipped - mb_return) + vf_loss = 0.5 * torch.max(vf_losses1, vf_losses2).mean() + vf_clipfrac = (vf_losses2 > vf_losses1).float().mean() + logprobs_diff = new_logprobs - mb_logprobs + ratio = torch.exp(logprobs_diff) + pg_losses = -mb_advantage * ratio + pg_losses2 = -mb_advantage * torch.clamp(ratio, 1.0 - args.ppo.cliprange, 1.0 + args.ppo.cliprange) + pg_loss = torch.max(pg_losses, pg_losses2).mean() + pg_clipfrac = (pg_losses2 > pg_losses).float().mean() + loss = pg_loss + args.ppo.vf_coef * vf_loss + accelerator.backward(loss) + breakpoint() + optimizer.step() + optimizer.zero_grad() + prob_dist = torch.nn.functional.softmax(logits, dim=-1) + entropy = torch.logsumexp(logits, dim=-1) - torch.sum(prob_dist * logits, dim=-1) + approxkl = 0.5 * (logprobs_diff**2).mean() + with torch.no_grad(): + approxkls_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = approxkl + clipfracs_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = pg_clipfrac + pg_losses_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = pg_loss + vf_losses_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = vf_loss + vf_clipfrac_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = vf_clipfrac + entropies_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = entropy.mean() + gradient_accumulation_idx += 1 + minibatch_idx += 1 + if accelerator.is_main_process: + console.print( + f"ppo_epoch_idx", + ppo_epoch_idx, + "approxkl", + approxkl.item(), + "pg_loss", + pg_loss.item(), + "pg_clipfrac", + pg_clipfrac.item(), + "ratio", + ratio.mean().item(), + ) + + with torch.no_grad(): + if not args.deepspeed: # for some reason there is a OOM with the `writer.add_histogram` + writer.add_histogram("ppo/val/ratio_hist", ratio, update) + kl = logprobs - ref_logprobs + mean_kl = kl.sum(1).mean() + mean_entropy = (-logprobs).sum(1).mean() + mean_non_score_reward = non_score_reward.sum(1).mean() + writer.add_scalar("objective/kl_coef", kl_ctl.value, update) + writer.add_scalar("objective/kl", accelerator.gather(mean_kl).mean().item(), update) + writer.add_scalar("objective/entropy", accelerator.gather(mean_entropy).mean().item(), update) + writer.add_scalar("objective/non_score_reward", accelerator.gather(mean_non_score_reward).mean().item(), update) + writer.add_scalar( + "objective/score_total", accelerator.gather(mean_non_score_reward + scores.mean()).mean().item(), update + ) + writer.add_scalar("objective/scores", accelerator.gather(scores.mean()).mean().item(), update) + writer.add_scalar("objective/reference_scores", accelerator.gather(reference_scores.mean()).mean().item(), update) + writer.add_scalar("objective/validation_score", accelerator.gather(validation_score.mean()).mean().item(), update) + writer.add_scalar("ppo/loss/policy", accelerator.gather(pg_loss).mean().item(), update) + writer.add_scalar("ppo/loss/value", accelerator.gather(vf_loss).mean().item(), update) + writer.add_scalar("ppo/loss/total", accelerator.gather(loss).mean().item(), update) + writer.add_scalar("ppo/policy/entropy", accelerator.gather(entropy.mean()).mean().item(), update) + writer.add_scalar("ppo/policy/approxkl", accelerator.gather(approxkl).mean().item(), update) + writer.add_scalar("ppo/policy/clipfrac", accelerator.gather(pg_clipfrac).mean().item(), update) + writer.add_scalar("ppo/policy/approxkl_avg", accelerator.gather(approxkls_stats).mean().item(), update) + writer.add_scalar("ppo/policy/clipfrac_avg", accelerator.gather(clipfracs_stats).mean().item(), update) + writer.add_scalar("ppo/loss/policy_avg", accelerator.gather(pg_losses_stats).mean().item(), update) + writer.add_scalar("ppo/loss/value_avg", accelerator.gather(vf_losses_stats).mean().item(), update) + writer.add_scalar("ppo/val/clipfrac_avg", accelerator.gather(vf_clipfrac_stats).mean().item(), update) + writer.add_scalar("ppo/policy/entropy_avg", accelerator.gather(entropies_stats).mean().item(), update) + writer.add_scalar("ppo/returns/mean", accelerator.gather(return_mean).mean().item(), update) + writer.add_scalar("ppo/returns/var", accelerator.gather(return_var).mean().item(), update) + writer.add_scalar("ppo/val/vpred", accelerator.gather(vpred.mean()).mean().item(), update) + writer.add_scalar("ppo/val/error", accelerator.gather(vf_losses1.mean()).mean().item(), update) + writer.add_scalar("ppo/val/clipfrac", accelerator.gather(vf_clipfrac).mean().item(), update) + writer.add_scalar("ppo/val/mean", accelerator.gather(value_mean).mean().item(), update) + writer.add_scalar("ppo/val/var", accelerator.gather(value_var).mean().item(), update) + writer.add_scalar("ppo/val/ratio", accelerator.gather(ratio.mean()).mean().item(), update) + writer.add_scalar("ppo/val/ratio_var", accelerator.gather(ratio.mean()).var().item(), update) + writer.add_scalar("ppo/val/advantage", accelerator.gather(advantages.mean()).mean().item(), update) + writer.add_scalar("ppo/val/advantage_var", accelerator.gather(advantages.mean()).var().item(), update) + writer.add_scalar("ppo/val/num_eos_tokens", (responses == tokenizer.eos_token_id).sum().item(), update) + writer.add_scalar("ppo/lr", lrnow, update) + writer.add_scalar("ppo/episode", global_step, update) + if args.rewards.use_adaptive_kl: + kl_ctl.update(mean_kl.item(), args.ppo.batch_size) + del kl, mean_kl, mean_entropy, mean_non_score_reward, scores + + # save model + if accelerator.is_main_process and args.save_path: + os.makedirs(os.path.dirname(args.save_path), exist_ok=True) + torch.save(policy.state_dict(), args.save_path) + + if args.upload_model: + repo_name = f"{args.exp_name}__{args.rewards.label_dataset}__seed{args.seed}__{int(time.time())}" + repo_id = f"{args.hf_entity}/{repo_name}" if args.hf_entity else repo_name + policy.save_pretrained(repo_id, safe_serialization=True, push_to_hub=True) + tokenizer.save_pretrained(repo_id, push_to_hub=True) + +# if __name__ == "__main__": +# args = tyro.cli(Args) +# train(args) diff --git a/lm_human_preference_details/summarize_old/train_reward_accelerate_summarize.py b/lm_human_preference_details/summarize_old/train_reward_accelerate_summarize.py new file mode 100644 index 0000000..443cc3b --- /dev/null +++ b/lm_human_preference_details/summarize_old/train_reward_accelerate_summarize.py @@ -0,0 +1,814 @@ +import os +import random +import time +from dataclasses import asdict, dataclass, field +from types import SimpleNamespace +from typing import List, Literal, Optional + +import numpy as np +import pandas as pd +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +import transformers +import tyro +from accelerate import Accelerator +from accelerate.state import AcceleratorState +from accelerate.utils import DistributedDataParallelKwargs, broadcast +from datasets import load_dataset +from rich.console import Console +from rich.pretty import pprint +from rich.table import Table +from torch import Tensor, optim +from torch.optim.optimizer import ( + _dispatch_sqrt, + _get_value, + _use_grad_for_differentiable, +) +from torch.utils.data import DataLoader +from torch.utils.tensorboard import SummaryWriter +from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig, get_scheduler + +from lm_human_preference_details.data import process_query + + +@dataclass +class LabelHParams: + type: str = None + num_train: int = 92832 + num_labels: int = 2 + source: str = None + + +@dataclass +class TaskHParams: + # Query params + query_length: int = 512 + query_dataset: str = "vwxyzjn/summarize_from_feedback_tldr_3_filtered" + + query_format_str: Optional[str] = "SUBREDDIT: r/{subreddit}\n\nTITLE: {title}\n\nPOST: {post}\n\nTL;DR:" + query_truncate_field: Optional[str] = "post" + query_truncate_text: Optional[str] = "\n" + query_padding: Optional[str] = None # defaults to repeated spaces + query_pad_side: Optional[str] = "left" + + # Response params + response_length: int = 48 + + # LM params + temperature: float = 0.7 + + +# a patch +@dataclass +class TaskQueryHParams: + length: int = None + dataset: str = None + format_str: Optional[str] = None # if underlying dataset yields dicts, can format arbitrarily + truncate_field: Optional[str] = None + truncate_text: Optional[str] = None + padding: Optional[str] = None # defaults to repeated spaces + pad_side: Optional[str] = None + + +@dataclass +class Args: + # common args + exp_name: str = os.path.basename(__file__)[: -len(".py")] + """the name of this experiment""" + seed: int = 1 + """seed of the experiment""" + track: bool = False + """if toggled, this experiment will be tracked with Weights and Biases""" + wandb_project_name: str = "cleanrl" + """the wandb's project name""" + wandb_entity: Optional[str] = None + """the entity (team) of wandb's project""" + cuda: bool = True + """Whether to use cuda if available.""" + run_name: tyro.conf.Suppress[str] = None + """TO BE FILLED: a unique name of this run""" + load_from_cache_file: bool = False + """Whether to load data from the local cache file in `dataset.map`""" + + base_model: str = "gpt2" + """the name of the pretrained model to use""" + deepspeed: bool = False + """Whether to use deepspeed to train the model""" + label_dataset: str = "openai/summarize_from_feedback" + """the name of the dataset to use for labels in `https://huggingface.co/datasets/vwxyzjn/lm-human-preferences`""" + local_batch_size: int = 4 + """per rank batch size""" + gradient_accumulation_steps: int = 1 + """gradient accumulation steps""" + local_micro_batch_size: tyro.conf.Suppress[int] = None + """per rank micro batch size""" + lr: float = 0.00005 + """the learning rate""" + eps: float = 1e-5 + """the epsilon for AdamW""" + local_rollout_batch_size: int = 512 + """per rank rollout batch size""" + rollout_batch_size: tyro.conf.Suppress[int] = None + """rollout batch size""" + world_size: tyro.conf.Suppress[int] = None + """the number of processes to use""" + batch_size: tyro.conf.Suppress[int] = None + """the batch size across all ranks""" + local_normalize_samples: int = 256 + """Samples used to estimate reward mean and std""" + normalize_samples: tyro.conf.Suppress[int] = None + """Samples used to estimate reward mean and std across all ranks""" + debug_normalize: int = 0 + """Samples used to check that normalization worked""" + normalize_before: bool = True + """Whether, before training, to normalize the rewards on the policy to the scales on the training buffer. (For comparisons, just use mean 0, var 1.)""" + normalize_after: bool = True + """Whether, after training, to normalize the rewards on the ref policy to mean 0, var 1 (so the KL coefficient always has the same meaning).""" + print_sample_output_freq: int = 300 + """How often to print sample output""" + sft_model_path: str = "models/sft_policy" + """Where to load the SFT model""" + logsigmoid: bool = True + """Whether to use log-sigmoid loss instead of cross-entropy loss""" + trainable_param_percentage: float = 1.0 + """Percentage of parameters to train""" + num_epochs: int = 1 + """Number of epochs to train""" + num_updates: tyro.conf.Suppress[int] = None + """Number of updates to train""" + save_path: str = "models/reward" + """Where to save the model""" + optimizer: Literal["tf_adam", "adam", "adamw"] = "adamw" + """Which optimizer to use""" + scheduler: str = "constant_with_warmup" + """Which scheduler to use""" + warm_up_steps: int = 100 + """Number of warm up steps for the scheduler""" + task: TaskHParams = field(default_factory=TaskHParams) + labels: LabelHParams = field(default_factory=LabelHParams) + + +def first_true_indices(bools, dtype=torch.long): + """ + Takes an N-dimensional bool tensor and returns an (N-1)-dimensional tensor of integers giving + the position of the first True in each "row". + + Returns the length of the rows (bools.size(-1)) if no element is True in a given row. + """ + row_len = bools.size(-1) + zero_or_index = row_len * (~bools).type(dtype) + torch.arange(row_len, dtype=dtype, device=bools.device) + return torch.min(zero_or_index, dim=-1).values + + +def print_rich_table(title: str, df: pd.DataFrame, console: Console) -> Table: + table = Table(show_lines=True) + for column in df.columns: + table.add_column(column) + for _, row in df.iterrows(): + table.add_row(*row.astype(str).tolist()) + console.rule(f"[bold red]{title}") + console.print(table) + + +def _single_tensor_adam( + params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_avg_sqs: List[Tensor], + max_exp_avg_sqs: List[Tensor], + state_steps: List[Tensor], + grad_scale: Optional[Tensor], + found_inf: Optional[Tensor], + *, + amsgrad: bool, + beta1: float, + beta2: float, + lr: float, + weight_decay: float, + eps: float, + maximize: bool, + capturable: bool, + differentiable: bool, +): + assert grad_scale is None and found_inf is None + + for i, param in enumerate(params): + grad = grads[i] if not maximize else -grads[i] + exp_avg = exp_avgs[i] + exp_avg_sq = exp_avg_sqs[i] + step_t = state_steps[i] + # update step + step_t += 1 + # Decay the first and second moment running average coefficient + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2) + step = _get_value(step_t) + + ### pytorch adam implementation: + # bias_correction1 = 1 - beta1 ** step + # bias_correction2 = 1 - beta2 ** step + # step_size = lr / bias_correction1 + # bias_correction2_sqrt = _dispatch_sqrt(bias_correction2) + # denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps) + # param.addcdiv_(exp_avg, denom, value=-step_size) + + ### tensorflow adam implementation: + lr_t = lr * _dispatch_sqrt(1 - beta2**step) / (1 - beta1**step) + denom = exp_avg_sq.sqrt().add_(eps) + param.addcdiv_(exp_avg, denom, value=-lr_t) + + +def adam( + params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_avg_sqs: List[Tensor], + max_exp_avg_sqs: List[Tensor], + state_steps: List[Tensor], + # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627 + # setting this as kwarg for now as functional API is compiled by torch/distributed/optim + foreach: Optional[bool] = None, + capturable: bool = False, + differentiable: bool = False, + fused: Optional[bool] = None, + grad_scale: Optional[Tensor] = None, + found_inf: Optional[Tensor] = None, + *, + amsgrad: bool, + beta1: float, + beta2: float, + lr: float, + weight_decay: float, + eps: float, + maximize: bool, +): + func = _single_tensor_adam + + func( + params, + grads, + exp_avgs, + exp_avg_sqs, + max_exp_avg_sqs, + state_steps, + amsgrad=amsgrad, + beta1=beta1, + beta2=beta2, + lr=lr, + weight_decay=weight_decay, + eps=eps, + maximize=maximize, + capturable=capturable, + differentiable=differentiable, + grad_scale=grad_scale, + found_inf=found_inf, + ) + + +class AdamTensorFlowStyle(optim.Adam): + @_use_grad_for_differentiable + def step(self, closure=None): + self._cuda_graph_capture_health_check() + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + params_with_grad = [] + grads = [] + exp_avgs = [] + exp_avg_sqs = [] + max_exp_avg_sqs = [] + state_steps = [] + beta1, beta2 = group["betas"] + + self._init_group( + group, + params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + max_exp_avg_sqs, + state_steps, + ) + + adam( + params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + max_exp_avg_sqs, + state_steps, + amsgrad=group["amsgrad"], + beta1=beta1, + beta2=beta2, + lr=group["lr"], + weight_decay=group["weight_decay"], + eps=group["eps"], + maximize=group["maximize"], + foreach=group["foreach"], + capturable=group["capturable"], + differentiable=group["differentiable"], + fused=group["fused"], + grad_scale=getattr(self, "grad_scale", None), + found_inf=getattr(self, "found_inf", None), + ) + + return loss + + +def layer_init(layer, std=np.sqrt(2), bias_const=0.0): + torch.nn.init.normal_(layer.weight, std=std) + torch.nn.init.constant_(layer.bias, val=bias_const) + return layer + + +class AutoModelForCausalLMWithRewardHead(nn.Module): + def __init__(self, lm_backbone): + super().__init__() + self.lm_backbone = lm_backbone + self.scalar_head = layer_init( + nn.Linear(lm_backbone.config.hidden_size, 1), + std=1 / np.sqrt(lm_backbone.config.hidden_size + 1), + ) + # self.reward_gain = torch.nn.Parameter(torch.tensor(1.0), requires_grad=True) + # self.reward_bias = torch.nn.Parameter(torch.tensor(0.0), requires_grad=True) + + def forward(self, **kwargs): + output = self.lm_backbone(**kwargs) + last_reward_latents = output.hidden_states[-1] + # shape: [batch_size, hidden_size] + reward = self.scalar_head(last_reward_latents) + return output, reward + + +def right_padding_to_left_padding(tokens, pad_id): + """Convert from right padding to left padding.""" + assert tokens.ndim == 2 + return torch.tensor( + [[pad_id] * (row == pad_id).sum() + [x for x in row if x != pad_id] for row in tokens], + device=tokens.device, + ) + + +def ceil_div(a, b): + return (a - 1) // b + 1 + + +def exact_div(a, b): + q = a // b + if a != q * b: + raise ValueError(f"Inexact division: {a} / {b} = {a / b}") + return q + + +def generate(lm_backbone, queries, tokenizer, generation_config): + """generate in a way that does not affect padding tokens""" + context_length = queries.shape[1] + attention_mask = queries != tokenizer.pad_token_id + input_ids = queries.clone() + input_ids[~attention_mask] = 0 # set padding tokens to 0 + output = lm_backbone.generate( + input_ids=input_ids, + attention_mask=attention_mask, + # position_ids=attention_mask.cumsum(1) - attention_mask.long(), # generation collapsed if this was turned on. TODO: why does generation collapse with this? + generation_config=generation_config, + return_dict_in_generate=True, + ) + # restore padding tokens + return torch.cat((queries, output.sequences[:, context_length:]), dim=1) + + +def get_reward(reward_model, query_responses, tokenizer): + attention_mask = query_responses != tokenizer.pad_token_id + position_ids = attention_mask.cumsum(1) - attention_mask.long() # exclusive cumsum + input_ids = torch.masked_fill(query_responses, ~attention_mask, 0) + return reward_model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + return_dict=True, + output_hidden_states=True, + ) + + +def get_reward_complete(reward_model, query_responses, tokenizer): + reward = get_reward(reward_model, query_responses, tokenizer)[1] + last_response_indices = first_true_indices(query_responses == tokenizer.pad_token_id) - 1 + last_response_indices = torch.max( + last_response_indices, + torch.zeros([1], dtype=last_response_indices.dtype, device=query_responses.device), + ) + return reward[:, :, 0].gather(1, last_response_indices.unsqueeze(1)).view(-1), reward + + +def normalize( + tokenizer, + accelerator, + device, + lm_backbone, + reward_model, + dataloader, + validation_dataloader, +): + idx = 0 + with torch.no_grad(): + # reset reward scales + accelerator.unwrap_model(reward_model).reward_gain.data.fill_(1.0) + accelerator.unwrap_model(reward_model).reward_bias.data.fill_(0.0) + # number of minibatches for computing the normalization statistics + rewards = [] + for data in dataloader: + idx += len(data["query_token"]) + queries = data["query_token"].to(device) + queries = right_padding_to_left_padding(queries, tokenizer.pad_token_id).to(device) + reference_response = data["reference_response"].to(device) + query_responses = torch.cat((queries, reference_response), dim=1) + score = get_reward_complete(reward_model, query_responses, tokenizer) + rewards.append(score) + accelerator.print(f"====number of samples per device: {idx}") + rewards = torch.cat(rewards) + rewards = accelerator.gather(rewards) + mean, std = rewards.mean(), rewards.std() + print(f"mean: {mean}, std: {std}") + + # reward normalization + target_mean, target_std = torch.tensor(0.0, device=device), torch.tensor(1.0, device=device) + gain = target_std / std + bias = target_mean - gain * mean + print(f"gain: {gain}, bias: {bias}") + accelerator.unwrap_model(reward_model).reward_gain.data = gain + accelerator.unwrap_model(reward_model).reward_bias.data = bias + + # validate normalization + rewards = [] + for data in validation_dataloader: + queries = data["query_token"].to(device) + queries = right_padding_to_left_padding(queries, tokenizer.pad_token_id).to(device) + reference_response = data["reference_response"].to(device) + query_responses = torch.cat((queries, reference_response), dim=1) + score = get_reward_complete(reward_model, query_responses, tokenizer) + rewards.append(score) + rewards = torch.cat(rewards) + rewards = accelerator.gather(rewards) + mean, std = rewards.mean(), rewards.std() + print(f"after mean: {mean}, after std: {std}") + + +def evaluate(args, accelerator, device, reward_model, validation_label): + # reward_model.eval() + with torch.no_grad(): + # eval on validation_label, some duplicate code (I don't want to make the training loop into a function...) + test_accuracies = [] + eval_len = len(validation_label) + len_labels = (eval_len // args.batch_size) * args.batch_size # in case the last batch is not full + new_all_inds = np.arange(len_labels) + for start in range(0, len_labels, args.batch_size): + end = start + args.batch_size + b_inds_all = new_all_inds[start:end] + b_inds = b_inds_all[accelerator.process_index :: accelerator.num_processes] # multi-GPU slicing + for micro_batch_start in range(0, args.local_batch_size, args.local_micro_batch_size): + micro_batch_end = micro_batch_start + args.local_micro_batch_size + micro_batch_inds = b_inds[micro_batch_start:micro_batch_end] + mb_data = validation_label[micro_batch_inds] + mb_query = torch.from_numpy(np.stack(mb_data["query_token"])).to(device) + mb_query = right_padding_to_left_padding(mb_query, args.pad_token_id).to(device) + mb_best = torch.from_numpy(np.stack(mb_data["choice"])).to(device) + mb_responses = [ + torch.from_numpy(np.stack(mb_data[f"response{i}_token"])).to(device) + for i in range(args.labels.num_labels) + ] + predicted_reward = [] + for i in range(args.labels.num_labels): + query_responses = torch.cat([mb_query, mb_responses[i]], dim=1) + score, _ = get_reward_complete(reward_model, query_responses, args) + predicted_reward.append(score) + predicted_reward = torch.stack( + predicted_reward, dim=1 + ) # shape (batch_size, num_labels), basically a reward prediction for each label + accuracy = (predicted_reward.argmax(1) == mb_best).float().mean() + test_accuracies.append(accuracy) + test_accuracy = accelerator.gather(torch.stack(test_accuracies).mean()).mean().item() + # reward_model.train() + return test_accuracy + + +def train(args: Args): + accelerator = Accelerator( + kwargs_handlers=[ + DistributedDataParallelKwargs( + broadcast_buffers=False, + # find_unused_parameters=True, + ) + ], # this is needed to avoid https://github.com/pytorch/pytorch/issues/22095#issuecomment-505099500 + gradient_accumulation_steps=args.gradient_accumulation_steps, + ) + args.world_size = accelerator.num_processes + args.batch_size = int(args.local_batch_size * args.world_size) + args.rollout_batch_size = int(args.local_rollout_batch_size * args.world_size) + args.local_micro_batch_size = exact_div(args.local_batch_size, args.gradient_accumulation_steps) + args.num_updates = args.labels.num_train // args.batch_size + patch_h = TaskQueryHParams( + length=args.task.query_length, + dataset=args.task.query_dataset, + format_str=args.task.query_format_str, + truncate_field=args.task.query_truncate_field, + truncate_text=args.task.query_truncate_text, + padding=args.task.query_padding, + pad_side=args.task.query_pad_side, + ) + + console = Console(force_terminal=True) + run_name = f"{args.exp_name}__{args.seed}__{int(time.time())}" + writer = SimpleNamespace() # dummy writer + writer.add_scalar = lambda x, y, z: None + if accelerator.is_main_process: + if args.track: + import wandb + + wandb.init( + project=args.wandb_project_name, + entity=args.wandb_entity, + sync_tensorboard=True, + config=asdict(args), + name=run_name, + save_code=True, + ) + wandb.run.log_code(".") + writer = SummaryWriter(f"runs/{run_name}") + writer.add_text( + "hyperparameters", + "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])), + ) + pprint(args) + device = accelerator.device + local_seed = args.seed + accelerator.process_index * 100003 # Prime + random.seed(local_seed) + np.random.seed(local_seed) + torch.manual_seed(local_seed) + torch.backends.cudnn.deterministic = True + tokenizer = AutoTokenizer.from_pretrained( + args.base_model, + padding_side="right", + trust_remote_code=True, + ) + # we use the padding token manually but do not resize the token embedding of the model + tokenizer.add_special_tokens({"pad_token": "[PAD]"}) + args.pad_token_id = tokenizer.pad_token_id + reward_model = AutoModelForCausalLMWithRewardHead( + AutoModelForCausalLM.from_pretrained(args.base_model, trust_remote_code=True) + ) + + # freeze the first 70% of layers + if args.trainable_param_percentage < 1.0: + layers = reward_model.lm_backbone.transformer.h + num_layers = len(layers) + num_unfrozen = int(args.trainable_param_percentage * num_layers) + for layer in layers[:-num_unfrozen]: + layer.requires_grad_(False) + + if args.sft_model_path: + reward_model.lm_backbone.load_state_dict(torch.load(args.sft_model_path, map_location=device)) + print(f"loaded SFT model from {args.sft_model_path}") + reward_model.lm_backbone.generation_config.eos_token_id = ( + None # disable `pad_token_id` and `eos_token_id` because we just want to + ) + reward_model.lm_backbone.generation_config.pad_token_id = None # generate tokens without truncation / padding + # make sure the `lm_head` or `embed_out` does not require gradients, otherwise + # pytorch DDP complains; see https://gist.github.com/vwxyzjn/45fc8706dfb3cf33695f0f57cc44a533 + if isinstance(reward_model.lm_backbone, transformers.GPTNeoXForCausalLM): + reward_model.lm_backbone.embed_out.requires_grad_(False) + if args.optimizer == "tf_adam": + optimizer = AdamTensorFlowStyle(reward_model.parameters(), lr=args.lr, eps=args.eps) + elif args.optimizer == "adam": + optimizer = optim.Adam(reward_model.parameters(), lr=args.lr, eps=args.eps) + elif args.optimizer == "adamw": + optimizer = optim.AdamW(reward_model.parameters(), lr=args.lr, eps=args.eps) + # TODO: use AdamW + scheduler = get_scheduler( + args.scheduler, + optimizer=optimizer, + num_warmup_steps=args.warm_up_steps, + num_training_steps=args.num_updates * args.num_epochs, + ) + + if args.deepspeed: + import deepspeed + + deepspeed_states = AcceleratorState().deepspeed_plugin + deepspeed_states.deepspeed_config["train_micro_batch_size_per_gpu"] = args.local_micro_batch_size + + + reward_model, optimizer, scheduler = accelerator.prepare(reward_model, optimizer, scheduler) + if args.normalize_before: + dataset = load_dataset(args.task.query_dataset, split="train") + validation_dataset = load_dataset(args.task.query_dataset, split="validation") + + def process_query_data(x): + return { + **process_query(x, encoder=tokenizer, hparams=patch_h), + "reference_response": tokenizer.encode( + f" {x['summary']}<|endoftext|>", padding="max_length", max_length=args.task.response_length, truncation=True, + # with an extra leading space to account for the space between the query and response + ), + } + + dataset = dataset.map(process_query_data, load_from_cache_file=args.load_from_cache_file) + dataset = dataset.with_format("torch", columns=["query_token", "reference_response"]) + dataset = dataset.shuffle(seed=local_seed) + dataloader = DataLoader(dataset, batch_size=args.local_rollout_batch_size) + validation_dataset = validation_dataset.map(process_query_data, load_from_cache_file=args.load_from_cache_file) + validation_dataset = validation_dataset.with_format("torch", columns=["query_token", "reference_response"]) + validation_dataset = validation_dataset.shuffle(seed=local_seed) + validation_dataloader = DataLoader(validation_dataset, batch_size=args.local_rollout_batch_size) + dataloader = accelerator.prepare(dataloader) + iter_dataloader = iter(dataloader) + print("===Normalize reward model *before* training===") + print( + "before normalization. " + + f"Gain: {accelerator.unwrap_model(reward_model).reward_gain.data}" + + f" Bias: {accelerator.unwrap_model(reward_model).reward_bias.data}" + ) + + normalize( + tokenizer, + accelerator, + device, + reward_model, + reward_model, + dataloader, + validation_dataloader, + ) + print( + "after normalization. " + + f"Gain: {accelerator.unwrap_model(reward_model).reward_gain.data}" + + f" Bias: {accelerator.unwrap_model(reward_model).reward_bias.data}" + ) + + # `label` has keys `['sample0', 'query', 'best', 'sample3', 'sample1', 'sample2']` + label = load_dataset(args.label_dataset, "comparisons", split="train") + validation_label = load_dataset(args.label_dataset, "comparisons", split="validation") + dev_validation_label = validation_label.filter(lambda x: x["split"] == "valid1") + eval_validation_label = validation_label.filter(lambda x: x["split"] == "valid2") + accelerator.print("Num labels found in source:", len(label)) + accelerator.print("training on", args.labels.num_train, "in batches of", args.local_batch_size) + + def process_response_data(x): + return { + **process_query(x["info"], encoder=tokenizer, hparams=patch_h), + "response0_token": tokenizer.encode( + f" {x['summaries'][0]['text']}<|endoftext|>", padding="max_length", max_length=args.task.response_length, truncation=True + ), + "response1_token": tokenizer.encode( + f" {x['summaries'][1]['text']}<|endoftext|>", padding="max_length", max_length=args.task.response_length, truncation=True + ), + } + + label = label.map(process_response_data, load_from_cache_file=args.load_from_cache_file) + dev_validation_label = dev_validation_label.map(process_response_data, load_from_cache_file=args.load_from_cache_file) + eval_validation_label = eval_validation_label.map(process_response_data, load_from_cache_file=args.load_from_cache_file) + # TODO: check if all labels have eos token + accelerator.print("===training reward model===") + num_train = (args.labels.num_train // args.batch_size) * args.batch_size + for epoch in range(args.num_epochs): + all_inds = np.random.permutation(args.labels.num_train) + # ensure that all processes have the same shuffled indices + all_inds = broadcast(torch.tensor(all_inds, device=device), 0) + all_inds = all_inds.cpu().numpy() + accelerator.print(f"epoch: {epoch}") + for (epoch_global_step, start) in enumerate(range(0, num_train, args.batch_size)): + global_step = epoch * args.num_updates + epoch_global_step + end = start + args.batch_size + b_inds_all = all_inds[start:end] + b_inds = b_inds_all[accelerator.process_index :: accelerator.num_processes] # multi-GPU slicing + # accelerator.print(f"global_step: {global_step}, start: {start}, end: {end}, b_inds: {b_inds}") + if accelerator.is_main_process: pprint( + { + "global_step": global_step, + "start:end": f"{start}:{end}", + "b_inds_all": b_inds_all, + "b_inds": b_inds, + } + ) + losses = torch.zeros((args.gradient_accumulation_steps,), device=device) + accuracies = torch.zeros((args.gradient_accumulation_steps,), device=device) + reward_preferreds = torch.zeros((args.gradient_accumulation_steps,), device=device) + reward_rejecteds = torch.zeros((args.gradient_accumulation_steps,), device=device) + gradient_accumulation_step = 0 + # reward_model.train() + for micro_batch_start in range(0, args.local_batch_size, args.local_micro_batch_size): + with accelerator.accumulate(reward_model): + micro_batch_end = micro_batch_start + args.local_micro_batch_size + micro_batch_inds = b_inds[micro_batch_start:micro_batch_end] + mb_data = label[micro_batch_inds] + # pprint({ + # "micro_batch_start:micro_batch_end": f"{micro_batch_start}:{micro_batch_end}", + # "micro_batch_inds": micro_batch_inds, + # }) + mb_query = torch.from_numpy(np.stack(mb_data["query_token"])).to(device) + mb_best = torch.from_numpy(np.stack(mb_data["choice"])).to(device) + mb_responses = [ + torch.from_numpy(np.stack(mb_data[f"response{i}_token"])).to(device) for i in range(args.labels.num_labels) + ] + mb_query_tiled = mb_query.unsqueeze(1).repeat(1, len(mb_responses), 1) + query_responses = torch.cat([mb_query_tiled, torch.stack(mb_responses).transpose(0,1)], dim=2).flatten(0, 1) + predicted_reward, reward = get_reward_complete(reward_model, query_responses, tokenizer) + predicted_reward = predicted_reward.view(-1, len(mb_responses)) # TODO check shape for no gradienta ccumulation steps + + # print(tokenizer.decode(mb_query[0])) + # print(tokenizer.decode(mb_responses[0][0])) + # print(tokenizer.decode(mb_responses[1][0])) + # predicted_reward = [] + # rewards = [] + # for i in range(args.labels.num_labels): + # query_responses = torch.cat([mb_query, mb_responses[i]], dim=1) + # score, reward = get_reward_complete(reward_model, query_responses, tokenizer) + # rewards.append(reward.squeeze(-1)) + # predicted_reward.append(score) + # # shape (batch_size, num_labels), basically a reward prediction for each label + # predicted_reward = torch.stack(predicted_reward, dim=1) + # breakpoint() + accuracy = (predicted_reward.argmax(1) == mb_best).float().mean() + reward_preferred = predicted_reward.gather(1, mb_best.view(-1, 1)).view(-1) + reward_rejected = predicted_reward.gather(1, (1 - mb_best).view(-1, 1)).view(-1) + if args.logsigmoid: + loss = -F.logsigmoid(reward_preferred - reward_rejected).mean() + else: + loss = F.cross_entropy(predicted_reward, mb_best) + accelerator.backward(loss) + + # for k, v in reward_model.named_parameters(): + # if v.requires_grad: + # if v.grad is None: + # print(f"found unused param: {k}") + + optimizer.step() # accelerate handles gradient accumulation automatically + optimizer.zero_grad() + scheduler.step() + losses[gradient_accumulation_step] = loss + accuracies[gradient_accumulation_step] = accuracy + reward_preferreds[gradient_accumulation_step] = reward_preferred.mean() + reward_rejecteds[gradient_accumulation_step] = reward_rejected.mean() + gradient_accumulation_step += 1 + + train_accuracy = accelerator.gather(accuracies).mean().item() + writer.add_scalar("train/loss", accelerator.gather(losses).mean().item(), global_step) + writer.add_scalar("train/accuracy", train_accuracy, global_step) + writer.add_scalar("train/reward_preferred", accelerator.gather(reward_preferreds).mean().item(), global_step) + writer.add_scalar("train/reward_rejected", accelerator.gather(reward_rejecteds).mean().item(), global_step) + lr = scheduler.get_last_lr() + writer.add_scalar("train/lr", np.array(lr).mean().item(), global_step) + accelerator.print("train/accuracy", train_accuracy) + + # if args.print_sample_output_freq > 0 and global_step % args.print_sample_output_freq == 0: + if global_step == args.num_updates - 1: # first and last update + dev_validation_accuracy = evaluate(args, accelerator, device, reward_model, dev_validation_label) + writer.add_scalar("dev_validation/accuracy", dev_validation_accuracy, global_step) + accelerator.print("dev_validation/accuracy", dev_validation_accuracy, global_step) + eval_validation_accuracy = evaluate(args, accelerator, device, reward_model, eval_validation_label) + writer.add_scalar("eval_validation/accuracy", eval_validation_accuracy, global_step) + accelerator.print("eval_validation/accuracy", eval_validation_accuracy, global_step) + eval_validation_accuracy = evaluate(args, accelerator, device, reward_model, label) + writer.add_scalar("train_full/accuracy", eval_validation_accuracy, global_step) + accelerator.print("train_full/accuracy", eval_validation_accuracy, global_step) + + torch.cuda.empty_cache() + if args.normalize_after: + print("===Normalize reward model *after* training===") + print( + "before normalization. " + + f"Gain: {accelerator.unwrap_model(reward_model).reward_gain.data}" + + f" Bias: {accelerator.unwrap_model(reward_model).reward_bias.data}" + ) + + normalize( + tokenizer, + accelerator, + device, + reward_model, + reward_model, + dataloader, + validation_dataloader, + ) + print( + "after normalization. " + + f"Gain: {accelerator.unwrap_model(reward_model).reward_gain.data}" + + f" Bias: {accelerator.unwrap_model(reward_model).reward_bias.data}" + ) + + # save model + if args.save_path: + os.makedirs(os.path.dirname(args.save_path), exist_ok=True) + # torch.save(accelerator.unwrap_model(reward_model).state_dict(), args.save_path) + accelerator.save_model(reward_model, args.save_path) + + if accelerator.is_main_process and args.track: + wandb.finish() + + +if __name__ == "__main__": + args = tyro.cli(Args) + train(args) diff --git a/lm_human_preference_details/tldr_dataset.py b/lm_human_preference_details/tldr_dataset.py new file mode 100644 index 0000000..cee1642 --- /dev/null +++ b/lm_human_preference_details/tldr_dataset.py @@ -0,0 +1,128 @@ +from datasets import load_dataset +from dataclasses import dataclass +from typing import Dict, Optional, Union +from transformers import AutoTokenizer +from rich.pretty import pprint + +import numpy as np + + +@dataclass +class TaskQueryHParams: + length: int = 512 + dataset: str = "vwxyzjn/summarize_from_feedback_tldr_3_filtered" + format_str: Optional[str] = "SUBREDDIT: r/{subreddit}\n\nTITLE: {title}\n\nPOST: {post}\n\nTL;DR:" # if underlying dataset yields dicts, can format arbitrarily + truncate_field: Optional[str] = "post" + truncate_text: Optional[str] = "\n" + padding: Optional[Union[str, int]] = 50257 + pad_side: Optional[str] = "left" + + +def _ensure_length(toks, l, pad_sequence=None, pad_side=None, truncate_side=None): + assert pad_side in (None, "left", "right") + assert truncate_side in (None, "left", "right") + if len(toks) < l: + assert pad_sequence is not None + pad_amt = l - len(toks) + assert len(pad_sequence) >= pad_amt, f"{len(pad_sequence)} < {pad_amt}" + if pad_side is None: + assert len(toks) == l, f"Needed to pad! {len(toks)} < {l}" + return toks + elif pad_side == "left": + return pad_sequence[-pad_amt:] + toks + else: + assert pad_side == "right" + return toks + pad_sequence[:pad_amt] + if truncate_side is None: + assert len(toks) == l, f"Needed to truncate! {len(toks)} > {l}" + return toks + elif truncate_side == "left": + return toks[-l:] + else: + assert truncate_side == "right" + return toks[:l] + + +def _get_query_padding_for_task(encoder, hparams: TaskQueryHParams): + return hparams.padding * hparams.length + + +def process_query(query_info: Dict[str, str], *, encoder, hparams: TaskQueryHParams, pad_sequence=None): + if pad_sequence is None: + pad_sequence = _get_query_padding_for_task(encoder, hparams) + if isinstance(query_info, str): + query_info = dict(query=query_info) + else: + # copy to avoid mutating input + query_info = dict(**query_info) + + format_str = hparams.format_str or "{query}" + query_tokens = encoder.encode(format_str.format(**query_info)) + truncate_field = hparams.truncate_field or "query" + + if truncate_field not in query_info: + raise ValueError(f"Could not truncate field {truncate_field}, found fields: {query_info.keys()}!") + while len(query_tokens) > hparams.length: + if not len(query_info[truncate_field]): + raise ValueError("Could not truncate enough!") + + i = -1 # default to just remove one character + if hparams.truncate_text: + try: + i = query_info[truncate_field].rindex(hparams.truncate_text) + except ValueError: + pass + query_info[truncate_field] = query_info[truncate_field][:i] + query_tokens = encoder.encode(format_str.format(**query_info)) + + query_token = _ensure_length(query_tokens, hparams.length, pad_side=hparams.pad_side, pad_sequence=pad_sequence) + query = encoder.decode(query_token).lstrip() + return dict( + query_token=query_token, + query=query, + ) + + +if __name__ == "__main__": + tokenizer = AutoTokenizer.from_pretrained("gpt2") + tokenizer.add_special_tokens({"pad_token": "[PAD]"}) + max_response_length = 48 + oai_h = TaskQueryHParams() + if isinstance(oai_h.padding, str): + oai_h.padding = tokenizer.encode(oai_h.padding) + else: + oai_h.padding = [oai_h.padding] + pprint(oai_h) + dataset = load_dataset(oai_h.dataset) + def process_query_data(x): + # with an extra leading space to account for the space between the query and response + reference_response = f" {x['summary']}<|endoftext|>" + return { + **process_query(x, encoder=tokenizer, hparams=oai_h), + "reference_response": reference_response, + "reference_response_token": tokenizer.encode( + reference_response, padding="max_length", max_length=max_response_length, truncation=True, + ), + } + dataset = dataset.map(process_query_data, load_from_cache_file=False) + push_result = dataset.push_to_hub("vwxyzjn/summarize_from_feedback_tldr_3_filtered_oai_preprocessing") + print(push_result) + + label = load_dataset("openai/summarize_from_feedback", "comparisons") + def process_response_data(x): + # with an extra leading space to account for the space between the query and response + response0 = x['summaries'][0]['text'] + response1 = x['summaries'][1]['text'] + return { + **process_query(x["info"], encoder=tokenizer, hparams=oai_h), + "response0": response0, + "response0_token": tokenizer.encode( + response0, padding="max_length", max_length=max_response_length, truncation=True + ), + "response1": response1, + "response1_token": tokenizer.encode( + response1, padding="max_length", max_length=max_response_length, truncation=True + ), + } + label = label.map(process_response_data, load_from_cache_file=False) + push_result = label.push_to_hub("vwxyzjn/summarize_from_feedback_oai_preprocessing") diff --git a/lm_human_preference_details/train_policy_accelerate_summarize_separate.py b/lm_human_preference_details/train_policy_accelerate_summarize_separate.py index 90df20d..71a5dd5 100644 --- a/lm_human_preference_details/train_policy_accelerate_summarize_separate.py +++ b/lm_human_preference_details/train_policy_accelerate_summarize_separate.py @@ -3,7 +3,7 @@ import time from dataclasses import asdict, dataclass, field from types import SimpleNamespace -from typing import List, Optional +from typing import List, Literal, Optional import numpy as np import pandas as pd @@ -28,7 +28,6 @@ from torch.utils.tensorboard import SummaryWriter from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig -from lm_human_preference_details.data import process_query INVALID_LOGPROB = 1.0 @@ -43,7 +42,7 @@ class RewardHParams: kl_coef: float = 0.15 use_adaptive_kl: bool = True adaptive_kl: Optional[AdaptiveKLParams] = field(default_factory=AdaptiveKLParams) - trained_model: Optional[str] = "models/reward" + trained_model: Optional[str] = "models/gpt2medium_last_index_reward/pytorch_model.bin" label_dataset: tyro.conf.Suppress[Optional[str]] = None @@ -78,7 +77,7 @@ class PpoHParams: class TaskHParams: # Query params query_length: int = 512 - query_dataset: str = "vwxyzjn/summarize_from_feedback_tldr_3_filtered" + query_dataset: str = "vwxyzjn/summarize_from_feedback_tldr_3_filtered_oai_preprocessing" query_format_str: Optional[str] = "SUBREDDIT: r/{subreddit}\n\nTITLE: {title}\n\nPOST: {post}\n\nTL;DR:" query_truncate_field: Optional[str] = "post" @@ -140,12 +139,12 @@ class Args: """Whether to use deepspeed to train the model""" print_sample_output_freq: int = 1 """How often to print sample output""" - sft_model_path: str = "models/sft_policy" + sft_model_path: str = "" """Where to load the SFT model""" save_path: str = "models/policy.pt" """Where to save the model""" - use_tensorflow_adam: bool = True - """Whether to use tensorflow-style Adam optimizer instead of PyTorch's""" + optimizer: Literal["tf_adam", "adam", "adamw"] = "adamw" + """Which optimizer to use""" task: TaskHParams = field(default_factory=TaskHParams) rewards: RewardHParams = field(default_factory=RewardHParams) ppo: PpoHParams = field(default_factory=PpoHParams) @@ -361,14 +360,10 @@ def __init__(self, lm_backbone): def forward(self, **kwargs): output = self.lm_backbone(**kwargs) - last_reward_latents = output.hidden_states[-1] - # shape: [batch_size, length, hidden_size] - # last_reward_latents = reward_latents - # shape: [batch_size, hidden_size] - reward = self.scalar_head(last_reward_latents) - # # shape: [batch_size, 1] - # reward = self.reward_gain * reward + self.reward_bias - return output, reward + latents = output.hidden_states[-1] # shape: [batch_size, length, hidden_size] + scalars = self.scalar_head(latents).squeeze(-1) # shape: [batch_size, length] + last_scalar = scalars[:, -1] # shape: [batch_size, 1] + return scalars, last_scalar # taken from https://github.com/OpenLMLab/MOSS-RLHF/blob/40b91eb2f2b71b16919addede0341d2bef70825d/ppo/ppo_trainer.py#L29 @@ -407,8 +402,7 @@ def generate(lm_backbone, queries, tokenizer, generation_config): """generate in a way that does not affect padding tokens""" context_length = queries.shape[1] attention_mask = queries != tokenizer.pad_token_id - input_ids = queries.clone() - input_ids[~attention_mask] = 0 # set padding tokens to 0 + input_ids = torch.masked_fill(queries, ~attention_mask, 0) output = lm_backbone.generate( input_ids=input_ids, attention_mask=attention_mask, @@ -416,7 +410,6 @@ def generate(lm_backbone, queries, tokenizer, generation_config): generation_config=generation_config, return_dict_in_generate=True, ) - # restore padding tokens return torch.cat((queries, output.sequences[:, context_length:]), dim=1) @@ -433,21 +426,10 @@ def get_reward(reward_model, query_responses, tokenizer): ) -def get_reward_complete(reward_model, query_responses, tokenizer): - reward = get_reward(reward_model, query_responses, tokenizer)[1] - last_response_indices = first_true_indices(query_responses == tokenizer.pad_token_id) - 1 - last_response_indices = torch.max( - last_response_indices, - torch.zeros([1], dtype=last_response_indices.dtype, device=query_responses.device), - ) - return reward[:, :, 0].gather(1, last_response_indices.unsqueeze(1)).view(-1) - - def forward(policy, query_responses, tokenizer): attention_mask = query_responses != tokenizer.pad_token_id position_ids = attention_mask.cumsum(1) - attention_mask.long() # exclusive cumsum - input_ids = query_responses.clone() - input_ids[~attention_mask] = 0 + input_ids = torch.masked_fill(query_responses, ~attention_mask, 0) return policy( input_ids=input_ids, attention_mask=attention_mask, @@ -457,6 +439,14 @@ def forward(policy, query_responses, tokenizer): ) +def truncate_response(args, tokenizer, responses): + trunc_idxs = first_true_indices(responses == args.task.truncate_token).unsqueeze(-1) + new_size = [1] * (len(responses.size()) - 1) + [args.task.response_length] + idxs = torch.arange(args.task.response_length, device=responses.device).view(*new_size) + postprocessed_responses = torch.masked_fill(responses, idxs > trunc_idxs, tokenizer.pad_token_id) + postprocessed_responses = right_padding_to_left_padding(postprocessed_responses, tokenizer.pad_token_id) + return postprocessed_responses + # def train(args: Args): if __name__ == "__main__": args = tyro.cli(Args) @@ -466,15 +456,6 @@ def forward(policy, query_responses, tokenizer): args.ppo.minibatch_size = exact_div(args.ppo.batch_size, args.ppo.nminibatches) args.ppo.local_mini_batch_size = exact_div(args.ppo.local_batch_size, args.ppo.nminibatches) args.ppo.local_micro_batch_size = exact_div(args.ppo.local_mini_batch_size, args.ppo.gradient_accumulation_steps) - patch_h = TaskQueryHParams( - length=args.task.query_length, - dataset=args.task.query_dataset, - format_str=args.task.query_format_str, - truncate_field=args.task.query_truncate_field, - truncate_text=args.task.query_truncate_text, - padding=args.task.query_padding, - pad_side=args.task.query_pad_side, - ) if args.ppo.whiten_rewards: assert ( args.ppo.local_mini_batch_size >= 8 @@ -542,28 +523,19 @@ def forward(policy, query_responses, tokenizer): ) policy.generation_config.pad_token_id = None # generate tokens without truncation / padding model = PolicyAndValueWrapper(policy, critic) - if args.use_tensorflow_adam: + if args.optimizer == "tf_adam": optimizer = AdamTensorFlowStyle(model.parameters(), lr=args.ppo.lr, eps=args.ppo.eps) - else: + elif args.optimizer == "adam": optimizer = optim.Adam(model.parameters(), lr=args.ppo.lr, eps=args.ppo.eps) + elif args.optimizer == "adamw": + optimizer = optim.AdamW(model.parameters(), lr=args.ppo.lr, eps=args.ppo.eps) + dataset = load_dataset(args.task.query_dataset, split="train") validation_dataset = load_dataset(args.task.query_dataset, split="validation") - - def process_query_data(x): - return { - **process_query(x, encoder=tokenizer, hparams=patch_h), - "reference_response": tokenizer.encode( - f" {x['summary']}<|endoftext|>", padding="max_length", max_length=args.task.response_length, truncation=True, - # with an extra leading space to account for the space between the query and response - ), - } - - dataset = dataset.map(process_query_data, load_from_cache_file=args.load_from_cache_file) - dataset = dataset.with_format("torch", columns=["query_token", "reference_response"]) + dataset = dataset.with_format("torch", columns=["query_token", "reference_response_token"]) dataset = dataset.shuffle(seed=local_seed) dataloader = DataLoader(dataset, batch_size=args.ppo.local_batch_size) - validation_dataset = validation_dataset.map(process_query_data, load_from_cache_file=args.load_from_cache_file) - validation_dataset = validation_dataset.with_format("torch", columns=["query_token", "reference_response"]) + validation_dataset = validation_dataset.with_format("torch", columns=["query_token", "reference_response_token"]) validation_dataloader = DataLoader(validation_dataset, batch_size=args.ppo.local_batch_size) model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader) validation_dataloader = accelerator.prepare(validation_dataloader) @@ -602,15 +574,13 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well sample_validation_inds = np.arange(args.ppo.batch_size) local_sample_validation_inds = sample_validation_inds[accelerator.process_index :: accelerator.num_processes] sample_validation = validation_dataset[local_sample_validation_inds] - sample_validation = {k: v.to(device) for k, v in sample_validation.items()} - sample_validation_queries = sample_validation["query_token"] + sample_validation_queries = torch.Tensor(sample_validation["query_token"]).to(device) with torch.no_grad(): - print(sample_validation_queries.shape) sample_validation_queries = right_padding_to_left_padding(sample_validation_queries, tokenizer.pad_token_id) - sample_validation_reference_response = sample_validation["reference_response"] + sample_validation_reference_response = torch.Tensor(sample_validation["reference_response_token"]).to(device) sample_validation_query_reference_responses = torch.cat((sample_validation_queries, sample_validation_reference_response), dim=1) - sample_validation_reference_scores = get_reward_complete(reward_model, sample_validation_query_reference_responses, tokenizer) - # breakpoint() + sample_validation_query_reference_responses = right_padding_to_left_padding(sample_validation_query_reference_responses, tokenizer.pad_token_id) + _, sample_validation_reference_scores = get_reward(reward_model, sample_validation_query_reference_responses, tokenizer) iter_dataloader = iter(repeat_generator()) kl_ctl = AdaptiveKLController(args.rewards.kl_coef, hparams=args.rewards.adaptive_kl) @@ -625,33 +595,6 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well do_sample=True, ) - # print("===Normalize reward model *before* training===") - # print( - # "before normalization. " - # + f"Gain: {accelerator.unwrap_model(reward_model).reward_gain.data}" - # + f" Bias: {accelerator.unwrap_model(reward_model).reward_bias.data}" - # ) - - # normalize( - # tokenizer, - # accelerator, - # device, - # reward_model, - # reward_model, - # dataloader, - # validation_dataloader, - # ) - # print( - # "after normalization. " - # + f"Gain: {accelerator.unwrap_model(reward_model).reward_gain.data}" - # + f" Bias: {accelerator.unwrap_model(reward_model).reward_bias.data}" - # ) - # # # save model - # # if args.save_path: - # # os.makedirs(os.path.dirname("models/correct_reward.pt"), exist_ok=True) - # # torch.save(accelerator.unwrap_model(reward_model).state_dict(), "models/correct_reward.pt") - # raise - print("===training policy===") global_step = 0 stats_shape = (args.ppo.noptepochs, args.ppo.nminibatches, args.ppo.gradient_accumulation_steps) @@ -668,21 +611,11 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well optimizer.param_groups[0]["lr"] = lrnow data = next(iter_dataloader) with torch.no_grad(): - """ - let's use `P` to denote the padding token, `T` to denote the truncate token, and `X` to denote the - actual tokens. - queries: `PPXXX` - query_responses: `PPXXX,XXXXTXX` # the space separates the query and response - response: `XXXXTXX` - postprocessed_responses: `XXXXTXX` -> `XXXXTPP` - postprocessed_query_responses: `PPXXX,XXXXTPP` - scores: ↑ # corresponding to this `X` token - - """ queries = data["query_token"].to(device) - reference_responses = data["reference_response"].to(device) + reference_responses = data["reference_response_token"].to(device) queries = right_padding_to_left_padding(data["query_token"], tokenizer.pad_token_id).to(device) query_reference_responses = torch.cat((queries, reference_responses), dim=1) + query_reference_responses = right_padding_to_left_padding(query_reference_responses, tokenizer.pad_token_id) query_responses = generate( accelerator.unwrap_model(model).policy, queries, @@ -700,29 +633,11 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well generation_config, ) sample_validation_responses = sample_validation_query_responses[:, context_length:] - truncate_token_mask = sample_validation_responses == args.task.truncate_token - truncate_after_or_token_mask = torch.cat( - [ - torch.zeros_like(truncate_token_mask)[:, : args.task.truncate_after], - truncate_token_mask[:, args.task.truncate_after :], - ], - dim=1, - ) - truncate_mask = (torch.cumsum(truncate_after_or_token_mask, dim=1) - truncate_after_or_token_mask.long()).bool() - postprocessed_sample_validation_responses = torch.where( - truncate_mask, - torch.full_like(sample_validation_responses, tokenizer.pad_token_id), - sample_validation_responses, - ) + postprocessed_sample_validation_responses = truncate_response(args, tokenizer, sample_validation_responses) postprocessed_sample_validation_query_responses = torch.cat((sample_validation_queries, postprocessed_sample_validation_responses), 1) - del truncate_token_mask, truncate_after_or_token_mask, truncate_mask torch.cuda.empty_cache() - - output = forward(accelerator.unwrap_model(model).policy, query_responses, tokenizer) - full_values = get_reward(accelerator.unwrap_model(model).critic, query_responses, tokenizer)[1] - values = full_values[:, context_length - 1 : -1].squeeze(-1) logits = output.logits[:, context_length - 1 : -1] logits /= (args.task.temperature + 1e-7) all_logprobs = F.log_softmax(logits, dim=-1) @@ -739,58 +654,23 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well torch.cuda.empty_cache() # **Response Processing** - # 1. truncate at the first occurrence of `truncate_token` that appears at or after - # position truncate_after in the responses - # https://github.com/openai/lm-human-preferences/blob/cbfd210bb8b08f6bc5c26878c10984b90f516c66/lm_human_preferences/train_policy.py#L378 - # truncate_token_mask = responses == args.task.truncate_token - # truncate_after_or_token_mask = torch.cat( - # [ - # torch.zeros_like(truncate_token_mask)[:, : args.task.truncate_after], - # truncate_token_mask[:, args.task.truncate_after :], - # ], - # dim=1, - # ) - # truncate_mask = (torch.cumsum(truncate_after_or_token_mask, dim=1) - truncate_after_or_token_mask.long()).bool() - # postprocessed_responses = torch.where( - # truncate_mask, - # torch.full_like(responses, tokenizer.pad_token_id), - # responses, - # ) - # del truncate_token_mask, truncate_after_or_token_mask, truncate_mask - - trunc_idxs = first_true_indices(responses == args.task.truncate_token).unsqueeze(-1) - new_size = [1] * (len(responses.size()) - 1) + [args.task.response_length] - idxs = torch.arange(args.task.response_length, device=responses.device).view(*new_size) - postprocessed_responses = torch.masked_fill(responses, idxs > trunc_idxs, tokenizer.pad_token_id) + postprocessed_responses = truncate_response(args, tokenizer, responses) torch.cuda.empty_cache() # 2. run reward model on the truncated responses postprocessed_query_responses = torch.cat((queries, postprocessed_responses), 1) + postprocessed_query_responses = right_padding_to_left_padding(postprocessed_query_responses, tokenizer.pad_token_id) + full_values, _ = get_reward(accelerator.unwrap_model(model).critic, postprocessed_query_responses, tokenizer) + values = full_values[:, context_length - 1 : -1].squeeze(-1) padding_mask = postprocessed_responses == tokenizer.pad_token_id - logprobs = torch.masked_fill(logprobs, padding_mask, INVALID_LOGPROB) - ref_logprobs = torch.masked_fill(ref_logprobs, padding_mask, INVALID_LOGPROB) - values = torch.masked_fill(values, padding_mask, 0) - - scores = get_reward_complete(reward_model, postprocessed_query_responses, tokenizer) - rew = get_reward(reward_model, postprocessed_query_responses, tokenizer)[1] - - qr = postprocessed_query_responses - attention_mask = qr != tokenizer.pad_token_id - position_ids = attention_mask.cumsum(1) - attention_mask.long() # exclusive cumsum - input_ids = torch.masked_fill(qr, ~attention_mask, 0) - output = reward_model.lm_backbone(input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, return_dict=True, output_hidden_states=True) - last_reward_latents = output.hidden_states[-1] # TODO: investigate whether it should be output.hidden_states[0] or output.hidden_states[-1] - reward = reward_model.scalar_head(last_reward_latents) - - print(postprocessed_query_responses[0:5,537:]) - print(rew.squeeze(-1)[0:5,537:]) - print(scores) - breakpoint() - - - reference_scores = get_reward_complete(reward_model, query_reference_responses, tokenizer) - # note that we do not truncate the validation responses - validation_score = get_reward_complete(reward_model, postprocessed_sample_validation_query_responses, tokenizer) + # logprobs = torch.masked_fill(logprobs, padding_mask, INVALID_LOGPROB) + # ref_logprobs = torch.masked_fill(ref_logprobs, padding_mask, INVALID_LOGPROB) + # values = torch.masked_fill(values, padding_mask, 0) + + rew, scores = get_reward(reward_model, postprocessed_query_responses, tokenizer) + + _, reference_scores = get_reward(reward_model, query_reference_responses, tokenizer) + _, validation_score = get_reward(reward_model, postprocessed_sample_validation_query_responses, tokenizer) # carperAI-style score normaliation accelerator.print("before score", scores, scores.mean()) @@ -893,25 +773,15 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well mb_query_responses = query_responses[micro_batch_inds] mb_logprobs = logprobs[micro_batch_inds] - # output, vpred_temp = forward(policy, mb_query_responses, tokenizer) - output, (_, vpred_temp) = forward(model, mb_query_responses, tokenizer) - # TODO: value also use the EOS token index!!!!!!!!!!!!!!!!!!! - # TODO: value also use the EOS token index!!!!!!!!!!!!!!!!!!! - # TODO: value also use the EOS token index!!!!!!!!!!!!!!!!!!! - # TODO: value also use the EOS token index!!!!!!!!!!!!!!!!!!! - # TODO: value also use the EOS token index!!!!!!!!!!!!!!!!!!! - # TODO: value also use the EOS token index!!!!!!!!!!!!!!!!!!! - # TODO: value also use the EOS token index!!!!!!!!!!!!!!!!!!! - # TODO: value also use the EOS token index!!!!!!!!!!!!!!!!!!! - # TODO: value also use the EOS token index!!!!!!!!!!!!!!!!!!! + output, (vpred_temp, _) = forward(model, mb_query_responses, tokenizer) logits = output.logits[:, context_length - 1 : -1] logits /= (args.task.temperature + 1e-7) new_all_logprobs = F.log_softmax(logits, dim=-1) new_logprobs = torch.gather(new_all_logprobs, 2, mb_responses.unsqueeze(-1)).squeeze(-1) - new_logprobs = torch.masked_fill(new_logprobs, padding_mask[micro_batch_inds], INVALID_LOGPROB) - vpred = vpred_temp[:, context_length - 1 : -1].squeeze(-1) - vpred = torch.masked_fill(vpred, padding_mask[micro_batch_inds], 0) + # new_logprobs = torch.masked_fill(new_logprobs, padding_mask[micro_batch_inds], INVALID_LOGPROB) + vpred = vpred_temp[:, context_length - 1 : -1] + # vpred = torch.masked_fill(vpred, padding_mask[micro_batch_inds], 0) vpredclipped = torch.clamp( vpred, mb_values - args.ppo.cliprange_value, @@ -929,7 +799,6 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well pg_clipfrac = (pg_losses2 > pg_losses).float().mean() loss = pg_loss + args.ppo.vf_coef * vf_loss accelerator.backward(loss) - breakpoint() optimizer.step() optimizer.zero_grad() prob_dist = torch.nn.functional.softmax(logits, dim=-1) diff --git a/lm_human_preference_details/train_reward_accelerate_summarize.py b/lm_human_preference_details/train_reward_accelerate_summarize.py index 443cc3b..cf9b122 100644 --- a/lm_human_preference_details/train_reward_accelerate_summarize.py +++ b/lm_human_preference_details/train_reward_accelerate_summarize.py @@ -11,6 +11,7 @@ import torch.nn as nn import torch.nn.functional as F import torch.optim as optim +from tqdm import tqdm import transformers import tyro from accelerate import Accelerator @@ -96,7 +97,7 @@ class Args: """the name of the pretrained model to use""" deepspeed: bool = False """Whether to use deepspeed to train the model""" - label_dataset: str = "openai/summarize_from_feedback" + label_dataset: str = "vwxyzjn/summarize_from_feedback_oai_preprocessing" """the name of the dataset to use for labels in `https://huggingface.co/datasets/vwxyzjn/lm-human-preferences`""" local_batch_size: int = 4 """per rank batch size""" @@ -122,13 +123,13 @@ class Args: """Samples used to estimate reward mean and std across all ranks""" debug_normalize: int = 0 """Samples used to check that normalization worked""" - normalize_before: bool = True + normalize_before: bool = False """Whether, before training, to normalize the rewards on the policy to the scales on the training buffer. (For comparisons, just use mean 0, var 1.)""" - normalize_after: bool = True + normalize_after: bool = False """Whether, after training, to normalize the rewards on the ref policy to mean 0, var 1 (so the KL coefficient always has the same meaning).""" print_sample_output_freq: int = 300 """How often to print sample output""" - sft_model_path: str = "models/sft_policy" + sft_model_path: str = "" """Where to load the SFT model""" logsigmoid: bool = True """Whether to use log-sigmoid loss instead of cross-entropy loss""" @@ -146,21 +147,12 @@ class Args: """Which scheduler to use""" warm_up_steps: int = 100 """Number of warm up steps for the scheduler""" + model_dot_train: bool = False + """Whether to call `model.train()`""" task: TaskHParams = field(default_factory=TaskHParams) labels: LabelHParams = field(default_factory=LabelHParams) -def first_true_indices(bools, dtype=torch.long): - """ - Takes an N-dimensional bool tensor and returns an (N-1)-dimensional tensor of integers giving - the position of the first True in each "row". - - Returns the length of the rows (bools.size(-1)) if no element is True in a given row. - """ - row_len = bools.size(-1) - zero_or_index = row_len * (~bools).type(dtype) + torch.arange(row_len, dtype=dtype, device=bools.device) - return torch.min(zero_or_index, dim=-1).values - def print_rich_table(title: str, df: pd.DataFrame, console: Console) -> Table: table = Table(show_lines=True) @@ -339,10 +331,12 @@ def __init__(self, lm_backbone): def forward(self, **kwargs): output = self.lm_backbone(**kwargs) - last_reward_latents = output.hidden_states[-1] + reward_latents = output.hidden_states[-1] + # shape: [batch_size, length, hidden_size] + last_reward_latents = reward_latents[:, -1, :] # shape: [batch_size, hidden_size] reward = self.scalar_head(last_reward_latents) - return output, reward + return reward def right_padding_to_left_padding(tokens, pad_id): @@ -369,8 +363,7 @@ def generate(lm_backbone, queries, tokenizer, generation_config): """generate in a way that does not affect padding tokens""" context_length = queries.shape[1] attention_mask = queries != tokenizer.pad_token_id - input_ids = queries.clone() - input_ids[~attention_mask] = 0 # set padding tokens to 0 + input_ids = torch.masked_fill(queries, ~attention_mask, 0) output = lm_backbone.generate( input_ids=input_ids, attention_mask=attention_mask, @@ -395,78 +388,15 @@ def get_reward(reward_model, query_responses, tokenizer): ) -def get_reward_complete(reward_model, query_responses, tokenizer): - reward = get_reward(reward_model, query_responses, tokenizer)[1] - last_response_indices = first_true_indices(query_responses == tokenizer.pad_token_id) - 1 - last_response_indices = torch.max( - last_response_indices, - torch.zeros([1], dtype=last_response_indices.dtype, device=query_responses.device), - ) - return reward[:, :, 0].gather(1, last_response_indices.unsqueeze(1)).view(-1), reward - - -def normalize( - tokenizer, - accelerator, - device, - lm_backbone, - reward_model, - dataloader, - validation_dataloader, -): - idx = 0 - with torch.no_grad(): - # reset reward scales - accelerator.unwrap_model(reward_model).reward_gain.data.fill_(1.0) - accelerator.unwrap_model(reward_model).reward_bias.data.fill_(0.0) - # number of minibatches for computing the normalization statistics - rewards = [] - for data in dataloader: - idx += len(data["query_token"]) - queries = data["query_token"].to(device) - queries = right_padding_to_left_padding(queries, tokenizer.pad_token_id).to(device) - reference_response = data["reference_response"].to(device) - query_responses = torch.cat((queries, reference_response), dim=1) - score = get_reward_complete(reward_model, query_responses, tokenizer) - rewards.append(score) - accelerator.print(f"====number of samples per device: {idx}") - rewards = torch.cat(rewards) - rewards = accelerator.gather(rewards) - mean, std = rewards.mean(), rewards.std() - print(f"mean: {mean}, std: {std}") - - # reward normalization - target_mean, target_std = torch.tensor(0.0, device=device), torch.tensor(1.0, device=device) - gain = target_std / std - bias = target_mean - gain * mean - print(f"gain: {gain}, bias: {bias}") - accelerator.unwrap_model(reward_model).reward_gain.data = gain - accelerator.unwrap_model(reward_model).reward_bias.data = bias - - # validate normalization - rewards = [] - for data in validation_dataloader: - queries = data["query_token"].to(device) - queries = right_padding_to_left_padding(queries, tokenizer.pad_token_id).to(device) - reference_response = data["reference_response"].to(device) - query_responses = torch.cat((queries, reference_response), dim=1) - score = get_reward_complete(reward_model, query_responses, tokenizer) - rewards.append(score) - rewards = torch.cat(rewards) - rewards = accelerator.gather(rewards) - mean, std = rewards.mean(), rewards.std() - print(f"after mean: {mean}, after std: {std}") - - -def evaluate(args, accelerator, device, reward_model, validation_label): - # reward_model.eval() +def evaluate(args, accelerator, tokenizer, device, reward_model, validation_label): + reward_model.eval() with torch.no_grad(): # eval on validation_label, some duplicate code (I don't want to make the training loop into a function...) test_accuracies = [] eval_len = len(validation_label) len_labels = (eval_len // args.batch_size) * args.batch_size # in case the last batch is not full new_all_inds = np.arange(len_labels) - for start in range(0, len_labels, args.batch_size): + for start in tqdm(range(0, len_labels, args.batch_size)): end = start + args.batch_size b_inds_all = new_all_inds[start:end] b_inds = b_inds_all[accelerator.process_index :: accelerator.num_processes] # multi-GPU slicing @@ -475,24 +405,20 @@ def evaluate(args, accelerator, device, reward_model, validation_label): micro_batch_inds = b_inds[micro_batch_start:micro_batch_end] mb_data = validation_label[micro_batch_inds] mb_query = torch.from_numpy(np.stack(mb_data["query_token"])).to(device) - mb_query = right_padding_to_left_padding(mb_query, args.pad_token_id).to(device) mb_best = torch.from_numpy(np.stack(mb_data["choice"])).to(device) mb_responses = [ - torch.from_numpy(np.stack(mb_data[f"response{i}_token"])).to(device) - for i in range(args.labels.num_labels) - ] - predicted_reward = [] - for i in range(args.labels.num_labels): - query_responses = torch.cat([mb_query, mb_responses[i]], dim=1) - score, _ = get_reward_complete(reward_model, query_responses, args) - predicted_reward.append(score) - predicted_reward = torch.stack( - predicted_reward, dim=1 - ) # shape (batch_size, num_labels), basically a reward prediction for each label + torch.from_numpy(np.stack(mb_data[f"response{i}_token"])).to(device) for i in range(args.labels.num_labels) + ] + mb_query_tiled = mb_query.unsqueeze(1).repeat(1, len(mb_responses), 1) + query_responses = torch.cat([mb_query_tiled, torch.stack(mb_responses).transpose(0,1)], dim=2).flatten(0, 1) + query_responses = right_padding_to_left_padding(query_responses, tokenizer.pad_token_id) + predicted_reward = get_reward(reward_model, query_responses, tokenizer) + predicted_reward = predicted_reward.view(-1, len(mb_responses)) accuracy = (predicted_reward.argmax(1) == mb_best).float().mean() test_accuracies.append(accuracy) test_accuracy = accelerator.gather(torch.stack(test_accuracies).mean()).mean().item() - # reward_model.train() + if args.model_dot_train: + reward_model.train() return test_accuracy @@ -510,16 +436,6 @@ def train(args: Args): args.batch_size = int(args.local_batch_size * args.world_size) args.rollout_batch_size = int(args.local_rollout_batch_size * args.world_size) args.local_micro_batch_size = exact_div(args.local_batch_size, args.gradient_accumulation_steps) - args.num_updates = args.labels.num_train // args.batch_size - patch_h = TaskQueryHParams( - length=args.task.query_length, - dataset=args.task.query_dataset, - format_str=args.task.query_format_str, - truncate_field=args.task.query_truncate_field, - truncate_text=args.task.query_truncate_text, - padding=args.task.query_padding, - pad_side=args.task.query_pad_side, - ) console = Console(force_terminal=True) run_name = f"{args.exp_name}__{args.seed}__{int(time.time())}" @@ -557,7 +473,6 @@ def train(args: Args): ) # we use the padding token manually but do not resize the token embedding of the model tokenizer.add_special_tokens({"pad_token": "[PAD]"}) - args.pad_token_id = tokenizer.pad_token_id reward_model = AutoModelForCausalLMWithRewardHead( AutoModelForCausalLM.from_pretrained(args.base_model, trust_remote_code=True) ) @@ -603,50 +518,6 @@ def train(args: Args): reward_model, optimizer, scheduler = accelerator.prepare(reward_model, optimizer, scheduler) - if args.normalize_before: - dataset = load_dataset(args.task.query_dataset, split="train") - validation_dataset = load_dataset(args.task.query_dataset, split="validation") - - def process_query_data(x): - return { - **process_query(x, encoder=tokenizer, hparams=patch_h), - "reference_response": tokenizer.encode( - f" {x['summary']}<|endoftext|>", padding="max_length", max_length=args.task.response_length, truncation=True, - # with an extra leading space to account for the space between the query and response - ), - } - - dataset = dataset.map(process_query_data, load_from_cache_file=args.load_from_cache_file) - dataset = dataset.with_format("torch", columns=["query_token", "reference_response"]) - dataset = dataset.shuffle(seed=local_seed) - dataloader = DataLoader(dataset, batch_size=args.local_rollout_batch_size) - validation_dataset = validation_dataset.map(process_query_data, load_from_cache_file=args.load_from_cache_file) - validation_dataset = validation_dataset.with_format("torch", columns=["query_token", "reference_response"]) - validation_dataset = validation_dataset.shuffle(seed=local_seed) - validation_dataloader = DataLoader(validation_dataset, batch_size=args.local_rollout_batch_size) - dataloader = accelerator.prepare(dataloader) - iter_dataloader = iter(dataloader) - print("===Normalize reward model *before* training===") - print( - "before normalization. " - + f"Gain: {accelerator.unwrap_model(reward_model).reward_gain.data}" - + f" Bias: {accelerator.unwrap_model(reward_model).reward_bias.data}" - ) - - normalize( - tokenizer, - accelerator, - device, - reward_model, - reward_model, - dataloader, - validation_dataloader, - ) - print( - "after normalization. " - + f"Gain: {accelerator.unwrap_model(reward_model).reward_gain.data}" - + f" Bias: {accelerator.unwrap_model(reward_model).reward_bias.data}" - ) # `label` has keys `['sample0', 'query', 'best', 'sample3', 'sample1', 'sample2']` label = load_dataset(args.label_dataset, "comparisons", split="train") @@ -655,24 +526,10 @@ def process_query_data(x): eval_validation_label = validation_label.filter(lambda x: x["split"] == "valid2") accelerator.print("Num labels found in source:", len(label)) accelerator.print("training on", args.labels.num_train, "in batches of", args.local_batch_size) - - def process_response_data(x): - return { - **process_query(x["info"], encoder=tokenizer, hparams=patch_h), - "response0_token": tokenizer.encode( - f" {x['summaries'][0]['text']}<|endoftext|>", padding="max_length", max_length=args.task.response_length, truncation=True - ), - "response1_token": tokenizer.encode( - f" {x['summaries'][1]['text']}<|endoftext|>", padding="max_length", max_length=args.task.response_length, truncation=True - ), - } - - label = label.map(process_response_data, load_from_cache_file=args.load_from_cache_file) - dev_validation_label = dev_validation_label.map(process_response_data, load_from_cache_file=args.load_from_cache_file) - eval_validation_label = eval_validation_label.map(process_response_data, load_from_cache_file=args.load_from_cache_file) - # TODO: check if all labels have eos token accelerator.print("===training reward model===") num_train = (args.labels.num_train // args.batch_size) * args.batch_size + if args.model_dot_train: + reward_model.train() for epoch in range(args.num_epochs): all_inds = np.random.permutation(args.labels.num_train) # ensure that all processes have the same shuffled indices @@ -684,7 +541,6 @@ def process_response_data(x): end = start + args.batch_size b_inds_all = all_inds[start:end] b_inds = b_inds_all[accelerator.process_index :: accelerator.num_processes] # multi-GPU slicing - # accelerator.print(f"global_step: {global_step}, start: {start}, end: {end}, b_inds: {b_inds}") if accelerator.is_main_process: pprint( { "global_step": global_step, @@ -698,16 +554,11 @@ def process_response_data(x): reward_preferreds = torch.zeros((args.gradient_accumulation_steps,), device=device) reward_rejecteds = torch.zeros((args.gradient_accumulation_steps,), device=device) gradient_accumulation_step = 0 - # reward_model.train() for micro_batch_start in range(0, args.local_batch_size, args.local_micro_batch_size): with accelerator.accumulate(reward_model): micro_batch_end = micro_batch_start + args.local_micro_batch_size micro_batch_inds = b_inds[micro_batch_start:micro_batch_end] mb_data = label[micro_batch_inds] - # pprint({ - # "micro_batch_start:micro_batch_end": f"{micro_batch_start}:{micro_batch_end}", - # "micro_batch_inds": micro_batch_inds, - # }) mb_query = torch.from_numpy(np.stack(mb_data["query_token"])).to(device) mb_best = torch.from_numpy(np.stack(mb_data["choice"])).to(device) mb_responses = [ @@ -715,22 +566,9 @@ def process_response_data(x): ] mb_query_tiled = mb_query.unsqueeze(1).repeat(1, len(mb_responses), 1) query_responses = torch.cat([mb_query_tiled, torch.stack(mb_responses).transpose(0,1)], dim=2).flatten(0, 1) - predicted_reward, reward = get_reward_complete(reward_model, query_responses, tokenizer) - predicted_reward = predicted_reward.view(-1, len(mb_responses)) # TODO check shape for no gradienta ccumulation steps - - # print(tokenizer.decode(mb_query[0])) - # print(tokenizer.decode(mb_responses[0][0])) - # print(tokenizer.decode(mb_responses[1][0])) - # predicted_reward = [] - # rewards = [] - # for i in range(args.labels.num_labels): - # query_responses = torch.cat([mb_query, mb_responses[i]], dim=1) - # score, reward = get_reward_complete(reward_model, query_responses, tokenizer) - # rewards.append(reward.squeeze(-1)) - # predicted_reward.append(score) - # # shape (batch_size, num_labels), basically a reward prediction for each label - # predicted_reward = torch.stack(predicted_reward, dim=1) - # breakpoint() + query_responses = right_padding_to_left_padding(query_responses, tokenizer.pad_token_id) + predicted_reward = get_reward(reward_model, query_responses, tokenizer) + predicted_reward = predicted_reward.view(-1, len(mb_responses)) accuracy = (predicted_reward.argmax(1) == mb_best).float().mean() reward_preferred = predicted_reward.gather(1, mb_best.view(-1, 1)).view(-1) reward_rejected = predicted_reward.gather(1, (1 - mb_best).view(-1, 1)).view(-1) @@ -739,12 +577,10 @@ def process_response_data(x): else: loss = F.cross_entropy(predicted_reward, mb_best) accelerator.backward(loss) - # for k, v in reward_model.named_parameters(): # if v.requires_grad: # if v.grad is None: # print(f"found unused param: {k}") - optimizer.step() # accelerate handles gradient accumulation automatically optimizer.zero_grad() scheduler.step() @@ -765,39 +601,18 @@ def process_response_data(x): # if args.print_sample_output_freq > 0 and global_step % args.print_sample_output_freq == 0: if global_step == args.num_updates - 1: # first and last update - dev_validation_accuracy = evaluate(args, accelerator, device, reward_model, dev_validation_label) + # if global_step == 1: + dev_validation_accuracy = evaluate(args, accelerator, tokenizer, device, reward_model, dev_validation_label) writer.add_scalar("dev_validation/accuracy", dev_validation_accuracy, global_step) accelerator.print("dev_validation/accuracy", dev_validation_accuracy, global_step) - eval_validation_accuracy = evaluate(args, accelerator, device, reward_model, eval_validation_label) + eval_validation_accuracy = evaluate(args, accelerator, tokenizer, device, reward_model, eval_validation_label) writer.add_scalar("eval_validation/accuracy", eval_validation_accuracy, global_step) accelerator.print("eval_validation/accuracy", eval_validation_accuracy, global_step) - eval_validation_accuracy = evaluate(args, accelerator, device, reward_model, label) + eval_validation_accuracy = evaluate(args, accelerator, tokenizer, device, reward_model, label) writer.add_scalar("train_full/accuracy", eval_validation_accuracy, global_step) accelerator.print("train_full/accuracy", eval_validation_accuracy, global_step) torch.cuda.empty_cache() - if args.normalize_after: - print("===Normalize reward model *after* training===") - print( - "before normalization. " - + f"Gain: {accelerator.unwrap_model(reward_model).reward_gain.data}" - + f" Bias: {accelerator.unwrap_model(reward_model).reward_bias.data}" - ) - - normalize( - tokenizer, - accelerator, - device, - reward_model, - reward_model, - dataloader, - validation_dataloader, - ) - print( - "after normalization. " - + f"Gain: {accelerator.unwrap_model(reward_model).reward_gain.data}" - + f" Bias: {accelerator.unwrap_model(reward_model).reward_bias.data}" - ) # save model if args.save_path: From b6d4984d447bdfc997b6a801d1eafba88a9f45d1 Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Sat, 28 Oct 2023 22:55:25 +0000 Subject: [PATCH 18/62] push changes --- ...in_policy_accelerate_summarize_separate.py | 51 +++--- .../train_reward_accelerate_summarize.py | 24 +-- .../train_sft_accelerate_summarize.py | 159 ++++++++---------- 3 files changed, 114 insertions(+), 120 deletions(-) diff --git a/lm_human_preference_details/train_policy_accelerate_summarize_separate.py b/lm_human_preference_details/train_policy_accelerate_summarize_separate.py index 71a5dd5..d284d66 100644 --- a/lm_human_preference_details/train_policy_accelerate_summarize_separate.py +++ b/lm_human_preference_details/train_policy_accelerate_summarize_separate.py @@ -137,14 +137,14 @@ class Args: """the name of the pretrained model to use""" deepspeed: bool = False """Whether to use deepspeed to train the model""" - print_sample_output_freq: int = 1 + print_sample_output_freq: int = 10 """How often to print sample output""" - sft_model_path: str = "" - """Where to load the SFT model""" - save_path: str = "models/policy.pt" + save_path: str = "models/ppo_policy" """Where to save the model""" optimizer: Literal["tf_adam", "adam", "adamw"] = "adamw" """Which optimizer to use""" + sft_model_path: str = "" + """Where to load the SFT model""" task: TaskHParams = field(default_factory=TaskHParams) rewards: RewardHParams = field(default_factory=RewardHParams) ppo: PpoHParams = field(default_factory=PpoHParams) @@ -378,7 +378,7 @@ def forward(self, **kwargs): return self.policy(**kwargs), self.critic(**kwargs) -def right_padding_to_left_padding(tokens, pad_id): +def shift_pad_id_left(tokens, pad_id): """Convert from right padding to left padding.""" assert tokens.ndim == 2 return torch.tensor( @@ -386,6 +386,16 @@ def right_padding_to_left_padding(tokens, pad_id): device=tokens.device, ) +def shift_pad_id_left(data, pad_id): + # Step 1: Create a boolean mask + mask = (data == pad_id).long() + # Step 3: Use argsort on the inverted boolean mask to get sorted indices + sorted_indices = torch.argsort(~mask, axis=1) + # Step 4: Use advanced indexing to rearrange the elements + rows_range = torch.arange(data.shape[0], device=data.device) + shifted_data = data[rows_range[:, None], sorted_indices] + return shifted_data + def ceil_div(a, b): return (a - 1) // b + 1 @@ -444,7 +454,6 @@ def truncate_response(args, tokenizer, responses): new_size = [1] * (len(responses.size()) - 1) + [args.task.response_length] idxs = torch.arange(args.task.response_length, device=responses.device).view(*new_size) postprocessed_responses = torch.masked_fill(responses, idxs > trunc_idxs, tokenizer.pad_token_id) - postprocessed_responses = right_padding_to_left_padding(postprocessed_responses, tokenizer.pad_token_id) return postprocessed_responses # def train(args: Args): @@ -576,10 +585,10 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well sample_validation = validation_dataset[local_sample_validation_inds] sample_validation_queries = torch.Tensor(sample_validation["query_token"]).to(device) with torch.no_grad(): - sample_validation_queries = right_padding_to_left_padding(sample_validation_queries, tokenizer.pad_token_id) + sample_validation_queries = shift_pad_id_left(sample_validation_queries, tokenizer.pad_token_id) sample_validation_reference_response = torch.Tensor(sample_validation["reference_response_token"]).to(device) sample_validation_query_reference_responses = torch.cat((sample_validation_queries, sample_validation_reference_response), dim=1) - sample_validation_query_reference_responses = right_padding_to_left_padding(sample_validation_query_reference_responses, tokenizer.pad_token_id) + sample_validation_query_reference_responses = shift_pad_id_left(sample_validation_query_reference_responses, tokenizer.pad_token_id) _, sample_validation_reference_scores = get_reward(reward_model, sample_validation_query_reference_responses, tokenizer) iter_dataloader = iter(repeat_generator()) @@ -613,9 +622,9 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well with torch.no_grad(): queries = data["query_token"].to(device) reference_responses = data["reference_response_token"].to(device) - queries = right_padding_to_left_padding(data["query_token"], tokenizer.pad_token_id).to(device) + queries = shift_pad_id_left(queries, tokenizer.pad_token_id) query_reference_responses = torch.cat((queries, reference_responses), dim=1) - query_reference_responses = right_padding_to_left_padding(query_reference_responses, tokenizer.pad_token_id) + query_reference_responses= shift_pad_id_left(query_reference_responses, tokenizer.pad_token_id) query_responses = generate( accelerator.unwrap_model(model).policy, queries, @@ -635,8 +644,10 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well sample_validation_responses = sample_validation_query_responses[:, context_length:] postprocessed_sample_validation_responses = truncate_response(args, tokenizer, sample_validation_responses) postprocessed_sample_validation_query_responses = torch.cat((sample_validation_queries, postprocessed_sample_validation_responses), 1) + postprocessed_sample_validation_query_responses = shift_pad_id_left(postprocessed_sample_validation_query_responses, tokenizer.pad_token_id) torch.cuda.empty_cache() + # TODO: do I do this with query response or post-processed query response? output = forward(accelerator.unwrap_model(model).policy, query_responses, tokenizer) logits = output.logits[:, context_length - 1 : -1] logits /= (args.task.temperature + 1e-7) @@ -659,7 +670,7 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well # 2. run reward model on the truncated responses postprocessed_query_responses = torch.cat((queries, postprocessed_responses), 1) - postprocessed_query_responses = right_padding_to_left_padding(postprocessed_query_responses, tokenizer.pad_token_id) + postprocessed_query_responses = shift_pad_id_left(postprocessed_query_responses, tokenizer.pad_token_id) full_values, _ = get_reward(accelerator.unwrap_model(model).critic, postprocessed_query_responses, tokenizer) values = full_values[:, context_length - 1 : -1].squeeze(-1) padding_mask = postprocessed_responses == tokenizer.pad_token_id @@ -673,10 +684,10 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well _, validation_score = get_reward(reward_model, postprocessed_sample_validation_query_responses, tokenizer) # carperAI-style score normaliation - accelerator.print("before score", scores, scores.mean()) - accelerator.print("reference_scores", reference_scores, reference_scores.mean()) + # accelerator.print("before score", scores, scores.mean()) + # accelerator.print("reference_scores", reference_scores, reference_scores.mean()) scores = scores - reference_scores - accelerator.print("after score", scores, scores.mean()) + # accelerator.print("after score", scores, scores.mean()) # 3. filter response. Ensure that the sample contains truncate_token # responses not passing that filter will receive a low (fixed) score @@ -697,16 +708,13 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well if args.print_sample_output_freq > 0 and (update - 1) % args.print_sample_output_freq == 0: try: - all_decode_validation_queries = tokenizer.batch_decode(sample_validation_queries) - all_sample_validation_query_responses = tokenizer.batch_decode( - sample_validation_query_responses + all_decode_validation_queries = tokenizer.batch_decode(sample_validation_queries, skip_special_tokens=True) + all_sample_validation_responses = tokenizer.batch_decode( + postprocessed_sample_validation_responses ) all_sample_validation_query_responses_postprocessed = tokenizer.batch_decode( - postprocessed_sample_validation_query_responses + postprocessed_sample_validation_query_responses, skip_special_tokens=True ) - all_sample_validation_responses = [ - x[len(y) :] for x, y in zip(all_sample_validation_query_responses, all_decode_validation_queries) - ] all_sample_validation_postprocessed_responses = [ x[len(y) :] for x, y in zip(all_sample_validation_query_responses_postprocessed, all_decode_validation_queries) ] @@ -731,7 +739,6 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well print(e) del ( all_decode_validation_queries, - all_sample_validation_query_responses, all_sample_validation_responses, all_sample_validation_reference_responses, all_sample_validation_df, diff --git a/lm_human_preference_details/train_reward_accelerate_summarize.py b/lm_human_preference_details/train_reward_accelerate_summarize.py index cf9b122..afb18d9 100644 --- a/lm_human_preference_details/train_reward_accelerate_summarize.py +++ b/lm_human_preference_details/train_reward_accelerate_summarize.py @@ -46,7 +46,7 @@ class LabelHParams: class TaskHParams: # Query params query_length: int = 512 - query_dataset: str = "vwxyzjn/summarize_from_feedback_tldr_3_filtered" + query_dataset: str = "vwxyzjn/summarize_from_feedback_tldr_3_filtered_oai_preprocessing" query_format_str: Optional[str] = "SUBREDDIT: r/{subreddit}\n\nTITLE: {title}\n\nPOST: {post}\n\nTL;DR:" query_truncate_field: Optional[str] = "post" @@ -339,13 +339,15 @@ def forward(self, **kwargs): return reward -def right_padding_to_left_padding(tokens, pad_id): - """Convert from right padding to left padding.""" - assert tokens.ndim == 2 - return torch.tensor( - [[pad_id] * (row == pad_id).sum() + [x for x in row if x != pad_id] for row in tokens], - device=tokens.device, - ) +def shift_pad_id_left(data, pad_id): + # Step 1: Create a boolean mask + mask = (data == pad_id).long() + # Step 3: Use argsort on the inverted boolean mask to get sorted indices + sorted_indices = torch.argsort(~mask, axis=1) + # Step 4: Use advanced indexing to rearrange the elements + rows_range = torch.arange(data.shape[0], device=data.device) + shifted_data = data[rows_range[:, None], sorted_indices] + return shifted_data def ceil_div(a, b): @@ -371,7 +373,6 @@ def generate(lm_backbone, queries, tokenizer, generation_config): generation_config=generation_config, return_dict_in_generate=True, ) - # restore padding tokens return torch.cat((queries, output.sequences[:, context_length:]), dim=1) @@ -411,7 +412,7 @@ def evaluate(args, accelerator, tokenizer, device, reward_model, validation_labe ] mb_query_tiled = mb_query.unsqueeze(1).repeat(1, len(mb_responses), 1) query_responses = torch.cat([mb_query_tiled, torch.stack(mb_responses).transpose(0,1)], dim=2).flatten(0, 1) - query_responses = right_padding_to_left_padding(query_responses, tokenizer.pad_token_id) + query_responses = shift_pad_id_left(query_responses, tokenizer.pad_token_id) predicted_reward = get_reward(reward_model, query_responses, tokenizer) predicted_reward = predicted_reward.view(-1, len(mb_responses)) accuracy = (predicted_reward.argmax(1) == mb_best).float().mean() @@ -436,6 +437,7 @@ def train(args: Args): args.batch_size = int(args.local_batch_size * args.world_size) args.rollout_batch_size = int(args.local_rollout_batch_size * args.world_size) args.local_micro_batch_size = exact_div(args.local_batch_size, args.gradient_accumulation_steps) + args.num_updates = args.labels.num_train // args.batch_size console = Console(force_terminal=True) run_name = f"{args.exp_name}__{args.seed}__{int(time.time())}" @@ -566,7 +568,7 @@ def train(args: Args): ] mb_query_tiled = mb_query.unsqueeze(1).repeat(1, len(mb_responses), 1) query_responses = torch.cat([mb_query_tiled, torch.stack(mb_responses).transpose(0,1)], dim=2).flatten(0, 1) - query_responses = right_padding_to_left_padding(query_responses, tokenizer.pad_token_id) + query_responses = shift_pad_id_left(query_responses, tokenizer.pad_token_id) predicted_reward = get_reward(reward_model, query_responses, tokenizer) predicted_reward = predicted_reward.view(-1, len(mb_responses)) accuracy = (predicted_reward.argmax(1) == mb_best).float().mean() diff --git a/lm_human_preference_details/train_sft_accelerate_summarize.py b/lm_human_preference_details/train_sft_accelerate_summarize.py index 0d6988f..cf16d90 100644 --- a/lm_human_preference_details/train_sft_accelerate_summarize.py +++ b/lm_human_preference_details/train_sft_accelerate_summarize.py @@ -26,9 +26,7 @@ ) from torch.utils.data import DataLoader from torch.utils.tensorboard import SummaryWriter -from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig, get_scheduler - -from lm_human_preference_details.data import process_query +from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, GenerationConfig, get_scheduler @dataclass @@ -38,7 +36,6 @@ class SFTHParams: noptepochs: int = 1 lr: float = 6.35e-5 eps: float = 1e-5 - lm_loss_on_response_only: bool = False total_episodes: tyro.conf.Suppress[int] = None local_batch_size:tyro.conf.Suppress[int] = None batch_size: tyro.conf.Suppress[int] = None @@ -52,7 +49,7 @@ class SFTHParams: class TaskHParams: # Query params query_length: int = 512 - query_dataset: str = "vwxyzjn/summarize_from_feedback_tldr_3_filtered" + query_dataset: str = "vwxyzjn/summarize_from_feedback_tldr_3_filtered_oai_preprocessing" query_format_str: Optional[str] = "SUBREDDIT: r/{subreddit}\n\nTITLE: {title}\n\nPOST: {post}\n\nTL;DR:" query_truncate_field: Optional[str] = "post" @@ -126,6 +123,16 @@ class Args: sft: SFTHParams = field(default_factory=SFTHParams) +# taken from https://github.com/microsoft/DeepSpeedExamples/blob/737c6740bec38b77a24a59135b6481a53d566b38/applications/DeepSpeed-Chat/training/utils/model/model_utils.py#L20C1-L26C52 +def configure_dropout(model_config, dropout): + if dropout is not None: + for key in ('dropout', 'attention_dropout', 'hidden_dropout', + 'activation_dropout'): + if hasattr(model_config, key): + print(f"Setting model_config.{key} to {dropout}") + setattr(model_config, key, dropout) + + def print_rich_table(title: str, df: pd.DataFrame, console: Console) -> Table: table = Table(show_lines=True) for column in df.columns: @@ -284,13 +291,15 @@ def step(self, closure=None): return loss -def right_padding_to_left_padding(tokens, pad_id): - """Convert from right padding to left padding.""" - assert tokens.ndim == 2 - return torch.tensor( - [[pad_id] * (row == pad_id).sum() + [x for x in row if x != pad_id] for row in tokens], - device=tokens.device, - ) +def shift_pad_id_left(data, pad_id): + # Step 1: Create a boolean mask + mask = (data == pad_id).long() + # Step 3: Use argsort on the inverted boolean mask to get sorted indices + sorted_indices = torch.argsort(~mask, axis=1) + # Step 4: Use advanced indexing to rearrange the elements + rows_range = torch.arange(data.shape[0], device=data.device) + shifted_data = data[rows_range[:, None], sorted_indices] + return shifted_data def ceil_div(a, b): @@ -308,8 +317,7 @@ def generate(lm_backbone, queries, tokenizer, generation_config): """generate in a way that does not affect padding tokens""" context_length = queries.shape[1] attention_mask = queries != tokenizer.pad_token_id - input_ids = queries.clone() - input_ids[~attention_mask] = 0 # set padding tokens to 0 + input_ids = torch.masked_fill(queries, ~attention_mask, 0) output = lm_backbone.generate( input_ids=input_ids, attention_mask=attention_mask, @@ -317,7 +325,6 @@ def generate(lm_backbone, queries, tokenizer, generation_config): generation_config=generation_config, return_dict_in_generate=True, ) - # restore padding tokens return torch.cat((queries, output.sequences[:, context_length:]), dim=1) @@ -340,19 +347,10 @@ def forward(policy, query_responses, tokenizer): args.sft.world_size = accelerator.num_processes args.sft.local_batch_size = args.sft.local_micro_batch_size * args.sft.gradient_accumulation_steps args.sft.batch_size = int(args.sft.local_batch_size * args.sft.world_size) - patch_h = TaskQueryHParams( - length=args.task.query_length, - dataset=args.task.query_dataset, - format_str=args.task.query_format_str, - truncate_field=args.task.query_truncate_field, - truncate_text=args.task.query_truncate_text, - padding=args.task.query_padding, - pad_side=args.task.query_pad_side, - ) dataset = load_dataset(args.task.query_dataset, split="train") - test_dataset = load_dataset(args.task.query_dataset, split="test") + validation_dataset = load_dataset(args.task.query_dataset, split="validation") accelerator.print("The number of samples in dataset", len(dataset)) - accelerator.print("The number of samples in test_dataset", len(test_dataset)) + accelerator.print("The number of samples in validation_dataset", len(validation_dataset)) args.sft.total_episodes = len(dataset) args.sft.num_updates = args.sft.total_episodes // args.sft.batch_size @@ -392,7 +390,9 @@ def forward(policy, query_responses, tokenizer): ) # we use the padding token manually but do not resize the token embedding of the model tokenizer.add_special_tokens({"pad_token": "[PAD]"}) - policy = AutoModelForCausalLM.from_pretrained(args.base_model, trust_remote_code=True) + model_config = AutoConfig.from_pretrained(args.base_model) + configure_dropout(model_config, 0.0) # disable dropout + policy = AutoConfig, AutoModelForCausalLM.from_pretrained(args.base_model, trust_remote_code=True, config=model_config) policy.generation_config.eos_token_id = None # disable `pad_token_id` and `eos_token_id` because we just want to policy.generation_config.pad_token_id = None # generate tokens without truncation / padding # IMPORTANT: Layer norm produces weird gradients, which affects Adam optimizer to impact all the parameters systematically @@ -411,31 +411,19 @@ def forward(policy, query_responses, tokenizer): num_training_steps=args.sft.num_updates // args.sft.gradient_accumulation_steps, ) - def process_query_data(x): - return { - **process_query(x, encoder=tokenizer, hparams=patch_h), - "reference_response": tokenizer.encode( - f" {x['summary']}<|endoftext|>", padding="max_length", max_length=args.task.response_length, truncation=True, - # with an extra leading space to account for the space between the query and response - ), - } - - dataset = dataset.map(process_query_data, load_from_cache_file=args.load_from_cache_file) - dataset = dataset.with_format("torch", columns=["query_token", "reference_response"]) + dataset = dataset.with_format("torch", columns=["query_token", "reference_response_token"]) dataset = dataset.shuffle(seed=local_seed) - test_dataset = test_dataset.map(process_query_data, load_from_cache_file=args.load_from_cache_file) - test_dataset = test_dataset.with_format("torch", columns=["query_token", "reference_response"]) - test_dataset = test_dataset.shuffle(seed=local_seed) dataloader = DataLoader(dataset, batch_size=args.sft.local_micro_batch_size) - test_dataloader = DataLoader(test_dataset, batch_size=args.sft.local_micro_batch_size) - policy, optimizer, dataloader, test_dataloader, scheduler = accelerator.prepare(policy, optimizer, dataloader, test_dataloader, scheduler) + validation_dataset = validation_dataset.with_format("torch", columns=["query_token", "reference_response_token"]) + validation_dataloader = DataLoader(validation_dataset, batch_size=args.sft.local_micro_batch_size) + policy, optimizer, dataloader, validation_dataloader, scheduler = accelerator.prepare(policy, optimizer, dataloader, validation_dataloader, scheduler) iter_dataloader = iter(dataloader) # WARNING: even with `max_new_tokens` and `min_new_tokens` set to the same value, the number of tokens generated # may not be the same. TODO: investigate further, we just want to generate a fixed number of tokens generation_config = GenerationConfig( max_new_tokens=args.task.response_length, min_new_tokens=args.task.response_length, - temperature=args.task.temperature, + temperature=(args.task.temperature + 1e-7), top_k=0.0, top_p=1.0, do_sample=True, @@ -451,17 +439,14 @@ def process_query_data(x): global_step += args.sft.batch_size accelerator.print(f"update {update}, global_step {global_step}") data = next(iter_dataloader) - reference_responses = data["reference_response"].to(device, non_blocking=True) + reference_responses = data["reference_response_token"].to(device, non_blocking=True) queries = data["query_token"].to(device, non_blocking=True) - queries = right_padding_to_left_padding(queries, tokenizer.pad_token_id).to(device) query_responses = torch.cat((queries, reference_responses), dim=1) + query_responses = shift_pad_id_left(query_responses, tokenizer.pad_token_id) with accelerator.accumulate(policy): output = forward(policy, query_responses, tokenizer) # mask out gradient effects on response padding tokens labels = query_responses.masked_fill(query_responses == tokenizer.pad_token_id, -1) - if args.sft.lm_loss_on_response_only: - # mask out gradient effects on query tokens - labels[:, :queries.shape[1]] = -1 lm_logits = output.logits # hand-rolled transformer loss: Shift so that tokens < n predict n # but unlike `transformers` we mask the padding tokens via `ignore_index=-1` @@ -480,59 +465,59 @@ def process_query_data(x): if update == 1 or update == args.sft.num_updates - 1: policy.eval() rouge_scores = collections.defaultdict(list) - all_decode_test_queries = [] - all_decode_test_query_responses = [] - all_decode_test_responses = [] - all_decode_test_reference_responses = [] - all_test_losses = [] - for test_idx, test_data in enumerate(test_dataloader): + all_decode_validation_queries = [] + all_decode_validation_query_responses = [] + all_decode_validation_responses = [] + all_decode_validation_reference_responses = [] + all_validation_losses = [] + for validation_idx, validation_data in enumerate(validation_dataloader): with torch.no_grad(): - test_reference_responses = test_data["reference_response"].to(device, non_blocking=True) - test_queries = test_data["query_token"].to(device, non_blocking=True) - test_queries = right_padding_to_left_padding(test_queries, tokenizer.pad_token_id) - test_query_reference_responses = torch.cat((test_queries, test_reference_responses), dim=1) + validation_reference_responses = validation_data["reference_response_token"].to(device, non_blocking=True) + validation_queries = validation_data["query_token"].to(device, non_blocking=True) + validation_queries = shift_pad_id_left(validation_queries, tokenizer.pad_token_id) + validation_query_reference_responses = torch.cat((validation_queries, validation_reference_responses), dim=1) - test_output = forward(policy, test_query_reference_responses, tokenizer) - test_labels = test_query_reference_responses.masked_fill(test_query_reference_responses == tokenizer.pad_token_id, -1) + validation_output = forward(policy, validation_query_reference_responses, tokenizer) + validation_labels = validation_query_reference_responses.masked_fill(validation_query_reference_responses == tokenizer.pad_token_id, -1) if args.sft.lm_loss_on_response_only: - test_labels[:, :queries.shape[1]] = -1 - test_lm_logits = test_output.logits + validation_labels[:, :queries.shape[1]] = -1 + validation_lm_logits = validation_output.logits # hand-rolled transformer loss: Shift so that tokens < n predict n # but unlike `transformers` we mask the padding tokens via `ignore_index=-1` - test_shift_logits = test_lm_logits[..., :-1, :].contiguous() - test_shift_labels = test_labels[..., 1:].contiguous() - test_loss = F.cross_entropy(test_shift_logits.view(-1, test_shift_logits.size(-1)), test_shift_labels.view(-1), ignore_index=-1) - test_loss = accelerator.gather(test_loss) - all_test_losses.append(test_loss) - - generated_responses = generate(accelerator.unwrap_model(policy), test_queries, tokenizer, generation_config) - decode_test_queries = tokenizer.batch_decode(accelerator.gather(test_queries)) - decode_test_query_responses = tokenizer.batch_decode(accelerator.gather(generated_responses)) - decode_test_reference_responses = tokenizer.batch_decode( - accelerator.gather(test_reference_responses) + validation_shift_logits = validation_lm_logits[..., :-1, :].contiguous() + validation_shift_labels = validation_labels[..., 1:].contiguous() + validation_loss = F.cross_entropy(validation_shift_logits.view(-1, validation_shift_logits.size(-1)), validation_shift_labels.view(-1), ignore_index=-1) + validation_loss = accelerator.gather(validation_loss) + all_validation_losses.append(validation_loss) + + generated_responses = generate(accelerator.unwrap_model(policy), validation_queries, tokenizer, generation_config) + decode_validation_queries = tokenizer.batch_decode(accelerator.gather(validation_queries)) + decode_validation_query_responses = tokenizer.batch_decode(accelerator.gather(generated_responses)) + decode_validation_reference_responses = tokenizer.batch_decode( + accelerator.gather(validation_reference_responses) ) - decode_test_responses = [ - x[len(y) :] for x, y in zip(decode_test_query_responses, decode_test_queries) + decode_validation_responses = [ + x[len(y) :] for x, y in zip(decode_validation_query_responses, decode_validation_queries) ] - rouge_score = rouge.compute(predictions=decode_test_responses, references=decode_test_reference_responses) + rouge_score = rouge.compute(predictions=decode_validation_responses, references=decode_validation_reference_responses) rouge_scores["rouge1"].append(rouge_score["rouge1"]) rouge_scores["rouge2"].append(rouge_score["rouge2"]) rouge_scores["rougeL"].append(rouge_score["rougeL"]) - all_decode_test_queries.extend(decode_test_queries) - accelerator.print("len(all_decode_test_queries)", len(all_decode_test_queries), decode_test_responses) - all_decode_test_query_responses.extend(decode_test_query_responses) - all_decode_test_responses.extend(decode_test_responses) - all_decode_test_reference_responses.extend(decode_test_reference_responses) - if test_idx == 10: + all_decode_validation_queries.extend(decode_validation_queries) + accelerator.print("len(all_decode_validation_queries)", len(all_decode_validation_queries), decode_validation_responses) + all_decode_validation_query_responses.extend(decode_validation_query_responses) + all_decode_validation_responses.extend(decode_validation_responses) + all_decode_validation_reference_responses.extend(decode_validation_reference_responses) + if validation_idx == 10: break try: all_df = pd.DataFrame( { - "query": all_decode_test_queries, - "response": all_decode_test_responses, - "reference": all_decode_test_reference_responses, + "query": all_decode_validation_queries, + "response": all_decode_validation_responses, + "reference": all_decode_validation_reference_responses, } ) accelerator.print(all_df) @@ -547,7 +532,7 @@ def process_query_data(x): rouge_metric = accelerator.gather(rouge_metric) writer.add_scalar(f"rouge/{k}", rouge_metric.mean().item(), update) accelerator.print(f"rouge/{k}: {rouge_metric.mean().item()} {rouge_metric.shape} {rouge_metric}") - writer.add_scalar("test_loss", torch.stack(all_test_losses).mean().item(), update) + writer.add_scalar("validation_loss", torch.stack(all_validation_losses).mean().item(), update) policy.train() # save model From 3b7cbb7d8951860536de3dfa4f25ec1e6416761d Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Sat, 28 Oct 2023 18:58:06 -0400 Subject: [PATCH 19/62] pre-commit --- .../summarization/minimal_rm copy.py | 7 +- .../summarization/minimal_rm.py | 6 +- .../summarization/minisft.py | 26 ++-- .../train_policy_accelerate copy 2.py | 2 +- .../train_policy_accelerate copy.py | 44 ++++--- .../train_policy_accelerate_new.py | 28 ++--- .../train_policy_accelerate_old.py | 24 ++-- ...in_policy_accelerate_summarize_ref_diff.py | 15 ++- .../train_reward_accelerate copy.py | 8 +- .../train_reward_accelerate_debug copy.py | 94 ++++++++------ .../train_reward_accelerate_debug.py | 115 +++++++++++------- ...train_reward_accelerate_summarize_debug.py | 16 ++- .../train_reward_accelerate_summarized.py | 51 ++++---- .../train_reward_accelerate_summarizew.py | 60 +++++---- .../train_sft_accelerate_summarize copy.py | 26 ++-- ...train_sft_accelerate_summarize_executor.py | 29 +++-- .../train_policy_accelerate_summarize.py | 6 +- ...in_policy_accelerate_summarize_separate.py | 76 ++++++------ .../train_reward_accelerate_summarize.py | 59 +++++---- lm_human_preference_details/tldr_dataset.py | 24 ++-- ...in_policy_accelerate_summarize_separate.py | 69 ++++++----- .../train_reward_accelerate_summarize.py | 41 +++---- .../train_sft_accelerate_summarize.py | 53 +++++--- 23 files changed, 508 insertions(+), 371 deletions(-) diff --git a/lm_human_preference_details/summarization/minimal_rm copy.py b/lm_human_preference_details/summarization/minimal_rm copy.py index 7049f0f..0b6ae67 100644 --- a/lm_human_preference_details/summarization/minimal_rm copy.py +++ b/lm_human_preference_details/summarization/minimal_rm copy.py @@ -1,7 +1,8 @@ -import numpy as np import torch import torch.nn as nn from transformers import AutoModelForCausalLM, AutoTokenizer + + class AutoModelForCausalLMWithRewardHead(nn.Module): def __init__(self, lm_backbone): super().__init__() @@ -14,6 +15,8 @@ def forward(self, **kwargs): # shape: [batch_size, hidden_size] reward = self.scalar_head(last_reward_latents) return output, reward + + base_model = "gpt2" tokenizer = AutoTokenizer.from_pretrained(base_model, padding_side="left") reward_model = AutoModelForCausalLMWithRewardHead(AutoModelForCausalLM.from_pretrained(base_model)) @@ -22,4 +25,4 @@ def forward(self, **kwargs): mb_query_tiled = mb_query.unsqueeze(1).repeat(1, mb_responses.shape[1], 1) query_responses = torch.cat([mb_query_tiled, mb_responses], dim=2).flatten(0, 1) _, score = reward_model(input_ids=query_responses, return_dict=True, output_hidden_states=True) -print(score.squeeze(2)) \ No newline at end of file +print(score.squeeze(2)) diff --git a/lm_human_preference_details/summarization/minimal_rm.py b/lm_human_preference_details/summarization/minimal_rm.py index 0cb4179..1c993d0 100644 --- a/lm_human_preference_details/summarization/minimal_rm.py +++ b/lm_human_preference_details/summarization/minimal_rm.py @@ -1,4 +1,3 @@ -import numpy as np import torch import torch.nn as nn from transformers import AutoModelForCausalLM, AutoTokenizer @@ -31,15 +30,16 @@ def get_reward(reward_model, query_responses, tokenizer): output_hidden_states=True, ) + base_model = "gpt2" tokenizer = AutoTokenizer.from_pretrained(base_model, padding_side="left") tokenizer.add_special_tokens({"pad_token": "[PAD]"}) reward_model = AutoModelForCausalLMWithRewardHead(AutoModelForCausalLM.from_pretrained(base_model)) reward_model.train() mb_query = torch.randint(0, len(tokenizer), (1, 10)) -mb_query[:,0:4] = tokenizer.pad_token_id +mb_query[:, 0:4] = tokenizer.pad_token_id mb_responses = torch.randint(0, len(tokenizer), (1, 2, 10)) mb_query_tiled = mb_query.unsqueeze(1).repeat(1, mb_responses.shape[1], 1) query_responses = torch.cat([mb_query_tiled, mb_responses], dim=2).flatten(0, 1) _, score_all = get_reward(reward_model, query_responses, tokenizer) -print(score_all.squeeze(2)) \ No newline at end of file +print(score_all.squeeze(2)) diff --git a/lm_human_preference_details/summarization/minisft.py b/lm_human_preference_details/summarization/minisft.py index fede737..85b7cbd 100644 --- a/lm_human_preference_details/summarization/minisft.py +++ b/lm_human_preference_details/summarization/minisft.py @@ -1,32 +1,23 @@ -import collections import os import random import time from dataclasses import asdict, dataclass, field from types import SimpleNamespace -from typing import List, Optional +from typing import Optional import numpy as np -import pandas as pd import torch import torch.optim as optim -from torch.nn import functional as F import tyro -import evaluate from accelerate import Accelerator from datasets import load_dataset from rich.console import Console from rich.pretty import pprint -from rich.table import Table -from torch import Tensor, optim -from torch.optim.optimizer import ( - _dispatch_sqrt, - _get_value, - _use_grad_for_differentiable, -) +from torch import optim +from torch.nn import functional as F from torch.utils.data import DataLoader from torch.utils.tensorboard import SummaryWriter -from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig +from transformers import AutoModelForCausalLM, AutoTokenizer from lm_human_preference_details.data import process_query @@ -40,7 +31,7 @@ class SFTHParams: eps: float = 1e-5 lm_loss_on_response_only: bool = False total_episodes: tyro.conf.Suppress[int] = None - local_batch_size:tyro.conf.Suppress[int] = None + local_batch_size: tyro.conf.Suppress[int] = None batch_size: tyro.conf.Suppress[int] = None mini_batch_size: tyro.conf.Suppress[int] = None world_size: tyro.conf.Suppress[int] = None @@ -228,7 +219,10 @@ def process_query_data(x): return { **process_query(x, encoder=tokenizer, hparams=patch_h), "reference_response": tokenizer.encode( - f" {x['summary']}<|endoftext|>", padding="max_length", max_length=args.task.response_length, truncation=True, + f" {x['summary']}<|endoftext|>", + padding="max_length", + max_length=args.task.response_length, + truncation=True, # with an extra leading space to account for the space between the query and response ), } @@ -282,7 +276,7 @@ def process_query_data(x): labels = query_responses.masked_fill(query_responses == tokenizer.pad_token_id, -1) if args.sft.lm_loss_on_response_only: # mask out gradient effects on query tokens - labels[:, :queries.shape[1]] = -1 + labels[:, : queries.shape[1]] = -1 lm_logits = output.logits # hand-rolled transformer loss: Shift so that tokens < n predict n # but unlike `transformers` we mask the padding tokens via `ignore_index=-1` diff --git a/lm_human_preference_details/summarization/train_policy_accelerate copy 2.py b/lm_human_preference_details/summarization/train_policy_accelerate copy 2.py index b77f275..1b5943d 100644 --- a/lm_human_preference_details/summarization/train_policy_accelerate copy 2.py +++ b/lm_human_preference_details/summarization/train_policy_accelerate copy 2.py @@ -660,7 +660,7 @@ def train(args: Args): # 5. whiten rewards if args.ppo.whiten_rewards: rewards = whiten(rewards, shift_mean=False) - + if args.print_sample_output_freq > 0 and (update - 1) % args.print_sample_output_freq == 0: try: all_decode_queries = tokenizer.batch_decode(queries, skip_special_tokens=True) diff --git a/lm_human_preference_details/summarization/train_policy_accelerate copy.py b/lm_human_preference_details/summarization/train_policy_accelerate copy.py index a975666..e9e6d84 100644 --- a/lm_human_preference_details/summarization/train_policy_accelerate copy.py +++ b/lm_human_preference_details/summarization/train_policy_accelerate copy.py @@ -487,12 +487,16 @@ def train(args: Args): ) # we use the padding token manually but do not resize the token embedding of the model tokenizer.add_special_tokens({"pad_token": "[PAD]"}) - reward_model = AutoModelForCausalLMWithRewardHead(AutoModelForCausalLM.from_pretrained(args.base_model, trust_remote_code=True)) + reward_model = AutoModelForCausalLMWithRewardHead( + AutoModelForCausalLM.from_pretrained(args.base_model, trust_remote_code=True) + ) if args.rewards.trained_model: reward_model.load_state_dict(torch.load(args.rewards.trained_model, map_location=device)) print(f"loaded pretrained reward model from {args.rewards.trained_model}") # each class should have a separate pretrained model that do not share weights - ref_policy = AutoModelForCausalLMWithScalarHead(AutoModelForCausalLM.from_pretrained(args.base_model, trust_remote_code=True)) + ref_policy = AutoModelForCausalLMWithScalarHead( + AutoModelForCausalLM.from_pretrained(args.base_model, trust_remote_code=True) + ) policy = AutoModelForCausalLMWithScalarHead(AutoModelForCausalLM.from_pretrained(args.base_model, trust_remote_code=True)) policy.lm_backbone.generation_config.eos_token_id = ( None # disable `pad_token_id` and `eos_token_id` because we just want to @@ -518,12 +522,10 @@ def train(args: Args): import deepspeed deepspeed_states = AcceleratorState().deepspeed_plugin - deepspeed_states.deepspeed_config['train_micro_batch_size_per_gpu'] = args.ppo.local_micro_batch_size - deepspeed_states.deepspeed_config['checkpoint'] = {'use_node_local_storage': True} - off_load_device = "cpu" - stage = 3 + deepspeed_states.deepspeed_config["train_micro_batch_size_per_gpu"] = args.ppo.local_micro_batch_size + deepspeed_states.deepspeed_config["checkpoint"] = {"use_node_local_storage": True} eval_ds_config = { - "train_micro_batch_size_per_gpu": deepspeed_states.deepspeed_config['train_micro_batch_size_per_gpu'], + "train_micro_batch_size_per_gpu": deepspeed_states.deepspeed_config["train_micro_batch_size_per_gpu"], "steps_per_print": 10, # "zero_optimization": { # "stage": stage, @@ -532,11 +534,9 @@ def train(args: Args): # "device": off_load_device # } # }, - "bf16": { - "enabled": True - }, + "bf16": {"enabled": True}, "prescale_gradients": False, - "wall_clock_breakdown": False + "wall_clock_breakdown": False, } reward_model, *_ = deepspeed.initialize(model=reward_model, config=eval_ds_config) reward_model.eval() @@ -632,7 +632,7 @@ def train(args: Args): logits /= args.task.temperature all_logprobs = F.log_softmax(logits, dim=-1) logprobs = torch.gather(all_logprobs, 2, responses.unsqueeze(-1)).squeeze(-1) - + output4, _ = forward(policy, query_responses, tokenizer) logits4 = output4.logits[:, context_length - 1 : -1] logits4 /= args.task.temperature @@ -750,16 +750,12 @@ def train(args: Args): with accelerator.accumulate(policy): - - output, vpred_temp = forward(policy, mb_query_responses, tokenizer) logits = output.logits[:, context_length - 1 : -1] logits /= args.task.temperature new_all_logprobs = F.log_softmax(logits, dim=-1) new_logprobs = torch.gather(new_all_logprobs, 2, mb_responses.unsqueeze(-1)).squeeze(-1) - - vpred = vpred_temp[:, context_length - 1 : -1].squeeze(-1) vpredclipped = torch.clamp( vpred, @@ -771,12 +767,14 @@ def train(args: Args): vf_loss = 0.5 * torch.max(vf_losses1, vf_losses2).mean() vf_clipfrac = (vf_losses2 > vf_losses1).float().mean() logprobs_diff = new_logprobs - mb_logprobs - pprint({ - "new_logprobs": new_logprobs, - "new_logprobs2": new_logprobs2, - "mb_logprobs": mb_logprobs, - "mb_logprobs2": logprobs4[micro_batch_inds], - }) + pprint( + { + "new_logprobs": new_logprobs, + "new_logprobs2": new_logprobs2, + "mb_logprobs": mb_logprobs, + "mb_logprobs2": logprobs4[micro_batch_inds], + } + ) ratio = torch.exp(logprobs_diff) print(ratio.mean()) breakpoint() @@ -815,7 +813,7 @@ def train(args: Args): ) with torch.no_grad(): - if not args.deepspeed: # for some reason there is a OOM with the `writer.add_histogram` + if not args.deepspeed: # for some reason there is a OOM with the `writer.add_histogram` writer.add_histogram("ppo/val/ratio_hist", ratio, update) kl = logprobs - ref_logprobs mean_kl = kl.sum(1).mean() diff --git a/lm_human_preference_details/summarization/train_policy_accelerate_new.py b/lm_human_preference_details/summarization/train_policy_accelerate_new.py index e9f296d..3187431 100644 --- a/lm_human_preference_details/summarization/train_policy_accelerate_new.py +++ b/lm_human_preference_details/summarization/train_policy_accelerate_new.py @@ -487,12 +487,16 @@ def train(args: Args): ) # we use the padding token manually but do not resize the token embedding of the model tokenizer.add_special_tokens({"pad_token": "[PAD]"}) - reward_model = AutoModelForCausalLMWithRewardHead(AutoModelForCausalLM.from_pretrained(args.base_model, trust_remote_code=True)) + reward_model = AutoModelForCausalLMWithRewardHead( + AutoModelForCausalLM.from_pretrained(args.base_model, trust_remote_code=True) + ) if args.rewards.trained_model: reward_model.load_state_dict(torch.load(args.rewards.trained_model, map_location=device)) print(f"loaded pretrained reward model from {args.rewards.trained_model}") # each class should have a separate pretrained model that do not share weights - ref_policy = AutoModelForCausalLMWithScalarHead(AutoModelForCausalLM.from_pretrained(args.base_model, trust_remote_code=True)) + ref_policy = AutoModelForCausalLMWithScalarHead( + AutoModelForCausalLM.from_pretrained(args.base_model, trust_remote_code=True) + ) policy = AutoModelForCausalLMWithScalarHead(AutoModelForCausalLM.from_pretrained(args.base_model, trust_remote_code=True)) policy.lm_backbone.generation_config.eos_token_id = ( None # disable `pad_token_id` and `eos_token_id` because we just want to @@ -518,12 +522,10 @@ def train(args: Args): import deepspeed deepspeed_states = AcceleratorState().deepspeed_plugin - deepspeed_states.deepspeed_config['train_micro_batch_size_per_gpu'] = args.ppo.local_micro_batch_size - deepspeed_states.deepspeed_config['checkpoint'] = {'use_node_local_storage': True} - off_load_device = "cpu" - stage = 3 + deepspeed_states.deepspeed_config["train_micro_batch_size_per_gpu"] = args.ppo.local_micro_batch_size + deepspeed_states.deepspeed_config["checkpoint"] = {"use_node_local_storage": True} eval_ds_config = { - "train_micro_batch_size_per_gpu": deepspeed_states.deepspeed_config['train_micro_batch_size_per_gpu'], + "train_micro_batch_size_per_gpu": deepspeed_states.deepspeed_config["train_micro_batch_size_per_gpu"], "steps_per_print": 10, # "zero_optimization": { # "stage": stage, @@ -532,11 +534,9 @@ def train(args: Args): # "device": off_load_device # } # }, - "bf16": { - "enabled": True - }, + "bf16": {"enabled": True}, "prescale_gradients": False, - "wall_clock_breakdown": False + "wall_clock_breakdown": False, } reward_model, *_ = deepspeed.initialize(model=reward_model, config=eval_ds_config) reward_model.eval() @@ -745,8 +745,8 @@ def train(args: Args): mb_responses = responses[micro_batch_inds] mb_query_responses = query_responses[micro_batch_inds] - # re-calculate logprobs and values for the first epoch, otherwise `bf16` will cause the logprobs to - # be much different because the logprobs are with a batch size of `local_batch_size` but the + # re-calculate logprobs and values for the first epoch, otherwise `bf16` will cause the logprobs to + # be much different because the logprobs are with a batch size of `local_batch_size` but the # `new_logprobs` are with a batch size of `local_micro_batch_size` if ppo_epoch_idx == 0: with torch.no_grad(): @@ -820,7 +820,7 @@ def train(args: Args): breakpoint() with torch.no_grad(): - if not args.deepspeed: # for some reason there is a OOM with the `writer.add_histogram` + if not args.deepspeed: # for some reason there is a OOM with the `writer.add_histogram` writer.add_histogram("ppo/val/ratio_hist", ratio, update) kl = logprobs - ref_logprobs mean_kl = kl.sum(1).mean() diff --git a/lm_human_preference_details/summarization/train_policy_accelerate_old.py b/lm_human_preference_details/summarization/train_policy_accelerate_old.py index de27920..1c9fc5a 100644 --- a/lm_human_preference_details/summarization/train_policy_accelerate_old.py +++ b/lm_human_preference_details/summarization/train_policy_accelerate_old.py @@ -487,12 +487,16 @@ def train(args: Args): ) # we use the padding token manually but do not resize the token embedding of the model tokenizer.add_special_tokens({"pad_token": "[PAD]"}) - reward_model = AutoModelForCausalLMWithRewardHead(AutoModelForCausalLM.from_pretrained(args.base_model, trust_remote_code=True)) + reward_model = AutoModelForCausalLMWithRewardHead( + AutoModelForCausalLM.from_pretrained(args.base_model, trust_remote_code=True) + ) if args.rewards.trained_model: reward_model.load_state_dict(torch.load(args.rewards.trained_model, map_location=device)) print(f"loaded pretrained reward model from {args.rewards.trained_model}") # each class should have a separate pretrained model that do not share weights - ref_policy = AutoModelForCausalLMWithScalarHead(AutoModelForCausalLM.from_pretrained(args.base_model, trust_remote_code=True)) + ref_policy = AutoModelForCausalLMWithScalarHead( + AutoModelForCausalLM.from_pretrained(args.base_model, trust_remote_code=True) + ) policy = AutoModelForCausalLMWithScalarHead(AutoModelForCausalLM.from_pretrained(args.base_model, trust_remote_code=True)) policy.lm_backbone.generation_config.eos_token_id = ( None # disable `pad_token_id` and `eos_token_id` because we just want to @@ -518,12 +522,10 @@ def train(args: Args): import deepspeed deepspeed_states = AcceleratorState().deepspeed_plugin - deepspeed_states.deepspeed_config['train_micro_batch_size_per_gpu'] = args.ppo.local_micro_batch_size - deepspeed_states.deepspeed_config['checkpoint'] = {'use_node_local_storage': True} - off_load_device = "cpu" - stage = 3 + deepspeed_states.deepspeed_config["train_micro_batch_size_per_gpu"] = args.ppo.local_micro_batch_size + deepspeed_states.deepspeed_config["checkpoint"] = {"use_node_local_storage": True} eval_ds_config = { - "train_micro_batch_size_per_gpu": deepspeed_states.deepspeed_config['train_micro_batch_size_per_gpu'], + "train_micro_batch_size_per_gpu": deepspeed_states.deepspeed_config["train_micro_batch_size_per_gpu"], "steps_per_print": 10, # "zero_optimization": { # "stage": stage, @@ -532,11 +534,9 @@ def train(args: Args): # "device": off_load_device # } # }, - "bf16": { - "enabled": True - }, + "bf16": {"enabled": True}, "prescale_gradients": False, - "wall_clock_breakdown": False + "wall_clock_breakdown": False, } reward_model, *_ = deepspeed.initialize(model=reward_model, config=eval_ds_config) reward_model.eval() @@ -790,7 +790,7 @@ def train(args: Args): ) with torch.no_grad(): - if not args.deepspeed: # for some reason there is a OOM with the `writer.add_histogram` + if not args.deepspeed: # for some reason there is a OOM with the `writer.add_histogram` writer.add_histogram("ppo/val/ratio_hist", ratio, update) kl = logprobs - ref_logprobs mean_kl = kl.sum(1).mean() diff --git a/lm_human_preference_details/summarization/train_policy_accelerate_summarize_ref_diff.py b/lm_human_preference_details/summarization/train_policy_accelerate_summarize_ref_diff.py index ee56755..50aca9e 100644 --- a/lm_human_preference_details/summarization/train_policy_accelerate_summarize_ref_diff.py +++ b/lm_human_preference_details/summarization/train_policy_accelerate_summarize_ref_diff.py @@ -536,7 +536,10 @@ def process_query_data(x): return { **process_query(x, encoder=tokenizer, hparams=patch_h), "reference_response": tokenizer.encode( - f" {x['summary']}", padding="max_length", max_length=args.task.response_length, truncation=True, + f" {x['summary']}", + padding="max_length", + max_length=args.task.response_length, + truncation=True, # with an extra leading space to account for the space between the query and response ), } @@ -611,7 +614,9 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well reference_responses = data["reference_response"].to(device) query_reference_responses = torch.cat((queries, reference_responses), dim=1) queries = right_padding_to_left_padding(data["query_token"], tokenizer.pad_token_id).to(device) - query_reference_responses = right_padding_to_left_padding(query_reference_responses, tokenizer.pad_token_id).to(device) + query_reference_responses = right_padding_to_left_padding(query_reference_responses, tokenizer.pad_token_id).to( + device + ) query_responses = generate( accelerator.unwrap_model(policy).lm_backbone, queries, @@ -624,7 +629,7 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well output, full_values = forward(policy, query_responses, tokenizer) values = full_values[:, context_length - 1 : -1].squeeze(-1) logits = output.logits[:, context_length - 1 : -1] - logits /= (args.task.temperature + 1e-7) + logits /= args.task.temperature + 1e-7 all_logprobs = F.log_softmax(logits, dim=-1) logprobs = torch.gather(all_logprobs, 2, responses.unsqueeze(-1)).squeeze(-1) del output, logits, all_logprobs @@ -632,7 +637,7 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well ref_output, _ = forward(ref_policy, query_responses, tokenizer) ref_logits = ref_output.logits[:, context_length - 1 : -1] - ref_logits /= (args.task.temperature + 1e-7) + ref_logits /= args.task.temperature + 1e-7 ref_all_logprobs = F.log_softmax(ref_logits, dim=-1) ref_logprobs = torch.gather(ref_all_logprobs, 2, responses.unsqueeze(-1)).squeeze(-1) del ref_output, ref_logits, ref_all_logprobs @@ -777,7 +782,7 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well output, vpred_temp = forward(policy, mb_query_responses, tokenizer) logits = output.logits[:, context_length - 1 : -1] - logits /= (args.task.temperature + 1e-7) + logits /= args.task.temperature + 1e-7 new_all_logprobs = F.log_softmax(logits, dim=-1) new_logprobs = torch.gather(new_all_logprobs, 2, mb_responses.unsqueeze(-1)).squeeze(-1) vpred = vpred_temp[:, context_length - 1 : -1].squeeze(-1) diff --git a/lm_human_preference_details/summarization/train_reward_accelerate copy.py b/lm_human_preference_details/summarization/train_reward_accelerate copy.py index 11e26d0..aae124c 100644 --- a/lm_human_preference_details/summarization/train_reward_accelerate copy.py +++ b/lm_human_preference_details/summarization/train_reward_accelerate copy.py @@ -493,8 +493,12 @@ def train(args: Args): # we use the padding token manually but do not resize the token embedding of the model tokenizer.add_special_tokens({"pad_token": "[PAD]"}) args.pad_token_id = tokenizer.pad_token_id - untrained_model = AutoModelForCausalLMWithRewardHead(AutoModelForCausalLM.from_pretrained(args.base_model, trust_remote_code=True)).to(device) - reward_model = AutoModelForCausalLMWithRewardHead(AutoModelForCausalLM.from_pretrained(args.base_model, trust_remote_code=True)).to(device) + untrained_model = AutoModelForCausalLMWithRewardHead( + AutoModelForCausalLM.from_pretrained(args.base_model, trust_remote_code=True) + ).to(device) + reward_model = AutoModelForCausalLMWithRewardHead( + AutoModelForCausalLM.from_pretrained(args.base_model, trust_remote_code=True) + ).to(device) untrained_model.lm_backbone.generation_config.eos_token_id = ( None # disable `pad_token_id` and `eos_token_id` because we just want to ) diff --git a/lm_human_preference_details/summarization/train_reward_accelerate_debug copy.py b/lm_human_preference_details/summarization/train_reward_accelerate_debug copy.py index 5113045..91e7a56 100644 --- a/lm_human_preference_details/summarization/train_reward_accelerate_debug copy.py +++ b/lm_human_preference_details/summarization/train_reward_accelerate_debug copy.py @@ -5,16 +5,16 @@ from types import SimpleNamespace from typing import Optional -from accelerate import Accelerator -from accelerate.utils import DistributedDataParallelKwargs, broadcast import numpy as np import torch import torch.nn as nn -import torch.optim as optim import torch.nn.functional as F +import torch.optim as optim import tyro -from rich.console import Console +from accelerate import Accelerator +from accelerate.utils import DistributedDataParallelKwargs, broadcast from datasets import load_dataset +from rich.console import Console from rich.pretty import pprint from torch.utils.data import DataLoader, IterableDataset from torch.utils.tensorboard import SummaryWriter @@ -22,6 +22,7 @@ from lm_human_preference_details.datamod import DATASET + @dataclass class LabelHParams: type: str = None @@ -50,7 +51,7 @@ class TaskHParams: @dataclass class Args: # common args - exp_name: str = os.path.basename(__file__)[:-len(".py")] + exp_name: str = os.path.basename(__file__)[: -len(".py")] """the name of this experiment""" seed: int = 1 """seed of the experiment""" @@ -160,7 +161,9 @@ def forward(self, **kwargs): # a pytorch dataset class MyDataset(IterableDataset): - def __init__(self, generator, tokenizer, query_length, start_text=None, end_text=None, query_prefix="", query_suffix="", seed=None): + def __init__( + self, generator, tokenizer, query_length, start_text=None, end_text=None, query_prefix="", query_suffix="", seed=None + ): self.generator = generator self.tokenizer = tokenizer self.query_length = query_length @@ -175,7 +178,6 @@ def __init__(self, generator, tokenizer, query_length, start_text=None, end_text self.query_prefix_tokens = torch.LongTensor(tokenizer.encode(query_prefix)) self.query_suffix_tokens = torch.LongTensor(tokenizer.encode(query_suffix)) - def __iter__(self): for text in self.generator("train", self.seed, shuffle=True): tokens = self.tokenizer.encode(text) @@ -206,10 +208,7 @@ def __iter__(self): def left_padding_to_right_padding(query, pad_id): # got to convert to right padding, otherwise `transformers` has weird issues # even with `position_ids` - return torch.tensor([ - [pad_id]*(row==pad_id).sum() + [x for x in row if x != pad_id] - for row in query - ]) + return torch.tensor([[pad_id] * (row == pad_id).sum() + [x for x in row if x != pad_id] for row in query]) def ceil_div(a, b): @@ -221,7 +220,7 @@ def generate(pretrained_model, queries, tokenizer, generation_config): context_length = queries.shape[1] attention_mask = queries != tokenizer.pad_token_id input_ids = queries.clone() - input_ids[~attention_mask] = 0 # set padding tokens to 0 + input_ids[~attention_mask] = 0 # set padding tokens to 0 output = pretrained_model.generate( input_ids=input_ids, attention_mask=attention_mask, @@ -229,13 +228,13 @@ def generate(pretrained_model, queries, tokenizer, generation_config): generation_config=generation_config, return_dict_in_generate=True, ) - # restore padding tokens + # restore padding tokens return torch.cat((queries, output.sequences[:, context_length:]), dim=1) def get_reward(reward_model, query_responses, tokenizer): attention_mask = query_responses != tokenizer.pad_token_id - position_ids = attention_mask.cumsum(1) - attention_mask.long() # exclusive cumsum + position_ids = attention_mask.cumsum(1) - attention_mask.long() # exclusive cumsum input_ids = query_responses.clone() input_ids[~attention_mask] = 0 return reward_model( @@ -246,6 +245,7 @@ def get_reward(reward_model, query_responses, tokenizer): output_hidden_states=True, ) + def normalize(args, accelerator, device, tokenizer, pretrained_model, reward_model, iter_dataloader, generation_config): with torch.no_grad(): # reset reward scales @@ -261,13 +261,13 @@ def normalize(args, accelerator, device, tokenizer, pretrained_model, reward_mod queries = left_padding_to_right_padding(data["input_ids"], tokenizer.pad_token_id).to(device) query_responses = generate(pretrained_model, queries, tokenizer, generation_config) sample_queries_responses.append(query_responses) - + # compute reward statistics rewards = [] for query_responses in sample_queries_responses: rewards.append(get_reward(reward_model, query_responses, tokenizer)[1]) rewards = torch.cat(rewards) - rewards= accelerator.gather(rewards) + rewards = accelerator.gather(rewards) mean, std = rewards.mean(), rewards.std() print(f"mean: {mean}, std: {std}") @@ -292,7 +292,7 @@ def normalize(args, accelerator, device, tokenizer, pretrained_model, reward_mod for query_responses in sample_queries_responses: rewards.append(get_reward(reward_model, query_responses, tokenizer)[1]) rewards = torch.cat(rewards) - rewards= accelerator.gather(rewards) + rewards = accelerator.gather(rewards) mean, std = rewards.mean(), rewards.std() print(f"after mean: {mean}, after std: {std}") @@ -301,14 +301,16 @@ def train(args: Args): args.task.query_prefix = args.task.query_prefix.replace("\\n", "\n") args.task.query_suffix = args.task.query_suffix.replace("\\n", "\n") accelerator = Accelerator( - kwargs_handlers=[DistributedDataParallelKwargs(broadcast_buffers=False)] # this is needed to avoid https://github.com/pytorch/pytorch/issues/22095#issuecomment-505099500 + kwargs_handlers=[ + DistributedDataParallelKwargs(broadcast_buffers=False) + ] # this is needed to avoid https://github.com/pytorch/pytorch/issues/22095#issuecomment-505099500 ) args.world_size = accelerator.num_processes args.batch_size = int(args.local_batch_size * args.world_size) console = Console(force_terminal=True) run_name = f"{args.exp_name}__{args.seed}__{int(time.time())}" - writer = SimpleNamespace() # dummy writer + writer = SimpleNamespace() # dummy writer writer.add_scalar = lambda x, y, z: None if accelerator.is_main_process: if args.track: @@ -343,7 +345,9 @@ def train(args: Args): tokenizer.add_special_tokens({"pad_token": "[PAD]"}) untrained_model = AutoModelForCausalLMWithRewardHead(AutoModelForCausalLM.from_pretrained(args.base_model)).to(device) reward_model = AutoModelForCausalLMWithRewardHead(AutoModelForCausalLM.from_pretrained(args.base_model)).to(device) - reward_model.pretrained_model.generation_config.eos_token_id = None # disable `pad_token_id` and `eos_token_id` because we just want to + reward_model.pretrained_model.generation_config.eos_token_id = ( + None # disable `pad_token_id` and `eos_token_id` because we just want to + ) reward_model.pretrained_model.generation_config.pad_token_id = None # generate tokens without truncation / padding optimizer = optim.Adam(reward_model.parameters(), lr=args.lr, eps=args.eps) dataset = MyDataset( @@ -379,7 +383,16 @@ def train(args: Args): print("before====", reward_model.module.reward_gain.data) if args.normalize_before: - normalize(args, accelerator, device, tokenizer, accelerator.unwrap_model(reward_model).pretrained_model, reward_model, iter_dataloader, generation_config) + normalize( + args, + accelerator, + device, + tokenizer, + accelerator.unwrap_model(reward_model).pretrained_model, + reward_model, + iter_dataloader, + generation_config, + ) print("after====", reward_model.module.reward_gain.data) print("===training reward model===") @@ -393,7 +406,7 @@ def train(args: Args): global_step += 1 end = start + args.batch_size b_inds_all = all_inds[start:end] - b_inds = b_inds_all[accelerator.process_index::accelerator.num_processes] # multi-GPU slicing + b_inds = b_inds_all[accelerator.process_index :: accelerator.num_processes] # multi-GPU slicing lr = (1 - start / args.labels.num_train) * args.lr optimizer.param_groups[0]["lr"] = lr mb_data = label[b_inds] @@ -402,10 +415,7 @@ def train(args: Args): print("mb_query.shape", mb_query.shape) mb_query = left_padding_to_right_padding(mb_query, tokenizer.pad_token_id).to(device) mb_best = torch.from_numpy(np.stack(mb_data["best"])).to(device) - mb_responses = [ - torch.from_numpy(np.stack(mb_data[f"sample{i}"])).to(device) - for i in range(args.labels.num_labels) - ] + mb_responses = [torch.from_numpy(np.stack(mb_data[f"sample{i}"])).to(device) for i in range(args.labels.num_labels)] # hack: deal with openai's padding token # assert (mb_query == tokenizer.pad_token_id).sum() == 0 mb_query[mb_query == OPENAI_PAD_TOKEN_ID] = tokenizer.pad_token_id @@ -417,9 +427,7 @@ def train(args: Args): for i in range(args.labels.num_labels): query_responses = torch.cat([mb_query, mb_responses[i]], dim=1) reward = get_reward(reward_model, query_responses, tokenizer)[1] - predicted_rewards.append( - reward.squeeze() - ) + predicted_rewards.append(reward.squeeze()) predicted_rewards = torch.stack( predicted_rewards, dim=1 ) # shape (batch_size, num_labels), basically a reward prediction for each label @@ -438,17 +446,19 @@ def train(args: Args): queries = data["input_ids"].to(device) context_length = queries.shape[1] queries = left_padding_to_right_padding(data["input_ids"], tokenizer.pad_token_id).to(device) - query_responses = generate(accelerator.unwrap_model(reward_model).pretrained_model, queries, tokenizer, generation_config) + query_responses = generate( + accelerator.unwrap_model(reward_model).pretrained_model, queries, tokenizer, generation_config + ) responses = query_responses[:, context_length:] output, reward = get_reward(reward_model, query_responses, tokenizer) - logits = output.logits[:,context_length-1:-1] + logits = output.logits[:, context_length - 1 : -1] logits /= args.task.temperature all_logprobs = F.log_softmax(logits, dim=-1) logprobs = torch.gather(all_logprobs, 2, responses.unsqueeze(-1)).squeeze(-1) output, _ = get_reward(untrained_model, query_responses, tokenizer) - logits = output.logits[:,context_length-1:-1] + logits = output.logits[:, context_length - 1 : -1] logits /= args.task.temperature all_logprobs = F.log_softmax(logits, dim=-1) ref_logprobs = torch.gather(all_logprobs, 2, responses.unsqueeze(-1)).squeeze(-1) @@ -470,15 +480,14 @@ def train(args: Args): for start in range(args.labels.num_train, len(label), args.batch_size): end = start + args.batch_size b_inds_all = all_inds[start:end] - b_inds = b_inds_all[accelerator.process_index::accelerator.num_processes] # multi-GPU slicing + b_inds = b_inds_all[accelerator.process_index :: accelerator.num_processes] # multi-GPU slicing mb_data = label[b_inds] # print("accelerator.process_index", accelerator.process_index, b_inds, b_inds_all) mb_query = torch.from_numpy(np.stack(mb_data["query"])) mb_query = left_padding_to_right_padding(mb_query, tokenizer.pad_token_id).to(device) mb_best = torch.from_numpy(np.stack(mb_data["best"])).to(device) mb_responses = [ - torch.from_numpy(np.stack(mb_data[f"sample{i}"])).to(device) - for i in range(args.labels.num_labels) + torch.from_numpy(np.stack(mb_data[f"sample{i}"])).to(device) for i in range(args.labels.num_labels) ] # hack: deal with openai's padding token # assert (mb_query == tokenizer.pad_token_id).sum() == 0 @@ -495,9 +504,7 @@ def train(args: Args): print(tokenizer.decode(mb_responses[i], skip_special_tokens=True)) breakpoint() reward = get_reward(reward_model, query_responses, tokenizer)[1] - predicted_rewards.append( - reward.squeeze() - ) + predicted_rewards.append(reward.squeeze()) predicted_rewards = torch.stack( predicted_rewards, dim=1 ) # shape (batch_size, num_labels), basically a reward prediction for each label @@ -510,7 +517,16 @@ def train(args: Args): torch.cuda.empty_cache() if args.normalize_after: - normalize(args, accelerator, device, tokenizer, accelerator.unwrap_model(reward_model).pretrained_model, reward_model, iter_dataloader, generation_config) + normalize( + args, + accelerator, + device, + tokenizer, + accelerator.unwrap_model(reward_model).pretrained_model, + reward_model, + iter_dataloader, + generation_config, + ) # save model if args.save_path: diff --git a/lm_human_preference_details/summarization/train_reward_accelerate_debug.py b/lm_human_preference_details/summarization/train_reward_accelerate_debug.py index e4811b1..9a8a4ec 100644 --- a/lm_human_preference_details/summarization/train_reward_accelerate_debug.py +++ b/lm_human_preference_details/summarization/train_reward_accelerate_debug.py @@ -5,16 +5,16 @@ from types import SimpleNamespace from typing import Optional -from accelerate import Accelerator -from accelerate.utils import DistributedDataParallelKwargs, broadcast import numpy as np import torch import torch.nn as nn -import torch.optim as optim import torch.nn.functional as F +import torch.optim as optim import tyro -from rich.console import Console +from accelerate import Accelerator +from accelerate.utils import DistributedDataParallelKwargs, broadcast from datasets import load_dataset +from rich.console import Console from rich.pretty import pprint from torch.utils.data import DataLoader, IterableDataset from torch.utils.tensorboard import SummaryWriter @@ -22,6 +22,7 @@ from lm_human_preference_details.datamod import DATASET + @dataclass class LabelHParams: type: str = None @@ -50,7 +51,7 @@ class TaskHParams: @dataclass class Args: # common args - exp_name: str = os.path.basename(__file__)[:-len(".py")] + exp_name: str = os.path.basename(__file__)[: -len(".py")] """the name of this experiment""" seed: int = 1 """seed of the experiment""" @@ -201,10 +202,7 @@ def __iter__(self): def left_padding_to_right_padding(query, pad_id): # got to convert to right padding, otherwise `transformers` has weird issues # even with `position_ids` - return torch.tensor([ - [pad_id]*(row==pad_id).sum() + [x for x in row if x != pad_id] - for row in query - ]) + return torch.tensor([[pad_id] * (row == pad_id).sum() + [x for x in row if x != pad_id] for row in query]) def ceil_div(a, b): @@ -216,7 +214,7 @@ def generate(pretrained_model, queries, tokenizer, generation_config): context_length = queries.shape[1] attention_mask = queries != tokenizer.pad_token_id input_ids = queries.clone() - input_ids[~attention_mask] = 0 # set padding tokens to 0 + input_ids[~attention_mask] = 0 # set padding tokens to 0 output = pretrained_model.generate( input_ids=input_ids, attention_mask=attention_mask, @@ -224,13 +222,13 @@ def generate(pretrained_model, queries, tokenizer, generation_config): generation_config=generation_config, return_dict_in_generate=True, ) - # restore padding tokens + # restore padding tokens return torch.cat((queries, output.sequences[:, context_length:]), dim=1) def get_reward(reward_model, query_responses, tokenizer): attention_mask = query_responses != tokenizer.pad_token_id - position_ids = attention_mask.cumsum(1) - attention_mask.long() # exclusive cumsum + position_ids = attention_mask.cumsum(1) - attention_mask.long() # exclusive cumsum input_ids = query_responses.clone() input_ids[~attention_mask] = 0 return reward_model( @@ -241,7 +239,19 @@ def get_reward(reward_model, query_responses, tokenizer): output_hidden_states=True, ) -def normalize(args, accelerator, device, tokenizer, pretrained_model, reward_model, iter_dataloader, generation_config, query_prefix_tokens, query_suffix_tokens): + +def normalize( + args, + accelerator, + device, + tokenizer, + pretrained_model, + reward_model, + iter_dataloader, + generation_config, + query_prefix_tokens, + query_suffix_tokens, +): with torch.no_grad(): # reset reward scales reward_model.module.reward_gain.data.fill_(1.0) @@ -257,13 +267,13 @@ def normalize(args, accelerator, device, tokenizer, pretrained_model, reward_mod queries = left_padding_to_right_padding(data["input_ids"], tokenizer.pad_token_id).to(device) query_responses = generate(pretrained_model, queries, tokenizer, generation_config) sample_queries_responses.append(query_responses) - + # compute reward statistics rewards = [] for query_responses in sample_queries_responses: rewards.append(get_reward(reward_model, query_responses, tokenizer)[1]) rewards = torch.cat(rewards) - rewards= accelerator.gather(rewards) + rewards = accelerator.gather(rewards) mean, std = rewards.mean(), rewards.std() print(f"mean: {mean}, std: {std}") @@ -289,7 +299,7 @@ def normalize(args, accelerator, device, tokenizer, pretrained_model, reward_mod for query_responses in sample_queries_responses: rewards.append(get_reward(reward_model, query_responses, tokenizer)[1]) rewards = torch.cat(rewards) - rewards= accelerator.gather(rewards) + rewards = accelerator.gather(rewards) mean, std = rewards.mean(), rewards.std() print(f"after mean: {mean}, after std: {std}") @@ -304,14 +314,16 @@ def train(args: Args): args.task.query_prefix = args.task.query_prefix.replace("\\n", "\n") args.task.query_suffix = args.task.query_suffix.replace("\\n", "\n") accelerator = Accelerator( - kwargs_handlers=[DistributedDataParallelKwargs(broadcast_buffers=False)] # this is needed to avoid https://github.com/pytorch/pytorch/issues/22095#issuecomment-505099500 + kwargs_handlers=[ + DistributedDataParallelKwargs(broadcast_buffers=False) + ] # this is needed to avoid https://github.com/pytorch/pytorch/issues/22095#issuecomment-505099500 ) args.world_size = accelerator.num_processes args.batch_size = int(args.local_batch_size * args.world_size) console = Console(force_terminal=True) run_name = f"{args.exp_name}__{args.seed}__{int(time.time())}" - writer = SimpleNamespace() # dummy writer + writer = SimpleNamespace() # dummy writer writer.add_scalar = lambda x, y, z: None if accelerator.is_main_process: if args.track: @@ -347,9 +359,15 @@ def train(args: Args): tokenizer.add_special_tokens({"pad_token": "[PAD]"}) query_prefix_tokens = torch.LongTensor(tokenizer.encode(args.task.query_prefix)) query_suffix_tokens = torch.LongTensor(tokenizer.encode(args.task.query_suffix)) - untrained_model = AutoModelForCausalLMWithRewardHead(AutoModelForCausalLM.from_pretrained(args.base_model, use_auth_token=True)).to(device) - reward_model = AutoModelForCausalLMWithRewardHead(AutoModelForCausalLM.from_pretrained(args.base_model, use_auth_token=True)).to(device) - reward_model.pretrained_model.generation_config.eos_token_id = None # disable `pad_token_id` and `eos_token_id` because we just want to + untrained_model = AutoModelForCausalLMWithRewardHead( + AutoModelForCausalLM.from_pretrained(args.base_model, use_auth_token=True) + ).to(device) + reward_model = AutoModelForCausalLMWithRewardHead( + AutoModelForCausalLM.from_pretrained(args.base_model, use_auth_token=True) + ).to(device) + reward_model.pretrained_model.generation_config.eos_token_id = ( + None # disable `pad_token_id` and `eos_token_id` because we just want to + ) reward_model.pretrained_model.generation_config.pad_token_id = None # generate tokens without truncation / padding optimizer = optim.Adam(reward_model.parameters(), lr=args.lr, eps=args.eps) dataset = MyDataset( @@ -382,7 +400,18 @@ def train(args: Args): print("before====", reward_model.module.reward_gain.data) if args.normalize_before: - normalize(args, accelerator, device, tokenizer, accelerator.unwrap_model(reward_model).pretrained_model, reward_model, iter_dataloader, generation_config, query_prefix_tokens, query_suffix_tokens) + normalize( + args, + accelerator, + device, + tokenizer, + accelerator.unwrap_model(reward_model).pretrained_model, + reward_model, + iter_dataloader, + generation_config, + query_prefix_tokens, + query_suffix_tokens, + ) print("after====", reward_model.module.reward_gain.data) print("===training reward model===") @@ -396,7 +425,7 @@ def train(args: Args): global_step += 1 end = start + args.batch_size b_inds_all = all_inds[start:end] - b_inds = b_inds_all[accelerator.process_index::accelerator.num_processes] # multi-GPU slicing + b_inds = b_inds_all[accelerator.process_index :: accelerator.num_processes] # multi-GPU slicing lr = (1 - start / args.labels.num_train) * args.lr optimizer.param_groups[0]["lr"] = lr mb_data = label[b_inds] @@ -405,10 +434,7 @@ def train(args: Args): mb_query = format_query(query_prefix_tokens, mb_query, query_suffix_tokens) mb_query = left_padding_to_right_padding(mb_query, tokenizer.pad_token_id).to(device) mb_best = torch.from_numpy(np.stack(mb_data["best"])).to(device) - mb_responses = [ - torch.from_numpy(np.stack(mb_data[f"sample{i}"])).to(device) - for i in range(args.labels.num_labels) - ] + mb_responses = [torch.from_numpy(np.stack(mb_data[f"sample{i}"])).to(device) for i in range(args.labels.num_labels)] # hack: deal with openai's padding token # assert (mb_query == tokenizer.pad_token_id).sum() == 0 mb_query[mb_query == OPENAI_PAD_TOKEN_ID] = tokenizer.pad_token_id @@ -420,9 +446,7 @@ def train(args: Args): for i in range(args.labels.num_labels): query_responses = torch.cat([mb_query, mb_responses[i]], dim=1) reward = get_reward(reward_model, query_responses, tokenizer)[1] - predicted_rewards.append( - reward.squeeze() - ) + predicted_rewards.append(reward.squeeze()) predicted_rewards = torch.stack( predicted_rewards, dim=1 ) # shape (batch_size, num_labels), basically a reward prediction for each label @@ -442,18 +466,19 @@ def train(args: Args): queries = format_query(query_prefix_tokens, queries, query_suffix_tokens) context_length = queries.shape[1] queries = left_padding_to_right_padding(queries, tokenizer.pad_token_id).to(device) - query_responses = generate(accelerator.unwrap_model(reward_model).pretrained_model, queries, tokenizer, generation_config) + query_responses = generate( + accelerator.unwrap_model(reward_model).pretrained_model, queries, tokenizer, generation_config + ) responses = query_responses[:, context_length:] - output, reward = get_reward(reward_model, query_responses, tokenizer) - logits = output.logits[:,context_length-1:-1] + logits = output.logits[:, context_length - 1 : -1] logits /= args.task.temperature all_logprobs = F.log_softmax(logits, dim=-1) logprobs = torch.gather(all_logprobs, 2, responses.unsqueeze(-1)).squeeze(-1) output, _ = get_reward(untrained_model, query_responses, tokenizer) - logits = output.logits[:,context_length-1:-1] + logits = output.logits[:, context_length - 1 : -1] logits /= args.task.temperature all_logprobs = F.log_softmax(logits, dim=-1) ref_logprobs = torch.gather(all_logprobs, 2, responses.unsqueeze(-1)).squeeze(-1) @@ -475,7 +500,7 @@ def train(args: Args): for start in range(args.labels.num_train, len(label), args.batch_size): end = start + args.batch_size b_inds_all = all_inds[start:end] - b_inds = b_inds_all[accelerator.process_index::accelerator.num_processes] # multi-GPU slicing + b_inds = b_inds_all[accelerator.process_index :: accelerator.num_processes] # multi-GPU slicing mb_data = label[b_inds] # print("accelerator.process_index", accelerator.process_index, b_inds, b_inds_all) mb_query = torch.from_numpy(np.stack(mb_data["query"])) @@ -483,8 +508,7 @@ def train(args: Args): mb_query = left_padding_to_right_padding(mb_query, tokenizer.pad_token_id).to(device) mb_best = torch.from_numpy(np.stack(mb_data["best"])).to(device) mb_responses = [ - torch.from_numpy(np.stack(mb_data[f"sample{i}"])).to(device) - for i in range(args.labels.num_labels) + torch.from_numpy(np.stack(mb_data[f"sample{i}"])).to(device) for i in range(args.labels.num_labels) ] # hack: deal with openai's padding token # assert (mb_query == tokenizer.pad_token_id).sum() == 0 @@ -497,9 +521,7 @@ def train(args: Args): for i in range(args.labels.num_labels): query_responses = torch.cat([mb_query, mb_responses[i]], dim=1) reward = get_reward(reward_model, query_responses, tokenizer)[1] - predicted_rewards.append( - reward.squeeze() - ) + predicted_rewards.append(reward.squeeze()) predicted_rewards = torch.stack( predicted_rewards, dim=1 ) # shape (batch_size, num_labels), basically a reward prediction for each label @@ -512,7 +534,18 @@ def train(args: Args): torch.cuda.empty_cache() if args.normalize_after: - normalize(args, accelerator, device, tokenizer, accelerator.unwrap_model(reward_model).pretrained_model, reward_model, iter_dataloader, generation_config, query_prefix_tokens, query_suffix_tokens) + normalize( + args, + accelerator, + device, + tokenizer, + accelerator.unwrap_model(reward_model).pretrained_model, + reward_model, + iter_dataloader, + generation_config, + query_prefix_tokens, + query_suffix_tokens, + ) # save model if args.save_path: diff --git a/lm_human_preference_details/summarization/train_reward_accelerate_summarize_debug.py b/lm_human_preference_details/summarization/train_reward_accelerate_summarize_debug.py index e52f3ee..c04927b 100644 --- a/lm_human_preference_details/summarization/train_reward_accelerate_summarize_debug.py +++ b/lm_human_preference_details/summarization/train_reward_accelerate_summarize_debug.py @@ -591,7 +591,10 @@ def process_query_data(x): return { **process_query(x, encoder=tokenizer, hparams=patch_h), "reference_response": tokenizer.encode( - f" {x['summary']}", padding="max_length", max_length=args.task.response_length, truncation=True, + f" {x['summary']}", + padding="max_length", + max_length=args.task.response_length, + truncation=True, # with an extra leading space to account for the space between the query and response ), } @@ -674,7 +677,6 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well # torch.save(accelerator.unwrap_model(reward_model).state_dict(), "models/correct_reward.pt") raise - print("===training policy===") global_step = 0 stats_shape = (args.ppo.noptepochs, args.ppo.nminibatches, args.ppo.gradient_accumulation_steps) @@ -695,7 +697,9 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well reference_responses = data["reference_response"].to(device) query_reference_responses = torch.cat((queries, reference_responses), dim=1) queries = right_padding_to_left_padding(data["query_token"], tokenizer.pad_token_id).to(device) - query_reference_responses = right_padding_to_left_padding(query_reference_responses, tokenizer.pad_token_id).to(device) + query_reference_responses = right_padding_to_left_padding(query_reference_responses, tokenizer.pad_token_id).to( + device + ) query_responses = generate( accelerator.unwrap_model(policy).lm_backbone, queries, @@ -708,7 +712,7 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well output, full_values = forward(policy, query_responses, tokenizer) values = full_values[:, context_length - 1 : -1].squeeze(-1) logits = output.logits[:, context_length - 1 : -1] - logits /= (args.task.temperature + 1e-7) + logits /= args.task.temperature + 1e-7 all_logprobs = F.log_softmax(logits, dim=-1) logprobs = torch.gather(all_logprobs, 2, responses.unsqueeze(-1)).squeeze(-1) del output, logits, all_logprobs @@ -716,7 +720,7 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well ref_output, _ = forward(ref_policy, query_responses, tokenizer) ref_logits = ref_output.logits[:, context_length - 1 : -1] - ref_logits /= (args.task.temperature + 1e-7) + ref_logits /= args.task.temperature + 1e-7 ref_all_logprobs = F.log_softmax(ref_logits, dim=-1) ref_logprobs = torch.gather(ref_all_logprobs, 2, responses.unsqueeze(-1)).squeeze(-1) del ref_output, ref_logits, ref_all_logprobs @@ -763,7 +767,7 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well torch.zeros([1], dtype=last_reference_response_indices.dtype, device=query_reference_responses.device), ) reference_scores = reference_scores[:, :, 0].gather(1, last_reference_response_indices.unsqueeze(1)).view(-1) - + print(reference_scores.mean()) # normalization again scores = scores - reference_scores diff --git a/lm_human_preference_details/summarization/train_reward_accelerate_summarized.py b/lm_human_preference_details/summarization/train_reward_accelerate_summarized.py index e9c49f8..4623a57 100644 --- a/lm_human_preference_details/summarization/train_reward_accelerate_summarized.py +++ b/lm_human_preference_details/summarization/train_reward_accelerate_summarized.py @@ -14,7 +14,6 @@ import transformers import tyro from accelerate import Accelerator -from accelerate.state import AcceleratorState from accelerate.utils import DistributedDataParallelKwargs, broadcast from datasets import load_dataset from rich.console import Console @@ -470,17 +469,16 @@ def evaluate(args, accelerator, device, reward_model, validation_label): mb_query = right_padding_to_left_padding(mb_query, args.pad_token_id).to(device) mb_best = torch.from_numpy(np.stack(mb_data["choice"])).to(device) mb_responses = [ - torch.from_numpy(np.stack(mb_data[f"response{i}_token"])).to(device) - for i in range(args.labels.num_labels) - ] + torch.from_numpy(np.stack(mb_data[f"response{i}_token"])).to(device) for i in range(args.labels.num_labels) + ] predicted_rewards = [] for i in range(args.labels.num_labels): query_responses = torch.cat([mb_query, mb_responses[i]], dim=1) score, _ = get_reward_complete(reward_model, query_responses, args) predicted_rewards.append(score) predicted_rewards = torch.stack( - predicted_rewards, dim=1 - ) # shape (batch_size, num_labels), basically a reward prediction for each label + predicted_rewards, dim=1 + ) # shape (batch_size, num_labels), basically a reward prediction for each label accuracy = (predicted_rewards.argmax(1) == mb_best).float().mean() test_accuracies.append(accuracy) test_accuracy = accelerator.gather(torch.stack(test_accuracies).mean()).mean().item() @@ -587,7 +585,10 @@ def process_query_data(x): return { **process_query(x, encoder=tokenizer, hparams=patch_h), "reference_response": tokenizer.encode( - f" {x['summary']}<|endoftext|>", padding="max_length", max_length=args.task.response_length, truncation=True, + f" {x['summary']}<|endoftext|>", + padding="max_length", + max_length=args.task.response_length, + truncation=True, # with an extra leading space to account for the space between the query and response ), } @@ -603,7 +604,7 @@ def process_query_data(x): validation_dataloader = DataLoader(validation_dataset, batch_size=args.local_rollout_batch_size) reward_model, optimizer, dataloader, scheduler = accelerator.prepare(reward_model, optimizer, dataloader, scheduler) - iter_dataloader = iter(dataloader) + iter(dataloader) generation_config = GenerationConfig( max_new_tokens=args.task.response_length, min_new_tokens=args.task.response_length, @@ -648,10 +649,16 @@ def process_response_data(x): return { **process_query(x["info"], encoder=tokenizer, hparams=patch_h), "response0_token": tokenizer.encode( - f" {x['summaries'][0]['text']}<|endoftext|>", padding="max_length", max_length=args.task.response_length, truncation=True + f" {x['summaries'][0]['text']}<|endoftext|>", + padding="max_length", + max_length=args.task.response_length, + truncation=True, ), "response1_token": tokenizer.encode( - f" {x['summaries'][1]['text']}<|endoftext|>", padding="max_length", max_length=args.task.response_length, truncation=True + f" {x['summaries'][1]['text']}<|endoftext|>", + padding="max_length", + max_length=args.task.response_length, + truncation=True, ), } @@ -665,7 +672,7 @@ def process_response_data(x): # ensure that all processes have the same shuffled indices all_inds = broadcast(torch.tensor(all_inds, device=device), 0) all_inds = all_inds.cpu().numpy() - + for (global_step, start) in enumerate(range(0, args.labels.num_train, args.batch_size)): # # linear rate annealing # lr = (1 - start / args.labels.num_train) * args.lr @@ -675,14 +682,15 @@ def process_response_data(x): b_inds_all = all_inds[start:end] b_inds = b_inds_all[accelerator.process_index :: accelerator.num_processes] # multi-GPU slicing # accelerator.print(f"global_step: {global_step}, start: {start}, end: {end}, b_inds: {b_inds}") - if accelerator.is_main_process: pprint( - { - "global_step": global_step, - "start:end": f"{start}:{end}", - "b_inds_all": b_inds_all, - "b_inds": b_inds, - } - ) + if accelerator.is_main_process: + pprint( + { + "global_step": global_step, + "start:end": f"{start}:{end}", + "b_inds_all": b_inds_all, + "b_inds": b_inds, + } + ) losses = torch.zeros((args.gradient_accumulation_steps,), device=device) accuracies = torch.zeros((args.gradient_accumulation_steps,), device=device) gradient_accumulation_step = 0 @@ -702,7 +710,7 @@ def process_response_data(x): torch.from_numpy(np.stack(mb_data[f"response{i}_token"])).to(device) for i in range(args.labels.num_labels) ] mb_query_tiled = mb_query.unsqueeze(1).repeat(1, len(mb_responses), 1) - query_responses = torch.cat([mb_query_tiled, torch.stack(mb_responses).transpose(0,1)], dim=2).flatten(0, 1) + query_responses = torch.cat([mb_query_tiled, torch.stack(mb_responses).transpose(0, 1)], dim=2).flatten(0, 1) predicted_rewards, score_all = get_reward_complete(reward_model, query_responses, tokenizer) breakpoint() @@ -730,7 +738,7 @@ def process_response_data(x): accelerator.print("train/accuracy", train_accuracy) # if args.print_sample_output_freq > 0 and global_step % args.print_sample_output_freq == 0: - if global_step == num_updates - 1: # first and last update + if global_step == num_updates - 1: # first and last update dev_validation_accuracy = evaluate(args, accelerator, device, reward_model, dev_validation_label) writer.add_scalar("dev_validation/accuracy", dev_validation_accuracy, global_step) accelerator.print("dev_validation/accuracy", dev_validation_accuracy, global_step) @@ -772,7 +780,6 @@ def process_response_data(x): wandb.finish() - if __name__ == "__main__": args = tyro.cli(Args) train(args) diff --git a/lm_human_preference_details/summarization/train_reward_accelerate_summarizew.py b/lm_human_preference_details/summarization/train_reward_accelerate_summarizew.py index cfbd58a..a199c85 100644 --- a/lm_human_preference_details/summarization/train_reward_accelerate_summarizew.py +++ b/lm_human_preference_details/summarization/train_reward_accelerate_summarizew.py @@ -9,7 +9,6 @@ import pandas as pd import torch import torch.nn as nn -import torch.nn.functional as F import torch.optim as optim import transformers import tyro @@ -28,7 +27,7 @@ ) from torch.utils.data import DataLoader from torch.utils.tensorboard import SummaryWriter -from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig, get_scheduler +from transformers import AutoModelForCausalLM, AutoTokenizer, get_scheduler from lm_human_preference_details.data import process_query @@ -478,9 +477,8 @@ def evaluate(args, accelerator, device, reward_model, validation_label): mb_query = right_padding_to_left_padding(mb_query, args.pad_token_id).to(device) mb_best = torch.from_numpy(np.stack(mb_data["choice"])).to(device) mb_responses = [ - torch.from_numpy(np.stack(mb_data[f"response{i}_token"])).to(device) - for i in range(args.labels.num_labels) - ] + torch.from_numpy(np.stack(mb_data[f"response{i}_token"])).to(device) for i in range(args.labels.num_labels) + ] predicted_reward = [] rewards = [] for i in range(args.labels.num_labels): @@ -600,12 +598,11 @@ def train(args: Args): ) if args.deepspeed: - import deepspeed + pass deepspeed_states = AcceleratorState().deepspeed_plugin deepspeed_states.deepspeed_config["train_micro_batch_size_per_gpu"] = args.local_micro_batch_size - reward_model, optimizer, scheduler = accelerator.prepare(reward_model, optimizer, scheduler) if args.normalize_before: dataset = load_dataset(args.task.query_dataset, split="train") @@ -615,7 +612,10 @@ def process_query_data(x): return { **process_query(x, encoder=tokenizer, hparams=patch_h), "reference_response": tokenizer.encode( - f" {x['summary']}<|endoftext|>", padding="max_length", max_length=args.task.response_length, truncation=True, + f" {x['summary']}<|endoftext|>", + padding="max_length", + max_length=args.task.response_length, + truncation=True, # with an extra leading space to account for the space between the query and response ), } @@ -629,7 +629,7 @@ def process_query_data(x): validation_dataset = validation_dataset.shuffle(seed=local_seed) validation_dataloader = DataLoader(validation_dataset, batch_size=args.local_rollout_batch_size) dataloader = accelerator.prepare(dataloader) - iter_dataloader = iter(dataloader) + iter(dataloader) print("===Normalize reward model *before* training===") print( "before normalization. " @@ -664,10 +664,16 @@ def process_response_data(x): return { **process_query(x["info"], encoder=tokenizer, hparams=patch_h), "response0_token": tokenizer.encode( - f" {x['summaries'][0]['text']}<|endoftext|>", padding="max_length", max_length=args.task.response_length, truncation=True + f" {x['summaries'][0]['text']}<|endoftext|>", + padding="max_length", + max_length=args.task.response_length, + truncation=True, ), "response1_token": tokenizer.encode( - f" {x['summaries'][1]['text']}<|endoftext|>", padding="max_length", max_length=args.task.response_length, truncation=True + f" {x['summaries'][1]['text']}<|endoftext|>", + padding="max_length", + max_length=args.task.response_length, + truncation=True, ), } @@ -689,14 +695,15 @@ def process_response_data(x): b_inds_all = all_inds[start:end] b_inds = b_inds_all[accelerator.process_index :: accelerator.num_processes] # multi-GPU slicing # accelerator.print(f"global_step: {global_step}, start: {start}, end: {end}, b_inds: {b_inds}") - if accelerator.is_main_process: pprint( - { - "global_step": global_step, - "start:end": f"{start}:{end}", - "b_inds_all": b_inds_all, - "b_inds": b_inds, - } - ) + if accelerator.is_main_process: + pprint( + { + "global_step": global_step, + "start:end": f"{start}:{end}", + "b_inds_all": b_inds_all, + "b_inds": b_inds, + } + ) losses = torch.zeros((args.gradient_accumulation_steps,), device=device) accuracies = torch.zeros((args.gradient_accumulation_steps,), device=device) reward_preferreds = torch.zeros((args.gradient_accumulation_steps,), device=device) @@ -715,13 +722,18 @@ def process_response_data(x): mb_query = torch.from_numpy(np.stack(mb_data["query_token"])).to(device) mb_best = torch.from_numpy(np.stack(mb_data["choice"])).to(device) mb_responses = [ - torch.from_numpy(np.stack(mb_data[f"response{i}_token"])).to(device) for i in range(args.labels.num_labels) + torch.from_numpy(np.stack(mb_data[f"response{i}_token"])).to(device) + for i in range(args.labels.num_labels) ] mb_query_tiled = mb_query.unsqueeze(1).repeat(1, len(mb_responses), 1) - query_responses = torch.cat([mb_query_tiled, torch.stack(mb_responses).transpose(0,1)], dim=2).flatten(0, 1) + query_responses = torch.cat([mb_query_tiled, torch.stack(mb_responses).transpose(0, 1)], dim=2).flatten( + 0, 1 + ) predicted_reward, reward = get_reward_complete(reward_model, query_responses, tokenizer) - predicted_reward = predicted_reward.view(-1, len(mb_responses)) # TODO check shape for no gradienta ccumulation steps - + predicted_reward = predicted_reward.view( + -1, len(mb_responses) + ) # TODO check shape for no gradienta ccumulation steps + # print(tokenizer.decode(mb_query[0])) # print(tokenizer.decode(mb_responses[0][0])) # print(tokenizer.decode(mb_responses[1][0])) @@ -774,7 +786,7 @@ def process_response_data(x): accelerator.print("train/accuracy", train_accuracy) # if args.print_sample_output_freq > 0 and global_step % args.print_sample_output_freq == 0: - if global_step == args.num_updates - 1: # first and last update + if global_step == args.num_updates - 1: # first and last update dev_validation_accuracy = evaluate(args, accelerator, device, reward_model, dev_validation_label) writer.add_scalar("dev_validation/accuracy", dev_validation_accuracy, global_step) accelerator.print("dev_validation/accuracy", dev_validation_accuracy, global_step) diff --git a/lm_human_preference_details/summarization/train_sft_accelerate_summarize copy.py b/lm_human_preference_details/summarization/train_sft_accelerate_summarize copy.py index 31fb58b..0ba4cb8 100644 --- a/lm_human_preference_details/summarization/train_sft_accelerate_summarize copy.py +++ b/lm_human_preference_details/summarization/train_sft_accelerate_summarize copy.py @@ -6,12 +6,12 @@ from types import SimpleNamespace from typing import List, Optional +import evaluate import numpy as np import pandas as pd import torch import torch.optim as optim import tyro -import evaluate from accelerate import Accelerator from datasets import load_dataset from rich.console import Console @@ -38,7 +38,7 @@ class SFTHParams: lr: float = 6.35e-5 eps: float = 1e-5 total_episodes: tyro.conf.Suppress[int] = None - local_batch_size:tyro.conf.Suppress[int] = None + local_batch_size: tyro.conf.Suppress[int] = None batch_size: tyro.conf.Suppress[int] = None mini_batch_size: tyro.conf.Suppress[int] = None world_size: tyro.conf.Suppress[int] = None @@ -395,11 +395,14 @@ def train(args: Args): optimizer = optim.Adam(policy.parameters(), lr=args.sft.lr, eps=args.sft.eps) def process_query_data(x): - pad_summary_w_leading_space = " " + x['summary'] + pad_summary_w_leading_space = " " + x["summary"] return { **process_query(x, encoder=tokenizer, hparams=patch_h), "reference_response": tokenizer.encode( - pad_summary_w_leading_space, padding="max_length", max_length=args.task.response_length, truncation=True, + pad_summary_w_leading_space, + padding="max_length", + max_length=args.task.response_length, + truncation=True, # with an extra leading space to account for the space between the query and response ), } @@ -468,9 +471,11 @@ def process_query_data(x): test_queries = test_data["query_token"].to(device) test_reference_responses = test_data["reference_response"].to(device) test_queries = right_padding_to_left_padding(test_queries, tokenizer.pad_token_id) - generated_responses = generate(accelerator.unwrap_model(policy), test_queries, tokenizer, generation_config) + generated_responses = generate( + accelerator.unwrap_model(policy), test_queries, tokenizer, generation_config + ) accelerator.print(update, test_idx) - + all_decode_test_queries = tokenizer.batch_decode(test_queries, skip_special_tokens=True) all_decode_test_query_responses = tokenizer.batch_decode(generated_responses, skip_special_tokens=True) all_decode_test_reference_responses = tokenizer.batch_decode( @@ -479,7 +484,9 @@ def process_query_data(x): all_decode_test_responses = [ x[len(y) :] for x, y in zip(all_decode_test_query_responses, all_decode_test_queries) ] - rouge_score = rouge.compute(predictions=all_decode_test_responses, references=all_decode_test_reference_responses) + rouge_score = rouge.compute( + predictions=all_decode_test_responses, references=all_decode_test_reference_responses + ) rouge_scores["rouge1"].append(rouge_score["rouge1"]) rouge_scores["rouge2"].append(rouge_score["rouge2"]) rouge_scores["rougeL"].append(rouge_score["rougeL"]) @@ -498,7 +505,7 @@ def process_query_data(x): print_rich_table(f"Sample Output at Step {update}", all_df[:4], console) except Exception as e: print(e) - + for k, v in rouge_scores.items(): rouge_metric = torch.tensor(v, device=device) rouge_metric = accelerator.gather(rouge_metric) @@ -516,6 +523,7 @@ def process_query_data(x): policy.save_pretrained(repo_id, safe_serialization=True, push_to_hub=True) tokenizer.save_pretrained(repo_id, push_to_hub=True) + if __name__ == "__main__": args = tyro.cli(Args) - train(args) \ No newline at end of file + train(args) diff --git a/lm_human_preference_details/summarization/train_sft_accelerate_summarize_executor.py b/lm_human_preference_details/summarization/train_sft_accelerate_summarize_executor.py index 618a1e9..5f9b4a2 100644 --- a/lm_human_preference_details/summarization/train_sft_accelerate_summarize_executor.py +++ b/lm_human_preference_details/summarization/train_sft_accelerate_summarize_executor.py @@ -2,16 +2,17 @@ import os import random import time +from concurrent.futures import ProcessPoolExecutor from dataclasses import asdict, dataclass, field from types import SimpleNamespace from typing import List, Optional +import evaluate import numpy as np import pandas as pd import torch import torch.optim as optim import tyro -import evaluate from accelerate import Accelerator from datasets import load_dataset from rich.console import Console @@ -28,7 +29,6 @@ from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig from lm_human_preference_details.data import process_query -from concurrent.futures import ProcessPoolExecutor @dataclass @@ -39,7 +39,7 @@ class SFTHParams: lr: float = 6.35e-5 eps: float = 1e-5 total_episodes: tyro.conf.Suppress[int] = None - local_batch_size:tyro.conf.Suppress[int] = None + local_batch_size: tyro.conf.Suppress[int] = None batch_size: tyro.conf.Suppress[int] = None mini_batch_size: tyro.conf.Suppress[int] = None world_size: tyro.conf.Suppress[int] = None @@ -138,12 +138,8 @@ def calculate_rouge( tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True) all_decode_test_queries = tokenizer.batch_decode(test_queries, skip_special_tokens=True) all_decode_test_query_responses = tokenizer.batch_decode(generated_responses, skip_special_tokens=True) - all_decode_test_reference_responses = tokenizer.batch_decode( - test_reference_responses, skip_special_tokens=True - ) - all_decode_test_responses = [ - x[len(y) :] for x, y in zip(all_decode_test_query_responses, all_decode_test_queries) - ] + all_decode_test_reference_responses = tokenizer.batch_decode(test_reference_responses, skip_special_tokens=True) + all_decode_test_responses = [x[len(y) :] for x, y in zip(all_decode_test_query_responses, all_decode_test_queries)] rouge = evaluate.load("rouge") return rouge.compute(predictions=predictions, references=references) @@ -417,7 +413,10 @@ def process_query_data(x): return { **process_query(x, encoder=tokenizer, hparams=patch_h), "reference_response": tokenizer.encode( - f" {x['summary']}", padding="max_length", max_length=args.task.response_length, truncation=True, + f" {x['summary']}", + padding="max_length", + max_length=args.task.response_length, + truncation=True, # with an extra leading space to account for the space between the query and response ), } @@ -442,7 +441,7 @@ def process_query_data(x): top_p=1.0, do_sample=True, ) - executor = ProcessPoolExecutor() + ProcessPoolExecutor() # rouge = evaluate.load("rouge") print("===training policy===") @@ -486,11 +485,10 @@ def process_query_data(x): for test_idx, test_data in enumerate(test_dataloader): with torch.no_grad(): test_queries = test_data["query_token"].to(device) - test_reference_responses = test_data["reference_response"] + test_data["reference_response"] # test_queries = right_padding_to_left_padding(test_queries, tokenizer.pad_token_id) - generated_responses = generate(accelerator.unwrap_model(policy), test_queries, tokenizer, generation_config) + generate(accelerator.unwrap_model(policy), test_queries, tokenizer, generation_config) accelerator.print(update, test_idx) - # futures.append( # executor.submit( @@ -516,7 +514,7 @@ def process_query_data(x): # except Exception as e: # print(e) - rouge_scores = [f.result() for f in futures] # list of dicts + rouge_scores = [f.result() for f in futures] # list of dicts rouge_scores = {k: np.mean([x[k] for x in rouge_scores]) for k in rouge_scores[0].keys()} for k, v in rouge_scores.items(): rouge_metric = torch.tensor(v, device=device) @@ -535,6 +533,7 @@ def process_query_data(x): policy.save_pretrained(repo_id, safe_serialization=True, push_to_hub=True) tokenizer.save_pretrained(repo_id, push_to_hub=True) + if __name__ == "__main__": args = tyro.cli(Args) train(args) diff --git a/lm_human_preference_details/summarize_old/train_policy_accelerate_summarize.py b/lm_human_preference_details/summarize_old/train_policy_accelerate_summarize.py index 0ffcd00..ce44ca9 100644 --- a/lm_human_preference_details/summarize_old/train_policy_accelerate_summarize.py +++ b/lm_human_preference_details/summarize_old/train_policy_accelerate_summarize.py @@ -617,7 +617,7 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well output, full_values = forward(policy, query_responses, tokenizer) values = full_values[:, context_length - 1 : -1].squeeze(-1) logits = output.logits[:, context_length - 1 : -1] - logits /= (args.task.temperature + 1e-7) + logits /= args.task.temperature + 1e-7 all_logprobs = F.log_softmax(logits, dim=-1) logprobs = torch.gather(all_logprobs, 2, responses.unsqueeze(-1)).squeeze(-1) del output, logits, all_logprobs @@ -625,7 +625,7 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well ref_output, _ = forward(ref_policy, query_responses, tokenizer) ref_logits = ref_output.logits[:, context_length - 1 : -1] - ref_logits /= (args.task.temperature + 1e-7) + ref_logits /= args.task.temperature + 1e-7 ref_all_logprobs = F.log_softmax(ref_logits, dim=-1) ref_logprobs = torch.gather(ref_all_logprobs, 2, responses.unsqueeze(-1)).squeeze(-1) del ref_output, ref_logits, ref_all_logprobs @@ -759,7 +759,7 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well output, vpred_temp = forward(policy, mb_query_responses, tokenizer) logits = output.logits[:, context_length - 1 : -1] - logits /= (args.task.temperature + 1e-7) + logits /= args.task.temperature + 1e-7 new_all_logprobs = F.log_softmax(logits, dim=-1) new_logprobs = torch.gather(new_all_logprobs, 2, mb_responses.unsqueeze(-1)).squeeze(-1) vpred = vpred_temp[:, context_length - 1 : -1].squeeze(-1) diff --git a/lm_human_preference_details/summarize_old/train_policy_accelerate_summarize_separate.py b/lm_human_preference_details/summarize_old/train_policy_accelerate_summarize_separate.py index 90df20d..e2293d3 100644 --- a/lm_human_preference_details/summarize_old/train_policy_accelerate_summarize_separate.py +++ b/lm_human_preference_details/summarize_old/train_policy_accelerate_summarize_separate.py @@ -32,6 +32,7 @@ INVALID_LOGPROB = 1.0 + @dataclass class AdaptiveKLParams: target: float = 6.0 @@ -378,10 +379,10 @@ def __init__(self, policy, critic) -> None: super().__init__() self.policy = policy self.critic = critic - + def forward(self, **kwargs): return self.policy(**kwargs), self.critic(**kwargs) - + def right_padding_to_left_padding(tokens, pad_id): """Convert from right padding to left padding.""" @@ -523,9 +524,7 @@ def forward(policy, query_responses, tokenizer): reward_model = AutoModelForCausalLMWithRewardHead( AutoModelForCausalLM.from_pretrained(args.base_model, trust_remote_code=True) ) - critic = AutoModelForCausalLMWithRewardHead( - AutoModelForCausalLM.from_pretrained(args.base_model, trust_remote_code=True) - ) + critic = AutoModelForCausalLMWithRewardHead(AutoModelForCausalLM.from_pretrained(args.base_model, trust_remote_code=True)) if args.rewards.trained_model: reward_model.load_state_dict(torch.load(args.rewards.trained_model, map_location=device)) critic.load_state_dict(torch.load(args.rewards.trained_model, map_location=device)) @@ -537,9 +536,7 @@ def forward(policy, query_responses, tokenizer): policy.load_state_dict(torch.load(args.sft_model_path, map_location=device)) ref_policy.load_state_dict(torch.load(args.sft_model_path, map_location=device)) print(f"loaded pretrained policy from {args.sft_model_path}") - policy.generation_config.eos_token_id = ( - None # disable `pad_token_id` and `eos_token_id` because we just want to - ) + policy.generation_config.eos_token_id = None # disable `pad_token_id` and `eos_token_id` because we just want to policy.generation_config.pad_token_id = None # generate tokens without truncation / padding model = PolicyAndValueWrapper(policy, critic) if args.use_tensorflow_adam: @@ -553,7 +550,10 @@ def process_query_data(x): return { **process_query(x, encoder=tokenizer, hparams=patch_h), "reference_response": tokenizer.encode( - f" {x['summary']}<|endoftext|>", padding="max_length", max_length=args.task.response_length, truncation=True, + f" {x['summary']}<|endoftext|>", + padding="max_length", + max_length=args.task.response_length, + truncation=True, # with an extra leading space to account for the space between the query and response ), } @@ -608,8 +608,12 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well print(sample_validation_queries.shape) sample_validation_queries = right_padding_to_left_padding(sample_validation_queries, tokenizer.pad_token_id) sample_validation_reference_response = sample_validation["reference_response"] - sample_validation_query_reference_responses = torch.cat((sample_validation_queries, sample_validation_reference_response), dim=1) - sample_validation_reference_scores = get_reward_complete(reward_model, sample_validation_query_reference_responses, tokenizer) + sample_validation_query_reference_responses = torch.cat( + (sample_validation_queries, sample_validation_reference_response), dim=1 + ) + sample_validation_reference_scores = get_reward_complete( + reward_model, sample_validation_query_reference_responses, tokenizer + ) # breakpoint() iter_dataloader = iter(repeat_generator()) @@ -677,7 +681,7 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well postprocessed_responses: `XXXXTXX` -> `XXXXTPP` postprocessed_query_responses: `PPXXX,XXXXTPP` scores: ↑ # corresponding to this `X` token - + """ queries = data["query_token"].to(device) reference_responses = data["reference_response"].to(device) @@ -714,17 +718,17 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well torch.full_like(sample_validation_responses, tokenizer.pad_token_id), sample_validation_responses, ) - postprocessed_sample_validation_query_responses = torch.cat((sample_validation_queries, postprocessed_sample_validation_responses), 1) + postprocessed_sample_validation_query_responses = torch.cat( + (sample_validation_queries, postprocessed_sample_validation_responses), 1 + ) del truncate_token_mask, truncate_after_or_token_mask, truncate_mask torch.cuda.empty_cache() - - output = forward(accelerator.unwrap_model(model).policy, query_responses, tokenizer) full_values = get_reward(accelerator.unwrap_model(model).critic, query_responses, tokenizer)[1] values = full_values[:, context_length - 1 : -1].squeeze(-1) logits = output.logits[:, context_length - 1 : -1] - logits /= (args.task.temperature + 1e-7) + logits /= args.task.temperature + 1e-7 all_logprobs = F.log_softmax(logits, dim=-1) logprobs = torch.gather(all_logprobs, 2, responses.unsqueeze(-1)).squeeze(-1) del output, logits, all_logprobs @@ -732,7 +736,7 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well ref_output = forward(ref_policy, query_responses, tokenizer) ref_logits = ref_output.logits[:, context_length - 1 : -1] - ref_logits /= (args.task.temperature + 1e-7) + ref_logits /= args.task.temperature + 1e-7 ref_all_logprobs = F.log_softmax(ref_logits, dim=-1) ref_logprobs = torch.gather(ref_all_logprobs, 2, responses.unsqueeze(-1)).squeeze(-1) del ref_output, ref_logits, ref_all_logprobs @@ -778,20 +782,27 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well attention_mask = qr != tokenizer.pad_token_id position_ids = attention_mask.cumsum(1) - attention_mask.long() # exclusive cumsum input_ids = torch.masked_fill(qr, ~attention_mask, 0) - output = reward_model.lm_backbone(input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, return_dict=True, output_hidden_states=True) - last_reward_latents = output.hidden_states[-1] # TODO: investigate whether it should be output.hidden_states[0] or output.hidden_states[-1] + output = reward_model.lm_backbone( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + return_dict=True, + output_hidden_states=True, + ) + last_reward_latents = output.hidden_states[ + -1 + ] # TODO: investigate whether it should be output.hidden_states[0] or output.hidden_states[-1] reward = reward_model.scalar_head(last_reward_latents) - print(postprocessed_query_responses[0:5,537:]) - print(rew.squeeze(-1)[0:5,537:]) + print(postprocessed_query_responses[0:5, 537:]) + print(rew.squeeze(-1)[0:5, 537:]) print(scores) breakpoint() - reference_scores = get_reward_complete(reward_model, query_reference_responses, tokenizer) # note that we do not truncate the validation responses validation_score = get_reward_complete(reward_model, postprocessed_sample_validation_query_responses, tokenizer) - + # carperAI-style score normaliation accelerator.print("before score", scores, scores.mean()) accelerator.print("reference_scores", reference_scores, reference_scores.mean()) @@ -818,9 +829,7 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well if args.print_sample_output_freq > 0 and (update - 1) % args.print_sample_output_freq == 0: try: all_decode_validation_queries = tokenizer.batch_decode(sample_validation_queries) - all_sample_validation_query_responses = tokenizer.batch_decode( - sample_validation_query_responses - ) + all_sample_validation_query_responses = tokenizer.batch_decode(sample_validation_query_responses) all_sample_validation_query_responses_postprocessed = tokenizer.batch_decode( postprocessed_sample_validation_query_responses ) @@ -828,11 +837,10 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well x[len(y) :] for x, y in zip(all_sample_validation_query_responses, all_decode_validation_queries) ] all_sample_validation_postprocessed_responses = [ - x[len(y) :] for x, y in zip(all_sample_validation_query_responses_postprocessed, all_decode_validation_queries) + x[len(y) :] + for x, y in zip(all_sample_validation_query_responses_postprocessed, all_decode_validation_queries) ] - all_sample_validation_reference_responses = tokenizer.batch_decode( - sample_validation_reference_response - ) + all_sample_validation_reference_responses = tokenizer.batch_decode(sample_validation_reference_response) all_sample_validation_df = pd.DataFrame( { "query": all_decode_validation_queries, @@ -846,7 +854,7 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well if accelerator.is_main_process and args.track: wandb.log({"samples/query_responses": wandb.Table(dataframe=all_sample_validation_df)}, step=update) print_rich_table("stuff", all_sample_validation_df[:4], console) - + except Exception as e: print(e) del ( @@ -904,9 +912,9 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well # TODO: value also use the EOS token index!!!!!!!!!!!!!!!!!!! # TODO: value also use the EOS token index!!!!!!!!!!!!!!!!!!! # TODO: value also use the EOS token index!!!!!!!!!!!!!!!!!!! - + logits = output.logits[:, context_length - 1 : -1] - logits /= (args.task.temperature + 1e-7) + logits /= args.task.temperature + 1e-7 new_all_logprobs = F.log_softmax(logits, dim=-1) new_logprobs = torch.gather(new_all_logprobs, 2, mb_responses.unsqueeze(-1)).squeeze(-1) new_logprobs = torch.masked_fill(new_logprobs, padding_mask[micro_batch_inds], INVALID_LOGPROB) @@ -1017,5 +1025,5 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well tokenizer.save_pretrained(repo_id, push_to_hub=True) # if __name__ == "__main__": -# args = tyro.cli(Args) +# args = tyro.cli(Args) # train(args) diff --git a/lm_human_preference_details/summarize_old/train_reward_accelerate_summarize.py b/lm_human_preference_details/summarize_old/train_reward_accelerate_summarize.py index 443cc3b..9ff6abe 100644 --- a/lm_human_preference_details/summarize_old/train_reward_accelerate_summarize.py +++ b/lm_human_preference_details/summarize_old/train_reward_accelerate_summarize.py @@ -28,7 +28,7 @@ ) from torch.utils.data import DataLoader from torch.utils.tensorboard import SummaryWriter -from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig, get_scheduler +from transformers import AutoModelForCausalLM, AutoTokenizer, get_scheduler from lm_human_preference_details.data import process_query @@ -478,9 +478,8 @@ def evaluate(args, accelerator, device, reward_model, validation_label): mb_query = right_padding_to_left_padding(mb_query, args.pad_token_id).to(device) mb_best = torch.from_numpy(np.stack(mb_data["choice"])).to(device) mb_responses = [ - torch.from_numpy(np.stack(mb_data[f"response{i}_token"])).to(device) - for i in range(args.labels.num_labels) - ] + torch.from_numpy(np.stack(mb_data[f"response{i}_token"])).to(device) for i in range(args.labels.num_labels) + ] predicted_reward = [] for i in range(args.labels.num_labels): query_responses = torch.cat([mb_query, mb_responses[i]], dim=1) @@ -596,12 +595,11 @@ def train(args: Args): ) if args.deepspeed: - import deepspeed + pass deepspeed_states = AcceleratorState().deepspeed_plugin deepspeed_states.deepspeed_config["train_micro_batch_size_per_gpu"] = args.local_micro_batch_size - reward_model, optimizer, scheduler = accelerator.prepare(reward_model, optimizer, scheduler) if args.normalize_before: dataset = load_dataset(args.task.query_dataset, split="train") @@ -611,7 +609,10 @@ def process_query_data(x): return { **process_query(x, encoder=tokenizer, hparams=patch_h), "reference_response": tokenizer.encode( - f" {x['summary']}<|endoftext|>", padding="max_length", max_length=args.task.response_length, truncation=True, + f" {x['summary']}<|endoftext|>", + padding="max_length", + max_length=args.task.response_length, + truncation=True, # with an extra leading space to account for the space between the query and response ), } @@ -625,7 +626,7 @@ def process_query_data(x): validation_dataset = validation_dataset.shuffle(seed=local_seed) validation_dataloader = DataLoader(validation_dataset, batch_size=args.local_rollout_batch_size) dataloader = accelerator.prepare(dataloader) - iter_dataloader = iter(dataloader) + iter(dataloader) print("===Normalize reward model *before* training===") print( "before normalization. " @@ -660,10 +661,16 @@ def process_response_data(x): return { **process_query(x["info"], encoder=tokenizer, hparams=patch_h), "response0_token": tokenizer.encode( - f" {x['summaries'][0]['text']}<|endoftext|>", padding="max_length", max_length=args.task.response_length, truncation=True + f" {x['summaries'][0]['text']}<|endoftext|>", + padding="max_length", + max_length=args.task.response_length, + truncation=True, ), "response1_token": tokenizer.encode( - f" {x['summaries'][1]['text']}<|endoftext|>", padding="max_length", max_length=args.task.response_length, truncation=True + f" {x['summaries'][1]['text']}<|endoftext|>", + padding="max_length", + max_length=args.task.response_length, + truncation=True, ), } @@ -685,14 +692,15 @@ def process_response_data(x): b_inds_all = all_inds[start:end] b_inds = b_inds_all[accelerator.process_index :: accelerator.num_processes] # multi-GPU slicing # accelerator.print(f"global_step: {global_step}, start: {start}, end: {end}, b_inds: {b_inds}") - if accelerator.is_main_process: pprint( - { - "global_step": global_step, - "start:end": f"{start}:{end}", - "b_inds_all": b_inds_all, - "b_inds": b_inds, - } - ) + if accelerator.is_main_process: + pprint( + { + "global_step": global_step, + "start:end": f"{start}:{end}", + "b_inds_all": b_inds_all, + "b_inds": b_inds, + } + ) losses = torch.zeros((args.gradient_accumulation_steps,), device=device) accuracies = torch.zeros((args.gradient_accumulation_steps,), device=device) reward_preferreds = torch.zeros((args.gradient_accumulation_steps,), device=device) @@ -711,13 +719,18 @@ def process_response_data(x): mb_query = torch.from_numpy(np.stack(mb_data["query_token"])).to(device) mb_best = torch.from_numpy(np.stack(mb_data["choice"])).to(device) mb_responses = [ - torch.from_numpy(np.stack(mb_data[f"response{i}_token"])).to(device) for i in range(args.labels.num_labels) + torch.from_numpy(np.stack(mb_data[f"response{i}_token"])).to(device) + for i in range(args.labels.num_labels) ] mb_query_tiled = mb_query.unsqueeze(1).repeat(1, len(mb_responses), 1) - query_responses = torch.cat([mb_query_tiled, torch.stack(mb_responses).transpose(0,1)], dim=2).flatten(0, 1) + query_responses = torch.cat([mb_query_tiled, torch.stack(mb_responses).transpose(0, 1)], dim=2).flatten( + 0, 1 + ) predicted_reward, reward = get_reward_complete(reward_model, query_responses, tokenizer) - predicted_reward = predicted_reward.view(-1, len(mb_responses)) # TODO check shape for no gradienta ccumulation steps - + predicted_reward = predicted_reward.view( + -1, len(mb_responses) + ) # TODO check shape for no gradienta ccumulation steps + # print(tokenizer.decode(mb_query[0])) # print(tokenizer.decode(mb_responses[0][0])) # print(tokenizer.decode(mb_responses[1][0])) @@ -764,7 +777,7 @@ def process_response_data(x): accelerator.print("train/accuracy", train_accuracy) # if args.print_sample_output_freq > 0 and global_step % args.print_sample_output_freq == 0: - if global_step == args.num_updates - 1: # first and last update + if global_step == args.num_updates - 1: # first and last update dev_validation_accuracy = evaluate(args, accelerator, device, reward_model, dev_validation_label) writer.add_scalar("dev_validation/accuracy", dev_validation_accuracy, global_step) accelerator.print("dev_validation/accuracy", dev_validation_accuracy, global_step) diff --git a/lm_human_preference_details/tldr_dataset.py b/lm_human_preference_details/tldr_dataset.py index cee1642..de9277b 100644 --- a/lm_human_preference_details/tldr_dataset.py +++ b/lm_human_preference_details/tldr_dataset.py @@ -1,17 +1,18 @@ -from datasets import load_dataset from dataclasses import dataclass from typing import Dict, Optional, Union -from transformers import AutoTokenizer -from rich.pretty import pprint -import numpy as np +from datasets import load_dataset +from rich.pretty import pprint +from transformers import AutoTokenizer @dataclass class TaskQueryHParams: length: int = 512 dataset: str = "vwxyzjn/summarize_from_feedback_tldr_3_filtered" - format_str: Optional[str] = "SUBREDDIT: r/{subreddit}\n\nTITLE: {title}\n\nPOST: {post}\n\nTL;DR:" # if underlying dataset yields dicts, can format arbitrarily + format_str: Optional[ + str + ] = "SUBREDDIT: r/{subreddit}\n\nTITLE: {title}\n\nPOST: {post}\n\nTL;DR:" # if underlying dataset yields dicts, can format arbitrarily truncate_field: Optional[str] = "post" truncate_text: Optional[str] = "\n" padding: Optional[Union[str, int]] = 50257 @@ -94,6 +95,7 @@ def process_query(query_info: Dict[str, str], *, encoder, hparams: TaskQueryHPar oai_h.padding = [oai_h.padding] pprint(oai_h) dataset = load_dataset(oai_h.dataset) + def process_query_data(x): # with an extra leading space to account for the space between the query and response reference_response = f" {x['summary']}<|endoftext|>" @@ -101,18 +103,23 @@ def process_query_data(x): **process_query(x, encoder=tokenizer, hparams=oai_h), "reference_response": reference_response, "reference_response_token": tokenizer.encode( - reference_response, padding="max_length", max_length=max_response_length, truncation=True, + reference_response, + padding="max_length", + max_length=max_response_length, + truncation=True, ), } + dataset = dataset.map(process_query_data, load_from_cache_file=False) push_result = dataset.push_to_hub("vwxyzjn/summarize_from_feedback_tldr_3_filtered_oai_preprocessing") print(push_result) label = load_dataset("openai/summarize_from_feedback", "comparisons") + def process_response_data(x): # with an extra leading space to account for the space between the query and response - response0 = x['summaries'][0]['text'] - response1 = x['summaries'][1]['text'] + response0 = x["summaries"][0]["text"] + response1 = x["summaries"][1]["text"] return { **process_query(x["info"], encoder=tokenizer, hparams=oai_h), "response0": response0, @@ -124,5 +131,6 @@ def process_response_data(x): response1, padding="max_length", max_length=max_response_length, truncation=True ), } + label = label.map(process_response_data, load_from_cache_file=False) push_result = label.push_to_hub("vwxyzjn/summarize_from_feedback_oai_preprocessing") diff --git a/lm_human_preference_details/train_policy_accelerate_summarize_separate.py b/lm_human_preference_details/train_policy_accelerate_summarize_separate.py index d284d66..18e50fa 100644 --- a/lm_human_preference_details/train_policy_accelerate_summarize_separate.py +++ b/lm_human_preference_details/train_policy_accelerate_summarize_separate.py @@ -28,9 +28,9 @@ from torch.utils.tensorboard import SummaryWriter from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig - INVALID_LOGPROB = 1.0 + @dataclass class AdaptiveKLParams: target: float = 6.0 @@ -360,9 +360,9 @@ def __init__(self, lm_backbone): def forward(self, **kwargs): output = self.lm_backbone(**kwargs) - latents = output.hidden_states[-1] # shape: [batch_size, length, hidden_size] - scalars = self.scalar_head(latents).squeeze(-1) # shape: [batch_size, length] - last_scalar = scalars[:, -1] # shape: [batch_size, 1] + latents = output.hidden_states[-1] # shape: [batch_size, length, hidden_size] + scalars = self.scalar_head(latents).squeeze(-1) # shape: [batch_size, length] + last_scalar = scalars[:, -1] # shape: [batch_size, 1] return scalars, last_scalar @@ -373,10 +373,10 @@ def __init__(self, policy, critic) -> None: super().__init__() self.policy = policy self.critic = critic - + def forward(self, **kwargs): return self.policy(**kwargs), self.critic(**kwargs) - + def shift_pad_id_left(tokens, pad_id): """Convert from right padding to left padding.""" @@ -386,6 +386,7 @@ def shift_pad_id_left(tokens, pad_id): device=tokens.device, ) + def shift_pad_id_left(data, pad_id): # Step 1: Create a boolean mask mask = (data == pad_id).long() @@ -456,6 +457,7 @@ def truncate_response(args, tokenizer, responses): postprocessed_responses = torch.masked_fill(responses, idxs > trunc_idxs, tokenizer.pad_token_id) return postprocessed_responses + # def train(args: Args): if __name__ == "__main__": args = tyro.cli(Args) @@ -513,9 +515,7 @@ def truncate_response(args, tokenizer, responses): reward_model = AutoModelForCausalLMWithRewardHead( AutoModelForCausalLM.from_pretrained(args.base_model, trust_remote_code=True) ) - critic = AutoModelForCausalLMWithRewardHead( - AutoModelForCausalLM.from_pretrained(args.base_model, trust_remote_code=True) - ) + critic = AutoModelForCausalLMWithRewardHead(AutoModelForCausalLM.from_pretrained(args.base_model, trust_remote_code=True)) if args.rewards.trained_model: reward_model.load_state_dict(torch.load(args.rewards.trained_model, map_location=device)) critic.load_state_dict(torch.load(args.rewards.trained_model, map_location=device)) @@ -527,9 +527,7 @@ def truncate_response(args, tokenizer, responses): policy.load_state_dict(torch.load(args.sft_model_path, map_location=device)) ref_policy.load_state_dict(torch.load(args.sft_model_path, map_location=device)) print(f"loaded pretrained policy from {args.sft_model_path}") - policy.generation_config.eos_token_id = ( - None # disable `pad_token_id` and `eos_token_id` because we just want to - ) + policy.generation_config.eos_token_id = None # disable `pad_token_id` and `eos_token_id` because we just want to policy.generation_config.pad_token_id = None # generate tokens without truncation / padding model = PolicyAndValueWrapper(policy, critic) if args.optimizer == "tf_adam": @@ -587,9 +585,15 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well with torch.no_grad(): sample_validation_queries = shift_pad_id_left(sample_validation_queries, tokenizer.pad_token_id) sample_validation_reference_response = torch.Tensor(sample_validation["reference_response_token"]).to(device) - sample_validation_query_reference_responses = torch.cat((sample_validation_queries, sample_validation_reference_response), dim=1) - sample_validation_query_reference_responses = shift_pad_id_left(sample_validation_query_reference_responses, tokenizer.pad_token_id) - _, sample_validation_reference_scores = get_reward(reward_model, sample_validation_query_reference_responses, tokenizer) + sample_validation_query_reference_responses = torch.cat( + (sample_validation_queries, sample_validation_reference_response), dim=1 + ) + sample_validation_query_reference_responses = shift_pad_id_left( + sample_validation_query_reference_responses, tokenizer.pad_token_id + ) + _, sample_validation_reference_scores = get_reward( + reward_model, sample_validation_query_reference_responses, tokenizer + ) iter_dataloader = iter(repeat_generator()) kl_ctl = AdaptiveKLController(args.rewards.kl_coef, hparams=args.rewards.adaptive_kl) @@ -624,7 +628,7 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well reference_responses = data["reference_response_token"].to(device) queries = shift_pad_id_left(queries, tokenizer.pad_token_id) query_reference_responses = torch.cat((queries, reference_responses), dim=1) - query_reference_responses= shift_pad_id_left(query_reference_responses, tokenizer.pad_token_id) + query_reference_responses = shift_pad_id_left(query_reference_responses, tokenizer.pad_token_id) query_responses = generate( accelerator.unwrap_model(model).policy, queries, @@ -643,14 +647,18 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well ) sample_validation_responses = sample_validation_query_responses[:, context_length:] postprocessed_sample_validation_responses = truncate_response(args, tokenizer, sample_validation_responses) - postprocessed_sample_validation_query_responses = torch.cat((sample_validation_queries, postprocessed_sample_validation_responses), 1) - postprocessed_sample_validation_query_responses = shift_pad_id_left(postprocessed_sample_validation_query_responses, tokenizer.pad_token_id) + postprocessed_sample_validation_query_responses = torch.cat( + (sample_validation_queries, postprocessed_sample_validation_responses), 1 + ) + postprocessed_sample_validation_query_responses = shift_pad_id_left( + postprocessed_sample_validation_query_responses, tokenizer.pad_token_id + ) torch.cuda.empty_cache() # TODO: do I do this with query response or post-processed query response? output = forward(accelerator.unwrap_model(model).policy, query_responses, tokenizer) logits = output.logits[:, context_length - 1 : -1] - logits /= (args.task.temperature + 1e-7) + logits /= args.task.temperature + 1e-7 all_logprobs = F.log_softmax(logits, dim=-1) logprobs = torch.gather(all_logprobs, 2, responses.unsqueeze(-1)).squeeze(-1) del output, logits, all_logprobs @@ -658,7 +666,7 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well ref_output = forward(ref_policy, query_responses, tokenizer) ref_logits = ref_output.logits[:, context_length - 1 : -1] - ref_logits /= (args.task.temperature + 1e-7) + ref_logits /= args.task.temperature + 1e-7 ref_all_logprobs = F.log_softmax(ref_logits, dim=-1) ref_logprobs = torch.gather(ref_all_logprobs, 2, responses.unsqueeze(-1)).squeeze(-1) del ref_output, ref_logits, ref_all_logprobs @@ -682,7 +690,7 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well _, reference_scores = get_reward(reward_model, query_reference_responses, tokenizer) _, validation_score = get_reward(reward_model, postprocessed_sample_validation_query_responses, tokenizer) - + # carperAI-style score normaliation # accelerator.print("before score", scores, scores.mean()) # accelerator.print("reference_scores", reference_scores, reference_scores.mean()) @@ -709,18 +717,15 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well if args.print_sample_output_freq > 0 and (update - 1) % args.print_sample_output_freq == 0: try: all_decode_validation_queries = tokenizer.batch_decode(sample_validation_queries, skip_special_tokens=True) - all_sample_validation_responses = tokenizer.batch_decode( - postprocessed_sample_validation_responses - ) + all_sample_validation_responses = tokenizer.batch_decode(postprocessed_sample_validation_responses) all_sample_validation_query_responses_postprocessed = tokenizer.batch_decode( postprocessed_sample_validation_query_responses, skip_special_tokens=True ) all_sample_validation_postprocessed_responses = [ - x[len(y) :] for x, y in zip(all_sample_validation_query_responses_postprocessed, all_decode_validation_queries) + x[len(y) :] + for x, y in zip(all_sample_validation_query_responses_postprocessed, all_decode_validation_queries) ] - all_sample_validation_reference_responses = tokenizer.batch_decode( - sample_validation_reference_response - ) + all_sample_validation_reference_responses = tokenizer.batch_decode(sample_validation_reference_response) all_sample_validation_df = pd.DataFrame( { "query": all_decode_validation_queries, @@ -734,7 +739,7 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well if accelerator.is_main_process and args.track: wandb.log({"samples/query_responses": wandb.Table(dataframe=all_sample_validation_df)}, step=update) print_rich_table("stuff", all_sample_validation_df[:4], console) - + except Exception as e: print(e) del ( @@ -781,9 +786,9 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well mb_logprobs = logprobs[micro_batch_inds] output, (vpred_temp, _) = forward(model, mb_query_responses, tokenizer) - + logits = output.logits[:, context_length - 1 : -1] - logits /= (args.task.temperature + 1e-7) + logits /= args.task.temperature + 1e-7 new_all_logprobs = F.log_softmax(logits, dim=-1) new_logprobs = torch.gather(new_all_logprobs, 2, mb_responses.unsqueeze(-1)).squeeze(-1) # new_logprobs = torch.masked_fill(new_logprobs, padding_mask[micro_batch_inds], INVALID_LOGPROB) @@ -893,5 +898,5 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well tokenizer.save_pretrained(repo_id, push_to_hub=True) # if __name__ == "__main__": -# args = tyro.cli(Args) +# args = tyro.cli(Args) # train(args) diff --git a/lm_human_preference_details/train_reward_accelerate_summarize.py b/lm_human_preference_details/train_reward_accelerate_summarize.py index afb18d9..bb5d49d 100644 --- a/lm_human_preference_details/train_reward_accelerate_summarize.py +++ b/lm_human_preference_details/train_reward_accelerate_summarize.py @@ -11,7 +11,6 @@ import torch.nn as nn import torch.nn.functional as F import torch.optim as optim -from tqdm import tqdm import transformers import tyro from accelerate import Accelerator @@ -27,11 +26,9 @@ _get_value, _use_grad_for_differentiable, ) -from torch.utils.data import DataLoader from torch.utils.tensorboard import SummaryWriter -from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig, get_scheduler - -from lm_human_preference_details.data import process_query +from tqdm import tqdm +from transformers import AutoModelForCausalLM, AutoTokenizer, get_scheduler @dataclass @@ -153,7 +150,6 @@ class Args: labels: LabelHParams = field(default_factory=LabelHParams) - def print_rich_table(title: str, df: pd.DataFrame, console: Console) -> Table: table = Table(show_lines=True) for column in df.columns: @@ -411,7 +407,7 @@ def evaluate(args, accelerator, tokenizer, device, reward_model, validation_labe torch.from_numpy(np.stack(mb_data[f"response{i}_token"])).to(device) for i in range(args.labels.num_labels) ] mb_query_tiled = mb_query.unsqueeze(1).repeat(1, len(mb_responses), 1) - query_responses = torch.cat([mb_query_tiled, torch.stack(mb_responses).transpose(0,1)], dim=2).flatten(0, 1) + query_responses = torch.cat([mb_query_tiled, torch.stack(mb_responses).transpose(0, 1)], dim=2).flatten(0, 1) query_responses = shift_pad_id_left(query_responses, tokenizer.pad_token_id) predicted_reward = get_reward(reward_model, query_responses, tokenizer) predicted_reward = predicted_reward.view(-1, len(mb_responses)) @@ -513,12 +509,11 @@ def train(args: Args): ) if args.deepspeed: - import deepspeed + pass deepspeed_states = AcceleratorState().deepspeed_plugin deepspeed_states.deepspeed_config["train_micro_batch_size_per_gpu"] = args.local_micro_batch_size - reward_model, optimizer, scheduler = accelerator.prepare(reward_model, optimizer, scheduler) # `label` has keys `['sample0', 'query', 'best', 'sample3', 'sample1', 'sample2']` @@ -543,14 +538,15 @@ def train(args: Args): end = start + args.batch_size b_inds_all = all_inds[start:end] b_inds = b_inds_all[accelerator.process_index :: accelerator.num_processes] # multi-GPU slicing - if accelerator.is_main_process: pprint( - { - "global_step": global_step, - "start:end": f"{start}:{end}", - "b_inds_all": b_inds_all, - "b_inds": b_inds, - } - ) + if accelerator.is_main_process: + pprint( + { + "global_step": global_step, + "start:end": f"{start}:{end}", + "b_inds_all": b_inds_all, + "b_inds": b_inds, + } + ) losses = torch.zeros((args.gradient_accumulation_steps,), device=device) accuracies = torch.zeros((args.gradient_accumulation_steps,), device=device) reward_preferreds = torch.zeros((args.gradient_accumulation_steps,), device=device) @@ -564,10 +560,13 @@ def train(args: Args): mb_query = torch.from_numpy(np.stack(mb_data["query_token"])).to(device) mb_best = torch.from_numpy(np.stack(mb_data["choice"])).to(device) mb_responses = [ - torch.from_numpy(np.stack(mb_data[f"response{i}_token"])).to(device) for i in range(args.labels.num_labels) + torch.from_numpy(np.stack(mb_data[f"response{i}_token"])).to(device) + for i in range(args.labels.num_labels) ] mb_query_tiled = mb_query.unsqueeze(1).repeat(1, len(mb_responses), 1) - query_responses = torch.cat([mb_query_tiled, torch.stack(mb_responses).transpose(0,1)], dim=2).flatten(0, 1) + query_responses = torch.cat([mb_query_tiled, torch.stack(mb_responses).transpose(0, 1)], dim=2).flatten( + 0, 1 + ) query_responses = shift_pad_id_left(query_responses, tokenizer.pad_token_id) predicted_reward = get_reward(reward_model, query_responses, tokenizer) predicted_reward = predicted_reward.view(-1, len(mb_responses)) @@ -602,8 +601,8 @@ def train(args: Args): accelerator.print("train/accuracy", train_accuracy) # if args.print_sample_output_freq > 0 and global_step % args.print_sample_output_freq == 0: - if global_step == args.num_updates - 1: # first and last update - # if global_step == 1: + if global_step == args.num_updates - 1: # first and last update + # if global_step == 1: dev_validation_accuracy = evaluate(args, accelerator, tokenizer, device, reward_model, dev_validation_label) writer.add_scalar("dev_validation/accuracy", dev_validation_accuracy, global_step) accelerator.print("dev_validation/accuracy", dev_validation_accuracy, global_step) diff --git a/lm_human_preference_details/train_sft_accelerate_summarize.py b/lm_human_preference_details/train_sft_accelerate_summarize.py index cf16d90..d7100e9 100644 --- a/lm_human_preference_details/train_sft_accelerate_summarize.py +++ b/lm_human_preference_details/train_sft_accelerate_summarize.py @@ -6,19 +6,19 @@ from types import SimpleNamespace from typing import List, Literal, Optional +import evaluate import numpy as np import pandas as pd import torch import torch.optim as optim -from torch.nn import functional as F import tyro -import evaluate from accelerate import Accelerator from datasets import load_dataset from rich.console import Console from rich.pretty import pprint from rich.table import Table from torch import Tensor, optim +from torch.nn import functional as F from torch.optim.optimizer import ( _dispatch_sqrt, _get_value, @@ -26,7 +26,13 @@ ) from torch.utils.data import DataLoader from torch.utils.tensorboard import SummaryWriter -from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, GenerationConfig, get_scheduler +from transformers import ( + AutoConfig, + AutoModelForCausalLM, + AutoTokenizer, + GenerationConfig, + get_scheduler, +) @dataclass @@ -37,7 +43,7 @@ class SFTHParams: lr: float = 6.35e-5 eps: float = 1e-5 total_episodes: tyro.conf.Suppress[int] = None - local_batch_size:tyro.conf.Suppress[int] = None + local_batch_size: tyro.conf.Suppress[int] = None batch_size: tyro.conf.Suppress[int] = None mini_batch_size: tyro.conf.Suppress[int] = None world_size: tyro.conf.Suppress[int] = None @@ -126,8 +132,7 @@ class Args: # taken from https://github.com/microsoft/DeepSpeedExamples/blob/737c6740bec38b77a24a59135b6481a53d566b38/applications/DeepSpeed-Chat/training/utils/model/model_utils.py#L20C1-L26C52 def configure_dropout(model_config, dropout): if dropout is not None: - for key in ('dropout', 'attention_dropout', 'hidden_dropout', - 'activation_dropout'): + for key in ("dropout", "attention_dropout", "hidden_dropout", "activation_dropout"): if hasattr(model_config, key): print(f"Setting model_config.{key} to {dropout}") setattr(model_config, key, dropout) @@ -391,7 +396,7 @@ def forward(policy, query_responses, tokenizer): # we use the padding token manually but do not resize the token embedding of the model tokenizer.add_special_tokens({"pad_token": "[PAD]"}) model_config = AutoConfig.from_pretrained(args.base_model) - configure_dropout(model_config, 0.0) # disable dropout + configure_dropout(model_config, 0.0) # disable dropout policy = AutoConfig, AutoModelForCausalLM.from_pretrained(args.base_model, trust_remote_code=True, config=model_config) policy.generation_config.eos_token_id = None # disable `pad_token_id` and `eos_token_id` because we just want to policy.generation_config.pad_token_id = None # generate tokens without truncation / padding @@ -416,7 +421,9 @@ def forward(policy, query_responses, tokenizer): dataloader = DataLoader(dataset, batch_size=args.sft.local_micro_batch_size) validation_dataset = validation_dataset.with_format("torch", columns=["query_token", "reference_response_token"]) validation_dataloader = DataLoader(validation_dataset, batch_size=args.sft.local_micro_batch_size) - policy, optimizer, dataloader, validation_dataloader, scheduler = accelerator.prepare(policy, optimizer, dataloader, validation_dataloader, scheduler) + policy, optimizer, dataloader, validation_dataloader, scheduler = accelerator.prepare( + policy, optimizer, dataloader, validation_dataloader, scheduler + ) iter_dataloader = iter(dataloader) # WARNING: even with `max_new_tokens` and `min_new_tokens` set to the same value, the number of tokens generated # may not be the same. TODO: investigate further, we just want to generate a fixed number of tokens @@ -475,22 +482,32 @@ def forward(policy, query_responses, tokenizer): validation_reference_responses = validation_data["reference_response_token"].to(device, non_blocking=True) validation_queries = validation_data["query_token"].to(device, non_blocking=True) validation_queries = shift_pad_id_left(validation_queries, tokenizer.pad_token_id) - validation_query_reference_responses = torch.cat((validation_queries, validation_reference_responses), dim=1) + validation_query_reference_responses = torch.cat( + (validation_queries, validation_reference_responses), dim=1 + ) validation_output = forward(policy, validation_query_reference_responses, tokenizer) - validation_labels = validation_query_reference_responses.masked_fill(validation_query_reference_responses == tokenizer.pad_token_id, -1) + validation_labels = validation_query_reference_responses.masked_fill( + validation_query_reference_responses == tokenizer.pad_token_id, -1 + ) if args.sft.lm_loss_on_response_only: - validation_labels[:, :queries.shape[1]] = -1 + validation_labels[:, : queries.shape[1]] = -1 validation_lm_logits = validation_output.logits # hand-rolled transformer loss: Shift so that tokens < n predict n # but unlike `transformers` we mask the padding tokens via `ignore_index=-1` validation_shift_logits = validation_lm_logits[..., :-1, :].contiguous() validation_shift_labels = validation_labels[..., 1:].contiguous() - validation_loss = F.cross_entropy(validation_shift_logits.view(-1, validation_shift_logits.size(-1)), validation_shift_labels.view(-1), ignore_index=-1) + validation_loss = F.cross_entropy( + validation_shift_logits.view(-1, validation_shift_logits.size(-1)), + validation_shift_labels.view(-1), + ignore_index=-1, + ) validation_loss = accelerator.gather(validation_loss) all_validation_losses.append(validation_loss) - generated_responses = generate(accelerator.unwrap_model(policy), validation_queries, tokenizer, generation_config) + generated_responses = generate( + accelerator.unwrap_model(policy), validation_queries, tokenizer, generation_config + ) decode_validation_queries = tokenizer.batch_decode(accelerator.gather(validation_queries)) decode_validation_query_responses = tokenizer.batch_decode(accelerator.gather(generated_responses)) decode_validation_reference_responses = tokenizer.batch_decode( @@ -499,13 +516,17 @@ def forward(policy, query_responses, tokenizer): decode_validation_responses = [ x[len(y) :] for x, y in zip(decode_validation_query_responses, decode_validation_queries) ] - rouge_score = rouge.compute(predictions=decode_validation_responses, references=decode_validation_reference_responses) + rouge_score = rouge.compute( + predictions=decode_validation_responses, references=decode_validation_reference_responses + ) rouge_scores["rouge1"].append(rouge_score["rouge1"]) rouge_scores["rouge2"].append(rouge_score["rouge2"]) rouge_scores["rougeL"].append(rouge_score["rougeL"]) all_decode_validation_queries.extend(decode_validation_queries) - accelerator.print("len(all_decode_validation_queries)", len(all_decode_validation_queries), decode_validation_responses) + accelerator.print( + "len(all_decode_validation_queries)", len(all_decode_validation_queries), decode_validation_responses + ) all_decode_validation_query_responses.extend(decode_validation_query_responses) all_decode_validation_responses.extend(decode_validation_responses) all_decode_validation_reference_responses.extend(decode_validation_reference_responses) @@ -526,7 +547,7 @@ def forward(policy, query_responses, tokenizer): print_rich_table(f"Sample Output at Step {update}", all_df[:4], console) except Exception as e: print(e) - + for k, v in rouge_scores.items(): rouge_metric = torch.tensor(v, device=device) rouge_metric = accelerator.gather(rouge_metric) From 58db7b1251a1f6e6fc3c45574e92ac322a17e060 Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Sat, 28 Oct 2023 19:17:01 -0400 Subject: [PATCH 20/62] pre-commit --- ...in_policy_accelerate_summarize_separate.py | 71 +++++++++++-------- .../train_reward_accelerate_summarize.py | 52 ++++++-------- .../train_sft_accelerate_summarize.py | 3 +- 3 files changed, 64 insertions(+), 62 deletions(-) diff --git a/lm_human_preference_details/train_policy_accelerate_summarize_separate.py b/lm_human_preference_details/train_policy_accelerate_summarize_separate.py index 18e50fa..ea365d6 100644 --- a/lm_human_preference_details/train_policy_accelerate_summarize_separate.py +++ b/lm_human_preference_details/train_policy_accelerate_summarize_separate.py @@ -26,7 +26,12 @@ ) from torch.utils.data import DataLoader from torch.utils.tensorboard import SummaryWriter -from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig +from transformers import ( + AutoConfig, + AutoModelForCausalLM, + AutoTokenizer, + GenerationConfig, +) INVALID_LOGPROB = 1.0 @@ -150,16 +155,13 @@ class Args: ppo: PpoHParams = field(default_factory=PpoHParams) -def first_true_indices(bools, dtype=torch.long): - """ - Takes an N-dimensional bool tensor and returns an (N-1)-dimensional tensor of integers giving - the position of the first True in each "row". - - Returns the length of the rows (bools.size(-1)) if no element is True in a given row. - """ - row_len = bools.size(-1) - zero_or_index = row_len * (~bools).type(dtype) + torch.arange(row_len, dtype=dtype, device=bools.device) - return torch.min(zero_or_index, dim=-1).values +# taken from https://github.com/microsoft/DeepSpeedExamples/blob/737c6740bec38b77a24a59135b6481a53d566b38/applications/DeepSpeed-Chat/training/utils/model/model_utils.py#L20C1-L26C52 +def configure_dropout(model_config, dropout): + if dropout is not None: + for key in ("dropout", "attention_dropout", "hidden_dropout", "activation_dropout"): + if hasattr(model_config, key): + print(f"Setting model_config.{key} to {dropout}") + setattr(model_config, key, dropout) def print_rich_table(title: str, df: pd.DataFrame, console: Console) -> Table: @@ -378,15 +380,6 @@ def forward(self, **kwargs): return self.policy(**kwargs), self.critic(**kwargs) -def shift_pad_id_left(tokens, pad_id): - """Convert from right padding to left padding.""" - assert tokens.ndim == 2 - return torch.tensor( - [[pad_id] * (row == pad_id).sum() + [x for x in row if x != pad_id] for row in tokens], - device=tokens.device, - ) - - def shift_pad_id_left(data, pad_id): # Step 1: Create a boolean mask mask = (data == pad_id).long() @@ -450,6 +443,18 @@ def forward(policy, query_responses, tokenizer): ) +def first_true_indices(bools, dtype=torch.long): + """ + Takes an N-dimensional bool tensor and returns an (N-1)-dimensional tensor of integers giving + the position of the first True in each "row". + + Returns the length of the rows (bools.size(-1)) if no element is True in a given row. + """ + row_len = bools.size(-1) + zero_or_index = row_len * (~bools).type(dtype) + torch.arange(row_len, dtype=dtype, device=bools.device) + return torch.min(zero_or_index, dim=-1).values + + def truncate_response(args, tokenizer, responses): trunc_idxs = first_true_indices(responses == args.task.truncate_token).unsqueeze(-1) new_size = [1] * (len(responses.size()) - 1) + [args.task.response_length] @@ -512,10 +517,22 @@ def truncate_response(args, tokenizer, responses): ) # we use the padding token manually but do not resize the token embedding of the model tokenizer.add_special_tokens({"pad_token": "[PAD]"}) + model_config = AutoConfig.from_pretrained(args.base_model) + configure_dropout(model_config, 0.0) # disable dropout reward_model = AutoModelForCausalLMWithRewardHead( - AutoModelForCausalLM.from_pretrained(args.base_model, trust_remote_code=True) + AutoModelForCausalLM.from_pretrained( + args.base_model, + config=model_config, + trust_remote_code=True, + ) + ) + critic = AutoModelForCausalLMWithRewardHead( + AutoModelForCausalLM.from_pretrained( + args.base_model, + config=model_config, + trust_remote_code=True, + ) ) - critic = AutoModelForCausalLMWithRewardHead(AutoModelForCausalLM.from_pretrained(args.base_model, trust_remote_code=True)) if args.rewards.trained_model: reward_model.load_state_dict(torch.load(args.rewards.trained_model, map_location=device)) critic.load_state_dict(torch.load(args.rewards.trained_model, map_location=device)) @@ -617,6 +634,7 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well vf_losses_stats = torch.zeros(stats_shape, device=device) vf_clipfrac_stats = torch.zeros(stats_shape, device=device) entropies_stats = torch.zeros(stats_shape, device=device) + model.train() for update in range(1, args.ppo.num_updates + 1): global_step += 1 * args.ppo.batch_size frac = 1.0 - (update - 1.0) / args.ppo.num_updates @@ -692,10 +710,7 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well _, validation_score = get_reward(reward_model, postprocessed_sample_validation_query_responses, tokenizer) # carperAI-style score normaliation - # accelerator.print("before score", scores, scores.mean()) - # accelerator.print("reference_scores", reference_scores, reference_scores.mean()) scores = scores - reference_scores - # accelerator.print("after score", scores, scores.mean()) # 3. filter response. Ensure that the sample contains truncate_token # responses not passing that filter will receive a low (fixed) score @@ -887,11 +902,11 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well del kl, mean_kl, mean_entropy, mean_non_score_reward, scores # save model - if accelerator.is_main_process and args.save_path: + if args.save_path: os.makedirs(os.path.dirname(args.save_path), exist_ok=True) - torch.save(policy.state_dict(), args.save_path) + accelerator.save_model(policy, args.save_path, max_shard_size="1000GB") - if args.upload_model: + if args.upload_model and accelerator.is_main_process: repo_name = f"{args.exp_name}__{args.rewards.label_dataset}__seed{args.seed}__{int(time.time())}" repo_id = f"{args.hf_entity}/{repo_name}" if args.hf_entity else repo_name policy.save_pretrained(repo_id, safe_serialization=True, push_to_hub=True) diff --git a/lm_human_preference_details/train_reward_accelerate_summarize.py b/lm_human_preference_details/train_reward_accelerate_summarize.py index bb5d49d..fcbba82 100644 --- a/lm_human_preference_details/train_reward_accelerate_summarize.py +++ b/lm_human_preference_details/train_reward_accelerate_summarize.py @@ -28,7 +28,7 @@ ) from torch.utils.tensorboard import SummaryWriter from tqdm import tqdm -from transformers import AutoModelForCausalLM, AutoTokenizer, get_scheduler +from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, get_scheduler @dataclass @@ -144,12 +144,19 @@ class Args: """Which scheduler to use""" warm_up_steps: int = 100 """Number of warm up steps for the scheduler""" - model_dot_train: bool = False - """Whether to call `model.train()`""" task: TaskHParams = field(default_factory=TaskHParams) labels: LabelHParams = field(default_factory=LabelHParams) +# taken from https://github.com/microsoft/DeepSpeedExamples/blob/737c6740bec38b77a24a59135b6481a53d566b38/applications/DeepSpeed-Chat/training/utils/model/model_utils.py#L20C1-L26C52 +def configure_dropout(model_config, dropout): + if dropout is not None: + for key in ("dropout", "attention_dropout", "hidden_dropout", "activation_dropout"): + if hasattr(model_config, key): + print(f"Setting model_config.{key} to {dropout}") + setattr(model_config, key, dropout) + + def print_rich_table(title: str, df: pd.DataFrame, console: Console) -> Table: table = Table(show_lines=True) for column in df.columns: @@ -357,21 +364,6 @@ def exact_div(a, b): return q -def generate(lm_backbone, queries, tokenizer, generation_config): - """generate in a way that does not affect padding tokens""" - context_length = queries.shape[1] - attention_mask = queries != tokenizer.pad_token_id - input_ids = torch.masked_fill(queries, ~attention_mask, 0) - output = lm_backbone.generate( - input_ids=input_ids, - attention_mask=attention_mask, - # position_ids=attention_mask.cumsum(1) - attention_mask.long(), # generation collapsed if this was turned on. TODO: why does generation collapse with this? - generation_config=generation_config, - return_dict_in_generate=True, - ) - return torch.cat((queries, output.sequences[:, context_length:]), dim=1) - - def get_reward(reward_model, query_responses, tokenizer): attention_mask = query_responses != tokenizer.pad_token_id position_ids = attention_mask.cumsum(1) - attention_mask.long() # exclusive cumsum @@ -414,8 +406,7 @@ def evaluate(args, accelerator, tokenizer, device, reward_model, validation_labe accuracy = (predicted_reward.argmax(1) == mb_best).float().mean() test_accuracies.append(accuracy) test_accuracy = accelerator.gather(torch.stack(test_accuracies).mean()).mean().item() - if args.model_dot_train: - reward_model.train() + reward_model.train() return test_accuracy @@ -471,8 +462,14 @@ def train(args: Args): ) # we use the padding token manually but do not resize the token embedding of the model tokenizer.add_special_tokens({"pad_token": "[PAD]"}) + model_config = AutoConfig.from_pretrained(args.base_model) + configure_dropout(model_config, 0.0) # disable dropout reward_model = AutoModelForCausalLMWithRewardHead( - AutoModelForCausalLM.from_pretrained(args.base_model, trust_remote_code=True) + AutoModelForCausalLM.from_pretrained( + args.base_model, + config=model_config, + trust_remote_code=True, + ) ) # freeze the first 70% of layers @@ -486,10 +483,6 @@ def train(args: Args): if args.sft_model_path: reward_model.lm_backbone.load_state_dict(torch.load(args.sft_model_path, map_location=device)) print(f"loaded SFT model from {args.sft_model_path}") - reward_model.lm_backbone.generation_config.eos_token_id = ( - None # disable `pad_token_id` and `eos_token_id` because we just want to - ) - reward_model.lm_backbone.generation_config.pad_token_id = None # generate tokens without truncation / padding # make sure the `lm_head` or `embed_out` does not require gradients, otherwise # pytorch DDP complains; see https://gist.github.com/vwxyzjn/45fc8706dfb3cf33695f0f57cc44a533 if isinstance(reward_model.lm_backbone, transformers.GPTNeoXForCausalLM): @@ -500,7 +493,6 @@ def train(args: Args): optimizer = optim.Adam(reward_model.parameters(), lr=args.lr, eps=args.eps) elif args.optimizer == "adamw": optimizer = optim.AdamW(reward_model.parameters(), lr=args.lr, eps=args.eps) - # TODO: use AdamW scheduler = get_scheduler( args.scheduler, optimizer=optimizer, @@ -509,8 +501,6 @@ def train(args: Args): ) if args.deepspeed: - pass - deepspeed_states = AcceleratorState().deepspeed_plugin deepspeed_states.deepspeed_config["train_micro_batch_size_per_gpu"] = args.local_micro_batch_size @@ -525,8 +515,7 @@ def train(args: Args): accelerator.print("training on", args.labels.num_train, "in batches of", args.local_batch_size) accelerator.print("===training reward model===") num_train = (args.labels.num_train // args.batch_size) * args.batch_size - if args.model_dot_train: - reward_model.train() + reward_model.train() for epoch in range(args.num_epochs): all_inds = np.random.permutation(args.labels.num_train) # ensure that all processes have the same shuffled indices @@ -618,8 +607,7 @@ def train(args: Args): # save model if args.save_path: os.makedirs(os.path.dirname(args.save_path), exist_ok=True) - # torch.save(accelerator.unwrap_model(reward_model).state_dict(), args.save_path) - accelerator.save_model(reward_model, args.save_path) + accelerator.save_model(reward_model, args.save_path, max_shard_size="1000GB") if accelerator.is_main_process and args.track: wandb.finish() diff --git a/lm_human_preference_details/train_sft_accelerate_summarize.py b/lm_human_preference_details/train_sft_accelerate_summarize.py index d7100e9..e24a008 100644 --- a/lm_human_preference_details/train_sft_accelerate_summarize.py +++ b/lm_human_preference_details/train_sft_accelerate_summarize.py @@ -397,7 +397,7 @@ def forward(policy, query_responses, tokenizer): tokenizer.add_special_tokens({"pad_token": "[PAD]"}) model_config = AutoConfig.from_pretrained(args.base_model) configure_dropout(model_config, 0.0) # disable dropout - policy = AutoConfig, AutoModelForCausalLM.from_pretrained(args.base_model, trust_remote_code=True, config=model_config) + policy = AutoModelForCausalLM.from_pretrained(args.base_model, config=model_config, trust_remote_code=True) policy.generation_config.eos_token_id = None # disable `pad_token_id` and `eos_token_id` because we just want to policy.generation_config.pad_token_id = None # generate tokens without truncation / padding # IMPORTANT: Layer norm produces weird gradients, which affects Adam optimizer to impact all the parameters systematically @@ -408,7 +408,6 @@ def forward(policy, query_responses, tokenizer): optimizer = optim.Adam(policy.parameters(), lr=args.sft.lr, eps=args.sft.eps) elif args.optimizer == "adamw": optimizer = optim.AdamW(policy.parameters(), lr=args.sft.lr, eps=args.sft.eps) - # TODO: use AdamW scheduler = get_scheduler( args.scheduler, optimizer=optimizer, From c9183405fd05be1e6bf0177aaaaa8b49a0f2bf75 Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Sat, 28 Oct 2023 19:17:51 -0400 Subject: [PATCH 21/62] remove unnecessary stuff --- .pre-commit-config.yaml | 2 +- .../train_policy_accelerate_summarize_separate.py | 5 ----- 2 files changed, 1 insertion(+), 6 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 5a300b2..d44ff3a 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -35,5 +35,5 @@ repos: hooks: - id: codespell args: - - --ignore-words-list=nd,reacher,thist,ths,magent,ba + - --ignore-words-list=nd,reacher,thist,ths,magent,ba,rouge - --skip=docs/css/termynal.css,docs/js/termynal.js \ No newline at end of file diff --git a/lm_human_preference_details/train_policy_accelerate_summarize_separate.py b/lm_human_preference_details/train_policy_accelerate_summarize_separate.py index ea365d6..ac86b6a 100644 --- a/lm_human_preference_details/train_policy_accelerate_summarize_separate.py +++ b/lm_human_preference_details/train_policy_accelerate_summarize_separate.py @@ -33,8 +33,6 @@ GenerationConfig, ) -INVALID_LOGPROB = 1.0 - @dataclass class AdaptiveKLParams: @@ -700,8 +698,6 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well full_values, _ = get_reward(accelerator.unwrap_model(model).critic, postprocessed_query_responses, tokenizer) values = full_values[:, context_length - 1 : -1].squeeze(-1) padding_mask = postprocessed_responses == tokenizer.pad_token_id - # logprobs = torch.masked_fill(logprobs, padding_mask, INVALID_LOGPROB) - # ref_logprobs = torch.masked_fill(ref_logprobs, padding_mask, INVALID_LOGPROB) # values = torch.masked_fill(values, padding_mask, 0) rew, scores = get_reward(reward_model, postprocessed_query_responses, tokenizer) @@ -806,7 +802,6 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well logits /= args.task.temperature + 1e-7 new_all_logprobs = F.log_softmax(logits, dim=-1) new_logprobs = torch.gather(new_all_logprobs, 2, mb_responses.unsqueeze(-1)).squeeze(-1) - # new_logprobs = torch.masked_fill(new_logprobs, padding_mask[micro_batch_inds], INVALID_LOGPROB) vpred = vpred_temp[:, context_length - 1 : -1] # vpred = torch.masked_fill(vpred, padding_mask[micro_batch_inds], 0) vpredclipped = torch.clamp( From ca76b55e148cde41c2f516f88328d1acf2559736 Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Sun, 29 Oct 2023 00:46:23 +0000 Subject: [PATCH 22/62] dropout proper setting --- ...in_policy_accelerate_summarize_separate.py | 63 +++++++++++-------- .../train_reward_accelerate_summarize.py | 10 ++- .../train_sft_accelerate_summarize.py | 21 +++---- 3 files changed, 52 insertions(+), 42 deletions(-) diff --git a/lm_human_preference_details/train_policy_accelerate_summarize_separate.py b/lm_human_preference_details/train_policy_accelerate_summarize_separate.py index ac86b6a..099d69d 100644 --- a/lm_human_preference_details/train_policy_accelerate_summarize_separate.py +++ b/lm_human_preference_details/train_policy_accelerate_summarize_separate.py @@ -34,6 +34,9 @@ ) +INVALID_LOGPROB = 1.0 + + @dataclass class AdaptiveKLParams: target: float = 6.0 @@ -45,7 +48,7 @@ class RewardHParams: kl_coef: float = 0.15 use_adaptive_kl: bool = True adaptive_kl: Optional[AdaptiveKLParams] = field(default_factory=AdaptiveKLParams) - trained_model: Optional[str] = "models/gpt2medium_last_index_reward/pytorch_model.bin" + trained_model: Optional[str] = "" label_dataset: tyro.conf.Suppress[Optional[str]] = None @@ -138,6 +141,8 @@ class Args: base_model: str = "gpt2" """the name of the pretrained model to use""" + dropout_layer_keys: List[str] = field(default_factory=lambda: ["attn_pdrop", "embd_pdrop", "resid_pdrop", "summary_first_dropout"]) + """Which layers to apply dropout to""" deepspeed: bool = False """Whether to use deepspeed to train the model""" print_sample_output_freq: int = 10 @@ -154,9 +159,9 @@ class Args: # taken from https://github.com/microsoft/DeepSpeedExamples/blob/737c6740bec38b77a24a59135b6481a53d566b38/applications/DeepSpeed-Chat/training/utils/model/model_utils.py#L20C1-L26C52 -def configure_dropout(model_config, dropout): +def configure_dropout(model_config, dropout_layer_keys, dropout): if dropout is not None: - for key in ("dropout", "attention_dropout", "hidden_dropout", "activation_dropout"): + for key in dropout_layer_keys: if hasattr(model_config, key): print(f"Setting model_config.{key} to {dropout}") setattr(model_config, key, dropout) @@ -516,7 +521,9 @@ def truncate_response(args, tokenizer, responses): # we use the padding token manually but do not resize the token embedding of the model tokenizer.add_special_tokens({"pad_token": "[PAD]"}) model_config = AutoConfig.from_pretrained(args.base_model) - configure_dropout(model_config, 0.0) # disable dropout + configure_dropout(model_config, args.dropout_layer_keys, 0.0) # disable dropout + if accelerator.is_main_process: + pprint(model_config) reward_model = AutoModelForCausalLMWithRewardHead( AutoModelForCausalLM.from_pretrained( args.base_model, @@ -536,8 +543,8 @@ def truncate_response(args, tokenizer, responses): critic.load_state_dict(torch.load(args.rewards.trained_model, map_location=device)) print(f"loaded pretrained reward model from {args.rewards.trained_model}") # each class should have a separate pretrained model that do not share weights - ref_policy = AutoModelForCausalLM.from_pretrained(args.base_model, trust_remote_code=True) - policy = AutoModelForCausalLM.from_pretrained(args.base_model, trust_remote_code=True) + ref_policy = AutoModelForCausalLM.from_pretrained(args.base_model, config=model_config, trust_remote_code=True) + policy = AutoModelForCausalLM.from_pretrained(args.base_model, config=model_config, trust_remote_code=True) if args.sft_model_path: policy.load_state_dict(torch.load(args.sft_model_path, map_location=device)) ref_policy.load_state_dict(torch.load(args.sft_model_path, map_location=device)) @@ -626,12 +633,13 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well print("===training policy===") global_step = 0 stats_shape = (args.ppo.noptepochs, args.ppo.nminibatches, args.ppo.gradient_accumulation_steps) - approxkls_stats = torch.zeros(stats_shape, device=device) - clipfracs_stats = torch.zeros(stats_shape, device=device) - pg_losses_stats = torch.zeros(stats_shape, device=device) - vf_losses_stats = torch.zeros(stats_shape, device=device) + approxkl_stats = torch.zeros(stats_shape, device=device) + pg_clipfrac_stats = torch.zeros(stats_shape, device=device) + pg_loss_stats = torch.zeros(stats_shape, device=device) + vf_loss_stats = torch.zeros(stats_shape, device=device) vf_clipfrac_stats = torch.zeros(stats_shape, device=device) - entropies_stats = torch.zeros(stats_shape, device=device) + entropy_stats = torch.zeros(stats_shape, device=device) + ratio_stats = torch.zeros(stats_shape, device=device) model.train() for update in range(1, args.ppo.num_updates + 1): global_step += 1 * args.ppo.batch_size @@ -698,6 +706,8 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well full_values, _ = get_reward(accelerator.unwrap_model(model).critic, postprocessed_query_responses, tokenizer) values = full_values[:, context_length - 1 : -1].squeeze(-1) padding_mask = postprocessed_responses == tokenizer.pad_token_id + # logprobs = torch.masked_fill(logprobs, padding_mask, INVALID_LOGPROB) + # ref_logprobs = torch.masked_fill(ref_logprobs, padding_mask, INVALID_LOGPROB) # values = torch.masked_fill(values, padding_mask, 0) rew, scores = get_reward(reward_model, postprocessed_query_responses, tokenizer) @@ -797,10 +807,10 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well mb_logprobs = logprobs[micro_batch_inds] output, (vpred_temp, _) = forward(model, mb_query_responses, tokenizer) - logits = output.logits[:, context_length - 1 : -1] logits /= args.task.temperature + 1e-7 new_all_logprobs = F.log_softmax(logits, dim=-1) + # new_logprobs = torch.masked_fill(new_logprobs, padding_mask[micro_batch_inds], INVALID_LOGPROB) new_logprobs = torch.gather(new_all_logprobs, 2, mb_responses.unsqueeze(-1)).squeeze(-1) vpred = vpred_temp[:, context_length - 1 : -1] # vpred = torch.masked_fill(vpred, padding_mask[micro_batch_inds], 0) @@ -827,12 +837,13 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well entropy = torch.logsumexp(logits, dim=-1) - torch.sum(prob_dist * logits, dim=-1) approxkl = 0.5 * (logprobs_diff**2).mean() with torch.no_grad(): - approxkls_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = approxkl - clipfracs_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = pg_clipfrac - pg_losses_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = pg_loss - vf_losses_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = vf_loss + approxkl_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = approxkl + pg_clipfrac_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = pg_clipfrac + pg_loss_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = pg_loss + vf_loss_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = vf_loss vf_clipfrac_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = vf_clipfrac - entropies_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = entropy.mean() + entropy_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = entropy.mean() + ratio_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = ratio.mean() gradient_accumulation_idx += 1 minibatch_idx += 1 if accelerator.is_main_process: @@ -840,13 +851,13 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well f"ppo_epoch_idx", ppo_epoch_idx, "approxkl", - approxkl.item(), + approxkl_stats[:ppo_epoch_idx+1].mean().item(), "pg_loss", - pg_loss.item(), + pg_loss_stats[:ppo_epoch_idx+1].mean().item(), "pg_clipfrac", - pg_clipfrac.item(), + pg_clipfrac_stats[:ppo_epoch_idx+1].mean().item(), "ratio", - ratio.mean().item(), + ratio_stats[:ppo_epoch_idx+1].mean().item(), ) with torch.no_grad(): @@ -872,12 +883,12 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well writer.add_scalar("ppo/policy/entropy", accelerator.gather(entropy.mean()).mean().item(), update) writer.add_scalar("ppo/policy/approxkl", accelerator.gather(approxkl).mean().item(), update) writer.add_scalar("ppo/policy/clipfrac", accelerator.gather(pg_clipfrac).mean().item(), update) - writer.add_scalar("ppo/policy/approxkl_avg", accelerator.gather(approxkls_stats).mean().item(), update) - writer.add_scalar("ppo/policy/clipfrac_avg", accelerator.gather(clipfracs_stats).mean().item(), update) - writer.add_scalar("ppo/loss/policy_avg", accelerator.gather(pg_losses_stats).mean().item(), update) - writer.add_scalar("ppo/loss/value_avg", accelerator.gather(vf_losses_stats).mean().item(), update) + writer.add_scalar("ppo/policy/approxkl_avg", accelerator.gather(approxkl_stats).mean().item(), update) + writer.add_scalar("ppo/policy/clipfrac_avg", accelerator.gather(pg_clipfrac_stats).mean().item(), update) + writer.add_scalar("ppo/loss/policy_avg", accelerator.gather(pg_loss_stats).mean().item(), update) + writer.add_scalar("ppo/loss/value_avg", accelerator.gather(vf_loss_stats).mean().item(), update) writer.add_scalar("ppo/val/clipfrac_avg", accelerator.gather(vf_clipfrac_stats).mean().item(), update) - writer.add_scalar("ppo/policy/entropy_avg", accelerator.gather(entropies_stats).mean().item(), update) + writer.add_scalar("ppo/policy/entropy_avg", accelerator.gather(entropy_stats).mean().item(), update) writer.add_scalar("ppo/returns/mean", accelerator.gather(return_mean).mean().item(), update) writer.add_scalar("ppo/returns/var", accelerator.gather(return_var).mean().item(), update) writer.add_scalar("ppo/val/vpred", accelerator.gather(vpred.mean()).mean().item(), update) diff --git a/lm_human_preference_details/train_reward_accelerate_summarize.py b/lm_human_preference_details/train_reward_accelerate_summarize.py index fcbba82..22a8976 100644 --- a/lm_human_preference_details/train_reward_accelerate_summarize.py +++ b/lm_human_preference_details/train_reward_accelerate_summarize.py @@ -92,6 +92,8 @@ class Args: base_model: str = "gpt2" """the name of the pretrained model to use""" + dropout_layer_keys: List[str] = field(default_factory=lambda: ["attn_pdrop", "embd_pdrop", "resid_pdrop", "summary_first_dropout"]) + """Which layers to apply dropout to""" deepspeed: bool = False """Whether to use deepspeed to train the model""" label_dataset: str = "vwxyzjn/summarize_from_feedback_oai_preprocessing" @@ -149,9 +151,9 @@ class Args: # taken from https://github.com/microsoft/DeepSpeedExamples/blob/737c6740bec38b77a24a59135b6481a53d566b38/applications/DeepSpeed-Chat/training/utils/model/model_utils.py#L20C1-L26C52 -def configure_dropout(model_config, dropout): +def configure_dropout(model_config, dropout_layer_keys, dropout): if dropout is not None: - for key in ("dropout", "attention_dropout", "hidden_dropout", "activation_dropout"): + for key in dropout_layer_keys: if hasattr(model_config, key): print(f"Setting model_config.{key} to {dropout}") setattr(model_config, key, dropout) @@ -463,7 +465,9 @@ def train(args: Args): # we use the padding token manually but do not resize the token embedding of the model tokenizer.add_special_tokens({"pad_token": "[PAD]"}) model_config = AutoConfig.from_pretrained(args.base_model) - configure_dropout(model_config, 0.0) # disable dropout + configure_dropout(model_config, args.dropout_layer_keys, 0.0) # disable dropout + if accelerator.is_main_process: + pprint(model_config) reward_model = AutoModelForCausalLMWithRewardHead( AutoModelForCausalLM.from_pretrained( args.base_model, diff --git a/lm_human_preference_details/train_sft_accelerate_summarize.py b/lm_human_preference_details/train_sft_accelerate_summarize.py index e24a008..4d1a112 100644 --- a/lm_human_preference_details/train_sft_accelerate_summarize.py +++ b/lm_human_preference_details/train_sft_accelerate_summarize.py @@ -113,6 +113,8 @@ class Args: base_model: str = "gpt2" """the name of the pretrained model to use""" + dropout_layer_keys: List[str] = field(default_factory=lambda: ["attn_pdrop", "embd_pdrop", "resid_pdrop", "summary_first_dropout"]) + """Which layers to apply dropout to""" deepspeed: bool = False """Whether to use deepspeed to train the model""" print_sample_output_freq: int = 220 @@ -130,9 +132,9 @@ class Args: # taken from https://github.com/microsoft/DeepSpeedExamples/blob/737c6740bec38b77a24a59135b6481a53d566b38/applications/DeepSpeed-Chat/training/utils/model/model_utils.py#L20C1-L26C52 -def configure_dropout(model_config, dropout): +def configure_dropout(model_config, dropout_layer_keys, dropout): if dropout is not None: - for key in ("dropout", "attention_dropout", "hidden_dropout", "activation_dropout"): + for key in dropout_layer_keys: if hasattr(model_config, key): print(f"Setting model_config.{key} to {dropout}") setattr(model_config, key, dropout) @@ -396,12 +398,12 @@ def forward(policy, query_responses, tokenizer): # we use the padding token manually but do not resize the token embedding of the model tokenizer.add_special_tokens({"pad_token": "[PAD]"}) model_config = AutoConfig.from_pretrained(args.base_model) - configure_dropout(model_config, 0.0) # disable dropout + configure_dropout(model_config, args.dropout_layer_keys, 0.0) # disable dropout + if accelerator.is_main_process: + pprint(model_config) policy = AutoModelForCausalLM.from_pretrained(args.base_model, config=model_config, trust_remote_code=True) policy.generation_config.eos_token_id = None # disable `pad_token_id` and `eos_token_id` because we just want to policy.generation_config.pad_token_id = None # generate tokens without truncation / padding - # IMPORTANT: Layer norm produces weird gradients, which affects Adam optimizer to impact all the parameters systematically - # see https://github.com/pytorch/pytorch/issues/104857 for more details if args.optimizer == "tf_adam": optimizer = AdamTensorFlowStyle(policy.parameters(), lr=args.sft.lr, eps=args.sft.eps) elif args.optimizer == "adam": @@ -480,17 +482,15 @@ def forward(policy, query_responses, tokenizer): with torch.no_grad(): validation_reference_responses = validation_data["reference_response_token"].to(device, non_blocking=True) validation_queries = validation_data["query_token"].to(device, non_blocking=True) - validation_queries = shift_pad_id_left(validation_queries, tokenizer.pad_token_id) validation_query_reference_responses = torch.cat( (validation_queries, validation_reference_responses), dim=1 ) + validation_query_reference_responses = shift_pad_id_left(validation_query_reference_responses, tokenizer.pad_token_id) validation_output = forward(policy, validation_query_reference_responses, tokenizer) validation_labels = validation_query_reference_responses.masked_fill( validation_query_reference_responses == tokenizer.pad_token_id, -1 ) - if args.sft.lm_loss_on_response_only: - validation_labels[:, : queries.shape[1]] = -1 validation_lm_logits = validation_output.logits # hand-rolled transformer loss: Shift so that tokens < n predict n # but unlike `transformers` we mask the padding tokens via `ignore_index=-1` @@ -523,14 +523,9 @@ def forward(policy, query_responses, tokenizer): rouge_scores["rougeL"].append(rouge_score["rougeL"]) all_decode_validation_queries.extend(decode_validation_queries) - accelerator.print( - "len(all_decode_validation_queries)", len(all_decode_validation_queries), decode_validation_responses - ) all_decode_validation_query_responses.extend(decode_validation_query_responses) all_decode_validation_responses.extend(decode_validation_responses) all_decode_validation_reference_responses.extend(decode_validation_reference_responses) - if validation_idx == 10: - break try: all_df = pd.DataFrame( From 89f208f7bc033a7e5b482bdf88b1a48331afe759 Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Sun, 29 Oct 2023 02:40:49 +0000 Subject: [PATCH 23/62] make it work with gpt-large --- lm_human_preference_details/train_policy.sh | 46 +++++++++++++++++++ ...in_policy_accelerate_summarize_separate.py | 11 +++-- 2 files changed, 52 insertions(+), 5 deletions(-) create mode 100644 lm_human_preference_details/train_policy.sh diff --git a/lm_human_preference_details/train_policy.sh b/lm_human_preference_details/train_policy.sh new file mode 100644 index 0000000..12faf0c --- /dev/null +++ b/lm_human_preference_details/train_policy.sh @@ -0,0 +1,46 @@ +# generate random seed and model paths +# set seed if not found in env +if [ -z "$SEED" ]; then + SEED=$RANDOM +fi +# SEED=1 +REWARD_MODEL_PATH=models/gpt2-large_reward_model_$SEED +SFT_MODEL_PATH=models/gpt2-large_sft_model_$SEED +POLICY_MODEL_PATH=models/gpt2-large_policy_model_$SEED +poetry run accelerate launch --config_file deepspeed.yaml \ + lm_human_preference_details/train_sft_accelerate_summarize.py \ + --base_model=gpt2-large \ + --deepspeed \ + --track \ + --upload_model \ + --save_path=$SFT_MODEL_PATH \ + --seed=$SEED \ + +poetry run accelerate launch --config_file deepspeed.yaml \ + lm_human_preference_details/train_reward_accelerate_summarize.py \ + --base_model=gpt2-large \ + --no_normalize_before --no_normalize_after \ + --local_batch_size=8 \ + --gradient_accumulation_steps=8 \ + --labels.num_train=92832 \ + --deepspeed \ + --track \ + --sft_model_path=$SFT_MODEL_PATH/pytorch_model.bin \ + --seed=$SEED \ + --save_path=$REWARD_MODEL_PATH \ + +poetry run accelerate launch --config_file deepspeed.yaml \ + lm_human_preference_details/train_policy_accelerate_summarize_separate.py \ + --base_model=gpt2-large \ + --rewards.no_use_adaptive_kl \ + --rewards.kl_coef=0.05 \ + --ppo.gradient_accumulation_steps=64 \ + --ppo.lr=1.5e-5 \ + --seed=3 \ + --task.temperature=0.7 \ + --deepspeed \ + --track \ + --upload_model \ + --sft_model_path=$SFT_MODEL_PATH/pytorch_model.bin \ + --rewards.trained_model=$REWARD_MODEL_PATH/pytorch_model.bin \ + --save_path=$POLICY_MODEL_PATH \ \ No newline at end of file diff --git a/lm_human_preference_details/train_policy_accelerate_summarize_separate.py b/lm_human_preference_details/train_policy_accelerate_summarize_separate.py index 099d69d..63cc0b7 100644 --- a/lm_human_preference_details/train_policy_accelerate_summarize_separate.py +++ b/lm_human_preference_details/train_policy_accelerate_summarize_separate.py @@ -706,9 +706,9 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well full_values, _ = get_reward(accelerator.unwrap_model(model).critic, postprocessed_query_responses, tokenizer) values = full_values[:, context_length - 1 : -1].squeeze(-1) padding_mask = postprocessed_responses == tokenizer.pad_token_id - # logprobs = torch.masked_fill(logprobs, padding_mask, INVALID_LOGPROB) - # ref_logprobs = torch.masked_fill(ref_logprobs, padding_mask, INVALID_LOGPROB) - # values = torch.masked_fill(values, padding_mask, 0) + logprobs = torch.masked_fill(logprobs, padding_mask, INVALID_LOGPROB) + ref_logprobs = torch.masked_fill(ref_logprobs, padding_mask, INVALID_LOGPROB) + values = torch.masked_fill(values, padding_mask, 0) rew, scores = get_reward(reward_model, postprocessed_query_responses, tokenizer) @@ -810,10 +810,11 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well logits = output.logits[:, context_length - 1 : -1] logits /= args.task.temperature + 1e-7 new_all_logprobs = F.log_softmax(logits, dim=-1) - # new_logprobs = torch.masked_fill(new_logprobs, padding_mask[micro_batch_inds], INVALID_LOGPROB) new_logprobs = torch.gather(new_all_logprobs, 2, mb_responses.unsqueeze(-1)).squeeze(-1) + new_logprobs = torch.masked_fill(new_logprobs, padding_mask[micro_batch_inds], INVALID_LOGPROB) vpred = vpred_temp[:, context_length - 1 : -1] - # vpred = torch.masked_fill(vpred, padding_mask[micro_batch_inds], 0) + vpred = torch.masked_fill(vpred, padding_mask[micro_batch_inds], 0) + mb_return = torch.masked_fill(mb_return, padding_mask[micro_batch_inds], 0) # should not have a gradient effect vpredclipped = torch.clamp( vpred, mb_values - args.ppo.cliprange_value, From 9302a5bcf020b41eead937154aed1f0f24dcb063 Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Tue, 14 Nov 2023 20:52:00 +0000 Subject: [PATCH 24/62] push changes --- lm_human_preference_details/tldr_dataset.py | 4 +- .../train_reward_accelerate_summarize.py | 130 ++++++++---------- .../train_sft_accelerate_summarize.py | 18 +-- 3 files changed, 62 insertions(+), 90 deletions(-) diff --git a/lm_human_preference_details/tldr_dataset.py b/lm_human_preference_details/tldr_dataset.py index de9277b..c25dd2e 100644 --- a/lm_human_preference_details/tldr_dataset.py +++ b/lm_human_preference_details/tldr_dataset.py @@ -118,8 +118,8 @@ def process_query_data(x): def process_response_data(x): # with an extra leading space to account for the space between the query and response - response0 = x["summaries"][0]["text"] - response1 = x["summaries"][1]["text"] + response0 = f" {x['summaries'][0]['text']}<|endoftext|>" + response1 = f" {x['summaries'][1]['text']}<|endoftext|>" return { **process_query(x["info"], encoder=tokenizer, hparams=oai_h), "response0": response0, diff --git a/lm_human_preference_details/train_reward_accelerate_summarize.py b/lm_human_preference_details/train_reward_accelerate_summarize.py index 22a8976..a4b55d3 100644 --- a/lm_human_preference_details/train_reward_accelerate_summarize.py +++ b/lm_human_preference_details/train_reward_accelerate_summarize.py @@ -27,6 +27,7 @@ _use_grad_for_differentiable, ) from torch.utils.tensorboard import SummaryWriter +from torch.utils.data import DataLoader from tqdm import tqdm from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, get_scheduler @@ -104,6 +105,8 @@ class Args: """gradient accumulation steps""" local_micro_batch_size: tyro.conf.Suppress[int] = None """per rank micro batch size""" + local_eval_batch_size: int = 8 + """per rank eval batch size""" lr: float = 0.00005 """the learning rate""" eps: float = 1e-5 @@ -142,9 +145,9 @@ class Args: """Where to save the model""" optimizer: Literal["tf_adam", "adam", "adamw"] = "adamw" """Which optimizer to use""" - scheduler: str = "constant_with_warmup" + scheduler: str = "cosine" """Which scheduler to use""" - warm_up_steps: int = 100 + warm_up_steps: int = 0 """Number of warm up steps for the scheduler""" task: TaskHParams = field(default_factory=TaskHParams) labels: LabelHParams = field(default_factory=LabelHParams) @@ -379,40 +382,30 @@ def get_reward(reward_model, query_responses, tokenizer): ) -def evaluate(args, accelerator, tokenizer, device, reward_model, validation_label): +def evaluate(args, accelerator, tokenizer, reward_model, dataloader): reward_model.eval() with torch.no_grad(): # eval on validation_label, some duplicate code (I don't want to make the training loop into a function...) - test_accuracies = [] - eval_len = len(validation_label) - len_labels = (eval_len // args.batch_size) * args.batch_size # in case the last batch is not full - new_all_inds = np.arange(len_labels) - for start in tqdm(range(0, len_labels, args.batch_size)): - end = start + args.batch_size - b_inds_all = new_all_inds[start:end] - b_inds = b_inds_all[accelerator.process_index :: accelerator.num_processes] # multi-GPU slicing - for micro_batch_start in range(0, args.local_batch_size, args.local_micro_batch_size): - micro_batch_end = micro_batch_start + args.local_micro_batch_size - micro_batch_inds = b_inds[micro_batch_start:micro_batch_end] - mb_data = validation_label[micro_batch_inds] - mb_query = torch.from_numpy(np.stack(mb_data["query_token"])).to(device) - mb_best = torch.from_numpy(np.stack(mb_data["choice"])).to(device) - mb_responses = [ - torch.from_numpy(np.stack(mb_data[f"response{i}_token"])).to(device) for i in range(args.labels.num_labels) - ] - mb_query_tiled = mb_query.unsqueeze(1).repeat(1, len(mb_responses), 1) - query_responses = torch.cat([mb_query_tiled, torch.stack(mb_responses).transpose(0, 1)], dim=2).flatten(0, 1) - query_responses = shift_pad_id_left(query_responses, tokenizer.pad_token_id) - predicted_reward = get_reward(reward_model, query_responses, tokenizer) - predicted_reward = predicted_reward.view(-1, len(mb_responses)) - accuracy = (predicted_reward.argmax(1) == mb_best).float().mean() - test_accuracies.append(accuracy) - test_accuracy = accelerator.gather(torch.stack(test_accuracies).mean()).mean().item() + accuracies = [] + for data in tqdm(dataloader): + mb_query = data["query_token"] + mb_responses = torch.cat([data[f"response0_token"].unsqueeze(1), data[f"response1_token"].unsqueeze(1)], dim=1) + mb_best = data["choice"] + mb_query_tiled = mb_query.unsqueeze(1).repeat(1, args.labels.num_labels, 1) + query_responses = torch.cat([mb_query_tiled, mb_responses], dim=2).flatten(0, 1) + query_responses = shift_pad_id_left(query_responses, tokenizer.pad_token_id) + predicted_reward = get_reward(reward_model, query_responses, tokenizer) + predicted_reward = predicted_reward.view(-1, args.labels.num_labels) + accuracy = (predicted_reward.argmax(1) == mb_best).float().mean() + accuracies.append(accuracy) + accuracy = accelerator.gather(torch.stack(accuracies).mean()).mean().item() reward_model.train() - return test_accuracy + return accuracy -def train(args: Args): +# def train(args: Args): +if __name__ == "__main__": + args = tyro.cli(Args) accelerator = Accelerator( kwargs_handlers=[ DistributedDataParallelKwargs( @@ -508,13 +501,29 @@ def train(args: Args): deepspeed_states = AcceleratorState().deepspeed_plugin deepspeed_states.deepspeed_config["train_micro_batch_size_per_gpu"] = args.local_micro_batch_size - reward_model, optimizer, scheduler = accelerator.prepare(reward_model, optimizer, scheduler) - - # `label` has keys `['sample0', 'query', 'best', 'sample3', 'sample1', 'sample2']` label = load_dataset(args.label_dataset, "comparisons", split="train") + label = label.with_format("torch", columns=["query_token", "choice", "response0_token", "response1_token"]) + label = label.shuffle(seed=local_seed) + dataloader = DataLoader(label, batch_size=args.local_micro_batch_size) + reward_model, optimizer, dataloader, scheduler = accelerator.prepare(reward_model, optimizer, dataloader, scheduler) + iter_dataloader = iter(dataloader) validation_label = load_dataset(args.label_dataset, "comparisons", split="validation") - dev_validation_label = validation_label.filter(lambda x: x["split"] == "valid1") - eval_validation_label = validation_label.filter(lambda x: x["split"] == "valid2") + + batch_names = set() + for item in tqdm(validation_label): + batch_names.add(item["batch"]) + batch_names = sorted(list(batch_names)) + pprint(batch_names) + batches = [validation_label.filter( + lambda x: x["batch"] == batch_name + ).with_format( + "torch", + columns=["query_token", "choice", "response0_token", "response1_token"] + ) for batch_name in batch_names + ] + batch_dataloaders = [DataLoader(batch, batch_size=args.local_eval_batch_size) for batch in batches] + batch_dataloaders = [accelerator.prepare(dataloader) for dataloader in batch_dataloaders] + accelerator.print("Num labels found in source:", len(label)) accelerator.print("training on", args.labels.num_train, "in batches of", args.local_batch_size) accelerator.print("===training reward model===") @@ -528,41 +537,22 @@ def train(args: Args): accelerator.print(f"epoch: {epoch}") for (epoch_global_step, start) in enumerate(range(0, num_train, args.batch_size)): global_step = epoch * args.num_updates + epoch_global_step - end = start + args.batch_size - b_inds_all = all_inds[start:end] - b_inds = b_inds_all[accelerator.process_index :: accelerator.num_processes] # multi-GPU slicing - if accelerator.is_main_process: - pprint( - { - "global_step": global_step, - "start:end": f"{start}:{end}", - "b_inds_all": b_inds_all, - "b_inds": b_inds, - } - ) losses = torch.zeros((args.gradient_accumulation_steps,), device=device) accuracies = torch.zeros((args.gradient_accumulation_steps,), device=device) reward_preferreds = torch.zeros((args.gradient_accumulation_steps,), device=device) reward_rejecteds = torch.zeros((args.gradient_accumulation_steps,), device=device) gradient_accumulation_step = 0 - for micro_batch_start in range(0, args.local_batch_size, args.local_micro_batch_size): + for gradient_accumulation_step in range(args.gradient_accumulation_steps): with accelerator.accumulate(reward_model): - micro_batch_end = micro_batch_start + args.local_micro_batch_size - micro_batch_inds = b_inds[micro_batch_start:micro_batch_end] - mb_data = label[micro_batch_inds] - mb_query = torch.from_numpy(np.stack(mb_data["query_token"])).to(device) - mb_best = torch.from_numpy(np.stack(mb_data["choice"])).to(device) - mb_responses = [ - torch.from_numpy(np.stack(mb_data[f"response{i}_token"])).to(device) - for i in range(args.labels.num_labels) - ] - mb_query_tiled = mb_query.unsqueeze(1).repeat(1, len(mb_responses), 1) - query_responses = torch.cat([mb_query_tiled, torch.stack(mb_responses).transpose(0, 1)], dim=2).flatten( - 0, 1 - ) + data = next(iter_dataloader) + mb_query = data["query_token"] + mb_responses = torch.cat([data[f"response0_token"].unsqueeze(1), data[f"response1_token"].unsqueeze(1)], dim=1) + mb_best = data["choice"] + mb_query_tiled = mb_query.unsqueeze(1).repeat(1, args.labels.num_labels, 1) + query_responses = torch.cat([mb_query_tiled, mb_responses], dim=2).flatten(0, 1) query_responses = shift_pad_id_left(query_responses, tokenizer.pad_token_id) predicted_reward = get_reward(reward_model, query_responses, tokenizer) - predicted_reward = predicted_reward.view(-1, len(mb_responses)) + predicted_reward = predicted_reward.view(-1, args.labels.num_labels) accuracy = (predicted_reward.argmax(1) == mb_best).float().mean() reward_preferred = predicted_reward.gather(1, mb_best.view(-1, 1)).view(-1) reward_rejected = predicted_reward.gather(1, (1 - mb_best).view(-1, 1)).view(-1) @@ -582,8 +572,6 @@ def train(args: Args): accuracies[gradient_accumulation_step] = accuracy reward_preferreds[gradient_accumulation_step] = reward_preferred.mean() reward_rejecteds[gradient_accumulation_step] = reward_rejected.mean() - gradient_accumulation_step += 1 - train_accuracy = accelerator.gather(accuracies).mean().item() writer.add_scalar("train/loss", accelerator.gather(losses).mean().item(), global_step) writer.add_scalar("train/accuracy", train_accuracy, global_step) @@ -595,16 +583,10 @@ def train(args: Args): # if args.print_sample_output_freq > 0 and global_step % args.print_sample_output_freq == 0: if global_step == args.num_updates - 1: # first and last update - # if global_step == 1: - dev_validation_accuracy = evaluate(args, accelerator, tokenizer, device, reward_model, dev_validation_label) - writer.add_scalar("dev_validation/accuracy", dev_validation_accuracy, global_step) - accelerator.print("dev_validation/accuracy", dev_validation_accuracy, global_step) - eval_validation_accuracy = evaluate(args, accelerator, tokenizer, device, reward_model, eval_validation_label) - writer.add_scalar("eval_validation/accuracy", eval_validation_accuracy, global_step) - accelerator.print("eval_validation/accuracy", eval_validation_accuracy, global_step) - eval_validation_accuracy = evaluate(args, accelerator, tokenizer, device, reward_model, label) - writer.add_scalar("train_full/accuracy", eval_validation_accuracy, global_step) - accelerator.print("train_full/accuracy", eval_validation_accuracy, global_step) + for batch_name, batch_dataloader in zip(batch_names, batch_dataloaders): + batch_accuracy = evaluate(args, accelerator, tokenizer, reward_model, batch_dataloader) + writer.add_scalar(f"eval/accuracy/{batch_name}", batch_accuracy, global_step) + accelerator.print(f"eval/accuracy/{batch_name}", batch_accuracy, global_step) torch.cuda.empty_cache() diff --git a/lm_human_preference_details/train_sft_accelerate_summarize.py b/lm_human_preference_details/train_sft_accelerate_summarize.py index 4d1a112..77306df 100644 --- a/lm_human_preference_details/train_sft_accelerate_summarize.py +++ b/lm_human_preference_details/train_sft_accelerate_summarize.py @@ -12,6 +12,7 @@ import torch import torch.optim as optim import tyro +from tqdm import tqdm from accelerate import Accelerator from datasets import load_dataset from rich.console import Console @@ -75,18 +76,6 @@ class TaskHParams: temperature: float = 0.01 -# a patch -@dataclass -class TaskQueryHParams: - length: int = None - dataset: str = None - format_str: Optional[str] = None # if underlying dataset yields dicts, can format arbitrarily - truncate_field: Optional[str] = None - truncate_text: Optional[str] = None - padding: Optional[str] = None # defaults to repeated spaces - pad_side: Optional[str] = None - - @dataclass class Args: # common args @@ -478,7 +467,7 @@ def forward(policy, query_responses, tokenizer): all_decode_validation_responses = [] all_decode_validation_reference_responses = [] all_validation_losses = [] - for validation_idx, validation_data in enumerate(validation_dataloader): + for validation_idx, validation_data in tqdm(enumerate(validation_dataloader)): with torch.no_grad(): validation_reference_responses = validation_data["reference_response_token"].to(device, non_blocking=True) validation_queries = validation_data["query_token"].to(device, non_blocking=True) @@ -526,7 +515,8 @@ def forward(policy, query_responses, tokenizer): all_decode_validation_query_responses.extend(decode_validation_query_responses) all_decode_validation_responses.extend(decode_validation_responses) all_decode_validation_reference_responses.extend(decode_validation_reference_responses) - + # if validation_idx == 10: + # break try: all_df = pd.DataFrame( { From 94acf3878390ef3bed1760eacc8578f67ad3f827 Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Fri, 17 Nov 2023 14:04:17 +0000 Subject: [PATCH 25/62] SFT seemed to work finally https://wandb.ai/costa-huang/tldr_summarize/reports/RougeL--Vmlldzo1OTk2NTUw --- .../train_sft_accelerate_summarize.py | 210 ++++++++++-------- 1 file changed, 116 insertions(+), 94 deletions(-) diff --git a/lm_human_preference_details/train_sft_accelerate_summarize.py b/lm_human_preference_details/train_sft_accelerate_summarize.py index 77306df..eea6feb 100644 --- a/lm_human_preference_details/train_sft_accelerate_summarize.py +++ b/lm_human_preference_details/train_sft_accelerate_summarize.py @@ -85,7 +85,7 @@ class Args: """seed of the experiment""" track: bool = False """if toggled, this experiment will be tracked with Weights and Biases""" - wandb_project_name: str = "cleanrl" + wandb_project_name: str = "tldr_summarize" """the wandb's project name""" wandb_entity: Optional[str] = None """the entity (team) of wandb's project""" @@ -298,6 +298,23 @@ def shift_pad_id_left(data, pad_id): return shifted_data +def right_padding_to_left_padding(tokens, pad_id): + """Convert from right padding to left padding.""" + assert tokens.ndim == 2 + return torch.tensor( + [[pad_id] * (row == pad_id).sum() + [x for x in row if x != pad_id] for row in tokens], + device=tokens.device, + ) + +def left_padding_to_right_padding(tokens, pad_id): + """Convert from left padding to right padding.""" + assert tokens.ndim == 2 + return torch.tensor( + [[x for x in row if x != pad_id] + [pad_id] * (row == pad_id).sum() for row in tokens], + device=tokens.device, + ) + + def ceil_div(a, b): return (a - 1) // b + 1 @@ -326,12 +343,12 @@ def generate(lm_backbone, queries, tokenizer, generation_config): def forward(policy, query_responses, tokenizer): attention_mask = query_responses != tokenizer.pad_token_id - position_ids = attention_mask.cumsum(1) - attention_mask.long() # exclusive cumsum + # position_ids = attention_mask.cumsum(1) - attention_mask.long() # exclusive cumsum input_ids = torch.masked_fill(query_responses, ~attention_mask, 0) return policy( input_ids=input_ids, attention_mask=attention_mask, - position_ids=position_ids, + # position_ids=position_ids, return_dict=True, ) @@ -342,13 +359,14 @@ def forward(policy, query_responses, tokenizer): accelerator = Accelerator(gradient_accumulation_steps=args.sft.gradient_accumulation_steps) args.sft.world_size = accelerator.num_processes args.sft.local_batch_size = args.sft.local_micro_batch_size * args.sft.gradient_accumulation_steps + args.sft.micro_batch_size = int(args.sft.local_micro_batch_size * args.sft.world_size) args.sft.batch_size = int(args.sft.local_batch_size * args.sft.world_size) dataset = load_dataset(args.task.query_dataset, split="train") validation_dataset = load_dataset(args.task.query_dataset, split="validation") accelerator.print("The number of samples in dataset", len(dataset)) accelerator.print("The number of samples in validation_dataset", len(validation_dataset)) args.sft.total_episodes = len(dataset) - args.sft.num_updates = args.sft.total_episodes // args.sft.batch_size + args.sft.num_updates = args.sft.total_episodes // args.sft.local_batch_size console = Console(force_terminal=True) run_name = f"{args.exp_name}__{args.seed}__{int(time.time())}" @@ -403,7 +421,7 @@ def forward(policy, query_responses, tokenizer): args.scheduler, optimizer=optimizer, num_warmup_steps=args.warm_up_steps, - num_training_steps=args.sft.num_updates // args.sft.gradient_accumulation_steps, + num_training_steps=args.sft.num_updates, ) dataset = dataset.with_format("torch", columns=["query_token", "reference_response_token"]) @@ -411,10 +429,10 @@ def forward(policy, query_responses, tokenizer): dataloader = DataLoader(dataset, batch_size=args.sft.local_micro_batch_size) validation_dataset = validation_dataset.with_format("torch", columns=["query_token", "reference_response_token"]) validation_dataloader = DataLoader(validation_dataset, batch_size=args.sft.local_micro_batch_size) - policy, optimizer, dataloader, validation_dataloader, scheduler = accelerator.prepare( - policy, optimizer, dataloader, validation_dataloader, scheduler + policy, optimizer, dataloader, scheduler = accelerator.prepare( + policy, optimizer, dataloader, scheduler ) - iter_dataloader = iter(dataloader) + validation_dataloader = accelerator.prepare(validation_dataloader) # WARNING: even with `max_new_tokens` and `min_new_tokens` set to the same value, the number of tokens generated # may not be the same. TODO: investigate further, we just want to generate a fixed number of tokens generation_config = GenerationConfig( @@ -432,14 +450,15 @@ def forward(policy, query_responses, tokenizer): loss_stats = torch.zeros(args.sft.gradient_accumulation_steps, device=device) gradient_accumulation_idx = 0 policy.train() - for update in range(1, args.sft.num_updates + 1): - global_step += args.sft.batch_size - accelerator.print(f"update {update}, global_step {global_step}") - data = next(iter_dataloader) + # for update in range(1, args.sft.num_updates + 1): + update = 0 + for data in dataloader: + update += 1 + global_step += args.sft.micro_batch_size reference_responses = data["reference_response_token"].to(device, non_blocking=True) queries = data["query_token"].to(device, non_blocking=True) query_responses = torch.cat((queries, reference_responses), dim=1) - query_responses = shift_pad_id_left(query_responses, tokenizer.pad_token_id) + query_responses = left_padding_to_right_padding(query_responses, tokenizer.pad_token_id) with accelerator.accumulate(policy): output = forward(policy, query_responses, tokenizer) # mask out gradient effects on response padding tokens @@ -458,87 +477,90 @@ def forward(policy, query_responses, tokenizer): gradient_accumulation_idx = (gradient_accumulation_idx + 1) % args.sft.gradient_accumulation_steps if update > 1 and (update - 1) % args.sft.gradient_accumulation_steps == 0: writer.add_scalar("loss", accelerator.gather(loss_stats).mean().item(), update) - writer.add_scalar("lr", optimizer.param_groups[0]["lr"], update) - if update == 1 or update == args.sft.num_updates - 1: - policy.eval() - rouge_scores = collections.defaultdict(list) - all_decode_validation_queries = [] - all_decode_validation_query_responses = [] - all_decode_validation_responses = [] - all_decode_validation_reference_responses = [] - all_validation_losses = [] - for validation_idx, validation_data in tqdm(enumerate(validation_dataloader)): - with torch.no_grad(): - validation_reference_responses = validation_data["reference_response_token"].to(device, non_blocking=True) - validation_queries = validation_data["query_token"].to(device, non_blocking=True) - validation_query_reference_responses = torch.cat( - (validation_queries, validation_reference_responses), dim=1 - ) - validation_query_reference_responses = shift_pad_id_left(validation_query_reference_responses, tokenizer.pad_token_id) - - validation_output = forward(policy, validation_query_reference_responses, tokenizer) - validation_labels = validation_query_reference_responses.masked_fill( - validation_query_reference_responses == tokenizer.pad_token_id, -1 - ) - validation_lm_logits = validation_output.logits - # hand-rolled transformer loss: Shift so that tokens < n predict n - # but unlike `transformers` we mask the padding tokens via `ignore_index=-1` - validation_shift_logits = validation_lm_logits[..., :-1, :].contiguous() - validation_shift_labels = validation_labels[..., 1:].contiguous() - validation_loss = F.cross_entropy( - validation_shift_logits.view(-1, validation_shift_logits.size(-1)), - validation_shift_labels.view(-1), - ignore_index=-1, - ) - validation_loss = accelerator.gather(validation_loss) - all_validation_losses.append(validation_loss) - - generated_responses = generate( - accelerator.unwrap_model(policy), validation_queries, tokenizer, generation_config - ) - decode_validation_queries = tokenizer.batch_decode(accelerator.gather(validation_queries)) - decode_validation_query_responses = tokenizer.batch_decode(accelerator.gather(generated_responses)) - decode_validation_reference_responses = tokenizer.batch_decode( - accelerator.gather(validation_reference_responses) - ) - decode_validation_responses = [ - x[len(y) :] for x, y in zip(decode_validation_query_responses, decode_validation_queries) - ] - rouge_score = rouge.compute( - predictions=decode_validation_responses, references=decode_validation_reference_responses - ) - rouge_scores["rouge1"].append(rouge_score["rouge1"]) - rouge_scores["rouge2"].append(rouge_score["rouge2"]) - rouge_scores["rougeL"].append(rouge_score["rougeL"]) - - all_decode_validation_queries.extend(decode_validation_queries) - all_decode_validation_query_responses.extend(decode_validation_query_responses) - all_decode_validation_responses.extend(decode_validation_responses) - all_decode_validation_reference_responses.extend(decode_validation_reference_responses) - # if validation_idx == 10: - # break - try: - all_df = pd.DataFrame( - { - "query": all_decode_validation_queries, - "response": all_decode_validation_responses, - "reference": all_decode_validation_reference_responses, - } - ) - accelerator.print(all_df) - if accelerator.is_main_process and args.track: - wandb.log({"samples/query_responses": wandb.Table(dataframe=all_df)}, step=update) - print_rich_table(f"Sample Output at Step {update}", all_df[:4], console) - except Exception as e: - print(e) - - for k, v in rouge_scores.items(): - rouge_metric = torch.tensor(v, device=device) - rouge_metric = accelerator.gather(rouge_metric) - writer.add_scalar(f"rouge/{k}", rouge_metric.mean().item(), update) - accelerator.print(f"rouge/{k}: {rouge_metric.mean().item()} {rouge_metric.shape} {rouge_metric}") - writer.add_scalar("validation_loss", torch.stack(all_validation_losses).mean().item(), update) - policy.train() + writer.add_scalar("lr", scheduler.get_last_lr()[0], update) + accelerator.print(f"{loss.item()=}, {scheduler.get_last_lr()=}, {update=}") + # if update == args.sft.num_updates - 1: # update == 1 or + policy.eval() + rouge_scores = collections.defaultdict(list) + all_decode_validation_queries = [] + all_decode_validation_query_responses = [] + all_decode_validation_responses = [] + all_decode_validation_reference_responses = [] + all_validation_losses = [] + for validation_idx, validation_data in tqdm(enumerate(validation_dataloader)): + with torch.no_grad(): + validation_reference_responses = validation_data["reference_response_token"].to(device, non_blocking=True) + validation_queries = validation_data["query_token"].to(device, non_blocking=True) + # validation_queries = right_padding_to_left_padding(validation_queries, tokenizer.pad_token_id) # not necessary + validation_query_reference_responses = torch.cat( + (validation_queries, validation_reference_responses), dim=1 + ) + validation_query_reference_responses = left_padding_to_right_padding(validation_query_reference_responses, tokenizer.pad_token_id) + + validation_output = forward(policy, validation_query_reference_responses, tokenizer) + validation_labels = validation_query_reference_responses.masked_fill( + validation_query_reference_responses == tokenizer.pad_token_id, -1 + ) + validation_lm_logits = validation_output.logits + # hand-rolled transformer loss: Shift so that tokens < n predict n + # but unlike `transformers` we mask the padding tokens via `ignore_index=-1` + validation_shift_logits = validation_lm_logits[..., :-1, :].contiguous() + validation_shift_labels = validation_labels[..., 1:].contiguous() + validation_loss = F.cross_entropy( + validation_shift_logits.view(-1, validation_shift_logits.size(-1)), + validation_shift_labels.view(-1), + ignore_index=-1, + ) + validation_loss = accelerator.gather(validation_loss) + all_validation_losses.append(validation_loss) + + generated_responses = generate( + accelerator.unwrap_model(policy), + validation_queries, + tokenizer, + generation_config, + ) + decode_validation_queries = tokenizer.batch_decode(accelerator.gather(validation_queries)) + decode_validation_query_responses = tokenizer.batch_decode(accelerator.gather(generated_responses)) + decode_validation_reference_responses = tokenizer.batch_decode( + accelerator.gather(validation_reference_responses) + ) + decode_validation_responses = tokenizer.batch_decode(accelerator.gather(generated_responses[:, -args.task.response_length:])) + rouge_score = rouge.compute( + predictions=decode_validation_responses, references=decode_validation_reference_responses + ) + rouge_scores["rouge1"].append(rouge_score["rouge1"]) + rouge_scores["rouge2"].append(rouge_score["rouge2"]) + rouge_scores["rougeL"].append(rouge_score["rougeL"]) + + all_decode_validation_queries.extend(decode_validation_queries) + all_decode_validation_query_responses.extend(decode_validation_query_responses) + all_decode_validation_responses.extend(decode_validation_responses) + all_decode_validation_reference_responses.extend(decode_validation_reference_responses) + # if validation_idx == 10: + # break + try: + all_df = pd.DataFrame( + { + "query": all_decode_validation_queries, + "response": all_decode_validation_responses, + "reference": all_decode_validation_reference_responses, + } + ) + accelerator.print(all_df) + if accelerator.is_main_process and args.track: + wandb.log({"samples/query_responses": wandb.Table(dataframe=all_df)}, step=update) + print_rich_table(f"Sample Output at Step {update}", all_df[:4], console) + except Exception as e: + print(e) + + for k, v in rouge_scores.items(): + rouge_metric = torch.tensor(v, device=device) + rouge_metric = accelerator.gather(rouge_metric) + writer.add_scalar(f"rouge/{k}", rouge_metric.mean().item(), update) + accelerator.print(f"rouge/{k}: {rouge_metric.mean().item()} {rouge_metric.shape} {rouge_metric}") + writer.add_scalar("validation_loss", torch.stack(all_validation_losses).mean().item(), update) + policy.train() # save model if args.save_path: From b7f6876e2e27d5423051375973445c3057c31940 Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Wed, 22 Nov 2023 02:10:51 +0000 Subject: [PATCH 26/62] reducing lr helps; refactor loop logic https://wandb.ai/costa-huang/cleanRL/runs/zj3wgvcs --- .../train_reward_accelerate_summarize.py | 206 +++++++++--------- 1 file changed, 102 insertions(+), 104 deletions(-) diff --git a/lm_human_preference_details/train_reward_accelerate_summarize.py b/lm_human_preference_details/train_reward_accelerate_summarize.py index a4b55d3..ddcc817 100644 --- a/lm_human_preference_details/train_reward_accelerate_summarize.py +++ b/lm_human_preference_details/train_reward_accelerate_summarize.py @@ -1,3 +1,4 @@ +from collections import defaultdict import os import random import time @@ -15,7 +16,7 @@ import tyro from accelerate import Accelerator from accelerate.state import AcceleratorState -from accelerate.utils import DistributedDataParallelKwargs, broadcast +from accelerate.utils import DistributedDataParallelKwargs from datasets import load_dataset from rich.console import Console from rich.pretty import pprint @@ -99,7 +100,7 @@ class Args: """Whether to use deepspeed to train the model""" label_dataset: str = "vwxyzjn/summarize_from_feedback_oai_preprocessing" """the name of the dataset to use for labels in `https://huggingface.co/datasets/vwxyzjn/lm-human-preferences`""" - local_batch_size: int = 4 + local_batch_size: int = 8 """per rank batch size""" gradient_accumulation_steps: int = 1 """gradient accumulation steps""" @@ -107,7 +108,7 @@ class Args: """per rank micro batch size""" local_eval_batch_size: int = 8 """per rank eval batch size""" - lr: float = 0.00005 + lr: float = 5e-6 """the learning rate""" eps: float = 1e-5 """the epsilon for AdamW""" @@ -339,23 +340,17 @@ def __init__(self, lm_backbone): def forward(self, **kwargs): output = self.lm_backbone(**kwargs) - reward_latents = output.hidden_states[-1] - # shape: [batch_size, length, hidden_size] - last_reward_latents = reward_latents[:, -1, :] - # shape: [batch_size, hidden_size] - reward = self.scalar_head(last_reward_latents) + reward = self.scalar_head(output.hidden_states[-1]) return reward -def shift_pad_id_left(data, pad_id): - # Step 1: Create a boolean mask - mask = (data == pad_id).long() - # Step 3: Use argsort on the inverted boolean mask to get sorted indices - sorted_indices = torch.argsort(~mask, axis=1) - # Step 4: Use advanced indexing to rearrange the elements - rows_range = torch.arange(data.shape[0], device=data.device) - shifted_data = data[rows_range[:, None], sorted_indices] - return shifted_data +def left_padding_to_right_padding(tokens, pad_id): + """Convert from left padding to right padding.""" + assert tokens.ndim == 2 + return torch.tensor( + [[x for x in row if x != pad_id] + [pad_id] * (row == pad_id).sum() for row in tokens], + device=tokens.device, + ) def ceil_div(a, b): @@ -371,15 +366,19 @@ def exact_div(a, b): def get_reward(reward_model, query_responses, tokenizer): attention_mask = query_responses != tokenizer.pad_token_id - position_ids = attention_mask.cumsum(1) - attention_mask.long() # exclusive cumsum input_ids = torch.masked_fill(query_responses, ~attention_mask, 0) - return reward_model( + reward_logits = reward_model( input_ids=input_ids, attention_mask=attention_mask, - position_ids=position_ids, return_dict=True, output_hidden_states=True, ) + sequence_lengths = ( + torch.eq(query_responses, tokenizer.pad_token_id).long().argmax(-1) - 1).to( + query_responses.device + ) + # https://github.com/huggingface/transformers/blob/dc68a39c8111217683bf49a4912d0c9018bab33d/src/transformers/models/gpt2/modeling_gpt2.py#L1454 + return reward_logits[torch.arange(reward_logits.size(0), device=reward_logits.device), sequence_lengths] def evaluate(args, accelerator, tokenizer, reward_model, dataloader): @@ -387,20 +386,28 @@ def evaluate(args, accelerator, tokenizer, reward_model, dataloader): with torch.no_grad(): # eval on validation_label, some duplicate code (I don't want to make the training loop into a function...) accuracies = [] + accuracy_splits = defaultdict(list) + accuracy_batches = defaultdict(list) + accuracy_confidences = defaultdict(list) for data in tqdm(dataloader): mb_query = data["query_token"] mb_responses = torch.cat([data[f"response0_token"].unsqueeze(1), data[f"response1_token"].unsqueeze(1)], dim=1) mb_best = data["choice"] mb_query_tiled = mb_query.unsqueeze(1).repeat(1, args.labels.num_labels, 1) query_responses = torch.cat([mb_query_tiled, mb_responses], dim=2).flatten(0, 1) - query_responses = shift_pad_id_left(query_responses, tokenizer.pad_token_id) + query_responses = left_padding_to_right_padding(query_responses, tokenizer.pad_token_id) predicted_reward = get_reward(reward_model, query_responses, tokenizer) predicted_reward = predicted_reward.view(-1, args.labels.num_labels) - accuracy = (predicted_reward.argmax(1) == mb_best).float().mean() - accuracies.append(accuracy) - accuracy = accelerator.gather(torch.stack(accuracies).mean()).mean().item() + accuracy = (predicted_reward.argmax(1) == mb_best).float() + accuracies.append(accuracy.mean()) + for batch, split, confidence, acc in zip(data["batch"], data["split"], data["extra"]["confidence"], accuracy): + acc_item = acc.item() + accuracy_splits[split].append(acc_item) + accuracy_batches[batch].append(acc_item) + accuracy_confidences[int(confidence)].append(acc_item) + accuracies = accelerator.gather(torch.stack(accuracies).mean()).mean().item() reward_model.train() - return accuracy + return accuracies, accuracy_batches, accuracy_splits, accuracy_confidences # def train(args: Args): @@ -419,7 +426,8 @@ def evaluate(args, accelerator, tokenizer, reward_model, dataloader): args.batch_size = int(args.local_batch_size * args.world_size) args.rollout_batch_size = int(args.local_rollout_batch_size * args.world_size) args.local_micro_batch_size = exact_div(args.local_batch_size, args.gradient_accumulation_steps) - args.num_updates = args.labels.num_train // args.batch_size + args.micro_batch_size = int(args.local_micro_batch_size * args.world_size) + args.num_updates = args.labels.num_train // args.local_batch_size console = Console(force_terminal=True) run_name = f"{args.exp_name}__{args.seed}__{int(time.time())}" @@ -468,7 +476,6 @@ def evaluate(args, accelerator, tokenizer, reward_model, dataloader): trust_remote_code=True, ) ) - # freeze the first 70% of layers if args.trainable_param_percentage < 1.0: layers = reward_model.lm_backbone.transformer.h @@ -502,91 +509,82 @@ def evaluate(args, accelerator, tokenizer, reward_model, dataloader): deepspeed_states.deepspeed_config["train_micro_batch_size_per_gpu"] = args.local_micro_batch_size label = load_dataset(args.label_dataset, "comparisons", split="train") - label = label.with_format("torch", columns=["query_token", "choice", "response0_token", "response1_token"]) label = label.shuffle(seed=local_seed) + label = label.select(range(args.labels.num_train)) + label = label.with_format("torch", columns=["query_token", "choice", "response0_token", "response1_token", "batch", "split"]) dataloader = DataLoader(label, batch_size=args.local_micro_batch_size) reward_model, optimizer, dataloader, scheduler = accelerator.prepare(reward_model, optimizer, dataloader, scheduler) - iter_dataloader = iter(dataloader) validation_label = load_dataset(args.label_dataset, "comparisons", split="validation") + validation_label = validation_label.with_format("torch", columns=["query_token", "choice", "response0_token", "response1_token", "batch", "split", "extra"]) + validation_dataloader = DataLoader(validation_label, batch_size=args.local_eval_batch_size) + validation_dataloader = accelerator.prepare(validation_dataloader) - batch_names = set() - for item in tqdm(validation_label): - batch_names.add(item["batch"]) - batch_names = sorted(list(batch_names)) - pprint(batch_names) - batches = [validation_label.filter( - lambda x: x["batch"] == batch_name - ).with_format( - "torch", - columns=["query_token", "choice", "response0_token", "response1_token"] - ) for batch_name in batch_names - ] - batch_dataloaders = [DataLoader(batch, batch_size=args.local_eval_batch_size) for batch in batches] - batch_dataloaders = [accelerator.prepare(dataloader) for dataloader in batch_dataloaders] - - accelerator.print("Num labels found in source:", len(label)) - accelerator.print("training on", args.labels.num_train, "in batches of", args.local_batch_size) accelerator.print("===training reward model===") - num_train = (args.labels.num_train // args.batch_size) * args.batch_size + losses = torch.zeros((args.gradient_accumulation_steps,), device=device) + accuracies = torch.zeros((args.gradient_accumulation_steps,), device=device) + reward_preferreds = torch.zeros((args.gradient_accumulation_steps,), device=device) + reward_rejecteds = torch.zeros((args.gradient_accumulation_steps,), device=device) reward_model.train() + gradient_accumulation_idx = 0 + global_step = 0 + update = 0 for epoch in range(args.num_epochs): - all_inds = np.random.permutation(args.labels.num_train) - # ensure that all processes have the same shuffled indices - all_inds = broadcast(torch.tensor(all_inds, device=device), 0) - all_inds = all_inds.cpu().numpy() accelerator.print(f"epoch: {epoch}") - for (epoch_global_step, start) in enumerate(range(0, num_train, args.batch_size)): - global_step = epoch * args.num_updates + epoch_global_step - losses = torch.zeros((args.gradient_accumulation_steps,), device=device) - accuracies = torch.zeros((args.gradient_accumulation_steps,), device=device) - reward_preferreds = torch.zeros((args.gradient_accumulation_steps,), device=device) - reward_rejecteds = torch.zeros((args.gradient_accumulation_steps,), device=device) - gradient_accumulation_step = 0 - for gradient_accumulation_step in range(args.gradient_accumulation_steps): - with accelerator.accumulate(reward_model): - data = next(iter_dataloader) - mb_query = data["query_token"] - mb_responses = torch.cat([data[f"response0_token"].unsqueeze(1), data[f"response1_token"].unsqueeze(1)], dim=1) - mb_best = data["choice"] - mb_query_tiled = mb_query.unsqueeze(1).repeat(1, args.labels.num_labels, 1) - query_responses = torch.cat([mb_query_tiled, mb_responses], dim=2).flatten(0, 1) - query_responses = shift_pad_id_left(query_responses, tokenizer.pad_token_id) - predicted_reward = get_reward(reward_model, query_responses, tokenizer) - predicted_reward = predicted_reward.view(-1, args.labels.num_labels) - accuracy = (predicted_reward.argmax(1) == mb_best).float().mean() - reward_preferred = predicted_reward.gather(1, mb_best.view(-1, 1)).view(-1) - reward_rejected = predicted_reward.gather(1, (1 - mb_best).view(-1, 1)).view(-1) - if args.logsigmoid: - loss = -F.logsigmoid(reward_preferred - reward_rejected).mean() - else: - loss = F.cross_entropy(predicted_reward, mb_best) - accelerator.backward(loss) - # for k, v in reward_model.named_parameters(): - # if v.requires_grad: - # if v.grad is None: - # print(f"found unused param: {k}") - optimizer.step() # accelerate handles gradient accumulation automatically - optimizer.zero_grad() - scheduler.step() - losses[gradient_accumulation_step] = loss - accuracies[gradient_accumulation_step] = accuracy - reward_preferreds[gradient_accumulation_step] = reward_preferred.mean() - reward_rejecteds[gradient_accumulation_step] = reward_rejected.mean() - train_accuracy = accelerator.gather(accuracies).mean().item() - writer.add_scalar("train/loss", accelerator.gather(losses).mean().item(), global_step) - writer.add_scalar("train/accuracy", train_accuracy, global_step) - writer.add_scalar("train/reward_preferred", accelerator.gather(reward_preferreds).mean().item(), global_step) - writer.add_scalar("train/reward_rejected", accelerator.gather(reward_rejecteds).mean().item(), global_step) - lr = scheduler.get_last_lr() - writer.add_scalar("train/lr", np.array(lr).mean().item(), global_step) - accelerator.print("train/accuracy", train_accuracy) - - # if args.print_sample_output_freq > 0 and global_step % args.print_sample_output_freq == 0: - if global_step == args.num_updates - 1: # first and last update - for batch_name, batch_dataloader in zip(batch_names, batch_dataloaders): - batch_accuracy = evaluate(args, accelerator, tokenizer, reward_model, batch_dataloader) - writer.add_scalar(f"eval/accuracy/{batch_name}", batch_accuracy, global_step) - accelerator.print(f"eval/accuracy/{batch_name}", batch_accuracy, global_step) + for data in dataloader: + update += 1 + global_step += args.micro_batch_size + mb_query = data["query_token"] + mb_responses = torch.cat([data[f"response0_token"].unsqueeze(1), data[f"response1_token"].unsqueeze(1)], dim=1) + mb_best = data["choice"] + mb_query_tiled = mb_query.unsqueeze(1).repeat(1, args.labels.num_labels, 1) + query_responses = torch.cat([mb_query_tiled, mb_responses], dim=2).flatten(0, 1) + query_responses = left_padding_to_right_padding(query_responses, tokenizer.pad_token_id) + with accelerator.accumulate(reward_model): + predicted_reward = get_reward(reward_model, query_responses, tokenizer) + predicted_reward = predicted_reward.view(-1, args.labels.num_labels) + accuracy = (predicted_reward.argmax(1) == mb_best).float().mean() + reward_preferred = predicted_reward.gather(1, mb_best.view(-1, 1)).view(-1) + reward_rejected = predicted_reward.gather(1, (1 - mb_best).view(-1, 1)).view(-1) + if args.logsigmoid: + loss = -F.logsigmoid(reward_preferred - reward_rejected).mean() + else: + loss = F.cross_entropy(predicted_reward, mb_best) + accelerator.backward(loss) + # for k, v in reward_model.named_parameters(): + # if v.requires_grad: + # if v.grad is None: + # print(f"found unused param: {k}") + optimizer.step() # accelerate handles gradient accumulation automatically + optimizer.zero_grad() + scheduler.step() + losses[gradient_accumulation_idx] = loss + accuracies[gradient_accumulation_idx] = accuracy + reward_preferreds[gradient_accumulation_idx] = reward_preferred.mean() + reward_rejecteds[gradient_accumulation_idx] = reward_rejected.mean() + gradient_accumulation_idx = (gradient_accumulation_idx + 1) % args.gradient_accumulation_steps + if update > 1 and (update - 1) % args.gradient_accumulation_steps == 0: + train_accuracy = accelerator.gather(accuracies).mean().item() + writer.add_scalar("train/loss", accelerator.gather(losses).mean().item(), global_step) + writer.add_scalar("train/accuracy", train_accuracy, global_step) + writer.add_scalar("train/reward_preferred", accelerator.gather(reward_preferreds).mean().item(), global_step) + writer.add_scalar("train/reward_rejected", accelerator.gather(reward_rejecteds).mean().item(), global_step) + writer.add_scalar("train/lr", scheduler.get_last_lr()[0], global_step) + accelerator.print(f"{train_accuracy=}, {scheduler.get_last_lr()=}, {update=}") + # if args.print_sample_output_freq > 0 and global_step % args.print_sample_output_freq == 0: + + accuracies, accuracy_batches, accuracy_splits, accuracy_confidences = evaluate(args, accelerator, tokenizer, reward_model, validation_dataloader) + for split, accs in accuracy_splits.items(): + writer.add_scalar(f"eval/accuracy/{split}", np.mean(accs), global_step) + accelerator.print(f"{split} accuracy: {np.mean(accs)}") + for batch, accs in accuracy_batches.items(): + writer.add_scalar(f"eval/accuracy/{batch}", np.mean(accs), global_step) + accelerator.print(f"{batch} accuracy: {np.mean(accs)}") + for confs, accs in accuracy_confidences.items(): + writer.add_scalar(f"eval/confidence/{confs}", np.mean(accs), global_step) + accelerator.print(f"{confs} confidence: {np.mean(accs)}") + writer.add_scalar("eval/accuracy", accuracies, global_step) + accelerator.print(f"eval accuracy: {accuracies}") + torch.cuda.empty_cache() @@ -601,4 +599,4 @@ def evaluate(args, accelerator, tokenizer, reward_model, dataloader): if __name__ == "__main__": args = tyro.cli(Args) - train(args) + # train(args) From 5622b63edd112c20aab05240d0e9039e7d6cf743 Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Thu, 23 Nov 2023 02:00:06 +0000 Subject: [PATCH 27/62] push changes --- lm_human_preference_details/tldr_dataset.py | 31 ++++++--- .../train_policy_pythia.sh | 51 ++++++++++++++ .../train_reward_accelerate_summarize.py | 66 +++++++++++-------- .../train_sft_accelerate_summarize.py | 19 +----- 4 files changed, 113 insertions(+), 54 deletions(-) create mode 100644 lm_human_preference_details/train_policy_pythia.sh diff --git a/lm_human_preference_details/tldr_dataset.py b/lm_human_preference_details/tldr_dataset.py index c25dd2e..5f8d8e3 100644 --- a/lm_human_preference_details/tldr_dataset.py +++ b/lm_human_preference_details/tldr_dataset.py @@ -4,12 +4,18 @@ from datasets import load_dataset from rich.pretty import pprint from transformers import AutoTokenizer +import tyro + + +@dataclass +class Args: + base_model: str = "gpt2" # EleutherAI/pythia-160m + max_response_length: int = 48 @dataclass class TaskQueryHParams: length: int = 512 - dataset: str = "vwxyzjn/summarize_from_feedback_tldr_3_filtered" format_str: Optional[ str ] = "SUBREDDIT: r/{subreddit}\n\nTITLE: {title}\n\nPOST: {post}\n\nTL;DR:" # if underlying dataset yields dicts, can format arbitrarily @@ -85,16 +91,16 @@ def process_query(query_info: Dict[str, str], *, encoder, hparams: TaskQueryHPar if __name__ == "__main__": - tokenizer = AutoTokenizer.from_pretrained("gpt2") + args = tyro.cli(Args) + tokenizer = AutoTokenizer.from_pretrained(args.base_model) tokenizer.add_special_tokens({"pad_token": "[PAD]"}) - max_response_length = 48 oai_h = TaskQueryHParams() if isinstance(oai_h.padding, str): oai_h.padding = tokenizer.encode(oai_h.padding) else: oai_h.padding = [oai_h.padding] pprint(oai_h) - dataset = load_dataset(oai_h.dataset) + dataset = load_dataset("vwxyzjn/summarize_from_feedback_tldr_3_filtered") def process_query_data(x): # with an extra leading space to account for the space between the query and response @@ -105,14 +111,13 @@ def process_query_data(x): "reference_response_token": tokenizer.encode( reference_response, padding="max_length", - max_length=max_response_length, + max_length=args.max_response_length, truncation=True, ), } dataset = dataset.map(process_query_data, load_from_cache_file=False) - push_result = dataset.push_to_hub("vwxyzjn/summarize_from_feedback_tldr_3_filtered_oai_preprocessing") - print(push_result) + dataset.push_to_hub(f"vwxyzjn/summarize_from_feedback_tldr_3_filtered_oai_preprocessing_{args.base_model.split('/')[-1]}_{args.max_response_length}") label = load_dataset("openai/summarize_from_feedback", "comparisons") @@ -120,17 +125,23 @@ def process_response_data(x): # with an extra leading space to account for the space between the query and response response0 = f" {x['summaries'][0]['text']}<|endoftext|>" response1 = f" {x['summaries'][1]['text']}<|endoftext|>" + response0_policy = x["summaries"][0]["policy"] + response1_policy = x["summaries"][1]["policy"] + policies = "--".join(sorted([response0_policy, response1_policy])) return { **process_query(x["info"], encoder=tokenizer, hparams=oai_h), "response0": response0, "response0_token": tokenizer.encode( - response0, padding="max_length", max_length=max_response_length, truncation=True + response0, padding="max_length", max_length=args.max_response_length, truncation=True ), "response1": response1, "response1_token": tokenizer.encode( - response1, padding="max_length", max_length=max_response_length, truncation=True + response1, padding="max_length", max_length=args.max_response_length, truncation=True ), + "response0_policy": response0_policy, + "response1_policy": response1_policy, + "policies": policies, } label = label.map(process_response_data, load_from_cache_file=False) - push_result = label.push_to_hub("vwxyzjn/summarize_from_feedback_oai_preprocessing") + label.push_to_hub(f"vwxyzjn/summarize_from_feedback_oai_preprocessing_{args.base_model.split('/')[-1]}_{args.max_response_length}") diff --git a/lm_human_preference_details/train_policy_pythia.sh b/lm_human_preference_details/train_policy_pythia.sh new file mode 100644 index 0000000..babd4fc --- /dev/null +++ b/lm_human_preference_details/train_policy_pythia.sh @@ -0,0 +1,51 @@ +# generate random seed and model paths +# set seed if not found in env +if [ -z "$SEED" ]; then + SEED=$RANDOM +fi +if [ -z "$MODEL" ]; then + MODEL=EleutherAI/pythia-1b-deduped +fi +# SEED=3131 +# MODEL=EleutherAI/pythia-1b-deduped +REWARD_MODEL_PATH=models/$MODEL/reward_model_$SEED +SFT_MODEL_PATH=models/$MODEL/sft_model_$SEED +POLICY_MODEL_PATH=models/$MODEL/policy_model_$SEED +poetry run accelerate launch --config_file deepspeed.yaml \ + lm_human_preference_details/train_sft_accelerate_summarize.py \ + --task.query_dataset=vwxyzjn/summarize_from_feedback_tldr_3_filtered_oai_preprocessing_pythia-160m_48 \ + --base_model=$MODEL \ + --deepspeed \ + --track \ + --upload_model \ + --save_path=$SFT_MODEL_PATH \ + --seed=$SEED \ + +poetry run accelerate launch --config_file deepspeed.yaml \ + lm_human_preference_details/train_reward_accelerate_summarize.py \ + --label_dataset=vwxyzjn/summarize_from_feedback_oai_preprocessing_pythia-160m_48 \ + --base_model=$MODEL \ + --no_normalize_before --no_normalize_after \ + --local_batch_size=8 \ + --gradient_accumulation_steps=8 \ + --labels.num_train=92832 \ + --deepspeed \ + --track \ + --sft_model_path=$SFT_MODEL_PATH/pytorch_model.bin \ + --save_path=$REWARD_MODEL_PATH \ + --seed=$SEED \ + +# poetry run accelerate launch --config_file deepspeed.yaml \ +# lm_human_preference_details/train_policy_accelerate_summarize_separate.py \ +# --base_model=$MODEL \ +# --rewards.no_use_adaptive_kl \ +# --rewards.kl_coef=0.05 \ +# --ppo.gradient_accumulation_steps=64 \ +# --ppo.lr=1.5e-5 \ +# --task.temperature=0.7 \ +# --deepspeed \ +# --track \ +# --sft_model_path=$SFT_MODEL_PATH/pytorch_model.bin \ +# --rewards.trained_model=$REWARD_MODEL_PATH/pytorch_model.bin \ +# --seed=$SEED \ +# --save_path=$POLICY_MODEL_PATH \ \ No newline at end of file diff --git a/lm_human_preference_details/train_reward_accelerate_summarize.py b/lm_human_preference_details/train_reward_accelerate_summarize.py index ddcc817..0ddf16c 100644 --- a/lm_human_preference_details/train_reward_accelerate_summarize.py +++ b/lm_human_preference_details/train_reward_accelerate_summarize.py @@ -16,7 +16,7 @@ import tyro from accelerate import Accelerator from accelerate.state import AcceleratorState -from accelerate.utils import DistributedDataParallelKwargs +from accelerate.utils import DistributedDataParallelKwargs, gather_object from datasets import load_dataset from rich.console import Console from rich.pretty import pprint @@ -384,11 +384,7 @@ def get_reward(reward_model, query_responses, tokenizer): def evaluate(args, accelerator, tokenizer, reward_model, dataloader): reward_model.eval() with torch.no_grad(): - # eval on validation_label, some duplicate code (I don't want to make the training loop into a function...) - accuracies = [] - accuracy_splits = defaultdict(list) - accuracy_batches = defaultdict(list) - accuracy_confidences = defaultdict(list) + items = defaultdict(list) for data in tqdm(dataloader): mb_query = data["query_token"] mb_responses = torch.cat([data[f"response0_token"].unsqueeze(1), data[f"response1_token"].unsqueeze(1)], dim=1) @@ -399,15 +395,24 @@ def evaluate(args, accelerator, tokenizer, reward_model, dataloader): predicted_reward = get_reward(reward_model, query_responses, tokenizer) predicted_reward = predicted_reward.view(-1, args.labels.num_labels) accuracy = (predicted_reward.argmax(1) == mb_best).float() - accuracies.append(accuracy.mean()) - for batch, split, confidence, acc in zip(data["batch"], data["split"], data["extra"]["confidence"], accuracy): - acc_item = acc.item() - accuracy_splits[split].append(acc_item) - accuracy_batches[batch].append(acc_item) - accuracy_confidences[int(confidence)].append(acc_item) - accuracies = accelerator.gather(torch.stack(accuracies).mean()).mean().item() + + for k in data: + data[k] = gather_object(data[k]) + for i in range(len(accuracy)): + items["query"].append(tokenizer.decode(data["query_token"][i], skip_special_tokens=True)) + items["response0"].append(tokenizer.decode(data["response0_token"][i])) + items["response1"].append(tokenizer.decode(data["response1_token"][i])) + items["batch"].append(data["batch"][i]) + items["split"].append(data["split"][i]) + items["confidence"].append(data["extra.confidence"][i].item()) + items["choice"].append(data["choice"][i].item()) + items["policies"].append(data["policies"][i]) + items["response0_policy"].append(data["response0_policy"][i]) + items["response1_policy"].append(data["response1_policy"][i]) + items["accuracy"].append(accuracy[i].item()) + breakpoint() reward_model.train() - return accuracies, accuracy_batches, accuracy_splits, accuracy_confidences + return pd.DataFrame(items) # def train(args: Args): @@ -514,8 +519,8 @@ def evaluate(args, accelerator, tokenizer, reward_model, dataloader): label = label.with_format("torch", columns=["query_token", "choice", "response0_token", "response1_token", "batch", "split"]) dataloader = DataLoader(label, batch_size=args.local_micro_batch_size) reward_model, optimizer, dataloader, scheduler = accelerator.prepare(reward_model, optimizer, dataloader, scheduler) - validation_label = load_dataset(args.label_dataset, "comparisons", split="validation") - validation_label = validation_label.with_format("torch", columns=["query_token", "choice", "response0_token", "response1_token", "batch", "split", "extra"]) + validation_label = load_dataset(args.label_dataset, "comparisons", split="validation").flatten() + validation_label = validation_label.with_format("torch", columns=["query_token", "choice", "response0_token", "response1_token", "batch", "split", "extra.confidence", "response0_policy", "response1_policy", "policies"]) validation_dataloader = DataLoader(validation_label, batch_size=args.local_eval_batch_size) validation_dataloader = accelerator.prepare(validation_dataloader) @@ -572,18 +577,23 @@ def evaluate(args, accelerator, tokenizer, reward_model, dataloader): accelerator.print(f"{train_accuracy=}, {scheduler.get_last_lr()=}, {update=}") # if args.print_sample_output_freq > 0 and global_step % args.print_sample_output_freq == 0: - accuracies, accuracy_batches, accuracy_splits, accuracy_confidences = evaluate(args, accelerator, tokenizer, reward_model, validation_dataloader) - for split, accs in accuracy_splits.items(): - writer.add_scalar(f"eval/accuracy/{split}", np.mean(accs), global_step) - accelerator.print(f"{split} accuracy: {np.mean(accs)}") - for batch, accs in accuracy_batches.items(): - writer.add_scalar(f"eval/accuracy/{batch}", np.mean(accs), global_step) - accelerator.print(f"{batch} accuracy: {np.mean(accs)}") - for confs, accs in accuracy_confidences.items(): - writer.add_scalar(f"eval/confidence/{confs}", np.mean(accs), global_step) - accelerator.print(f"{confs} confidence: {np.mean(accs)}") - writer.add_scalar("eval/accuracy", accuracies, global_step) - accelerator.print(f"eval accuracy: {accuracies}") + evaluate_df = evaluate(args, accelerator, tokenizer, reward_model, validation_dataloader) + for split, row in evaluate_df[["split", "accuracy"]].groupby(["split"]).mean().iterrows(): + writer.add_scalar(f"eval/accuracy/{split}", row["accuracy"], global_step) + accelerator.print(f"{split} accuracy: {row['accuracy']}") + for batch, row in evaluate_df[["batch", "accuracy"]].groupby(["batch"]).mean().iterrows(): + writer.add_scalar(f"eval/accuracy/{batch}", row["accuracy"], global_step) + accelerator.print(f"{batch} accuracy: {row['accuracy']}") + for confi, row in evaluate_df[["confidence", "accuracy"]].groupby(["confidence"]).mean().iterrows(): + writer.add_scalar(f"eval/confidence/{confi}", row["accuracy"], global_step) + accelerator.print(f"{confi} confidence: {row['accuracy']}") + writer.add_scalar("eval/accuracy", evaluate_df["accuracy"].mean(), global_step) + accelerator.print(f"eval accuracy: {evaluate_df['accuracy'].mean()}") + if accelerator.is_main_process: + os.makedirs(f"eval_tables/{run_name}", exist_ok=True) + evaluate_df.to_csv(f"eval_tables/{run_name}/eval_{update}.csv") + if args.track: + wandb.log({"samples/query_responses": wandb.Table(dataframe=evaluate_df)}, step=update) torch.cuda.empty_cache() diff --git a/lm_human_preference_details/train_sft_accelerate_summarize.py b/lm_human_preference_details/train_sft_accelerate_summarize.py index eea6feb..e88729c 100644 --- a/lm_human_preference_details/train_sft_accelerate_summarize.py +++ b/lm_human_preference_details/train_sft_accelerate_summarize.py @@ -44,11 +44,10 @@ class SFTHParams: lr: float = 6.35e-5 eps: float = 1e-5 total_episodes: tyro.conf.Suppress[int] = None + micro_batch_size: tyro.conf.Suppress[int] = None local_batch_size: tyro.conf.Suppress[int] = None batch_size: tyro.conf.Suppress[int] = None - mini_batch_size: tyro.conf.Suppress[int] = None world_size: tyro.conf.Suppress[int] = None - batch_size: tyro.conf.Suppress[int] = None num_updates: tyro.conf.Suppress[int] = None @@ -287,17 +286,6 @@ def step(self, closure=None): return loss -def shift_pad_id_left(data, pad_id): - # Step 1: Create a boolean mask - mask = (data == pad_id).long() - # Step 3: Use argsort on the inverted boolean mask to get sorted indices - sorted_indices = torch.argsort(~mask, axis=1) - # Step 4: Use advanced indexing to rearrange the elements - rows_range = torch.arange(data.shape[0], device=data.device) - shifted_data = data[rows_range[:, None], sorted_indices] - return shifted_data - - def right_padding_to_left_padding(tokens, pad_id): """Convert from right padding to left padding.""" assert tokens.ndim == 2 @@ -446,11 +434,10 @@ def forward(policy, query_responses, tokenizer): rouge = evaluate.load("rouge") print("===training policy===") - global_step = 0 loss_stats = torch.zeros(args.sft.gradient_accumulation_steps, device=device) - gradient_accumulation_idx = 0 policy.train() - # for update in range(1, args.sft.num_updates + 1): + gradient_accumulation_idx = 0 + global_step = 0 update = 0 for data in dataloader: update += 1 From b8c5ffcebef949685768f3cd795f6005eb533b4f Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Tue, 28 Nov 2023 14:46:55 +0000 Subject: [PATCH 28/62] push changes --- .../train_reward_accelerate_summarize.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/lm_human_preference_details/train_reward_accelerate_summarize.py b/lm_human_preference_details/train_reward_accelerate_summarize.py index 0ddf16c..d23fb10 100644 --- a/lm_human_preference_details/train_reward_accelerate_summarize.py +++ b/lm_human_preference_details/train_reward_accelerate_summarize.py @@ -81,7 +81,7 @@ class Args: """seed of the experiment""" track: bool = False """if toggled, this experiment will be tracked with Weights and Biases""" - wandb_project_name: str = "cleanrl" + wandb_project_name: str = "tldr_summarize" """the wandb's project name""" wandb_entity: Optional[str] = None """the entity (team) of wandb's project""" @@ -410,7 +410,6 @@ def evaluate(args, accelerator, tokenizer, reward_model, dataloader): items["response0_policy"].append(data["response0_policy"][i]) items["response1_policy"].append(data["response1_policy"][i]) items["accuracy"].append(accuracy[i].item()) - breakpoint() reward_model.train() return pd.DataFrame(items) From ef6d2b1816ab5de87bbc10116aadd9f684ea6e19 Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Tue, 28 Nov 2023 20:52:49 +0000 Subject: [PATCH 29/62] deal with >48 tokens in rm dataset --- lm_human_preference_details/tldr_dataset.py | 138 +++++++++++++++--- .../train_policy_pythia.sh | 10 +- .../train_reward_accelerate_summarize.py | 42 ++---- .../train_sft_accelerate_summarize.py | 26 +--- 4 files changed, 140 insertions(+), 76 deletions(-) diff --git a/lm_human_preference_details/tldr_dataset.py b/lm_human_preference_details/tldr_dataset.py index 5f8d8e3..bac3363 100644 --- a/lm_human_preference_details/tldr_dataset.py +++ b/lm_human_preference_details/tldr_dataset.py @@ -1,16 +1,31 @@ from dataclasses import dataclass -from typing import Dict, Optional, Union +import os +from typing import Dict, Optional from datasets import load_dataset from rich.pretty import pprint from transformers import AutoTokenizer import tyro - - +import multiprocessing +import matplotlib.pyplot as plt +import pandas as pd +from huggingface_hub import HfApi +api = HfApi() + + +""" +poetry run python lm_human_preference_details/tldr_dataset.py +poetry run python lm_human_preference_details/tldr_dataset.py \ + --base-model=EleutherAI/pythia-160m \ + --max-sft-response-length=53 \ + --max-rm-response-length=169 +""" @dataclass class Args: base_model: str = "gpt2" # EleutherAI/pythia-160m - max_response_length: int = 48 + max_sft_response_length: int = 48 # 53 + max_rm_response_length: int = 153 # 169 + hf_entity: str = None @dataclass @@ -21,7 +36,7 @@ class TaskQueryHParams: ] = "SUBREDDIT: r/{subreddit}\n\nTITLE: {title}\n\nPOST: {post}\n\nTL;DR:" # if underlying dataset yields dicts, can format arbitrarily truncate_field: Optional[str] = "post" truncate_text: Optional[str] = "\n" - padding: Optional[Union[str, int]] = 50257 + padding: Optional[str] = " " # empty spaces pad_side: Optional[str] = "left" @@ -92,18 +107,23 @@ def process_query(query_info: Dict[str, str], *, encoder, hparams: TaskQueryHPar if __name__ == "__main__": args = tyro.cli(Args) + if args.hf_entity is None: + args.hf_entity = api.whoami()["name"] + assert isinstance(args.hf_entity, str) tokenizer = AutoTokenizer.from_pretrained(args.base_model) tokenizer.add_special_tokens({"pad_token": "[PAD]"}) oai_h = TaskQueryHParams() if isinstance(oai_h.padding, str): oai_h.padding = tokenizer.encode(oai_h.padding) else: - oai_h.padding = [oai_h.padding] + oai_h.padding = tokenizer.pad_token_id pprint(oai_h) - dataset = load_dataset("vwxyzjn/summarize_from_feedback_tldr_3_filtered") + sft_ds = load_dataset("vwxyzjn/summarize_from_feedback_tldr_3_filtered") def process_query_data(x): - # with an extra leading space to account for the space between the query and response + # the `x['summary']` in `vwxyzjn/summarize_from_feedback_tldr_3_filtered` + # DOES NOT HAVE a leading space so we are adding the leading space and + # `<|endoftext|>` token reference_response = f" {x['summary']}<|endoftext|>" return { **process_query(x, encoder=tokenizer, hparams=oai_h), @@ -111,20 +131,22 @@ def process_query_data(x): "reference_response_token": tokenizer.encode( reference_response, padding="max_length", - max_length=args.max_response_length, + max_length=args.max_sft_response_length, truncation=True, ), + "reference_response_token_len": len(tokenizer.encode(reference_response)), } - dataset = dataset.map(process_query_data, load_from_cache_file=False) - dataset.push_to_hub(f"vwxyzjn/summarize_from_feedback_tldr_3_filtered_oai_preprocessing_{args.base_model.split('/')[-1]}_{args.max_response_length}") + sft_ds = sft_ds.map(process_query_data, load_from_cache_file=False, num_proc=multiprocessing.cpu_count()) + sft_ds.push_to_hub(f"{args.hf_entity}/summarize_from_feedback_tldr_3_filtered_oai_preprocessing_{args.base_model.split('/')[-1]}_{args.max_sft_response_length}") - label = load_dataset("openai/summarize_from_feedback", "comparisons") + label_ds = load_dataset("openai/summarize_from_feedback", "comparisons") def process_response_data(x): - # with an extra leading space to account for the space between the query and response - response0 = f" {x['summaries'][0]['text']}<|endoftext|>" - response1 = f" {x['summaries'][1]['text']}<|endoftext|>" + # the `x['summaries'][0]['text']` in `openai/summarize_from_feedback` `comaprisons` + # DOES HAVE a leading space so we are just adding the `<|endoftext|>` token + response0 = f"{x['summaries'][0]['text']}<|endoftext|>" + response1 = f"{x['summaries'][1]['text']}<|endoftext|>" response0_policy = x["summaries"][0]["policy"] response1_policy = x["summaries"][1]["policy"] policies = "--".join(sorted([response0_policy, response1_policy])) @@ -132,16 +154,94 @@ def process_response_data(x): **process_query(x["info"], encoder=tokenizer, hparams=oai_h), "response0": response0, "response0_token": tokenizer.encode( - response0, padding="max_length", max_length=args.max_response_length, truncation=True + response0, padding="max_length", max_length=args.max_rm_response_length, truncation=True ), + "response0_token_len": len(tokenizer.encode(response0)), "response1": response1, "response1_token": tokenizer.encode( - response1, padding="max_length", max_length=args.max_response_length, truncation=True + response1, padding="max_length", max_length=args.max_rm_response_length, truncation=True ), + "response1_token_len": len(tokenizer.encode(response1)), "response0_policy": response0_policy, "response1_policy": response1_policy, "policies": policies, } - label = label.map(process_response_data, load_from_cache_file=False) - label.push_to_hub(f"vwxyzjn/summarize_from_feedback_oai_preprocessing_{args.base_model.split('/')[-1]}_{args.max_response_length}") + label_ds = label_ds.map(process_response_data, load_from_cache_file=False, num_proc=multiprocessing.cpu_count()) + label_ds.push_to_hub(f"{args.hf_entity}/summarize_from_feedback_oai_preprocessing_{args.base_model.split('/')[-1]}_{args.max_rm_response_length}") + + os.makedirs("dataset_visuals", exist_ok=True) + # visualize token length distribution + num_subplots = len(sft_ds) + len(label_ds) * 2 + print(f"{num_subplots=}") + fig, axs = plt.subplots(3, 3, figsize=(16, 16)) + axs = axs.flatten() + for i, key in enumerate(sft_ds.keys()): + df = sft_ds[key].to_pandas() + axs[i].hist(df["reference_response_token_len"], bins=100) + axs[i].set_title(f"{key} split: reference response token length\nmax_length={max(df['reference_response_token_len'])}") + offset = len(sft_ds) + for i, key in enumerate(label_ds.keys()): + df = label_ds[key].to_pandas() + axs[2*i + offset].hist(df["response0_token_len"], bins=100) + axs[2*i + offset].set_title(f"{key} split: response0 token length\nmax_length={max(df['response0_token_len'])}") + axs[2*i + offset + 1].hist(df["response1_token_len"], bins=100) + axs[2*i + offset + 1].set_title(f"{key} split: response1 token length\nmax_length={max(df['response1_token_len'])}") + fig.suptitle(f"{args.base_model} Tokenizer: Token length distribution") + fig.tight_layout() + fig.savefig("dataset_visuals/token_len.png") + + # visualize confidence distribution + fig, axs = plt.subplots(len(label_ds), 1, figsize=(8, 8)) + axs = axs.flatten() + label_ds = label_ds.flatten() + for i, key in enumerate(label_ds.keys()): + df = label_ds[key].to_pandas() + axs[i].hist(df["extra.confidence"]) + axs[i].set_title(f"{key} split: confidence distribution") + fig.suptitle("Confidence distribution") + fig.tight_layout() + fig.savefig("dataset_visuals/confidence.png") + + # visualize policies used + fig, axs = plt.subplots(1, len(label_ds), figsize=(8, 12)) + axs = axs.flatten() + label_ds = label_ds.flatten() + for i, key in enumerate(label_ds.keys()): + df = label_ds[key].to_pandas() + cat = pd.concat([df["response0_policy"], df["response1_policy"]], axis=0) + cat.hist(ax=axs[i], xrot=90, orientation="horizontal") + axs[i].set_title(f"{key} split: policy distribution") + fig.suptitle("Policy distribution") + fig.tight_layout() + fig.savefig("dataset_visuals/policies.png") + + # visualize compairson distribution + fig, axs = plt.subplots(1, len(label_ds), figsize=(24, 30)) + axs = axs.flatten() + label_ds = label_ds.flatten() + for i, key in enumerate(label_ds.keys()): + df = label_ds[key].to_pandas() + df["policies"].hist(ax=axs[i], xrot=90, orientation="horizontal") + axs[i].set_title(f"{key} split: policy comparison distribution") + fig.suptitle("Policy comparison distribution") + fig.tight_layout() + fig.savefig("dataset_visuals/policy_comparisons.png") + + # upload the `dataset_visuals` + + api.upload_folder( + folder_path="dataset_visuals", + path_in_repo="dataset_visuals", + repo_id=f"{args.hf_entity}/summarize_from_feedback_oai_preprocessing_{args.base_model.split('/')[-1]}_{args.max_rm_response_length}", + repo_type="dataset", + ) + # upload current file + print(f"{__file__=}") + api.upload_file( + path_or_fileobj=__file__, + path_in_repo="create_dataset.py", + repo_id=f"{args.hf_entity}/summarize_from_feedback_oai_preprocessing_{args.base_model.split('/')[-1]}_{args.max_rm_response_length}", + repo_type="dataset", + ) + diff --git a/lm_human_preference_details/train_policy_pythia.sh b/lm_human_preference_details/train_policy_pythia.sh index babd4fc..b0c43ab 100644 --- a/lm_human_preference_details/train_policy_pythia.sh +++ b/lm_human_preference_details/train_policy_pythia.sh @@ -6,6 +6,9 @@ fi if [ -z "$MODEL" ]; then MODEL=EleutherAI/pythia-1b-deduped fi +if [ -z "$LR" ]; then + LR=1.5e-5 +fi # SEED=3131 # MODEL=EleutherAI/pythia-1b-deduped REWARD_MODEL_PATH=models/$MODEL/reward_model_$SEED @@ -13,18 +16,19 @@ SFT_MODEL_PATH=models/$MODEL/sft_model_$SEED POLICY_MODEL_PATH=models/$MODEL/policy_model_$SEED poetry run accelerate launch --config_file deepspeed.yaml \ lm_human_preference_details/train_sft_accelerate_summarize.py \ - --task.query_dataset=vwxyzjn/summarize_from_feedback_tldr_3_filtered_oai_preprocessing_pythia-160m_48 \ + --task.query_dataset=vwxyzjn/summarize_from_feedback_tldr_3_filtered_oai_preprocessing_pythia-160m_53 \ --base_model=$MODEL \ + --sft.lr=$LR \ --deepspeed \ --track \ - --upload_model \ --save_path=$SFT_MODEL_PATH \ --seed=$SEED \ poetry run accelerate launch --config_file deepspeed.yaml \ lm_human_preference_details/train_reward_accelerate_summarize.py \ - --label_dataset=vwxyzjn/summarize_from_feedback_oai_preprocessing_pythia-160m_48 \ + --label_dataset=vwxyzjn/summarize_from_feedback_oai_preprocessing_pythia-160m_169 \ --base_model=$MODEL \ + --lr=$LR \ --no_normalize_before --no_normalize_after \ --local_batch_size=8 \ --gradient_accumulation_steps=8 \ diff --git a/lm_human_preference_details/train_reward_accelerate_summarize.py b/lm_human_preference_details/train_reward_accelerate_summarize.py index d23fb10..e096f8b 100644 --- a/lm_human_preference_details/train_reward_accelerate_summarize.py +++ b/lm_human_preference_details/train_reward_accelerate_summarize.py @@ -41,25 +41,6 @@ class LabelHParams: source: str = None -@dataclass -class TaskHParams: - # Query params - query_length: int = 512 - query_dataset: str = "vwxyzjn/summarize_from_feedback_tldr_3_filtered_oai_preprocessing" - - query_format_str: Optional[str] = "SUBREDDIT: r/{subreddit}\n\nTITLE: {title}\n\nPOST: {post}\n\nTL;DR:" - query_truncate_field: Optional[str] = "post" - query_truncate_text: Optional[str] = "\n" - query_padding: Optional[str] = None # defaults to repeated spaces - query_pad_side: Optional[str] = "left" - - # Response params - response_length: int = 48 - - # LM params - temperature: float = 0.7 - - # a patch @dataclass class TaskQueryHParams: @@ -92,13 +73,13 @@ class Args: load_from_cache_file: bool = False """Whether to load data from the local cache file in `dataset.map`""" - base_model: str = "gpt2" + base_model: str = "EleutherAI/pythia-160m" """the name of the pretrained model to use""" dropout_layer_keys: List[str] = field(default_factory=lambda: ["attn_pdrop", "embd_pdrop", "resid_pdrop", "summary_first_dropout"]) """Which layers to apply dropout to""" deepspeed: bool = False """Whether to use deepspeed to train the model""" - label_dataset: str = "vwxyzjn/summarize_from_feedback_oai_preprocessing" + label_dataset: str = "vwxyzjn/summarize_from_feedback_oai_preprocessing_pythia-160m_169" """the name of the dataset to use for labels in `https://huggingface.co/datasets/vwxyzjn/lm-human-preferences`""" local_batch_size: int = 8 """per rank batch size""" @@ -150,7 +131,6 @@ class Args: """Which scheduler to use""" warm_up_steps: int = 0 """Number of warm up steps for the scheduler""" - task: TaskHParams = field(default_factory=TaskHParams) labels: LabelHParams = field(default_factory=LabelHParams) @@ -344,13 +324,13 @@ def forward(self, **kwargs): return reward -def left_padding_to_right_padding(tokens, pad_id): - """Convert from left padding to right padding.""" - assert tokens.ndim == 2 - return torch.tensor( - [[x for x in row if x != pad_id] + [pad_id] * (row == pad_id).sum() for row in tokens], - device=tokens.device, - ) +# def left_padding_to_right_padding(tokens, pad_id): +# """Convert from left padding to right padding.""" +# assert tokens.ndim == 2 +# return torch.tensor( +# [[x for x in row if x != pad_id] + [pad_id] * (row == pad_id).sum() for row in tokens], +# device=tokens.device, +# ) def ceil_div(a, b): @@ -391,7 +371,7 @@ def evaluate(args, accelerator, tokenizer, reward_model, dataloader): mb_best = data["choice"] mb_query_tiled = mb_query.unsqueeze(1).repeat(1, args.labels.num_labels, 1) query_responses = torch.cat([mb_query_tiled, mb_responses], dim=2).flatten(0, 1) - query_responses = left_padding_to_right_padding(query_responses, tokenizer.pad_token_id) + # query_responses = left_padding_to_right_padding(query_responses, tokenizer.pad_token_id) predicted_reward = get_reward(reward_model, query_responses, tokenizer) predicted_reward = predicted_reward.view(-1, args.labels.num_labels) accuracy = (predicted_reward.argmax(1) == mb_best).float() @@ -542,7 +522,7 @@ def evaluate(args, accelerator, tokenizer, reward_model, dataloader): mb_best = data["choice"] mb_query_tiled = mb_query.unsqueeze(1).repeat(1, args.labels.num_labels, 1) query_responses = torch.cat([mb_query_tiled, mb_responses], dim=2).flatten(0, 1) - query_responses = left_padding_to_right_padding(query_responses, tokenizer.pad_token_id) + # query_responses = left_padding_to_right_padding(query_responses, tokenizer.pad_token_id) with accelerator.accumulate(reward_model): predicted_reward = get_reward(reward_model, query_responses, tokenizer) predicted_reward = predicted_reward.view(-1, args.labels.num_labels) diff --git a/lm_human_preference_details/train_sft_accelerate_summarize.py b/lm_human_preference_details/train_sft_accelerate_summarize.py index e88729c..1f4e252 100644 --- a/lm_human_preference_details/train_sft_accelerate_summarize.py +++ b/lm_human_preference_details/train_sft_accelerate_summarize.py @@ -55,7 +55,7 @@ class SFTHParams: class TaskHParams: # Query params query_length: int = 512 - query_dataset: str = "vwxyzjn/summarize_from_feedback_tldr_3_filtered_oai_preprocessing" + query_dataset: str = "vwxyzjn/summarize_from_feedback_tldr_3_filtered_oai_preprocessing_pythia-160m_53" query_format_str: Optional[str] = "SUBREDDIT: r/{subreddit}\n\nTITLE: {title}\n\nPOST: {post}\n\nTL;DR:" query_truncate_field: Optional[str] = "post" @@ -64,7 +64,7 @@ class TaskHParams: query_pad_side: Optional[str] = "left" # Response params - response_length: int = 48 + response_length: int = 53 # Truncate response after the first occurrence of this token at or after index after when sampling. truncate_token: int = 50256 # EOS token @@ -99,7 +99,7 @@ class Args: hf_entity: str = "" "the user or org name of the model repository from the Hugging Face Hub" - base_model: str = "gpt2" + base_model: str = "EleutherAI/pythia-160m" """the name of the pretrained model to use""" dropout_layer_keys: List[str] = field(default_factory=lambda: ["attn_pdrop", "embd_pdrop", "resid_pdrop", "summary_first_dropout"]) """Which layers to apply dropout to""" @@ -286,23 +286,6 @@ def step(self, closure=None): return loss -def right_padding_to_left_padding(tokens, pad_id): - """Convert from right padding to left padding.""" - assert tokens.ndim == 2 - return torch.tensor( - [[pad_id] * (row == pad_id).sum() + [x for x in row if x != pad_id] for row in tokens], - device=tokens.device, - ) - -def left_padding_to_right_padding(tokens, pad_id): - """Convert from left padding to right padding.""" - assert tokens.ndim == 2 - return torch.tensor( - [[x for x in row if x != pad_id] + [pad_id] * (row == pad_id).sum() for row in tokens], - device=tokens.device, - ) - - def ceil_div(a, b): return (a - 1) // b + 1 @@ -445,7 +428,6 @@ def forward(policy, query_responses, tokenizer): reference_responses = data["reference_response_token"].to(device, non_blocking=True) queries = data["query_token"].to(device, non_blocking=True) query_responses = torch.cat((queries, reference_responses), dim=1) - query_responses = left_padding_to_right_padding(query_responses, tokenizer.pad_token_id) with accelerator.accumulate(policy): output = forward(policy, query_responses, tokenizer) # mask out gradient effects on response padding tokens @@ -478,11 +460,9 @@ def forward(policy, query_responses, tokenizer): with torch.no_grad(): validation_reference_responses = validation_data["reference_response_token"].to(device, non_blocking=True) validation_queries = validation_data["query_token"].to(device, non_blocking=True) - # validation_queries = right_padding_to_left_padding(validation_queries, tokenizer.pad_token_id) # not necessary validation_query_reference_responses = torch.cat( (validation_queries, validation_reference_responses), dim=1 ) - validation_query_reference_responses = left_padding_to_right_padding(validation_query_reference_responses, tokenizer.pad_token_id) validation_output = forward(policy, validation_query_reference_responses, tokenizer) validation_labels = validation_query_reference_responses.masked_fill( From 46e00ca3c8a03c1864b466a0ce9fa9160da65136 Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Thu, 30 Nov 2023 22:52:48 +0000 Subject: [PATCH 30/62] cache a debugging 25 token generated script --- ...in_policy_accelerate_summarize_separate.py | 274 +++-- ...n_policy_accelerate_summarize_separate1.py | 1024 +++++++++++++++++ 2 files changed, 1215 insertions(+), 83 deletions(-) create mode 100644 lm_human_preference_details/train_policy_accelerate_summarize_separate1.py diff --git a/lm_human_preference_details/train_policy_accelerate_summarize_separate.py b/lm_human_preference_details/train_policy_accelerate_summarize_separate.py index 63cc0b7..6849bc3 100644 --- a/lm_human_preference_details/train_policy_accelerate_summarize_separate.py +++ b/lm_human_preference_details/train_policy_accelerate_summarize_separate.py @@ -59,7 +59,7 @@ class PpoHParams: local_mini_batch_size: tyro.conf.Suppress[int] = None batch_size: tyro.conf.Suppress[int] = None mini_batch_size: tyro.conf.Suppress[int] = None - gradient_accumulation_steps: int = 1 + gradient_accumulation_steps: int = 64 """gradient accumulation steps""" local_micro_batch_size: tyro.conf.Suppress[int] = None """per rank micro batch size""" @@ -83,7 +83,7 @@ class PpoHParams: class TaskHParams: # Query params query_length: int = 512 - query_dataset: str = "vwxyzjn/summarize_from_feedback_tldr_3_filtered_oai_preprocessing" + query_dataset: str = "vwxyzjn/summarize_from_feedback_tldr_3_filtered_oai_preprocessing_pythia-160m_53" query_format_str: Optional[str] = "SUBREDDIT: r/{subreddit}\n\nTITLE: {title}\n\nPOST: {post}\n\nTL;DR:" query_truncate_field: Optional[str] = "post" @@ -92,10 +92,11 @@ class TaskHParams: query_pad_side: Optional[str] = "left" # Response params - response_length: int = 48 + response_length: int = 53 # Truncate response after the first occurrence of this token at or after index after when sampling. - truncate_token: int = 50256 # EOS token + truncate_token: Literal["eos"] = "eos" + truncate_token_id: Optional[int] = None truncate_after: int = 16 penalty_reward_value: int = -1 @@ -124,7 +125,7 @@ class Args: """seed of the experiment""" track: bool = False """if toggled, this experiment will be tracked with Weights and Biases""" - wandb_project_name: str = "cleanrl" + wandb_project_name: str = "tldr_summarize" """the wandb's project name""" wandb_entity: Optional[str] = None """the entity (team) of wandb's project""" @@ -139,7 +140,7 @@ class Args: hf_entity: str = "" "the user or org name of the model repository from the Hugging Face Hub" - base_model: str = "gpt2" + base_model: str = "EleutherAI/pythia-160m" """the name of the pretrained model to use""" dropout_layer_keys: List[str] = field(default_factory=lambda: ["attn_pdrop", "embd_pdrop", "resid_pdrop", "summary_first_dropout"]) """Which layers to apply dropout to""" @@ -365,10 +366,8 @@ def __init__(self, lm_backbone): def forward(self, **kwargs): output = self.lm_backbone(**kwargs) - latents = output.hidden_states[-1] # shape: [batch_size, length, hidden_size] - scalars = self.scalar_head(latents).squeeze(-1) # shape: [batch_size, length] - last_scalar = scalars[:, -1] # shape: [batch_size, 1] - return scalars, last_scalar + reward = self.scalar_head(output.hidden_states[-1]) + return reward # taken from https://github.com/OpenLMLab/MOSS-RLHF/blob/40b91eb2f2b71b16919addede0341d2bef70825d/ppo/ppo_trainer.py#L29 @@ -383,17 +382,6 @@ def forward(self, **kwargs): return self.policy(**kwargs), self.critic(**kwargs) -def shift_pad_id_left(data, pad_id): - # Step 1: Create a boolean mask - mask = (data == pad_id).long() - # Step 3: Use argsort on the inverted boolean mask to get sorted indices - sorted_indices = torch.argsort(~mask, axis=1) - # Step 4: Use advanced indexing to rearrange the elements - rows_range = torch.arange(data.shape[0], device=data.device) - shifted_data = data[rows_range[:, None], sorted_indices] - return shifted_data - - def ceil_div(a, b): return (a - 1) // b + 1 @@ -420,52 +408,64 @@ def generate(lm_backbone, queries, tokenizer, generation_config): return torch.cat((queries, output.sequences[:, context_length:]), dim=1) +def first_true_indices(bools, dtype=torch.long): + """ + Takes an N-dimensional bool tensor and returns an (N-1)-dimensional tensor of integers giving + the position of the first True in each "row". + + Returns the length of the rows (bools.size(-1)) if no element is True in a given row. + """ + row_len = bools.size(-1) + zero_or_index = row_len * (~bools).type(dtype) + torch.arange(row_len, dtype=dtype, device=bools.device) + return torch.min(zero_or_index, dim=-1).values + + +def truncate_response(args, tokenizer, responses): + trunc_idxs = first_true_indices(responses == args.task.truncate_token_id).unsqueeze(-1) + new_size = [1] * (len(responses.size()) - 1) + [args.task.response_length] + idxs = torch.arange(args.task.response_length, device=responses.device).view(*new_size) + postprocessed_responses = torch.masked_fill(responses, idxs > trunc_idxs, tokenizer.pad_token_id) + return postprocessed_responses + + +def masked_mean(x, mask): + return (x.sum(-1) / (~mask).sum(-1)).mean() + + def get_reward(reward_model, query_responses, tokenizer): attention_mask = query_responses != tokenizer.pad_token_id - position_ids = attention_mask.cumsum(1) - attention_mask.long() # exclusive cumsum + # position_ids = attention_mask.cumsum(1) - attention_mask.long() # exclusive cumsum input_ids = torch.masked_fill(query_responses, ~attention_mask, 0) - return reward_model( + reward_logits = reward_model( input_ids=input_ids, attention_mask=attention_mask, - position_ids=position_ids, + # position_ids=position_ids, return_dict=True, output_hidden_states=True, ) + sequence_lengths = first_true_indices(query_responses == tokenizer.pad_token_id) - 1 + # sequence_lengths1 = ( + # torch.eq(query_responses, tokenizer.pad_token_id).long().argmax(-1) - 1).to( + # query_responses.device + # ) + # print(f"======={sequence_lengths1=} {sequence_lengths=}") + # https://github.com/huggingface/transformers/blob/dc68a39c8111217683bf49a4912d0c9018bab33d/src/transformers/models/gpt2/modeling_gpt2.py#L1454 + return reward_logits, reward_logits[torch.arange(reward_logits.size(0), device=reward_logits.device), sequence_lengths].squeeze(-1), sequence_lengths def forward(policy, query_responses, tokenizer): attention_mask = query_responses != tokenizer.pad_token_id - position_ids = attention_mask.cumsum(1) - attention_mask.long() # exclusive cumsum + # position_ids = attention_mask.cumsum(1) - attention_mask.long() # exclusive cumsum input_ids = torch.masked_fill(query_responses, ~attention_mask, 0) return policy( input_ids=input_ids, attention_mask=attention_mask, - position_ids=position_ids, + # position_ids=position_ids, return_dict=True, output_hidden_states=True, ) -def first_true_indices(bools, dtype=torch.long): - """ - Takes an N-dimensional bool tensor and returns an (N-1)-dimensional tensor of integers giving - the position of the first True in each "row". - - Returns the length of the rows (bools.size(-1)) if no element is True in a given row. - """ - row_len = bools.size(-1) - zero_or_index = row_len * (~bools).type(dtype) + torch.arange(row_len, dtype=dtype, device=bools.device) - return torch.min(zero_or_index, dim=-1).values - - -def truncate_response(args, tokenizer, responses): - trunc_idxs = first_true_indices(responses == args.task.truncate_token).unsqueeze(-1) - new_size = [1] * (len(responses.size()) - 1) + [args.task.response_length] - idxs = torch.arange(args.task.response_length, device=responses.device).view(*new_size) - postprocessed_responses = torch.masked_fill(responses, idxs > trunc_idxs, tokenizer.pad_token_id) - return postprocessed_responses - - # def train(args: Args): if __name__ == "__main__": args = tyro.cli(Args) @@ -482,6 +482,15 @@ def truncate_response(args, tokenizer, responses): # `per_rank_rollout_batch_size` is our `args.ppo.local_batch_size` # `per_rank_minibatch_size` is our `args.ppo.local_mini_batch_size` args.ppo.num_updates = args.ppo.total_episodes // args.ppo.batch_size + tokenizer = AutoTokenizer.from_pretrained( + args.base_model, + padding_side="right", + trust_remote_code=True, + ) + # we use the padding token manually but do not resize the token embedding of the model + tokenizer.add_special_tokens({"pad_token": "[PAD]"}) + if args.task.truncate_token == "eos": + args.task.truncate_token_id = tokenizer.eos_token_id console = Console(force_terminal=True) run_name = f"{args.exp_name}__{args.seed}__{int(time.time())}" @@ -513,13 +522,7 @@ def truncate_response(args, tokenizer, responses): np.random.seed(local_seed) torch.manual_seed(local_seed) torch.backends.cudnn.deterministic = True - tokenizer = AutoTokenizer.from_pretrained( - args.base_model, - padding_side="right", - trust_remote_code=True, - ) - # we use the padding token manually but do not resize the token embedding of the model - tokenizer.add_special_tokens({"pad_token": "[PAD]"}) + model_config = AutoConfig.from_pretrained(args.base_model) configure_dropout(model_config, args.dropout_layer_keys, 0.0) # disable dropout if accelerator.is_main_process: @@ -605,15 +608,15 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well sample_validation = validation_dataset[local_sample_validation_inds] sample_validation_queries = torch.Tensor(sample_validation["query_token"]).to(device) with torch.no_grad(): - sample_validation_queries = shift_pad_id_left(sample_validation_queries, tokenizer.pad_token_id) + # sample_validation_queries = shift_pad_id_left(sample_validation_queries, tokenizer.pad_token_id) sample_validation_reference_response = torch.Tensor(sample_validation["reference_response_token"]).to(device) sample_validation_query_reference_responses = torch.cat( (sample_validation_queries, sample_validation_reference_response), dim=1 ) - sample_validation_query_reference_responses = shift_pad_id_left( - sample_validation_query_reference_responses, tokenizer.pad_token_id - ) - _, sample_validation_reference_scores = get_reward( + # sample_validation_query_reference_responses = shift_pad_id_left( + # sample_validation_query_reference_responses, tokenizer.pad_token_id + # ) + _, sample_validation_reference_scores, _ = get_reward( reward_model, sample_validation_query_reference_responses, tokenizer ) @@ -640,6 +643,7 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well vf_clipfrac_stats = torch.zeros(stats_shape, device=device) entropy_stats = torch.zeros(stats_shape, device=device) ratio_stats = torch.zeros(stats_shape, device=device) + model.train() for update in range(1, args.ppo.num_updates + 1): global_step += 1 * args.ppo.batch_size @@ -650,15 +654,70 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well with torch.no_grad(): queries = data["query_token"].to(device) reference_responses = data["reference_response_token"].to(device) - queries = shift_pad_id_left(queries, tokenizer.pad_token_id) + # queries = shift_pad_id_left(queries, tokenizer.pad_token_id) query_reference_responses = torch.cat((queries, reference_responses), dim=1) - query_reference_responses = shift_pad_id_left(query_reference_responses, tokenizer.pad_token_id) + # query_reference_responses = shift_pad_id_left(query_reference_responses, tokenizer.pad_token_id) query_responses = generate( accelerator.unwrap_model(model).policy, queries, tokenizer, generation_config, ) + if args.task.response_length != 53: + query_responses = torch.tensor([[ 209, 209, 209, 209, 209, 209, 209, 209, 209, 209, + 209, 209, 209, 209, 209, 209, 209, 209, 209, 209, + 209, 209, 209, 209, 209, 209, 209, 209, 209, 209, + 209, 209, 209, 209, 209, 209, 209, 209, 209, 209, + 209, 209, 209, 209, 209, 209, 209, 209, 209, 209, + 209, 209, 209, 209, 209, 209, 209, 209, 209, 209, + 209, 209, 209, 209, 209, 209, 209, 209, 209, 209, + 209, 209, 209, 209, 209, 209, 209, 209, 209, 209, + 209, 209, 209, 209, 209, 209, 209, 209, 209, 209, + 209, 209, 209, 209, 209, 209, 209, 209, 209, 209, + 209, 209, 209, 209, 209, 209, 209, 209, 209, 209, + 209, 209, 209, 209, 209, 209, 209, 209, 209, 209, + 209, 209, 209, 209, 209, 209, 209, 209, 209, 209, + 209, 209, 209, 209, 209, 209, 209, 209, 209, 209, + 209, 209, 209, 209, 209, 209, 6971, 7941, 1703, 37, + 1433, 27, 391, 16, 22842, 16458, 187, 187, 53, 43561, + 27, 3189, 544, 1348, 278, 62, 5816, 619, 806, 385, + 544, 1797, 269, 62, 846, 608, 2607, 273, 2740, 598, + 15, 187, 187, 15743, 27, 24387, 39714, 187, 6300, 15950, + 436, 1501, 562, 627, 816, 281, 1339, 352, 562, 273, + 619, 985, 15, 187, 2598, 309, 452, 644, 13597, 436, + 3226, 313, 2577, 806, 19609, 15, 309, 369, 617, 806, + 10, 323, 495, 1107, 15, 844, 574, 271, 13103, 673, + 285, 4536, 7227, 35267, 285, 37616, 15, 496, 253, 990, + 13, 352, 1904, 626, 789, 562, 15, 187, 42, 3260, + 309, 7636, 617, 285, 703, 7636, 479, 533, 1841, 816, + 1904, 626, 789, 562, 1955, 281, 1097, 4858, 4606, 15, + 187, 187, 2598, 352, 556, 644, 2761, 608, 2607, 15, + 309, 1694, 689, 253, 31056, 673, 273, 619, 1495, 534, + 369, 1501, 2740, 598, 273, 806, 374, 2607, 15, 209, + 187, 4125, 846, 608, 2607, 13, 309, 816, 2985, 617, + 15, 187, 42, 5476, 627, 11210, 626, 644, 247, 2014, + 835, 309, 6468, 626, 1869, 670, 617, 2568, 15, 23385, + 50276, 187, 42, 871, 309, 10095, 626, 3057, 617, 285, + 309, 1353, 3965, 2119, 703, 1912, 626, 3057, 479, 2057, + 534, 310, 323, 253, 1805, 15, 187, 1231, 6468, 626, + 13452, 323, 5046, 374, 2607, 32, 1633, 751, 326, 15, + 187, 43688, 13, 309, 816, 4571, 626, 6016, 352, 10542, + 285, 3261, 387, 776, 7963, 327, 619, 17899, 7963, 534, + 309, 1620, 755, 327, 15, 187, 1147, 369, 5322, 281, + 923, 617, 2454, 969, 285, 30774, 336, 253, 1711, 1897, + 15, 187, 1147, 369, 5322, 281, 923, 253, 9097, 359, + 1097, 2389, 1024, 3811, 342, 617, 2021, 15, 187, 34937, + 512, 608, 2607, 13, 619, 5249, 5055, 598, 15, 309, + 1694, 247, 14892, 209, 187, 36421, 598, 247, 2257, 273, + 2583, 285, 858, 1841, 1475, 253, 2419, 309, 6468, 626, + 644, 2104, 281, 3966, 3966, 15, 187, 1989, 309, 816, + 2985, 617, 15, 187, 42, 871, 703, 434, 2509, 973, + 13, 3164, 1805, 685, 1078, 15, 187, 2513, 352, 816, + 479, 32, 209, 187, 25954, 6701, 323, 634, 673, 4361, + 436, 285, 11435, 634, 5701, 15, 187, 187, 14135, 28, + 4976, 27, 6365, 619, 806, 19609, 13, 9377, 598, 13, + 309, 2985, 617, 533, 1053, 626, 3057, 617, 285, 12371, + 604, 352, 434, 816, 479, 15, 0,]], device=device) context_length = queries.shape[1] responses = query_responses[:, context_length:] @@ -674,9 +733,9 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well postprocessed_sample_validation_query_responses = torch.cat( (sample_validation_queries, postprocessed_sample_validation_responses), 1 ) - postprocessed_sample_validation_query_responses = shift_pad_id_left( - postprocessed_sample_validation_query_responses, tokenizer.pad_token_id - ) + # postprocessed_sample_validation_query_responses = shift_pad_id_left( + # postprocessed_sample_validation_query_responses, tokenizer.pad_token_id + # ) torch.cuda.empty_cache() # TODO: do I do this with query response or post-processed query response? @@ -702,34 +761,47 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well # 2. run reward model on the truncated responses postprocessed_query_responses = torch.cat((queries, postprocessed_responses), 1) - postprocessed_query_responses = shift_pad_id_left(postprocessed_query_responses, tokenizer.pad_token_id) - full_values, _ = get_reward(accelerator.unwrap_model(model).critic, postprocessed_query_responses, tokenizer) + # postprocessed_query_responses = shift_pad_id_left(postprocessed_query_responses, tokenizer.pad_token_id) + full_values, _, _ = get_reward(accelerator.unwrap_model(model).critic, postprocessed_query_responses, tokenizer) values = full_values[:, context_length - 1 : -1].squeeze(-1) padding_mask = postprocessed_responses == tokenizer.pad_token_id logprobs = torch.masked_fill(logprobs, padding_mask, INVALID_LOGPROB) ref_logprobs = torch.masked_fill(ref_logprobs, padding_mask, INVALID_LOGPROB) values = torch.masked_fill(values, padding_mask, 0) - rew, scores = get_reward(reward_model, postprocessed_query_responses, tokenizer) + rew, scores, sequence_lengths = get_reward(reward_model, postprocessed_query_responses, tokenizer) - _, reference_scores = get_reward(reward_model, query_reference_responses, tokenizer) - _, validation_score = get_reward(reward_model, postprocessed_sample_validation_query_responses, tokenizer) + _, reference_scores, _ = get_reward(reward_model, query_reference_responses, tokenizer) + _, validation_score, _ = get_reward(reward_model, postprocessed_sample_validation_query_responses, tokenizer) # carperAI-style score normaliation scores = scores - reference_scores - # 3. filter response. Ensure that the sample contains truncate_token + # 3. filter response. Ensure that the sample contains truncate_token_id # responses not passing that filter will receive a low (fixed) score # only query humans on responses that pass that filter contain_pad_token = torch.any(postprocessed_responses == tokenizer.pad_token_id, dim=-1) - scores = torch.where(contain_pad_token, scores, torch.full_like(scores, args.task.penalty_reward_value)) + + + + # TODO: reverse it back + # scores = torch.where(contain_pad_token, scores, torch.full_like(scores, args.task.penalty_reward_value)) + + + + torch.cuda.empty_cache() # 4. compute rewards kl = logprobs - ref_logprobs non_score_reward = -kl_ctl.value * kl rewards = non_score_reward.clone() - rewards[:, -1] += scores + # print(f"{sequence_lengths=}") + # breakpoint() + # rewards[:, -1] += scores + actual_start = torch.arange(rewards.size(0), device=rewards.device) + actual_end = sequence_lengths - context_length + rewards[[actual_start, actual_end]] += scores # 5. whiten rewards if args.ppo.whiten_rewards: @@ -738,7 +810,7 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well if args.print_sample_output_freq > 0 and (update - 1) % args.print_sample_output_freq == 0: try: all_decode_validation_queries = tokenizer.batch_decode(sample_validation_queries, skip_special_tokens=True) - all_sample_validation_responses = tokenizer.batch_decode(postprocessed_sample_validation_responses) + all_sample_validation_responses = tokenizer.batch_decode(sample_validation_responses) all_sample_validation_query_responses_postprocessed = tokenizer.batch_decode( postprocessed_sample_validation_query_responses, skip_special_tokens=True ) @@ -759,7 +831,7 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well ) if accelerator.is_main_process and args.track: wandb.log({"samples/query_responses": wandb.Table(dataframe=all_sample_validation_df)}, step=update) - print_rich_table("stuff", all_sample_validation_df[:4], console) + # print_rich_table("stuff", all_sample_validation_df[:4], console) except Exception as e: print(e) @@ -783,7 +855,17 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well advantages_reversed.append(lastgaelam) advantages = torch.stack(advantages_reversed[::-1], axis=1) returns = advantages + values - advantages = whiten(advantages) + + + + + # TODO: reverse it back + # advantages = whiten(advantages) + + + + + return_mean, return_var = returns.mean(), returns.var() value_mean, value_var = values.mean(), values.var() @@ -806,15 +888,14 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well mb_query_responses = query_responses[micro_batch_inds] mb_logprobs = logprobs[micro_batch_inds] - output, (vpred_temp, _) = forward(model, mb_query_responses, tokenizer) + output, vpred_temp = forward(model, mb_query_responses, tokenizer) logits = output.logits[:, context_length - 1 : -1] logits /= args.task.temperature + 1e-7 new_all_logprobs = F.log_softmax(logits, dim=-1) new_logprobs = torch.gather(new_all_logprobs, 2, mb_responses.unsqueeze(-1)).squeeze(-1) + vpred = vpred_temp[:, context_length - 1 : -1].squeeze(-1) new_logprobs = torch.masked_fill(new_logprobs, padding_mask[micro_batch_inds], INVALID_LOGPROB) - vpred = vpred_temp[:, context_length - 1 : -1] vpred = torch.masked_fill(vpred, padding_mask[micro_batch_inds], 0) - mb_return = torch.masked_fill(mb_return, padding_mask[micro_batch_inds], 0) # should not have a gradient effect vpredclipped = torch.clamp( vpred, mb_values - args.ppo.cliprange_value, @@ -822,28 +903,55 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well ) vf_losses1 = torch.square(vpred - mb_return) vf_losses2 = torch.square(vpredclipped - mb_return) - vf_loss = 0.5 * torch.max(vf_losses1, vf_losses2).mean() - vf_clipfrac = (vf_losses2 > vf_losses1).float().mean() + vf_loss_max = torch.max(vf_losses1, vf_losses2) + + + # vf_loss = 0.5 * vf_loss_max.mean() + vf_loss = 0.5 * masked_mean(vf_loss_max, padding_mask[micro_batch_inds]) + + vf_clipfrac = masked_mean((vf_losses2 > vf_losses1).float(), padding_mask[micro_batch_inds]) logprobs_diff = new_logprobs - mb_logprobs ratio = torch.exp(logprobs_diff) pg_losses = -mb_advantage * ratio pg_losses2 = -mb_advantage * torch.clamp(ratio, 1.0 - args.ppo.cliprange, 1.0 + args.ppo.cliprange) - pg_loss = torch.max(pg_losses, pg_losses2).mean() - pg_clipfrac = (pg_losses2 > pg_losses).float().mean() + pg_loss_max = torch.max(pg_losses, pg_losses2) + # pg_loss = pg_loss_max.mean() + pg_loss = masked_mean(pg_loss_max, padding_mask[micro_batch_inds]) + pg_clipfrac = masked_mean((pg_losses2 > pg_losses).float(), padding_mask[micro_batch_inds]) loss = pg_loss + args.ppo.vf_coef * vf_loss accelerator.backward(loss) optimizer.step() optimizer.zero_grad() + # TODO: entropy does not handle padding tokens properly yet prob_dist = torch.nn.functional.softmax(logits, dim=-1) entropy = torch.logsumexp(logits, dim=-1) - torch.sum(prob_dist * logits, dim=-1) - approxkl = 0.5 * (logprobs_diff**2).mean() + # approxkl = 0.5 * (logprobs_diff**2).mean() + approxkl = 0.5 * masked_mean((logprobs_diff**2), padding_mask[micro_batch_inds]) + # if ppo_epoch_idx == 0 and micro_batch_start == 0: + # torch.testing.assert_close(ratio, torch.zeros_like(ratio) + 1, atol=1e-4, rtol=1e-4) + # pprint({ + # "responses": responses, + # "values": values, + # "rewards": rewards, + # "scores": scores, + # "advantages": advantages, + # "ratio": ratio, + # "pg_losses": pg_losses, + # "approxkl": approxkl, + # "pg_loss": pg_loss, + # "pg_clipfrac": pg_clipfrac, + # "ratio": ratio.mean(), + # "vf_loss": vf_loss, + # "vf_clipfrac": vf_clipfrac, + # "entropy": masked_mean(entropy, padding_mask[micro_batch_inds]), + # }) with torch.no_grad(): approxkl_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = approxkl pg_clipfrac_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = pg_clipfrac pg_loss_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = pg_loss vf_loss_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = vf_loss vf_clipfrac_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = vf_clipfrac - entropy_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = entropy.mean() + entropy_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = masked_mean(entropy, padding_mask[micro_batch_inds]) ratio_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = ratio.mean() gradient_accumulation_idx += 1 minibatch_idx += 1 diff --git a/lm_human_preference_details/train_policy_accelerate_summarize_separate1.py b/lm_human_preference_details/train_policy_accelerate_summarize_separate1.py new file mode 100644 index 0000000..927c6bc --- /dev/null +++ b/lm_human_preference_details/train_policy_accelerate_summarize_separate1.py @@ -0,0 +1,1024 @@ +import os +import random +import time +from dataclasses import asdict, dataclass, field +from types import SimpleNamespace +from typing import List, Literal, Optional + +import numpy as np +import pandas as pd +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +import tyro +from accelerate import Accelerator +from accelerate.state import AcceleratorState +from datasets import load_dataset +from rich.console import Console +from rich.pretty import pprint +from rich.table import Table +from torch import Tensor, optim +from torch.optim.optimizer import ( + _dispatch_sqrt, + _get_value, + _use_grad_for_differentiable, +) +from torch.utils.data import DataLoader +from torch.utils.tensorboard import SummaryWriter +from transformers import ( + AutoConfig, + AutoModelForCausalLM, + AutoTokenizer, + GenerationConfig, +) + + +INVALID_LOGPROB = 1.0 + + +@dataclass +class AdaptiveKLParams: + target: float = 6.0 + horizon: int = 10000 # in episodes + + +@dataclass +class RewardHParams: + kl_coef: float = 0.15 + use_adaptive_kl: bool = True + adaptive_kl: Optional[AdaptiveKLParams] = field(default_factory=AdaptiveKLParams) + trained_model: Optional[str] = "" + label_dataset: tyro.conf.Suppress[Optional[str]] = None + + +@dataclass +class PpoHParams: + total_episodes: int = 1000000 + local_batch_size: int = 64 + local_mini_batch_size: tyro.conf.Suppress[int] = None + batch_size: tyro.conf.Suppress[int] = None + mini_batch_size: tyro.conf.Suppress[int] = None + gradient_accumulation_steps: int = 64 + """gradient accumulation steps""" + local_micro_batch_size: tyro.conf.Suppress[int] = None + """per rank micro batch size""" + world_size: tyro.conf.Suppress[int] = None + batch_size: tyro.conf.Suppress[int] = None + minibatch_size: tyro.conf.Suppress[int] = None + num_updates: tyro.conf.Suppress[int] = None + nminibatches: int = 1 + noptepochs: int = 4 + lr: float = 0.00001 + eps: float = 1e-5 + vf_coef: float = 0.1 + cliprange: float = 0.2 + cliprange_value: float = 0.2 + gamma: float = 1 + lam: float = 0.95 + whiten_rewards: bool = True + + +@dataclass +class TaskHParams: + # Query params + query_length: int = 512 + query_dataset: str = "vwxyzjn/summarize_from_feedback_tldr_3_filtered_oai_preprocessing_pythia-160m_53" + + query_format_str: Optional[str] = "SUBREDDIT: r/{subreddit}\n\nTITLE: {title}\n\nPOST: {post}\n\nTL;DR:" + query_truncate_field: Optional[str] = "post" + query_truncate_text: Optional[str] = "\n" + query_padding: Optional[str] = None # defaults to repeated spaces + query_pad_side: Optional[str] = "left" + + # Response params + response_length: int = 53 + + # Truncate response after the first occurrence of this token at or after index after when sampling. + truncate_token: Literal["eos"] = "eos" + truncate_token_id: Optional[int] = None + truncate_after: int = 16 + penalty_reward_value: int = -1 + + # LM params + temperature: float = 0.7 + + +# a patch +@dataclass +class TaskQueryHParams: + length: int = None + dataset: str = None + format_str: Optional[str] = None # if underlying dataset yields dicts, can format arbitrarily + truncate_field: Optional[str] = None + truncate_text: Optional[str] = None + padding: Optional[str] = None # defaults to repeated spaces + pad_side: Optional[str] = None + + +@dataclass +class Args: + # common args + exp_name: str = os.path.basename(__file__)[: -len(".py")] + """the name of this experiment""" + seed: int = 1 + """seed of the experiment""" + track: bool = False + """if toggled, this experiment will be tracked with Weights and Biases""" + wandb_project_name: str = "tldr_summarize" + """the wandb's project name""" + wandb_entity: Optional[str] = None + """the entity (team) of wandb's project""" + cuda: bool = True + """Whether to use cuda if available.""" + run_name: tyro.conf.Suppress[str] = None + """TO BE FILLED: a unique name of this run""" + load_from_cache_file: bool = False + """Whether to load data from the local cache file in `dataset.map`""" + upload_model: bool = False + "whether to upload the saved model to huggingface" + hf_entity: str = "" + "the user or org name of the model repository from the Hugging Face Hub" + + base_model: str = "EleutherAI/pythia-160m" + """the name of the pretrained model to use""" + dropout_layer_keys: List[str] = field(default_factory=lambda: ["attn_pdrop", "embd_pdrop", "resid_pdrop", "summary_first_dropout"]) + """Which layers to apply dropout to""" + deepspeed: bool = False + """Whether to use deepspeed to train the model""" + print_sample_output_freq: int = 10 + """How often to print sample output""" + save_path: str = "models/ppo_policy" + """Where to save the model""" + optimizer: Literal["tf_adam", "adam", "adamw"] = "adamw" + """Which optimizer to use""" + sft_model_path: str = "" + """Where to load the SFT model""" + task: TaskHParams = field(default_factory=TaskHParams) + rewards: RewardHParams = field(default_factory=RewardHParams) + ppo: PpoHParams = field(default_factory=PpoHParams) + + +# taken from https://github.com/microsoft/DeepSpeedExamples/blob/737c6740bec38b77a24a59135b6481a53d566b38/applications/DeepSpeed-Chat/training/utils/model/model_utils.py#L20C1-L26C52 +def configure_dropout(model_config, dropout_layer_keys, dropout): + if dropout is not None: + for key in dropout_layer_keys: + if hasattr(model_config, key): + print(f"Setting model_config.{key} to {dropout}") + setattr(model_config, key, dropout) + + +def print_rich_table(title: str, df: pd.DataFrame, console: Console) -> Table: + table = Table(show_lines=True) + for column in df.columns: + table.add_column(column) + for _, row in df.iterrows(): + table.add_row(*row.astype(str).tolist()) + console.rule(f"[bold red]{title}") + console.print(table) + + +def _single_tensor_adam( + params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_avg_sqs: List[Tensor], + max_exp_avg_sqs: List[Tensor], + state_steps: List[Tensor], + grad_scale: Optional[Tensor], + found_inf: Optional[Tensor], + *, + amsgrad: bool, + beta1: float, + beta2: float, + lr: float, + weight_decay: float, + eps: float, + maximize: bool, + capturable: bool, + differentiable: bool, +): + assert grad_scale is None and found_inf is None + + for i, param in enumerate(params): + grad = grads[i] if not maximize else -grads[i] + exp_avg = exp_avgs[i] + exp_avg_sq = exp_avg_sqs[i] + step_t = state_steps[i] + # update step + step_t += 1 + # Decay the first and second moment running average coefficient + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2) + step = _get_value(step_t) + + ### pytorch adam implementation: + # bias_correction1 = 1 - beta1 ** step + # bias_correction2 = 1 - beta2 ** step + # step_size = lr / bias_correction1 + # bias_correction2_sqrt = _dispatch_sqrt(bias_correction2) + # denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps) + # param.addcdiv_(exp_avg, denom, value=-step_size) + + ### tensorflow adam implementation: + lr_t = lr * _dispatch_sqrt(1 - beta2**step) / (1 - beta1**step) + denom = exp_avg_sq.sqrt().add_(eps) + param.addcdiv_(exp_avg, denom, value=-lr_t) + + +def adam( + params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_avg_sqs: List[Tensor], + max_exp_avg_sqs: List[Tensor], + state_steps: List[Tensor], + # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627 + # setting this as kwarg for now as functional API is compiled by torch/distributed/optim + foreach: Optional[bool] = None, + capturable: bool = False, + differentiable: bool = False, + fused: Optional[bool] = None, + grad_scale: Optional[Tensor] = None, + found_inf: Optional[Tensor] = None, + *, + amsgrad: bool, + beta1: float, + beta2: float, + lr: float, + weight_decay: float, + eps: float, + maximize: bool, +): + func = _single_tensor_adam + + func( + params, + grads, + exp_avgs, + exp_avg_sqs, + max_exp_avg_sqs, + state_steps, + amsgrad=amsgrad, + beta1=beta1, + beta2=beta2, + lr=lr, + weight_decay=weight_decay, + eps=eps, + maximize=maximize, + capturable=capturable, + differentiable=differentiable, + grad_scale=grad_scale, + found_inf=found_inf, + ) + + +class AdamTensorFlowStyle(optim.Adam): + @_use_grad_for_differentiable + def step(self, closure=None): + self._cuda_graph_capture_health_check() + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + params_with_grad = [] + grads = [] + exp_avgs = [] + exp_avg_sqs = [] + max_exp_avg_sqs = [] + state_steps = [] + beta1, beta2 = group["betas"] + + self._init_group( + group, + params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + max_exp_avg_sqs, + state_steps, + ) + + adam( + params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + max_exp_avg_sqs, + state_steps, + amsgrad=group["amsgrad"], + beta1=beta1, + beta2=beta2, + lr=group["lr"], + weight_decay=group["weight_decay"], + eps=group["eps"], + maximize=group["maximize"], + foreach=group["foreach"], + capturable=group["capturable"], + differentiable=group["differentiable"], + fused=group["fused"], + grad_scale=getattr(self, "grad_scale", None), + found_inf=getattr(self, "found_inf", None), + ) + + return loss + + +class AdaptiveKLController: + def __init__(self, init_kl_coef: float, hparams: AdaptiveKLParams): + self.value = init_kl_coef + self.hparams = hparams + + def update(self, current, n_steps): + target = self.hparams.target + proportional_error = np.clip(current / target - 1, -0.2, 0.2) + mult = 1 + proportional_error * n_steps / self.hparams.horizon + self.value *= mult + + +def layer_init(layer, std=np.sqrt(2), bias_const=0.0): + torch.nn.init.normal_(layer.weight, std=std) + torch.nn.init.constant_(layer.bias, val=bias_const) + return layer + + +def whiten(values, shift_mean=True): + # `unbiased=False` matches TF `tf.nn.moments`'s setting + mean, var = torch.mean(values), torch.var(values, unbiased=False) + whitened = (values - mean) * torch.rsqrt(var + 1e-8) + if not shift_mean: + whitened += mean + return whitened + + +class AutoModelForCausalLMWithRewardHead(nn.Module): + def __init__(self, lm_backbone): + super().__init__() + self.lm_backbone = lm_backbone + self.scalar_head = layer_init( + nn.Linear(lm_backbone.config.hidden_size, 1), + std=1 / np.sqrt(lm_backbone.config.hidden_size + 1), + ) + # self.reward_gain = torch.nn.Parameter(torch.tensor(1.0), requires_grad=True) + # self.reward_bias = torch.nn.Parameter(torch.tensor(0.0), requires_grad=True) + + def forward(self, **kwargs): + output = self.lm_backbone(**kwargs) + reward = self.scalar_head(output.hidden_states[-1]) + return reward + + +# taken from https://github.com/OpenLMLab/MOSS-RLHF/blob/40b91eb2f2b71b16919addede0341d2bef70825d/ppo/ppo_trainer.py#L29 +# we did this we can do a single `model = accelerator.prepare(model)` +class PolicyAndValueWrapper(nn.Module): + def __init__(self, policy, critic) -> None: + super().__init__() + self.policy = policy + self.critic = critic + + def forward(self, **kwargs): + return self.policy(**kwargs), self.critic(**kwargs) + + +def ceil_div(a, b): + return (a - 1) // b + 1 + + +def exact_div(a, b): + q = a // b + if a != q * b: + raise ValueError(f"Inexact division: {a} / {b} = {a / b}") + return q + + +def generate(lm_backbone, queries, tokenizer, generation_config): + """generate in a way that does not affect padding tokens""" + context_length = queries.shape[1] + attention_mask = queries != tokenizer.pad_token_id + input_ids = torch.masked_fill(queries, ~attention_mask, 0) + output = lm_backbone.generate( + input_ids=input_ids, + attention_mask=attention_mask, + # position_ids=attention_mask.cumsum(1) - attention_mask.long(), # generation collapsed if this was turned on. TODO: why does generation collapse with this? + generation_config=generation_config, + return_dict_in_generate=True, + ) + return torch.cat((queries, output.sequences[:, context_length:]), dim=1) + + +def get_reward(reward_model, query_responses, tokenizer): + attention_mask = query_responses != tokenizer.pad_token_id + # position_ids = attention_mask.cumsum(1) - attention_mask.long() # exclusive cumsum + input_ids = torch.masked_fill(query_responses, ~attention_mask, 0) + reward_logits = reward_model( + input_ids=input_ids, + attention_mask=attention_mask, + # position_ids=position_ids, + return_dict=True, + output_hidden_states=True, + ) + sequence_lengths = ( + torch.eq(query_responses, tokenizer.pad_token_id).long().argmax(-1) - 1).to( + query_responses.device + ) + # https://github.com/huggingface/transformers/blob/dc68a39c8111217683bf49a4912d0c9018bab33d/src/transformers/models/gpt2/modeling_gpt2.py#L1454 + return reward_logits, reward_logits[torch.arange(reward_logits.size(0), device=reward_logits.device), sequence_lengths].squeeze(-1), sequence_lengths + + +def forward(policy, query_responses, tokenizer): + attention_mask = query_responses != tokenizer.pad_token_id + # position_ids = attention_mask.cumsum(1) - attention_mask.long() # exclusive cumsum + input_ids = torch.masked_fill(query_responses, ~attention_mask, 0) + return policy( + input_ids=input_ids, + attention_mask=attention_mask, + # position_ids=position_ids, + return_dict=True, + output_hidden_states=True, + ) + + +def first_true_indices(bools, dtype=torch.long): + """ + Takes an N-dimensional bool tensor and returns an (N-1)-dimensional tensor of integers giving + the position of the first True in each "row". + + Returns the length of the rows (bools.size(-1)) if no element is True in a given row. + """ + row_len = bools.size(-1) + zero_or_index = row_len * (~bools).type(dtype) + torch.arange(row_len, dtype=dtype, device=bools.device) + return torch.min(zero_or_index, dim=-1).values + + +def truncate_response(args, tokenizer, responses): + trunc_idxs = first_true_indices(responses == args.task.truncate_token_id).unsqueeze(-1) + new_size = [1] * (len(responses.size()) - 1) + [args.task.response_length] + idxs = torch.arange(args.task.response_length, device=responses.device).view(*new_size) + postprocessed_responses = torch.masked_fill(responses, idxs > trunc_idxs, tokenizer.pad_token_id) + return postprocessed_responses + + +def masked_mean(x, mask): + return (x.sum(-1) / (~mask).sum(-1)).mean() + +# def train(args: Args): +if __name__ == "__main__": + args = tyro.cli(Args) + accelerator = Accelerator(gradient_accumulation_steps=args.ppo.gradient_accumulation_steps) + args.ppo.world_size = accelerator.num_processes + args.ppo.batch_size = int(args.ppo.local_batch_size * args.ppo.world_size) + args.ppo.minibatch_size = exact_div(args.ppo.batch_size, args.ppo.nminibatches) + args.ppo.local_mini_batch_size = exact_div(args.ppo.local_batch_size, args.ppo.nminibatches) + args.ppo.local_micro_batch_size = exact_div(args.ppo.local_mini_batch_size, args.ppo.gradient_accumulation_steps) + if args.ppo.whiten_rewards: + assert ( + args.ppo.local_mini_batch_size >= 8 + ), f"Per-rank minibatch size {args.ppo.local_mini_batch_size} is insufficient for whitening" + # `per_rank_rollout_batch_size` is our `args.ppo.local_batch_size` + # `per_rank_minibatch_size` is our `args.ppo.local_mini_batch_size` + args.ppo.num_updates = args.ppo.total_episodes // args.ppo.batch_size + tokenizer = AutoTokenizer.from_pretrained( + args.base_model, + padding_side="right", + trust_remote_code=True, + ) + # we use the padding token manually but do not resize the token embedding of the model + tokenizer.add_special_tokens({"pad_token": "[PAD]"}) + if args.task.truncate_token == "eos": + args.task.truncate_token_id = tokenizer.eos_token_id + + console = Console(force_terminal=True) + run_name = f"{args.exp_name}__{args.seed}__{int(time.time())}" + writer = SimpleNamespace() # dummy writer + writer.add_scalar = lambda x, y, z: None + writer.add_histogram = lambda x, y, z: None + if accelerator.is_main_process: + if args.track: + import wandb + + wandb.init( + project=args.wandb_project_name, + entity=args.wandb_entity, + sync_tensorboard=True, + config=asdict(args), + name=run_name, + save_code=True, + ) + wandb.run.log_code(".") + writer = SummaryWriter(f"runs/{run_name}") + writer.add_text( + "hyperparameters", + "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])), + ) + pprint(args) + device = accelerator.device + local_seed = args.seed + accelerator.process_index * 100003 # Prime + random.seed(local_seed) + np.random.seed(local_seed) + torch.manual_seed(local_seed) + torch.backends.cudnn.deterministic = True + + model_config = AutoConfig.from_pretrained(args.base_model) + configure_dropout(model_config, args.dropout_layer_keys, 0.0) # disable dropout + if accelerator.is_main_process: + pprint(model_config) + reward_model = AutoModelForCausalLMWithRewardHead( + AutoModelForCausalLM.from_pretrained( + args.base_model, + config=model_config, + trust_remote_code=True, + ) + ) + critic = AutoModelForCausalLMWithRewardHead( + AutoModelForCausalLM.from_pretrained( + args.base_model, + config=model_config, + trust_remote_code=True, + ) + ) + if args.rewards.trained_model: + reward_model.load_state_dict(torch.load(args.rewards.trained_model, map_location=device)) + critic.load_state_dict(torch.load(args.rewards.trained_model, map_location=device)) + print(f"loaded pretrained reward model from {args.rewards.trained_model}") + # each class should have a separate pretrained model that do not share weights + ref_policy = AutoModelForCausalLM.from_pretrained(args.base_model, config=model_config, trust_remote_code=True) + policy = AutoModelForCausalLM.from_pretrained(args.base_model, config=model_config, trust_remote_code=True) + if args.sft_model_path: + policy.load_state_dict(torch.load(args.sft_model_path, map_location=device)) + ref_policy.load_state_dict(torch.load(args.sft_model_path, map_location=device)) + print(f"loaded pretrained policy from {args.sft_model_path}") + policy.generation_config.eos_token_id = None # disable `pad_token_id` and `eos_token_id` because we just want to + policy.generation_config.pad_token_id = None # generate tokens without truncation / padding + model = PolicyAndValueWrapper(policy, critic) + if args.optimizer == "tf_adam": + optimizer = AdamTensorFlowStyle(model.parameters(), lr=args.ppo.lr, eps=args.ppo.eps) + elif args.optimizer == "adam": + optimizer = optim.Adam(model.parameters(), lr=args.ppo.lr, eps=args.ppo.eps) + elif args.optimizer == "adamw": + optimizer = optim.AdamW(model.parameters(), lr=args.ppo.lr, eps=args.ppo.eps) + + dataset = load_dataset(args.task.query_dataset, split="train") + validation_dataset = load_dataset(args.task.query_dataset, split="validation") + dataset = dataset.with_format("torch", columns=["query_token", "reference_response_token"]) + dataset = dataset.shuffle(seed=local_seed) + dataloader = DataLoader(dataset, batch_size=args.ppo.local_batch_size) + validation_dataset = validation_dataset.with_format("torch", columns=["query_token", "reference_response_token"]) + validation_dataloader = DataLoader(validation_dataset, batch_size=args.ppo.local_batch_size) + model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader) + validation_dataloader = accelerator.prepare(validation_dataloader) + if args.deepspeed: + import deepspeed + + deepspeed_states = AcceleratorState().deepspeed_plugin + # deepspeed_states.deepspeed_config["train_micro_batch_size_per_gpu"] = args.ppo.local_micro_batch_size + # deepspeed_states.deepspeed_config["checkpoint"] = {"use_node_local_storage": True} + eval_ds_config = { + "train_micro_batch_size_per_gpu": deepspeed_states.deepspeed_config["train_micro_batch_size_per_gpu"], + # "steps_per_print": 10, + # "zero_optimization": { + # "stage": stage, + # "stage3_param_persistence_threshold": 1e4, + # "offload_param": { + # "device": off_load_device + # } + # }, + "bf16": {"enabled": True}, + "prescale_gradients": False, + "wall_clock_breakdown": False, + } + reward_model, *_ = deepspeed.initialize(model=reward_model, config=eval_ds_config) + reward_model.eval() + ref_policy, *_ = deepspeed.initialize(model=ref_policy, config=eval_ds_config) + ref_policy.eval() + else: + ref_policy = ref_policy.to(device) + reward_model = reward_model.to(device) + + def repeat_generator(): # TODO: ideally we shuffle the dataloader as well + while True: + yield from dataloader + + sample_validation_inds = np.arange(args.ppo.batch_size) + local_sample_validation_inds = sample_validation_inds[accelerator.process_index :: accelerator.num_processes] + sample_validation = validation_dataset[local_sample_validation_inds] + sample_validation_queries = torch.Tensor(sample_validation["query_token"]).to(device) + with torch.no_grad(): + # sample_validation_queries = shift_pad_id_left(sample_validation_queries, tokenizer.pad_token_id) + sample_validation_reference_response = torch.Tensor(sample_validation["reference_response_token"]).to(device) + sample_validation_query_reference_responses = torch.cat( + (sample_validation_queries, sample_validation_reference_response), dim=1 + ) + # sample_validation_query_reference_responses = shift_pad_id_left( + # sample_validation_query_reference_responses, tokenizer.pad_token_id + # ) + _, sample_validation_reference_scores, _ = get_reward( + reward_model, sample_validation_query_reference_responses, tokenizer + ) + + iter_dataloader = iter(repeat_generator()) + kl_ctl = AdaptiveKLController(args.rewards.kl_coef, hparams=args.rewards.adaptive_kl) + # WARNING: even with `max_new_tokens` and `min_new_tokens` set to the same value, the number of tokens generated + # may not be the same. TODO: investigate further, we just want to generate a fixed number of tokens + generation_config = GenerationConfig( + max_new_tokens=args.task.response_length, + min_new_tokens=args.task.response_length, + temperature=(args.task.temperature + 1e-7), + top_k=0.0, + top_p=1.0, + do_sample=True, + ) + + print("===training policy===") + global_step = 0 + stats_shape = (args.ppo.noptepochs, args.ppo.nminibatches, args.ppo.gradient_accumulation_steps) + approxkl_stats = torch.zeros(stats_shape, device=device) + pg_clipfrac_stats = torch.zeros(stats_shape, device=device) + pg_loss_stats = torch.zeros(stats_shape, device=device) + vf_loss_stats = torch.zeros(stats_shape, device=device) + vf_clipfrac_stats = torch.zeros(stats_shape, device=device) + entropy_stats = torch.zeros(stats_shape, device=device) + ratio_stats = torch.zeros(stats_shape, device=device) + + model.train() + for update in range(1, args.ppo.num_updates + 1): + global_step += 1 * args.ppo.batch_size + frac = 1.0 - (update - 1.0) / args.ppo.num_updates + lrnow = frac * args.ppo.lr + optimizer.param_groups[0]["lr"] = lrnow + data = next(iter_dataloader) + with torch.no_grad(): + queries = data["query_token"].to(device) + reference_responses = data["reference_response_token"].to(device) + # queries = shift_pad_id_left(queries, tokenizer.pad_token_id) + query_reference_responses = torch.cat((queries, reference_responses), dim=1) + # query_reference_responses = shift_pad_id_left(query_reference_responses, tokenizer.pad_token_id) + query_responses = generate( + accelerator.unwrap_model(model).policy, + queries, + tokenizer, + generation_config, + ) + if args.task.response_length != 53: + query_responses = torch.tensor([[ 209, 209, 209, 209, 209, 209, 209, 209, 209, 209, + 209, 209, 209, 209, 209, 209, 209, 209, 209, 209, + 209, 209, 209, 209, 209, 209, 209, 209, 209, 209, + 209, 209, 209, 209, 209, 209, 209, 209, 209, 209, + 209, 209, 209, 209, 209, 209, 209, 209, 209, 209, + 209, 209, 209, 209, 209, 209, 209, 209, 209, 209, + 209, 209, 209, 209, 209, 209, 209, 209, 209, 209, + 209, 209, 209, 209, 209, 209, 209, 209, 209, 209, + 209, 209, 209, 209, 209, 209, 209, 209, 209, 209, + 209, 209, 209, 209, 209, 209, 209, 209, 209, 209, + 209, 209, 209, 209, 209, 209, 209, 209, 209, 209, + 209, 209, 209, 209, 209, 209, 209, 209, 209, 209, + 209, 209, 209, 209, 209, 209, 209, 209, 209, 209, + 209, 209, 209, 209, 209, 209, 209, 209, 209, 209, + 209, 209, 209, 209, 209, 209, 6971, 7941, 1703, 37, + 1433, 27, 391, 16, 22842, 16458, 187, 187, 53, 43561, + 27, 3189, 544, 1348, 278, 62, 5816, 619, 806, 385, + 544, 1797, 269, 62, 846, 608, 2607, 273, 2740, 598, + 15, 187, 187, 15743, 27, 24387, 39714, 187, 6300, 15950, + 436, 1501, 562, 627, 816, 281, 1339, 352, 562, 273, + 619, 985, 15, 187, 2598, 309, 452, 644, 13597, 436, + 3226, 313, 2577, 806, 19609, 15, 309, 369, 617, 806, + 10, 323, 495, 1107, 15, 844, 574, 271, 13103, 673, + 285, 4536, 7227, 35267, 285, 37616, 15, 496, 253, 990, + 13, 352, 1904, 626, 789, 562, 15, 187, 42, 3260, + 309, 7636, 617, 285, 703, 7636, 479, 533, 1841, 816, + 1904, 626, 789, 562, 1955, 281, 1097, 4858, 4606, 15, + 187, 187, 2598, 352, 556, 644, 2761, 608, 2607, 15, + 309, 1694, 689, 253, 31056, 673, 273, 619, 1495, 534, + 369, 1501, 2740, 598, 273, 806, 374, 2607, 15, 209, + 187, 4125, 846, 608, 2607, 13, 309, 816, 2985, 617, + 15, 187, 42, 5476, 627, 11210, 626, 644, 247, 2014, + 835, 309, 6468, 626, 1869, 670, 617, 2568, 15, 23385, + 50276, 187, 42, 871, 309, 10095, 626, 3057, 617, 285, + 309, 1353, 3965, 2119, 703, 1912, 626, 3057, 479, 2057, + 534, 310, 323, 253, 1805, 15, 187, 1231, 6468, 626, + 13452, 323, 5046, 374, 2607, 32, 1633, 751, 326, 15, + 187, 43688, 13, 309, 816, 4571, 626, 6016, 352, 10542, + 285, 3261, 387, 776, 7963, 327, 619, 17899, 7963, 534, + 309, 1620, 755, 327, 15, 187, 1147, 369, 5322, 281, + 923, 617, 2454, 969, 285, 30774, 336, 253, 1711, 1897, + 15, 187, 1147, 369, 5322, 281, 923, 253, 9097, 359, + 1097, 2389, 1024, 3811, 342, 617, 2021, 15, 187, 34937, + 512, 608, 2607, 13, 619, 5249, 5055, 598, 15, 309, + 1694, 247, 14892, 209, 187, 36421, 598, 247, 2257, 273, + 2583, 285, 858, 1841, 1475, 253, 2419, 309, 6468, 626, + 644, 2104, 281, 3966, 3966, 15, 187, 1989, 309, 816, + 2985, 617, 15, 187, 42, 871, 703, 434, 2509, 973, + 13, 3164, 1805, 685, 1078, 15, 187, 2513, 352, 816, + 479, 32, 209, 187, 25954, 6701, 323, 634, 673, 4361, + 436, 285, 11435, 634, 5701, 15, 187, 187, 14135, 28, + 4976, 27, 6365, 619, 806, 19609, 13, 9377, 598, 13, + 309, 2985, 617, 533, 1053, 626, 3057, 617, 285, 12371, + 604, 352, 434, 816, 479, 15, 0,]], device=device) + context_length = queries.shape[1] + responses = query_responses[:, context_length:] + + # validation + sample_validation_query_responses = generate( + accelerator.unwrap_model(model).policy, + sample_validation_queries, + tokenizer, + generation_config, + ) + sample_validation_responses = sample_validation_query_responses[:, context_length:] + postprocessed_sample_validation_responses = truncate_response(args, tokenizer, sample_validation_responses) + postprocessed_sample_validation_query_responses = torch.cat( + (sample_validation_queries, postprocessed_sample_validation_responses), 1 + ) + # postprocessed_sample_validation_query_responses = shift_pad_id_left( + # postprocessed_sample_validation_query_responses, tokenizer.pad_token_id + # ) + torch.cuda.empty_cache() + + # TODO: do I do this with query response or post-processed query response? + output = forward(accelerator.unwrap_model(model).policy, query_responses, tokenizer) + logits = output.logits[:, context_length - 1 : -1] + logits /= args.task.temperature + 1e-7 + all_logprobs = F.log_softmax(logits, dim=-1) + logprobs = torch.gather(all_logprobs, 2, responses.unsqueeze(-1)).squeeze(-1) + del output, logits, all_logprobs + torch.cuda.empty_cache() + + ref_output = forward(ref_policy, query_responses, tokenizer) + ref_logits = ref_output.logits[:, context_length - 1 : -1] + ref_logits /= args.task.temperature + 1e-7 + ref_all_logprobs = F.log_softmax(ref_logits, dim=-1) + ref_logprobs = torch.gather(ref_all_logprobs, 2, responses.unsqueeze(-1)).squeeze(-1) + del ref_output, ref_logits, ref_all_logprobs + torch.cuda.empty_cache() + + # **Response Processing** + postprocessed_responses = truncate_response(args, tokenizer, responses) + torch.cuda.empty_cache() + + # 2. run reward model on the truncated responses + postprocessed_query_responses = torch.cat((queries, postprocessed_responses), 1) + # postprocessed_query_responses = shift_pad_id_left(postprocessed_query_responses, tokenizer.pad_token_id) + full_values, _, _ = get_reward(accelerator.unwrap_model(model).critic, postprocessed_query_responses, tokenizer) + values = full_values[:, context_length - 1 : -1].squeeze(-1) + padding_mask = postprocessed_responses == tokenizer.pad_token_id + logprobs = torch.masked_fill(logprobs, padding_mask, INVALID_LOGPROB) + ref_logprobs = torch.masked_fill(ref_logprobs, padding_mask, INVALID_LOGPROB) + values = torch.masked_fill(values, padding_mask, 0) + + rew, scores, sequence_lengths = get_reward(reward_model, postprocessed_query_responses, tokenizer) + + _, reference_scores, _ = get_reward(reward_model, query_reference_responses, tokenizer) + _, validation_score, _ = get_reward(reward_model, postprocessed_sample_validation_query_responses, tokenizer) + + # carperAI-style score normaliation + scores = scores - reference_scores + + # 3. filter response. Ensure that the sample contains truncate_token_id + # responses not passing that filter will receive a low (fixed) score + # only query humans on responses that pass that filter + contain_pad_token = torch.any(postprocessed_responses == tokenizer.pad_token_id, dim=-1) + + + + # TODO: reverse it back + # scores = torch.where(contain_pad_token, scores, torch.full_like(scores, args.task.penalty_reward_value)) + + + + + torch.cuda.empty_cache() + + # 4. compute rewards + kl = logprobs - ref_logprobs + non_score_reward = -kl_ctl.value * kl + rewards = non_score_reward.clone() + rewards[:, -1] += scores + + # 5. whiten rewards + if args.ppo.whiten_rewards: + rewards = whiten(rewards, shift_mean=False) + + if args.print_sample_output_freq > 0 and (update - 1) % args.print_sample_output_freq == 0: + try: + all_decode_validation_queries = tokenizer.batch_decode(sample_validation_queries, skip_special_tokens=True) + all_sample_validation_responses = tokenizer.batch_decode(sample_validation_responses) + all_sample_validation_query_responses_postprocessed = tokenizer.batch_decode( + postprocessed_sample_validation_query_responses, skip_special_tokens=True + ) + all_sample_validation_postprocessed_responses = [ + x[len(y) :] + for x, y in zip(all_sample_validation_query_responses_postprocessed, all_decode_validation_queries) + ] + all_sample_validation_reference_responses = tokenizer.batch_decode(sample_validation_reference_response) + all_sample_validation_df = pd.DataFrame( + { + "query": all_decode_validation_queries, + "response": all_sample_validation_responses, + "postprocessed_response": all_sample_validation_postprocessed_responses, + "reference_responses": all_sample_validation_reference_responses, + "scores": validation_score.float().cpu().numpy(), + "reference_scores": sample_validation_reference_scores.float().cpu().numpy(), + } + ) + if accelerator.is_main_process and args.track: + wandb.log({"samples/query_responses": wandb.Table(dataframe=all_sample_validation_df)}, step=update) + # print_rich_table("stuff", all_sample_validation_df[:4], console) + + except Exception as e: + print(e) + del ( + all_decode_validation_queries, + all_sample_validation_responses, + all_sample_validation_reference_responses, + all_sample_validation_df, + ) + del postprocessed_query_responses + torch.cuda.empty_cache() + + # 6. compute advantages and returns + lastgaelam = 0 + advantages_reversed = [] + gen_length = args.task.response_length + for t in reversed(range(gen_length)): + nextvalues = values[:, t + 1] if t < gen_length - 1 else 0.0 + delta = rewards[:, t] + args.ppo.gamma * nextvalues - values[:, t] + lastgaelam = delta + args.ppo.gamma * args.ppo.lam * lastgaelam + advantages_reversed.append(lastgaelam) + advantages = torch.stack(advantages_reversed[::-1], axis=1) + returns = advantages + values + + + + + # TODO: reverse it back + # advantages = whiten(advantages) + + + + + + return_mean, return_var = returns.mean(), returns.var() + value_mean, value_var = values.mean(), values.var() + + # Do multiple epochs of PPO training, with a fresh random shuffle in each epoch + for ppo_epoch_idx in range(args.ppo.noptepochs): + b_inds = np.random.permutation(args.ppo.local_batch_size) + minibatch_idx = 0 + for mini_batch_start in range(0, args.ppo.local_batch_size, args.ppo.local_mini_batch_size): + mini_batch_end = mini_batch_start + args.ppo.local_mini_batch_size + mini_batch_inds = b_inds[mini_batch_start:mini_batch_end] + gradient_accumulation_idx = 0 + for micro_batch_start in range(0, args.ppo.local_mini_batch_size, args.ppo.local_micro_batch_size): + with accelerator.accumulate(policy): + micro_batch_end = micro_batch_start + args.ppo.local_micro_batch_size + micro_batch_inds = mini_batch_inds[micro_batch_start:micro_batch_end] + mb_return = returns[micro_batch_inds] + mb_advantage = advantages[micro_batch_inds] + mb_values = values[micro_batch_inds] + mb_responses = responses[micro_batch_inds] + mb_query_responses = query_responses[micro_batch_inds] + mb_logprobs = logprobs[micro_batch_inds] + + output, vpred_temp = forward(model, mb_query_responses, tokenizer) + logits = output.logits[:, context_length - 1 : -1] + logits /= args.task.temperature + 1e-7 + new_all_logprobs = F.log_softmax(logits, dim=-1) + new_logprobs = torch.gather(new_all_logprobs, 2, mb_responses.unsqueeze(-1)).squeeze(-1) + vpred = vpred_temp[:, context_length - 1 : -1].squeeze(-1) + new_logprobs = torch.masked_fill(new_logprobs, padding_mask[micro_batch_inds], INVALID_LOGPROB) + vpred = torch.masked_fill(vpred, padding_mask[micro_batch_inds], 0) + vpredclipped = torch.clamp( + vpred, + mb_values - args.ppo.cliprange_value, + mb_values + args.ppo.cliprange_value, + ) + vf_losses1 = torch.square(vpred - mb_return) + vf_losses2 = torch.square(vpredclipped - mb_return) + vf_loss_max = torch.max(vf_losses1, vf_losses2) + + + vf_loss = 0.5 * vf_loss_max.mean() + # vf_loss = 0.5 * masked_mean(vf_loss_max, padding_mask[micro_batch_inds]) + + vf_clipfrac = masked_mean((vf_losses2 > vf_losses1).float(), padding_mask[micro_batch_inds]) + logprobs_diff = new_logprobs - mb_logprobs + ratio = torch.exp(logprobs_diff) + pg_losses = -mb_advantage * ratio + pg_losses2 = -mb_advantage * torch.clamp(ratio, 1.0 - args.ppo.cliprange, 1.0 + args.ppo.cliprange) + pg_loss_max = torch.max(pg_losses, pg_losses2) + pg_loss = pg_loss_max.mean() + # pg_loss = masked_mean(pg_loss_max, padding_mask[micro_batch_inds]) + pg_clipfrac = masked_mean((pg_losses2 > pg_losses).float(), padding_mask[micro_batch_inds]) + loss = pg_loss + args.ppo.vf_coef * vf_loss + accelerator.backward(loss) + optimizer.step() + optimizer.zero_grad() + prob_dist = torch.nn.functional.softmax(logits, dim=-1) + entropy = torch.logsumexp(logits, dim=-1) - torch.sum(prob_dist * logits, dim=-1) + approxkl = 0.5 * (logprobs_diff**2).mean() + # approxkl = 0.5 * masked_mean((logprobs_diff**2), padding_mask[micro_batch_inds]) + # if ppo_epoch_idx == 0 and micro_batch_start == 0: + # torch.testing.assert_close(ratio, torch.zeros_like(ratio) + 1, atol=1e-4, rtol=1e-4) + pprint({ + "responses": responses, + "values": values, + "rewards": rewards, + "scores": scores, + "advantages": advantages, + "ratio": ratio, + "pg_losses": pg_losses, + "approxkl": approxkl, + "pg_loss": pg_loss, + "pg_clipfrac": pg_clipfrac, + "ratio": ratio.mean(), + "vf_loss": vf_loss, + "vf_clipfrac": vf_clipfrac, + "entropy": entropy.mean(), + }) + with torch.no_grad(): + approxkl_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = approxkl + pg_clipfrac_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = pg_clipfrac + pg_loss_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = pg_loss + vf_loss_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = vf_loss + vf_clipfrac_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = vf_clipfrac + entropy_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = entropy.mean() + ratio_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = ratio.mean() + gradient_accumulation_idx += 1 + raise + # minibatch_idx += 1 + # if accelerator.is_main_process: + # console.print( + # f"ppo_epoch_idx", + # ppo_epoch_idx, + # "approxkl", + # approxkl_stats[:ppo_epoch_idx+1].mean().item(), + # "pg_loss", + # pg_loss_stats[:ppo_epoch_idx+1].mean().item(), + # "pg_clipfrac", + # pg_clipfrac_stats[:ppo_epoch_idx+1].mean().item(), + # "ratio", + # ratio_stats[:ppo_epoch_idx+1].mean().item(), + # ) + + with torch.no_grad(): + if not args.deepspeed: # for some reason there is a OOM with the `writer.add_histogram` + writer.add_histogram("ppo/val/ratio_hist", ratio, update) + kl = logprobs - ref_logprobs + mean_kl = kl.sum(1).mean() + mean_entropy = (-logprobs).sum(1).mean() + mean_non_score_reward = non_score_reward.sum(1).mean() + writer.add_scalar("objective/kl_coef", kl_ctl.value, update) + writer.add_scalar("objective/kl", accelerator.gather(mean_kl).mean().item(), update) + writer.add_scalar("objective/entropy", accelerator.gather(mean_entropy).mean().item(), update) + writer.add_scalar("objective/non_score_reward", accelerator.gather(mean_non_score_reward).mean().item(), update) + writer.add_scalar( + "objective/score_total", accelerator.gather(mean_non_score_reward + scores.mean()).mean().item(), update + ) + writer.add_scalar("objective/scores", accelerator.gather(scores.mean()).mean().item(), update) + writer.add_scalar("objective/reference_scores", accelerator.gather(reference_scores.mean()).mean().item(), update) + writer.add_scalar("objective/validation_score", accelerator.gather(validation_score.mean()).mean().item(), update) + writer.add_scalar("ppo/loss/policy", accelerator.gather(pg_loss).mean().item(), update) + writer.add_scalar("ppo/loss/value", accelerator.gather(vf_loss).mean().item(), update) + writer.add_scalar("ppo/loss/total", accelerator.gather(loss).mean().item(), update) + writer.add_scalar("ppo/policy/entropy", accelerator.gather(entropy.mean()).mean().item(), update) + writer.add_scalar("ppo/policy/approxkl", accelerator.gather(approxkl).mean().item(), update) + writer.add_scalar("ppo/policy/clipfrac", accelerator.gather(pg_clipfrac).mean().item(), update) + writer.add_scalar("ppo/policy/approxkl_avg", accelerator.gather(approxkl_stats).mean().item(), update) + writer.add_scalar("ppo/policy/clipfrac_avg", accelerator.gather(pg_clipfrac_stats).mean().item(), update) + writer.add_scalar("ppo/loss/policy_avg", accelerator.gather(pg_loss_stats).mean().item(), update) + writer.add_scalar("ppo/loss/value_avg", accelerator.gather(vf_loss_stats).mean().item(), update) + writer.add_scalar("ppo/val/clipfrac_avg", accelerator.gather(vf_clipfrac_stats).mean().item(), update) + writer.add_scalar("ppo/policy/entropy_avg", accelerator.gather(entropy_stats).mean().item(), update) + writer.add_scalar("ppo/returns/mean", accelerator.gather(return_mean).mean().item(), update) + writer.add_scalar("ppo/returns/var", accelerator.gather(return_var).mean().item(), update) + writer.add_scalar("ppo/val/vpred", accelerator.gather(vpred.mean()).mean().item(), update) + writer.add_scalar("ppo/val/error", accelerator.gather(vf_losses1.mean()).mean().item(), update) + writer.add_scalar("ppo/val/clipfrac", accelerator.gather(vf_clipfrac).mean().item(), update) + writer.add_scalar("ppo/val/mean", accelerator.gather(value_mean).mean().item(), update) + writer.add_scalar("ppo/val/var", accelerator.gather(value_var).mean().item(), update) + writer.add_scalar("ppo/val/ratio", accelerator.gather(ratio.mean()).mean().item(), update) + writer.add_scalar("ppo/val/ratio_var", accelerator.gather(ratio.mean()).var().item(), update) + writer.add_scalar("ppo/val/advantage", accelerator.gather(advantages.mean()).mean().item(), update) + writer.add_scalar("ppo/val/advantage_var", accelerator.gather(advantages.mean()).var().item(), update) + writer.add_scalar("ppo/val/num_eos_tokens", (responses == tokenizer.eos_token_id).sum().item(), update) + writer.add_scalar("ppo/lr", lrnow, update) + writer.add_scalar("ppo/episode", global_step, update) + if args.rewards.use_adaptive_kl: + kl_ctl.update(mean_kl.item(), args.ppo.batch_size) + del kl, mean_kl, mean_entropy, mean_non_score_reward, scores + + # save model + if args.save_path: + os.makedirs(os.path.dirname(args.save_path), exist_ok=True) + accelerator.save_model(policy, args.save_path, max_shard_size="1000GB") + + if args.upload_model and accelerator.is_main_process: + repo_name = f"{args.exp_name}__{args.rewards.label_dataset}__seed{args.seed}__{int(time.time())}" + repo_id = f"{args.hf_entity}/{repo_name}" if args.hf_entity else repo_name + policy.save_pretrained(repo_id, safe_serialization=True, push_to_hub=True) + tokenizer.save_pretrained(repo_id, push_to_hub=True) + +# if __name__ == "__main__": +# args = tyro.cli(Args) +# train(args) From c4ebe5ecb25a3ae9ad089b7836692c63dfef88df Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Sun, 3 Dec 2023 15:01:02 +0000 Subject: [PATCH 31/62] seems successful! --- ...in_policy_accelerate_summarize_separate.py | 216 +++++++++--------- 1 file changed, 104 insertions(+), 112 deletions(-) diff --git a/lm_human_preference_details/train_policy_accelerate_summarize_separate.py b/lm_human_preference_details/train_policy_accelerate_summarize_separate.py index 6849bc3..acdcdd7 100644 --- a/lm_human_preference_details/train_policy_accelerate_summarize_separate.py +++ b/lm_human_preference_details/train_policy_accelerate_summarize_separate.py @@ -353,6 +353,57 @@ def whiten(values, shift_mean=True): return whitened +def masked_mean(x, mask): + return (x.sum(-1) / (~mask).sum(-1)).mean() + +def masked_var(x, mask): + return (x**2).sum(-1) / (~mask).sum(-1) - masked_mean(x, mask)**2 + + +def masked_whiten(values, mask, shift_mean=True): + """Whiten values with masked values.""" + mean, var = masked_mean(values, mask), masked_var(values, mask) + whitened = (values - mean) * torch.rsqrt(var + 1e-8) + if not shift_mean: + whitened += mean + return whitened + + +def masked_mean(values, mask, axis=None): + """Compute mean of tensor with a masked values.""" + if axis is not None: + return (values * mask).sum(axis=axis) / mask.sum(axis=axis) + else: + return (values * mask).sum() / mask.sum() + +def masked_var(values, mask, unbiased=True): + """Compute variance of tensor with masked values.""" + mean = masked_mean(values, mask) + centered_values = values - mean + variance = masked_mean(centered_values**2, mask) + if unbiased: + mask_sum = mask.sum() + if mask_sum == 0: + raise ValueError( + "The sum of the mask is zero, which can happen when `mini_batch_size=1`;" + "try increase the `mini_batch_size` or `gradient_accumulation_steps`" + ) + # note that if mask_sum == 1, then there is a division by zero issue + # to avoid it you just need to use a larger minibatch_size + bessel_correction = mask_sum / (mask_sum - 1) + variance = variance * bessel_correction + return variance + + +def masked_whiten(values, mask, shift_mean=True): + """Whiten values with masked values.""" + mean, var = masked_mean(values, mask), masked_var(values, mask, False) + whitened = (values - mean) * torch.rsqrt(var + 1e-8) + if not shift_mean: + whitened += mean + return whitened + + class AutoModelForCausalLMWithRewardHead(nn.Module): def __init__(self, lm_backbone): super().__init__() @@ -428,10 +479,6 @@ def truncate_response(args, tokenizer, responses): return postprocessed_responses -def masked_mean(x, mask): - return (x.sum(-1) / (~mask).sum(-1)).mean() - - def get_reward(reward_model, query_responses, tokenizer): attention_mask = query_responses != tokenizer.pad_token_id # position_ids = attention_mask.cumsum(1) - attention_mask.long() # exclusive cumsum @@ -643,8 +690,7 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well vf_clipfrac_stats = torch.zeros(stats_shape, device=device) entropy_stats = torch.zeros(stats_shape, device=device) ratio_stats = torch.zeros(stats_shape, device=device) - - model.train() + model.eval() for update in range(1, args.ppo.num_updates + 1): global_step += 1 * args.ppo.batch_size frac = 1.0 - (update - 1.0) / args.ppo.num_updates @@ -663,61 +709,6 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well tokenizer, generation_config, ) - if args.task.response_length != 53: - query_responses = torch.tensor([[ 209, 209, 209, 209, 209, 209, 209, 209, 209, 209, - 209, 209, 209, 209, 209, 209, 209, 209, 209, 209, - 209, 209, 209, 209, 209, 209, 209, 209, 209, 209, - 209, 209, 209, 209, 209, 209, 209, 209, 209, 209, - 209, 209, 209, 209, 209, 209, 209, 209, 209, 209, - 209, 209, 209, 209, 209, 209, 209, 209, 209, 209, - 209, 209, 209, 209, 209, 209, 209, 209, 209, 209, - 209, 209, 209, 209, 209, 209, 209, 209, 209, 209, - 209, 209, 209, 209, 209, 209, 209, 209, 209, 209, - 209, 209, 209, 209, 209, 209, 209, 209, 209, 209, - 209, 209, 209, 209, 209, 209, 209, 209, 209, 209, - 209, 209, 209, 209, 209, 209, 209, 209, 209, 209, - 209, 209, 209, 209, 209, 209, 209, 209, 209, 209, - 209, 209, 209, 209, 209, 209, 209, 209, 209, 209, - 209, 209, 209, 209, 209, 209, 6971, 7941, 1703, 37, - 1433, 27, 391, 16, 22842, 16458, 187, 187, 53, 43561, - 27, 3189, 544, 1348, 278, 62, 5816, 619, 806, 385, - 544, 1797, 269, 62, 846, 608, 2607, 273, 2740, 598, - 15, 187, 187, 15743, 27, 24387, 39714, 187, 6300, 15950, - 436, 1501, 562, 627, 816, 281, 1339, 352, 562, 273, - 619, 985, 15, 187, 2598, 309, 452, 644, 13597, 436, - 3226, 313, 2577, 806, 19609, 15, 309, 369, 617, 806, - 10, 323, 495, 1107, 15, 844, 574, 271, 13103, 673, - 285, 4536, 7227, 35267, 285, 37616, 15, 496, 253, 990, - 13, 352, 1904, 626, 789, 562, 15, 187, 42, 3260, - 309, 7636, 617, 285, 703, 7636, 479, 533, 1841, 816, - 1904, 626, 789, 562, 1955, 281, 1097, 4858, 4606, 15, - 187, 187, 2598, 352, 556, 644, 2761, 608, 2607, 15, - 309, 1694, 689, 253, 31056, 673, 273, 619, 1495, 534, - 369, 1501, 2740, 598, 273, 806, 374, 2607, 15, 209, - 187, 4125, 846, 608, 2607, 13, 309, 816, 2985, 617, - 15, 187, 42, 5476, 627, 11210, 626, 644, 247, 2014, - 835, 309, 6468, 626, 1869, 670, 617, 2568, 15, 23385, - 50276, 187, 42, 871, 309, 10095, 626, 3057, 617, 285, - 309, 1353, 3965, 2119, 703, 1912, 626, 3057, 479, 2057, - 534, 310, 323, 253, 1805, 15, 187, 1231, 6468, 626, - 13452, 323, 5046, 374, 2607, 32, 1633, 751, 326, 15, - 187, 43688, 13, 309, 816, 4571, 626, 6016, 352, 10542, - 285, 3261, 387, 776, 7963, 327, 619, 17899, 7963, 534, - 309, 1620, 755, 327, 15, 187, 1147, 369, 5322, 281, - 923, 617, 2454, 969, 285, 30774, 336, 253, 1711, 1897, - 15, 187, 1147, 369, 5322, 281, 923, 253, 9097, 359, - 1097, 2389, 1024, 3811, 342, 617, 2021, 15, 187, 34937, - 512, 608, 2607, 13, 619, 5249, 5055, 598, 15, 309, - 1694, 247, 14892, 209, 187, 36421, 598, 247, 2257, 273, - 2583, 285, 858, 1841, 1475, 253, 2419, 309, 6468, 626, - 644, 2104, 281, 3966, 3966, 15, 187, 1989, 309, 816, - 2985, 617, 15, 187, 42, 871, 703, 434, 2509, 973, - 13, 3164, 1805, 685, 1078, 15, 187, 2513, 352, 816, - 479, 32, 209, 187, 25954, 6701, 323, 634, 673, 4361, - 436, 285, 11435, 634, 5701, 15, 187, 187, 14135, 28, - 4976, 27, 6365, 619, 806, 19609, 13, 9377, 598, 13, - 309, 2985, 617, 533, 1053, 626, 3057, 617, 285, 12371, - 604, 352, 434, 816, 479, 15, 0,]], device=device) context_length = queries.shape[1] responses = query_responses[:, context_length:] @@ -781,31 +772,26 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well # responses not passing that filter will receive a low (fixed) score # only query humans on responses that pass that filter contain_pad_token = torch.any(postprocessed_responses == tokenizer.pad_token_id, dim=-1) - - - - # TODO: reverse it back - # scores = torch.where(contain_pad_token, scores, torch.full_like(scores, args.task.penalty_reward_value)) - - - - + scores = torch.where(contain_pad_token, scores, torch.full_like(scores, args.task.penalty_reward_value)) + accelerator.print(f"{scores=}, {(contain_pad_token.sum() / len(contain_pad_token))=}") torch.cuda.empty_cache() # 4. compute rewards kl = logprobs - ref_logprobs + kl = torch.masked_fill(kl, padding_mask, 0) non_score_reward = -kl_ctl.value * kl rewards = non_score_reward.clone() - # print(f"{sequence_lengths=}") - # breakpoint() - # rewards[:, -1] += scores actual_start = torch.arange(rewards.size(0), device=rewards.device) actual_end = sequence_lengths - context_length rewards[[actual_start, actual_end]] += scores + mean_kl = kl.sum(1).mean() + accelerator.print(f"{mean_kl=}, {(logprobs - ref_logprobs).sum(1).mean()}") + # if update == 2: raise # 5. whiten rewards if args.ppo.whiten_rewards: - rewards = whiten(rewards, shift_mean=False) + rewards = masked_whiten(rewards, mask=~padding_mask, shift_mean=False) + rewards = torch.masked_fill(rewards, padding_mask, 0) if args.print_sample_output_freq > 0 and (update - 1) % args.print_sample_output_freq == 0: try: @@ -829,9 +815,11 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well "reference_scores": sample_validation_reference_scores.float().cpu().numpy(), } ) - if accelerator.is_main_process and args.track: - wandb.log({"samples/query_responses": wandb.Table(dataframe=all_sample_validation_df)}, step=update) - # print_rich_table("stuff", all_sample_validation_df[:4], console) + if accelerator.is_main_process: + all_sample_validation_df.to_json(f"runs/{run_name}/table.json") + if args.track: + wandb.log({"samples/query_responses": wandb.Table(dataframe=all_sample_validation_df)}, step=update) + print_rich_table("stuff", all_sample_validation_df[:4], console) except Exception as e: print(e) @@ -856,19 +844,22 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well advantages = torch.stack(advantages_reversed[::-1], axis=1) returns = advantages + values - - - - # TODO: reverse it back - # advantages = whiten(advantages) - - - - + # TODO: reversed back + advantages = masked_whiten(advantages, ~padding_mask) + advantages = torch.masked_fill(advantages, padding_mask, 0) return_mean, return_var = returns.mean(), returns.var() value_mean, value_var = values.mean(), values.var() - + writer.add_histogram("rewards", rewards[0].float(), global_step) + writer.add_histogram("advantages", advantages[0].float(), global_step) + accelerator.print("rewards====", rewards[0]) + accelerator.print("advantages====", advantages[0]) + # pprint({ + # "rewards": rewards, + # "returns": returns, + # "advantages": advantages, + # }) + # breakpoint() # Do multiple epochs of PPO training, with a fresh random shuffle in each epoch for ppo_epoch_idx in range(args.ppo.noptepochs): b_inds = np.random.permutation(args.ppo.local_batch_size) @@ -893,8 +884,8 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well logits /= args.task.temperature + 1e-7 new_all_logprobs = F.log_softmax(logits, dim=-1) new_logprobs = torch.gather(new_all_logprobs, 2, mb_responses.unsqueeze(-1)).squeeze(-1) - vpred = vpred_temp[:, context_length - 1 : -1].squeeze(-1) new_logprobs = torch.masked_fill(new_logprobs, padding_mask[micro_batch_inds], INVALID_LOGPROB) + vpred = vpred_temp[:, context_length - 1 : -1].squeeze(-1) vpred = torch.masked_fill(vpred, padding_mask[micro_batch_inds], 0) vpredclipped = torch.clamp( vpred, @@ -907,17 +898,16 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well # vf_loss = 0.5 * vf_loss_max.mean() - vf_loss = 0.5 * masked_mean(vf_loss_max, padding_mask[micro_batch_inds]) - - vf_clipfrac = masked_mean((vf_losses2 > vf_losses1).float(), padding_mask[micro_batch_inds]) + vf_loss = 0.5 * masked_mean(vf_loss_max, ~padding_mask[micro_batch_inds]) + vf_clipfrac = masked_mean((vf_losses2 > vf_losses1).float(), ~padding_mask[micro_batch_inds]) logprobs_diff = new_logprobs - mb_logprobs ratio = torch.exp(logprobs_diff) pg_losses = -mb_advantage * ratio pg_losses2 = -mb_advantage * torch.clamp(ratio, 1.0 - args.ppo.cliprange, 1.0 + args.ppo.cliprange) pg_loss_max = torch.max(pg_losses, pg_losses2) # pg_loss = pg_loss_max.mean() - pg_loss = masked_mean(pg_loss_max, padding_mask[micro_batch_inds]) - pg_clipfrac = masked_mean((pg_losses2 > pg_losses).float(), padding_mask[micro_batch_inds]) + pg_loss = masked_mean(pg_loss_max, ~padding_mask[micro_batch_inds]) + pg_clipfrac = masked_mean((pg_losses2 > pg_losses).float(), ~padding_mask[micro_batch_inds]) loss = pg_loss + args.ppo.vf_coef * vf_loss accelerator.backward(loss) optimizer.step() @@ -926,25 +916,27 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well prob_dist = torch.nn.functional.softmax(logits, dim=-1) entropy = torch.logsumexp(logits, dim=-1) - torch.sum(prob_dist * logits, dim=-1) # approxkl = 0.5 * (logprobs_diff**2).mean() - approxkl = 0.5 * masked_mean((logprobs_diff**2), padding_mask[micro_batch_inds]) + approxkl = 0.5 * masked_mean((logprobs_diff**2), ~padding_mask[micro_batch_inds]) # if ppo_epoch_idx == 0 and micro_batch_start == 0: # torch.testing.assert_close(ratio, torch.zeros_like(ratio) + 1, atol=1e-4, rtol=1e-4) - # pprint({ - # "responses": responses, - # "values": values, - # "rewards": rewards, - # "scores": scores, - # "advantages": advantages, - # "ratio": ratio, - # "pg_losses": pg_losses, - # "approxkl": approxkl, - # "pg_loss": pg_loss, - # "pg_clipfrac": pg_clipfrac, - # "ratio": ratio.mean(), - # "vf_loss": vf_loss, - # "vf_clipfrac": vf_clipfrac, - # "entropy": masked_mean(entropy, padding_mask[micro_batch_inds]), - # }) + # if ppo_epoch_idx == 0: + # pprint({ + # # "responses": responses, + # # "values": values, + # "rewards": rewards, + # # "scores": scores, + # "advantages": advantages, + # # "ratio": ratio, + # # "pg_losses": pg_losses, + # # "approxkl": approxkl, + # # "pg_loss": pg_loss, + # # "pg_clipfrac": pg_clipfrac, + # # "ratio": ratio.mean(), + # # "vf_loss": vf_loss, + # # "vf_clipfrac": vf_clipfrac, + # # "entropy": masked_mean(entropy, ~padding_mask[micro_batch_inds]), + # }) + # breakpoint() with torch.no_grad(): approxkl_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = approxkl pg_clipfrac_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = pg_clipfrac @@ -954,6 +946,7 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well entropy_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = masked_mean(entropy, padding_mask[micro_batch_inds]) ratio_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = ratio.mean() gradient_accumulation_idx += 1 + minibatch_idx += 1 if accelerator.is_main_process: console.print( @@ -968,11 +961,10 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well "ratio", ratio_stats[:ppo_epoch_idx+1].mean().item(), ) - + # breakpoint() with torch.no_grad(): if not args.deepspeed: # for some reason there is a OOM with the `writer.add_histogram` writer.add_histogram("ppo/val/ratio_hist", ratio, update) - kl = logprobs - ref_logprobs mean_kl = kl.sum(1).mean() mean_entropy = (-logprobs).sum(1).mean() mean_non_score_reward = non_score_reward.sum(1).mean() From 11ed54610eef38e4fcb0b2048d77b048b6ae4c29 Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Tue, 5 Dec 2023 16:54:26 +0000 Subject: [PATCH 32/62] seems to work ok with 1B models --- ...in_policy_accelerate_summarize_separate.py | 29 ++++++++++++------- 1 file changed, 19 insertions(+), 10 deletions(-) diff --git a/lm_human_preference_details/train_policy_accelerate_summarize_separate.py b/lm_human_preference_details/train_policy_accelerate_summarize_separate.py index acdcdd7..085833b 100644 --- a/lm_human_preference_details/train_policy_accelerate_summarize_separate.py +++ b/lm_human_preference_details/train_policy_accelerate_summarize_separate.py @@ -595,6 +595,10 @@ def forward(policy, query_responses, tokenizer): # each class should have a separate pretrained model that do not share weights ref_policy = AutoModelForCausalLM.from_pretrained(args.base_model, config=model_config, trust_remote_code=True) policy = AutoModelForCausalLM.from_pretrained(args.base_model, config=model_config, trust_remote_code=True) + # policy.gradient_checkpointing_enable() + # accelerator.print(policy) + # critic.lm_backbone.gradient_checkpointing_enable() + # accelerator.print(critic) if args.sft_model_path: policy.load_state_dict(torch.load(args.sft_model_path, map_location=device)) ref_policy.load_state_dict(torch.load(args.sft_model_path, map_location=device)) @@ -624,20 +628,23 @@ def forward(policy, query_responses, tokenizer): deepspeed_states = AcceleratorState().deepspeed_plugin # deepspeed_states.deepspeed_config["train_micro_batch_size_per_gpu"] = args.ppo.local_micro_batch_size # deepspeed_states.deepspeed_config["checkpoint"] = {"use_node_local_storage": True} + + offload = False eval_ds_config = { "train_micro_batch_size_per_gpu": deepspeed_states.deepspeed_config["train_micro_batch_size_per_gpu"], - # "steps_per_print": 10, - # "zero_optimization": { - # "stage": stage, - # "stage3_param_persistence_threshold": 1e4, - # "offload_param": { - # "device": off_load_device - # } - # }, "bf16": {"enabled": True}, "prescale_gradients": False, "wall_clock_breakdown": False, } + if offload: + eval_ds_config["zero_optimization"] = { + "stage": 3, + "stage3_param_persistence_threshold": 1e4, + "offload_param": { + "device": "cpu" + } + } + accelerator.print(f"{eval_ds_config=}") reward_model, *_ = deepspeed.initialize(model=reward_model, config=eval_ds_config) reward_model.eval() ref_policy, *_ = deepspeed.initialize(model=ref_policy, config=eval_ds_config) @@ -754,7 +761,7 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well postprocessed_query_responses = torch.cat((queries, postprocessed_responses), 1) # postprocessed_query_responses = shift_pad_id_left(postprocessed_query_responses, tokenizer.pad_token_id) full_values, _, _ = get_reward(accelerator.unwrap_model(model).critic, postprocessed_query_responses, tokenizer) - values = full_values[:, context_length - 1 : -1].squeeze(-1) + values = full_values[:, context_length:].squeeze(-1) padding_mask = postprocessed_responses == tokenizer.pad_token_id logprobs = torch.masked_fill(logprobs, padding_mask, INVALID_LOGPROB) ref_logprobs = torch.masked_fill(ref_logprobs, padding_mask, INVALID_LOGPROB) @@ -764,6 +771,7 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well _, reference_scores, _ = get_reward(reward_model, query_reference_responses, tokenizer) _, validation_score, _ = get_reward(reward_model, postprocessed_sample_validation_query_responses, tokenizer) + # raise # carperAI-style score normaliation scores = scores - reference_scores @@ -854,6 +862,7 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well writer.add_histogram("advantages", advantages[0].float(), global_step) accelerator.print("rewards====", rewards[0]) accelerator.print("advantages====", advantages[0]) + # raise # pprint({ # "rewards": rewards, # "returns": returns, @@ -885,7 +894,7 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well new_all_logprobs = F.log_softmax(logits, dim=-1) new_logprobs = torch.gather(new_all_logprobs, 2, mb_responses.unsqueeze(-1)).squeeze(-1) new_logprobs = torch.masked_fill(new_logprobs, padding_mask[micro_batch_inds], INVALID_LOGPROB) - vpred = vpred_temp[:, context_length - 1 : -1].squeeze(-1) + vpred = vpred_temp[:, context_length:].squeeze(-1) vpred = torch.masked_fill(vpred, padding_mask[micro_batch_inds], 0) vpredclipped = torch.clamp( vpred, From 277fd53d5a06d54bafb985e6825e50be0eb5a87c Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Wed, 6 Dec 2023 20:36:33 +0000 Subject: [PATCH 33/62] minor change --- .../train_policy_accelerate_summarize_separate.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/lm_human_preference_details/train_policy_accelerate_summarize_separate.py b/lm_human_preference_details/train_policy_accelerate_summarize_separate.py index 085833b..0e28c49 100644 --- a/lm_human_preference_details/train_policy_accelerate_summarize_separate.py +++ b/lm_human_preference_details/train_policy_accelerate_summarize_separate.py @@ -574,14 +574,14 @@ def forward(policy, query_responses, tokenizer): configure_dropout(model_config, args.dropout_layer_keys, 0.0) # disable dropout if accelerator.is_main_process: pprint(model_config) - reward_model = AutoModelForCausalLMWithRewardHead( + critic = AutoModelForCausalLMWithRewardHead( AutoModelForCausalLM.from_pretrained( args.base_model, config=model_config, trust_remote_code=True, ) ) - critic = AutoModelForCausalLMWithRewardHead( + reward_model = AutoModelForCausalLMWithRewardHead( AutoModelForCausalLM.from_pretrained( args.base_model, config=model_config, @@ -589,8 +589,8 @@ def forward(policy, query_responses, tokenizer): ) ) if args.rewards.trained_model: - reward_model.load_state_dict(torch.load(args.rewards.trained_model, map_location=device)) critic.load_state_dict(torch.load(args.rewards.trained_model, map_location=device)) + reward_model.load_state_dict(torch.load(args.rewards.trained_model, map_location=device)) print(f"loaded pretrained reward model from {args.rewards.trained_model}") # each class should have a separate pretrained model that do not share weights ref_policy = AutoModelForCausalLM.from_pretrained(args.base_model, config=model_config, trust_remote_code=True) @@ -852,7 +852,6 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well advantages = torch.stack(advantages_reversed[::-1], axis=1) returns = advantages + values - # TODO: reversed back advantages = masked_whiten(advantages, ~padding_mask) advantages = torch.masked_fill(advantages, padding_mask, 0) From 37a29633cad57ed2444878c61a693ef474d70542 Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Fri, 8 Dec 2023 14:52:33 +0000 Subject: [PATCH 34/62] remove files --- benchmark/summarize.slurm_template | 18 - .../summarization/minimal_rm copy.py | 28 - .../summarization/minimal_rm.py | 45 - .../summarization/minisft.py | 289 ----- .../train_policy_accelerate copy 2.py | 836 ------------- .../train_policy_accelerate copy.py | 945 --------------- .../train_policy_accelerate_new.py | 952 --------------- .../train_policy_accelerate_old.py | 922 --------------- ...in_policy_accelerate_summarize_ref_diff.py | 894 -------------- .../train_reward_accelerate copy.py | 736 ------------ .../train_reward_accelerate_debug copy.py | 542 --------- .../train_reward_accelerate_debug.py | 561 --------- ...train_reward_accelerate_summarize_debug.py | 981 ---------------- .../train_reward_accelerate_summarized.py | 785 ------------- .../train_reward_accelerate_summarizew.py | 836 ------------- .../train_sft_accelerate_summarize copy.py | 529 --------- ...train_sft_accelerate_summarize_executor.py | 539 --------- .../train_policy_accelerate_summarize.py | 870 -------------- ...in_policy_accelerate_summarize_separate.py | 1029 ----------------- .../train_reward_accelerate_summarize.py | 827 ------------- 20 files changed, 13164 deletions(-) delete mode 100644 benchmark/summarize.slurm_template delete mode 100644 lm_human_preference_details/summarization/minimal_rm copy.py delete mode 100644 lm_human_preference_details/summarization/minimal_rm.py delete mode 100644 lm_human_preference_details/summarization/minisft.py delete mode 100644 lm_human_preference_details/summarization/train_policy_accelerate copy 2.py delete mode 100644 lm_human_preference_details/summarization/train_policy_accelerate copy.py delete mode 100644 lm_human_preference_details/summarization/train_policy_accelerate_new.py delete mode 100644 lm_human_preference_details/summarization/train_policy_accelerate_old.py delete mode 100644 lm_human_preference_details/summarization/train_policy_accelerate_summarize_ref_diff.py delete mode 100644 lm_human_preference_details/summarization/train_reward_accelerate copy.py delete mode 100644 lm_human_preference_details/summarization/train_reward_accelerate_debug copy.py delete mode 100644 lm_human_preference_details/summarization/train_reward_accelerate_debug.py delete mode 100644 lm_human_preference_details/summarization/train_reward_accelerate_summarize_debug.py delete mode 100644 lm_human_preference_details/summarization/train_reward_accelerate_summarized.py delete mode 100644 lm_human_preference_details/summarization/train_reward_accelerate_summarizew.py delete mode 100644 lm_human_preference_details/summarization/train_sft_accelerate_summarize copy.py delete mode 100644 lm_human_preference_details/summarization/train_sft_accelerate_summarize_executor.py delete mode 100644 lm_human_preference_details/summarize_old/train_policy_accelerate_summarize.py delete mode 100644 lm_human_preference_details/summarize_old/train_policy_accelerate_summarize_separate.py delete mode 100644 lm_human_preference_details/summarize_old/train_reward_accelerate_summarize.py diff --git a/benchmark/summarize.slurm_template b/benchmark/summarize.slurm_template deleted file mode 100644 index 035feb7..0000000 --- a/benchmark/summarize.slurm_template +++ /dev/null @@ -1,18 +0,0 @@ -#!/bin/bash -#SBATCH --job-name=lm_human_preference_details -#SBATCH --partition=production-cluster -#SBATCH --gpus-per-task={{gpus_per_task}} -#SBATCH --cpus-per-gpu={{cpus_per_gpu}} -#SBATCH --ntasks={{ntasks}} -#SBATCH --output=slurm/logs/%x_%j.out -#SBATCH --array={{array}} -#SBATCH --exclude=ip-26-0-149-199 -#SBATCH --exclusive - -{{nodes}} - -seeds={{seeds}} -seed=${seeds[$SLURM_ARRAY_TASK_ID % {{len_seeds}}]} - -echo "Running task $SLURM_ARRAY_TASK_ID with seed: $seed" -SEED=$seed srun {{command}} diff --git a/lm_human_preference_details/summarization/minimal_rm copy.py b/lm_human_preference_details/summarization/minimal_rm copy.py deleted file mode 100644 index 0b6ae67..0000000 --- a/lm_human_preference_details/summarization/minimal_rm copy.py +++ /dev/null @@ -1,28 +0,0 @@ -import torch -import torch.nn as nn -from transformers import AutoModelForCausalLM, AutoTokenizer - - -class AutoModelForCausalLMWithRewardHead(nn.Module): - def __init__(self, lm_backbone): - super().__init__() - self.lm_backbone = lm_backbone - self.scalar_head = nn.Linear(lm_backbone.config.hidden_size, 1) - - def forward(self, **kwargs): - output = self.lm_backbone(**kwargs) - last_reward_latents = output.hidden_states[-1] - # shape: [batch_size, hidden_size] - reward = self.scalar_head(last_reward_latents) - return output, reward - - -base_model = "gpt2" -tokenizer = AutoTokenizer.from_pretrained(base_model, padding_side="left") -reward_model = AutoModelForCausalLMWithRewardHead(AutoModelForCausalLM.from_pretrained(base_model)) -mb_query = torch.randint(0, len(tokenizer), (1, 512)) -mb_responses = torch.randint(0, len(tokenizer), (1, 2, 80)) -mb_query_tiled = mb_query.unsqueeze(1).repeat(1, mb_responses.shape[1], 1) -query_responses = torch.cat([mb_query_tiled, mb_responses], dim=2).flatten(0, 1) -_, score = reward_model(input_ids=query_responses, return_dict=True, output_hidden_states=True) -print(score.squeeze(2)) diff --git a/lm_human_preference_details/summarization/minimal_rm.py b/lm_human_preference_details/summarization/minimal_rm.py deleted file mode 100644 index 1c993d0..0000000 --- a/lm_human_preference_details/summarization/minimal_rm.py +++ /dev/null @@ -1,45 +0,0 @@ -import torch -import torch.nn as nn -from transformers import AutoModelForCausalLM, AutoTokenizer - - -class AutoModelForCausalLMWithRewardHead(nn.Module): - def __init__(self, lm_backbone): - super().__init__() - self.lm_backbone = lm_backbone - self.scalar_head = nn.Linear(lm_backbone.config.hidden_size, 1) - - def forward(self, **kwargs): - output = self.lm_backbone(**kwargs) - last_reward_latents = output.hidden_states[-1] - # shape: [batch_size, hidden_size] - reward = self.scalar_head(last_reward_latents) - return output, reward - - -def get_reward(reward_model, query_responses, tokenizer): - attention_mask = query_responses != tokenizer.pad_token_id - position_ids = attention_mask.cumsum(1) - attention_mask.long() # exclusive cumsum - input_ids = query_responses.clone() - input_ids[~attention_mask] = 0 - return reward_model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - return_dict=True, - output_hidden_states=True, - ) - - -base_model = "gpt2" -tokenizer = AutoTokenizer.from_pretrained(base_model, padding_side="left") -tokenizer.add_special_tokens({"pad_token": "[PAD]"}) -reward_model = AutoModelForCausalLMWithRewardHead(AutoModelForCausalLM.from_pretrained(base_model)) -reward_model.train() -mb_query = torch.randint(0, len(tokenizer), (1, 10)) -mb_query[:, 0:4] = tokenizer.pad_token_id -mb_responses = torch.randint(0, len(tokenizer), (1, 2, 10)) -mb_query_tiled = mb_query.unsqueeze(1).repeat(1, mb_responses.shape[1], 1) -query_responses = torch.cat([mb_query_tiled, mb_responses], dim=2).flatten(0, 1) -_, score_all = get_reward(reward_model, query_responses, tokenizer) -print(score_all.squeeze(2)) diff --git a/lm_human_preference_details/summarization/minisft.py b/lm_human_preference_details/summarization/minisft.py deleted file mode 100644 index 85b7cbd..0000000 --- a/lm_human_preference_details/summarization/minisft.py +++ /dev/null @@ -1,289 +0,0 @@ -import os -import random -import time -from dataclasses import asdict, dataclass, field -from types import SimpleNamespace -from typing import Optional - -import numpy as np -import torch -import torch.optim as optim -import tyro -from accelerate import Accelerator -from datasets import load_dataset -from rich.console import Console -from rich.pretty import pprint -from torch import optim -from torch.nn import functional as F -from torch.utils.data import DataLoader -from torch.utils.tensorboard import SummaryWriter -from transformers import AutoModelForCausalLM, AutoTokenizer - -from lm_human_preference_details.data import process_query - - -@dataclass -class SFTHParams: - gradient_accumulation_steps: int = 1 - local_micro_batch_size: int = 16 - noptepochs: int = 1 - lr: float = 6.35e-5 - eps: float = 1e-5 - lm_loss_on_response_only: bool = False - total_episodes: tyro.conf.Suppress[int] = None - local_batch_size: tyro.conf.Suppress[int] = None - batch_size: tyro.conf.Suppress[int] = None - mini_batch_size: tyro.conf.Suppress[int] = None - world_size: tyro.conf.Suppress[int] = None - batch_size: tyro.conf.Suppress[int] = None - num_updates: tyro.conf.Suppress[int] = None - - -@dataclass -class TaskHParams: - # Query params - query_length: int = 512 - query_dataset: str = "vwxyzjn/summarize_from_feedback_tldr_3_filtered" - - query_format_str: Optional[str] = "SUBREDDIT: r/{subreddit}\n\nTITLE: {title}\n\nPOST: {post}\n\nTL;DR:" - query_truncate_field: Optional[str] = "post" - query_truncate_text: Optional[str] = "\n" - query_padding: Optional[str] = None # defaults to repeated spaces - query_pad_side: Optional[str] = "left" - - # Response params - response_length: int = 48 - - # Truncate response after the first occurrence of this token at or after index after when sampling. - truncate_token: int = 50256 # EOS token - truncate_after: int = 16 - penalty_reward_value: int = -1 - - # LM params - temperature: float = 0.01 - - -# a patch -@dataclass -class TaskQueryHParams: - length: int = None - dataset: str = None - format_str: Optional[str] = None # if underlying dataset yields dicts, can format arbitrarily - truncate_field: Optional[str] = None - truncate_text: Optional[str] = None - padding: Optional[str] = None # defaults to repeated spaces - pad_side: Optional[str] = None - - -@dataclass -class Args: - # common args - exp_name: str = os.path.basename(__file__)[: -len(".py")] - """the name of this experiment""" - seed: int = 1 - """seed of the experiment""" - track: bool = False - """if toggled, this experiment will be tracked with Weights and Biases""" - wandb_project_name: str = "cleanrl" - """the wandb's project name""" - wandb_entity: Optional[str] = None - """the entity (team) of wandb's project""" - cuda: bool = True - """Whether to use cuda if available.""" - run_name: tyro.conf.Suppress[str] = None - """TO BE FILLED: a unique name of this run""" - upload_model: bool = False - "whether to upload the saved model to huggingface" - hf_entity: str = "" - "the user or org name of the model repository from the Hugging Face Hub" - - base_model: str = "gpt2" - """the name of the pretrained model to use""" - deepspeed: bool = False - """Whether to use deepspeed to train the model""" - print_sample_output_freq: int = 220 - """How often to print sample output""" - save_path: str = "models/sft_policy.pt" - """Where to save the model""" - use_tensorflow_adam: bool = True - """Whether to use tensorflow-style Adam optimizer instead of PyTorch's""" - task: TaskHParams = field(default_factory=TaskHParams) - sft: SFTHParams = field(default_factory=SFTHParams) - - -def right_padding_to_left_padding(tokens, pad_id): - """Convert from right padding to left padding.""" - assert tokens.ndim == 2 - return torch.tensor( - [[pad_id] * (row == pad_id).sum() + [x for x in row if x != pad_id] for row in tokens], - device=tokens.device, - ) - - -def generate(lm_backbone, queries, tokenizer, generation_config): - """generate in a way that does not affect padding tokens""" - context_length = queries.shape[1] - attention_mask = queries != tokenizer.pad_token_id - input_ids = queries.clone() - input_ids[~attention_mask] = 0 # set padding tokens to 0 - output = lm_backbone.generate( - input_ids=input_ids, - attention_mask=attention_mask, - # position_ids=attention_mask.cumsum(1) - attention_mask.long(), # generation collapsed if this was turned on. TODO: why does generation collapse with this? - generation_config=generation_config, - return_dict_in_generate=True, - ) - # restore padding tokens - return torch.cat((queries, output.sequences[:, context_length:]), dim=1) - - -def forward(policy, query_responses, tokenizer): - attention_mask = query_responses != tokenizer.pad_token_id - position_ids = attention_mask.cumsum(1) - attention_mask.long() # exclusive cumsum - input_ids = torch.masked_fill(query_responses, ~attention_mask, 0) - return policy( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - return_dict=True, - ) - - -if __name__ == "__main__": - args = tyro.cli(Args) - accelerator = Accelerator(gradient_accumulation_steps=args.sft.gradient_accumulation_steps) - args.sft.world_size = accelerator.num_processes - args.sft.local_batch_size = args.sft.local_micro_batch_size * args.sft.gradient_accumulation_steps - args.sft.batch_size = int(args.sft.local_batch_size * args.sft.world_size) - patch_h = TaskQueryHParams( - length=args.task.query_length, - dataset=args.task.query_dataset, - format_str=args.task.query_format_str, - truncate_field=args.task.query_truncate_field, - truncate_text=args.task.query_truncate_text, - padding=args.task.query_padding, - pad_side=args.task.query_pad_side, - ) - dataset = load_dataset(args.task.query_dataset, split="train") - test_dataset = load_dataset(args.task.query_dataset, split="test") - accelerator.print("The number of samples in dataset", len(dataset)) - accelerator.print("The number of samples in test_dataset", len(test_dataset)) - args.sft.total_episodes = len(dataset) - args.sft.num_updates = args.sft.total_episodes // args.sft.batch_size - - console = Console(force_terminal=True) - run_name = f"{args.exp_name}__{args.seed}__{int(time.time())}" - writer = SimpleNamespace() # dummy writer - writer.add_scalar = lambda x, y, z: None - writer.add_histogram = lambda x, y, z: None - if accelerator.is_main_process: - if args.track: - import wandb - - wandb.init( - project=args.wandb_project_name, - entity=args.wandb_entity, - sync_tensorboard=True, - config=asdict(args), - name=run_name, - save_code=True, - ) - wandb.run.log_code(".") - writer = SummaryWriter(f"runs/{run_name}") - writer.add_text( - "hyperparameters", - "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])), - ) - pprint(args) - device = accelerator.device - local_seed = args.seed + accelerator.process_index * 100003 # Prime - random.seed(local_seed) - np.random.seed(local_seed) - torch.manual_seed(local_seed) - torch.backends.cudnn.deterministic = True - tokenizer = AutoTokenizer.from_pretrained( - args.base_model, - padding_side="right", - trust_remote_code=True, - ) - # we use the padding token manually but do not resize the token embedding of the model - tokenizer.add_special_tokens({"pad_token": "[PAD]"}) - policy = AutoModelForCausalLM.from_pretrained(args.base_model, trust_remote_code=True) - policy.generation_config.eos_token_id = None # disable `pad_token_id` and `eos_token_id` because we just want to - policy.generation_config.pad_token_id = None # generate tokens without truncation / padding - # IMPORTANT: Layer norm produces weird gradients, which affects Adam optimizer to impact all the parameters systematically - # see https://github.com/pytorch/pytorch/issues/104857 for more details - optimizer = optim.Adam(policy.parameters(), lr=args.sft.lr, eps=args.sft.eps) - - def process_query_data(x): - return { - **process_query(x, encoder=tokenizer, hparams=patch_h), - "reference_response": tokenizer.encode( - f" {x['summary']}<|endoftext|>", - padding="max_length", - max_length=args.task.response_length, - truncation=True, - # with an extra leading space to account for the space between the query and response - ), - } - - dataset = dataset.map(process_query_data) - dataset = dataset.with_format("torch", columns=["query_token", "reference_response"]) - dataset = dataset.shuffle(seed=local_seed) - test_dataset = test_dataset.map(process_query_data) - test_dataset = test_dataset.with_format("torch", columns=["query_token", "reference_response"]) - test_dataset = test_dataset.shuffle(seed=local_seed) - dataloader = DataLoader(dataset, batch_size=args.sft.local_micro_batch_size) - test_dataloader = DataLoader(test_dataset, batch_size=args.sft.local_micro_batch_size) - policy, optimizer, dataloader, test_dataloader = accelerator.prepare(policy, optimizer, dataloader, test_dataloader) - iter_dataloader = iter(dataloader) - # WARNING: even with `max_new_tokens` and `min_new_tokens` set to the same value, the number of tokens generated - # may not be the same. TODO: investigate further, we just want to generate a fixed number of tokens - # generation_config = GenerationConfig( - # max_new_tokens=args.task.response_length, - # min_new_tokens=args.task.response_length, - # temperature=args.task.temperature, - # top_k=0.0, - # top_p=1.0, - # do_sample=True, - # ) - - print("===training policy===") - global_step = 0 - test_data = test_dataset[0:10] - test_data = {k: v.to(device) for k, v in test_data.items()} - - # Given parameters - eta_min = 0 - eta_max = 6.35e-5 - T_max = args.sft.num_updates - - for update in range(1, args.sft.num_updates + 1): - global_step += 1 * args.sft.batch_size - accelerator.print(f"update {update}, global_step {global_step}") - # frac = 1.0 - (update - 1.0) / args.sft.num_updates - # lrnow = frac * args.sft.lr - lrnow = eta_min + 0.5 * (eta_max - eta_min) * (1 + np.cos(np.pi * (update - 1) / T_max)) - optimizer.param_groups[0]["lr"] = lrnow - data = next(iter_dataloader) - queries = data["query_token"].to(device) - reference_responses = data["reference_response"].to(device) - queries = right_padding_to_left_padding(queries, tokenizer.pad_token_id).to(device) - query_responses = torch.cat((queries, reference_responses), dim=1) - with accelerator.accumulate(policy): - output = forward(policy, query_responses, tokenizer) - # mask out gradient effects on response padding tokens - labels = query_responses.masked_fill(query_responses == tokenizer.pad_token_id, -1) - if args.sft.lm_loss_on_response_only: - # mask out gradient effects on query tokens - labels[:, : queries.shape[1]] = -1 - lm_logits = output.logits - # hand-rolled transformer loss: Shift so that tokens < n predict n - # but unlike `transformers` we mask the padding tokens via `ignore_index=-1` - shift_logits = lm_logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - loss = F.cross_entropy(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1), ignore_index=-1) - raise - accelerator.backward(loss) - optimizer.step() - optimizer.zero_grad() diff --git a/lm_human_preference_details/summarization/train_policy_accelerate copy 2.py b/lm_human_preference_details/summarization/train_policy_accelerate copy 2.py deleted file mode 100644 index 1b5943d..0000000 --- a/lm_human_preference_details/summarization/train_policy_accelerate copy 2.py +++ /dev/null @@ -1,836 +0,0 @@ -import os -import random -import time -from dataclasses import asdict, dataclass, field -from types import SimpleNamespace -from typing import List, Optional - -import numpy as np -import pandas as pd -import torch -import torch.nn as nn -import torch.nn.functional as F -import torch.optim as optim -import tyro -from accelerate import Accelerator -from accelerate.state import AcceleratorState -from rich.console import Console -from rich.pretty import pprint -from rich.table import Table -from torch import Tensor, optim -from torch.optim.optimizer import ( - _dispatch_sqrt, - _get_value, - _use_grad_for_differentiable, -) -from torch.utils.data import DataLoader, IterableDataset -from torch.utils.tensorboard import SummaryWriter -from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig - -from lm_human_preference_details.data import DATASET - - -@dataclass -class AdaptiveKLParams: - target: float = 6.0 - horizon: int = 10000 # in episodes - - -@dataclass -class RewardHParams: - kl_coef: float = 0.15 - adaptive_kl: Optional[AdaptiveKLParams] = field(default_factory=AdaptiveKLParams) - trained_model: Optional[str] = "models/reward.pt" - label_dataset: tyro.conf.Suppress[Optional[str]] = None - - -@dataclass -class PpoHParams: - total_episodes: int = 1000000 - local_batch_size: int = 64 - local_mini_batch_size: tyro.conf.Suppress[int] = None - batch_size: tyro.conf.Suppress[int] = None - mini_batch_size: tyro.conf.Suppress[int] = None - gradient_accumulation_steps: int = 1 - """gradient accumulation steps""" - local_micro_batch_size: tyro.conf.Suppress[int] = None - """per rank micro batch size""" - world_size: tyro.conf.Suppress[int] = None - batch_size: tyro.conf.Suppress[int] = None - minibatch_size: tyro.conf.Suppress[int] = None - num_updates: tyro.conf.Suppress[int] = None - nminibatches: int = 1 - noptepochs: int = 4 - lr: float = 0.00001 - eps: float = 1e-5 - vf_coef: float = 0.1 - cliprange: float = 0.2 - cliprange_value: float = 0.2 - gamma: float = 1 - lam: float = 0.95 - whiten_rewards: bool = True - - -@dataclass -class TaskHParams: - # Query params - query_length: int = 64 - query_dataset: str = "books" - query_prefix: str = "" - query_suffix: str = "" - start_text: Optional[str] = None - end_text: Optional[str] = None - - # Response params - response_length: int = 24 - - # Truncate response after the first occurrence of this token at or after index after when sampling. - truncate_token: int = 13 - truncate_after: int = 16 - penalty_reward_value: int = -1 - - # LM params - temperature: float = 0.7 - - -@dataclass -class Args: - # common args - exp_name: str = os.path.basename(__file__)[: -len(".py")] - """the name of this experiment""" - seed: int = 1 - """seed of the experiment""" - track: bool = False - """if toggled, this experiment will be tracked with Weights and Biases""" - wandb_project_name: str = "cleanrl" - """the wandb's project name""" - wandb_entity: Optional[str] = None - """the entity (team) of wandb's project""" - cuda: bool = True - """Whether to use cuda if available.""" - run_name: tyro.conf.Suppress[str] = None - """TO BE FILLED: a unique name of this run""" - - base_model: str = "gpt2" - """the name of the pretrained model to use""" - deepspeed: bool = False - """Whether to use deepspeed to train the model""" - print_sample_output_freq: int = 10 - """How often to print sample output""" - save_path: str = "models/policy.pt" - """Where to save the model""" - use_tensorflow_adam: bool = True - """Whether to use tensorflow-style Adam optimizer instead of PyTorch's""" - task: TaskHParams = field(default_factory=TaskHParams) - rewards: RewardHParams = field(default_factory=RewardHParams) - ppo: PpoHParams = field(default_factory=PpoHParams) - - -def print_rich_table(title: str, df: pd.DataFrame, console: Console) -> Table: - table = Table(show_lines=True) - for column in df.columns: - table.add_column(column) - for _, row in df.iterrows(): - table.add_row(*row.astype(str).tolist()) - console.rule(f"[bold red]{title}") - console.print(table) - - -def _single_tensor_adam( - params: List[Tensor], - grads: List[Tensor], - exp_avgs: List[Tensor], - exp_avg_sqs: List[Tensor], - max_exp_avg_sqs: List[Tensor], - state_steps: List[Tensor], - grad_scale: Optional[Tensor], - found_inf: Optional[Tensor], - *, - amsgrad: bool, - beta1: float, - beta2: float, - lr: float, - weight_decay: float, - eps: float, - maximize: bool, - capturable: bool, - differentiable: bool, -): - assert grad_scale is None and found_inf is None - - for i, param in enumerate(params): - grad = grads[i] if not maximize else -grads[i] - exp_avg = exp_avgs[i] - exp_avg_sq = exp_avg_sqs[i] - step_t = state_steps[i] - # update step - step_t += 1 - # Decay the first and second moment running average coefficient - exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) - exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2) - step = _get_value(step_t) - - ### pytorch adam implementation: - # bias_correction1 = 1 - beta1 ** step - # bias_correction2 = 1 - beta2 ** step - # step_size = lr / bias_correction1 - # bias_correction2_sqrt = _dispatch_sqrt(bias_correction2) - # denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps) - # param.addcdiv_(exp_avg, denom, value=-step_size) - - ### tensorflow adam implementation: - lr_t = lr * _dispatch_sqrt(1 - beta2**step) / (1 - beta1**step) - denom = exp_avg_sq.sqrt().add_(eps) - param.addcdiv_(exp_avg, denom, value=-lr_t) - - -def adam( - params: List[Tensor], - grads: List[Tensor], - exp_avgs: List[Tensor], - exp_avg_sqs: List[Tensor], - max_exp_avg_sqs: List[Tensor], - state_steps: List[Tensor], - # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627 - # setting this as kwarg for now as functional API is compiled by torch/distributed/optim - foreach: Optional[bool] = None, - capturable: bool = False, - differentiable: bool = False, - fused: Optional[bool] = None, - grad_scale: Optional[Tensor] = None, - found_inf: Optional[Tensor] = None, - *, - amsgrad: bool, - beta1: float, - beta2: float, - lr: float, - weight_decay: float, - eps: float, - maximize: bool, -): - func = _single_tensor_adam - - func( - params, - grads, - exp_avgs, - exp_avg_sqs, - max_exp_avg_sqs, - state_steps, - amsgrad=amsgrad, - beta1=beta1, - beta2=beta2, - lr=lr, - weight_decay=weight_decay, - eps=eps, - maximize=maximize, - capturable=capturable, - differentiable=differentiable, - grad_scale=grad_scale, - found_inf=found_inf, - ) - - -class AdamTensorFlowStyle(optim.Adam): - @_use_grad_for_differentiable - def step(self, closure=None): - self._cuda_graph_capture_health_check() - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - - for group in self.param_groups: - params_with_grad = [] - grads = [] - exp_avgs = [] - exp_avg_sqs = [] - max_exp_avg_sqs = [] - state_steps = [] - beta1, beta2 = group["betas"] - - self._init_group( - group, - params_with_grad, - grads, - exp_avgs, - exp_avg_sqs, - max_exp_avg_sqs, - state_steps, - ) - - adam( - params_with_grad, - grads, - exp_avgs, - exp_avg_sqs, - max_exp_avg_sqs, - state_steps, - amsgrad=group["amsgrad"], - beta1=beta1, - beta2=beta2, - lr=group["lr"], - weight_decay=group["weight_decay"], - eps=group["eps"], - maximize=group["maximize"], - foreach=group["foreach"], - capturable=group["capturable"], - differentiable=group["differentiable"], - fused=group["fused"], - grad_scale=getattr(self, "grad_scale", None), - found_inf=getattr(self, "found_inf", None), - ) - - return loss - - -class AdaptiveKLController: - def __init__(self, init_kl_coef: float, hparams: AdaptiveKLParams): - self.value = init_kl_coef - self.hparams = hparams - - def update(self, current, n_steps): - target = self.hparams.target - proportional_error = np.clip(current / target - 1, -0.2, 0.2) - mult = 1 + proportional_error * n_steps / self.hparams.horizon - self.value *= mult - - -def layer_init(layer, std=np.sqrt(2), bias_const=0.0): - torch.nn.init.normal_(layer.weight, std=std) - torch.nn.init.constant_(layer.bias, val=bias_const) - return layer - - -def whiten(values, shift_mean=True): - # `unbiased=False` matches TF `tf.nn.moments`'s setting - mean, var = torch.mean(values), torch.var(values, unbiased=False) - whitened = (values - mean) * torch.rsqrt(var + 1e-8) - if not shift_mean: - whitened += mean - return whitened - - -class AutoModelForCausalLMWithScalarHead(nn.Module): - def __init__(self, lm_backbone): - super().__init__() - self.lm_backbone = lm_backbone - self.scalar_head = layer_init(nn.Linear(lm_backbone.config.hidden_size, 1), std=0) - - def forward(self, **kwargs): - output = self.lm_backbone(**kwargs) - return output, self.scalar_head(output.hidden_states[-1]) - - -class AutoModelForCausalLMWithRewardHead(nn.Module): - def __init__(self, lm_backbone): - super().__init__() - self.lm_backbone = lm_backbone - self.scalar_head = layer_init( - nn.Linear(lm_backbone.config.hidden_size, 1), - std=1 / np.sqrt(lm_backbone.config.hidden_size + 1), - ) - self.reward_gain = torch.nn.Parameter(torch.tensor(1.0), requires_grad=True) - self.reward_bias = torch.nn.Parameter(torch.tensor(0.0), requires_grad=True) - - -# a pytorch dataset -class MyDataset(IterableDataset): - def __init__(self, generator, tokenizer, query_length, seed, start_text=None, end_text=None): - self.generator = generator - self.tokenizer = tokenizer - self.query_length = query_length - self.start_text = start_text - self.end_text = end_text - self.seed = seed - token_to_index = tokenizer.get_vocab() - self.start_token = token_to_index[start_text] if self.start_text else None - self.end_token = token_to_index[end_text] if self.end_text else None - - def __iter__(self): - for text in self.generator("train", self.seed, shuffle=True): - tokens = self.tokenizer.encode(text) - if self.start_token is not None: - try: - first_index = tokens.index(self.start_token) + 1 - if first_index < len(tokens): - tokens = tokens[first_index:] - except: - continue - tokens = tokens[: self.query_length] - if self.end_token is not None: - try: - last_index = len(tokens) - tokens[::-1].index(self.end_token) - tokens = tokens[:last_index] - except: - continue - output = self.tokenizer.pad( - {"input_ids": tokens}, - padding="max_length", - max_length=self.query_length, - return_tensors="pt", - return_attention_mask=True, - ) - yield output - - -def right_padding_to_left_padding(query, pad_id): - # Convert from right padding to left padding. - return torch.tensor( - [[pad_id] * (row == pad_id).sum() + [x for x in row if x != pad_id] for row in query], - device=query.device, - ) - - -def ceil_div(a, b): - return (a - 1) // b + 1 - - -def exact_div(a, b): - q = a // b - if a != q * b: - raise ValueError(f"Inexact division: {a} / {b} = {a / b}") - return q - - -def generate(lm_backbone, queries, tokenizer, generation_config): - """generate in a way that does not affect padding tokens""" - context_length = queries.shape[1] - attention_mask = queries != tokenizer.pad_token_id - input_ids = queries.clone() - input_ids[~attention_mask] = 0 # set padding tokens to 0 - output = lm_backbone.generate( - input_ids=input_ids, - attention_mask=attention_mask, - # position_ids=attention_mask.cumsum(1) - attention_mask.long(), # generation collapsed if this was turned on. TODO: why does generation collapse with this? - generation_config=generation_config, - return_dict_in_generate=True, - ) - # restore padding tokens - return torch.cat((queries, output.sequences[:, context_length:]), dim=1) - - -def get_reward(reward_model, query_responses, tokenizer): - attention_mask = query_responses != tokenizer.pad_token_id - position_ids = attention_mask.cumsum(1) - attention_mask.long() # exclusive cumsum - input_ids = query_responses.clone() - input_ids[~attention_mask] = 0 - output = reward_model.lm_backbone( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - return_dict=True, - output_hidden_states=True, - ) - reward = reward_model.scalar_head(output.hidden_states[-1]) - reward = reward_model.reward_gain * reward + reward_model.reward_bias - # but we only care about the reward of the last token - reward = reward[:, -1] - return reward - - -def forward(policy, query_responses, tokenizer): - attention_mask = query_responses != tokenizer.pad_token_id - position_ids = attention_mask.cumsum(1) - attention_mask.long() # exclusive cumsum - input_ids = query_responses.clone() - input_ids[~attention_mask] = 0 - return policy( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - return_dict=True, - output_hidden_states=True, - ) - - -def train(args: Args): - accelerator = Accelerator(gradient_accumulation_steps=args.ppo.gradient_accumulation_steps) - args.ppo.world_size = accelerator.num_processes - args.ppo.batch_size = int(args.ppo.local_batch_size * args.ppo.world_size) - args.ppo.minibatch_size = exact_div(args.ppo.batch_size, args.ppo.nminibatches) - args.ppo.local_mini_batch_size = exact_div(args.ppo.local_batch_size, args.ppo.nminibatches) - args.ppo.local_micro_batch_size = exact_div(args.ppo.local_mini_batch_size, args.ppo.gradient_accumulation_steps) - if args.ppo.whiten_rewards: - assert ( - args.ppo.local_mini_batch_size >= 8 - ), f"Per-rank minibatch size {args.ppo.local_mini_batch_size} is insufficient for whitening" - # `per_rank_rollout_batch_size` is our `args.ppo.local_batch_size` - # `per_rank_minibatch_size` is our `args.ppo.local_mini_batch_size` - args.ppo.num_updates = args.ppo.total_episodes // args.ppo.batch_size - - console = Console(force_terminal=True) - run_name = f"{args.exp_name}__{args.seed}__{int(time.time())}" - writer = SimpleNamespace() # dummy writer - writer.add_scalar = lambda x, y, z: None - writer.add_histogram = lambda x, y, z: None - if accelerator.is_main_process: - if args.track: - import wandb - - wandb.init( - project=args.wandb_project_name, - entity=args.wandb_entity, - sync_tensorboard=True, - config=asdict(args), - name=run_name, - save_code=True, - ) - wandb.run.log_code(".") - writer = SummaryWriter(f"runs/{run_name}") - writer.add_text( - "hyperparameters", - "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])), - ) - pprint(args) - device = accelerator.device - local_seed = args.seed + accelerator.process_index * 100003 # Prime - random.seed(local_seed) - np.random.seed(local_seed) - torch.manual_seed(local_seed) - torch.backends.cudnn.deterministic = True - tokenizer = AutoTokenizer.from_pretrained( - args.base_model, - padding_side="right", - trust_remote_code=True, - ) - # we use the padding token manually but do not resize the token embedding of the model - tokenizer.add_special_tokens({"pad_token": "[PAD]"}) - reward_model = AutoModelForCausalLMWithRewardHead( - AutoModelForCausalLM.from_pretrained(args.base_model, trust_remote_code=True) - ) - if args.rewards.trained_model: - reward_model.load_state_dict(torch.load(args.rewards.trained_model, map_location=device)) - print(f"loaded pretrained reward model from {args.rewards.trained_model}") - # each class should have a separate pretrained model that do not share weights - ref_policy = AutoModelForCausalLMWithScalarHead( - AutoModelForCausalLM.from_pretrained(args.base_model, trust_remote_code=True) - ) - policy = AutoModelForCausalLMWithScalarHead(AutoModelForCausalLM.from_pretrained(args.base_model, trust_remote_code=True)) - policy.lm_backbone.generation_config.eos_token_id = ( - None # disable `pad_token_id` and `eos_token_id` because we just want to - ) - policy.lm_backbone.generation_config.pad_token_id = None # generate tokens without truncation / padding - # IMPORTANT: Layer norm produces weird gradients, which affects Adam optimizer to impact all the parameters systematically - # see https://github.com/pytorch/pytorch/issues/104857 for more details - if args.use_tensorflow_adam: - optimizer = AdamTensorFlowStyle(policy.parameters(), lr=args.ppo.lr, eps=args.ppo.eps) - else: - optimizer = optim.Adam(policy.parameters(), lr=args.ppo.lr, eps=args.ppo.eps) - dataset = MyDataset( - DATASET[args.task.query_dataset], - tokenizer, - args.task.query_length, - seed=local_seed, - start_text=args.task.start_text, - end_text=args.task.end_text, - ) - dataloader = DataLoader(dataset, batch_size=args.ppo.local_batch_size) - policy, optimizer, dataloader = accelerator.prepare(policy, optimizer, dataloader) - if args.deepspeed: - import deepspeed - - deepspeed_states = AcceleratorState().deepspeed_plugin - # deepspeed_states.deepspeed_config["train_micro_batch_size_per_gpu"] = args.ppo.local_micro_batch_size - # deepspeed_states.deepspeed_config["checkpoint"] = {"use_node_local_storage": True} - eval_ds_config = { - "train_micro_batch_size_per_gpu": deepspeed_states.deepspeed_config["train_micro_batch_size_per_gpu"], - # "steps_per_print": 10, - # "zero_optimization": { - # "stage": stage, - # "stage3_param_persistence_threshold": 1e4, - # "offload_param": { - # "device": off_load_device - # } - # }, - "bf16": {"enabled": True}, - "prescale_gradients": False, - "wall_clock_breakdown": False, - } - reward_model, *_ = deepspeed.initialize(model=reward_model, config=eval_ds_config) - reward_model.eval() - ref_policy, *_ = deepspeed.initialize(model=ref_policy, config=eval_ds_config) - ref_policy.eval() - else: - ref_policy = ref_policy.to(device) - reward_model = reward_model.to(device) - iter_dataloader = iter(dataloader) - kl_ctl = AdaptiveKLController(args.rewards.kl_coef, hparams=args.rewards.adaptive_kl) - # WARNING: even with `max_new_tokens` and `min_new_tokens` set to the same value, the number of tokens generated - # may not be the same. TODO: investigate further, we just want to generate a fixed number of tokens - generation_config = GenerationConfig( - max_new_tokens=args.task.response_length, - min_new_tokens=args.task.response_length, - temperature=args.task.temperature, - top_k=0.0, - top_p=1.0, - do_sample=True, - ) - - print("===training policy===") - global_step = 0 - stats_shape = (args.ppo.noptepochs, args.ppo.nminibatches, args.ppo.gradient_accumulation_steps) - approxkls_stats = torch.zeros(stats_shape, device=device) - clipfracs_stats = torch.zeros(stats_shape, device=device) - pg_losses_stats = torch.zeros(stats_shape, device=device) - vf_losses_stats = torch.zeros(stats_shape, device=device) - vf_clipfrac_stats = torch.zeros(stats_shape, device=device) - entropies_stats = torch.zeros(stats_shape, device=device) - for update in range(1, args.ppo.num_updates + 1): - global_step += 1 * args.ppo.batch_size - frac = 1.0 - (update - 1.0) / args.ppo.num_updates - lrnow = frac * args.ppo.lr - optimizer.param_groups[0]["lr"] = lrnow - data = next(iter_dataloader) - with torch.no_grad(): - queries = data["input_ids"].to(device) - queries = right_padding_to_left_padding(data["input_ids"], tokenizer.pad_token_id).to(device) - query_responses = generate( - accelerator.unwrap_model(policy).lm_backbone, - queries, - tokenizer, - generation_config, - ) - context_length = queries.shape[1] - responses = query_responses[:, context_length:] - - output, full_values = forward(policy, query_responses, tokenizer) - values = full_values[:, context_length - 1 : -1].squeeze(-1) - logits = output.logits[:, context_length - 1 : -1] - logits /= args.task.temperature - all_logprobs = F.log_softmax(logits, dim=-1) - logprobs = torch.gather(all_logprobs, 2, responses.unsqueeze(-1)).squeeze(-1) - del output, logits, all_logprobs - torch.cuda.empty_cache() - - ref_output, _ = forward(ref_policy, query_responses, tokenizer) - ref_logits = ref_output.logits[:, context_length - 1 : -1] - ref_logits /= args.task.temperature - ref_all_logprobs = F.log_softmax(ref_logits, dim=-1) - ref_logprobs = torch.gather(ref_all_logprobs, 2, responses.unsqueeze(-1)).squeeze(-1) - del ref_output, ref_logits, ref_all_logprobs - torch.cuda.empty_cache() - - # **Response Processing** - # 1. truncate at the first occurrence of `truncate_token` that appears at or after - # position truncate_after in the responses - # https://github.com/openai/lm-human-preferences/blob/cbfd210bb8b08f6bc5c26878c10984b90f516c66/lm_human_preferences/train_policy.py#L378 - truncate_token_mask = responses == args.task.truncate_token - truncate_after_or_token_mask = torch.cat( - [ - torch.zeros_like(truncate_token_mask)[:, : args.task.truncate_after], - truncate_token_mask[:, args.task.truncate_after :], - ], - dim=1, - ) - truncate_mask = (torch.cumsum(truncate_after_or_token_mask, dim=1) - truncate_after_or_token_mask.long()).bool() - postprocessed_responses = torch.where( - truncate_mask, - torch.full_like(responses, tokenizer.pad_token_id), - responses, - ) - del truncate_token_mask, truncate_after_or_token_mask, truncate_mask - torch.cuda.empty_cache() - - # 2. run reward model on the truncated responses - postprocessed_query_responses = torch.cat((queries, postprocessed_responses), 1) - postprocessed_query_responses = right_padding_to_left_padding( - postprocessed_query_responses, tokenizer.pad_token_id - ) - scores = get_reward(reward_model, postprocessed_query_responses, tokenizer).flatten() - - # 3. filter response. Ensure that the sample contains truncate_token - # responses not passing that filter will receive a low (fixed) score - # only query humans on responses that pass that filter - matches_token = postprocessed_responses[:, args.task.truncate_after :] == args.task.truncate_token - filter_mask = torch.any(matches_token, dim=-1) - scores = torch.where( - filter_mask, - scores, - torch.full_like(scores, args.task.penalty_reward_value), - ) - del matches_token, filter_mask - torch.cuda.empty_cache() - - # 4. compute rewards - kl = logprobs - ref_logprobs - non_score_reward = -kl_ctl.value * kl - rewards = non_score_reward.clone() - rewards[:, -1] += scores - - # 5. whiten rewards - if args.ppo.whiten_rewards: - rewards = whiten(rewards, shift_mean=False) - - if args.print_sample_output_freq > 0 and (update - 1) % args.print_sample_output_freq == 0: - try: - all_decode_queries = tokenizer.batch_decode(queries, skip_special_tokens=True) - all_postprocessed_query_responses = tokenizer.batch_decode( - postprocessed_query_responses, skip_special_tokens=True - ) - all_postprocessed_responses = [ - x[len(y) :] for x, y in zip(all_postprocessed_query_responses, all_decode_queries) - ] - - kl_sum = kl.sum(axis=1) - all_df = pd.DataFrame( - { - "query": all_decode_queries, - "response": all_postprocessed_responses, - "score": scores.float().cpu().numpy(), - "kl": kl_sum.float().cpu().numpy(), - "reward": (scores - kl_ctl.value * kl_sum).float().cpu().numpy(), - } - ) - if accelerator.is_main_process and args.track: - wandb.log({"query_responses": wandb.Table(dataframe=all_df)}, step=update) - print_rich_table("stuff", all_df[:4], console) - except Exception as e: - print(e) - del ( - all_decode_queries, - all_postprocessed_query_responses, - all_postprocessed_responses, - kl_sum, - all_df, - ) - del postprocessed_query_responses - torch.cuda.empty_cache() - - # 6. compute advantages and returns - lastgaelam = 0 - advantages_reversed = [] - gen_length = args.task.response_length - for t in reversed(range(gen_length)): - nextvalues = values[:, t + 1] if t < gen_length - 1 else 0.0 - delta = rewards[:, t] + args.ppo.gamma * nextvalues - values[:, t] - lastgaelam = delta + args.ppo.gamma * args.ppo.lam * lastgaelam - advantages_reversed.append(lastgaelam) - advantages = torch.stack(advantages_reversed[::-1], axis=1) - returns = advantages + values - advantages = whiten(advantages) - return_mean, return_var = returns.mean(), returns.var() - value_mean, value_var = values.mean(), values.var() - - # Do multiple epochs of PPO training, with a fresh random shuffle in each epoch - for ppo_epoch_idx in range(args.ppo.noptepochs): - b_inds = np.random.permutation(args.ppo.local_batch_size) - minibatch_idx = 0 - for mini_batch_start in range(0, args.ppo.local_batch_size, args.ppo.local_mini_batch_size): - mini_batch_end = mini_batch_start + args.ppo.local_mini_batch_size - mini_batch_inds = b_inds[mini_batch_start:mini_batch_end] - gradient_accumulation_idx = 0 - for micro_batch_start in range(0, args.ppo.local_mini_batch_size, args.ppo.local_micro_batch_size): - with accelerator.accumulate(policy): - micro_batch_end = micro_batch_start + args.ppo.local_micro_batch_size - micro_batch_inds = mini_batch_inds[micro_batch_start:micro_batch_end] - mb_return = returns[micro_batch_inds] - mb_advantage = advantages[micro_batch_inds] - mb_values = values[micro_batch_inds] - mb_responses = responses[micro_batch_inds] - mb_query_responses = query_responses[micro_batch_inds] - mb_logprobs = logprobs[micro_batch_inds] - - output, vpred_temp = forward(policy, mb_query_responses, tokenizer) - logits = output.logits[:, context_length - 1 : -1] - logits /= args.task.temperature - new_all_logprobs = F.log_softmax(logits, dim=-1) - new_logprobs = torch.gather(new_all_logprobs, 2, mb_responses.unsqueeze(-1)).squeeze(-1) - vpred = vpred_temp[:, context_length - 1 : -1].squeeze(-1) - vpredclipped = torch.clamp( - vpred, - mb_values - args.ppo.cliprange_value, - mb_values + args.ppo.cliprange_value, - ) - vf_losses1 = torch.square(vpred - mb_return) - vf_losses2 = torch.square(vpredclipped - mb_return) - vf_loss = 0.5 * torch.max(vf_losses1, vf_losses2).mean() - vf_clipfrac = (vf_losses2 > vf_losses1).float().mean() - logprobs_diff = new_logprobs - mb_logprobs - ratio = torch.exp(logprobs_diff) - pg_losses = -mb_advantage * ratio - pg_losses2 = -mb_advantage * torch.clamp(ratio, 1.0 - args.ppo.cliprange, 1.0 + args.ppo.cliprange) - pg_loss = torch.max(pg_losses, pg_losses2).mean() - pg_clipfrac = (pg_losses2 > pg_losses).float().mean() - loss = pg_loss + args.ppo.vf_coef * vf_loss - accelerator.backward(loss) - optimizer.step() - optimizer.zero_grad() - prob_dist = torch.nn.functional.softmax(logits, dim=-1) - entropy = torch.logsumexp(logits, dim=-1) - torch.sum(prob_dist * logits, dim=-1) - approxkl = 0.5 * (logprobs_diff**2).mean() - with torch.no_grad(): - approxkls_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = approxkl - clipfracs_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = pg_clipfrac - pg_losses_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = pg_loss - vf_losses_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = vf_loss - vf_clipfrac_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = vf_clipfrac - entropies_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = entropy.mean() - gradient_accumulation_idx += 1 - minibatch_idx += 1 - if accelerator.is_main_process: - console.print( - f"ppo_epoch_idx", - ppo_epoch_idx, - "approxkl", - approxkl.item(), - "pg_loss", - pg_loss.item(), - "pg_clipfrac", - pg_clipfrac.item(), - "ratio", - ratio.mean().item(), - ) - - with torch.no_grad(): - if not args.deepspeed: # for some reason there is a OOM with the `writer.add_histogram` - writer.add_histogram("ppo/val/ratio_hist", ratio, update) - kl = logprobs - ref_logprobs - mean_kl = kl.sum(1).mean() - mean_entropy = (-logprobs).sum(1).mean() - mean_non_score_reward = non_score_reward.sum(1).mean() - writer.add_scalar("objective/kl_coef", kl_ctl.value, update) - writer.add_scalar("objective/kl", accelerator.gather(mean_kl).mean().item(), update) - writer.add_scalar("objective/entropy", accelerator.gather(mean_entropy).mean().item(), update) - writer.add_scalar("objective/non_score_reward", accelerator.gather(mean_non_score_reward).mean().item(), update) - writer.add_scalar( - "objective/score_total", accelerator.gather(mean_non_score_reward + scores.mean()).mean().item(), update - ) - writer.add_scalar("objective/scores", accelerator.gather(scores.mean()).mean().item(), update) - writer.add_scalar("ppo/loss/policy", accelerator.gather(pg_loss).mean().item(), update) - writer.add_scalar("ppo/loss/value", accelerator.gather(vf_loss).mean().item(), update) - writer.add_scalar("ppo/loss/total", accelerator.gather(loss).mean().item(), update) - writer.add_scalar("ppo/policy/entropy", accelerator.gather(entropy.mean()).mean().item(), update) - writer.add_scalar("ppo/policy/approxkl", accelerator.gather(approxkl).mean().item(), update) - writer.add_scalar("ppo/policy/clipfrac", accelerator.gather(pg_clipfrac).mean().item(), update) - writer.add_scalar("ppo/policy/approxkl_avg", accelerator.gather(approxkls_stats).mean().item(), update) - writer.add_scalar("ppo/policy/clipfrac_avg", accelerator.gather(clipfracs_stats).mean().item(), update) - writer.add_scalar("ppo/loss/policy_avg", accelerator.gather(pg_losses_stats).mean().item(), update) - writer.add_scalar("ppo/loss/value_avg", accelerator.gather(vf_losses_stats).mean().item(), update) - writer.add_scalar("ppo/val/clipfrac_avg", accelerator.gather(vf_clipfrac_stats).mean().item(), update) - writer.add_scalar("ppo/policy/entropy_avg", accelerator.gather(entropies_stats).mean().item(), update) - writer.add_scalar("ppo/returns/mean", accelerator.gather(return_mean).mean().item(), update) - writer.add_scalar("ppo/returns/var", accelerator.gather(return_var).mean().item(), update) - writer.add_scalar("ppo/val/vpred", accelerator.gather(vpred.mean()).mean().item(), update) - writer.add_scalar("ppo/val/error", accelerator.gather(vf_losses1.mean()).mean().item(), update) - writer.add_scalar("ppo/val/clipfrac", accelerator.gather(vf_clipfrac).mean().item(), update) - writer.add_scalar("ppo/val/mean", accelerator.gather(value_mean).mean().item(), update) - writer.add_scalar("ppo/val/var", accelerator.gather(value_var).mean().item(), update) - writer.add_scalar("ppo/val/ratio", accelerator.gather(ratio.mean()).mean().item(), update) - writer.add_scalar("ppo/val/ratio_var", accelerator.gather(ratio.mean()).var().item(), update) - writer.add_scalar("ppo/val/advantage", accelerator.gather(advantages.mean()).mean().item(), update) - writer.add_scalar("ppo/val/advantage_var", accelerator.gather(advantages.mean()).var().item(), update) - writer.add_scalar("ppo/val/num_eos_tokens", (responses == tokenizer.eos_token_id).sum().item(), update) - writer.add_scalar("ppo/lr", lrnow, update) - writer.add_scalar("ppo/episode", global_step, update) - kl_ctl.update(mean_kl.item(), args.ppo.batch_size) - del kl, mean_kl, mean_entropy, mean_non_score_reward, scores - - # save model - if args.save_path: - os.makedirs(os.path.dirname(args.save_path), exist_ok=True) - torch.save(reward_model.state_dict(), args.save_path) - - -if __name__ == "__main__": - args = tyro.cli(Args) - train(args) diff --git a/lm_human_preference_details/summarization/train_policy_accelerate copy.py b/lm_human_preference_details/summarization/train_policy_accelerate copy.py deleted file mode 100644 index e9e6d84..0000000 --- a/lm_human_preference_details/summarization/train_policy_accelerate copy.py +++ /dev/null @@ -1,945 +0,0 @@ -import os -import random -import time -from dataclasses import asdict, dataclass, field -from types import SimpleNamespace -from typing import Optional - -import numpy as np -import torch -import torch.nn as nn -import torch.nn.functional as F -import torch.optim as optim -import tyro -from accelerate import Accelerator -from accelerate.state import AcceleratorState -from rich.console import Console -from rich.pretty import pprint -from torch.utils.data import DataLoader, IterableDataset -from torch.utils.tensorboard import SummaryWriter -from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig - -from lm_human_preference_details.data import DATASET - - -@dataclass -class AdaptiveKLParams: - target: float = 6.0 - horizon: int = 10000 # in episodes - - -@dataclass -class RewardHParams: - kl_coef: float = 0.15 - adaptive_kl: Optional[AdaptiveKLParams] = field(default_factory=AdaptiveKLParams) - trained_model: Optional[str] = "models/reward.pt" - label_dataset: tyro.conf.Suppress[Optional[str]] = None - - -@dataclass -class PpoHParams: - total_episodes: int = 1000000 - local_batch_size: int = 64 - local_mini_batch_size: tyro.conf.Suppress[int] = None - batch_size: tyro.conf.Suppress[int] = None - mini_batch_size: tyro.conf.Suppress[int] = None - gradient_accumulation_steps: int = 1 - """gradient accumulation steps""" - local_micro_batch_size: tyro.conf.Suppress[int] = None - """per rank micro batch size""" - world_size: tyro.conf.Suppress[int] = None - batch_size: tyro.conf.Suppress[int] = None - minibatch_size: tyro.conf.Suppress[int] = None - num_updates: tyro.conf.Suppress[int] = None - nminibatches: int = 1 - noptepochs: int = 4 - lr: float = 0.00001 - eps: float = 1e-5 - vf_coef: float = 0.1 - cliprange: float = 0.2 - cliprange_value: float = 0.2 - gamma: float = 1 - lam: float = 0.95 - whiten_rewards: bool = True - - -@dataclass -class TaskHParams: - # Query params - query_length: int = 64 - query_dataset: str = "books" - query_prefix: str = "" - query_suffix: str = "" - start_text: Optional[str] = None - end_text: Optional[str] = None - - # Response params - response_length: int = 24 - - # Truncate response after the first occurrence of this token at or after index after when sampling. - truncate_token: int = 13 - truncate_after: int = 16 - penalty_reward_value: int = -1 - - # LM params - temperature: float = 0.7 - - -@dataclass -class Args: - # common args - exp_name: str = os.path.basename(__file__)[: -len(".py")] - """the name of this experiment""" - seed: int = 1 - """seed of the experiment""" - track: bool = False - """if toggled, this experiment will be tracked with Weights and Biases""" - wandb_project_name: str = "cleanrl" - """the wandb's project name""" - wandb_entity: Optional[str] = None - """the entity (team) of wandb's project""" - cuda: bool = True - """Whether to use cuda if available.""" - run_name: tyro.conf.Suppress[str] = None - """TO BE FILLED: a unique name of this run""" - - base_model: str = "gpt2" - """the name of the pretrained model to use""" - deepspeed: bool = False - """Whether to use deepspeed to train the model""" - print_sample_output_freq: int = 0 - """How often to print sample output""" - save_path: str = "models/policy.pt" - """Where to save the model""" - use_tensorflow_adam: bool = True - """Whether to use tensorflow-style Adam optimizer instead of PyTorch's""" - task: TaskHParams = field(default_factory=TaskHParams) - rewards: RewardHParams = field(default_factory=RewardHParams) - ppo: PpoHParams = field(default_factory=PpoHParams) - - -from typing import List, Optional - -from torch import Tensor, optim -from torch.optim.optimizer import ( - _dispatch_sqrt, - _get_value, - _use_grad_for_differentiable, -) - - -def _single_tensor_adam( - params: List[Tensor], - grads: List[Tensor], - exp_avgs: List[Tensor], - exp_avg_sqs: List[Tensor], - max_exp_avg_sqs: List[Tensor], - state_steps: List[Tensor], - grad_scale: Optional[Tensor], - found_inf: Optional[Tensor], - *, - amsgrad: bool, - beta1: float, - beta2: float, - lr: float, - weight_decay: float, - eps: float, - maximize: bool, - capturable: bool, - differentiable: bool, -): - assert grad_scale is None and found_inf is None - - for i, param in enumerate(params): - grad = grads[i] if not maximize else -grads[i] - exp_avg = exp_avgs[i] - exp_avg_sq = exp_avg_sqs[i] - step_t = state_steps[i] - # update step - step_t += 1 - # Decay the first and second moment running average coefficient - exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) - exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2) - step = _get_value(step_t) - - ### pytorch adam implementation: - # bias_correction1 = 1 - beta1 ** step - # bias_correction2 = 1 - beta2 ** step - # step_size = lr / bias_correction1 - # bias_correction2_sqrt = _dispatch_sqrt(bias_correction2) - # denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps) - # param.addcdiv_(exp_avg, denom, value=-step_size) - - ### tensorflow adam implementation: - lr_t = lr * _dispatch_sqrt(1 - beta2**step) / (1 - beta1**step) - denom = exp_avg_sq.sqrt().add_(eps) - param.addcdiv_(exp_avg, denom, value=-lr_t) - - -def adam( - params: List[Tensor], - grads: List[Tensor], - exp_avgs: List[Tensor], - exp_avg_sqs: List[Tensor], - max_exp_avg_sqs: List[Tensor], - state_steps: List[Tensor], - # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627 - # setting this as kwarg for now as functional API is compiled by torch/distributed/optim - foreach: Optional[bool] = None, - capturable: bool = False, - differentiable: bool = False, - fused: Optional[bool] = None, - grad_scale: Optional[Tensor] = None, - found_inf: Optional[Tensor] = None, - *, - amsgrad: bool, - beta1: float, - beta2: float, - lr: float, - weight_decay: float, - eps: float, - maximize: bool, -): - func = _single_tensor_adam - - func( - params, - grads, - exp_avgs, - exp_avg_sqs, - max_exp_avg_sqs, - state_steps, - amsgrad=amsgrad, - beta1=beta1, - beta2=beta2, - lr=lr, - weight_decay=weight_decay, - eps=eps, - maximize=maximize, - capturable=capturable, - differentiable=differentiable, - grad_scale=grad_scale, - found_inf=found_inf, - ) - - -class AdamTensorFlowStyle(optim.Adam): - @_use_grad_for_differentiable - def step(self, closure=None): - self._cuda_graph_capture_health_check() - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - - for group in self.param_groups: - params_with_grad = [] - grads = [] - exp_avgs = [] - exp_avg_sqs = [] - max_exp_avg_sqs = [] - state_steps = [] - beta1, beta2 = group["betas"] - - self._init_group( - group, - params_with_grad, - grads, - exp_avgs, - exp_avg_sqs, - max_exp_avg_sqs, - state_steps, - ) - - adam( - params_with_grad, - grads, - exp_avgs, - exp_avg_sqs, - max_exp_avg_sqs, - state_steps, - amsgrad=group["amsgrad"], - beta1=beta1, - beta2=beta2, - lr=group["lr"], - weight_decay=group["weight_decay"], - eps=group["eps"], - maximize=group["maximize"], - foreach=group["foreach"], - capturable=group["capturable"], - differentiable=group["differentiable"], - fused=group["fused"], - grad_scale=getattr(self, "grad_scale", None), - found_inf=getattr(self, "found_inf", None), - ) - - return loss - - -class AdaptiveKLController: - def __init__(self, init_kl_coef: float, hparams: AdaptiveKLParams): - self.value = init_kl_coef - self.hparams = hparams - - def update(self, current, n_steps): - target = self.hparams.target - proportional_error = np.clip(current / target - 1, -0.2, 0.2) - mult = 1 + proportional_error * n_steps / self.hparams.horizon - self.value *= mult - - -def layer_init(layer, std=np.sqrt(2), bias_const=0.0): - torch.nn.init.normal_(layer.weight, std=std) - torch.nn.init.constant_(layer.bias, val=bias_const) - return layer - - -def whiten(values, shift_mean=True): - # `unbiased=False` matches TF `tf.nn.moments`'s setting - mean, var = torch.mean(values), torch.var(values, unbiased=False) - whitened = (values - mean) * torch.rsqrt(var + 1e-8) - if not shift_mean: - whitened += mean - return whitened - - -class AutoModelForCausalLMWithScalarHead(nn.Module): - def __init__(self, lm_backbone): - super().__init__() - self.lm_backbone = lm_backbone - self.scalar_head = layer_init(nn.Linear(lm_backbone.config.hidden_size, 1), std=0) - - def forward(self, **kwargs): - output = self.lm_backbone(**kwargs) - return output, self.scalar_head(output.hidden_states[-1]) - - -class AutoModelForCausalLMWithRewardHead(nn.Module): - def __init__(self, lm_backbone): - super().__init__() - self.lm_backbone = lm_backbone - self.scalar_head = layer_init( - nn.Linear(lm_backbone.config.hidden_size, 1), - std=1 / np.sqrt(lm_backbone.config.hidden_size + 1), - ) - self.reward_gain = torch.nn.Parameter(torch.tensor(1.0), requires_grad=True) - self.reward_bias = torch.nn.Parameter(torch.tensor(0.0), requires_grad=True) - - -# a pytorch dataset -class MyDataset(IterableDataset): - def __init__(self, generator, tokenizer, query_length, seed, start_text=None, end_text=None): - self.generator = generator - self.tokenizer = tokenizer - self.query_length = query_length - self.start_text = start_text - self.end_text = end_text - self.seed = seed - token_to_index = tokenizer.get_vocab() - self.start_token = token_to_index[start_text] if self.start_text else None - self.end_token = token_to_index[end_text] if self.end_text else None - - def __iter__(self): - for text in self.generator("train", self.seed, shuffle=True): - tokens = self.tokenizer.encode(text) - if self.start_token is not None: - try: - first_index = tokens.index(self.start_token) + 1 - if first_index < len(tokens): - tokens = tokens[first_index:] - except: - continue - tokens = tokens[: self.query_length] - if self.end_token is not None: - try: - last_index = len(tokens) - tokens[::-1].index(self.end_token) - tokens = tokens[:last_index] - except: - continue - output = self.tokenizer.pad( - {"input_ids": tokens}, - padding="max_length", - max_length=self.query_length, - return_tensors="pt", - return_attention_mask=True, - ) - yield output - - -def right_padding_to_left_padding(query, pad_id): - # Convert from right padding to left padding. - return torch.tensor( - [[pad_id] * (row == pad_id).sum() + [x for x in row if x != pad_id] for row in query], - device=query.device, - ) - - -def ceil_div(a, b): - return (a - 1) // b + 1 - - -def exact_div(a, b): - q = a // b - if a != q * b: - raise ValueError(f"Inexact division: {a} / {b} = {a / b}") - return q - - -def generate(lm_backbone, queries, tokenizer, generation_config): - """generate in a way that does not affect padding tokens""" - context_length = queries.shape[1] - attention_mask = queries != tokenizer.pad_token_id - input_ids = queries.clone() - input_ids[~attention_mask] = 0 # set padding tokens to 0 - output = lm_backbone.generate( - input_ids=input_ids, - attention_mask=attention_mask, - # position_ids=attention_mask.cumsum(1) - attention_mask.long(), # generation collapsed if this was turned on. TODO: why does generation collapse with this? - generation_config=generation_config, - return_dict_in_generate=True, - ) - # restore padding tokens - return torch.cat((queries, output.sequences[:, context_length:]), dim=1) - - -def get_reward(reward_model, query_responses, tokenizer): - attention_mask = query_responses != tokenizer.pad_token_id - position_ids = attention_mask.cumsum(1) - attention_mask.long() # exclusive cumsum - input_ids = query_responses.clone() - input_ids[~attention_mask] = 0 - output = reward_model.lm_backbone( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - return_dict=True, - output_hidden_states=True, - ) - reward = reward_model.scalar_head(output.hidden_states[-1]) - reward = reward_model.reward_gain * reward + reward_model.reward_bias - # but we only care about the reward of the last token - reward = reward[:, -1] - return reward - - -def forward(policy, query_responses, tokenizer): - attention_mask = query_responses != tokenizer.pad_token_id - position_ids = attention_mask.cumsum(1) - attention_mask.long() # exclusive cumsum - input_ids = query_responses.clone() - input_ids[~attention_mask] = 0 - return policy( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - return_dict=True, - output_hidden_states=True, - ) - - -def train(args: Args): - accelerator = Accelerator(gradient_accumulation_steps=args.ppo.gradient_accumulation_steps) - args.ppo.world_size = accelerator.num_processes - args.ppo.batch_size = int(args.ppo.local_batch_size * args.ppo.world_size) - args.ppo.minibatch_size = exact_div(args.ppo.batch_size, args.ppo.nminibatches) - args.ppo.local_mini_batch_size = exact_div(args.ppo.local_batch_size, args.ppo.nminibatches) - args.ppo.local_micro_batch_size = exact_div(args.ppo.local_mini_batch_size, args.ppo.gradient_accumulation_steps) - if args.ppo.whiten_rewards: - assert ( - args.ppo.local_mini_batch_size >= 8 - ), f"Per-rank minibatch size {args.ppo.local_mini_batch_size} is insufficient for whitening" - # `per_rank_rollout_batch_size` is our `args.ppo.local_batch_size` - # `per_rank_minibatch_size` is our `args.ppo.local_mini_batch_size` - args.ppo.num_updates = args.ppo.total_episodes // args.ppo.batch_size - - console = Console(force_terminal=True) - run_name = f"{args.exp_name}__{args.seed}__{int(time.time())}" - writer = SimpleNamespace() # dummy writer - writer.add_scalar = lambda x, y, z: None - writer.add_histogram = lambda x, y, z: None - if accelerator.is_main_process: - if args.track: - import wandb - - wandb.init( - project=args.wandb_project_name, - entity=args.wandb_entity, - sync_tensorboard=True, - config=asdict(args), - name=run_name, - save_code=True, - ) - wandb.run.log_code(".") - writer = SummaryWriter(f"runs/{run_name}") - writer.add_text( - "hyperparameters", - "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])), - ) - pprint(args) - device = accelerator.device - local_seed = args.seed + accelerator.process_index * 100003 # Prime - random.seed(local_seed) - np.random.seed(local_seed) - torch.manual_seed(local_seed) - torch.backends.cudnn.deterministic = True - tokenizer = AutoTokenizer.from_pretrained( - args.base_model, - padding_side="right", - trust_remote_code=True, - ) - # we use the padding token manually but do not resize the token embedding of the model - tokenizer.add_special_tokens({"pad_token": "[PAD]"}) - reward_model = AutoModelForCausalLMWithRewardHead( - AutoModelForCausalLM.from_pretrained(args.base_model, trust_remote_code=True) - ) - if args.rewards.trained_model: - reward_model.load_state_dict(torch.load(args.rewards.trained_model, map_location=device)) - print(f"loaded pretrained reward model from {args.rewards.trained_model}") - # each class should have a separate pretrained model that do not share weights - ref_policy = AutoModelForCausalLMWithScalarHead( - AutoModelForCausalLM.from_pretrained(args.base_model, trust_remote_code=True) - ) - policy = AutoModelForCausalLMWithScalarHead(AutoModelForCausalLM.from_pretrained(args.base_model, trust_remote_code=True)) - policy.lm_backbone.generation_config.eos_token_id = ( - None # disable `pad_token_id` and `eos_token_id` because we just want to - ) - policy.lm_backbone.generation_config.pad_token_id = None # generate tokens without truncation / padding - # IMPORTANT: Layer norm produces weird gradients, which affects Adam optimizer to impact all the parameters systematically - # see https://github.com/pytorch/pytorch/issues/104857 for more details - if args.use_tensorflow_adam: - optimizer = AdamTensorFlowStyle(policy.parameters(), lr=args.ppo.lr, eps=args.ppo.eps) - else: - optimizer = optim.Adam(policy.parameters(), lr=args.ppo.lr, eps=args.ppo.eps) - dataset = MyDataset( - DATASET[args.task.query_dataset], - tokenizer, - args.task.query_length, - seed=local_seed, - start_text=args.task.start_text, - end_text=args.task.end_text, - ) - dataloader = DataLoader(dataset, batch_size=args.ppo.local_batch_size) - policy, optimizer, dataloader = accelerator.prepare(policy, optimizer, dataloader) - if args.deepspeed: - import deepspeed - - deepspeed_states = AcceleratorState().deepspeed_plugin - deepspeed_states.deepspeed_config["train_micro_batch_size_per_gpu"] = args.ppo.local_micro_batch_size - deepspeed_states.deepspeed_config["checkpoint"] = {"use_node_local_storage": True} - eval_ds_config = { - "train_micro_batch_size_per_gpu": deepspeed_states.deepspeed_config["train_micro_batch_size_per_gpu"], - "steps_per_print": 10, - # "zero_optimization": { - # "stage": stage, - # "stage3_param_persistence_threshold": 1e4, - # "offload_param": { - # "device": off_load_device - # } - # }, - "bf16": {"enabled": True}, - "prescale_gradients": False, - "wall_clock_breakdown": False, - } - reward_model, *_ = deepspeed.initialize(model=reward_model, config=eval_ds_config) - reward_model.eval() - ref_policy, *_ = deepspeed.initialize(model=ref_policy, config=eval_ds_config) - ref_policy.eval() - else: - ref_policy = ref_policy.to(device) - reward_model = reward_model.to(device) - iter_dataloader = iter(dataloader) - kl_ctl = AdaptiveKLController(args.rewards.kl_coef, hparams=args.rewards.adaptive_kl) - # WARNING: even with `max_new_tokens` and `min_new_tokens` set to the same value, the number of tokens generated - # may not be the same. TODO: investigate further, we just want to generate a fixed number of tokens - generation_config = GenerationConfig( - max_new_tokens=args.task.response_length, - min_new_tokens=args.task.response_length, - temperature=args.task.temperature, - top_k=0.0, - top_p=1.0, - do_sample=True, - ) - - print("===training policy===") - global_step = 0 - approxkls_stats = torch.zeros( - ( - args.ppo.noptepochs, - args.ppo.nminibatches, - args.ppo.gradient_accumulation_steps, - ), - device=device, - ) - clipfracs_stats = torch.zeros( - ( - args.ppo.noptepochs, - args.ppo.nminibatches, - args.ppo.gradient_accumulation_steps, - ), - device=device, - ) - pg_losses_stats = torch.zeros( - ( - args.ppo.noptepochs, - args.ppo.nminibatches, - args.ppo.gradient_accumulation_steps, - ), - device=device, - ) - vf_losses_stats = torch.zeros( - ( - args.ppo.noptepochs, - args.ppo.nminibatches, - args.ppo.gradient_accumulation_steps, - ), - device=device, - ) - vf_clipfrac_stats = torch.zeros( - ( - args.ppo.noptepochs, - args.ppo.nminibatches, - args.ppo.gradient_accumulation_steps, - ), - device=device, - ) - entropies_stats = torch.zeros( - ( - args.ppo.noptepochs, - args.ppo.nminibatches, - args.ppo.gradient_accumulation_steps, - ), - device=device, - ) - for update in range(1, args.ppo.num_updates + 1): - global_step += 1 * args.ppo.batch_size - frac = 1.0 - (update - 1.0) / args.ppo.num_updates - lrnow = frac * args.ppo.lr - optimizer.param_groups[0]["lr"] = lrnow - data = next(iter_dataloader) - with torch.no_grad(): - queries = data["input_ids"].to(device) - queries = right_padding_to_left_padding(data["input_ids"], tokenizer.pad_token_id).to(device) - query_responses = generate( - accelerator.unwrap_model(policy).lm_backbone, - queries, - tokenizer, - generation_config, - ) - context_length = queries.shape[1] - responses = query_responses[:, context_length:] - - output, full_values = forward(policy, query_responses, tokenizer) - values = full_values[:, context_length - 1 : -1].squeeze(-1) - logits = output.logits[:, context_length - 1 : -1] - logits /= args.task.temperature - all_logprobs = F.log_softmax(logits, dim=-1) - logprobs = torch.gather(all_logprobs, 2, responses.unsqueeze(-1)).squeeze(-1) - - output4, _ = forward(policy, query_responses, tokenizer) - logits4 = output4.logits[:, context_length - 1 : -1] - logits4 /= args.task.temperature - all_logprobs4 = F.log_softmax(logits4, dim=-1) - logprobs4 = torch.gather(all_logprobs4, 2, responses.unsqueeze(-1)).squeeze(-1) - del output, logits, all_logprobs - torch.cuda.empty_cache() - - ref_output, _ = forward(ref_policy, query_responses, tokenizer) - ref_logits = ref_output.logits[:, context_length - 1 : -1] - ref_logits /= args.task.temperature - ref_all_logprobs = F.log_softmax(ref_logits, dim=-1) - ref_logprobs = torch.gather(ref_all_logprobs, 2, responses.unsqueeze(-1)).squeeze(-1) - del ref_output, ref_logits, ref_all_logprobs - torch.cuda.empty_cache() - - # **Response Processing** - # 1. truncate at the first occurrence of `truncate_token` that appears at or after - # position truncate_after in the responses - # https://github.com/openai/lm-human-preferences/blob/cbfd210bb8b08f6bc5c26878c10984b90f516c66/lm_human_preferences/train_policy.py#L378 - truncate_token_mask = responses == args.task.truncate_token - truncate_after_or_token_mask = torch.cat( - [ - torch.zeros_like(truncate_token_mask)[:, : args.task.truncate_after], - truncate_token_mask[:, args.task.truncate_after :], - ], - dim=1, - ) - truncate_mask = (torch.cumsum(truncate_after_or_token_mask, dim=1) - truncate_after_or_token_mask.long()).bool() - postprocessed_responses = torch.where( - truncate_mask, - torch.full_like(responses, tokenizer.pad_token_id), - responses, - ) - del truncate_token_mask, truncate_after_or_token_mask, truncate_mask - torch.cuda.empty_cache() - - # 2. run reward model on the truncated responses - postprocessed_query_responses = torch.cat((queries, postprocessed_responses), 1) - postprocessed_query_responses = right_padding_to_left_padding( - postprocessed_query_responses, tokenizer.pad_token_id - ) - scores = get_reward(reward_model, postprocessed_query_responses, tokenizer).flatten() - - # 3. filter response. Ensure that the sample contains truncate_token - # responses not passing that filter will receive a low (fixed) score - # only query humans on responses that pass that filter - matches_token = postprocessed_responses[:, args.task.truncate_after :] == args.task.truncate_token - filter_mask = torch.any(matches_token, dim=-1) - scores = torch.where( - filter_mask, - scores, - torch.full_like(scores, args.task.penalty_reward_value), - ) - del matches_token, filter_mask - torch.cuda.empty_cache() - - # 4. compute rewards - kl = logprobs - ref_logprobs - non_score_reward = -kl_ctl.value * kl - rewards = non_score_reward.clone() - rewards[:, -1] += scores - - # 5. whiten rewards - if args.ppo.whiten_rewards: - rewards = whiten(rewards, shift_mean=False) - try: - sample_kl = kl[0].sum().item() - postprocessed_responses = postprocessed_query_responses[:, context_length:] - console.print( - f"[green]{tokenizer.decode(queries[0], skip_special_tokens=True)}[/]\n[yellow]{tokenizer.decode(postprocessed_responses[0], skip_special_tokens=True)}[/]\n[blue](NO POST-PROCESSING){tokenizer.decode(responses[0], skip_special_tokens=True)}[/]\n[red]score: {scores[0]}, kl: {kl[0].sum().item()}, total reward: {scores[0] - kl_ctl.value * sample_kl} [/]" - ) - except Exception as e: - print(e) - del postprocessed_query_responses - torch.cuda.empty_cache() - - # 6. compute advantages and returns - lastgaelam = 0 - advantages_reversed = [] - gen_length = args.task.response_length - for t in reversed(range(gen_length)): - nextvalues = values[:, t + 1] if t < gen_length - 1 else 0.0 - delta = rewards[:, t] + args.ppo.gamma * nextvalues - values[:, t] - lastgaelam = delta + args.ppo.gamma * args.ppo.lam * lastgaelam - advantages_reversed.append(lastgaelam) - advantages = torch.stack(advantages_reversed[::-1], axis=1) - returns = advantages + values - advantages = whiten(advantages) - return_mean, return_var = returns.mean(), returns.var() - value_mean, value_var = values.mean(), values.var() - - # Do multiple epochs of PPO training, with a fresh random shuffle in each epoch - for ppo_epoch_idx in range(args.ppo.noptepochs): - b_inds = np.random.permutation(args.ppo.local_batch_size) - minibatch_idx = 0 - for mini_batch_start in range(0, args.ppo.local_batch_size, args.ppo.local_mini_batch_size): - mini_batch_end = mini_batch_start + args.ppo.local_mini_batch_size - mini_batch_inds = b_inds[mini_batch_start:mini_batch_end] - gradient_accumulation_idx = 0 - for micro_batch_start in range(0, args.ppo.local_mini_batch_size, args.ppo.local_micro_batch_size): - micro_batch_end = micro_batch_start + args.ppo.local_micro_batch_size - micro_batch_inds = mini_batch_inds[micro_batch_start:micro_batch_end] - mb_return = returns[micro_batch_inds] - mb_advantage = advantages[micro_batch_inds] - mb_values = values[micro_batch_inds] - mb_responses = responses[micro_batch_inds] - mb_query_responses = query_responses[micro_batch_inds] - mb_logprobs = logprobs[micro_batch_inds] - output2, vpred_temp = forward(policy, mb_query_responses, tokenizer) - logits2 = output2.logits[:, context_length - 1 : -1] - logits2 /= args.task.temperature - new_all_logprobs2 = F.log_softmax(logits2, dim=-1) - new_logprobs2 = torch.gather(new_all_logprobs2, 2, mb_responses.unsqueeze(-1)).squeeze(-1) - - with accelerator.accumulate(policy): - - output, vpred_temp = forward(policy, mb_query_responses, tokenizer) - logits = output.logits[:, context_length - 1 : -1] - logits /= args.task.temperature - new_all_logprobs = F.log_softmax(logits, dim=-1) - new_logprobs = torch.gather(new_all_logprobs, 2, mb_responses.unsqueeze(-1)).squeeze(-1) - - vpred = vpred_temp[:, context_length - 1 : -1].squeeze(-1) - vpredclipped = torch.clamp( - vpred, - mb_values - args.ppo.cliprange_value, - mb_values + args.ppo.cliprange_value, - ) - vf_losses1 = torch.square(vpred - mb_return) - vf_losses2 = torch.square(vpredclipped - mb_return) - vf_loss = 0.5 * torch.max(vf_losses1, vf_losses2).mean() - vf_clipfrac = (vf_losses2 > vf_losses1).float().mean() - logprobs_diff = new_logprobs - mb_logprobs - pprint( - { - "new_logprobs": new_logprobs, - "new_logprobs2": new_logprobs2, - "mb_logprobs": mb_logprobs, - "mb_logprobs2": logprobs4[micro_batch_inds], - } - ) - ratio = torch.exp(logprobs_diff) - print(ratio.mean()) - breakpoint() - pg_losses = -mb_advantage * ratio - pg_losses2 = -mb_advantage * torch.clamp(ratio, 1.0 - args.ppo.cliprange, 1.0 + args.ppo.cliprange) - pg_loss = torch.max(pg_losses, pg_losses2).mean() - pg_clipfrac = (pg_losses2 > pg_losses).float().mean() - loss = pg_loss + args.ppo.vf_coef * vf_loss - accelerator.backward(loss) - optimizer.step() - optimizer.zero_grad() - pd = torch.nn.functional.softmax(logits, dim=-1) - entropy = torch.logsumexp(logits, dim=-1) - torch.sum(pd * logits, dim=-1) - approxkl = 0.5 * (logprobs_diff**2).mean() - with torch.no_grad(): - approxkls_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = approxkl - clipfracs_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = pg_clipfrac - pg_losses_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = pg_loss - vf_losses_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = vf_loss - vf_clipfrac_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = vf_clipfrac - entropies_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = entropy.mean() - gradient_accumulation_idx += 1 - minibatch_idx += 1 - if accelerator.is_main_process: - console.print( - f"ppo_epoch_idx", - ppo_epoch_idx, - "approxkl", - approxkl.item(), - "pg_loss", - pg_loss.item(), - "pg_clipfrac", - pg_clipfrac.item(), - "ratio", - ratio.mean().item(), - ) - - with torch.no_grad(): - if not args.deepspeed: # for some reason there is a OOM with the `writer.add_histogram` - writer.add_histogram("ppo/val/ratio_hist", ratio, update) - kl = logprobs - ref_logprobs - mean_kl = kl.sum(1).mean() - mean_entropy = (-logprobs).sum(1).mean() - mean_non_score_reward = non_score_reward.sum(1).mean() - writer.add_scalar("objective/kl_coef", kl_ctl.value, update) - writer.add_scalar("objective/kl", accelerator.gather(mean_kl).mean().item(), update) - writer.add_scalar( - "objective/entropy", - accelerator.gather(mean_entropy).mean().item(), - update, - ) - writer.add_scalar( - "objective/non_score_reward", - accelerator.gather(mean_non_score_reward).mean().item(), - update, - ) - writer.add_scalar( - "objective/score_total", - accelerator.gather(mean_non_score_reward + scores.mean()).mean().item(), - update, - ) - writer.add_scalar( - "objective/scores", - accelerator.gather(scores.mean()).mean().item(), - update, - ) - writer.add_scalar("ppo/loss/policy", accelerator.gather(pg_loss).mean().item(), update) - writer.add_scalar("ppo/loss/value", accelerator.gather(vf_loss).mean().item(), update) - writer.add_scalar("ppo/loss/total", accelerator.gather(loss).mean().item(), update) - writer.add_scalar( - "ppo/policy/entropy", - accelerator.gather(entropy.mean()).mean().item(), - update, - ) - writer.add_scalar( - "ppo/policy/approxkl", - accelerator.gather(approxkl).mean().item(), - update, - ) - writer.add_scalar( - "ppo/policy/clipfrac", - accelerator.gather(pg_clipfrac).mean().item(), - update, - ) - writer.add_scalar( - "ppo/policy/approxkl_avg", - accelerator.gather(approxkls_stats).mean().item(), - update, - ) - writer.add_scalar( - "ppo/policy/clipfrac_avg", - accelerator.gather(clipfracs_stats).mean().item(), - update, - ) - writer.add_scalar( - "ppo/loss/policy_avg", - accelerator.gather(pg_losses_stats).mean().item(), - update, - ) - writer.add_scalar( - "ppo/loss/value_avg", - accelerator.gather(vf_losses_stats).mean().item(), - update, - ) - writer.add_scalar( - "ppo/val/clipfrac_avg", - accelerator.gather(vf_clipfrac_stats).mean().item(), - update, - ) - writer.add_scalar( - "ppo/policy/entropy_avg", - accelerator.gather(entropies_stats).mean().item(), - update, - ) - writer.add_scalar( - "ppo/returns/mean", - accelerator.gather(return_mean).mean().item(), - update, - ) - writer.add_scalar("ppo/returns/var", accelerator.gather(return_var).mean().item(), update) - writer.add_scalar("ppo/val/vpred", accelerator.gather(vpred.mean()).mean().item(), update) - writer.add_scalar( - "ppo/val/error", - accelerator.gather(vf_losses1.mean()).mean().item(), - update, - ) - writer.add_scalar( - "ppo/val/clipfrac", - accelerator.gather(vf_clipfrac).mean().item(), - update, - ) - writer.add_scalar("ppo/val/mean", accelerator.gather(value_mean).mean().item(), update) - writer.add_scalar("ppo/val/var", accelerator.gather(value_var).mean().item(), update) - writer.add_scalar("ppo/val/ratio", accelerator.gather(ratio.mean()).mean().item(), update) - writer.add_scalar( - "ppo/val/ratio_var", - accelerator.gather(ratio.mean()).var().item(), - update, - ) - writer.add_scalar( - "ppo/val/advantage", - accelerator.gather(advantages.mean()).mean().item(), - update, - ) - writer.add_scalar( - "ppo/val/advantage_var", - accelerator.gather(advantages.mean()).var().item(), - update, - ) - writer.add_scalar( - "ppo/val/num_eos_tokens", - (responses == tokenizer.eos_token_id).sum().item(), - update, - ) - writer.add_scalar("ppo/lr", lrnow, update) - writer.add_scalar("ppo/episode", global_step, update) - kl_ctl.update(mean_kl.item(), args.ppo.batch_size) - del kl, mean_kl, mean_entropy, mean_non_score_reward, scores - - # save model - if args.save_path: - os.makedirs(os.path.dirname(args.save_path), exist_ok=True) - torch.save(reward_model.state_dict(), args.save_path) - - -if __name__ == "__main__": - args = tyro.cli(Args) - train(args) diff --git a/lm_human_preference_details/summarization/train_policy_accelerate_new.py b/lm_human_preference_details/summarization/train_policy_accelerate_new.py deleted file mode 100644 index 3187431..0000000 --- a/lm_human_preference_details/summarization/train_policy_accelerate_new.py +++ /dev/null @@ -1,952 +0,0 @@ -import os -import random -import time -from dataclasses import asdict, dataclass, field -from types import SimpleNamespace -from typing import Optional - -import numpy as np -import torch -import torch.nn as nn -import torch.nn.functional as F -import torch.optim as optim -import tyro -from accelerate import Accelerator -from accelerate.state import AcceleratorState -from rich.console import Console -from rich.pretty import pprint -from torch.utils.data import DataLoader, IterableDataset -from torch.utils.tensorboard import SummaryWriter -from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig - -from lm_human_preference_details.data import DATASET - - -@dataclass -class AdaptiveKLParams: - target: float = 6.0 - horizon: int = 10000 # in episodes - - -@dataclass -class RewardHParams: - kl_coef: float = 0.15 - adaptive_kl: Optional[AdaptiveKLParams] = field(default_factory=AdaptiveKLParams) - trained_model: Optional[str] = "models/reward.pt" - label_dataset: tyro.conf.Suppress[Optional[str]] = None - - -@dataclass -class PpoHParams: - total_episodes: int = 1000000 - local_batch_size: int = 64 - local_mini_batch_size: tyro.conf.Suppress[int] = None - batch_size: tyro.conf.Suppress[int] = None - mini_batch_size: tyro.conf.Suppress[int] = None - gradient_accumulation_steps: int = 1 - """gradient accumulation steps""" - local_micro_batch_size: tyro.conf.Suppress[int] = None - """per rank micro batch size""" - world_size: tyro.conf.Suppress[int] = None - batch_size: tyro.conf.Suppress[int] = None - minibatch_size: tyro.conf.Suppress[int] = None - num_updates: tyro.conf.Suppress[int] = None - nminibatches: int = 1 - noptepochs: int = 4 - lr: float = 0.00001 - eps: float = 1e-5 - vf_coef: float = 0.1 - cliprange: float = 0.2 - cliprange_value: float = 0.2 - gamma: float = 1 - lam: float = 0.95 - whiten_rewards: bool = True - - -@dataclass -class TaskHParams: - # Query params - query_length: int = 64 - query_dataset: str = "books" - query_prefix: str = "" - query_suffix: str = "" - start_text: Optional[str] = None - end_text: Optional[str] = None - - # Response params - response_length: int = 24 - - # Truncate response after the first occurrence of this token at or after index after when sampling. - truncate_token: int = 13 - truncate_after: int = 16 - penalty_reward_value: int = -1 - - # LM params - temperature: float = 0.7 - - -@dataclass -class Args: - # common args - exp_name: str = os.path.basename(__file__)[: -len(".py")] - """the name of this experiment""" - seed: int = 1 - """seed of the experiment""" - track: bool = False - """if toggled, this experiment will be tracked with Weights and Biases""" - wandb_project_name: str = "cleanrl" - """the wandb's project name""" - wandb_entity: Optional[str] = None - """the entity (team) of wandb's project""" - cuda: bool = True - """Whether to use cuda if available.""" - run_name: tyro.conf.Suppress[str] = None - """TO BE FILLED: a unique name of this run""" - - base_model: str = "gpt2" - """the name of the pretrained model to use""" - deepspeed: bool = False - """Whether to use deepspeed to train the model""" - print_sample_output_freq: int = 0 - """How often to print sample output""" - save_path: str = "models/policy.pt" - """Where to save the model""" - use_tensorflow_adam: bool = True - """Whether to use tensorflow-style Adam optimizer instead of PyTorch's""" - task: TaskHParams = field(default_factory=TaskHParams) - rewards: RewardHParams = field(default_factory=RewardHParams) - ppo: PpoHParams = field(default_factory=PpoHParams) - - -from typing import List, Optional - -from torch import Tensor, optim -from torch.optim.optimizer import ( - _dispatch_sqrt, - _get_value, - _use_grad_for_differentiable, -) - - -def _single_tensor_adam( - params: List[Tensor], - grads: List[Tensor], - exp_avgs: List[Tensor], - exp_avg_sqs: List[Tensor], - max_exp_avg_sqs: List[Tensor], - state_steps: List[Tensor], - grad_scale: Optional[Tensor], - found_inf: Optional[Tensor], - *, - amsgrad: bool, - beta1: float, - beta2: float, - lr: float, - weight_decay: float, - eps: float, - maximize: bool, - capturable: bool, - differentiable: bool, -): - assert grad_scale is None and found_inf is None - - for i, param in enumerate(params): - grad = grads[i] if not maximize else -grads[i] - exp_avg = exp_avgs[i] - exp_avg_sq = exp_avg_sqs[i] - step_t = state_steps[i] - # update step - step_t += 1 - # Decay the first and second moment running average coefficient - exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) - exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2) - step = _get_value(step_t) - - ### pytorch adam implementation: - # bias_correction1 = 1 - beta1 ** step - # bias_correction2 = 1 - beta2 ** step - # step_size = lr / bias_correction1 - # bias_correction2_sqrt = _dispatch_sqrt(bias_correction2) - # denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps) - # param.addcdiv_(exp_avg, denom, value=-step_size) - - ### tensorflow adam implementation: - lr_t = lr * _dispatch_sqrt(1 - beta2**step) / (1 - beta1**step) - denom = exp_avg_sq.sqrt().add_(eps) - param.addcdiv_(exp_avg, denom, value=-lr_t) - - -def adam( - params: List[Tensor], - grads: List[Tensor], - exp_avgs: List[Tensor], - exp_avg_sqs: List[Tensor], - max_exp_avg_sqs: List[Tensor], - state_steps: List[Tensor], - # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627 - # setting this as kwarg for now as functional API is compiled by torch/distributed/optim - foreach: Optional[bool] = None, - capturable: bool = False, - differentiable: bool = False, - fused: Optional[bool] = None, - grad_scale: Optional[Tensor] = None, - found_inf: Optional[Tensor] = None, - *, - amsgrad: bool, - beta1: float, - beta2: float, - lr: float, - weight_decay: float, - eps: float, - maximize: bool, -): - func = _single_tensor_adam - - func( - params, - grads, - exp_avgs, - exp_avg_sqs, - max_exp_avg_sqs, - state_steps, - amsgrad=amsgrad, - beta1=beta1, - beta2=beta2, - lr=lr, - weight_decay=weight_decay, - eps=eps, - maximize=maximize, - capturable=capturable, - differentiable=differentiable, - grad_scale=grad_scale, - found_inf=found_inf, - ) - - -class AdamTensorFlowStyle(optim.Adam): - @_use_grad_for_differentiable - def step(self, closure=None): - self._cuda_graph_capture_health_check() - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - - for group in self.param_groups: - params_with_grad = [] - grads = [] - exp_avgs = [] - exp_avg_sqs = [] - max_exp_avg_sqs = [] - state_steps = [] - beta1, beta2 = group["betas"] - - self._init_group( - group, - params_with_grad, - grads, - exp_avgs, - exp_avg_sqs, - max_exp_avg_sqs, - state_steps, - ) - - adam( - params_with_grad, - grads, - exp_avgs, - exp_avg_sqs, - max_exp_avg_sqs, - state_steps, - amsgrad=group["amsgrad"], - beta1=beta1, - beta2=beta2, - lr=group["lr"], - weight_decay=group["weight_decay"], - eps=group["eps"], - maximize=group["maximize"], - foreach=group["foreach"], - capturable=group["capturable"], - differentiable=group["differentiable"], - fused=group["fused"], - grad_scale=getattr(self, "grad_scale", None), - found_inf=getattr(self, "found_inf", None), - ) - - return loss - - -class AdaptiveKLController: - def __init__(self, init_kl_coef: float, hparams: AdaptiveKLParams): - self.value = init_kl_coef - self.hparams = hparams - - def update(self, current, n_steps): - target = self.hparams.target - proportional_error = np.clip(current / target - 1, -0.2, 0.2) - mult = 1 + proportional_error * n_steps / self.hparams.horizon - self.value *= mult - - -def layer_init(layer, std=np.sqrt(2), bias_const=0.0): - torch.nn.init.normal_(layer.weight, std=std) - torch.nn.init.constant_(layer.bias, val=bias_const) - return layer - - -def whiten(values, shift_mean=True): - # `unbiased=False` matches TF `tf.nn.moments`'s setting - mean, var = torch.mean(values), torch.var(values, unbiased=False) - whitened = (values - mean) * torch.rsqrt(var + 1e-8) - if not shift_mean: - whitened += mean - return whitened - - -class AutoModelForCausalLMWithScalarHead(nn.Module): - def __init__(self, lm_backbone): - super().__init__() - self.lm_backbone = lm_backbone - self.scalar_head = layer_init(nn.Linear(lm_backbone.config.hidden_size, 1), std=0) - - def forward(self, **kwargs): - output = self.lm_backbone(**kwargs) - return output, self.scalar_head(output.hidden_states[-1]) - - -class AutoModelForCausalLMWithRewardHead(nn.Module): - def __init__(self, lm_backbone): - super().__init__() - self.lm_backbone = lm_backbone - self.scalar_head = layer_init( - nn.Linear(lm_backbone.config.hidden_size, 1), - std=1 / np.sqrt(lm_backbone.config.hidden_size + 1), - ) - self.reward_gain = torch.nn.Parameter(torch.tensor(1.0), requires_grad=True) - self.reward_bias = torch.nn.Parameter(torch.tensor(0.0), requires_grad=True) - - -# a pytorch dataset -class MyDataset(IterableDataset): - def __init__(self, generator, tokenizer, query_length, seed, start_text=None, end_text=None): - self.generator = generator - self.tokenizer = tokenizer - self.query_length = query_length - self.start_text = start_text - self.end_text = end_text - self.seed = seed - token_to_index = tokenizer.get_vocab() - self.start_token = token_to_index[start_text] if self.start_text else None - self.end_token = token_to_index[end_text] if self.end_text else None - - def __iter__(self): - for text in self.generator("train", self.seed, shuffle=True): - tokens = self.tokenizer.encode(text) - if self.start_token is not None: - try: - first_index = tokens.index(self.start_token) + 1 - if first_index < len(tokens): - tokens = tokens[first_index:] - except: - continue - tokens = tokens[: self.query_length] - if self.end_token is not None: - try: - last_index = len(tokens) - tokens[::-1].index(self.end_token) - tokens = tokens[:last_index] - except: - continue - output = self.tokenizer.pad( - {"input_ids": tokens}, - padding="max_length", - max_length=self.query_length, - return_tensors="pt", - return_attention_mask=True, - ) - yield output - - -def right_padding_to_left_padding(query, pad_id): - # Convert from right padding to left padding. - return torch.tensor( - [[pad_id] * (row == pad_id).sum() + [x for x in row if x != pad_id] for row in query], - device=query.device, - ) - - -def ceil_div(a, b): - return (a - 1) // b + 1 - - -def exact_div(a, b): - q = a // b - if a != q * b: - raise ValueError(f"Inexact division: {a} / {b} = {a / b}") - return q - - -def generate(lm_backbone, queries, tokenizer, generation_config): - """generate in a way that does not affect padding tokens""" - context_length = queries.shape[1] - attention_mask = queries != tokenizer.pad_token_id - input_ids = queries.clone() - input_ids[~attention_mask] = 0 # set padding tokens to 0 - output = lm_backbone.generate( - input_ids=input_ids, - attention_mask=attention_mask, - # position_ids=attention_mask.cumsum(1) - attention_mask.long(), # generation collapsed if this was turned on. TODO: why does generation collapse with this? - generation_config=generation_config, - return_dict_in_generate=True, - ) - # restore padding tokens - return torch.cat((queries, output.sequences[:, context_length:]), dim=1) - - -def get_reward(reward_model, query_responses, tokenizer): - attention_mask = query_responses != tokenizer.pad_token_id - position_ids = attention_mask.cumsum(1) - attention_mask.long() # exclusive cumsum - input_ids = query_responses.clone() - input_ids[~attention_mask] = 0 - output = reward_model.lm_backbone( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - return_dict=True, - output_hidden_states=True, - ) - reward = reward_model.scalar_head(output.hidden_states[-1]) - reward = reward_model.reward_gain * reward + reward_model.reward_bias - # but we only care about the reward of the last token - reward = reward[:, -1] - return reward - - -def forward(policy, query_responses, tokenizer): - attention_mask = query_responses != tokenizer.pad_token_id - position_ids = attention_mask.cumsum(1) - attention_mask.long() # exclusive cumsum - input_ids = query_responses.clone() - input_ids[~attention_mask] = 0 - return policy( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - return_dict=True, - output_hidden_states=True, - ) - - -def train(args: Args): - accelerator = Accelerator(gradient_accumulation_steps=args.ppo.gradient_accumulation_steps) - args.ppo.world_size = accelerator.num_processes - args.ppo.batch_size = int(args.ppo.local_batch_size * args.ppo.world_size) - args.ppo.minibatch_size = exact_div(args.ppo.batch_size, args.ppo.nminibatches) - args.ppo.local_mini_batch_size = exact_div(args.ppo.local_batch_size, args.ppo.nminibatches) - args.ppo.local_micro_batch_size = exact_div(args.ppo.local_mini_batch_size, args.ppo.gradient_accumulation_steps) - if args.ppo.whiten_rewards: - assert ( - args.ppo.local_mini_batch_size >= 8 - ), f"Per-rank minibatch size {args.ppo.local_mini_batch_size} is insufficient for whitening" - # `per_rank_rollout_batch_size` is our `args.ppo.local_batch_size` - # `per_rank_minibatch_size` is our `args.ppo.local_mini_batch_size` - args.ppo.num_updates = args.ppo.total_episodes // args.ppo.batch_size - - console = Console(force_terminal=True) - run_name = f"{args.exp_name}__{args.seed}__{int(time.time())}" - writer = SimpleNamespace() # dummy writer - writer.add_scalar = lambda x, y, z: None - writer.add_histogram = lambda x, y, z: None - if accelerator.is_main_process: - if args.track: - import wandb - - wandb.init( - project=args.wandb_project_name, - entity=args.wandb_entity, - sync_tensorboard=True, - config=asdict(args), - name=run_name, - save_code=True, - ) - wandb.run.log_code(".") - writer = SummaryWriter(f"runs/{run_name}") - writer.add_text( - "hyperparameters", - "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])), - ) - pprint(args) - device = accelerator.device - local_seed = args.seed + accelerator.process_index * 100003 # Prime - random.seed(local_seed) - np.random.seed(local_seed) - torch.manual_seed(local_seed) - torch.backends.cudnn.deterministic = True - tokenizer = AutoTokenizer.from_pretrained( - args.base_model, - padding_side="right", - trust_remote_code=True, - ) - # we use the padding token manually but do not resize the token embedding of the model - tokenizer.add_special_tokens({"pad_token": "[PAD]"}) - reward_model = AutoModelForCausalLMWithRewardHead( - AutoModelForCausalLM.from_pretrained(args.base_model, trust_remote_code=True) - ) - if args.rewards.trained_model: - reward_model.load_state_dict(torch.load(args.rewards.trained_model, map_location=device)) - print(f"loaded pretrained reward model from {args.rewards.trained_model}") - # each class should have a separate pretrained model that do not share weights - ref_policy = AutoModelForCausalLMWithScalarHead( - AutoModelForCausalLM.from_pretrained(args.base_model, trust_remote_code=True) - ) - policy = AutoModelForCausalLMWithScalarHead(AutoModelForCausalLM.from_pretrained(args.base_model, trust_remote_code=True)) - policy.lm_backbone.generation_config.eos_token_id = ( - None # disable `pad_token_id` and `eos_token_id` because we just want to - ) - policy.lm_backbone.generation_config.pad_token_id = None # generate tokens without truncation / padding - # IMPORTANT: Layer norm produces weird gradients, which affects Adam optimizer to impact all the parameters systematically - # see https://github.com/pytorch/pytorch/issues/104857 for more details - if args.use_tensorflow_adam: - optimizer = AdamTensorFlowStyle(policy.parameters(), lr=args.ppo.lr, eps=args.ppo.eps) - else: - optimizer = optim.Adam(policy.parameters(), lr=args.ppo.lr, eps=args.ppo.eps) - dataset = MyDataset( - DATASET[args.task.query_dataset], - tokenizer, - args.task.query_length, - seed=local_seed, - start_text=args.task.start_text, - end_text=args.task.end_text, - ) - dataloader = DataLoader(dataset, batch_size=args.ppo.local_batch_size) - policy, optimizer, dataloader = accelerator.prepare(policy, optimizer, dataloader) - if args.deepspeed: - import deepspeed - - deepspeed_states = AcceleratorState().deepspeed_plugin - deepspeed_states.deepspeed_config["train_micro_batch_size_per_gpu"] = args.ppo.local_micro_batch_size - deepspeed_states.deepspeed_config["checkpoint"] = {"use_node_local_storage": True} - eval_ds_config = { - "train_micro_batch_size_per_gpu": deepspeed_states.deepspeed_config["train_micro_batch_size_per_gpu"], - "steps_per_print": 10, - # "zero_optimization": { - # "stage": stage, - # "stage3_param_persistence_threshold": 1e4, - # "offload_param": { - # "device": off_load_device - # } - # }, - "bf16": {"enabled": True}, - "prescale_gradients": False, - "wall_clock_breakdown": False, - } - reward_model, *_ = deepspeed.initialize(model=reward_model, config=eval_ds_config) - reward_model.eval() - ref_policy, *_ = deepspeed.initialize(model=ref_policy, config=eval_ds_config) - ref_policy.eval() - else: - ref_policy = ref_policy.to(device) - reward_model = reward_model.to(device) - iter_dataloader = iter(dataloader) - kl_ctl = AdaptiveKLController(args.rewards.kl_coef, hparams=args.rewards.adaptive_kl) - # WARNING: even with `max_new_tokens` and `min_new_tokens` set to the same value, the number of tokens generated - # may not be the same. TODO: investigate further, we just want to generate a fixed number of tokens - generation_config = GenerationConfig( - max_new_tokens=args.task.response_length, - min_new_tokens=args.task.response_length, - temperature=args.task.temperature, - top_k=0.0, - top_p=1.0, - do_sample=True, - ) - - print("===training policy===") - global_step = 0 - approxkls_stats = torch.zeros( - ( - args.ppo.noptepochs, - args.ppo.nminibatches, - args.ppo.gradient_accumulation_steps, - ), - device=device, - ) - clipfracs_stats = torch.zeros( - ( - args.ppo.noptepochs, - args.ppo.nminibatches, - args.ppo.gradient_accumulation_steps, - ), - device=device, - ) - pg_losses_stats = torch.zeros( - ( - args.ppo.noptepochs, - args.ppo.nminibatches, - args.ppo.gradient_accumulation_steps, - ), - device=device, - ) - vf_losses_stats = torch.zeros( - ( - args.ppo.noptepochs, - args.ppo.nminibatches, - args.ppo.gradient_accumulation_steps, - ), - device=device, - ) - vf_clipfrac_stats = torch.zeros( - ( - args.ppo.noptepochs, - args.ppo.nminibatches, - args.ppo.gradient_accumulation_steps, - ), - device=device, - ) - entropies_stats = torch.zeros( - ( - args.ppo.noptepochs, - args.ppo.nminibatches, - args.ppo.gradient_accumulation_steps, - ), - device=device, - ) - ratio_stats = torch.zeros( - ( - args.ppo.noptepochs, - args.ppo.nminibatches, - args.ppo.gradient_accumulation_steps, - ), - device=device, - ) - for update in range(1, args.ppo.num_updates + 1): - global_step += 1 * args.ppo.batch_size - frac = 1.0 - (update - 1.0) / args.ppo.num_updates - lrnow = frac * args.ppo.lr - optimizer.param_groups[0]["lr"] = lrnow - data = next(iter_dataloader) - with torch.no_grad(): - queries = data["input_ids"].to(device) - queries = right_padding_to_left_padding(data["input_ids"], tokenizer.pad_token_id).to(device) - query_responses = generate( - accelerator.unwrap_model(policy).lm_backbone, - queries, - tokenizer, - generation_config, - ) - context_length = queries.shape[1] - responses = query_responses[:, context_length:] - - output, full_values = forward(policy, query_responses, tokenizer) - values = full_values[:, context_length - 1 : -1].squeeze(-1) - logits = output.logits[:, context_length - 1 : -1] - logits /= args.task.temperature - all_logprobs = F.log_softmax(logits, dim=-1) - logprobs = torch.gather(all_logprobs, 2, responses.unsqueeze(-1)).squeeze(-1) - del output, logits, all_logprobs - torch.cuda.empty_cache() - - ref_output, _ = forward(ref_policy, query_responses, tokenizer) - ref_logits = ref_output.logits[:, context_length - 1 : -1] - ref_logits /= args.task.temperature - ref_all_logprobs = F.log_softmax(ref_logits, dim=-1) - ref_logprobs = torch.gather(ref_all_logprobs, 2, responses.unsqueeze(-1)).squeeze(-1) - del ref_output, ref_logits, ref_all_logprobs - torch.cuda.empty_cache() - - # **Response Processing** - # 1. truncate at the first occurrence of `truncate_token` that appears at or after - # position truncate_after in the responses - # https://github.com/openai/lm-human-preferences/blob/cbfd210bb8b08f6bc5c26878c10984b90f516c66/lm_human_preferences/train_policy.py#L378 - truncate_token_mask = responses == args.task.truncate_token - truncate_after_or_token_mask = torch.cat( - [ - torch.zeros_like(truncate_token_mask)[:, : args.task.truncate_after], - truncate_token_mask[:, args.task.truncate_after :], - ], - dim=1, - ) - truncate_mask = (torch.cumsum(truncate_after_or_token_mask, dim=1) - truncate_after_or_token_mask.long()).bool() - postprocessed_responses = torch.where( - truncate_mask, - torch.full_like(responses, tokenizer.pad_token_id), - responses, - ) - del truncate_token_mask, truncate_after_or_token_mask, truncate_mask - torch.cuda.empty_cache() - - # 2. run reward model on the truncated responses - postprocessed_query_responses = torch.cat((queries, postprocessed_responses), 1) - postprocessed_query_responses = right_padding_to_left_padding( - postprocessed_query_responses, tokenizer.pad_token_id - ) - scores = get_reward(reward_model, postprocessed_query_responses, tokenizer).flatten() - - # 3. filter response. Ensure that the sample contains truncate_token - # responses not passing that filter will receive a low (fixed) score - # only query humans on responses that pass that filter - matches_token = postprocessed_responses[:, args.task.truncate_after :] == args.task.truncate_token - filter_mask = torch.any(matches_token, dim=-1) - scores = torch.where( - filter_mask, - scores, - torch.full_like(scores, args.task.penalty_reward_value), - ) - del matches_token, filter_mask - torch.cuda.empty_cache() - - # 4. compute rewards - kl = logprobs - ref_logprobs - non_score_reward = -kl_ctl.value * kl - rewards = non_score_reward.clone() - rewards[:, -1] += scores - - # 5. whiten rewards - if args.ppo.whiten_rewards: - rewards = whiten(rewards, shift_mean=False) - try: - sample_kl = kl[0].sum().item() - postprocessed_responses = postprocessed_query_responses[:, context_length:] - console.print( - f"[green]{tokenizer.decode(queries[0], skip_special_tokens=True)}[/]\n[yellow]{tokenizer.decode(postprocessed_responses[0], skip_special_tokens=True)}[/]\n[blue](NO POST-PROCESSING){tokenizer.decode(responses[0], skip_special_tokens=True)}[/]\n[red]score: {scores[0]}, kl: {kl[0].sum().item()}, total reward: {scores[0] - kl_ctl.value * sample_kl} [/]" - ) - except Exception as e: - print(e) - del postprocessed_query_responses - torch.cuda.empty_cache() - - # 6. compute advantages and returns - lastgaelam = 0 - advantages_reversed = [] - gen_length = args.task.response_length - for t in reversed(range(gen_length)): - nextvalues = values[:, t + 1] if t < gen_length - 1 else 0.0 - delta = rewards[:, t] + args.ppo.gamma * nextvalues - values[:, t] - lastgaelam = delta + args.ppo.gamma * args.ppo.lam * lastgaelam - advantages_reversed.append(lastgaelam) - advantages = torch.stack(advantages_reversed[::-1], axis=1) - returns = advantages + values - advantages = whiten(advantages) - return_mean, return_var = returns.mean(), returns.var() - value_mean, value_var = values.mean(), values.var() - - # Do multiple epochs of PPO training, with a fresh random shuffle in each epoch - re_calculated_logprobs = torch.zeros_like(logprobs) - re_calculated_values = torch.zeros_like(values) - for ppo_epoch_idx in range(args.ppo.noptepochs): - b_inds = np.random.permutation(args.ppo.local_batch_size) - minibatch_idx = 0 - for mini_batch_start in range(0, args.ppo.local_batch_size, args.ppo.local_mini_batch_size): - mini_batch_end = mini_batch_start + args.ppo.local_mini_batch_size - mini_batch_inds = b_inds[mini_batch_start:mini_batch_end] - gradient_accumulation_idx = 0 - for micro_batch_start in range(0, args.ppo.local_mini_batch_size, args.ppo.local_micro_batch_size): - micro_batch_end = micro_batch_start + args.ppo.local_micro_batch_size - micro_batch_inds = mini_batch_inds[micro_batch_start:micro_batch_end] - mb_return = returns[micro_batch_inds] - mb_advantage = advantages[micro_batch_inds] - mb_responses = responses[micro_batch_inds] - mb_query_responses = query_responses[micro_batch_inds] - - # re-calculate logprobs and values for the first epoch, otherwise `bf16` will cause the logprobs to - # be much different because the logprobs are with a batch size of `local_batch_size` but the - # `new_logprobs` are with a batch size of `local_micro_batch_size` - if ppo_epoch_idx == 0: - with torch.no_grad(): - output, vpred_temp = forward(policy, mb_query_responses, tokenizer) - logits = output.logits[:, context_length - 1 : -1] - logits /= args.task.temperature - new_all_logprobs = F.log_softmax(logits, dim=-1) - new_logprobs = torch.gather(new_all_logprobs, 2, mb_responses.unsqueeze(-1)).squeeze(-1) - vpred = vpred_temp[:, context_length - 1 : -1].squeeze(-1) - re_calculated_logprobs[micro_batch_inds] = new_logprobs - re_calculated_values[micro_batch_inds] = vpred - del output, logits, new_all_logprobs - mb_values = re_calculated_values[micro_batch_inds] - mb_logprobs = re_calculated_logprobs[micro_batch_inds] - - with accelerator.accumulate(policy): - output, vpred_temp = forward(policy, mb_query_responses, tokenizer) - logits = output.logits[:, context_length - 1 : -1] - logits /= args.task.temperature - new_all_logprobs = F.log_softmax(logits, dim=-1) - new_logprobs = torch.gather(new_all_logprobs, 2, mb_responses.unsqueeze(-1)).squeeze(-1) - vpred = vpred_temp[:, context_length - 1 : -1].squeeze(-1) - vpredclipped = torch.clamp( - vpred, - mb_values - args.ppo.cliprange_value, - mb_values + args.ppo.cliprange_value, - ) - vf_losses1 = torch.square(vpred - mb_return) - vf_losses2 = torch.square(vpredclipped - mb_return) - vf_loss = 0.5 * torch.max(vf_losses1, vf_losses2).mean() - vf_clipfrac = (vf_losses2 > vf_losses1).float().mean() - logprobs_diff = new_logprobs - mb_logprobs - ratio = torch.exp(logprobs_diff) - pg_losses = -mb_advantage * ratio - pg_losses2 = -mb_advantage * torch.clamp(ratio, 1.0 - args.ppo.cliprange, 1.0 + args.ppo.cliprange) - pg_loss = torch.max(pg_losses, pg_losses2).mean() - pg_clipfrac = (pg_losses2 > pg_losses).float().mean() - loss = pg_loss + args.ppo.vf_coef * vf_loss - accelerator.backward(loss) - optimizer.step() - optimizer.zero_grad() - pd = torch.nn.functional.softmax(logits, dim=-1) - entropy = torch.logsumexp(logits, dim=-1) - torch.sum(pd * logits, dim=-1) - approxkl = 0.5 * (logprobs_diff**2).mean() - with torch.no_grad(): - approxkls_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = approxkl - clipfracs_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = pg_clipfrac - pg_losses_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = pg_loss - vf_losses_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = vf_loss - vf_clipfrac_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = vf_clipfrac - entropies_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = entropy.mean() - ratio_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = ratio.mean() - gradient_accumulation_idx += 1 - minibatch_idx += 1 - if accelerator.is_main_process: - console.print( - f"ppo_epoch_idx", - ppo_epoch_idx, - "approxkl", - approxkl.item(), - "pg_loss", - pg_loss.item(), - "pg_clipfrac", - pg_clipfrac.item(), - "ratio", - ratio.mean().item(), - ) - breakpoint() - if accelerator.is_main_process: - console.print("ratio_stats", ratio_stats.mean()) - breakpoint() - - with torch.no_grad(): - if not args.deepspeed: # for some reason there is a OOM with the `writer.add_histogram` - writer.add_histogram("ppo/val/ratio_hist", ratio, update) - kl = logprobs - ref_logprobs - mean_kl = kl.sum(1).mean() - mean_entropy = (-logprobs).sum(1).mean() - mean_non_score_reward = non_score_reward.sum(1).mean() - writer.add_scalar("objective/kl_coef", kl_ctl.value, update) - writer.add_scalar("objective/kl", accelerator.gather(mean_kl).mean().item(), update) - writer.add_scalar( - "objective/entropy", - accelerator.gather(mean_entropy).mean().item(), - update, - ) - writer.add_scalar( - "objective/non_score_reward", - accelerator.gather(mean_non_score_reward).mean().item(), - update, - ) - writer.add_scalar( - "objective/score_total", - accelerator.gather(mean_non_score_reward + scores.mean()).mean().item(), - update, - ) - writer.add_scalar( - "objective/scores", - accelerator.gather(scores.mean()).mean().item(), - update, - ) - writer.add_scalar("ppo/loss/policy", accelerator.gather(pg_loss).mean().item(), update) - writer.add_scalar("ppo/loss/value", accelerator.gather(vf_loss).mean().item(), update) - writer.add_scalar("ppo/loss/total", accelerator.gather(loss).mean().item(), update) - writer.add_scalar( - "ppo/policy/entropy", - accelerator.gather(entropy.mean()).mean().item(), - update, - ) - writer.add_scalar( - "ppo/policy/approxkl", - accelerator.gather(approxkl).mean().item(), - update, - ) - writer.add_scalar( - "ppo/policy/clipfrac", - accelerator.gather(pg_clipfrac).mean().item(), - update, - ) - writer.add_scalar( - "ppo/policy/approxkl_avg", - accelerator.gather(approxkls_stats).mean().item(), - update, - ) - writer.add_scalar( - "ppo/policy/clipfrac_avg", - accelerator.gather(clipfracs_stats).mean().item(), - update, - ) - writer.add_scalar( - "ppo/loss/policy_avg", - accelerator.gather(pg_losses_stats).mean().item(), - update, - ) - writer.add_scalar( - "ppo/loss/value_avg", - accelerator.gather(vf_losses_stats).mean().item(), - update, - ) - writer.add_scalar( - "ppo/val/clipfrac_avg", - accelerator.gather(vf_clipfrac_stats).mean().item(), - update, - ) - writer.add_scalar( - "ppo/policy/entropy_avg", - accelerator.gather(entropies_stats).mean().item(), - update, - ) - writer.add_scalar( - "ppo/returns/mean", - accelerator.gather(return_mean).mean().item(), - update, - ) - writer.add_scalar("ppo/returns/var", accelerator.gather(return_var).mean().item(), update) - writer.add_scalar("ppo/val/vpred", accelerator.gather(vpred.mean()).mean().item(), update) - writer.add_scalar( - "ppo/val/error", - accelerator.gather(vf_losses1.mean()).mean().item(), - update, - ) - writer.add_scalar( - "ppo/val/clipfrac", - accelerator.gather(vf_clipfrac).mean().item(), - update, - ) - writer.add_scalar("ppo/val/mean", accelerator.gather(value_mean).mean().item(), update) - writer.add_scalar("ppo/val/var", accelerator.gather(value_var).mean().item(), update) - writer.add_scalar("ppo/val/ratio", accelerator.gather(ratio.mean()).mean().item(), update) - writer.add_scalar( - "ppo/val/ratio_var", - accelerator.gather(ratio.mean()).var().item(), - update, - ) - writer.add_scalar( - "ppo/val/advantage", - accelerator.gather(advantages.mean()).mean().item(), - update, - ) - writer.add_scalar( - "ppo/val/advantage_var", - accelerator.gather(advantages.mean()).var().item(), - update, - ) - writer.add_scalar( - "ppo/val/num_eos_tokens", - (responses == tokenizer.eos_token_id).sum().item(), - update, - ) - writer.add_scalar("ppo/lr", lrnow, update) - writer.add_scalar("ppo/episode", global_step, update) - kl_ctl.update(mean_kl.item(), args.ppo.batch_size) - del kl, mean_kl, mean_entropy, mean_non_score_reward, scores - - # save model - if args.save_path: - os.makedirs(os.path.dirname(args.save_path), exist_ok=True) - torch.save(reward_model.state_dict(), args.save_path) - - -if __name__ == "__main__": - args = tyro.cli(Args) - train(args) diff --git a/lm_human_preference_details/summarization/train_policy_accelerate_old.py b/lm_human_preference_details/summarization/train_policy_accelerate_old.py deleted file mode 100644 index 1c9fc5a..0000000 --- a/lm_human_preference_details/summarization/train_policy_accelerate_old.py +++ /dev/null @@ -1,922 +0,0 @@ -import os -import random -import time -from dataclasses import asdict, dataclass, field -from types import SimpleNamespace -from typing import Optional - -import numpy as np -import torch -import torch.nn as nn -import torch.nn.functional as F -import torch.optim as optim -import tyro -from accelerate import Accelerator -from accelerate.state import AcceleratorState -from rich.console import Console -from rich.pretty import pprint -from torch.utils.data import DataLoader, IterableDataset -from torch.utils.tensorboard import SummaryWriter -from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig - -from lm_human_preference_details.data import DATASET - - -@dataclass -class AdaptiveKLParams: - target: float = 6.0 - horizon: int = 10000 # in episodes - - -@dataclass -class RewardHParams: - kl_coef: float = 0.15 - adaptive_kl: Optional[AdaptiveKLParams] = field(default_factory=AdaptiveKLParams) - trained_model: Optional[str] = "models/reward.pt" - label_dataset: tyro.conf.Suppress[Optional[str]] = None - - -@dataclass -class PpoHParams: - total_episodes: int = 1000000 - local_batch_size: int = 64 - local_mini_batch_size: tyro.conf.Suppress[int] = None - batch_size: tyro.conf.Suppress[int] = None - mini_batch_size: tyro.conf.Suppress[int] = None - gradient_accumulation_steps: int = 1 - """gradient accumulation steps""" - local_micro_batch_size: tyro.conf.Suppress[int] = None - """per rank micro batch size""" - world_size: tyro.conf.Suppress[int] = None - batch_size: tyro.conf.Suppress[int] = None - minibatch_size: tyro.conf.Suppress[int] = None - num_updates: tyro.conf.Suppress[int] = None - nminibatches: int = 1 - noptepochs: int = 4 - lr: float = 0.00001 - eps: float = 1e-5 - vf_coef: float = 0.1 - cliprange: float = 0.2 - cliprange_value: float = 0.2 - gamma: float = 1 - lam: float = 0.95 - whiten_rewards: bool = True - - -@dataclass -class TaskHParams: - # Query params - query_length: int = 64 - query_dataset: str = "books" - query_prefix: str = "" - query_suffix: str = "" - start_text: Optional[str] = None - end_text: Optional[str] = None - - # Response params - response_length: int = 24 - - # Truncate response after the first occurrence of this token at or after index after when sampling. - truncate_token: int = 13 - truncate_after: int = 16 - penalty_reward_value: int = -1 - - # LM params - temperature: float = 0.7 - - -@dataclass -class Args: - # common args - exp_name: str = os.path.basename(__file__)[: -len(".py")] - """the name of this experiment""" - seed: int = 1 - """seed of the experiment""" - track: bool = False - """if toggled, this experiment will be tracked with Weights and Biases""" - wandb_project_name: str = "cleanrl" - """the wandb's project name""" - wandb_entity: Optional[str] = None - """the entity (team) of wandb's project""" - cuda: bool = True - """Whether to use cuda if available.""" - run_name: tyro.conf.Suppress[str] = None - """TO BE FILLED: a unique name of this run""" - - base_model: str = "gpt2" - """the name of the pretrained model to use""" - deepspeed: bool = False - """Whether to use deepspeed to train the model""" - print_sample_output_freq: int = 0 - """How often to print sample output""" - save_path: str = "models/policy.pt" - """Where to save the model""" - use_tensorflow_adam: bool = True - """Whether to use tensorflow-style Adam optimizer instead of PyTorch's""" - task: TaskHParams = field(default_factory=TaskHParams) - rewards: RewardHParams = field(default_factory=RewardHParams) - ppo: PpoHParams = field(default_factory=PpoHParams) - - -from typing import List, Optional - -from torch import Tensor, optim -from torch.optim.optimizer import ( - _dispatch_sqrt, - _get_value, - _use_grad_for_differentiable, -) - - -def _single_tensor_adam( - params: List[Tensor], - grads: List[Tensor], - exp_avgs: List[Tensor], - exp_avg_sqs: List[Tensor], - max_exp_avg_sqs: List[Tensor], - state_steps: List[Tensor], - grad_scale: Optional[Tensor], - found_inf: Optional[Tensor], - *, - amsgrad: bool, - beta1: float, - beta2: float, - lr: float, - weight_decay: float, - eps: float, - maximize: bool, - capturable: bool, - differentiable: bool, -): - assert grad_scale is None and found_inf is None - - for i, param in enumerate(params): - grad = grads[i] if not maximize else -grads[i] - exp_avg = exp_avgs[i] - exp_avg_sq = exp_avg_sqs[i] - step_t = state_steps[i] - # update step - step_t += 1 - # Decay the first and second moment running average coefficient - exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) - exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2) - step = _get_value(step_t) - - ### pytorch adam implementation: - # bias_correction1 = 1 - beta1 ** step - # bias_correction2 = 1 - beta2 ** step - # step_size = lr / bias_correction1 - # bias_correction2_sqrt = _dispatch_sqrt(bias_correction2) - # denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps) - # param.addcdiv_(exp_avg, denom, value=-step_size) - - ### tensorflow adam implementation: - lr_t = lr * _dispatch_sqrt(1 - beta2**step) / (1 - beta1**step) - denom = exp_avg_sq.sqrt().add_(eps) - param.addcdiv_(exp_avg, denom, value=-lr_t) - - -def adam( - params: List[Tensor], - grads: List[Tensor], - exp_avgs: List[Tensor], - exp_avg_sqs: List[Tensor], - max_exp_avg_sqs: List[Tensor], - state_steps: List[Tensor], - # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627 - # setting this as kwarg for now as functional API is compiled by torch/distributed/optim - foreach: Optional[bool] = None, - capturable: bool = False, - differentiable: bool = False, - fused: Optional[bool] = None, - grad_scale: Optional[Tensor] = None, - found_inf: Optional[Tensor] = None, - *, - amsgrad: bool, - beta1: float, - beta2: float, - lr: float, - weight_decay: float, - eps: float, - maximize: bool, -): - func = _single_tensor_adam - - func( - params, - grads, - exp_avgs, - exp_avg_sqs, - max_exp_avg_sqs, - state_steps, - amsgrad=amsgrad, - beta1=beta1, - beta2=beta2, - lr=lr, - weight_decay=weight_decay, - eps=eps, - maximize=maximize, - capturable=capturable, - differentiable=differentiable, - grad_scale=grad_scale, - found_inf=found_inf, - ) - - -class AdamTensorFlowStyle(optim.Adam): - @_use_grad_for_differentiable - def step(self, closure=None): - self._cuda_graph_capture_health_check() - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - - for group in self.param_groups: - params_with_grad = [] - grads = [] - exp_avgs = [] - exp_avg_sqs = [] - max_exp_avg_sqs = [] - state_steps = [] - beta1, beta2 = group["betas"] - - self._init_group( - group, - params_with_grad, - grads, - exp_avgs, - exp_avg_sqs, - max_exp_avg_sqs, - state_steps, - ) - - adam( - params_with_grad, - grads, - exp_avgs, - exp_avg_sqs, - max_exp_avg_sqs, - state_steps, - amsgrad=group["amsgrad"], - beta1=beta1, - beta2=beta2, - lr=group["lr"], - weight_decay=group["weight_decay"], - eps=group["eps"], - maximize=group["maximize"], - foreach=group["foreach"], - capturable=group["capturable"], - differentiable=group["differentiable"], - fused=group["fused"], - grad_scale=getattr(self, "grad_scale", None), - found_inf=getattr(self, "found_inf", None), - ) - - return loss - - -class AdaptiveKLController: - def __init__(self, init_kl_coef: float, hparams: AdaptiveKLParams): - self.value = init_kl_coef - self.hparams = hparams - - def update(self, current, n_steps): - target = self.hparams.target - proportional_error = np.clip(current / target - 1, -0.2, 0.2) - mult = 1 + proportional_error * n_steps / self.hparams.horizon - self.value *= mult - - -def layer_init(layer, std=np.sqrt(2), bias_const=0.0): - torch.nn.init.normal_(layer.weight, std=std) - torch.nn.init.constant_(layer.bias, val=bias_const) - return layer - - -def whiten(values, shift_mean=True): - # `unbiased=False` matches TF `tf.nn.moments`'s setting - mean, var = torch.mean(values), torch.var(values, unbiased=False) - whitened = (values - mean) * torch.rsqrt(var + 1e-8) - if not shift_mean: - whitened += mean - return whitened - - -class AutoModelForCausalLMWithScalarHead(nn.Module): - def __init__(self, lm_backbone): - super().__init__() - self.lm_backbone = lm_backbone - self.scalar_head = layer_init(nn.Linear(lm_backbone.config.hidden_size, 1), std=0) - - def forward(self, **kwargs): - output = self.lm_backbone(**kwargs) - return output, self.scalar_head(output.hidden_states[-1]) - - -class AutoModelForCausalLMWithRewardHead(nn.Module): - def __init__(self, lm_backbone): - super().__init__() - self.lm_backbone = lm_backbone - self.scalar_head = layer_init( - nn.Linear(lm_backbone.config.hidden_size, 1), - std=1 / np.sqrt(lm_backbone.config.hidden_size + 1), - ) - self.reward_gain = torch.nn.Parameter(torch.tensor(1.0), requires_grad=True) - self.reward_bias = torch.nn.Parameter(torch.tensor(0.0), requires_grad=True) - - -# a pytorch dataset -class MyDataset(IterableDataset): - def __init__(self, generator, tokenizer, query_length, seed, start_text=None, end_text=None): - self.generator = generator - self.tokenizer = tokenizer - self.query_length = query_length - self.start_text = start_text - self.end_text = end_text - self.seed = seed - token_to_index = tokenizer.get_vocab() - self.start_token = token_to_index[start_text] if self.start_text else None - self.end_token = token_to_index[end_text] if self.end_text else None - - def __iter__(self): - for text in self.generator("train", self.seed, shuffle=True): - tokens = self.tokenizer.encode(text) - if self.start_token is not None: - try: - first_index = tokens.index(self.start_token) + 1 - if first_index < len(tokens): - tokens = tokens[first_index:] - except: - continue - tokens = tokens[: self.query_length] - if self.end_token is not None: - try: - last_index = len(tokens) - tokens[::-1].index(self.end_token) - tokens = tokens[:last_index] - except: - continue - output = self.tokenizer.pad( - {"input_ids": tokens}, - padding="max_length", - max_length=self.query_length, - return_tensors="pt", - return_attention_mask=True, - ) - yield output - - -def right_padding_to_left_padding(query, pad_id): - # Convert from right padding to left padding. - return torch.tensor( - [[pad_id] * (row == pad_id).sum() + [x for x in row if x != pad_id] for row in query], - device=query.device, - ) - - -def ceil_div(a, b): - return (a - 1) // b + 1 - - -def exact_div(a, b): - q = a // b - if a != q * b: - raise ValueError(f"Inexact division: {a} / {b} = {a / b}") - return q - - -def generate(lm_backbone, queries, tokenizer, generation_config): - """generate in a way that does not affect padding tokens""" - context_length = queries.shape[1] - attention_mask = queries != tokenizer.pad_token_id - input_ids = queries.clone() - input_ids[~attention_mask] = 0 # set padding tokens to 0 - output = lm_backbone.generate( - input_ids=input_ids, - attention_mask=attention_mask, - # position_ids=attention_mask.cumsum(1) - attention_mask.long(), # generation collapsed if this was turned on. TODO: why does generation collapse with this? - generation_config=generation_config, - return_dict_in_generate=True, - ) - # restore padding tokens - return torch.cat((queries, output.sequences[:, context_length:]), dim=1) - - -def get_reward(reward_model, query_responses, tokenizer): - attention_mask = query_responses != tokenizer.pad_token_id - position_ids = attention_mask.cumsum(1) - attention_mask.long() # exclusive cumsum - input_ids = query_responses.clone() - input_ids[~attention_mask] = 0 - output = reward_model.lm_backbone( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - return_dict=True, - output_hidden_states=True, - ) - reward = reward_model.scalar_head(output.hidden_states[-1]) - reward = reward_model.reward_gain * reward + reward_model.reward_bias - # but we only care about the reward of the last token - reward = reward[:, -1] - return reward - - -def forward(policy, query_responses, tokenizer): - attention_mask = query_responses != tokenizer.pad_token_id - position_ids = attention_mask.cumsum(1) - attention_mask.long() # exclusive cumsum - input_ids = query_responses.clone() - input_ids[~attention_mask] = 0 - return policy( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - return_dict=True, - output_hidden_states=True, - ) - - -def train(args: Args): - accelerator = Accelerator(gradient_accumulation_steps=args.ppo.gradient_accumulation_steps) - args.ppo.world_size = accelerator.num_processes - args.ppo.batch_size = int(args.ppo.local_batch_size * args.ppo.world_size) - args.ppo.minibatch_size = exact_div(args.ppo.batch_size, args.ppo.nminibatches) - args.ppo.local_mini_batch_size = exact_div(args.ppo.local_batch_size, args.ppo.nminibatches) - args.ppo.local_micro_batch_size = exact_div(args.ppo.local_mini_batch_size, args.ppo.gradient_accumulation_steps) - if args.ppo.whiten_rewards: - assert ( - args.ppo.local_mini_batch_size >= 8 - ), f"Per-rank minibatch size {args.ppo.local_mini_batch_size} is insufficient for whitening" - # `per_rank_rollout_batch_size` is our `args.ppo.local_batch_size` - # `per_rank_minibatch_size` is our `args.ppo.local_mini_batch_size` - args.ppo.num_updates = args.ppo.total_episodes // args.ppo.batch_size - - console = Console(force_terminal=True) - run_name = f"{args.exp_name}__{args.seed}__{int(time.time())}" - writer = SimpleNamespace() # dummy writer - writer.add_scalar = lambda x, y, z: None - writer.add_histogram = lambda x, y, z: None - if accelerator.is_main_process: - if args.track: - import wandb - - wandb.init( - project=args.wandb_project_name, - entity=args.wandb_entity, - sync_tensorboard=True, - config=asdict(args), - name=run_name, - save_code=True, - ) - wandb.run.log_code(".") - writer = SummaryWriter(f"runs/{run_name}") - writer.add_text( - "hyperparameters", - "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])), - ) - pprint(args) - device = accelerator.device - local_seed = args.seed + accelerator.process_index * 100003 # Prime - random.seed(local_seed) - np.random.seed(local_seed) - torch.manual_seed(local_seed) - torch.backends.cudnn.deterministic = True - tokenizer = AutoTokenizer.from_pretrained( - args.base_model, - padding_side="right", - trust_remote_code=True, - ) - # we use the padding token manually but do not resize the token embedding of the model - tokenizer.add_special_tokens({"pad_token": "[PAD]"}) - reward_model = AutoModelForCausalLMWithRewardHead( - AutoModelForCausalLM.from_pretrained(args.base_model, trust_remote_code=True) - ) - if args.rewards.trained_model: - reward_model.load_state_dict(torch.load(args.rewards.trained_model, map_location=device)) - print(f"loaded pretrained reward model from {args.rewards.trained_model}") - # each class should have a separate pretrained model that do not share weights - ref_policy = AutoModelForCausalLMWithScalarHead( - AutoModelForCausalLM.from_pretrained(args.base_model, trust_remote_code=True) - ) - policy = AutoModelForCausalLMWithScalarHead(AutoModelForCausalLM.from_pretrained(args.base_model, trust_remote_code=True)) - policy.lm_backbone.generation_config.eos_token_id = ( - None # disable `pad_token_id` and `eos_token_id` because we just want to - ) - policy.lm_backbone.generation_config.pad_token_id = None # generate tokens without truncation / padding - # IMPORTANT: Layer norm produces weird gradients, which affects Adam optimizer to impact all the parameters systematically - # see https://github.com/pytorch/pytorch/issues/104857 for more details - if args.use_tensorflow_adam: - optimizer = AdamTensorFlowStyle(policy.parameters(), lr=args.ppo.lr, eps=args.ppo.eps) - else: - optimizer = optim.Adam(policy.parameters(), lr=args.ppo.lr, eps=args.ppo.eps) - dataset = MyDataset( - DATASET[args.task.query_dataset], - tokenizer, - args.task.query_length, - seed=local_seed, - start_text=args.task.start_text, - end_text=args.task.end_text, - ) - dataloader = DataLoader(dataset, batch_size=args.ppo.local_batch_size) - policy, optimizer, dataloader = accelerator.prepare(policy, optimizer, dataloader) - if args.deepspeed: - import deepspeed - - deepspeed_states = AcceleratorState().deepspeed_plugin - deepspeed_states.deepspeed_config["train_micro_batch_size_per_gpu"] = args.ppo.local_micro_batch_size - deepspeed_states.deepspeed_config["checkpoint"] = {"use_node_local_storage": True} - eval_ds_config = { - "train_micro_batch_size_per_gpu": deepspeed_states.deepspeed_config["train_micro_batch_size_per_gpu"], - "steps_per_print": 10, - # "zero_optimization": { - # "stage": stage, - # "stage3_param_persistence_threshold": 1e4, - # "offload_param": { - # "device": off_load_device - # } - # }, - "bf16": {"enabled": True}, - "prescale_gradients": False, - "wall_clock_breakdown": False, - } - reward_model, *_ = deepspeed.initialize(model=reward_model, config=eval_ds_config) - reward_model.eval() - ref_policy, *_ = deepspeed.initialize(model=ref_policy, config=eval_ds_config) - ref_policy.eval() - else: - ref_policy = ref_policy.to(device) - reward_model = reward_model.to(device) - iter_dataloader = iter(dataloader) - kl_ctl = AdaptiveKLController(args.rewards.kl_coef, hparams=args.rewards.adaptive_kl) - # WARNING: even with `max_new_tokens` and `min_new_tokens` set to the same value, the number of tokens generated - # may not be the same. TODO: investigate further, we just want to generate a fixed number of tokens - generation_config = GenerationConfig( - max_new_tokens=args.task.response_length, - min_new_tokens=args.task.response_length, - temperature=args.task.temperature, - top_k=0.0, - top_p=1.0, - do_sample=True, - ) - - print("===training policy===") - global_step = 0 - approxkls_stats = torch.zeros( - ( - args.ppo.noptepochs, - args.ppo.nminibatches, - args.ppo.gradient_accumulation_steps, - ), - device=device, - ) - clipfracs_stats = torch.zeros( - ( - args.ppo.noptepochs, - args.ppo.nminibatches, - args.ppo.gradient_accumulation_steps, - ), - device=device, - ) - pg_losses_stats = torch.zeros( - ( - args.ppo.noptepochs, - args.ppo.nminibatches, - args.ppo.gradient_accumulation_steps, - ), - device=device, - ) - vf_losses_stats = torch.zeros( - ( - args.ppo.noptepochs, - args.ppo.nminibatches, - args.ppo.gradient_accumulation_steps, - ), - device=device, - ) - vf_clipfrac_stats = torch.zeros( - ( - args.ppo.noptepochs, - args.ppo.nminibatches, - args.ppo.gradient_accumulation_steps, - ), - device=device, - ) - entropies_stats = torch.zeros( - ( - args.ppo.noptepochs, - args.ppo.nminibatches, - args.ppo.gradient_accumulation_steps, - ), - device=device, - ) - for update in range(1, args.ppo.num_updates + 1): - global_step += 1 * args.ppo.batch_size - frac = 1.0 - (update - 1.0) / args.ppo.num_updates - lrnow = frac * args.ppo.lr - optimizer.param_groups[0]["lr"] = lrnow - data = next(iter_dataloader) - with torch.no_grad(): - queries = data["input_ids"].to(device) - queries = right_padding_to_left_padding(data["input_ids"], tokenizer.pad_token_id).to(device) - query_responses = generate( - accelerator.unwrap_model(policy).lm_backbone, - queries, - tokenizer, - generation_config, - ) - context_length = queries.shape[1] - responses = query_responses[:, context_length:] - - output, full_values = forward(policy, query_responses, tokenizer) - values = full_values[:, context_length - 1 : -1].squeeze(-1) - logits = output.logits[:, context_length - 1 : -1] - logits /= args.task.temperature - all_logprobs = F.log_softmax(logits, dim=-1) - logprobs = torch.gather(all_logprobs, 2, responses.unsqueeze(-1)).squeeze(-1) - del output, logits, all_logprobs - torch.cuda.empty_cache() - - ref_output, _ = forward(ref_policy, query_responses, tokenizer) - ref_logits = ref_output.logits[:, context_length - 1 : -1] - ref_logits /= args.task.temperature - ref_all_logprobs = F.log_softmax(ref_logits, dim=-1) - ref_logprobs = torch.gather(ref_all_logprobs, 2, responses.unsqueeze(-1)).squeeze(-1) - del ref_output, ref_logits, ref_all_logprobs - torch.cuda.empty_cache() - - # **Response Processing** - # 1. truncate at the first occurrence of `truncate_token` that appears at or after - # position truncate_after in the responses - # https://github.com/openai/lm-human-preferences/blob/cbfd210bb8b08f6bc5c26878c10984b90f516c66/lm_human_preferences/train_policy.py#L378 - truncate_token_mask = responses == args.task.truncate_token - truncate_after_or_token_mask = torch.cat( - [ - torch.zeros_like(truncate_token_mask)[:, : args.task.truncate_after], - truncate_token_mask[:, args.task.truncate_after :], - ], - dim=1, - ) - truncate_mask = (torch.cumsum(truncate_after_or_token_mask, dim=1) - truncate_after_or_token_mask.long()).bool() - postprocessed_responses = torch.where( - truncate_mask, - torch.full_like(responses, tokenizer.pad_token_id), - responses, - ) - del truncate_token_mask, truncate_after_or_token_mask, truncate_mask - torch.cuda.empty_cache() - - # 2. run reward model on the truncated responses - postprocessed_query_responses = torch.cat((queries, postprocessed_responses), 1) - postprocessed_query_responses = right_padding_to_left_padding( - postprocessed_query_responses, tokenizer.pad_token_id - ) - scores = get_reward(reward_model, postprocessed_query_responses, tokenizer).flatten() - - # 3. filter response. Ensure that the sample contains truncate_token - # responses not passing that filter will receive a low (fixed) score - # only query humans on responses that pass that filter - matches_token = postprocessed_responses[:, args.task.truncate_after :] == args.task.truncate_token - filter_mask = torch.any(matches_token, dim=-1) - scores = torch.where( - filter_mask, - scores, - torch.full_like(scores, args.task.penalty_reward_value), - ) - del matches_token, filter_mask - torch.cuda.empty_cache() - - # 4. compute rewards - kl = logprobs - ref_logprobs - non_score_reward = -kl_ctl.value * kl - rewards = non_score_reward.clone() - rewards[:, -1] += scores - - # 5. whiten rewards - if args.ppo.whiten_rewards: - rewards = whiten(rewards, shift_mean=False) - try: - sample_kl = kl[0].sum().item() - postprocessed_responses = postprocessed_query_responses[:, context_length:] - console.print( - f"[green]{tokenizer.decode(queries[0], skip_special_tokens=True)}[/]\n[yellow]{tokenizer.decode(postprocessed_responses[0], skip_special_tokens=True)}[/]\n[blue](NO POST-PROCESSING){tokenizer.decode(responses[0], skip_special_tokens=True)}[/]\n[red]score: {scores[0]}, kl: {kl[0].sum().item()}, total reward: {scores[0] - kl_ctl.value * sample_kl} [/]" - ) - except Exception as e: - print(e) - del postprocessed_query_responses - torch.cuda.empty_cache() - - # 6. compute advantages and returns - lastgaelam = 0 - advantages_reversed = [] - gen_length = args.task.response_length - for t in reversed(range(gen_length)): - nextvalues = values[:, t + 1] if t < gen_length - 1 else 0.0 - delta = rewards[:, t] + args.ppo.gamma * nextvalues - values[:, t] - lastgaelam = delta + args.ppo.gamma * args.ppo.lam * lastgaelam - advantages_reversed.append(lastgaelam) - advantages = torch.stack(advantages_reversed[::-1], axis=1) - returns = advantages + values - advantages = whiten(advantages) - return_mean, return_var = returns.mean(), returns.var() - value_mean, value_var = values.mean(), values.var() - - # Do multiple epochs of PPO training, with a fresh random shuffle in each epoch - for ppo_epoch_idx in range(args.ppo.noptepochs): - b_inds = np.random.permutation(args.ppo.local_batch_size) - minibatch_idx = 0 - for mini_batch_start in range(0, args.ppo.local_batch_size, args.ppo.local_mini_batch_size): - mini_batch_end = mini_batch_start + args.ppo.local_mini_batch_size - mini_batch_inds = b_inds[mini_batch_start:mini_batch_end] - gradient_accumulation_idx = 0 - for micro_batch_start in range(0, args.ppo.local_mini_batch_size, args.ppo.local_micro_batch_size): - with accelerator.accumulate(policy): - micro_batch_end = micro_batch_start + args.ppo.local_micro_batch_size - micro_batch_inds = mini_batch_inds[micro_batch_start:micro_batch_end] - mb_return = returns[micro_batch_inds] - mb_advantage = advantages[micro_batch_inds] - mb_values = values[micro_batch_inds] - mb_responses = responses[micro_batch_inds] - mb_query_responses = query_responses[micro_batch_inds] - mb_logprobs = logprobs[micro_batch_inds] - - output, vpred_temp = forward(policy, mb_query_responses, tokenizer) - logits = output.logits[:, context_length - 1 : -1] - logits /= args.task.temperature - new_all_logprobs = F.log_softmax(logits, dim=-1) - new_logprobs = torch.gather(new_all_logprobs, 2, mb_responses.unsqueeze(-1)).squeeze(-1) - vpred = vpred_temp[:, context_length - 1 : -1].squeeze(-1) - vpredclipped = torch.clamp( - vpred, - mb_values - args.ppo.cliprange_value, - mb_values + args.ppo.cliprange_value, - ) - vf_losses1 = torch.square(vpred - mb_return) - vf_losses2 = torch.square(vpredclipped - mb_return) - vf_loss = 0.5 * torch.max(vf_losses1, vf_losses2).mean() - vf_clipfrac = (vf_losses2 > vf_losses1).float().mean() - logprobs_diff = new_logprobs - mb_logprobs - ratio = torch.exp(logprobs_diff) - pg_losses = -mb_advantage * ratio - pg_losses2 = -mb_advantage * torch.clamp(ratio, 1.0 - args.ppo.cliprange, 1.0 + args.ppo.cliprange) - pg_loss = torch.max(pg_losses, pg_losses2).mean() - pg_clipfrac = (pg_losses2 > pg_losses).float().mean() - loss = pg_loss + args.ppo.vf_coef * vf_loss - accelerator.backward(loss) - optimizer.step() - optimizer.zero_grad() - pd = torch.nn.functional.softmax(logits, dim=-1) - entropy = torch.logsumexp(logits, dim=-1) - torch.sum(pd * logits, dim=-1) - approxkl = 0.5 * (logprobs_diff**2).mean() - with torch.no_grad(): - approxkls_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = approxkl - clipfracs_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = pg_clipfrac - pg_losses_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = pg_loss - vf_losses_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = vf_loss - vf_clipfrac_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = vf_clipfrac - entropies_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = entropy.mean() - gradient_accumulation_idx += 1 - minibatch_idx += 1 - if accelerator.is_main_process: - console.print( - f"ppo_epoch_idx", - ppo_epoch_idx, - "approxkl", - approxkl.item(), - "pg_loss", - pg_loss.item(), - "pg_clipfrac", - pg_clipfrac.item(), - "ratio", - ratio.mean().item(), - ) - - with torch.no_grad(): - if not args.deepspeed: # for some reason there is a OOM with the `writer.add_histogram` - writer.add_histogram("ppo/val/ratio_hist", ratio, update) - kl = logprobs - ref_logprobs - mean_kl = kl.sum(1).mean() - mean_entropy = (-logprobs).sum(1).mean() - mean_non_score_reward = non_score_reward.sum(1).mean() - writer.add_scalar("objective/kl_coef", kl_ctl.value, update) - writer.add_scalar("objective/kl", accelerator.gather(mean_kl).mean().item(), update) - writer.add_scalar( - "objective/entropy", - accelerator.gather(mean_entropy).mean().item(), - update, - ) - writer.add_scalar( - "objective/non_score_reward", - accelerator.gather(mean_non_score_reward).mean().item(), - update, - ) - writer.add_scalar( - "objective/score_total", - accelerator.gather(mean_non_score_reward + scores.mean()).mean().item(), - update, - ) - writer.add_scalar( - "objective/scores", - accelerator.gather(scores.mean()).mean().item(), - update, - ) - writer.add_scalar("ppo/loss/policy", accelerator.gather(pg_loss).mean().item(), update) - writer.add_scalar("ppo/loss/value", accelerator.gather(vf_loss).mean().item(), update) - writer.add_scalar("ppo/loss/total", accelerator.gather(loss).mean().item(), update) - writer.add_scalar( - "ppo/policy/entropy", - accelerator.gather(entropy.mean()).mean().item(), - update, - ) - writer.add_scalar( - "ppo/policy/approxkl", - accelerator.gather(approxkl).mean().item(), - update, - ) - writer.add_scalar( - "ppo/policy/clipfrac", - accelerator.gather(pg_clipfrac).mean().item(), - update, - ) - writer.add_scalar( - "ppo/policy/approxkl_avg", - accelerator.gather(approxkls_stats).mean().item(), - update, - ) - writer.add_scalar( - "ppo/policy/clipfrac_avg", - accelerator.gather(clipfracs_stats).mean().item(), - update, - ) - writer.add_scalar( - "ppo/loss/policy_avg", - accelerator.gather(pg_losses_stats).mean().item(), - update, - ) - writer.add_scalar( - "ppo/loss/value_avg", - accelerator.gather(vf_losses_stats).mean().item(), - update, - ) - writer.add_scalar( - "ppo/val/clipfrac_avg", - accelerator.gather(vf_clipfrac_stats).mean().item(), - update, - ) - writer.add_scalar( - "ppo/policy/entropy_avg", - accelerator.gather(entropies_stats).mean().item(), - update, - ) - writer.add_scalar( - "ppo/returns/mean", - accelerator.gather(return_mean).mean().item(), - update, - ) - writer.add_scalar("ppo/returns/var", accelerator.gather(return_var).mean().item(), update) - writer.add_scalar("ppo/val/vpred", accelerator.gather(vpred.mean()).mean().item(), update) - writer.add_scalar( - "ppo/val/error", - accelerator.gather(vf_losses1.mean()).mean().item(), - update, - ) - writer.add_scalar( - "ppo/val/clipfrac", - accelerator.gather(vf_clipfrac).mean().item(), - update, - ) - writer.add_scalar("ppo/val/mean", accelerator.gather(value_mean).mean().item(), update) - writer.add_scalar("ppo/val/var", accelerator.gather(value_var).mean().item(), update) - writer.add_scalar("ppo/val/ratio", accelerator.gather(ratio.mean()).mean().item(), update) - writer.add_scalar( - "ppo/val/ratio_var", - accelerator.gather(ratio.mean()).var().item(), - update, - ) - writer.add_scalar( - "ppo/val/advantage", - accelerator.gather(advantages.mean()).mean().item(), - update, - ) - writer.add_scalar( - "ppo/val/advantage_var", - accelerator.gather(advantages.mean()).var().item(), - update, - ) - writer.add_scalar( - "ppo/val/num_eos_tokens", - (responses == tokenizer.eos_token_id).sum().item(), - update, - ) - writer.add_scalar("ppo/lr", lrnow, update) - writer.add_scalar("ppo/episode", global_step, update) - kl_ctl.update(mean_kl.item(), args.ppo.batch_size) - del kl, mean_kl, mean_entropy, mean_non_score_reward, scores - - # save model - if args.save_path: - os.makedirs(os.path.dirname(args.save_path), exist_ok=True) - torch.save(reward_model.state_dict(), args.save_path) - - -if __name__ == "__main__": - args = tyro.cli(Args) - train(args) diff --git a/lm_human_preference_details/summarization/train_policy_accelerate_summarize_ref_diff.py b/lm_human_preference_details/summarization/train_policy_accelerate_summarize_ref_diff.py deleted file mode 100644 index 50aca9e..0000000 --- a/lm_human_preference_details/summarization/train_policy_accelerate_summarize_ref_diff.py +++ /dev/null @@ -1,894 +0,0 @@ -import os -import random -import time -from dataclasses import asdict, dataclass, field -from types import SimpleNamespace -from typing import List, Optional - -import numpy as np -import pandas as pd -import torch -import torch.nn as nn -import torch.nn.functional as F -import torch.optim as optim -import tyro -from accelerate import Accelerator -from accelerate.state import AcceleratorState -from datasets import load_dataset -from rich.console import Console -from rich.pretty import pprint -from rich.table import Table -from torch import Tensor, optim -from torch.optim.optimizer import ( - _dispatch_sqrt, - _get_value, - _use_grad_for_differentiable, -) -from torch.utils.data import DataLoader -from torch.utils.tensorboard import SummaryWriter -from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig - -from lm_human_preference_details.data import process_query - - -@dataclass -class AdaptiveKLParams: - target: float = 6.0 - horizon: int = 10000 # in episodes - - -@dataclass -class RewardHParams: - kl_coef: float = 0.15 - use_adaptive_kl: bool = True - adaptive_kl: Optional[AdaptiveKLParams] = field(default_factory=AdaptiveKLParams) - trained_model: Optional[str] = "models/reward.pt" - label_dataset: tyro.conf.Suppress[Optional[str]] = None - - -@dataclass -class PpoHParams: - total_episodes: int = 1000000 - local_batch_size: int = 64 - local_mini_batch_size: tyro.conf.Suppress[int] = None - batch_size: tyro.conf.Suppress[int] = None - mini_batch_size: tyro.conf.Suppress[int] = None - gradient_accumulation_steps: int = 1 - """gradient accumulation steps""" - local_micro_batch_size: tyro.conf.Suppress[int] = None - """per rank micro batch size""" - world_size: tyro.conf.Suppress[int] = None - batch_size: tyro.conf.Suppress[int] = None - minibatch_size: tyro.conf.Suppress[int] = None - num_updates: tyro.conf.Suppress[int] = None - nminibatches: int = 1 - noptepochs: int = 4 - lr: float = 0.00001 - eps: float = 1e-5 - vf_coef: float = 0.1 - cliprange: float = 0.2 - cliprange_value: float = 0.2 - gamma: float = 1 - lam: float = 0.95 - whiten_rewards: bool = True - - -@dataclass -class TaskHParams: - # Query params - query_length: int = 512 - query_dataset: str = "vwxyzjn/summarize_from_feedback_tldr_3_filtered" - - query_format_str: Optional[str] = "SUBREDDIT: r/{subreddit}\n\nTITLE: {title}\n\nPOST: {post}\n\nTL;DR:" - query_truncate_field: Optional[str] = "post" - query_truncate_text: Optional[str] = "\n" - query_padding: Optional[str] = None # defaults to repeated spaces - query_pad_side: Optional[str] = "left" - - # Response params - response_length: int = 48 - - # Truncate response after the first occurrence of this token at or after index after when sampling. - truncate_token: int = 50256 # EOS token - truncate_after: int = 16 - penalty_reward_value: int = -1 - - # LM params - temperature: float = 0.7 - - -# a patch -@dataclass -class TaskQueryHParams: - length: int = None - dataset: str = None - format_str: Optional[str] = None # if underlying dataset yields dicts, can format arbitrarily - truncate_field: Optional[str] = None - truncate_text: Optional[str] = None - padding: Optional[str] = None # defaults to repeated spaces - pad_side: Optional[str] = None - - -@dataclass -class Args: - # common args - exp_name: str = os.path.basename(__file__)[: -len(".py")] - """the name of this experiment""" - seed: int = 1 - """seed of the experiment""" - track: bool = False - """if toggled, this experiment will be tracked with Weights and Biases""" - wandb_project_name: str = "cleanrl" - """the wandb's project name""" - wandb_entity: Optional[str] = None - """the entity (team) of wandb's project""" - cuda: bool = True - """Whether to use cuda if available.""" - run_name: tyro.conf.Suppress[str] = None - """TO BE FILLED: a unique name of this run""" - upload_model: bool = False - "whether to upload the saved model to huggingface" - hf_entity: str = "" - "the user or org name of the model repository from the Hugging Face Hub" - - base_model: str = "gpt2" - """the name of the pretrained model to use""" - deepspeed: bool = False - """Whether to use deepspeed to train the model""" - print_sample_output_freq: int = 10 - """How often to print sample output""" - sft_model_path: str = "models/sft_policy.pt" - """Where to load the SFT model""" - save_path: str = "models/policy.pt" - """Where to save the model""" - use_tensorflow_adam: bool = True - """Whether to use tensorflow-style Adam optimizer instead of PyTorch's""" - task: TaskHParams = field(default_factory=TaskHParams) - rewards: RewardHParams = field(default_factory=RewardHParams) - ppo: PpoHParams = field(default_factory=PpoHParams) - - -def first_true_indices(bools, dtype=torch.long): - """ - Takes an N-dimensional bool tensor and returns an (N-1)-dimensional tensor of integers giving - the position of the first True in each "row". - - Returns the length of the rows (bools.size(-1)) if no element is True in a given row. - """ - row_len = bools.size(-1) - zero_or_index = row_len * (~bools).type(dtype) + torch.arange(row_len, dtype=dtype, device=bools.device) - return torch.min(zero_or_index, dim=-1).values - - -def print_rich_table(title: str, df: pd.DataFrame, console: Console) -> Table: - table = Table(show_lines=True) - for column in df.columns: - table.add_column(column) - for _, row in df.iterrows(): - table.add_row(*row.astype(str).tolist()) - console.rule(f"[bold red]{title}") - console.print(table) - - -def _single_tensor_adam( - params: List[Tensor], - grads: List[Tensor], - exp_avgs: List[Tensor], - exp_avg_sqs: List[Tensor], - max_exp_avg_sqs: List[Tensor], - state_steps: List[Tensor], - grad_scale: Optional[Tensor], - found_inf: Optional[Tensor], - *, - amsgrad: bool, - beta1: float, - beta2: float, - lr: float, - weight_decay: float, - eps: float, - maximize: bool, - capturable: bool, - differentiable: bool, -): - assert grad_scale is None and found_inf is None - - for i, param in enumerate(params): - grad = grads[i] if not maximize else -grads[i] - exp_avg = exp_avgs[i] - exp_avg_sq = exp_avg_sqs[i] - step_t = state_steps[i] - # update step - step_t += 1 - # Decay the first and second moment running average coefficient - exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) - exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2) - step = _get_value(step_t) - - ### pytorch adam implementation: - # bias_correction1 = 1 - beta1 ** step - # bias_correction2 = 1 - beta2 ** step - # step_size = lr / bias_correction1 - # bias_correction2_sqrt = _dispatch_sqrt(bias_correction2) - # denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps) - # param.addcdiv_(exp_avg, denom, value=-step_size) - - ### tensorflow adam implementation: - lr_t = lr * _dispatch_sqrt(1 - beta2**step) / (1 - beta1**step) - denom = exp_avg_sq.sqrt().add_(eps) - param.addcdiv_(exp_avg, denom, value=-lr_t) - - -def adam( - params: List[Tensor], - grads: List[Tensor], - exp_avgs: List[Tensor], - exp_avg_sqs: List[Tensor], - max_exp_avg_sqs: List[Tensor], - state_steps: List[Tensor], - # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627 - # setting this as kwarg for now as functional API is compiled by torch/distributed/optim - foreach: Optional[bool] = None, - capturable: bool = False, - differentiable: bool = False, - fused: Optional[bool] = None, - grad_scale: Optional[Tensor] = None, - found_inf: Optional[Tensor] = None, - *, - amsgrad: bool, - beta1: float, - beta2: float, - lr: float, - weight_decay: float, - eps: float, - maximize: bool, -): - func = _single_tensor_adam - - func( - params, - grads, - exp_avgs, - exp_avg_sqs, - max_exp_avg_sqs, - state_steps, - amsgrad=amsgrad, - beta1=beta1, - beta2=beta2, - lr=lr, - weight_decay=weight_decay, - eps=eps, - maximize=maximize, - capturable=capturable, - differentiable=differentiable, - grad_scale=grad_scale, - found_inf=found_inf, - ) - - -class AdamTensorFlowStyle(optim.Adam): - @_use_grad_for_differentiable - def step(self, closure=None): - self._cuda_graph_capture_health_check() - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - - for group in self.param_groups: - params_with_grad = [] - grads = [] - exp_avgs = [] - exp_avg_sqs = [] - max_exp_avg_sqs = [] - state_steps = [] - beta1, beta2 = group["betas"] - - self._init_group( - group, - params_with_grad, - grads, - exp_avgs, - exp_avg_sqs, - max_exp_avg_sqs, - state_steps, - ) - - adam( - params_with_grad, - grads, - exp_avgs, - exp_avg_sqs, - max_exp_avg_sqs, - state_steps, - amsgrad=group["amsgrad"], - beta1=beta1, - beta2=beta2, - lr=group["lr"], - weight_decay=group["weight_decay"], - eps=group["eps"], - maximize=group["maximize"], - foreach=group["foreach"], - capturable=group["capturable"], - differentiable=group["differentiable"], - fused=group["fused"], - grad_scale=getattr(self, "grad_scale", None), - found_inf=getattr(self, "found_inf", None), - ) - - return loss - - -class AdaptiveKLController: - def __init__(self, init_kl_coef: float, hparams: AdaptiveKLParams): - self.value = init_kl_coef - self.hparams = hparams - - def update(self, current, n_steps): - target = self.hparams.target - proportional_error = np.clip(current / target - 1, -0.2, 0.2) - mult = 1 + proportional_error * n_steps / self.hparams.horizon - self.value *= mult - - -def layer_init(layer, std=np.sqrt(2), bias_const=0.0): - torch.nn.init.normal_(layer.weight, std=std) - torch.nn.init.constant_(layer.bias, val=bias_const) - return layer - - -def whiten(values, shift_mean=True): - # `unbiased=False` matches TF `tf.nn.moments`'s setting - mean, var = torch.mean(values), torch.var(values, unbiased=False) - whitened = (values - mean) * torch.rsqrt(var + 1e-8) - if not shift_mean: - whitened += mean - return whitened - - -class AutoModelForCausalLMWithScalarHead(nn.Module): - def __init__(self, lm_backbone): - super().__init__() - self.lm_backbone = lm_backbone - self.scalar_head = layer_init(nn.Linear(lm_backbone.config.hidden_size, 1), std=0) - - def forward(self, **kwargs): - output = self.lm_backbone(**kwargs) - return output, self.scalar_head(output.hidden_states[-1]) - - -class AutoModelForCausalLMWithRewardHead(nn.Module): - def __init__(self, lm_backbone): - super().__init__() - self.lm_backbone = lm_backbone - self.scalar_head = layer_init( - nn.Linear(lm_backbone.config.hidden_size, 1), - std=1 / np.sqrt(lm_backbone.config.hidden_size + 1), - ) - self.reward_gain = torch.nn.Parameter(torch.tensor(1.0), requires_grad=True) - self.reward_bias = torch.nn.Parameter(torch.tensor(0.0), requires_grad=True) - - def forward(self, **kwargs): - output = self.lm_backbone(**kwargs) - reward_latents = output.hidden_states[-1] - # shape: [batch_size, length, hidden_size] - last_reward_latents = reward_latents - # shape: [batch_size, hidden_size] - reward = self.scalar_head(last_reward_latents) - # shape: [batch_size, 1] - reward = self.reward_gain * reward + self.reward_bias - return output, reward - - -def right_padding_to_left_padding(tokens, pad_id): - """Convert from right padding to left padding.""" - assert tokens.ndim == 2 - return torch.tensor( - [[pad_id] * (row == pad_id).sum() + [x for x in row if x != pad_id] for row in tokens], - device=tokens.device, - ) - - -def ceil_div(a, b): - return (a - 1) // b + 1 - - -def exact_div(a, b): - q = a // b - if a != q * b: - raise ValueError(f"Inexact division: {a} / {b} = {a / b}") - return q - - -def generate(lm_backbone, queries, tokenizer, generation_config): - """generate in a way that does not affect padding tokens""" - context_length = queries.shape[1] - attention_mask = queries != tokenizer.pad_token_id - input_ids = queries.clone() - input_ids[~attention_mask] = 0 # set padding tokens to 0 - output = lm_backbone.generate( - input_ids=input_ids, - attention_mask=attention_mask, - # position_ids=attention_mask.cumsum(1) - attention_mask.long(), # generation collapsed if this was turned on. TODO: why does generation collapse with this? - generation_config=generation_config, - return_dict_in_generate=True, - ) - # restore padding tokens - return torch.cat((queries, output.sequences[:, context_length:]), dim=1) - - -def get_reward(reward_model, query_responses, args): - attention_mask = query_responses != args.pad_token_id - position_ids = attention_mask.cumsum(1) - attention_mask.long() # exclusive cumsum - input_ids = query_responses.clone() - input_ids[~attention_mask] = 0 - return reward_model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - return_dict=True, - output_hidden_states=True, - ) - - -def forward(policy, query_responses, tokenizer): - attention_mask = query_responses != tokenizer.pad_token_id - position_ids = attention_mask.cumsum(1) - attention_mask.long() # exclusive cumsum - input_ids = query_responses.clone() - input_ids[~attention_mask] = 0 - return policy( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - return_dict=True, - output_hidden_states=True, - ) - - -def train(args: Args): - accelerator = Accelerator(gradient_accumulation_steps=args.ppo.gradient_accumulation_steps) - args.ppo.world_size = accelerator.num_processes - args.ppo.batch_size = int(args.ppo.local_batch_size * args.ppo.world_size) - args.ppo.minibatch_size = exact_div(args.ppo.batch_size, args.ppo.nminibatches) - args.ppo.local_mini_batch_size = exact_div(args.ppo.local_batch_size, args.ppo.nminibatches) - args.ppo.local_micro_batch_size = exact_div(args.ppo.local_mini_batch_size, args.ppo.gradient_accumulation_steps) - patch_h = TaskQueryHParams( - length=args.task.query_length, - dataset=args.task.query_dataset, - format_str=args.task.query_format_str, - truncate_field=args.task.query_truncate_field, - truncate_text=args.task.query_truncate_text, - padding=args.task.query_padding, - pad_side=args.task.query_pad_side, - ) - if args.ppo.whiten_rewards: - assert ( - args.ppo.local_mini_batch_size >= 8 - ), f"Per-rank minibatch size {args.ppo.local_mini_batch_size} is insufficient for whitening" - # `per_rank_rollout_batch_size` is our `args.ppo.local_batch_size` - # `per_rank_minibatch_size` is our `args.ppo.local_mini_batch_size` - args.ppo.num_updates = args.ppo.total_episodes // args.ppo.batch_size - - console = Console(force_terminal=True) - run_name = f"{args.exp_name}__{args.seed}__{int(time.time())}" - writer = SimpleNamespace() # dummy writer - writer.add_scalar = lambda x, y, z: None - writer.add_histogram = lambda x, y, z: None - if accelerator.is_main_process: - if args.track: - import wandb - - wandb.init( - project=args.wandb_project_name, - entity=args.wandb_entity, - sync_tensorboard=True, - config=asdict(args), - name=run_name, - save_code=True, - ) - wandb.run.log_code(".") - writer = SummaryWriter(f"runs/{run_name}") - writer.add_text( - "hyperparameters", - "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])), - ) - pprint(args) - device = accelerator.device - local_seed = args.seed + accelerator.process_index * 100003 # Prime - random.seed(local_seed) - np.random.seed(local_seed) - torch.manual_seed(local_seed) - torch.backends.cudnn.deterministic = True - tokenizer = AutoTokenizer.from_pretrained( - args.base_model, - padding_side="right", - trust_remote_code=True, - ) - # we use the padding token manually but do not resize the token embedding of the model - tokenizer.add_special_tokens({"pad_token": "[PAD]"}) - reward_model = AutoModelForCausalLMWithRewardHead( - AutoModelForCausalLM.from_pretrained(args.base_model, trust_remote_code=True) - ) - if args.rewards.trained_model: - reward_model.load_state_dict(torch.load(args.rewards.trained_model, map_location=device)) - print(f"loaded pretrained reward model from {args.rewards.trained_model}") - # each class should have a separate pretrained model that do not share weights - ref_policy = AutoModelForCausalLMWithScalarHead( - AutoModelForCausalLM.from_pretrained(args.base_model, trust_remote_code=True) - ) - policy = AutoModelForCausalLMWithScalarHead(AutoModelForCausalLM.from_pretrained(args.base_model, trust_remote_code=True)) - if args.sft_model_path: - policy.lm_backbone.load_state_dict(torch.load(args.sft_model_path, map_location=device)) - ref_policy.lm_backbone.load_state_dict(torch.load(args.sft_model_path, map_location=device)) - print(f"loaded pretrained policy from {args.sft_model_path}") - policy.lm_backbone.generation_config.eos_token_id = ( - None # disable `pad_token_id` and `eos_token_id` because we just want to - ) - policy.lm_backbone.generation_config.pad_token_id = None # generate tokens without truncation / padding - # IMPORTANT: Layer norm produces weird gradients, which affects Adam optimizer to impact all the parameters systematically - # see https://github.com/pytorch/pytorch/issues/104857 for more details - if args.use_tensorflow_adam: - optimizer = AdamTensorFlowStyle(policy.parameters(), lr=args.ppo.lr, eps=args.ppo.eps) - else: - optimizer = optim.Adam(policy.parameters(), lr=args.ppo.lr, eps=args.ppo.eps) - dataset = load_dataset(args.task.query_dataset, split="train") - - def process_query_data(x): - return { - **process_query(x, encoder=tokenizer, hparams=patch_h), - "reference_response": tokenizer.encode( - f" {x['summary']}", - padding="max_length", - max_length=args.task.response_length, - truncation=True, - # with an extra leading space to account for the space between the query and response - ), - } - - dataset = dataset.map(process_query_data) - dataset = dataset.with_format("torch", columns=["query_token", "reference_response"]) - dataset = dataset.shuffle(seed=local_seed) - dataloader = DataLoader(dataset, batch_size=args.ppo.local_batch_size) - policy, optimizer, dataloader = accelerator.prepare(policy, optimizer, dataloader) - if args.deepspeed: - import deepspeed - - deepspeed_states = AcceleratorState().deepspeed_plugin - # deepspeed_states.deepspeed_config["train_micro_batch_size_per_gpu"] = args.ppo.local_micro_batch_size - # deepspeed_states.deepspeed_config["checkpoint"] = {"use_node_local_storage": True} - eval_ds_config = { - "train_micro_batch_size_per_gpu": deepspeed_states.deepspeed_config["train_micro_batch_size_per_gpu"], - # "steps_per_print": 10, - # "zero_optimization": { - # "stage": stage, - # "stage3_param_persistence_threshold": 1e4, - # "offload_param": { - # "device": off_load_device - # } - # }, - "bf16": {"enabled": True}, - "prescale_gradients": False, - "wall_clock_breakdown": False, - } - reward_model, *_ = deepspeed.initialize(model=reward_model, config=eval_ds_config) - reward_model.eval() - ref_policy, *_ = deepspeed.initialize(model=ref_policy, config=eval_ds_config) - ref_policy.eval() - else: - ref_policy = ref_policy.to(device) - reward_model = reward_model.to(device) - - def repeat_generator(): # TODO: ideally we shuffle the dataloader as well - while True: - yield from dataloader - - iter_dataloader = iter(repeat_generator()) - kl_ctl = AdaptiveKLController(args.rewards.kl_coef, hparams=args.rewards.adaptive_kl) - # WARNING: even with `max_new_tokens` and `min_new_tokens` set to the same value, the number of tokens generated - # may not be the same. TODO: investigate further, we just want to generate a fixed number of tokens - generation_config = GenerationConfig( - max_new_tokens=args.task.response_length, - min_new_tokens=args.task.response_length, - temperature=(args.task.temperature + 1e-7), - top_k=0.0, - top_p=1.0, - do_sample=True, - ) - - print("===training policy===") - global_step = 0 - stats_shape = (args.ppo.noptepochs, args.ppo.nminibatches, args.ppo.gradient_accumulation_steps) - approxkls_stats = torch.zeros(stats_shape, device=device) - clipfracs_stats = torch.zeros(stats_shape, device=device) - pg_losses_stats = torch.zeros(stats_shape, device=device) - vf_losses_stats = torch.zeros(stats_shape, device=device) - vf_clipfrac_stats = torch.zeros(stats_shape, device=device) - entropies_stats = torch.zeros(stats_shape, device=device) - for update in range(1, args.ppo.num_updates + 1): - global_step += 1 * args.ppo.batch_size - frac = 1.0 - (update - 1.0) / args.ppo.num_updates - lrnow = frac * args.ppo.lr - optimizer.param_groups[0]["lr"] = lrnow - data = next(iter_dataloader) - with torch.no_grad(): - queries = data["query_token"].to(device) - reference_responses = data["reference_response"].to(device) - query_reference_responses = torch.cat((queries, reference_responses), dim=1) - queries = right_padding_to_left_padding(data["query_token"], tokenizer.pad_token_id).to(device) - query_reference_responses = right_padding_to_left_padding(query_reference_responses, tokenizer.pad_token_id).to( - device - ) - query_responses = generate( - accelerator.unwrap_model(policy).lm_backbone, - queries, - tokenizer, - generation_config, - ) - context_length = queries.shape[1] - responses = query_responses[:, context_length:] - - output, full_values = forward(policy, query_responses, tokenizer) - values = full_values[:, context_length - 1 : -1].squeeze(-1) - logits = output.logits[:, context_length - 1 : -1] - logits /= args.task.temperature + 1e-7 - all_logprobs = F.log_softmax(logits, dim=-1) - logprobs = torch.gather(all_logprobs, 2, responses.unsqueeze(-1)).squeeze(-1) - del output, logits, all_logprobs - torch.cuda.empty_cache() - - ref_output, _ = forward(ref_policy, query_responses, tokenizer) - ref_logits = ref_output.logits[:, context_length - 1 : -1] - ref_logits /= args.task.temperature + 1e-7 - ref_all_logprobs = F.log_softmax(ref_logits, dim=-1) - ref_logprobs = torch.gather(ref_all_logprobs, 2, responses.unsqueeze(-1)).squeeze(-1) - del ref_output, ref_logits, ref_all_logprobs - torch.cuda.empty_cache() - - # **Response Processing** - # 1. truncate at the first occurrence of `truncate_token` that appears at or after - # position truncate_after in the responses - # https://github.com/openai/lm-human-preferences/blob/cbfd210bb8b08f6bc5c26878c10984b90f516c66/lm_human_preferences/train_policy.py#L378 - truncate_token_mask = responses == args.task.truncate_token - truncate_after_or_token_mask = torch.cat( - [ - torch.zeros_like(truncate_token_mask)[:, : args.task.truncate_after], - truncate_token_mask[:, args.task.truncate_after :], - ], - dim=1, - ) - truncate_mask = (torch.cumsum(truncate_after_or_token_mask, dim=1) - truncate_after_or_token_mask.long()).bool() - postprocessed_responses = torch.where( - truncate_mask, - torch.full_like(responses, tokenizer.pad_token_id), - responses, - ) - del truncate_token_mask, truncate_after_or_token_mask, truncate_mask - torch.cuda.empty_cache() - - # 2. run reward model on the truncated responses - postprocessed_query_responses = torch.cat((queries, postprocessed_responses), 1) - postprocessed_query_responses = right_padding_to_left_padding( - postprocessed_query_responses, tokenizer.pad_token_id - ) - scores = get_reward(reward_model, postprocessed_query_responses, tokenizer)[1] - last_response_indices = first_true_indices(query_responses == tokenizer.pad_token_id) - 1 - last_response_indices = torch.max( - last_response_indices, - torch.zeros([1], dtype=last_response_indices.dtype, device=query_responses.device), - ) - scores = scores[:, :, 0].gather(1, last_response_indices.unsqueeze(1)).view(-1) - - reference_scores = get_reward(reward_model, query_reference_responses, tokenizer)[1] - last_reference_response_indices = first_true_indices(query_reference_responses == tokenizer.pad_token_id) - 1 - last_reference_response_indices = torch.max( - last_reference_response_indices, - torch.zeros([1], dtype=last_reference_response_indices.dtype, device=query_reference_responses.device), - ) - reference_scores = reference_scores[:, :, 0].gather(1, last_reference_response_indices.unsqueeze(1)).view(-1) - - # 3. filter response. Ensure that the sample contains truncate_token - # responses not passing that filter will receive a low (fixed) score - # only query humans on responses that pass that filter - matches_token = postprocessed_responses[:, args.task.truncate_after :] == args.task.truncate_token - filter_mask = torch.any(matches_token, dim=-1) - scores = torch.where( - filter_mask, - scores, - torch.full_like(scores, args.task.penalty_reward_value), - ) - del matches_token, filter_mask - torch.cuda.empty_cache() - - # 4. compute rewards - kl = logprobs - ref_logprobs - non_score_reward = -kl_ctl.value * kl - rewards = non_score_reward.clone() - rewards[:, -1] += scores - - # 5. whiten rewards - if args.ppo.whiten_rewards: - rewards = whiten(rewards, shift_mean=False) - - if args.print_sample_output_freq > 0 and (update - 1) % args.print_sample_output_freq == 0: - try: - all_decode_queries = tokenizer.batch_decode(queries, skip_special_tokens=True) - all_postprocessed_query_responses = tokenizer.batch_decode( - postprocessed_query_responses, skip_special_tokens=True - ) - all_postprocessed_responses = [ - x[len(y) :] for x, y in zip(all_postprocessed_query_responses, all_decode_queries) - ] - all_reference_responses = tokenizer.batch_decode(reference_responses, skip_special_tokens=True) - - kl_sum = kl.sum(axis=1) - all_df = pd.DataFrame( - { - "query": all_decode_queries, - "response": all_postprocessed_responses, - "reference_responses": all_reference_responses, - "score": scores.float().cpu().numpy(), - "reference_scores": reference_scores.float().cpu().numpy(), - "kl": kl_sum.float().cpu().numpy(), - "reward": (scores - kl_ctl.value * kl_sum).float().cpu().numpy(), - } - ) - if accelerator.is_main_process and args.track: - wandb.log({"query_responses": wandb.Table(dataframe=all_df)}, step=update) - print_rich_table("stuff", all_df[:4], console) - except Exception as e: - print(e) - del ( - all_decode_queries, - all_postprocessed_query_responses, - all_postprocessed_responses, - kl_sum, - all_df, - ) - del postprocessed_query_responses - torch.cuda.empty_cache() - - # 6. compute advantages and returns - lastgaelam = 0 - advantages_reversed = [] - gen_length = args.task.response_length - for t in reversed(range(gen_length)): - nextvalues = values[:, t + 1] if t < gen_length - 1 else 0.0 - delta = rewards[:, t] + args.ppo.gamma * nextvalues - values[:, t] - lastgaelam = delta + args.ppo.gamma * args.ppo.lam * lastgaelam - advantages_reversed.append(lastgaelam) - advantages = torch.stack(advantages_reversed[::-1], axis=1) - returns = advantages + values - advantages = whiten(advantages) - return_mean, return_var = returns.mean(), returns.var() - value_mean, value_var = values.mean(), values.var() - - # Do multiple epochs of PPO training, with a fresh random shuffle in each epoch - for ppo_epoch_idx in range(args.ppo.noptepochs): - b_inds = np.random.permutation(args.ppo.local_batch_size) - minibatch_idx = 0 - for mini_batch_start in range(0, args.ppo.local_batch_size, args.ppo.local_mini_batch_size): - mini_batch_end = mini_batch_start + args.ppo.local_mini_batch_size - mini_batch_inds = b_inds[mini_batch_start:mini_batch_end] - gradient_accumulation_idx = 0 - for micro_batch_start in range(0, args.ppo.local_mini_batch_size, args.ppo.local_micro_batch_size): - with accelerator.accumulate(policy): - micro_batch_end = micro_batch_start + args.ppo.local_micro_batch_size - micro_batch_inds = mini_batch_inds[micro_batch_start:micro_batch_end] - mb_return = returns[micro_batch_inds] - mb_advantage = advantages[micro_batch_inds] - mb_values = values[micro_batch_inds] - mb_responses = responses[micro_batch_inds] - mb_query_responses = query_responses[micro_batch_inds] - mb_logprobs = logprobs[micro_batch_inds] - - output, vpred_temp = forward(policy, mb_query_responses, tokenizer) - logits = output.logits[:, context_length - 1 : -1] - logits /= args.task.temperature + 1e-7 - new_all_logprobs = F.log_softmax(logits, dim=-1) - new_logprobs = torch.gather(new_all_logprobs, 2, mb_responses.unsqueeze(-1)).squeeze(-1) - vpred = vpred_temp[:, context_length - 1 : -1].squeeze(-1) - vpredclipped = torch.clamp( - vpred, - mb_values - args.ppo.cliprange_value, - mb_values + args.ppo.cliprange_value, - ) - vf_losses1 = torch.square(vpred - mb_return) - vf_losses2 = torch.square(vpredclipped - mb_return) - vf_loss = 0.5 * torch.max(vf_losses1, vf_losses2).mean() - vf_clipfrac = (vf_losses2 > vf_losses1).float().mean() - logprobs_diff = new_logprobs - mb_logprobs - ratio = torch.exp(logprobs_diff) - pg_losses = -mb_advantage * ratio - pg_losses2 = -mb_advantage * torch.clamp(ratio, 1.0 - args.ppo.cliprange, 1.0 + args.ppo.cliprange) - pg_loss = torch.max(pg_losses, pg_losses2).mean() - pg_clipfrac = (pg_losses2 > pg_losses).float().mean() - loss = pg_loss + args.ppo.vf_coef * vf_loss - accelerator.backward(loss) - optimizer.step() - optimizer.zero_grad() - prob_dist = torch.nn.functional.softmax(logits, dim=-1) - entropy = torch.logsumexp(logits, dim=-1) - torch.sum(prob_dist * logits, dim=-1) - approxkl = 0.5 * (logprobs_diff**2).mean() - with torch.no_grad(): - approxkls_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = approxkl - clipfracs_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = pg_clipfrac - pg_losses_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = pg_loss - vf_losses_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = vf_loss - vf_clipfrac_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = vf_clipfrac - entropies_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = entropy.mean() - gradient_accumulation_idx += 1 - minibatch_idx += 1 - if accelerator.is_main_process: - console.print( - f"ppo_epoch_idx", - ppo_epoch_idx, - "approxkl", - approxkl.item(), - "pg_loss", - pg_loss.item(), - "pg_clipfrac", - pg_clipfrac.item(), - "ratio", - ratio.mean().item(), - ) - - with torch.no_grad(): - if not args.deepspeed: # for some reason there is a OOM with the `writer.add_histogram` - writer.add_histogram("ppo/val/ratio_hist", ratio, update) - kl = logprobs - ref_logprobs - mean_kl = kl.sum(1).mean() - mean_entropy = (-logprobs).sum(1).mean() - mean_non_score_reward = non_score_reward.sum(1).mean() - writer.add_scalar("objective/kl_coef", kl_ctl.value, update) - writer.add_scalar("objective/kl", accelerator.gather(mean_kl).mean().item(), update) - writer.add_scalar("objective/entropy", accelerator.gather(mean_entropy).mean().item(), update) - writer.add_scalar("objective/non_score_reward", accelerator.gather(mean_non_score_reward).mean().item(), update) - writer.add_scalar( - "objective/score_total", accelerator.gather(mean_non_score_reward + scores.mean()).mean().item(), update - ) - writer.add_scalar("objective/scores", accelerator.gather(scores.mean()).mean().item(), update) - writer.add_scalar("objective/reference_scores", accelerator.gather(reference_scores.mean()).mean().item(), update) - writer.add_scalar("ppo/loss/policy", accelerator.gather(pg_loss).mean().item(), update) - writer.add_scalar("ppo/loss/value", accelerator.gather(vf_loss).mean().item(), update) - writer.add_scalar("ppo/loss/total", accelerator.gather(loss).mean().item(), update) - writer.add_scalar("ppo/policy/entropy", accelerator.gather(entropy.mean()).mean().item(), update) - writer.add_scalar("ppo/policy/approxkl", accelerator.gather(approxkl).mean().item(), update) - writer.add_scalar("ppo/policy/clipfrac", accelerator.gather(pg_clipfrac).mean().item(), update) - writer.add_scalar("ppo/policy/approxkl_avg", accelerator.gather(approxkls_stats).mean().item(), update) - writer.add_scalar("ppo/policy/clipfrac_avg", accelerator.gather(clipfracs_stats).mean().item(), update) - writer.add_scalar("ppo/loss/policy_avg", accelerator.gather(pg_losses_stats).mean().item(), update) - writer.add_scalar("ppo/loss/value_avg", accelerator.gather(vf_losses_stats).mean().item(), update) - writer.add_scalar("ppo/val/clipfrac_avg", accelerator.gather(vf_clipfrac_stats).mean().item(), update) - writer.add_scalar("ppo/policy/entropy_avg", accelerator.gather(entropies_stats).mean().item(), update) - writer.add_scalar("ppo/returns/mean", accelerator.gather(return_mean).mean().item(), update) - writer.add_scalar("ppo/returns/var", accelerator.gather(return_var).mean().item(), update) - writer.add_scalar("ppo/val/vpred", accelerator.gather(vpred.mean()).mean().item(), update) - writer.add_scalar("ppo/val/error", accelerator.gather(vf_losses1.mean()).mean().item(), update) - writer.add_scalar("ppo/val/clipfrac", accelerator.gather(vf_clipfrac).mean().item(), update) - writer.add_scalar("ppo/val/mean", accelerator.gather(value_mean).mean().item(), update) - writer.add_scalar("ppo/val/var", accelerator.gather(value_var).mean().item(), update) - writer.add_scalar("ppo/val/ratio", accelerator.gather(ratio.mean()).mean().item(), update) - writer.add_scalar("ppo/val/ratio_var", accelerator.gather(ratio.mean()).var().item(), update) - writer.add_scalar("ppo/val/advantage", accelerator.gather(advantages.mean()).mean().item(), update) - writer.add_scalar("ppo/val/advantage_var", accelerator.gather(advantages.mean()).var().item(), update) - writer.add_scalar("ppo/val/num_eos_tokens", (responses == tokenizer.eos_token_id).sum().item(), update) - writer.add_scalar("ppo/lr", lrnow, update) - writer.add_scalar("ppo/episode", global_step, update) - if args.rewards.use_adaptive_kl: - kl_ctl.update(mean_kl.item(), args.ppo.batch_size) - del kl, mean_kl, mean_entropy, mean_non_score_reward, scores - - # save model - if accelerator.is_main_process and args.save_path: - os.makedirs(os.path.dirname(args.save_path), exist_ok=True) - torch.save(policy.state_dict(), args.save_path) - - if args.upload_model: - repo_name = f"{args.exp_name}__{args.rewards.label_dataset}__seed{args.seed}__{int(time.time())}" - repo_id = f"{args.hf_entity}/{repo_name}" if args.hf_entity else repo_name - policy.lm_backbone.save_pretrained(repo_id, safe_serialization=True, push_to_hub=True) - tokenizer.save_pretrained(repo_id, push_to_hub=True) - - -if __name__ == "__main__": - args = tyro.cli(Args) - train(args) diff --git a/lm_human_preference_details/summarization/train_reward_accelerate copy.py b/lm_human_preference_details/summarization/train_reward_accelerate copy.py deleted file mode 100644 index aae124c..0000000 --- a/lm_human_preference_details/summarization/train_reward_accelerate copy.py +++ /dev/null @@ -1,736 +0,0 @@ -import os -import random -import time -from dataclasses import asdict, dataclass, field -from types import SimpleNamespace -from typing import List, Optional - -import numpy as np -import torch -import torch.nn as nn -import torch.nn.functional as F -import torch.optim as optim -import tyro -from accelerate import Accelerator -from accelerate.utils import DistributedDataParallelKwargs, broadcast -from datasets import load_dataset -from rich.console import Console -from rich.pretty import pprint -from torch import Tensor, optim -from torch.optim.optimizer import ( - _dispatch_sqrt, - _get_value, - _use_grad_for_differentiable, -) -from torch.utils.data import DataLoader, IterableDataset -from torch.utils.tensorboard import SummaryWriter -from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig - -from lm_human_preference_details.data import DATASET - - -@dataclass -class LabelHParams: - type: str = None - num_train: int = 4992 - num_labels: int = 4 - source: str = None - - -@dataclass -class TaskHParams: - # Query params - query_length: int = 64 - query_dataset: str = "books" - query_prefix: str = "" - query_suffix: str = "" - start_text: Optional[str] = None - end_text: Optional[str] = None - - # Response params - response_length: int = 24 - - # LM params - temperature: float = 0.7 - - -@dataclass -class Args: - # common args - exp_name: str = os.path.basename(__file__)[: -len(".py")] - """the name of this experiment""" - seed: int = 1 - """seed of the experiment""" - track: bool = False - """if toggled, this experiment will be tracked with Weights and Biases""" - wandb_project_name: str = "cleanrl" - """the wandb's project name""" - wandb_entity: Optional[str] = None - """the entity (team) of wandb's project""" - cuda: bool = True - """Whether to use cuda if available.""" - run_name: tyro.conf.Suppress[str] = None - """TO BE FILLED: a unique name of this run""" - - base_model: str = "gpt2" - """the name of the pretrained model to use""" - label_dataset: str = "sentiment/offline_5k.json" - """the name of the dataset to use for labels in `https://huggingface.co/datasets/vwxyzjn/lm-human-preferences`""" - local_batch_size: int = 4 - """per rank batch size""" - gradient_accumulation_steps: int = 1 - """gradient accumulation steps""" - local_micro_batch_size: tyro.conf.Suppress[int] = None - """per rank micro batch size""" - lr: float = 0.00005 - """the learning rate""" - eps: float = 1e-5 - """the epsilon for AdamW""" - rollout_batch_size: int = 512 - """rollout batch size""" - world_size: tyro.conf.Suppress[int] = None - """the number of processes to use""" - batch_size: tyro.conf.Suppress[int] = None - """the batch size across all ranks""" - local_normalize_samples: int = 256 - """Samples used to estimate reward mean and std""" - normalize_samples: tyro.conf.Suppress[int] = None - """Samples used to estimate reward mean and std across all ranks""" - debug_normalize: int = 0 - """Samples used to check that normalization worked""" - normalize_before: bool = True - """Whether, before training, to normalize the rewards on the policy to the scales on the training buffer. (For comparisons, just use mean 0, var 1.)""" - normalize_after: bool = True - """Whether, after training, to normalize the rewards on the ref policy to mean 0, var 1 (so the KL coefficient always has the same meaning).""" - print_sample_output_freq: int = 10 - """How often to print sample output""" - save_path: str = "models/reward.pt" - """Where to save the model""" - use_tensorflow_adam: bool = True - """Whether to use tensorflow-style Adam optimizer instead of PyTorch's""" - task: TaskHParams = field(default_factory=TaskHParams) - labels: LabelHParams = field(default_factory=LabelHParams) - - -OPENAI_PAD_TOKEN_ID = 50259 - - -def _single_tensor_adam( - params: List[Tensor], - grads: List[Tensor], - exp_avgs: List[Tensor], - exp_avg_sqs: List[Tensor], - max_exp_avg_sqs: List[Tensor], - state_steps: List[Tensor], - grad_scale: Optional[Tensor], - found_inf: Optional[Tensor], - *, - amsgrad: bool, - beta1: float, - beta2: float, - lr: float, - weight_decay: float, - eps: float, - maximize: bool, - capturable: bool, - differentiable: bool, -): - assert grad_scale is None and found_inf is None - - for i, param in enumerate(params): - grad = grads[i] if not maximize else -grads[i] - exp_avg = exp_avgs[i] - exp_avg_sq = exp_avg_sqs[i] - step_t = state_steps[i] - # update step - step_t += 1 - # Decay the first and second moment running average coefficient - exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) - exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2) - step = _get_value(step_t) - - ### pytorch adam implementation: - # bias_correction1 = 1 - beta1 ** step - # bias_correction2 = 1 - beta2 ** step - # step_size = lr / bias_correction1 - # bias_correction2_sqrt = _dispatch_sqrt(bias_correction2) - # denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps) - # param.addcdiv_(exp_avg, denom, value=-step_size) - - ### tensorflow adam implementation: - lr_t = lr * _dispatch_sqrt(1 - beta2**step) / (1 - beta1**step) - denom = exp_avg_sq.sqrt().add_(eps) - param.addcdiv_(exp_avg, denom, value=-lr_t) - - -def adam( - params: List[Tensor], - grads: List[Tensor], - exp_avgs: List[Tensor], - exp_avg_sqs: List[Tensor], - max_exp_avg_sqs: List[Tensor], - state_steps: List[Tensor], - # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627 - # setting this as kwarg for now as functional API is compiled by torch/distributed/optim - foreach: Optional[bool] = None, - capturable: bool = False, - differentiable: bool = False, - fused: Optional[bool] = None, - grad_scale: Optional[Tensor] = None, - found_inf: Optional[Tensor] = None, - *, - amsgrad: bool, - beta1: float, - beta2: float, - lr: float, - weight_decay: float, - eps: float, - maximize: bool, -): - func = _single_tensor_adam - - func( - params, - grads, - exp_avgs, - exp_avg_sqs, - max_exp_avg_sqs, - state_steps, - amsgrad=amsgrad, - beta1=beta1, - beta2=beta2, - lr=lr, - weight_decay=weight_decay, - eps=eps, - maximize=maximize, - capturable=capturable, - differentiable=differentiable, - grad_scale=grad_scale, - found_inf=found_inf, - ) - - -class AdamTensorFlowStyle(optim.Adam): - @_use_grad_for_differentiable - def step(self, closure=None): - self._cuda_graph_capture_health_check() - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - - for group in self.param_groups: - params_with_grad = [] - grads = [] - exp_avgs = [] - exp_avg_sqs = [] - max_exp_avg_sqs = [] - state_steps = [] - beta1, beta2 = group["betas"] - - self._init_group( - group, - params_with_grad, - grads, - exp_avgs, - exp_avg_sqs, - max_exp_avg_sqs, - state_steps, - ) - - adam( - params_with_grad, - grads, - exp_avgs, - exp_avg_sqs, - max_exp_avg_sqs, - state_steps, - amsgrad=group["amsgrad"], - beta1=beta1, - beta2=beta2, - lr=group["lr"], - weight_decay=group["weight_decay"], - eps=group["eps"], - maximize=group["maximize"], - foreach=group["foreach"], - capturable=group["capturable"], - differentiable=group["differentiable"], - fused=group["fused"], - grad_scale=getattr(self, "grad_scale", None), - found_inf=getattr(self, "found_inf", None), - ) - - return loss - - -def layer_init(layer, std=np.sqrt(2), bias_const=0.0): - torch.nn.init.normal_(layer.weight, std=std) - torch.nn.init.constant_(layer.bias, val=bias_const) - return layer - - -class AutoModelForCausalLMWithRewardHead(nn.Module): - def __init__(self, lm_backbone): - super().__init__() - self.lm_backbone = lm_backbone - self.scalar_head = layer_init( - nn.Linear(lm_backbone.config.hidden_size, 1), - std=1 / np.sqrt(lm_backbone.config.hidden_size + 1), - ) - self.reward_gain = torch.nn.Parameter(torch.tensor(1.0), requires_grad=True) - self.reward_bias = torch.nn.Parameter(torch.tensor(0.0), requires_grad=True) - - def forward(self, **kwargs): - output = self.lm_backbone(**kwargs) - reward_latents = output.hidden_states[-1] - # shape: [batch_size, length, hidden_size] - last_reward_latents = reward_latents[:, -1, :] - # shape: [batch_size, hidden_size] - reward = self.scalar_head(last_reward_latents) - # shape: [batch_size, 1] - reward = self.reward_gain * reward + self.reward_bias - return output, reward - - -# Dataset for reward-model normalization -class NormalizationDataset(IterableDataset): - """A dataset for reward model normalization.""" - - def __init__(self, generator, tokenizer, query_length, seed, start_text=None, end_text=None): - self.generator = generator - self.tokenizer = tokenizer - self.query_length = query_length - self.start_text = start_text - self.end_text = end_text - self.seed = seed - token_to_index = tokenizer.get_vocab() - self.start_token = token_to_index[start_text] if self.start_text else None - self.end_token = token_to_index[end_text] if self.end_text else None - - def __iter__(self): - for text in self.generator("train", self.seed, shuffle=True): - tokens = self.tokenizer.encode(text) - if self.start_token is not None: - try: - first_index = tokens.index(self.start_token) + 1 - if first_index < len(tokens): - tokens = tokens[first_index:] - except: - continue - tokens = tokens[: self.query_length] - if self.end_token is not None: - try: - last_index = len(tokens) - tokens[::-1].index(self.end_token) - tokens = tokens[:last_index] - except: - continue - output = self.tokenizer.pad( - {"input_ids": tokens}, - padding="max_length", - max_length=self.query_length, - return_tensors="pt", - return_attention_mask=True, - ) - yield output - - -def right_padding_to_left_padding(tokens, pad_id): - """Convert from right padding to left padding.""" - assert tokens.ndim == 2 - return torch.tensor( - [[pad_id] * (row == pad_id).sum() + [x for x in row if x != pad_id] for row in tokens], - device=tokens.device, - ) - - -def ceil_div(a, b): - return (a - 1) // b + 1 - - -def exact_div(a, b): - q = a // b - if a != q * b: - raise ValueError(f"Inexact division: {a} / {b} = {a / b}") - return q - - -def generate(lm_backbone, queries, args, generation_config): - """generate in a way that does not affect padding tokens""" - context_length = queries.shape[1] - attention_mask = queries != args.pad_token_id - input_ids = queries.clone() - input_ids[~attention_mask] = 0 # set padding tokens to 0 - output = lm_backbone.generate( - input_ids=input_ids, - attention_mask=attention_mask, - # position_ids=attention_mask.cumsum(1) - attention_mask.long(), # generation collapsed if this was turned on. TODO: why does generation collapse with this? - generation_config=generation_config, - return_dict_in_generate=True, - ) - # restore padding tokens - return torch.cat((queries, output.sequences[:, context_length:]), dim=1) - - -def get_reward(reward_model, query_responses, args): - attention_mask = query_responses != args.pad_token_id - position_ids = attention_mask.cumsum(1) - attention_mask.long() # exclusive cumsum - input_ids = query_responses.clone() - input_ids[~attention_mask] = 0 - return reward_model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - return_dict=True, - output_hidden_states=True, - ) - - -def normalize( - args, - accelerator, - device, - lm_backbone, - reward_model, - iter_dataloader, - generation_config, -): - with torch.no_grad(): - # reset reward scales - accelerator.unwrap_model(reward_model).reward_gain.data.fill_(1.0) - accelerator.unwrap_model(reward_model).reward_bias.data.fill_(0.0) - - # sample queries and responses - n_batches = ceil_div(args.local_normalize_samples, args.rollout_batch_size) - sample_queries_responses = [] - for _ in range(n_batches): - data = next(iter_dataloader) - queries = data["input_ids"].to(device) - queries = right_padding_to_left_padding(data["input_ids"], args.pad_token_id).to(device) - query_responses = generate(lm_backbone, queries, args, generation_config) - sample_queries_responses.append(query_responses) - - # compute reward statistics - rewards = [] - for query_responses in sample_queries_responses: - rewards.append(get_reward(reward_model, query_responses, args)[1]) - rewards = torch.cat(rewards) - rewards = accelerator.gather(rewards) - # shape: [args.local_normalize_samples, 1] - mean, std = rewards.mean(), rewards.std() - print(f"mean: {mean}, std: {std}") - - # reward normalization - target_mean, target_std = torch.tensor(0.0, device=device), torch.tensor(1.0, device=device) - gain = target_std / std - bias = target_mean - gain * mean - print(f"gain: {gain}, bias: {bias}") - accelerator.unwrap_model(reward_model).reward_gain.data = gain - accelerator.unwrap_model(reward_model).reward_bias.data = bias - - # validate normalization - n_batches = ceil_div(args.local_normalize_samples, args.rollout_batch_size) - sample_queries_responses = [] - for _ in range(n_batches): - data = next(iter_dataloader) - queries = data["input_ids"].to(device) - queries = right_padding_to_left_padding(data["input_ids"], args.pad_token_id).to(device) - query_responses = generate(lm_backbone, queries, args, generation_config) - sample_queries_responses.append(query_responses) - rewards = [] - for query_responses in sample_queries_responses: - rewards.append(get_reward(reward_model, query_responses, args)[1]) - rewards = torch.cat(rewards) - rewards = accelerator.gather(rewards) - mean, std = rewards.mean(), rewards.std() - print(f"after mean: {mean}, after std: {std}") - - -def train(args: Args): - accelerator = Accelerator( - kwargs_handlers=[ - DistributedDataParallelKwargs(broadcast_buffers=False) - ], # this is needed to avoid https://github.com/pytorch/pytorch/issues/22095#issuecomment-505099500 - gradient_accumulation_steps=args.gradient_accumulation_steps, - ) - args.world_size = accelerator.num_processes - args.batch_size = int(args.local_batch_size * args.world_size) - args.local_micro_batch_size = exact_div(args.local_batch_size, args.gradient_accumulation_steps) - - console = Console(force_terminal=True) - run_name = f"{args.exp_name}__{args.seed}__{int(time.time())}" - writer = SimpleNamespace() # dummy writer - writer.add_scalar = lambda x, y, z: None - if accelerator.is_main_process: - if args.track: - import wandb - - wandb.init( - project=args.wandb_project_name, - entity=args.wandb_entity, - sync_tensorboard=True, - config=asdict(args), - name=run_name, - save_code=True, - ) - wandb.run.log_code(".") - writer = SummaryWriter(f"runs/{run_name}") - writer.add_text( - "hyperparameters", - "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])), - ) - pprint(args) - local_seed = args.seed + accelerator.process_index * 100003 # Prime - device = accelerator.device - random.seed(local_seed) - np.random.seed(local_seed) - torch.manual_seed(local_seed) - torch.backends.cudnn.deterministic = True - tokenizer = AutoTokenizer.from_pretrained( - args.base_model, - padding_side="right", - trust_remote_code=True, - ) - # we use the padding token manually but do not resize the token embedding of the model - tokenizer.add_special_tokens({"pad_token": "[PAD]"}) - args.pad_token_id = tokenizer.pad_token_id - untrained_model = AutoModelForCausalLMWithRewardHead( - AutoModelForCausalLM.from_pretrained(args.base_model, trust_remote_code=True) - ).to(device) - reward_model = AutoModelForCausalLMWithRewardHead( - AutoModelForCausalLM.from_pretrained(args.base_model, trust_remote_code=True) - ).to(device) - untrained_model.lm_backbone.generation_config.eos_token_id = ( - None # disable `pad_token_id` and `eos_token_id` because we just want to - ) - untrained_model.lm_backbone.generation_config.pad_token_id = None # generate tokens without truncation / padding - reward_model.lm_backbone.generation_config.eos_token_id = ( - None # disable `pad_token_id` and `eos_token_id` because we just want to - ) - reward_model.lm_backbone.generation_config.pad_token_id = None # generate tokens without truncation / padding - if args.use_tensorflow_adam: - optimizer = AdamTensorFlowStyle(reward_model.parameters(), lr=args.lr, eps=args.eps) - else: - optimizer = optim.Adam(reward_model.parameters(), lr=args.lr, eps=args.eps) - normalization_dataset = NormalizationDataset( - DATASET[args.task.query_dataset], - tokenizer, - args.task.query_length, - seed=local_seed, - start_text=args.task.start_text, - end_text=args.task.end_text, - ) - normalization_dataloader = DataLoader(normalization_dataset, batch_size=args.rollout_batch_size) - reward_model.lm_backbone._set_gradient_checkpointing(True) - reward_model, optimizer, normalization_dataloader = accelerator.prepare(reward_model, optimizer, normalization_dataloader) - iter_normalization_dataloader = iter(normalization_dataloader) - - generation_config = GenerationConfig( - max_new_tokens=args.task.response_length, - min_new_tokens=args.task.response_length, - temperature=args.task.temperature, - top_k=0.0, - top_p=1.0, - do_sample=True, - ) - - if args.normalize_before: - print("===Normalize reward model *before* training===") - print( - "before normalization. " - + f"Gain: {accelerator.unwrap_model(reward_model).reward_gain.data}" - + f" Bias: {accelerator.unwrap_model(reward_model).reward_bias.data}" - ) - - normalize( - args, - accelerator, - device, - untrained_model.lm_backbone, - reward_model, - iter_normalization_dataloader, - generation_config, - ) - print( - "after normalization. " - + f"Gain: {accelerator.unwrap_model(reward_model).reward_gain.data}" - + f" Bias: {accelerator.unwrap_model(reward_model).reward_bias.data}" - ) - - # `label` has keys `['sample0', 'query', 'best', 'sample3', 'sample1', 'sample2']` - label = load_dataset( - "vwxyzjn/lm-human-preferences", - data_files=[args.label_dataset], - )["train"] - print("Num labels found in source:", len(label)) - print("training on", args.labels.num_train, "in batches of", args.local_batch_size) - - print("===training reward model===") - all_inds = np.random.permutation(args.labels.num_train) - # ensure that all processes have the same shuffled indices - all_inds = broadcast(torch.tensor(all_inds, device=device), 0) - all_inds = all_inds.cpu().numpy() - global_step = 0 - for start in range(0, args.labels.num_train, args.batch_size): - # linear rate annealing - lr = (1 - start / args.labels.num_train) * args.lr - optimizer.param_groups[0]["lr"] = lr - - global_step += 1 - end = start + args.batch_size - b_inds_all = all_inds[start:end] - b_inds = b_inds_all[accelerator.process_index :: accelerator.num_processes] # multi-GPU slicing - losses = torch.zeros((args.gradient_accumulation_steps,), device=device) - accuracies = torch.zeros((args.gradient_accumulation_steps,), device=device) - gradient_accumulation_step = 0 - for micro_batch_start in range(0, args.local_batch_size, args.local_micro_batch_size): - with accelerator.accumulate(reward_model): - micro_batch_end = micro_batch_start + args.local_micro_batch_size - micro_batch_inds = b_inds[micro_batch_start:micro_batch_end] - mb_data = label[micro_batch_inds] - mb_query = torch.from_numpy(np.stack(mb_data["query"])).to(device) - mb_best = torch.from_numpy(np.stack(mb_data["best"])).to(device) - mb_responses = [ - torch.from_numpy(np.stack(mb_data[f"sample{i}"])).to(device) for i in range(args.labels.num_labels) - ] - # hack: deal with openai's padding token - mb_query[mb_query == OPENAI_PAD_TOKEN_ID] = args.pad_token_id - for item in mb_responses: - item[item == OPENAI_PAD_TOKEN_ID] = args.pad_token_id - - predicted_rewards = [] - for i in range(args.labels.num_labels): - query_responses = torch.cat([mb_query, mb_responses[i]], dim=1) - query_responses = right_padding_to_left_padding(query_responses, args.pad_token_id) - reward = get_reward(reward_model, query_responses, args)[1] - predicted_rewards.append(reward.view(-1)) - predicted_rewards = torch.stack( - predicted_rewards, dim=1 - ) # shape (batch_size, num_labels), basically a reward prediction for each label - accuracy = (predicted_rewards.argmax(1) == mb_best).float().mean() - loss = torch.nn.functional.cross_entropy(predicted_rewards, mb_best) - accelerator.backward(loss) - optimizer.step() # accelerate handles gradient accumulation automatically - optimizer.zero_grad() - losses[gradient_accumulation_step] = loss - accuracies[gradient_accumulation_step] = accuracy - gradient_accumulation_step += 1 - - writer.add_scalar("train/loss", accelerator.gather(losses).mean().item(), global_step) - writer.add_scalar("train/accuracy", accelerator.gather(accuracies).mean().item(), global_step) - writer.add_scalar("train/lr", lr, global_step) - - if args.print_sample_output_freq > 0 and global_step % args.print_sample_output_freq == 0: - with torch.no_grad(): - # eval on test_label, some duplicate code (I don't want to make the training loop into a function...) - test_accuracies = [] - new_all_inds = np.arange(len(label)) - for start in range(args.labels.num_train, len(label), args.batch_size): - end = start + args.batch_size - b_inds_all = new_all_inds[start:end] - b_inds = b_inds_all[accelerator.process_index :: accelerator.num_processes] # multi-GPU slicing - for micro_batch_start in range(0, args.local_batch_size, args.local_micro_batch_size): - micro_batch_end = micro_batch_start + args.local_micro_batch_size - micro_batch_inds = b_inds[micro_batch_start:micro_batch_end] - mb_data = label[micro_batch_inds] - mb_query = torch.from_numpy(np.stack(mb_data["query"])) - mb_query = right_padding_to_left_padding(mb_query, args.pad_token_id).to(device) - mb_best = torch.from_numpy(np.stack(mb_data["best"])).to(device) - mb_responses = [ - torch.from_numpy(np.stack(mb_data[f"sample{i}"])).to(device) for i in range(args.labels.num_labels) - ] - # hack: deal with openai's padding token - mb_query[mb_query == OPENAI_PAD_TOKEN_ID] = args.pad_token_id - for item in mb_responses: - item[item == OPENAI_PAD_TOKEN_ID] = args.pad_token_id - predicted_rewards = [] - for i in range(args.labels.num_labels): - query_responses = torch.cat([mb_query, mb_responses[i]], dim=1) - query_responses = right_padding_to_left_padding(query_responses, args.pad_token_id) - reward = get_reward(reward_model, query_responses, args)[1] - predicted_rewards.append(reward.view(-1)) - predicted_rewards = torch.stack( - predicted_rewards, dim=1 - ) # shape (batch_size, num_labels), basically a reward prediction for each label - accuracy = (predicted_rewards.argmax(1) == mb_best).float().mean() - test_accuracies.append(accuracy) - test_accuracy = accelerator.gather(torch.stack(test_accuracies).mean()).mean().item() - writer.add_scalar("test/accuracy", test_accuracy, global_step) - if accelerator.is_main_process: - print("test/accuracy", test_accuracy, global_step) - - # the part below is testing out some generations and KLs, not presented in the original code - data = next(iter_normalization_dataloader) - queries = data["input_ids"].to(device) - context_length = queries.shape[1] - queries = right_padding_to_left_padding(data["input_ids"], args.pad_token_id).to(device) - query_responses = generate( - accelerator.unwrap_model(reward_model).lm_backbone, - queries, - args, - generation_config, - ) - responses = query_responses[:, context_length:] - - output, reward = get_reward(reward_model, query_responses, args) - logits = output.logits[:, context_length - 1 : -1] - logits /= args.task.temperature - all_logprobs = F.log_softmax(logits, dim=-1) - logprobs = torch.gather(all_logprobs, 2, responses.unsqueeze(-1)).squeeze(-1) - del output, logits, all_logprobs - torch.cuda.empty_cache() - - output, _ = get_reward(untrained_model, query_responses, args) - logits = output.logits[:, context_length - 1 : -1] - logits /= args.task.temperature - all_logprobs = F.log_softmax(logits, dim=-1) - ref_logprobs = torch.gather(all_logprobs, 2, responses.unsqueeze(-1)).squeeze(-1) - del output, logits, all_logprobs - torch.cuda.empty_cache() - - print(f"global_step {global_step}:") - kl = logprobs - ref_logprobs - console.print( - f"[green]{tokenizer.decode(queries[0], skip_special_tokens=True)}[/]" - f"\n[blue]{tokenizer.decode(responses[0], skip_special_tokens=True)}[/]" - f"\n[red]reward: {reward[0].item()}[/]" - f"\n[red]kl: {kl[0].sum().item()}[/]" - f"\n[red]average kl: {kl.sum(1).mean().item()}[/]" - ) - writer.add_scalar("train/kl", kl.sum(1).mean().item(), global_step) - - torch.cuda.empty_cache() - if args.normalize_after: - print("===Normalize reward model *after* training===") - print( - "before normalization. " - + f"Gain: {accelerator.unwrap_model(reward_model).reward_gain.data}" - + f" Bias: {accelerator.unwrap_model(reward_model).reward_bias.data}" - ) - - normalize( - args, - accelerator, - device, - untrained_model.lm_backbone, - reward_model, - iter_normalization_dataloader, - generation_config, - ) - print( - "after normalization. " - + f"Gain: {accelerator.unwrap_model(reward_model).reward_gain.data}" - + f" Bias: {accelerator.unwrap_model(reward_model).reward_bias.data}" - ) - - # save model - if args.save_path: - os.makedirs(os.path.dirname(args.save_path), exist_ok=True) - torch.save(accelerator.unwrap_model(reward_model).state_dict(), args.save_path) - - if accelerator.is_main_process and args.track: - wandb.finish() - - -if __name__ == "__main__": - args = tyro.cli(Args) - train(args) diff --git a/lm_human_preference_details/summarization/train_reward_accelerate_debug copy.py b/lm_human_preference_details/summarization/train_reward_accelerate_debug copy.py deleted file mode 100644 index 91e7a56..0000000 --- a/lm_human_preference_details/summarization/train_reward_accelerate_debug copy.py +++ /dev/null @@ -1,542 +0,0 @@ -import os -import random -import time -from dataclasses import asdict, dataclass, field -from types import SimpleNamespace -from typing import Optional - -import numpy as np -import torch -import torch.nn as nn -import torch.nn.functional as F -import torch.optim as optim -import tyro -from accelerate import Accelerator -from accelerate.utils import DistributedDataParallelKwargs, broadcast -from datasets import load_dataset -from rich.console import Console -from rich.pretty import pprint -from torch.utils.data import DataLoader, IterableDataset -from torch.utils.tensorboard import SummaryWriter -from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig - -from lm_human_preference_details.datamod import DATASET - - -@dataclass -class LabelHParams: - type: str = None - num_train: int = 4992 - num_labels: int = 4 - source: str = None - - -@dataclass -class TaskHParams: - # Query params - query_length: int = 64 - query_dataset: str = "books" - query_prefix: str = "" - query_suffix: str = "" - start_text: Optional[str] = None - end_text: Optional[str] = None - - # Response params - response_length: int = 24 - - # LM params - temperature: float = 0.7 - - -@dataclass -class Args: - # common args - exp_name: str = os.path.basename(__file__)[: -len(".py")] - """the name of this experiment""" - seed: int = 1 - """seed of the experiment""" - track: bool = False - """if toggled, this experiment will be tracked with Weights and Biases""" - wandb_project_name: str = "cleanrl" - """the wandb's project name""" - wandb_entity: Optional[str] = None - """the entity (team) of wandb's project""" - cuda: bool = True - """Whether to use cuda if available.""" - run_name: tyro.conf.Suppress[str] = None - """TO BE FILLED: a unique name of this run""" - - base_model: str = "gpt2" - """the name of the pretrained model to use""" - label_dataset: str = "sentiment/offline_5k.json" - """the name of the dataset to use for labels in `https://huggingface.co/datasets/vwxyzjn/lm-human-preferences`""" - local_batch_size: int = 4 - """per rank batch size""" - lr: float = 0.00005 - """the learning rate""" - eps: float = 1e-5 - """the epsilon for AdamW""" - local_rollout_batch_size: int = 512 - """per rank rollot batch size""" - world_size: tyro.conf.Suppress[int] = None - """the number of processes to use""" - batch_size: tyro.conf.Suppress[int] = None - """the batch size across all ranks""" - normalize_samples: int = 256 - """Samples used to estimate reward mean and std""" - debug_normalize: int = 0 - """Samples used to check that normalization worked""" - normalize_before: bool = True - """Whether, before training, to normalize the rewards on the policy to the scales on the training buffer. (For comparisons, just use mean 0, var 1.)""" - normalize_after: bool = True - """Whether, after training, to normalize the rewards on the ref policy to mean 0, var 1 (so the KL coefficient always has the same meaning).""" - print_sample_output_freq: int = 10 - """How often to print sample output""" - save_path: str = "models/reward.pt" - """Where to save the model""" - task: TaskHParams = field(default_factory=TaskHParams) - labels: LabelHParams = field(default_factory=LabelHParams) - - -def layer_init(layer, std=np.sqrt(2), bias_const=0.0): - torch.nn.init.normal_(layer.weight, std=std) - torch.nn.init.constant_(layer.bias, val=bias_const) - return layer - - -OPENAI_PAD_TOKEN_ID = 50259 - - -class ScalarHead(nn.Module): - def __init__(self, config, scale=None, **kwargs): - super().__init__() - if not hasattr(config, "summary_dropout_prob"): - summary_dropout_prob = kwargs.pop("summary_dropout_prob", 0.1) - else: - summary_dropout_prob = config.summary_dropout_prob - self.dropout = nn.Dropout(summary_dropout_prob) if summary_dropout_prob else nn.Identity() - # some models such as OPT have a projection layer before the word embeddings - e.g. OPT-350m - if hasattr(config, "word_embed_proj_dim"): - hidden_size = config.word_embed_proj_dim - else: - hidden_size = config.hidden_size - if scale is None: - scale = 1 / np.sqrt(hidden_size + 1) - self.summary = layer_init(nn.Linear(hidden_size, 1), std=scale) - self.flatten = nn.Flatten() - - def forward(self, hidden_states): - output = self.dropout(hidden_states) - output = self.summary(output) - return output - - -class AutoModelForCausalLMWithScalarHead(nn.Module): - def __init__(self, pretrained_model): - super().__init__() - self.pretrained_model = pretrained_model - self.scalar_head = ScalarHead(self.pretrained_model.config, scale=0.0) - - def forward(self, **kwargs): - output = self.pretrained_model(**kwargs) - return output, self.scalar_head(output.hidden_states[-1]) - - -class AutoModelForCausalLMWithRewardHead(nn.Module): - def __init__(self, pretrained_model): - super().__init__() - self.pretrained_model = pretrained_model - self.scalar_head = ScalarHead(self.pretrained_model.config) - self.reward_gain = torch.nn.Parameter(torch.tensor(1.0), requires_grad=True) - self.reward_bias = torch.nn.Parameter(torch.tensor(0.0), requires_grad=True) - - def forward(self, **kwargs): - output = self.pretrained_model(**kwargs) - reward = self.scalar_head(output.hidden_states[-1]) - reward = self.reward_gain * reward + self.reward_bias - # but we only care about the reward of the last token - reward = reward[:, -1] - return output, reward - - -# a pytorch dataset -class MyDataset(IterableDataset): - def __init__( - self, generator, tokenizer, query_length, start_text=None, end_text=None, query_prefix="", query_suffix="", seed=None - ): - self.generator = generator - self.tokenizer = tokenizer - self.query_length = query_length - self.start_text = start_text - self.end_text = end_text - self.seed = seed - token_to_index = tokenizer.get_vocab() - self.start_token = token_to_index[start_text] if self.start_text else None - self.end_token = token_to_index[end_text] if self.end_text else None - self.query_prefix = query_prefix - self.query_suffix = query_suffix - self.query_prefix_tokens = torch.LongTensor(tokenizer.encode(query_prefix)) - self.query_suffix_tokens = torch.LongTensor(tokenizer.encode(query_suffix)) - - def __iter__(self): - for text in self.generator("train", self.seed, shuffle=True): - tokens = self.tokenizer.encode(text) - if self.start_token is not None: - try: - first_index = tokens.index(self.start_token) + 1 - if first_index < len(tokens): - tokens = tokens[first_index:] - except: - continue - tokens = tokens[: self.query_length] - if self.end_token is not None: - try: - last_index = len(tokens) - tokens[::-1].index(self.end_token) - tokens = tokens[:last_index] - except: - continue - output = self.tokenizer.pad( - {"input_ids": tokens}, - padding="max_length", - max_length=self.query_length, - return_tensors="pt", - ) - output["input_ids"] = torch.cat((self.query_prefix_tokens, output["input_ids"], self.query_suffix_tokens)) - yield output - - -def left_padding_to_right_padding(query, pad_id): - # got to convert to right padding, otherwise `transformers` has weird issues - # even with `position_ids` - return torch.tensor([[pad_id] * (row == pad_id).sum() + [x for x in row if x != pad_id] for row in query]) - - -def ceil_div(a, b): - return (a - 1) // b + 1 - - -def generate(pretrained_model, queries, tokenizer, generation_config): - """generate in a way that does not affect padding tokens""" - context_length = queries.shape[1] - attention_mask = queries != tokenizer.pad_token_id - input_ids = queries.clone() - input_ids[~attention_mask] = 0 # set padding tokens to 0 - output = pretrained_model.generate( - input_ids=input_ids, - attention_mask=attention_mask, - # position_ids=attention_mask.cumsum(1) - attention_mask.long(), # generation collapsed if this was turned on. TODO: why does generation collapse with this? - generation_config=generation_config, - return_dict_in_generate=True, - ) - # restore padding tokens - return torch.cat((queries, output.sequences[:, context_length:]), dim=1) - - -def get_reward(reward_model, query_responses, tokenizer): - attention_mask = query_responses != tokenizer.pad_token_id - position_ids = attention_mask.cumsum(1) - attention_mask.long() # exclusive cumsum - input_ids = query_responses.clone() - input_ids[~attention_mask] = 0 - return reward_model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - return_dict=True, - output_hidden_states=True, - ) - - -def normalize(args, accelerator, device, tokenizer, pretrained_model, reward_model, iter_dataloader, generation_config): - with torch.no_grad(): - # reset reward scales - reward_model.module.reward_gain.data.fill_(1.0) - reward_model.module.reward_bias.data.fill_(0.0) - - # sample queries and responses - n_batches = ceil_div(args.normalize_samples, args.local_rollout_batch_size) - sample_queries_responses = [] - for _ in range(n_batches): - data = next(iter_dataloader) - queries = data["input_ids"].to(device) - queries = left_padding_to_right_padding(data["input_ids"], tokenizer.pad_token_id).to(device) - query_responses = generate(pretrained_model, queries, tokenizer, generation_config) - sample_queries_responses.append(query_responses) - - # compute reward statistics - rewards = [] - for query_responses in sample_queries_responses: - rewards.append(get_reward(reward_model, query_responses, tokenizer)[1]) - rewards = torch.cat(rewards) - rewards = accelerator.gather(rewards) - mean, std = rewards.mean(), rewards.std() - print(f"mean: {mean}, std: {std}") - - # reward normalization - target_mean, target_std = torch.tensor(0.0, device=device), torch.tensor(1.0, device=device) - gain = target_std / std - bias = target_mean - gain * mean - print(f"gain: {gain}, bias: {bias}") - reward_model.module.reward_gain.data = gain - reward_model.module.reward_bias.data = bias - - # after normalization statistics - n_batches = ceil_div(args.normalize_samples, args.local_rollout_batch_size) - sample_queries_responses = [] - for _ in range(n_batches): - data = next(iter_dataloader) - queries = data["input_ids"].to(device) - queries = left_padding_to_right_padding(data["input_ids"], tokenizer.pad_token_id).to(device) - query_responses = generate(pretrained_model, queries, tokenizer, generation_config) - sample_queries_responses.append(query_responses) - rewards = [] - for query_responses in sample_queries_responses: - rewards.append(get_reward(reward_model, query_responses, tokenizer)[1]) - rewards = torch.cat(rewards) - rewards = accelerator.gather(rewards) - mean, std = rewards.mean(), rewards.std() - print(f"after mean: {mean}, after std: {std}") - - -def train(args: Args): - args.task.query_prefix = args.task.query_prefix.replace("\\n", "\n") - args.task.query_suffix = args.task.query_suffix.replace("\\n", "\n") - accelerator = Accelerator( - kwargs_handlers=[ - DistributedDataParallelKwargs(broadcast_buffers=False) - ] # this is needed to avoid https://github.com/pytorch/pytorch/issues/22095#issuecomment-505099500 - ) - args.world_size = accelerator.num_processes - args.batch_size = int(args.local_batch_size * args.world_size) - - console = Console(force_terminal=True) - run_name = f"{args.exp_name}__{args.seed}__{int(time.time())}" - writer = SimpleNamespace() # dummy writer - writer.add_scalar = lambda x, y, z: None - if accelerator.is_main_process: - if args.track: - import wandb - - wandb.init( - project=args.wandb_project_name, - entity=args.wandb_entity, - sync_tensorboard=True, - config=asdict(args), - name=run_name, - save_code=True, - ) - wandb.run.log_code(".") - writer = SummaryWriter(f"runs/{run_name}") - writer.add_text( - "hyperparameters", - "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])), - ) - pprint(args) - device = accelerator.device - args.seed += accelerator.process_index - random.seed(args.seed) - np.random.seed(args.seed) - torch.manual_seed(args.seed) - torch.backends.cudnn.deterministic = True - tokenizer = AutoTokenizer.from_pretrained( - args.base_model, - padding_side="right", - ) - # we use the padding token manually but do not resize the token embedding of the model - tokenizer.add_special_tokens({"pad_token": "[PAD]"}) - untrained_model = AutoModelForCausalLMWithRewardHead(AutoModelForCausalLM.from_pretrained(args.base_model)).to(device) - reward_model = AutoModelForCausalLMWithRewardHead(AutoModelForCausalLM.from_pretrained(args.base_model)).to(device) - reward_model.pretrained_model.generation_config.eos_token_id = ( - None # disable `pad_token_id` and `eos_token_id` because we just want to - ) - reward_model.pretrained_model.generation_config.pad_token_id = None # generate tokens without truncation / padding - optimizer = optim.Adam(reward_model.parameters(), lr=args.lr, eps=args.eps) - dataset = MyDataset( - DATASET[args.task.query_dataset], - tokenizer, - args.task.query_length, - start_text=args.task.start_text, - end_text=args.task.end_text, - query_prefix=args.task.query_prefix, - query_suffix=args.task.query_suffix, - ) - dataloader = DataLoader(dataset, batch_size=args.local_rollout_batch_size) - reward_model, optimizer, dataloader = accelerator.prepare(reward_model, optimizer, dataloader) - print(reward_model) - iter_dataloader = iter(dataloader) - - generation_config = GenerationConfig( - max_new_tokens=args.task.response_length, - min_new_tokens=args.task.response_length, - temperature=args.task.temperature, - top_k=0.0, - top_p=1.0, - do_sample=True, - ) - - # `label` has keys `['sample0', 'query', 'best', 'sample3', 'sample1', 'sample2']` - label = load_dataset( - "vwxyzjn/lm-human-preferences", - data_files=[args.label_dataset], - )["train"] - print("Num labels found in source:", len(label)) - print("training on", args.labels.num_train, "in batches of", args.local_batch_size) - - print("before====", reward_model.module.reward_gain.data) - if args.normalize_before: - normalize( - args, - accelerator, - device, - tokenizer, - accelerator.unwrap_model(reward_model).pretrained_model, - reward_model, - iter_dataloader, - generation_config, - ) - print("after====", reward_model.module.reward_gain.data) - - print("===training reward model===") - all_inds = np.arange(args.labels.num_train) - np.random.shuffle(all_inds) - # ensure that all processes have the same shuffled indices - all_inds = broadcast(torch.tensor(all_inds, device=device), 0) - all_inds = all_inds.cpu().numpy() - global_step = 0 - for start in range(0, args.labels.num_train, args.batch_size): - global_step += 1 - end = start + args.batch_size - b_inds_all = all_inds[start:end] - b_inds = b_inds_all[accelerator.process_index :: accelerator.num_processes] # multi-GPU slicing - lr = (1 - start / args.labels.num_train) * args.lr - optimizer.param_groups[0]["lr"] = lr - mb_data = label[b_inds] - # print("accelerator.process_index", accelerator.process_index, b_inds, b_inds_all) - mb_query = torch.from_numpy(np.stack(mb_data["query"])) - print("mb_query.shape", mb_query.shape) - mb_query = left_padding_to_right_padding(mb_query, tokenizer.pad_token_id).to(device) - mb_best = torch.from_numpy(np.stack(mb_data["best"])).to(device) - mb_responses = [torch.from_numpy(np.stack(mb_data[f"sample{i}"])).to(device) for i in range(args.labels.num_labels)] - # hack: deal with openai's padding token - # assert (mb_query == tokenizer.pad_token_id).sum() == 0 - mb_query[mb_query == OPENAI_PAD_TOKEN_ID] = tokenizer.pad_token_id - for item in mb_responses: - # assert (item == tokenizer.pad_token_id).sum() == 0 - item[item == OPENAI_PAD_TOKEN_ID] = tokenizer.pad_token_id - - predicted_rewards = [] - for i in range(args.labels.num_labels): - query_responses = torch.cat([mb_query, mb_responses[i]], dim=1) - reward = get_reward(reward_model, query_responses, tokenizer)[1] - predicted_rewards.append(reward.squeeze()) - predicted_rewards = torch.stack( - predicted_rewards, dim=1 - ) # shape (batch_size, num_labels), basically a reward prediction for each label - accuracy = (predicted_rewards.argmax(1) == mb_best).float().mean() - loss = torch.nn.functional.cross_entropy(predicted_rewards, mb_best) - optimizer.zero_grad() - accelerator.backward(loss) - optimizer.step() - writer.add_scalar("train/loss", accelerator.gather(loss).mean().item(), global_step) - writer.add_scalar("train/accuracy", accelerator.gather(accuracy).mean().item(), global_step) - writer.add_scalar("train/lr", lr, global_step) - - if args.print_sample_output_freq > 0 and global_step % args.print_sample_output_freq == 0: - with torch.no_grad(): - data = next(iter_dataloader) - queries = data["input_ids"].to(device) - context_length = queries.shape[1] - queries = left_padding_to_right_padding(data["input_ids"], tokenizer.pad_token_id).to(device) - query_responses = generate( - accelerator.unwrap_model(reward_model).pretrained_model, queries, tokenizer, generation_config - ) - responses = query_responses[:, context_length:] - - output, reward = get_reward(reward_model, query_responses, tokenizer) - logits = output.logits[:, context_length - 1 : -1] - logits /= args.task.temperature - all_logprobs = F.log_softmax(logits, dim=-1) - logprobs = torch.gather(all_logprobs, 2, responses.unsqueeze(-1)).squeeze(-1) - - output, _ = get_reward(untrained_model, query_responses, tokenizer) - logits = output.logits[:, context_length - 1 : -1] - logits /= args.task.temperature - all_logprobs = F.log_softmax(logits, dim=-1) - ref_logprobs = torch.gather(all_logprobs, 2, responses.unsqueeze(-1)).squeeze(-1) - - print(f"global_step {global_step}:") - kl = logprobs - ref_logprobs - console.print( - f"[green]{tokenizer.decode(queries[0], skip_special_tokens=True)}[/]" - f"\n[blue]{tokenizer.decode(responses[0], skip_special_tokens=True)}[/]" - f"\n[red]reward: {reward[0].item()}[/]" - f"\n[red]kl: {kl[0].sum().item()}[/]" - f"\n[red]average kl: {kl.sum(1).mean().item()}[/]" - ) - writer.add_scalar("train/kl", kl.sum(1).mean().item(), global_step) - - # eval on test_label - test_accuracies = [] - all_inds = np.arange(len(label)) - for start in range(args.labels.num_train, len(label), args.batch_size): - end = start + args.batch_size - b_inds_all = all_inds[start:end] - b_inds = b_inds_all[accelerator.process_index :: accelerator.num_processes] # multi-GPU slicing - mb_data = label[b_inds] - # print("accelerator.process_index", accelerator.process_index, b_inds, b_inds_all) - mb_query = torch.from_numpy(np.stack(mb_data["query"])) - mb_query = left_padding_to_right_padding(mb_query, tokenizer.pad_token_id).to(device) - mb_best = torch.from_numpy(np.stack(mb_data["best"])).to(device) - mb_responses = [ - torch.from_numpy(np.stack(mb_data[f"sample{i}"])).to(device) for i in range(args.labels.num_labels) - ] - # hack: deal with openai's padding token - # assert (mb_query == tokenizer.pad_token_id).sum() == 0 - mb_query[mb_query == OPENAI_PAD_TOKEN_ID] = tokenizer.pad_token_id - for item in mb_responses: - # assert (item == tokenizer.pad_token_id).sum() == 0 - item[item == OPENAI_PAD_TOKEN_ID] = tokenizer.pad_token_id - - predicted_rewards = [] - for i in range(args.labels.num_labels): - query_responses = torch.cat([mb_query, mb_responses[i]], dim=1) - if i == 0: - print(tokenizer.decode(query_responses[0], skip_special_tokens=True)) - print(tokenizer.decode(mb_responses[i], skip_special_tokens=True)) - breakpoint() - reward = get_reward(reward_model, query_responses, tokenizer)[1] - predicted_rewards.append(reward.squeeze()) - predicted_rewards = torch.stack( - predicted_rewards, dim=1 - ) # shape (batch_size, num_labels), basically a reward prediction for each label - accuracy = (predicted_rewards.argmax(1) == mb_best).float().mean() - test_accuracies.append(accuracy) - test_accuracy = accelerator.gather(torch.stack(test_accuracies).mean()).mean().item() - writer.add_scalar("test/accuracy", test_accuracy, global_step) - if accelerator.is_main_process: - print("test/accuracy", test_accuracy, global_step) - - torch.cuda.empty_cache() - if args.normalize_after: - normalize( - args, - accelerator, - device, - tokenizer, - accelerator.unwrap_model(reward_model).pretrained_model, - reward_model, - iter_dataloader, - generation_config, - ) - - # save model - if args.save_path: - os.makedirs(os.path.dirname(args.save_path), exist_ok=True) - torch.save(accelerator.unwrap_model(reward_model).state_dict(), args.save_path) - - if accelerator.is_main_process and args.track: - wandb.finish() - - -if __name__ == "__main__": - args = tyro.cli(Args) - train(args) diff --git a/lm_human_preference_details/summarization/train_reward_accelerate_debug.py b/lm_human_preference_details/summarization/train_reward_accelerate_debug.py deleted file mode 100644 index 9a8a4ec..0000000 --- a/lm_human_preference_details/summarization/train_reward_accelerate_debug.py +++ /dev/null @@ -1,561 +0,0 @@ -import os -import random -import time -from dataclasses import asdict, dataclass, field -from types import SimpleNamespace -from typing import Optional - -import numpy as np -import torch -import torch.nn as nn -import torch.nn.functional as F -import torch.optim as optim -import tyro -from accelerate import Accelerator -from accelerate.utils import DistributedDataParallelKwargs, broadcast -from datasets import load_dataset -from rich.console import Console -from rich.pretty import pprint -from torch.utils.data import DataLoader, IterableDataset -from torch.utils.tensorboard import SummaryWriter -from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig - -from lm_human_preference_details.datamod import DATASET - - -@dataclass -class LabelHParams: - type: str = None - num_train: int = 4992 - num_labels: int = 4 - source: str = None - - -@dataclass -class TaskHParams: - # Query params - query_length: int = 64 - query_dataset: str = "books" - query_prefix: str = "" - query_suffix: str = "" - start_text: Optional[str] = None - end_text: Optional[str] = None - - # Response params - response_length: int = 24 - - # LM params - temperature: float = 0.7 - - -@dataclass -class Args: - # common args - exp_name: str = os.path.basename(__file__)[: -len(".py")] - """the name of this experiment""" - seed: int = 1 - """seed of the experiment""" - track: bool = False - """if toggled, this experiment will be tracked with Weights and Biases""" - wandb_project_name: str = "cleanrl" - """the wandb's project name""" - wandb_entity: Optional[str] = None - """the entity (team) of wandb's project""" - cuda: bool = True - """Whether to use cuda if available.""" - run_name: tyro.conf.Suppress[str] = None - """TO BE FILLED: a unique name of this run""" - - base_model: str = "gpt2" - """the name of the pretrained model to use""" - label_dataset: str = "sentiment/offline_5k.json" - """the name of the dataset to use for labels in `https://huggingface.co/datasets/vwxyzjn/lm-human-preferences`""" - local_batch_size: int = 4 - """per rank batch size""" - lr: float = 0.00005 - """the learning rate""" - eps: float = 1e-5 - """the epsilon for AdamW""" - local_rollout_batch_size: int = 512 - """per rank rollot batch size""" - world_size: tyro.conf.Suppress[int] = None - """the number of processes to use""" - batch_size: tyro.conf.Suppress[int] = None - """the batch size across all ranks""" - normalize_samples: int = 256 - """Samples used to estimate reward mean and std""" - debug_normalize: int = 0 - """Samples used to check that normalization worked""" - normalize_before: bool = True - """Whether, before training, to normalize the rewards on the policy to the scales on the training buffer. (For comparisons, just use mean 0, var 1.)""" - normalize_after: bool = True - """Whether, after training, to normalize the rewards on the ref policy to mean 0, var 1 (so the KL coefficient always has the same meaning).""" - print_sample_output_freq: int = 10 - """How often to print sample output""" - save_path: str = "models/reward.pt" - """Where to save the model""" - task: TaskHParams = field(default_factory=TaskHParams) - labels: LabelHParams = field(default_factory=LabelHParams) - - -def layer_init(layer, std=np.sqrt(2), bias_const=0.0): - torch.nn.init.normal_(layer.weight, std=std) - torch.nn.init.constant_(layer.bias, val=bias_const) - return layer - - -OPENAI_PAD_TOKEN_ID = 50259 - - -class ScalarHead(nn.Module): - def __init__(self, config, scale=None, **kwargs): - super().__init__() - if not hasattr(config, "summary_dropout_prob"): - summary_dropout_prob = kwargs.pop("summary_dropout_prob", 0.1) - else: - summary_dropout_prob = config.summary_dropout_prob - self.dropout = nn.Dropout(summary_dropout_prob) if summary_dropout_prob else nn.Identity() - # some models such as OPT have a projection layer before the word embeddings - e.g. OPT-350m - if hasattr(config, "word_embed_proj_dim"): - hidden_size = config.word_embed_proj_dim - else: - hidden_size = config.hidden_size - if scale is None: - scale = 1 / np.sqrt(hidden_size + 1) - self.summary = layer_init(nn.Linear(hidden_size, 1), std=scale) - self.flatten = nn.Flatten() - - def forward(self, hidden_states): - output = self.dropout(hidden_states) - output = self.summary(output) - return output - - -class AutoModelForCausalLMWithScalarHead(nn.Module): - def __init__(self, pretrained_model): - super().__init__() - self.pretrained_model = pretrained_model - self.scalar_head = ScalarHead(self.pretrained_model.config, scale=0.0) - - def forward(self, **kwargs): - output = self.pretrained_model(**kwargs) - return output, self.scalar_head(output.hidden_states[-1]) - - -class AutoModelForCausalLMWithRewardHead(nn.Module): - def __init__(self, pretrained_model): - super().__init__() - self.pretrained_model = pretrained_model - self.scalar_head = ScalarHead(self.pretrained_model.config) - self.reward_gain = torch.nn.Parameter(torch.tensor(1.0), requires_grad=True) - self.reward_bias = torch.nn.Parameter(torch.tensor(0.0), requires_grad=True) - - def forward(self, **kwargs): - output = self.pretrained_model(**kwargs) - reward = self.scalar_head(output.hidden_states[-1]) - reward = self.reward_gain * reward + self.reward_bias - # but we only care about the reward of the last token - reward = reward[:, -1] - return output, reward - - -# a pytorch dataset -class MyDataset(IterableDataset): - def __init__(self, generator, tokenizer, query_length, start_text=None, end_text=None, seed=None): - self.generator = generator - self.tokenizer = tokenizer - self.query_length = query_length - self.start_text = start_text - self.end_text = end_text - self.seed = seed - token_to_index = tokenizer.get_vocab() - self.start_token = token_to_index[start_text] if self.start_text else None - self.end_token = token_to_index[end_text] if self.end_text else None - - def __iter__(self): - for text in self.generator("train", self.seed, shuffle=True): - tokens = self.tokenizer.encode(text) - if self.start_token is not None: - try: - first_index = tokens.index(self.start_token) + 1 - if first_index < len(tokens): - tokens = tokens[first_index:] - except: - continue - tokens = tokens[: self.query_length] - if self.end_token is not None: - try: - last_index = len(tokens) - tokens[::-1].index(self.end_token) - tokens = tokens[:last_index] - except: - continue - output = self.tokenizer.pad( - {"input_ids": tokens}, - padding="max_length", - max_length=self.query_length, - return_tensors="pt", - return_attention_mask=True, - ) - yield output - - -def left_padding_to_right_padding(query, pad_id): - # got to convert to right padding, otherwise `transformers` has weird issues - # even with `position_ids` - return torch.tensor([[pad_id] * (row == pad_id).sum() + [x for x in row if x != pad_id] for row in query]) - - -def ceil_div(a, b): - return (a - 1) // b + 1 - - -def generate(pretrained_model, queries, tokenizer, generation_config): - """generate in a way that does not affect padding tokens""" - context_length = queries.shape[1] - attention_mask = queries != tokenizer.pad_token_id - input_ids = queries.clone() - input_ids[~attention_mask] = 0 # set padding tokens to 0 - output = pretrained_model.generate( - input_ids=input_ids, - attention_mask=attention_mask, - # position_ids=attention_mask.cumsum(1) - attention_mask.long(), # generation collapsed if this was turned on. TODO: why does generation collapse with this? - generation_config=generation_config, - return_dict_in_generate=True, - ) - # restore padding tokens - return torch.cat((queries, output.sequences[:, context_length:]), dim=1) - - -def get_reward(reward_model, query_responses, tokenizer): - attention_mask = query_responses != tokenizer.pad_token_id - position_ids = attention_mask.cumsum(1) - attention_mask.long() # exclusive cumsum - input_ids = query_responses.clone() - input_ids[~attention_mask] = 0 - return reward_model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - return_dict=True, - output_hidden_states=True, - ) - - -def normalize( - args, - accelerator, - device, - tokenizer, - pretrained_model, - reward_model, - iter_dataloader, - generation_config, - query_prefix_tokens, - query_suffix_tokens, -): - with torch.no_grad(): - # reset reward scales - reward_model.module.reward_gain.data.fill_(1.0) - reward_model.module.reward_bias.data.fill_(0.0) - - # sample queries and responses - n_batches = ceil_div(args.normalize_samples, args.local_rollout_batch_size) - sample_queries_responses = [] - for _ in range(n_batches): - data = next(iter_dataloader) - queries = data["input_ids"].to(device) - queries = format_query(query_prefix_tokens, queries, query_suffix_tokens) - queries = left_padding_to_right_padding(data["input_ids"], tokenizer.pad_token_id).to(device) - query_responses = generate(pretrained_model, queries, tokenizer, generation_config) - sample_queries_responses.append(query_responses) - - # compute reward statistics - rewards = [] - for query_responses in sample_queries_responses: - rewards.append(get_reward(reward_model, query_responses, tokenizer)[1]) - rewards = torch.cat(rewards) - rewards = accelerator.gather(rewards) - mean, std = rewards.mean(), rewards.std() - print(f"mean: {mean}, std: {std}") - - # reward normalization - target_mean, target_std = torch.tensor(0.0, device=device), torch.tensor(1.0, device=device) - gain = target_std / std - bias = target_mean - gain * mean - print(f"gain: {gain}, bias: {bias}") - reward_model.module.reward_gain.data = gain - reward_model.module.reward_bias.data = bias - - # after normalization statistics - n_batches = ceil_div(args.normalize_samples, args.local_rollout_batch_size) - sample_queries_responses = [] - for _ in range(n_batches): - data = next(iter_dataloader) - queries = data["input_ids"].to(device) - queries = format_query(query_prefix_tokens, queries, query_suffix_tokens) - queries = left_padding_to_right_padding(queries, tokenizer.pad_token_id).to(device) - query_responses = generate(pretrained_model, queries, tokenizer, generation_config) - sample_queries_responses.append(query_responses) - rewards = [] - for query_responses in sample_queries_responses: - rewards.append(get_reward(reward_model, query_responses, tokenizer)[1]) - rewards = torch.cat(rewards) - rewards = accelerator.gather(rewards) - mean, std = rewards.mean(), rewards.std() - print(f"after mean: {mean}, after std: {std}") - - -def format_query(query_prefix_tokens, query, query_suffix_tokens): - query_prefix_tokens_tiled = query_prefix_tokens.unsqueeze(0).repeat(query.shape[0], 1).to(query.device) - query_suffix_tokens_tiled = query_suffix_tokens.unsqueeze(0).repeat(query.shape[0], 1).to(query.device) - return torch.cat((query_prefix_tokens_tiled, query, query_suffix_tokens_tiled), dim=1) - - -def train(args: Args): - args.task.query_prefix = args.task.query_prefix.replace("\\n", "\n") - args.task.query_suffix = args.task.query_suffix.replace("\\n", "\n") - accelerator = Accelerator( - kwargs_handlers=[ - DistributedDataParallelKwargs(broadcast_buffers=False) - ] # this is needed to avoid https://github.com/pytorch/pytorch/issues/22095#issuecomment-505099500 - ) - args.world_size = accelerator.num_processes - args.batch_size = int(args.local_batch_size * args.world_size) - - console = Console(force_terminal=True) - run_name = f"{args.exp_name}__{args.seed}__{int(time.time())}" - writer = SimpleNamespace() # dummy writer - writer.add_scalar = lambda x, y, z: None - if accelerator.is_main_process: - if args.track: - import wandb - - wandb.init( - project=args.wandb_project_name, - entity=args.wandb_entity, - sync_tensorboard=True, - config=asdict(args), - name=run_name, - save_code=True, - ) - wandb.run.log_code(".") - writer = SummaryWriter(f"runs/{run_name}") - writer.add_text( - "hyperparameters", - "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])), - ) - pprint(args) - device = accelerator.device - args.seed += accelerator.process_index - random.seed(args.seed) - np.random.seed(args.seed) - torch.manual_seed(args.seed) - torch.backends.cudnn.deterministic = True - tokenizer = AutoTokenizer.from_pretrained( - args.base_model, - padding_side="right", - use_auth_token=True, - ) - # we use the padding token manually but do not resize the token embedding of the model - tokenizer.add_special_tokens({"pad_token": "[PAD]"}) - query_prefix_tokens = torch.LongTensor(tokenizer.encode(args.task.query_prefix)) - query_suffix_tokens = torch.LongTensor(tokenizer.encode(args.task.query_suffix)) - untrained_model = AutoModelForCausalLMWithRewardHead( - AutoModelForCausalLM.from_pretrained(args.base_model, use_auth_token=True) - ).to(device) - reward_model = AutoModelForCausalLMWithRewardHead( - AutoModelForCausalLM.from_pretrained(args.base_model, use_auth_token=True) - ).to(device) - reward_model.pretrained_model.generation_config.eos_token_id = ( - None # disable `pad_token_id` and `eos_token_id` because we just want to - ) - reward_model.pretrained_model.generation_config.pad_token_id = None # generate tokens without truncation / padding - optimizer = optim.Adam(reward_model.parameters(), lr=args.lr, eps=args.eps) - dataset = MyDataset( - DATASET[args.task.query_dataset], - tokenizer, - args.task.query_length, - start_text=args.task.start_text, - end_text=args.task.end_text, - ) - dataloader = DataLoader(dataset, batch_size=args.local_rollout_batch_size) - reward_model, optimizer, dataloader = accelerator.prepare(reward_model, optimizer, dataloader) - iter_dataloader = iter(dataloader) - - generation_config = GenerationConfig( - max_new_tokens=args.task.response_length, - min_new_tokens=args.task.response_length, - temperature=args.task.temperature, - top_k=0.0, - top_p=1.0, - do_sample=True, - ) - - # `label` has keys `['sample0', 'query', 'best', 'sample3', 'sample1', 'sample2']` - label = load_dataset( - "vwxyzjn/lm-human-preferences", - data_files=[args.label_dataset], - )["train"] - print("Num labels found in source:", len(label)) - print("training on", args.labels.num_train, "in batches of", args.local_batch_size) - - print("before====", reward_model.module.reward_gain.data) - if args.normalize_before: - normalize( - args, - accelerator, - device, - tokenizer, - accelerator.unwrap_model(reward_model).pretrained_model, - reward_model, - iter_dataloader, - generation_config, - query_prefix_tokens, - query_suffix_tokens, - ) - print("after====", reward_model.module.reward_gain.data) - - print("===training reward model===") - all_inds = np.arange(args.labels.num_train) - np.random.shuffle(all_inds) - # ensure that all processes have the same shuffled indices - all_inds = broadcast(torch.tensor(all_inds, device=device), 0) - all_inds = all_inds.cpu().numpy() - global_step = 0 - for start in range(0, args.labels.num_train, args.batch_size): - global_step += 1 - end = start + args.batch_size - b_inds_all = all_inds[start:end] - b_inds = b_inds_all[accelerator.process_index :: accelerator.num_processes] # multi-GPU slicing - lr = (1 - start / args.labels.num_train) * args.lr - optimizer.param_groups[0]["lr"] = lr - mb_data = label[b_inds] - # print("accelerator.process_index", accelerator.process_index, b_inds, b_inds_all) - mb_query = torch.from_numpy(np.stack(mb_data["query"])) - mb_query = format_query(query_prefix_tokens, mb_query, query_suffix_tokens) - mb_query = left_padding_to_right_padding(mb_query, tokenizer.pad_token_id).to(device) - mb_best = torch.from_numpy(np.stack(mb_data["best"])).to(device) - mb_responses = [torch.from_numpy(np.stack(mb_data[f"sample{i}"])).to(device) for i in range(args.labels.num_labels)] - # hack: deal with openai's padding token - # assert (mb_query == tokenizer.pad_token_id).sum() == 0 - mb_query[mb_query == OPENAI_PAD_TOKEN_ID] = tokenizer.pad_token_id - for item in mb_responses: - # assert (item == tokenizer.pad_token_id).sum() == 0 - item[item == OPENAI_PAD_TOKEN_ID] = tokenizer.pad_token_id - - predicted_rewards = [] - for i in range(args.labels.num_labels): - query_responses = torch.cat([mb_query, mb_responses[i]], dim=1) - reward = get_reward(reward_model, query_responses, tokenizer)[1] - predicted_rewards.append(reward.squeeze()) - predicted_rewards = torch.stack( - predicted_rewards, dim=1 - ) # shape (batch_size, num_labels), basically a reward prediction for each label - accuracy = (predicted_rewards.argmax(1) == mb_best).float().mean() - loss = torch.nn.functional.cross_entropy(predicted_rewards, mb_best) - optimizer.zero_grad() - accelerator.backward(loss) - optimizer.step() - writer.add_scalar("train/loss", accelerator.gather(loss).mean().item(), global_step) - writer.add_scalar("train/accuracy", accelerator.gather(accuracy).mean().item(), global_step) - writer.add_scalar("train/lr", lr, global_step) - - if args.print_sample_output_freq > 0 and global_step % args.print_sample_output_freq == 0: - with torch.no_grad(): - data = next(iter_dataloader) - queries = data["input_ids"].to(device) - queries = format_query(query_prefix_tokens, queries, query_suffix_tokens) - context_length = queries.shape[1] - queries = left_padding_to_right_padding(queries, tokenizer.pad_token_id).to(device) - query_responses = generate( - accelerator.unwrap_model(reward_model).pretrained_model, queries, tokenizer, generation_config - ) - responses = query_responses[:, context_length:] - - output, reward = get_reward(reward_model, query_responses, tokenizer) - logits = output.logits[:, context_length - 1 : -1] - logits /= args.task.temperature - all_logprobs = F.log_softmax(logits, dim=-1) - logprobs = torch.gather(all_logprobs, 2, responses.unsqueeze(-1)).squeeze(-1) - - output, _ = get_reward(untrained_model, query_responses, tokenizer) - logits = output.logits[:, context_length - 1 : -1] - logits /= args.task.temperature - all_logprobs = F.log_softmax(logits, dim=-1) - ref_logprobs = torch.gather(all_logprobs, 2, responses.unsqueeze(-1)).squeeze(-1) - - print(f"global_step {global_step}:") - kl = logprobs - ref_logprobs - console.print( - f"[green]{tokenizer.decode(queries[0], skip_special_tokens=True)}[/]" - f"\n[blue]{tokenizer.decode(responses[0], skip_special_tokens=True)}[/]" - f"\n[red]reward: {reward[0].item()}[/]" - f"\n[red]kl: {kl[0].sum().item()}[/]" - f"\n[red]average kl: {kl.sum(1).mean().item()}[/]" - ) - writer.add_scalar("train/kl", kl.sum(1).mean().item(), global_step) - - # eval on test_label - test_accuracies = [] - all_inds = np.arange(len(label)) - for start in range(args.labels.num_train, len(label), args.batch_size): - end = start + args.batch_size - b_inds_all = all_inds[start:end] - b_inds = b_inds_all[accelerator.process_index :: accelerator.num_processes] # multi-GPU slicing - mb_data = label[b_inds] - # print("accelerator.process_index", accelerator.process_index, b_inds, b_inds_all) - mb_query = torch.from_numpy(np.stack(mb_data["query"])) - mb_query = format_query(query_prefix_tokens, mb_query, query_suffix_tokens) - mb_query = left_padding_to_right_padding(mb_query, tokenizer.pad_token_id).to(device) - mb_best = torch.from_numpy(np.stack(mb_data["best"])).to(device) - mb_responses = [ - torch.from_numpy(np.stack(mb_data[f"sample{i}"])).to(device) for i in range(args.labels.num_labels) - ] - # hack: deal with openai's padding token - # assert (mb_query == tokenizer.pad_token_id).sum() == 0 - mb_query[mb_query == OPENAI_PAD_TOKEN_ID] = tokenizer.pad_token_id - for item in mb_responses: - # assert (item == tokenizer.pad_token_id).sum() == 0 - item[item == OPENAI_PAD_TOKEN_ID] = tokenizer.pad_token_id - - predicted_rewards = [] - for i in range(args.labels.num_labels): - query_responses = torch.cat([mb_query, mb_responses[i]], dim=1) - reward = get_reward(reward_model, query_responses, tokenizer)[1] - predicted_rewards.append(reward.squeeze()) - predicted_rewards = torch.stack( - predicted_rewards, dim=1 - ) # shape (batch_size, num_labels), basically a reward prediction for each label - accuracy = (predicted_rewards.argmax(1) == mb_best).float().mean() - test_accuracies.append(accuracy) - test_accuracy = accelerator.gather(torch.stack(test_accuracies).mean()).mean().item() - writer.add_scalar("test/accuracy", test_accuracy, global_step) - if accelerator.is_main_process: - print("test/accuracy", test_accuracy, global_step) - - torch.cuda.empty_cache() - if args.normalize_after: - normalize( - args, - accelerator, - device, - tokenizer, - accelerator.unwrap_model(reward_model).pretrained_model, - reward_model, - iter_dataloader, - generation_config, - query_prefix_tokens, - query_suffix_tokens, - ) - - # save model - if args.save_path: - os.makedirs(os.path.dirname(args.save_path), exist_ok=True) - torch.save(accelerator.unwrap_model(reward_model).state_dict(), args.save_path) - - if accelerator.is_main_process and args.track: - wandb.finish() - - -if __name__ == "__main__": - args = tyro.cli(Args) - train(args) diff --git a/lm_human_preference_details/summarization/train_reward_accelerate_summarize_debug.py b/lm_human_preference_details/summarization/train_reward_accelerate_summarize_debug.py deleted file mode 100644 index c04927b..0000000 --- a/lm_human_preference_details/summarization/train_reward_accelerate_summarize_debug.py +++ /dev/null @@ -1,981 +0,0 @@ -import os -import random -import time -from dataclasses import asdict, dataclass, field -from types import SimpleNamespace -from typing import List, Optional - -import numpy as np -import pandas as pd -import torch -import torch.nn as nn -import torch.nn.functional as F -import torch.optim as optim -import tyro -from accelerate import Accelerator -from accelerate.state import AcceleratorState -from datasets import load_dataset -from rich.console import Console -from rich.pretty import pprint -from rich.table import Table -from torch import Tensor, optim -from torch.optim.optimizer import ( - _dispatch_sqrt, - _get_value, - _use_grad_for_differentiable, -) -from torch.utils.data import DataLoader -from torch.utils.tensorboard import SummaryWriter -from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig - -from lm_human_preference_details.data import process_query - - -@dataclass -class AdaptiveKLParams: - target: float = 6.0 - horizon: int = 10000 # in episodes - - -@dataclass -class RewardHParams: - kl_coef: float = 0.15 - use_adaptive_kl: bool = True - adaptive_kl: Optional[AdaptiveKLParams] = field(default_factory=AdaptiveKLParams) - trained_model: Optional[str] = "models/reward.pt" - label_dataset: tyro.conf.Suppress[Optional[str]] = None - - -@dataclass -class PpoHParams: - total_episodes: int = 1000000 - local_batch_size: int = 64 - local_mini_batch_size: tyro.conf.Suppress[int] = None - batch_size: tyro.conf.Suppress[int] = None - mini_batch_size: tyro.conf.Suppress[int] = None - gradient_accumulation_steps: int = 1 - """gradient accumulation steps""" - local_micro_batch_size: tyro.conf.Suppress[int] = None - """per rank micro batch size""" - world_size: tyro.conf.Suppress[int] = None - batch_size: tyro.conf.Suppress[int] = None - minibatch_size: tyro.conf.Suppress[int] = None - num_updates: tyro.conf.Suppress[int] = None - nminibatches: int = 1 - noptepochs: int = 4 - lr: float = 0.00001 - eps: float = 1e-5 - vf_coef: float = 0.1 - cliprange: float = 0.2 - cliprange_value: float = 0.2 - gamma: float = 1 - lam: float = 0.95 - whiten_rewards: bool = True - - -@dataclass -class TaskHParams: - # Query params - query_length: int = 512 - query_dataset: str = "vwxyzjn/summarize_from_feedback_tldr_3_filtered" - - query_format_str: Optional[str] = "SUBREDDIT: r/{subreddit}\n\nTITLE: {title}\n\nPOST: {post}\n\nTL;DR:" - query_truncate_field: Optional[str] = "post" - query_truncate_text: Optional[str] = "\n" - query_padding: Optional[str] = None # defaults to repeated spaces - query_pad_side: Optional[str] = "left" - - # Response params - response_length: int = 48 - - # Truncate response after the first occurrence of this token at or after index after when sampling. - truncate_token: int = 50256 # EOS token - truncate_after: int = 16 - penalty_reward_value: int = -1 - - # LM params - temperature: float = 0.7 - - -# a patch -@dataclass -class TaskQueryHParams: - length: int = None - dataset: str = None - format_str: Optional[str] = None # if underlying dataset yields dicts, can format arbitrarily - truncate_field: Optional[str] = None - truncate_text: Optional[str] = None - padding: Optional[str] = None # defaults to repeated spaces - pad_side: Optional[str] = None - - -@dataclass -class Args: - # common args - exp_name: str = os.path.basename(__file__)[: -len(".py")] - """the name of this experiment""" - seed: int = 1 - """seed of the experiment""" - track: bool = False - """if toggled, this experiment will be tracked with Weights and Biases""" - wandb_project_name: str = "cleanrl" - """the wandb's project name""" - wandb_entity: Optional[str] = None - """the entity (team) of wandb's project""" - cuda: bool = True - """Whether to use cuda if available.""" - run_name: tyro.conf.Suppress[str] = None - """TO BE FILLED: a unique name of this run""" - upload_model: bool = False - "whether to upload the saved model to huggingface" - hf_entity: str = "" - "the user or org name of the model repository from the Hugging Face Hub" - - base_model: str = "gpt2" - """the name of the pretrained model to use""" - deepspeed: bool = False - """Whether to use deepspeed to train the model""" - print_sample_output_freq: int = 10 - """How often to print sample output""" - sft_model_path: str = "models/sft_policy.pt" - """Where to load the SFT model""" - save_path: str = "models/policy.pt" - """Where to save the model""" - use_tensorflow_adam: bool = True - """Whether to use tensorflow-style Adam optimizer instead of PyTorch's""" - task: TaskHParams = field(default_factory=TaskHParams) - rewards: RewardHParams = field(default_factory=RewardHParams) - ppo: PpoHParams = field(default_factory=PpoHParams) - - -def first_true_indices(bools, dtype=torch.long): - """ - Takes an N-dimensional bool tensor and returns an (N-1)-dimensional tensor of integers giving - the position of the first True in each "row". - - Returns the length of the rows (bools.size(-1)) if no element is True in a given row. - """ - row_len = bools.size(-1) - zero_or_index = row_len * (~bools).type(dtype) + torch.arange(row_len, dtype=dtype, device=bools.device) - return torch.min(zero_or_index, dim=-1).values - - -def print_rich_table(title: str, df: pd.DataFrame, console: Console) -> Table: - table = Table(show_lines=True) - for column in df.columns: - table.add_column(column) - for _, row in df.iterrows(): - table.add_row(*row.astype(str).tolist()) - console.rule(f"[bold red]{title}") - console.print(table) - - -def _single_tensor_adam( - params: List[Tensor], - grads: List[Tensor], - exp_avgs: List[Tensor], - exp_avg_sqs: List[Tensor], - max_exp_avg_sqs: List[Tensor], - state_steps: List[Tensor], - grad_scale: Optional[Tensor], - found_inf: Optional[Tensor], - *, - amsgrad: bool, - beta1: float, - beta2: float, - lr: float, - weight_decay: float, - eps: float, - maximize: bool, - capturable: bool, - differentiable: bool, -): - assert grad_scale is None and found_inf is None - - for i, param in enumerate(params): - grad = grads[i] if not maximize else -grads[i] - exp_avg = exp_avgs[i] - exp_avg_sq = exp_avg_sqs[i] - step_t = state_steps[i] - # update step - step_t += 1 - # Decay the first and second moment running average coefficient - exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) - exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2) - step = _get_value(step_t) - - ### pytorch adam implementation: - # bias_correction1 = 1 - beta1 ** step - # bias_correction2 = 1 - beta2 ** step - # step_size = lr / bias_correction1 - # bias_correction2_sqrt = _dispatch_sqrt(bias_correction2) - # denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps) - # param.addcdiv_(exp_avg, denom, value=-step_size) - - ### tensorflow adam implementation: - lr_t = lr * _dispatch_sqrt(1 - beta2**step) / (1 - beta1**step) - denom = exp_avg_sq.sqrt().add_(eps) - param.addcdiv_(exp_avg, denom, value=-lr_t) - - -def adam( - params: List[Tensor], - grads: List[Tensor], - exp_avgs: List[Tensor], - exp_avg_sqs: List[Tensor], - max_exp_avg_sqs: List[Tensor], - state_steps: List[Tensor], - # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627 - # setting this as kwarg for now as functional API is compiled by torch/distributed/optim - foreach: Optional[bool] = None, - capturable: bool = False, - differentiable: bool = False, - fused: Optional[bool] = None, - grad_scale: Optional[Tensor] = None, - found_inf: Optional[Tensor] = None, - *, - amsgrad: bool, - beta1: float, - beta2: float, - lr: float, - weight_decay: float, - eps: float, - maximize: bool, -): - func = _single_tensor_adam - - func( - params, - grads, - exp_avgs, - exp_avg_sqs, - max_exp_avg_sqs, - state_steps, - amsgrad=amsgrad, - beta1=beta1, - beta2=beta2, - lr=lr, - weight_decay=weight_decay, - eps=eps, - maximize=maximize, - capturable=capturable, - differentiable=differentiable, - grad_scale=grad_scale, - found_inf=found_inf, - ) - - -class AdamTensorFlowStyle(optim.Adam): - @_use_grad_for_differentiable - def step(self, closure=None): - self._cuda_graph_capture_health_check() - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - - for group in self.param_groups: - params_with_grad = [] - grads = [] - exp_avgs = [] - exp_avg_sqs = [] - max_exp_avg_sqs = [] - state_steps = [] - beta1, beta2 = group["betas"] - - self._init_group( - group, - params_with_grad, - grads, - exp_avgs, - exp_avg_sqs, - max_exp_avg_sqs, - state_steps, - ) - - adam( - params_with_grad, - grads, - exp_avgs, - exp_avg_sqs, - max_exp_avg_sqs, - state_steps, - amsgrad=group["amsgrad"], - beta1=beta1, - beta2=beta2, - lr=group["lr"], - weight_decay=group["weight_decay"], - eps=group["eps"], - maximize=group["maximize"], - foreach=group["foreach"], - capturable=group["capturable"], - differentiable=group["differentiable"], - fused=group["fused"], - grad_scale=getattr(self, "grad_scale", None), - found_inf=getattr(self, "found_inf", None), - ) - - return loss - - -class AdaptiveKLController: - def __init__(self, init_kl_coef: float, hparams: AdaptiveKLParams): - self.value = init_kl_coef - self.hparams = hparams - - def update(self, current, n_steps): - target = self.hparams.target - proportional_error = np.clip(current / target - 1, -0.2, 0.2) - mult = 1 + proportional_error * n_steps / self.hparams.horizon - self.value *= mult - - -def layer_init(layer, std=np.sqrt(2), bias_const=0.0): - torch.nn.init.normal_(layer.weight, std=std) - torch.nn.init.constant_(layer.bias, val=bias_const) - return layer - - -def whiten(values, shift_mean=True): - # `unbiased=False` matches TF `tf.nn.moments`'s setting - mean, var = torch.mean(values), torch.var(values, unbiased=False) - whitened = (values - mean) * torch.rsqrt(var + 1e-8) - if not shift_mean: - whitened += mean - return whitened - - -class AutoModelForCausalLMWithRewardHead(nn.Module): - def __init__(self, lm_backbone): - super().__init__() - self.lm_backbone = lm_backbone - self.scalar_head = layer_init( - nn.Linear(lm_backbone.config.hidden_size, 1), - std=1 / np.sqrt(lm_backbone.config.hidden_size + 1), - ) - self.reward_gain = torch.nn.Parameter(torch.tensor(1.0), requires_grad=True) - self.reward_bias = torch.nn.Parameter(torch.tensor(0.0), requires_grad=True) - - def forward(self, **kwargs): - output = self.lm_backbone(**kwargs) - reward_latents = output.hidden_states[-1] - # shape: [batch_size, length, hidden_size] - last_reward_latents = reward_latents - # shape: [batch_size, hidden_size] - reward = self.scalar_head(last_reward_latents) - # shape: [batch_size, 1] - reward = self.reward_gain * reward + self.reward_bias - return output, reward - - -def right_padding_to_left_padding(tokens, pad_id): - """Convert from right padding to left padding.""" - assert tokens.ndim == 2 - return torch.tensor( - [[pad_id] * (row == pad_id).sum() + [x for x in row if x != pad_id] for row in tokens], - device=tokens.device, - ) - - -def ceil_div(a, b): - return (a - 1) // b + 1 - - -def exact_div(a, b): - q = a // b - if a != q * b: - raise ValueError(f"Inexact division: {a} / {b} = {a / b}") - return q - - -def generate(lm_backbone, queries, tokenizer, generation_config): - """generate in a way that does not affect padding tokens""" - context_length = queries.shape[1] - attention_mask = queries != tokenizer.pad_token_id - input_ids = queries.clone() - input_ids[~attention_mask] = 0 # set padding tokens to 0 - output = lm_backbone.generate( - input_ids=input_ids, - attention_mask=attention_mask, - # position_ids=attention_mask.cumsum(1) - attention_mask.long(), # generation collapsed if this was turned on. TODO: why does generation collapse with this? - generation_config=generation_config, - return_dict_in_generate=True, - ) - # restore padding tokens - return torch.cat((queries, output.sequences[:, context_length:]), dim=1) - - -def get_reward(reward_model, query_responses, args): - attention_mask = query_responses != args.pad_token_id - position_ids = attention_mask.cumsum(1) - attention_mask.long() # exclusive cumsum - input_ids = query_responses.clone() - input_ids[~attention_mask] = 0 - return reward_model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - return_dict=True, - output_hidden_states=True, - ) - - -def get_reward_complete(reward_model, query_responses, args): - reward = get_reward(reward_model, query_responses, args)[1] - last_response_indices = first_true_indices(query_responses == args.pad_token_id) - 1 - last_response_indices = torch.max( - last_response_indices, - torch.zeros([1], dtype=last_response_indices.dtype, device=query_responses.device), - ) - return reward[:, :, 0].gather(1, last_response_indices.unsqueeze(1)).view(-1) - - -def normalize( - tokenizer, - accelerator, - device, - lm_backbone, - reward_model, - dataloader, - validation_dataloader, -): - idx = 0 - with torch.no_grad(): - # reset reward scales - # accelerator.unwrap_model(reward_model).reward_gain.data.fill_(1.0) - # accelerator.unwrap_model(reward_model).reward_bias.data.fill_(0.0) - # number of minibatches for computing the normalization statistics - rewards = [] - for data in dataloader: - idx += len(data["query_token"]) - queries = data["query_token"].to(device) - queries = right_padding_to_left_padding(queries, tokenizer.pad_token_id).to(device) - reference_response = data["reference_response"].to(device) - query_responses = torch.cat((queries, reference_response), dim=1) - score = get_reward_complete(reward_model, query_responses, tokenizer) - accelerator.print(score.shape, accelerator.gather(score).mean()) - rewards.append(score) - accelerator.print(f"====number of samples per device: {idx}") - rewards = torch.cat(rewards) - rewards = accelerator.gather(rewards) - mean, std = rewards.mean(), rewards.std() - print(f"mean: {mean}, std: {std}") - - # reward normalization - target_mean, target_std = torch.tensor(0.0, device=device), torch.tensor(1.0, device=device) - gain = target_std / std - bias = target_mean - gain * mean - print(f"gain: {gain}, bias: {bias}") - accelerator.unwrap_model(reward_model).reward_gain.data = gain - accelerator.unwrap_model(reward_model).reward_bias.data = bias - - # validate normalization - rewards = [] - for data in validation_dataloader: - queries = data["query_token"].to(device) - queries = right_padding_to_left_padding(queries, tokenizer.pad_token_id).to(device) - reference_response = data["reference_response"].to(device) - query_responses = torch.cat((queries, reference_response), dim=1) - score = get_reward_complete(reward_model, query_responses, tokenizer) - rewards.append(score) - rewards = torch.cat(rewards) - rewards = accelerator.gather(rewards) - mean, std = rewards.mean(), rewards.std() - print(f"after mean: {mean}, after std: {std}") - - -def forward(policy, query_responses, tokenizer): - attention_mask = query_responses != tokenizer.pad_token_id - position_ids = attention_mask.cumsum(1) - attention_mask.long() # exclusive cumsum - input_ids = query_responses.clone() - input_ids[~attention_mask] = 0 - return policy( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - return_dict=True, - output_hidden_states=True, - ) - - -if __name__ == "__main__": - args = tyro.cli(Args) - accelerator = Accelerator(gradient_accumulation_steps=args.ppo.gradient_accumulation_steps) - args.ppo.world_size = accelerator.num_processes - args.ppo.batch_size = int(args.ppo.local_batch_size * args.ppo.world_size) - args.ppo.minibatch_size = exact_div(args.ppo.batch_size, args.ppo.nminibatches) - args.ppo.local_mini_batch_size = exact_div(args.ppo.local_batch_size, args.ppo.nminibatches) - args.ppo.local_micro_batch_size = exact_div(args.ppo.local_mini_batch_size, args.ppo.gradient_accumulation_steps) - patch_h = TaskQueryHParams( - length=args.task.query_length, - dataset=args.task.query_dataset, - format_str=args.task.query_format_str, - truncate_field=args.task.query_truncate_field, - truncate_text=args.task.query_truncate_text, - padding=args.task.query_padding, - pad_side=args.task.query_pad_side, - ) - if args.ppo.whiten_rewards: - assert ( - args.ppo.local_mini_batch_size >= 8 - ), f"Per-rank minibatch size {args.ppo.local_mini_batch_size} is insufficient for whitening" - # `per_rank_rollout_batch_size` is our `args.ppo.local_batch_size` - # `per_rank_minibatch_size` is our `args.ppo.local_mini_batch_size` - args.ppo.num_updates = args.ppo.total_episodes // args.ppo.batch_size - - console = Console(force_terminal=True) - run_name = f"{args.exp_name}__{args.seed}__{int(time.time())}" - writer = SimpleNamespace() # dummy writer - writer.add_scalar = lambda x, y, z: None - writer.add_histogram = lambda x, y, z: None - if accelerator.is_main_process: - if args.track: - import wandb - - wandb.init( - project=args.wandb_project_name, - entity=args.wandb_entity, - sync_tensorboard=True, - config=asdict(args), - name=run_name, - save_code=True, - ) - wandb.run.log_code(".") - writer = SummaryWriter(f"runs/{run_name}") - writer.add_text( - "hyperparameters", - "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])), - ) - pprint(args) - device = accelerator.device - local_seed = args.seed + accelerator.process_index * 100003 # Prime - random.seed(local_seed) - np.random.seed(local_seed) - torch.manual_seed(local_seed) - torch.backends.cudnn.deterministic = True - tokenizer = AutoTokenizer.from_pretrained( - args.base_model, - padding_side="right", - trust_remote_code=True, - ) - # we use the padding token manually but do not resize the token embedding of the model - tokenizer.add_special_tokens({"pad_token": "[PAD]"}) - reward_model = AutoModelForCausalLMWithRewardHead( - AutoModelForCausalLM.from_pretrained(args.base_model, trust_remote_code=True) - ) - if args.rewards.trained_model: - reward_model.load_state_dict(torch.load(args.rewards.trained_model, map_location=device)) - print(f"loaded pretrained reward model from {args.rewards.trained_model}") - # # each class should have a separate pretrained model that do not share weights - # ref_policy = AutoModelForCausalLMWithScalarHead( - # AutoModelForCausalLM.from_pretrained(args.base_model, trust_remote_code=True) - # ) - # policy = AutoModelForCausalLMWithScalarHead(AutoModelForCausalLM.from_pretrained(args.base_model, trust_remote_code=True)) - # if args.sft_model_path: - # policy.lm_backbone.load_state_dict(torch.load(args.sft_model_path, map_location=device)) - # ref_policy.lm_backbone.load_state_dict(torch.load(args.sft_model_path, map_location=device)) - # print(f"loaded pretrained policy from {args.sft_model_path}") - # policy.lm_backbone.generation_config.eos_token_id = ( - # None # disable `pad_token_id` and `eos_token_id` because we just want to - # ) - # policy.lm_backbone.generation_config.pad_token_id = None # generate tokens without truncation / padding - # # IMPORTANT: Layer norm produces weird gradients, which affects Adam optimizer to impact all the parameters systematically - # # see https://github.com/pytorch/pytorch/issues/104857 for more details - # if args.use_tensorflow_adam: - # optimizer = AdamTensorFlowStyle(policy.parameters(), lr=args.ppo.lr, eps=args.ppo.eps) - # else: - # optimizer = optim.Adam(policy.parameters(), lr=args.ppo.lr, eps=args.ppo.eps) - dataset = load_dataset(args.task.query_dataset, split="train") - validation_dataset = load_dataset(args.task.query_dataset, split="validation") - - def process_query_data(x): - return { - **process_query(x, encoder=tokenizer, hparams=patch_h), - "reference_response": tokenizer.encode( - f" {x['summary']}", - padding="max_length", - max_length=args.task.response_length, - truncation=True, - # with an extra leading space to account for the space between the query and response - ), - } - - dataset = dataset.map(process_query_data) - dataset = dataset.with_format("torch", columns=["query_token", "reference_response"]) - dataset = dataset.shuffle(seed=local_seed) - dataloader = DataLoader(dataset, batch_size=args.ppo.local_batch_size) - validation_dataset = validation_dataset.map(process_query_data) - validation_dataset = validation_dataset.with_format("torch", columns=["query_token", "reference_response"]) - validation_dataset = validation_dataset.shuffle(seed=local_seed) - validation_dataloader = DataLoader(validation_dataset, batch_size=args.ppo.local_batch_size) - dataloader = accelerator.prepare(dataloader) - if args.deepspeed: - import deepspeed - - deepspeed_states = AcceleratorState().deepspeed_plugin - # deepspeed_states.deepspeed_config["train_micro_batch_size_per_gpu"] = args.ppo.local_micro_batch_size - # deepspeed_states.deepspeed_config["checkpoint"] = {"use_node_local_storage": True} - eval_ds_config = { - "train_micro_batch_size_per_gpu": deepspeed_states.deepspeed_config["train_micro_batch_size_per_gpu"], - # "steps_per_print": 10, - # "zero_optimization": { - # "stage": stage, - # "stage3_param_persistence_threshold": 1e4, - # "offload_param": { - # "device": off_load_device - # } - # }, - "bf16": {"enabled": True}, - "prescale_gradients": False, - "wall_clock_breakdown": False, - } - reward_model, *_ = deepspeed.initialize(model=reward_model, config=eval_ds_config) - reward_model.eval() - else: - reward_model = reward_model.to(device) - - def repeat_generator(): # TODO: ideally we shuffle the dataloader as well - while True: - yield from dataloader - - iter_dataloader = iter(repeat_generator()) - kl_ctl = AdaptiveKLController(args.rewards.kl_coef, hparams=args.rewards.adaptive_kl) - # WARNING: even with `max_new_tokens` and `min_new_tokens` set to the same value, the number of tokens generated - # may not be the same. TODO: investigate further, we just want to generate a fixed number of tokens - generation_config = GenerationConfig( - max_new_tokens=args.task.response_length, - min_new_tokens=args.task.response_length, - temperature=(args.task.temperature + 1e-7), - top_k=0.0, - top_p=1.0, - do_sample=True, - ) - - print("===Normalize reward model *before* training===") - print( - "before normalization. " - + f"Gain: {accelerator.unwrap_model(reward_model).reward_gain.data}" - + f" Bias: {accelerator.unwrap_model(reward_model).reward_bias.data}" - ) - - normalize( - tokenizer, - accelerator, - device, - reward_model, - reward_model, - dataloader, - validation_dataloader, - ) - print( - "after normalization. " - + f"Gain: {accelerator.unwrap_model(reward_model).reward_gain.data}" - + f" Bias: {accelerator.unwrap_model(reward_model).reward_bias.data}" - ) - # # save model - # if args.save_path: - # os.makedirs(os.path.dirname("models/correct_reward.pt"), exist_ok=True) - # torch.save(accelerator.unwrap_model(reward_model).state_dict(), "models/correct_reward.pt") - raise - - print("===training policy===") - global_step = 0 - stats_shape = (args.ppo.noptepochs, args.ppo.nminibatches, args.ppo.gradient_accumulation_steps) - approxkls_stats = torch.zeros(stats_shape, device=device) - clipfracs_stats = torch.zeros(stats_shape, device=device) - pg_losses_stats = torch.zeros(stats_shape, device=device) - vf_losses_stats = torch.zeros(stats_shape, device=device) - vf_clipfrac_stats = torch.zeros(stats_shape, device=device) - entropies_stats = torch.zeros(stats_shape, device=device) - for update in range(1, args.ppo.num_updates + 1): - global_step += 1 * args.ppo.batch_size - frac = 1.0 - (update - 1.0) / args.ppo.num_updates - lrnow = frac * args.ppo.lr - optimizer.param_groups[0]["lr"] = lrnow - data = next(iter_dataloader) - with torch.no_grad(): - queries = data["query_token"].to(device) - reference_responses = data["reference_response"].to(device) - query_reference_responses = torch.cat((queries, reference_responses), dim=1) - queries = right_padding_to_left_padding(data["query_token"], tokenizer.pad_token_id).to(device) - query_reference_responses = right_padding_to_left_padding(query_reference_responses, tokenizer.pad_token_id).to( - device - ) - query_responses = generate( - accelerator.unwrap_model(policy).lm_backbone, - queries, - tokenizer, - generation_config, - ) - context_length = queries.shape[1] - responses = query_responses[:, context_length:] - - output, full_values = forward(policy, query_responses, tokenizer) - values = full_values[:, context_length - 1 : -1].squeeze(-1) - logits = output.logits[:, context_length - 1 : -1] - logits /= args.task.temperature + 1e-7 - all_logprobs = F.log_softmax(logits, dim=-1) - logprobs = torch.gather(all_logprobs, 2, responses.unsqueeze(-1)).squeeze(-1) - del output, logits, all_logprobs - torch.cuda.empty_cache() - - ref_output, _ = forward(ref_policy, query_responses, tokenizer) - ref_logits = ref_output.logits[:, context_length - 1 : -1] - ref_logits /= args.task.temperature + 1e-7 - ref_all_logprobs = F.log_softmax(ref_logits, dim=-1) - ref_logprobs = torch.gather(ref_all_logprobs, 2, responses.unsqueeze(-1)).squeeze(-1) - del ref_output, ref_logits, ref_all_logprobs - torch.cuda.empty_cache() - - # **Response Processing** - # 1. truncate at the first occurrence of `truncate_token` that appears at or after - # position truncate_after in the responses - # https://github.com/openai/lm-human-preferences/blob/cbfd210bb8b08f6bc5c26878c10984b90f516c66/lm_human_preferences/train_policy.py#L378 - truncate_token_mask = responses == args.task.truncate_token - truncate_after_or_token_mask = torch.cat( - [ - torch.zeros_like(truncate_token_mask)[:, : args.task.truncate_after], - truncate_token_mask[:, args.task.truncate_after :], - ], - dim=1, - ) - truncate_mask = (torch.cumsum(truncate_after_or_token_mask, dim=1) - truncate_after_or_token_mask.long()).bool() - postprocessed_responses = torch.where( - truncate_mask, - torch.full_like(responses, tokenizer.pad_token_id), - responses, - ) - del truncate_token_mask, truncate_after_or_token_mask, truncate_mask - torch.cuda.empty_cache() - - # 2. run reward model on the truncated responses - postprocessed_query_responses = torch.cat((queries, postprocessed_responses), 1) - postprocessed_query_responses = right_padding_to_left_padding( - postprocessed_query_responses, tokenizer.pad_token_id - ) - scores = get_reward(reward_model, postprocessed_query_responses, tokenizer)[1] - last_response_indices = first_true_indices(query_responses == tokenizer.pad_token_id) - 1 - last_response_indices = torch.max( - last_response_indices, - torch.zeros([1], dtype=last_response_indices.dtype, device=query_responses.device), - ) - scores = scores[:, :, 0].gather(1, last_response_indices.unsqueeze(1)).view(-1) - - reference_scores = get_reward(reward_model, query_reference_responses, tokenizer)[1] - last_reference_response_indices = first_true_indices(query_reference_responses == tokenizer.pad_token_id) - 1 - last_reference_response_indices = torch.max( - last_reference_response_indices, - torch.zeros([1], dtype=last_reference_response_indices.dtype, device=query_reference_responses.device), - ) - reference_scores = reference_scores[:, :, 0].gather(1, last_reference_response_indices.unsqueeze(1)).view(-1) - - print(reference_scores.mean()) - # normalization again - scores = scores - reference_scores - - # 3. filter response. Ensure that the sample contains truncate_token - # responses not passing that filter will receive a low (fixed) score - # only query humans on responses that pass that filter - matches_token = postprocessed_responses[:, args.task.truncate_after :] == args.task.truncate_token - filter_mask = torch.any(matches_token, dim=-1) - scores = torch.where( - filter_mask, - scores, - torch.full_like(scores, args.task.penalty_reward_value), - ) - del matches_token, filter_mask - torch.cuda.empty_cache() - - # 4. compute rewards - kl = logprobs - ref_logprobs - non_score_reward = -kl_ctl.value * kl - rewards = non_score_reward.clone() - rewards[:, -1] += scores - - # 5. whiten rewards - if args.ppo.whiten_rewards: - rewards = whiten(rewards, shift_mean=False) - - if args.print_sample_output_freq > 0 and (update - 1) % args.print_sample_output_freq == 0: - try: - all_decode_queries = tokenizer.batch_decode(queries, skip_special_tokens=True) - all_postprocessed_query_responses = tokenizer.batch_decode( - postprocessed_query_responses, skip_special_tokens=True - ) - all_postprocessed_responses = [ - x[len(y) :] for x, y in zip(all_postprocessed_query_responses, all_decode_queries) - ] - all_reference_responses = tokenizer.batch_decode(reference_responses, skip_special_tokens=True) - - kl_sum = kl.sum(axis=1) - all_df = pd.DataFrame( - { - "query": all_decode_queries, - "response": all_postprocessed_responses, - "reference_responses": all_reference_responses, - "score": scores.float().cpu().numpy(), - "reference_scores": reference_scores.float().cpu().numpy(), - "kl": kl_sum.float().cpu().numpy(), - "reward": (scores - kl_ctl.value * kl_sum).float().cpu().numpy(), - } - ) - if accelerator.is_main_process and args.track: - wandb.log({"query_responses": wandb.Table(dataframe=all_df)}, step=update) - print_rich_table("stuff", all_df[:4], console) - except Exception as e: - print(e) - del ( - all_decode_queries, - all_postprocessed_query_responses, - all_postprocessed_responses, - kl_sum, - all_df, - ) - del postprocessed_query_responses - torch.cuda.empty_cache() - -# # 6. compute advantages and returns -# lastgaelam = 0 -# advantages_reversed = [] -# gen_length = args.task.response_length -# for t in reversed(range(gen_length)): -# nextvalues = values[:, t + 1] if t < gen_length - 1 else 0.0 -# delta = rewards[:, t] + args.ppo.gamma * nextvalues - values[:, t] -# lastgaelam = delta + args.ppo.gamma * args.ppo.lam * lastgaelam -# advantages_reversed.append(lastgaelam) -# advantages = torch.stack(advantages_reversed[::-1], axis=1) -# returns = advantages + values -# advantages = whiten(advantages) -# return_mean, return_var = returns.mean(), returns.var() -# value_mean, value_var = values.mean(), values.var() - -# # Do multiple epochs of PPO training, with a fresh random shuffle in each epoch -# for ppo_epoch_idx in range(args.ppo.noptepochs): -# b_inds = np.random.permutation(args.ppo.local_batch_size) -# minibatch_idx = 0 -# for mini_batch_start in range(0, args.ppo.local_batch_size, args.ppo.local_mini_batch_size): -# mini_batch_end = mini_batch_start + args.ppo.local_mini_batch_size -# mini_batch_inds = b_inds[mini_batch_start:mini_batch_end] -# gradient_accumulation_idx = 0 -# for micro_batch_start in range(0, args.ppo.local_mini_batch_size, args.ppo.local_micro_batch_size): -# with accelerator.accumulate(policy): -# micro_batch_end = micro_batch_start + args.ppo.local_micro_batch_size -# micro_batch_inds = mini_batch_inds[micro_batch_start:micro_batch_end] -# mb_return = returns[micro_batch_inds] -# mb_advantage = advantages[micro_batch_inds] -# mb_values = values[micro_batch_inds] -# mb_responses = responses[micro_batch_inds] -# mb_query_responses = query_responses[micro_batch_inds] -# mb_logprobs = logprobs[micro_batch_inds] - -# output, vpred_temp = forward(policy, mb_query_responses, tokenizer) -# logits = output.logits[:, context_length - 1 : -1] -# logits /= (args.task.temperature + 1e-7) -# new_all_logprobs = F.log_softmax(logits, dim=-1) -# new_logprobs = torch.gather(new_all_logprobs, 2, mb_responses.unsqueeze(-1)).squeeze(-1) -# vpred = vpred_temp[:, context_length - 1 : -1].squeeze(-1) -# vpredclipped = torch.clamp( -# vpred, -# mb_values - args.ppo.cliprange_value, -# mb_values + args.ppo.cliprange_value, -# ) -# vf_losses1 = torch.square(vpred - mb_return) -# vf_losses2 = torch.square(vpredclipped - mb_return) -# vf_loss = 0.5 * torch.max(vf_losses1, vf_losses2).mean() -# vf_clipfrac = (vf_losses2 > vf_losses1).float().mean() -# logprobs_diff = new_logprobs - mb_logprobs -# ratio = torch.exp(logprobs_diff) -# pg_losses = -mb_advantage * ratio -# pg_losses2 = -mb_advantage * torch.clamp(ratio, 1.0 - args.ppo.cliprange, 1.0 + args.ppo.cliprange) -# pg_loss = torch.max(pg_losses, pg_losses2).mean() -# pg_clipfrac = (pg_losses2 > pg_losses).float().mean() -# loss = pg_loss + args.ppo.vf_coef * vf_loss -# accelerator.backward(loss) -# optimizer.step() -# optimizer.zero_grad() -# prob_dist = torch.nn.functional.softmax(logits, dim=-1) -# entropy = torch.logsumexp(logits, dim=-1) - torch.sum(prob_dist * logits, dim=-1) -# approxkl = 0.5 * (logprobs_diff**2).mean() -# with torch.no_grad(): -# approxkls_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = approxkl -# clipfracs_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = pg_clipfrac -# pg_losses_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = pg_loss -# vf_losses_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = vf_loss -# vf_clipfrac_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = vf_clipfrac -# entropies_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = entropy.mean() -# gradient_accumulation_idx += 1 -# minibatch_idx += 1 -# if accelerator.is_main_process: -# console.print( -# f"ppo_epoch_idx", -# ppo_epoch_idx, -# "approxkl", -# approxkl.item(), -# "pg_loss", -# pg_loss.item(), -# "pg_clipfrac", -# pg_clipfrac.item(), -# "ratio", -# ratio.mean().item(), -# ) - -# with torch.no_grad(): -# if not args.deepspeed: # for some reason there is a OOM with the `writer.add_histogram` -# writer.add_histogram("ppo/val/ratio_hist", ratio, update) -# kl = logprobs - ref_logprobs -# mean_kl = kl.sum(1).mean() -# mean_entropy = (-logprobs).sum(1).mean() -# mean_non_score_reward = non_score_reward.sum(1).mean() -# writer.add_scalar("objective/kl_coef", kl_ctl.value, update) -# writer.add_scalar("objective/kl", accelerator.gather(mean_kl).mean().item(), update) -# writer.add_scalar("objective/entropy", accelerator.gather(mean_entropy).mean().item(), update) -# writer.add_scalar("objective/non_score_reward", accelerator.gather(mean_non_score_reward).mean().item(), update) -# writer.add_scalar( -# "objective/score_total", accelerator.gather(mean_non_score_reward + scores.mean()).mean().item(), update -# ) -# writer.add_scalar("objective/scores", accelerator.gather(scores.mean()).mean().item(), update) -# writer.add_scalar("objective/reference_scores", accelerator.gather(reference_scores.mean()).mean().item(), update) -# writer.add_scalar("ppo/loss/policy", accelerator.gather(pg_loss).mean().item(), update) -# writer.add_scalar("ppo/loss/value", accelerator.gather(vf_loss).mean().item(), update) -# writer.add_scalar("ppo/loss/total", accelerator.gather(loss).mean().item(), update) -# writer.add_scalar("ppo/policy/entropy", accelerator.gather(entropy.mean()).mean().item(), update) -# writer.add_scalar("ppo/policy/approxkl", accelerator.gather(approxkl).mean().item(), update) -# writer.add_scalar("ppo/policy/clipfrac", accelerator.gather(pg_clipfrac).mean().item(), update) -# writer.add_scalar("ppo/policy/approxkl_avg", accelerator.gather(approxkls_stats).mean().item(), update) -# writer.add_scalar("ppo/policy/clipfrac_avg", accelerator.gather(clipfracs_stats).mean().item(), update) -# writer.add_scalar("ppo/loss/policy_avg", accelerator.gather(pg_losses_stats).mean().item(), update) -# writer.add_scalar("ppo/loss/value_avg", accelerator.gather(vf_losses_stats).mean().item(), update) -# writer.add_scalar("ppo/val/clipfrac_avg", accelerator.gather(vf_clipfrac_stats).mean().item(), update) -# writer.add_scalar("ppo/policy/entropy_avg", accelerator.gather(entropies_stats).mean().item(), update) -# writer.add_scalar("ppo/returns/mean", accelerator.gather(return_mean).mean().item(), update) -# writer.add_scalar("ppo/returns/var", accelerator.gather(return_var).mean().item(), update) -# writer.add_scalar("ppo/val/vpred", accelerator.gather(vpred.mean()).mean().item(), update) -# writer.add_scalar("ppo/val/error", accelerator.gather(vf_losses1.mean()).mean().item(), update) -# writer.add_scalar("ppo/val/clipfrac", accelerator.gather(vf_clipfrac).mean().item(), update) -# writer.add_scalar("ppo/val/mean", accelerator.gather(value_mean).mean().item(), update) -# writer.add_scalar("ppo/val/var", accelerator.gather(value_var).mean().item(), update) -# writer.add_scalar("ppo/val/ratio", accelerator.gather(ratio.mean()).mean().item(), update) -# writer.add_scalar("ppo/val/ratio_var", accelerator.gather(ratio.mean()).var().item(), update) -# writer.add_scalar("ppo/val/advantage", accelerator.gather(advantages.mean()).mean().item(), update) -# writer.add_scalar("ppo/val/advantage_var", accelerator.gather(advantages.mean()).var().item(), update) -# writer.add_scalar("ppo/val/num_eos_tokens", (responses == tokenizer.eos_token_id).sum().item(), update) -# writer.add_scalar("ppo/lr", lrnow, update) -# writer.add_scalar("ppo/episode", global_step, update) -# if args.rewards.use_adaptive_kl: -# kl_ctl.update(mean_kl.item(), args.ppo.batch_size) -# del kl, mean_kl, mean_entropy, mean_non_score_reward, scores - -# # save model -# if accelerator.is_main_process and args.save_path: -# os.makedirs(os.path.dirname(args.save_path), exist_ok=True) -# torch.save(policy.state_dict(), args.save_path) - -# if args.upload_model: -# repo_name = f"{args.exp_name}__{args.rewards.label_dataset}__seed{args.seed}__{int(time.time())}" -# repo_id = f"{args.hf_entity}/{repo_name}" if args.hf_entity else repo_name -# policy.lm_backbone.save_pretrained(repo_id, safe_serialization=True, push_to_hub=True) -# tokenizer.save_pretrained(repo_id, push_to_hub=True) - - -# if __name__ == "__main__": -# args = tyro.cli(Args) -# train(args) diff --git a/lm_human_preference_details/summarization/train_reward_accelerate_summarized.py b/lm_human_preference_details/summarization/train_reward_accelerate_summarized.py deleted file mode 100644 index 4623a57..0000000 --- a/lm_human_preference_details/summarization/train_reward_accelerate_summarized.py +++ /dev/null @@ -1,785 +0,0 @@ -import os -import random -import time -from dataclasses import asdict, dataclass, field -from types import SimpleNamespace -from typing import List, Optional - -import numpy as np -import pandas as pd -import torch -import torch.nn as nn -import torch.nn.functional as F -import torch.optim as optim -import transformers -import tyro -from accelerate import Accelerator -from accelerate.utils import DistributedDataParallelKwargs, broadcast -from datasets import load_dataset -from rich.console import Console -from rich.pretty import pprint -from rich.table import Table -from torch import Tensor, optim -from torch.optim.optimizer import ( - _dispatch_sqrt, - _get_value, - _use_grad_for_differentiable, -) -from torch.utils.data import DataLoader -from torch.utils.tensorboard import SummaryWriter -from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig - -from lm_human_preference_details.data import process_query - - -@dataclass -class LabelHParams: - type: str = None - num_train: int = 64832 - num_labels: int = 2 - source: str = None - - -@dataclass -class TaskHParams: - # Query params - query_length: int = 512 - query_dataset: str = "vwxyzjn/summarize_from_feedback_tldr_3_filtered" - - query_format_str: Optional[str] = "SUBREDDIT: r/{subreddit}\n\nTITLE: {title}\n\nPOST: {post}\n\nTL;DR:" - query_truncate_field: Optional[str] = "post" - query_truncate_text: Optional[str] = "\n" - query_padding: Optional[str] = None # defaults to repeated spaces - query_pad_side: Optional[str] = "left" - - # Response params - response_length: int = 48 - - # LM params - temperature: float = 0.7 - - -# a patch -@dataclass -class TaskQueryHParams: - length: int = None - dataset: str = None - format_str: Optional[str] = None # if underlying dataset yields dicts, can format arbitrarily - truncate_field: Optional[str] = None - truncate_text: Optional[str] = None - padding: Optional[str] = None # defaults to repeated spaces - pad_side: Optional[str] = None - - -@dataclass -class Args: - # common args - exp_name: str = os.path.basename(__file__)[: -len(".py")] - """the name of this experiment""" - seed: int = 1 - """seed of the experiment""" - track: bool = False - """if toggled, this experiment will be tracked with Weights and Biases""" - wandb_project_name: str = "cleanrl" - """the wandb's project name""" - wandb_entity: Optional[str] = None - """the entity (team) of wandb's project""" - cuda: bool = True - """Whether to use cuda if available.""" - run_name: tyro.conf.Suppress[str] = None - """TO BE FILLED: a unique name of this run""" - load_from_cache_file: bool = True - """Whether to load data from the local cache file in `dataset.map`""" - - base_model: str = "gpt2" - """the name of the pretrained model to use""" - deepspeed: bool = False - """Whether to use deepspeed to train the model""" - label_dataset: str = "openai/summarize_from_feedback" - """the name of the dataset to use for labels in `https://huggingface.co/datasets/vwxyzjn/lm-human-preferences`""" - local_batch_size: int = 4 - """per rank batch size""" - gradient_accumulation_steps: int = 1 - """gradient accumulation steps""" - local_micro_batch_size: tyro.conf.Suppress[int] = None - """per rank micro batch size""" - lr: float = 0.00005 - """the learning rate""" - eps: float = 1e-5 - """the epsilon for AdamW""" - local_rollout_batch_size: int = 512 - """per rank rollout batch size""" - rollout_batch_size: tyro.conf.Suppress[int] = None - """rollout batch size""" - world_size: tyro.conf.Suppress[int] = None - """the number of processes to use""" - batch_size: tyro.conf.Suppress[int] = None - """the batch size across all ranks""" - local_normalize_samples: int = 256 - """Samples used to estimate reward mean and std""" - normalize_samples: tyro.conf.Suppress[int] = None - """Samples used to estimate reward mean and std across all ranks""" - debug_normalize: int = 0 - """Samples used to check that normalization worked""" - normalize_before: bool = True - """Whether, before training, to normalize the rewards on the policy to the scales on the training buffer. (For comparisons, just use mean 0, var 1.)""" - normalize_after: bool = True - """Whether, after training, to normalize the rewards on the ref policy to mean 0, var 1 (so the KL coefficient always has the same meaning).""" - print_sample_output_freq: int = 506 - """How often to print sample output""" - sft_model_path: str = "models/sft_policy.pt" - """Where to load the SFT model""" - logsigmoid: bool = True - """Whether to use log-sigmoid loss instead of cross-entropy loss""" - trainable_param_percentage: float = 1.0 - """Percentage of parameters to train""" - save_path: str = "models/reward.pt" - """Where to save the model""" - use_tensorflow_adam: bool = True - """Whether to use tensorflow-style Adam optimizer instead of PyTorch's""" - task: TaskHParams = field(default_factory=TaskHParams) - labels: LabelHParams = field(default_factory=LabelHParams) - - -def first_true_indices(bools, dtype=torch.long): - """ - Takes an N-dimensional bool tensor and returns an (N-1)-dimensional tensor of integers giving - the position of the first True in each "row". - - Returns the length of the rows (bools.size(-1)) if no element is True in a given row. - """ - row_len = bools.size(-1) - zero_or_index = row_len * (~bools).type(dtype) + torch.arange(row_len, dtype=dtype, device=bools.device) - return torch.min(zero_or_index, dim=-1).values - - -def print_rich_table(title: str, df: pd.DataFrame, console: Console) -> Table: - table = Table(show_lines=True) - for column in df.columns: - table.add_column(column) - for _, row in df.iterrows(): - table.add_row(*row.astype(str).tolist()) - console.rule(f"[bold red]{title}") - console.print(table) - - -def _single_tensor_adam( - params: List[Tensor], - grads: List[Tensor], - exp_avgs: List[Tensor], - exp_avg_sqs: List[Tensor], - max_exp_avg_sqs: List[Tensor], - state_steps: List[Tensor], - grad_scale: Optional[Tensor], - found_inf: Optional[Tensor], - *, - amsgrad: bool, - beta1: float, - beta2: float, - lr: float, - weight_decay: float, - eps: float, - maximize: bool, - capturable: bool, - differentiable: bool, -): - assert grad_scale is None and found_inf is None - - for i, param in enumerate(params): - grad = grads[i] if not maximize else -grads[i] - exp_avg = exp_avgs[i] - exp_avg_sq = exp_avg_sqs[i] - step_t = state_steps[i] - # update step - step_t += 1 - # Decay the first and second moment running average coefficient - exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) - exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2) - step = _get_value(step_t) - - ### pytorch adam implementation: - # bias_correction1 = 1 - beta1 ** step - # bias_correction2 = 1 - beta2 ** step - # step_size = lr / bias_correction1 - # bias_correction2_sqrt = _dispatch_sqrt(bias_correction2) - # denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps) - # param.addcdiv_(exp_avg, denom, value=-step_size) - - ### tensorflow adam implementation: - lr_t = lr * _dispatch_sqrt(1 - beta2**step) / (1 - beta1**step) - denom = exp_avg_sq.sqrt().add_(eps) - param.addcdiv_(exp_avg, denom, value=-lr_t) - - -def adam( - params: List[Tensor], - grads: List[Tensor], - exp_avgs: List[Tensor], - exp_avg_sqs: List[Tensor], - max_exp_avg_sqs: List[Tensor], - state_steps: List[Tensor], - # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627 - # setting this as kwarg for now as functional API is compiled by torch/distributed/optim - foreach: Optional[bool] = None, - capturable: bool = False, - differentiable: bool = False, - fused: Optional[bool] = None, - grad_scale: Optional[Tensor] = None, - found_inf: Optional[Tensor] = None, - *, - amsgrad: bool, - beta1: float, - beta2: float, - lr: float, - weight_decay: float, - eps: float, - maximize: bool, -): - func = _single_tensor_adam - - func( - params, - grads, - exp_avgs, - exp_avg_sqs, - max_exp_avg_sqs, - state_steps, - amsgrad=amsgrad, - beta1=beta1, - beta2=beta2, - lr=lr, - weight_decay=weight_decay, - eps=eps, - maximize=maximize, - capturable=capturable, - differentiable=differentiable, - grad_scale=grad_scale, - found_inf=found_inf, - ) - - -class AdamTensorFlowStyle(optim.Adam): - @_use_grad_for_differentiable - def step(self, closure=None): - self._cuda_graph_capture_health_check() - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - - for group in self.param_groups: - params_with_grad = [] - grads = [] - exp_avgs = [] - exp_avg_sqs = [] - max_exp_avg_sqs = [] - state_steps = [] - beta1, beta2 = group["betas"] - - self._init_group( - group, - params_with_grad, - grads, - exp_avgs, - exp_avg_sqs, - max_exp_avg_sqs, - state_steps, - ) - - adam( - params_with_grad, - grads, - exp_avgs, - exp_avg_sqs, - max_exp_avg_sqs, - state_steps, - amsgrad=group["amsgrad"], - beta1=beta1, - beta2=beta2, - lr=group["lr"], - weight_decay=group["weight_decay"], - eps=group["eps"], - maximize=group["maximize"], - foreach=group["foreach"], - capturable=group["capturable"], - differentiable=group["differentiable"], - fused=group["fused"], - grad_scale=getattr(self, "grad_scale", None), - found_inf=getattr(self, "found_inf", None), - ) - - return loss - - -def layer_init(layer, std=np.sqrt(2), bias_const=0.0): - torch.nn.init.normal_(layer.weight, std=std) - torch.nn.init.constant_(layer.bias, val=bias_const) - return layer - - -class AutoModelForCausalLMWithRewardHead(nn.Module): - def __init__(self, lm_backbone): - super().__init__() - self.lm_backbone = lm_backbone - self.scalar_head = layer_init( - nn.Linear(lm_backbone.config.hidden_size, 1), - std=1 / np.sqrt(lm_backbone.config.hidden_size + 1), - ) - # self.reward_gain = torch.nn.Parameter(torch.tensor(1.0), requires_grad=True) - # self.reward_bias = torch.nn.Parameter(torch.tensor(0.0), requires_grad=True) - - def forward(self, **kwargs): - output = self.lm_backbone(**kwargs) - last_reward_latents = output.hidden_states[-1] - # shape: [batch_size, hidden_size] - reward = self.scalar_head(last_reward_latents) - return output, reward - - -def right_padding_to_left_padding(tokens, pad_id): - """Convert from right padding to left padding.""" - assert tokens.ndim == 2 - return torch.tensor( - [[pad_id] * (row == pad_id).sum() + [x for x in row if x != pad_id] for row in tokens], - device=tokens.device, - ) - - -def ceil_div(a, b): - return (a - 1) // b + 1 - - -def exact_div(a, b): - q = a // b - if a != q * b: - raise ValueError(f"Inexact division: {a} / {b} = {a / b}") - return q - - -def generate(lm_backbone, queries, tokenizer, generation_config): - """generate in a way that does not affect padding tokens""" - context_length = queries.shape[1] - attention_mask = queries != tokenizer.pad_token_id - input_ids = queries.clone() - input_ids[~attention_mask] = 0 # set padding tokens to 0 - output = lm_backbone.generate( - input_ids=input_ids, - attention_mask=attention_mask, - # position_ids=attention_mask.cumsum(1) - attention_mask.long(), # generation collapsed if this was turned on. TODO: why does generation collapse with this? - generation_config=generation_config, - return_dict_in_generate=True, - ) - # restore padding tokens - return torch.cat((queries, output.sequences[:, context_length:]), dim=1) - - -def get_reward(reward_model, query_responses, tokenizer): - attention_mask = query_responses != tokenizer.pad_token_id - position_ids = attention_mask.cumsum(1) - attention_mask.long() # exclusive cumsum - input_ids = torch.masked_fill(query_responses, ~attention_mask, 0) - return reward_model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - return_dict=True, - output_hidden_states=True, - ) - - -def get_reward_complete(reward_model, query_responses, tokenizer): - reward = get_reward(reward_model, query_responses, tokenizer)[1] - last_response_indices = first_true_indices(query_responses == tokenizer.pad_token_id) - 1 - last_response_indices = torch.max( - last_response_indices, - torch.zeros([1], dtype=last_response_indices.dtype, device=query_responses.device), - ) - return reward[:, :, 0].gather(1, last_response_indices.unsqueeze(1)).view(-1), reward - - -def normalize( - tokenizer, - accelerator, - device, - lm_backbone, - reward_model, - dataloader, - validation_dataloader, -): - idx = 0 - with torch.no_grad(): - # reset reward scales - accelerator.unwrap_model(reward_model).reward_gain.data.fill_(1.0) - accelerator.unwrap_model(reward_model).reward_bias.data.fill_(0.0) - # number of minibatches for computing the normalization statistics - rewards = [] - for data in dataloader: - idx += len(data["query_token"]) - queries = data["query_token"].to(device) - queries = right_padding_to_left_padding(queries, tokenizer.pad_token_id).to(device) - reference_response = data["reference_response"].to(device) - query_responses = torch.cat((queries, reference_response), dim=1) - score = get_reward_complete(reward_model, query_responses, tokenizer) - rewards.append(score) - accelerator.print(f"====number of samples per device: {idx}") - rewards = torch.cat(rewards) - rewards = accelerator.gather(rewards) - mean, std = rewards.mean(), rewards.std() - print(f"mean: {mean}, std: {std}") - - # reward normalization - target_mean, target_std = torch.tensor(0.0, device=device), torch.tensor(1.0, device=device) - gain = target_std / std - bias = target_mean - gain * mean - print(f"gain: {gain}, bias: {bias}") - accelerator.unwrap_model(reward_model).reward_gain.data = gain - accelerator.unwrap_model(reward_model).reward_bias.data = bias - - # validate normalization - rewards = [] - for data in validation_dataloader: - queries = data["query_token"].to(device) - queries = right_padding_to_left_padding(queries, tokenizer.pad_token_id).to(device) - reference_response = data["reference_response"].to(device) - query_responses = torch.cat((queries, reference_response), dim=1) - score = get_reward_complete(reward_model, query_responses, tokenizer) - rewards.append(score) - rewards = torch.cat(rewards) - rewards = accelerator.gather(rewards) - mean, std = rewards.mean(), rewards.std() - print(f"after mean: {mean}, after std: {std}") - - -def evaluate(args, accelerator, device, reward_model, validation_label): - reward_model.eval() - with torch.no_grad(): - # eval on validation_label, some duplicate code (I don't want to make the training loop into a function...) - test_accuracies = [] - eval_len = len(validation_label) - len_labels = (eval_len // args.batch_size) * args.batch_size # in case the last batch is not full - new_all_inds = np.arange(len_labels) - for start in range(0, len_labels, args.batch_size): - end = start + args.batch_size - b_inds_all = new_all_inds[start:end] - b_inds = b_inds_all[accelerator.process_index :: accelerator.num_processes] # multi-GPU slicing - for micro_batch_start in range(0, args.local_batch_size, args.local_micro_batch_size): - micro_batch_end = micro_batch_start + args.local_micro_batch_size - micro_batch_inds = b_inds[micro_batch_start:micro_batch_end] - mb_data = validation_label[micro_batch_inds] - mb_query = torch.from_numpy(np.stack(mb_data["query_token"])).to(device) - mb_query = right_padding_to_left_padding(mb_query, args.pad_token_id).to(device) - mb_best = torch.from_numpy(np.stack(mb_data["choice"])).to(device) - mb_responses = [ - torch.from_numpy(np.stack(mb_data[f"response{i}_token"])).to(device) for i in range(args.labels.num_labels) - ] - predicted_rewards = [] - for i in range(args.labels.num_labels): - query_responses = torch.cat([mb_query, mb_responses[i]], dim=1) - score, _ = get_reward_complete(reward_model, query_responses, args) - predicted_rewards.append(score) - predicted_rewards = torch.stack( - predicted_rewards, dim=1 - ) # shape (batch_size, num_labels), basically a reward prediction for each label - accuracy = (predicted_rewards.argmax(1) == mb_best).float().mean() - test_accuracies.append(accuracy) - test_accuracy = accelerator.gather(torch.stack(test_accuracies).mean()).mean().item() - reward_model.train() - return test_accuracy - - -def train(args: Args): - accelerator = Accelerator( - kwargs_handlers=[ - DistributedDataParallelKwargs( - broadcast_buffers=False, - ) - ], # this is needed to avoid https://github.com/pytorch/pytorch/issues/22095#issuecomment-505099500 - gradient_accumulation_steps=args.gradient_accumulation_steps, - ) - args.world_size = accelerator.num_processes - args.batch_size = int(args.local_batch_size * args.world_size) - args.rollout_batch_size = int(args.local_rollout_batch_size * args.world_size) - args.local_micro_batch_size = exact_div(args.local_batch_size, args.gradient_accumulation_steps) - num_updates = args.labels.num_train // args.batch_size - patch_h = TaskQueryHParams( - length=args.task.query_length, - dataset=args.task.query_dataset, - format_str=args.task.query_format_str, - truncate_field=args.task.query_truncate_field, - truncate_text=args.task.query_truncate_text, - padding=args.task.query_padding, - pad_side=args.task.query_pad_side, - ) - pprint(patch_h) - - console = Console(force_terminal=True) - run_name = f"{args.exp_name}__{args.seed}__{int(time.time())}" - writer = SimpleNamespace() # dummy writer - writer.add_scalar = lambda x, y, z: None - if accelerator.is_main_process: - if args.track: - import wandb - - wandb.init( - project=args.wandb_project_name, - entity=args.wandb_entity, - sync_tensorboard=True, - config=asdict(args), - name=run_name, - save_code=True, - ) - wandb.run.log_code(".") - writer = SummaryWriter(f"runs/{run_name}") - writer.add_text( - "hyperparameters", - "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])), - ) - pprint(args) - device = accelerator.device - local_seed = args.seed + accelerator.process_index * 100003 # Prime - random.seed(local_seed) - np.random.seed(local_seed) - torch.manual_seed(local_seed) - torch.backends.cudnn.deterministic = True - tokenizer = AutoTokenizer.from_pretrained( - args.base_model, - padding_side="right", - trust_remote_code=True, - ) - # we use the padding token manually but do not resize the token embedding of the model - tokenizer.add_special_tokens({"pad_token": "[PAD]"}) - args.pad_token_id = tokenizer.pad_token_id - reward_model = AutoModelForCausalLMWithRewardHead( - AutoModelForCausalLM.from_pretrained(args.base_model, trust_remote_code=True) - ) - - # freeze the first 70% of layers - if args.trainable_param_percentage < 1.0: - layers = reward_model.lm_backbone.transformer.h - num_layers = len(layers) - num_unfrozen = int(args.trainable_param_percentage * num_layers) - for layer in layers[:-num_unfrozen]: - layer.requires_grad_(False) - - if args.sft_model_path: - reward_model.lm_backbone.load_state_dict(torch.load(args.sft_model_path, map_location=device)) - print(f"loaded SFT model from {args.sft_model_path}") - reward_model.lm_backbone.generation_config.eos_token_id = ( - None # disable `pad_token_id` and `eos_token_id` because we just want to - ) - reward_model.lm_backbone.generation_config.pad_token_id = None # generate tokens without truncation / padding - # make sure the `lm_head` or `embed_out` does not require gradients, otherwise - # pytorch DDP complains; see https://gist.github.com/vwxyzjn/45fc8706dfb3cf33695f0f57cc44a533 - reward_model.lm_backbone.gradient_checkpointing_enable() - - if isinstance(reward_model.lm_backbone, transformers.GPTNeoXForCausalLM): - reward_model.lm_backbone.embed_out.requires_grad_(False) - if args.use_tensorflow_adam: - optimizer = AdamTensorFlowStyle(reward_model.parameters(), lr=args.lr, eps=args.eps) - else: - optimizer = optim.Adam(reward_model.parameters(), lr=args.lr, eps=args.eps) - scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_updates) - dataset = load_dataset(args.task.query_dataset, split="train") - validation_dataset = load_dataset(args.task.query_dataset, split="validation") - - def process_query_data(x): - return { - **process_query(x, encoder=tokenizer, hparams=patch_h), - "reference_response": tokenizer.encode( - f" {x['summary']}<|endoftext|>", - padding="max_length", - max_length=args.task.response_length, - truncation=True, - # with an extra leading space to account for the space between the query and response - ), - } - - # pprint(process_query_data(dataset[0])) - dataset = dataset.map(process_query_data, load_from_cache_file=args.load_from_cache_file) - dataset = dataset.with_format("torch", columns=["query_token", "reference_response"]) - dataset = dataset.shuffle(seed=local_seed) - dataloader = DataLoader(dataset, batch_size=args.local_rollout_batch_size) - validation_dataset = validation_dataset.map(process_query_data, load_from_cache_file=args.load_from_cache_file) - validation_dataset = validation_dataset.with_format("torch", columns=["query_token", "reference_response"]) - validation_dataset = validation_dataset.shuffle(seed=local_seed) - validation_dataloader = DataLoader(validation_dataset, batch_size=args.local_rollout_batch_size) - reward_model, optimizer, dataloader, scheduler = accelerator.prepare(reward_model, optimizer, dataloader, scheduler) - - iter(dataloader) - generation_config = GenerationConfig( - max_new_tokens=args.task.response_length, - min_new_tokens=args.task.response_length, - temperature=args.task.temperature, - top_k=0.0, - top_p=1.0, - do_sample=True, - ) - - if args.normalize_before: - print("===Normalize reward model *before* training===") - print( - "before normalization. " - + f"Gain: {accelerator.unwrap_model(reward_model).reward_gain.data}" - + f" Bias: {accelerator.unwrap_model(reward_model).reward_bias.data}" - ) - - normalize( - tokenizer, - accelerator, - device, - reward_model, - reward_model, - dataloader, - validation_dataloader, - ) - print( - "after normalization. " - + f"Gain: {accelerator.unwrap_model(reward_model).reward_gain.data}" - + f" Bias: {accelerator.unwrap_model(reward_model).reward_bias.data}" - ) - - # `label` has keys `['sample0', 'query', 'best', 'sample3', 'sample1', 'sample2']` - label = load_dataset(args.label_dataset, "comparisons", split="train") - validation_label = load_dataset(args.label_dataset, "comparisons", split="validation") - dev_validation_label = validation_label.filter(lambda x: x["split"] == "valid1") - eval_validation_label = validation_label.filter(lambda x: x["split"] == "valid2") - accelerator.print("Num labels found in source:", len(label)) - accelerator.print("training on", args.labels.num_train, "in batches of", args.local_batch_size) - - def process_response_data(x): - return { - **process_query(x["info"], encoder=tokenizer, hparams=patch_h), - "response0_token": tokenizer.encode( - f" {x['summaries'][0]['text']}<|endoftext|>", - padding="max_length", - max_length=args.task.response_length, - truncation=True, - ), - "response1_token": tokenizer.encode( - f" {x['summaries'][1]['text']}<|endoftext|>", - padding="max_length", - max_length=args.task.response_length, - truncation=True, - ), - } - - label = label.map(process_response_data) - dev_validation_label = dev_validation_label.map(process_response_data) - eval_validation_label = eval_validation_label.map(process_response_data) - # tokenizer.encode(label[0]["summaries"][0]["text"]) - - accelerator.print("===training reward model===") - all_inds = np.random.permutation(args.labels.num_train) - # ensure that all processes have the same shuffled indices - all_inds = broadcast(torch.tensor(all_inds, device=device), 0) - all_inds = all_inds.cpu().numpy() - - for (global_step, start) in enumerate(range(0, args.labels.num_train, args.batch_size)): - # # linear rate annealing - # lr = (1 - start / args.labels.num_train) * args.lr - # optimizer.param_groups[0]["lr"] = lr - - end = start + args.batch_size - b_inds_all = all_inds[start:end] - b_inds = b_inds_all[accelerator.process_index :: accelerator.num_processes] # multi-GPU slicing - # accelerator.print(f"global_step: {global_step}, start: {start}, end: {end}, b_inds: {b_inds}") - if accelerator.is_main_process: - pprint( - { - "global_step": global_step, - "start:end": f"{start}:{end}", - "b_inds_all": b_inds_all, - "b_inds": b_inds, - } - ) - losses = torch.zeros((args.gradient_accumulation_steps,), device=device) - accuracies = torch.zeros((args.gradient_accumulation_steps,), device=device) - gradient_accumulation_step = 0 - # reward_model.train() - for micro_batch_start in range(0, args.local_batch_size, args.local_micro_batch_size): - with accelerator.accumulate(reward_model): - micro_batch_end = micro_batch_start + args.local_micro_batch_size - micro_batch_inds = b_inds[micro_batch_start:micro_batch_end] - mb_data = label[micro_batch_inds] - # pprint({ - # "micro_batch_start:micro_batch_end": f"{micro_batch_start}:{micro_batch_end}", - # "micro_batch_inds": micro_batch_inds, - # }) - mb_query = torch.from_numpy(np.stack(mb_data["query_token"])).to(device) - mb_best = torch.from_numpy(np.stack(mb_data["choice"])).to(device) - mb_responses = [ - torch.from_numpy(np.stack(mb_data[f"response{i}_token"])).to(device) for i in range(args.labels.num_labels) - ] - mb_query_tiled = mb_query.unsqueeze(1).repeat(1, len(mb_responses), 1) - query_responses = torch.cat([mb_query_tiled, torch.stack(mb_responses).transpose(0, 1)], dim=2).flatten(0, 1) - predicted_rewards, score_all = get_reward_complete(reward_model, query_responses, tokenizer) - breakpoint() - - predicted_rewards = predicted_rewards.view(len(mb_responses), -1) - reward_preferred = predicted_rewards.gather(1, mb_best.view(-1, 1)).view(-1) - reward_rejected = predicted_rewards.gather(1, (1 - mb_best).view(-1, 1)).view(-1) - accuracy = (predicted_rewards.argmax(1) == mb_best).float().mean() - if args.logsigmoid: - loss = -F.logsigmoid(reward_preferred - reward_rejected).mean() - else: - loss = F.cross_entropy(predicted_rewards, mb_best) - accelerator.backward(loss) - optimizer.step() # accelerate handles gradient accumulation automatically - optimizer.zero_grad() - scheduler.step() - losses[gradient_accumulation_step] = loss - accuracies[gradient_accumulation_step] = accuracy - gradient_accumulation_step += 1 - - train_accuracy = accelerator.gather(accuracies).mean().item() - writer.add_scalar("train/loss", accelerator.gather(losses).mean().item(), global_step) - writer.add_scalar("train/accuracy", train_accuracy, global_step) - lr = scheduler.get_last_lr() - writer.add_scalar("train/lr", np.array(lr).mean().item(), global_step) - accelerator.print("train/accuracy", train_accuracy) - - # if args.print_sample_output_freq > 0 and global_step % args.print_sample_output_freq == 0: - if global_step == num_updates - 1: # first and last update - dev_validation_accuracy = evaluate(args, accelerator, device, reward_model, dev_validation_label) - writer.add_scalar("dev_validation/accuracy", dev_validation_accuracy, global_step) - accelerator.print("dev_validation/accuracy", dev_validation_accuracy, global_step) - eval_validation_accuracy = evaluate(args, accelerator, device, reward_model, eval_validation_label) - writer.add_scalar("eval_validation/accuracy", eval_validation_accuracy, global_step) - accelerator.print("eval_validation/accuracy", eval_validation_accuracy, global_step) - - torch.cuda.empty_cache() - if args.normalize_after: - print("===Normalize reward model *after* training===") - print( - "before normalization. " - + f"Gain: {accelerator.unwrap_model(reward_model).reward_gain.data}" - + f" Bias: {accelerator.unwrap_model(reward_model).reward_bias.data}" - ) - - normalize( - tokenizer, - accelerator, - device, - reward_model, - reward_model, - dataloader, - validation_dataloader, - ) - print( - "after normalization. " - + f"Gain: {accelerator.unwrap_model(reward_model).reward_gain.data}" - + f" Bias: {accelerator.unwrap_model(reward_model).reward_bias.data}" - ) - - # save model - if args.save_path: - os.makedirs(os.path.dirname(args.save_path), exist_ok=True) - # torch.save(accelerator.unwrap_model(reward_model).state_dict(), args.save_path) - accelerator.save_model(reward_model, args.save_path) - - if accelerator.is_main_process and args.track: - wandb.finish() - - -if __name__ == "__main__": - args = tyro.cli(Args) - train(args) diff --git a/lm_human_preference_details/summarization/train_reward_accelerate_summarizew.py b/lm_human_preference_details/summarization/train_reward_accelerate_summarizew.py deleted file mode 100644 index a199c85..0000000 --- a/lm_human_preference_details/summarization/train_reward_accelerate_summarizew.py +++ /dev/null @@ -1,836 +0,0 @@ -import os -import random -import time -from dataclasses import asdict, dataclass, field -from types import SimpleNamespace -from typing import List, Literal, Optional - -import numpy as np -import pandas as pd -import torch -import torch.nn as nn -import torch.optim as optim -import transformers -import tyro -from accelerate import Accelerator -from accelerate.state import AcceleratorState -from accelerate.utils import DistributedDataParallelKwargs, broadcast -from datasets import load_dataset -from rich.console import Console -from rich.pretty import pprint -from rich.table import Table -from torch import Tensor, optim -from torch.optim.optimizer import ( - _dispatch_sqrt, - _get_value, - _use_grad_for_differentiable, -) -from torch.utils.data import DataLoader -from torch.utils.tensorboard import SummaryWriter -from transformers import AutoModelForCausalLM, AutoTokenizer, get_scheduler - -from lm_human_preference_details.data import process_query - - -@dataclass -class LabelHParams: - type: str = None - num_train: int = 92832 - num_labels: int = 2 - source: str = None - - -@dataclass -class TaskHParams: - # Query params - query_length: int = 512 - query_dataset: str = "vwxyzjn/summarize_from_feedback_tldr_3_filtered" - - query_format_str: Optional[str] = "SUBREDDIT: r/{subreddit}\n\nTITLE: {title}\n\nPOST: {post}\n\nTL;DR:" - query_truncate_field: Optional[str] = "post" - query_truncate_text: Optional[str] = "\n" - query_padding: Optional[str] = None # defaults to repeated spaces - query_pad_side: Optional[str] = "left" - - # Response params - response_length: int = 48 - - # LM params - temperature: float = 0.7 - - -# a patch -@dataclass -class TaskQueryHParams: - length: int = None - dataset: str = None - format_str: Optional[str] = None # if underlying dataset yields dicts, can format arbitrarily - truncate_field: Optional[str] = None - truncate_text: Optional[str] = None - padding: Optional[str] = None # defaults to repeated spaces - pad_side: Optional[str] = None - - -@dataclass -class Args: - # common args - exp_name: str = os.path.basename(__file__)[: -len(".py")] - """the name of this experiment""" - seed: int = 1 - """seed of the experiment""" - track: bool = False - """if toggled, this experiment will be tracked with Weights and Biases""" - wandb_project_name: str = "cleanrl" - """the wandb's project name""" - wandb_entity: Optional[str] = None - """the entity (team) of wandb's project""" - cuda: bool = True - """Whether to use cuda if available.""" - run_name: tyro.conf.Suppress[str] = None - """TO BE FILLED: a unique name of this run""" - load_from_cache_file: bool = False - """Whether to load data from the local cache file in `dataset.map`""" - - base_model: str = "gpt2" - """the name of the pretrained model to use""" - deepspeed: bool = False - """Whether to use deepspeed to train the model""" - label_dataset: str = "openai/summarize_from_feedback" - """the name of the dataset to use for labels in `https://huggingface.co/datasets/vwxyzjn/lm-human-preferences`""" - local_batch_size: int = 4 - """per rank batch size""" - gradient_accumulation_steps: int = 1 - """gradient accumulation steps""" - local_micro_batch_size: tyro.conf.Suppress[int] = None - """per rank micro batch size""" - lr: float = 0.00005 - """the learning rate""" - eps: float = 1e-5 - """the epsilon for AdamW""" - local_rollout_batch_size: int = 512 - """per rank rollout batch size""" - rollout_batch_size: tyro.conf.Suppress[int] = None - """rollout batch size""" - world_size: tyro.conf.Suppress[int] = None - """the number of processes to use""" - batch_size: tyro.conf.Suppress[int] = None - """the batch size across all ranks""" - local_normalize_samples: int = 256 - """Samples used to estimate reward mean and std""" - normalize_samples: tyro.conf.Suppress[int] = None - """Samples used to estimate reward mean and std across all ranks""" - debug_normalize: int = 0 - """Samples used to check that normalization worked""" - normalize_before: bool = True - """Whether, before training, to normalize the rewards on the policy to the scales on the training buffer. (For comparisons, just use mean 0, var 1.)""" - normalize_after: bool = True - """Whether, after training, to normalize the rewards on the ref policy to mean 0, var 1 (so the KL coefficient always has the same meaning).""" - print_sample_output_freq: int = 300 - """How often to print sample output""" - sft_model_path: str = "models/sft_policy" - """Where to load the SFT model""" - logsigmoid: bool = True - """Whether to use log-sigmoid loss instead of cross-entropy loss""" - trainable_param_percentage: float = 1.0 - """Percentage of parameters to train""" - num_epochs: int = 1 - """Number of epochs to train""" - num_updates: tyro.conf.Suppress[int] = None - """Number of updates to train""" - save_path: str = "models/reward" - """Where to save the model""" - optimizer: Literal["tf_adam", "adam", "adamw"] = "adamw" - """Which optimizer to use""" - scheduler: str = "constant_with_warmup" - """Which scheduler to use""" - warm_up_steps: int = 100 - """Number of warm up steps for the scheduler""" - task: TaskHParams = field(default_factory=TaskHParams) - labels: LabelHParams = field(default_factory=LabelHParams) - - -def first_true_indices(bools, dtype=torch.long): - """ - Takes an N-dimensional bool tensor and returns an (N-1)-dimensional tensor of integers giving - the position of the first True in each "row". - - Returns the length of the rows (bools.size(-1)) if no element is True in a given row. - """ - row_len = bools.size(-1) - zero_or_index = row_len * (~bools).type(dtype) + torch.arange(row_len, dtype=dtype, device=bools.device) - return torch.min(zero_or_index, dim=-1).values - - -def print_rich_table(title: str, df: pd.DataFrame, console: Console) -> Table: - table = Table(show_lines=True) - for column in df.columns: - table.add_column(column) - for _, row in df.iterrows(): - table.add_row(*row.astype(str).tolist()) - console.rule(f"[bold red]{title}") - console.print(table) - - -def _single_tensor_adam( - params: List[Tensor], - grads: List[Tensor], - exp_avgs: List[Tensor], - exp_avg_sqs: List[Tensor], - max_exp_avg_sqs: List[Tensor], - state_steps: List[Tensor], - grad_scale: Optional[Tensor], - found_inf: Optional[Tensor], - *, - amsgrad: bool, - beta1: float, - beta2: float, - lr: float, - weight_decay: float, - eps: float, - maximize: bool, - capturable: bool, - differentiable: bool, -): - assert grad_scale is None and found_inf is None - - for i, param in enumerate(params): - grad = grads[i] if not maximize else -grads[i] - exp_avg = exp_avgs[i] - exp_avg_sq = exp_avg_sqs[i] - step_t = state_steps[i] - # update step - step_t += 1 - # Decay the first and second moment running average coefficient - exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) - exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2) - step = _get_value(step_t) - - ### pytorch adam implementation: - # bias_correction1 = 1 - beta1 ** step - # bias_correction2 = 1 - beta2 ** step - # step_size = lr / bias_correction1 - # bias_correction2_sqrt = _dispatch_sqrt(bias_correction2) - # denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps) - # param.addcdiv_(exp_avg, denom, value=-step_size) - - ### tensorflow adam implementation: - lr_t = lr * _dispatch_sqrt(1 - beta2**step) / (1 - beta1**step) - denom = exp_avg_sq.sqrt().add_(eps) - param.addcdiv_(exp_avg, denom, value=-lr_t) - - -def adam( - params: List[Tensor], - grads: List[Tensor], - exp_avgs: List[Tensor], - exp_avg_sqs: List[Tensor], - max_exp_avg_sqs: List[Tensor], - state_steps: List[Tensor], - # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627 - # setting this as kwarg for now as functional API is compiled by torch/distributed/optim - foreach: Optional[bool] = None, - capturable: bool = False, - differentiable: bool = False, - fused: Optional[bool] = None, - grad_scale: Optional[Tensor] = None, - found_inf: Optional[Tensor] = None, - *, - amsgrad: bool, - beta1: float, - beta2: float, - lr: float, - weight_decay: float, - eps: float, - maximize: bool, -): - func = _single_tensor_adam - - func( - params, - grads, - exp_avgs, - exp_avg_sqs, - max_exp_avg_sqs, - state_steps, - amsgrad=amsgrad, - beta1=beta1, - beta2=beta2, - lr=lr, - weight_decay=weight_decay, - eps=eps, - maximize=maximize, - capturable=capturable, - differentiable=differentiable, - grad_scale=grad_scale, - found_inf=found_inf, - ) - - -class AdamTensorFlowStyle(optim.Adam): - @_use_grad_for_differentiable - def step(self, closure=None): - self._cuda_graph_capture_health_check() - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - - for group in self.param_groups: - params_with_grad = [] - grads = [] - exp_avgs = [] - exp_avg_sqs = [] - max_exp_avg_sqs = [] - state_steps = [] - beta1, beta2 = group["betas"] - - self._init_group( - group, - params_with_grad, - grads, - exp_avgs, - exp_avg_sqs, - max_exp_avg_sqs, - state_steps, - ) - - adam( - params_with_grad, - grads, - exp_avgs, - exp_avg_sqs, - max_exp_avg_sqs, - state_steps, - amsgrad=group["amsgrad"], - beta1=beta1, - beta2=beta2, - lr=group["lr"], - weight_decay=group["weight_decay"], - eps=group["eps"], - maximize=group["maximize"], - foreach=group["foreach"], - capturable=group["capturable"], - differentiable=group["differentiable"], - fused=group["fused"], - grad_scale=getattr(self, "grad_scale", None), - found_inf=getattr(self, "found_inf", None), - ) - - return loss - - -def layer_init(layer, std=np.sqrt(2), bias_const=0.0): - torch.nn.init.normal_(layer.weight, std=std) - torch.nn.init.constant_(layer.bias, val=bias_const) - return layer - - -class AutoModelForCausalLMWithRewardHead(nn.Module): - def __init__(self, lm_backbone): - super().__init__() - self.lm_backbone = lm_backbone - self.scalar_head = layer_init( - nn.Linear(lm_backbone.config.hidden_size, 1), - std=1 / np.sqrt(lm_backbone.config.hidden_size + 1), - ) - # self.reward_gain = torch.nn.Parameter(torch.tensor(1.0), requires_grad=True) - # self.reward_bias = torch.nn.Parameter(torch.tensor(0.0), requires_grad=True) - - def forward(self, **kwargs): - output = self.lm_backbone(**kwargs) - last_reward_latents = output.hidden_states[-1] - # shape: [batch_size, hidden_size] - reward = self.scalar_head(last_reward_latents) - return output, reward - - -def right_padding_to_left_padding(tokens, pad_id): - """Convert from right padding to left padding.""" - assert tokens.ndim == 2 - return torch.tensor( - [[pad_id] * (row == pad_id).sum() + [x for x in row if x != pad_id] for row in tokens], - device=tokens.device, - ) - - -def ceil_div(a, b): - return (a - 1) // b + 1 - - -def exact_div(a, b): - q = a // b - if a != q * b: - raise ValueError(f"Inexact division: {a} / {b} = {a / b}") - return q - - -def generate(lm_backbone, queries, tokenizer, generation_config): - """generate in a way that does not affect padding tokens""" - context_length = queries.shape[1] - attention_mask = queries != tokenizer.pad_token_id - input_ids = queries.clone() - input_ids[~attention_mask] = 0 # set padding tokens to 0 - output = lm_backbone.generate( - input_ids=input_ids, - attention_mask=attention_mask, - # position_ids=attention_mask.cumsum(1) - attention_mask.long(), # generation collapsed if this was turned on. TODO: why does generation collapse with this? - generation_config=generation_config, - return_dict_in_generate=True, - ) - # restore padding tokens - return torch.cat((queries, output.sequences[:, context_length:]), dim=1) - - -def get_reward(reward_model, query_responses, tokenizer): - attention_mask = query_responses != tokenizer.pad_token_id - position_ids = attention_mask.cumsum(1) - attention_mask.long() # exclusive cumsum - input_ids = torch.masked_fill(query_responses, ~attention_mask, 0) - return reward_model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - return_dict=True, - output_hidden_states=True, - ) - - -def get_reward_complete(reward_model, query_responses, tokenizer): - reward = get_reward(reward_model, query_responses, tokenizer)[1] - last_response_indices = first_true_indices(query_responses == tokenizer.pad_token_id) - 1 - last_response_indices = torch.max( - last_response_indices, - torch.zeros([1], dtype=last_response_indices.dtype, device=query_responses.device), - ) - return reward[:, :, 0].gather(1, last_response_indices.unsqueeze(1)).view(-1), reward - - -def normalize( - tokenizer, - accelerator, - device, - lm_backbone, - reward_model, - dataloader, - validation_dataloader, -): - idx = 0 - with torch.no_grad(): - # reset reward scales - accelerator.unwrap_model(reward_model).reward_gain.data.fill_(1.0) - accelerator.unwrap_model(reward_model).reward_bias.data.fill_(0.0) - # number of minibatches for computing the normalization statistics - rewards = [] - for data in dataloader: - idx += len(data["query_token"]) - queries = data["query_token"].to(device) - queries = right_padding_to_left_padding(queries, tokenizer.pad_token_id).to(device) - reference_response = data["reference_response"].to(device) - query_responses = torch.cat((queries, reference_response), dim=1) - score = get_reward_complete(reward_model, query_responses, tokenizer) - rewards.append(score) - accelerator.print(f"====number of samples per device: {idx}") - rewards = torch.cat(rewards) - rewards = accelerator.gather(rewards) - mean, std = rewards.mean(), rewards.std() - print(f"mean: {mean}, std: {std}") - - # reward normalization - target_mean, target_std = torch.tensor(0.0, device=device), torch.tensor(1.0, device=device) - gain = target_std / std - bias = target_mean - gain * mean - print(f"gain: {gain}, bias: {bias}") - accelerator.unwrap_model(reward_model).reward_gain.data = gain - accelerator.unwrap_model(reward_model).reward_bias.data = bias - - # validate normalization - rewards = [] - for data in validation_dataloader: - queries = data["query_token"].to(device) - queries = right_padding_to_left_padding(queries, tokenizer.pad_token_id).to(device) - reference_response = data["reference_response"].to(device) - query_responses = torch.cat((queries, reference_response), dim=1) - score = get_reward_complete(reward_model, query_responses, tokenizer) - rewards.append(score) - rewards = torch.cat(rewards) - rewards = accelerator.gather(rewards) - mean, std = rewards.mean(), rewards.std() - print(f"after mean: {mean}, after std: {std}") - - -def evaluate(args, accelerator, device, reward_model, validation_label): - # reward_model.eval() - with torch.no_grad(): - # eval on validation_label, some duplicate code (I don't want to make the training loop into a function...) - test_accuracies = [] - eval_len = len(validation_label) - len_labels = (eval_len // args.batch_size) * args.batch_size # in case the last batch is not full - new_all_inds = np.arange(len_labels) - for start in range(0, len_labels, args.batch_size): - end = start + args.batch_size - b_inds_all = new_all_inds[start:end] - b_inds = b_inds_all[accelerator.process_index :: accelerator.num_processes] # multi-GPU slicing - for micro_batch_start in range(0, args.local_batch_size, args.local_micro_batch_size): - micro_batch_end = micro_batch_start + args.local_micro_batch_size - micro_batch_inds = b_inds[micro_batch_start:micro_batch_end] - mb_data = validation_label[micro_batch_inds] - mb_query = torch.from_numpy(np.stack(mb_data["query_token"])).to(device) - mb_query = right_padding_to_left_padding(mb_query, args.pad_token_id).to(device) - mb_best = torch.from_numpy(np.stack(mb_data["choice"])).to(device) - mb_responses = [ - torch.from_numpy(np.stack(mb_data[f"response{i}_token"])).to(device) for i in range(args.labels.num_labels) - ] - predicted_reward = [] - rewards = [] - for i in range(args.labels.num_labels): - query_responses = torch.cat([mb_query, mb_responses[i]], dim=1) - score, reward = get_reward_complete(reward_model, query_responses, args) - rewards.append(reward) - predicted_reward.append(score) - predicted_reward = torch.stack( - predicted_reward, dim=1 - ) # shape (batch_size, num_labels), basically a reward prediction for each label - accuracy = (predicted_reward.argmax(1) == mb_best).float().mean() - test_accuracies.append(accuracy) - test_accuracy = accelerator.gather(torch.stack(test_accuracies).mean()).mean().item() - # reward_model.train() - return test_accuracy - - -def train(args: Args): - accelerator = Accelerator( - kwargs_handlers=[ - DistributedDataParallelKwargs( - broadcast_buffers=False, - # find_unused_parameters=True, - ) - ], # this is needed to avoid https://github.com/pytorch/pytorch/issues/22095#issuecomment-505099500 - gradient_accumulation_steps=args.gradient_accumulation_steps, - ) - args.world_size = accelerator.num_processes - args.batch_size = int(args.local_batch_size * args.world_size) - args.rollout_batch_size = int(args.local_rollout_batch_size * args.world_size) - args.local_micro_batch_size = exact_div(args.local_batch_size, args.gradient_accumulation_steps) - args.num_updates = args.labels.num_train // args.batch_size - patch_h = TaskQueryHParams( - length=args.task.query_length, - dataset=args.task.query_dataset, - format_str=args.task.query_format_str, - truncate_field=args.task.query_truncate_field, - truncate_text=args.task.query_truncate_text, - padding=args.task.query_padding, - pad_side=args.task.query_pad_side, - ) - - console = Console(force_terminal=True) - run_name = f"{args.exp_name}__{args.seed}__{int(time.time())}" - writer = SimpleNamespace() # dummy writer - writer.add_scalar = lambda x, y, z: None - if accelerator.is_main_process: - if args.track: - import wandb - - wandb.init( - project=args.wandb_project_name, - entity=args.wandb_entity, - sync_tensorboard=True, - config=asdict(args), - name=run_name, - save_code=True, - ) - wandb.run.log_code(".") - writer = SummaryWriter(f"runs/{run_name}") - writer.add_text( - "hyperparameters", - "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])), - ) - pprint(args) - device = accelerator.device - local_seed = args.seed + accelerator.process_index * 100003 # Prime - random.seed(local_seed) - np.random.seed(local_seed) - torch.manual_seed(local_seed) - torch.backends.cudnn.deterministic = True - tokenizer = AutoTokenizer.from_pretrained( - args.base_model, - padding_side="right", - trust_remote_code=True, - ) - # we use the padding token manually but do not resize the token embedding of the model - tokenizer.add_special_tokens({"pad_token": "[PAD]"}) - args.pad_token_id = tokenizer.pad_token_id - reward_model = AutoModelForCausalLMWithRewardHead( - AutoModelForCausalLM.from_pretrained(args.base_model, trust_remote_code=True) - ) - - # freeze the first 70% of layers - if args.trainable_param_percentage < 1.0: - layers = reward_model.lm_backbone.transformer.h - num_layers = len(layers) - num_unfrozen = int(args.trainable_param_percentage * num_layers) - for layer in layers[:-num_unfrozen]: - layer.requires_grad_(False) - - if args.sft_model_path: - reward_model.lm_backbone.load_state_dict(torch.load(args.sft_model_path, map_location=device)) - print(f"loaded SFT model from {args.sft_model_path}") - reward_model.lm_backbone.generation_config.eos_token_id = ( - None # disable `pad_token_id` and `eos_token_id` because we just want to - ) - reward_model.lm_backbone.generation_config.pad_token_id = None # generate tokens without truncation / padding - # make sure the `lm_head` or `embed_out` does not require gradients, otherwise - # pytorch DDP complains; see https://gist.github.com/vwxyzjn/45fc8706dfb3cf33695f0f57cc44a533 - reward_model.load_state_dict(torch.load("models/gpt2-medium-rm/pytorch_model.bin", map_location=device)) - print("loaded reward model") - if isinstance(reward_model.lm_backbone, transformers.GPTNeoXForCausalLM): - reward_model.lm_backbone.embed_out.requires_grad_(False) - if args.optimizer == "tf_adam": - optimizer = AdamTensorFlowStyle(reward_model.parameters(), lr=args.lr, eps=args.eps) - elif args.optimizer == "adam": - optimizer = optim.Adam(reward_model.parameters(), lr=args.lr, eps=args.eps) - elif args.optimizer == "adamw": - optimizer = optim.AdamW(reward_model.parameters(), lr=args.lr, eps=args.eps) - # TODO: use AdamW - scheduler = get_scheduler( - args.scheduler, - optimizer=optimizer, - num_warmup_steps=args.warm_up_steps, - num_training_steps=args.num_updates * args.num_epochs, - ) - - if args.deepspeed: - pass - - deepspeed_states = AcceleratorState().deepspeed_plugin - deepspeed_states.deepspeed_config["train_micro_batch_size_per_gpu"] = args.local_micro_batch_size - - reward_model, optimizer, scheduler = accelerator.prepare(reward_model, optimizer, scheduler) - if args.normalize_before: - dataset = load_dataset(args.task.query_dataset, split="train") - validation_dataset = load_dataset(args.task.query_dataset, split="validation") - - def process_query_data(x): - return { - **process_query(x, encoder=tokenizer, hparams=patch_h), - "reference_response": tokenizer.encode( - f" {x['summary']}<|endoftext|>", - padding="max_length", - max_length=args.task.response_length, - truncation=True, - # with an extra leading space to account for the space between the query and response - ), - } - - dataset = dataset.map(process_query_data, load_from_cache_file=args.load_from_cache_file) - dataset = dataset.with_format("torch", columns=["query_token", "reference_response"]) - dataset = dataset.shuffle(seed=local_seed) - dataloader = DataLoader(dataset, batch_size=args.local_rollout_batch_size) - validation_dataset = validation_dataset.map(process_query_data, load_from_cache_file=args.load_from_cache_file) - validation_dataset = validation_dataset.with_format("torch", columns=["query_token", "reference_response"]) - validation_dataset = validation_dataset.shuffle(seed=local_seed) - validation_dataloader = DataLoader(validation_dataset, batch_size=args.local_rollout_batch_size) - dataloader = accelerator.prepare(dataloader) - iter(dataloader) - print("===Normalize reward model *before* training===") - print( - "before normalization. " - + f"Gain: {accelerator.unwrap_model(reward_model).reward_gain.data}" - + f" Bias: {accelerator.unwrap_model(reward_model).reward_bias.data}" - ) - - normalize( - tokenizer, - accelerator, - device, - reward_model, - reward_model, - dataloader, - validation_dataloader, - ) - print( - "after normalization. " - + f"Gain: {accelerator.unwrap_model(reward_model).reward_gain.data}" - + f" Bias: {accelerator.unwrap_model(reward_model).reward_bias.data}" - ) - - # `label` has keys `['sample0', 'query', 'best', 'sample3', 'sample1', 'sample2']` - label = load_dataset(args.label_dataset, "comparisons", split="train") - validation_label = load_dataset(args.label_dataset, "comparisons", split="validation") - dev_validation_label = validation_label.filter(lambda x: x["split"] == "valid1") - eval_validation_label = validation_label.filter(lambda x: x["split"] == "valid2") - accelerator.print("Num labels found in source:", len(label)) - accelerator.print("training on", args.labels.num_train, "in batches of", args.local_batch_size) - - def process_response_data(x): - return { - **process_query(x["info"], encoder=tokenizer, hparams=patch_h), - "response0_token": tokenizer.encode( - f" {x['summaries'][0]['text']}<|endoftext|>", - padding="max_length", - max_length=args.task.response_length, - truncation=True, - ), - "response1_token": tokenizer.encode( - f" {x['summaries'][1]['text']}<|endoftext|>", - padding="max_length", - max_length=args.task.response_length, - truncation=True, - ), - } - - label = label.map(process_response_data, load_from_cache_file=args.load_from_cache_file) - dev_validation_label = dev_validation_label.map(process_response_data, load_from_cache_file=args.load_from_cache_file) - eval_validation_label = eval_validation_label.map(process_response_data, load_from_cache_file=args.load_from_cache_file) - # TODO: check if all labels have eos token - accelerator.print("===training reward model===") - num_train = (args.labels.num_train // args.batch_size) * args.batch_size - for epoch in range(args.num_epochs): - all_inds = np.random.permutation(args.labels.num_train) - # ensure that all processes have the same shuffled indices - all_inds = broadcast(torch.tensor(all_inds, device=device), 0) - all_inds = all_inds.cpu().numpy() - accelerator.print(f"epoch: {epoch}") - for (epoch_global_step, start) in enumerate(range(0, num_train, args.batch_size)): - global_step = epoch * args.num_updates + epoch_global_step - end = start + args.batch_size - b_inds_all = all_inds[start:end] - b_inds = b_inds_all[accelerator.process_index :: accelerator.num_processes] # multi-GPU slicing - # accelerator.print(f"global_step: {global_step}, start: {start}, end: {end}, b_inds: {b_inds}") - if accelerator.is_main_process: - pprint( - { - "global_step": global_step, - "start:end": f"{start}:{end}", - "b_inds_all": b_inds_all, - "b_inds": b_inds, - } - ) - losses = torch.zeros((args.gradient_accumulation_steps,), device=device) - accuracies = torch.zeros((args.gradient_accumulation_steps,), device=device) - reward_preferreds = torch.zeros((args.gradient_accumulation_steps,), device=device) - reward_rejecteds = torch.zeros((args.gradient_accumulation_steps,), device=device) - gradient_accumulation_step = 0 - # reward_model.train() - for micro_batch_start in range(0, args.local_batch_size, args.local_micro_batch_size): - with accelerator.accumulate(reward_model): - micro_batch_end = micro_batch_start + args.local_micro_batch_size - micro_batch_inds = b_inds[micro_batch_start:micro_batch_end] - mb_data = label[micro_batch_inds] - # pprint({ - # "micro_batch_start:micro_batch_end": f"{micro_batch_start}:{micro_batch_end}", - # "micro_batch_inds": micro_batch_inds, - # }) - mb_query = torch.from_numpy(np.stack(mb_data["query_token"])).to(device) - mb_best = torch.from_numpy(np.stack(mb_data["choice"])).to(device) - mb_responses = [ - torch.from_numpy(np.stack(mb_data[f"response{i}_token"])).to(device) - for i in range(args.labels.num_labels) - ] - mb_query_tiled = mb_query.unsqueeze(1).repeat(1, len(mb_responses), 1) - query_responses = torch.cat([mb_query_tiled, torch.stack(mb_responses).transpose(0, 1)], dim=2).flatten( - 0, 1 - ) - predicted_reward, reward = get_reward_complete(reward_model, query_responses, tokenizer) - predicted_reward = predicted_reward.view( - -1, len(mb_responses) - ) # TODO check shape for no gradienta ccumulation steps - - # print(tokenizer.decode(mb_query[0])) - # print(tokenizer.decode(mb_responses[0][0])) - # print(tokenizer.decode(mb_responses[1][0])) - # predicted_reward = [] - # rewards = [] - # for i in range(args.labels.num_labels): - # query_responses = torch.cat([mb_query, mb_responses[i]], dim=1) - # score, reward = get_reward_complete(reward_model, query_responses, tokenizer) - # rewards.append(reward.squeeze(-1)) - # predicted_reward.append(score) - # # shape (batch_size, num_labels), basically a reward prediction for each label - # predicted_reward = torch.stack(predicted_reward, dim=1) - # breakpoint() - accuracy = (predicted_reward.argmax(1) == mb_best).float().mean() - reward_preferred = predicted_reward.gather(1, mb_best.view(-1, 1)).view(-1) - reward_rejected = predicted_reward.gather(1, (1 - mb_best).view(-1, 1)).view(-1) - # if args.logsigmoid: - # reward_preferred = predicted_reward.gather(1, mb_best.view(-1, 1)).view(-1) - # reward_rejected = predicted_reward.gather(1, (1 - mb_best).view(-1, 1)).view(-1) - # loss = -F.logsigmoid(reward_preferred - reward_rejected).mean() - # else: - # loss = F.cross_entropy(predicted_reward, mb_best) - # accelerator.backward(loss) - - # # for k, v in reward_model.named_parameters(): - # # if v.requires_grad: - # # if v.grad is None: - # # print(f"found unused param: {k}") - - # optimizer.step() # accelerate handles gradient accumulation automatically - # optimizer.zero_grad() - # scheduler.step() - # losses[gradient_accumulation_step] = loss - accuracies[gradient_accumulation_step] = accuracy - reward_preferreds[gradient_accumulation_step] = reward_preferred.mean() - reward_rejecteds[gradient_accumulation_step] = reward_rejected.mean() - gradient_accumulation_step += 1 - - train_accuracy = accelerator.gather(accuracies).mean().item() - print("train/accuracy", train_accuracy) - print("train/reward_preferred", accelerator.gather(reward_preferreds)) - print("train/reward_rejected", accelerator.gather(reward_rejecteds)) - breakpoint() - writer.add_scalar("train/loss", accelerator.gather(losses).mean().item(), global_step) - writer.add_scalar("train/accuracy", train_accuracy, global_step) - writer.add_scalar("train/reward_preferred", accelerator.gather(reward_preferreds).mean().item(), global_step) - writer.add_scalar("train/reward_rejected", accelerator.gather(reward_rejecteds).mean().item(), global_step) - lr = scheduler.get_last_lr() - writer.add_scalar("train/lr", np.array(lr).mean().item(), global_step) - accelerator.print("train/accuracy", train_accuracy) - - # if args.print_sample_output_freq > 0 and global_step % args.print_sample_output_freq == 0: - if global_step == args.num_updates - 1: # first and last update - dev_validation_accuracy = evaluate(args, accelerator, device, reward_model, dev_validation_label) - writer.add_scalar("dev_validation/accuracy", dev_validation_accuracy, global_step) - accelerator.print("dev_validation/accuracy", dev_validation_accuracy, global_step) - eval_validation_accuracy = evaluate(args, accelerator, device, reward_model, eval_validation_label) - writer.add_scalar("eval_validation/accuracy", eval_validation_accuracy, global_step) - accelerator.print("eval_validation/accuracy", eval_validation_accuracy, global_step) - eval_validation_accuracy = evaluate(args, accelerator, device, reward_model, label) - writer.add_scalar("train_full/accuracy", eval_validation_accuracy, global_step) - accelerator.print("train_full/accuracy", eval_validation_accuracy, global_step) - - torch.cuda.empty_cache() - if args.normalize_after: - print("===Normalize reward model *after* training===") - print( - "before normalization. " - + f"Gain: {accelerator.unwrap_model(reward_model).reward_gain.data}" - + f" Bias: {accelerator.unwrap_model(reward_model).reward_bias.data}" - ) - - normalize( - tokenizer, - accelerator, - device, - reward_model, - reward_model, - dataloader, - validation_dataloader, - ) - print( - "after normalization. " - + f"Gain: {accelerator.unwrap_model(reward_model).reward_gain.data}" - + f" Bias: {accelerator.unwrap_model(reward_model).reward_bias.data}" - ) - - # save model - if args.save_path: - os.makedirs(os.path.dirname(args.save_path), exist_ok=True) - # torch.save(accelerator.unwrap_model(reward_model).state_dict(), args.save_path) - accelerator.save_model(reward_model, args.save_path) - - if accelerator.is_main_process and args.track: - wandb.finish() - - -if __name__ == "__main__": - args = tyro.cli(Args) - train(args) diff --git a/lm_human_preference_details/summarization/train_sft_accelerate_summarize copy.py b/lm_human_preference_details/summarization/train_sft_accelerate_summarize copy.py deleted file mode 100644 index 0ba4cb8..0000000 --- a/lm_human_preference_details/summarization/train_sft_accelerate_summarize copy.py +++ /dev/null @@ -1,529 +0,0 @@ -import collections -import os -import random -import time -from dataclasses import asdict, dataclass, field -from types import SimpleNamespace -from typing import List, Optional - -import evaluate -import numpy as np -import pandas as pd -import torch -import torch.optim as optim -import tyro -from accelerate import Accelerator -from datasets import load_dataset -from rich.console import Console -from rich.pretty import pprint -from rich.table import Table -from torch import Tensor, optim -from torch.optim.optimizer import ( - _dispatch_sqrt, - _get_value, - _use_grad_for_differentiable, -) -from torch.utils.data import DataLoader -from torch.utils.tensorboard import SummaryWriter -from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig - -from lm_human_preference_details.data import process_query - - -@dataclass -class SFTHParams: - gradient_accumulation_steps: int = 16 - local_micro_batch_size: int = 1 - noptepochs: int = 1 - lr: float = 6.35e-5 - eps: float = 1e-5 - total_episodes: tyro.conf.Suppress[int] = None - local_batch_size: tyro.conf.Suppress[int] = None - batch_size: tyro.conf.Suppress[int] = None - mini_batch_size: tyro.conf.Suppress[int] = None - world_size: tyro.conf.Suppress[int] = None - batch_size: tyro.conf.Suppress[int] = None - num_updates: tyro.conf.Suppress[int] = None - - -@dataclass -class TaskHParams: - # Query params - query_length: int = 512 - query_dataset: str = "vwxyzjn/summarize_from_feedback_tldr_3_filtered" - - query_format_str: Optional[str] = "SUBREDDIT: r/{subreddit}\n\nTITLE: {title}\n\nPOST: {post}\n\nTL;DR:" - query_truncate_field: Optional[str] = "post" - query_truncate_text: Optional[str] = "\n" - query_padding: Optional[str] = None # defaults to repeated spaces - query_pad_side: Optional[str] = "left" - - # Response params - response_length: int = 48 - - # Truncate response after the first occurrence of this token at or after index after when sampling. - truncate_token: int = 50256 # EOS token - truncate_after: int = 16 - penalty_reward_value: int = -1 - - # LM params - temperature: float = 0.01 - - -# a patch -@dataclass -class TaskQueryHParams: - length: int = None - dataset: str = None - format_str: Optional[str] = None # if underlying dataset yields dicts, can format arbitrarily - truncate_field: Optional[str] = None - truncate_text: Optional[str] = None - padding: Optional[str] = None # defaults to repeated spaces - pad_side: Optional[str] = None - - -@dataclass -class Args: - # common args - exp_name: str = os.path.basename(__file__)[: -len(".py")] - """the name of this experiment""" - seed: int = 1 - """seed of the experiment""" - track: bool = False - """if toggled, this experiment will be tracked with Weights and Biases""" - wandb_project_name: str = "cleanrl" - """the wandb's project name""" - wandb_entity: Optional[str] = None - """the entity (team) of wandb's project""" - cuda: bool = True - """Whether to use cuda if available.""" - run_name: tyro.conf.Suppress[str] = None - """TO BE FILLED: a unique name of this run""" - upload_model: bool = False - "whether to upload the saved model to huggingface" - hf_entity: str = "" - "the user or org name of the model repository from the Hugging Face Hub" - - base_model: str = "gpt2" - """the name of the pretrained model to use""" - deepspeed: bool = False - """Whether to use deepspeed to train the model""" - print_sample_output_freq: int = 220 - """How often to print sample output""" - save_path: str = "models/sft_policy.pt" - """Where to save the model""" - use_tensorflow_adam: bool = True - """Whether to use tensorflow-style Adam optimizer instead of PyTorch's""" - task: TaskHParams = field(default_factory=TaskHParams) - sft: SFTHParams = field(default_factory=SFTHParams) - - -def print_rich_table(title: str, df: pd.DataFrame, console: Console) -> Table: - table = Table(show_lines=True) - for column in df.columns: - table.add_column(column) - for _, row in df.iterrows(): - table.add_row(*row.astype(str).tolist()) - console.rule(f"[bold red]{title}") - console.print(table) - - -def _single_tensor_adam( - params: List[Tensor], - grads: List[Tensor], - exp_avgs: List[Tensor], - exp_avg_sqs: List[Tensor], - max_exp_avg_sqs: List[Tensor], - state_steps: List[Tensor], - grad_scale: Optional[Tensor], - found_inf: Optional[Tensor], - *, - amsgrad: bool, - beta1: float, - beta2: float, - lr: float, - weight_decay: float, - eps: float, - maximize: bool, - capturable: bool, - differentiable: bool, -): - assert grad_scale is None and found_inf is None - - for i, param in enumerate(params): - grad = grads[i] if not maximize else -grads[i] - exp_avg = exp_avgs[i] - exp_avg_sq = exp_avg_sqs[i] - step_t = state_steps[i] - # update step - step_t += 1 - # Decay the first and second moment running average coefficient - exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) - exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2) - step = _get_value(step_t) - - ### pytorch adam implementation: - # bias_correction1 = 1 - beta1 ** step - # bias_correction2 = 1 - beta2 ** step - # step_size = lr / bias_correction1 - # bias_correction2_sqrt = _dispatch_sqrt(bias_correction2) - # denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps) - # param.addcdiv_(exp_avg, denom, value=-step_size) - - ### tensorflow adam implementation: - lr_t = lr * _dispatch_sqrt(1 - beta2**step) / (1 - beta1**step) - denom = exp_avg_sq.sqrt().add_(eps) - param.addcdiv_(exp_avg, denom, value=-lr_t) - - -def adam( - params: List[Tensor], - grads: List[Tensor], - exp_avgs: List[Tensor], - exp_avg_sqs: List[Tensor], - max_exp_avg_sqs: List[Tensor], - state_steps: List[Tensor], - # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627 - # setting this as kwarg for now as functional API is compiled by torch/distributed/optim - foreach: Optional[bool] = None, - capturable: bool = False, - differentiable: bool = False, - fused: Optional[bool] = None, - grad_scale: Optional[Tensor] = None, - found_inf: Optional[Tensor] = None, - *, - amsgrad: bool, - beta1: float, - beta2: float, - lr: float, - weight_decay: float, - eps: float, - maximize: bool, -): - func = _single_tensor_adam - - func( - params, - grads, - exp_avgs, - exp_avg_sqs, - max_exp_avg_sqs, - state_steps, - amsgrad=amsgrad, - beta1=beta1, - beta2=beta2, - lr=lr, - weight_decay=weight_decay, - eps=eps, - maximize=maximize, - capturable=capturable, - differentiable=differentiable, - grad_scale=grad_scale, - found_inf=found_inf, - ) - - -class AdamTensorFlowStyle(optim.Adam): - @_use_grad_for_differentiable - def step(self, closure=None): - self._cuda_graph_capture_health_check() - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - - for group in self.param_groups: - params_with_grad = [] - grads = [] - exp_avgs = [] - exp_avg_sqs = [] - max_exp_avg_sqs = [] - state_steps = [] - beta1, beta2 = group["betas"] - - self._init_group( - group, - params_with_grad, - grads, - exp_avgs, - exp_avg_sqs, - max_exp_avg_sqs, - state_steps, - ) - - adam( - params_with_grad, - grads, - exp_avgs, - exp_avg_sqs, - max_exp_avg_sqs, - state_steps, - amsgrad=group["amsgrad"], - beta1=beta1, - beta2=beta2, - lr=group["lr"], - weight_decay=group["weight_decay"], - eps=group["eps"], - maximize=group["maximize"], - foreach=group["foreach"], - capturable=group["capturable"], - differentiable=group["differentiable"], - fused=group["fused"], - grad_scale=getattr(self, "grad_scale", None), - found_inf=getattr(self, "found_inf", None), - ) - - return loss - - -def right_padding_to_left_padding(tokens, pad_id): - """Convert from right padding to left padding.""" - assert tokens.ndim == 2 - return torch.tensor( - [[pad_id] * (row == pad_id).sum() + [x for x in row if x != pad_id] for row in tokens], - device=tokens.device, - ) - - -def ceil_div(a, b): - return (a - 1) // b + 1 - - -def exact_div(a, b): - q = a // b - if a != q * b: - raise ValueError(f"Inexact division: {a} / {b} = {a / b}") - return q - - -def generate(lm_backbone, queries, tokenizer, generation_config): - """generate in a way that does not affect padding tokens""" - context_length = queries.shape[1] - attention_mask = queries != tokenizer.pad_token_id - input_ids = queries.clone() - input_ids[~attention_mask] = 0 # set padding tokens to 0 - output = lm_backbone.generate( - input_ids=input_ids, - attention_mask=attention_mask, - # position_ids=attention_mask.cumsum(1) - attention_mask.long(), # generation collapsed if this was turned on. TODO: why does generation collapse with this? - generation_config=generation_config, - return_dict_in_generate=True, - ) - # restore padding tokens - return torch.cat((queries, output.sequences[:, context_length:]), dim=1) - - -def forward(policy, query_responses, tokenizer): - attention_mask = query_responses != tokenizer.pad_token_id - position_ids = attention_mask.cumsum(1) - attention_mask.long() # exclusive cumsum - input_ids = torch.masked_fill(query_responses, ~attention_mask, 0) - return policy( - labels=input_ids, - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - return_dict=True, - ) - - -def train(args: Args): - accelerator = Accelerator(gradient_accumulation_steps=args.sft.gradient_accumulation_steps) - args.sft.world_size = accelerator.num_processes - args.sft.local_batch_size = args.sft.local_micro_batch_size * args.sft.gradient_accumulation_steps - args.sft.batch_size = int(args.sft.local_batch_size * args.sft.world_size) - patch_h = TaskQueryHParams( - length=args.task.query_length, - dataset=args.task.query_dataset, - format_str=args.task.query_format_str, - truncate_field=args.task.query_truncate_field, - truncate_text=args.task.query_truncate_text, - padding=args.task.query_padding, - pad_side=args.task.query_pad_side, - ) - dataset = load_dataset(args.task.query_dataset, split="train") - test_dataset = load_dataset(args.task.query_dataset, split="test") - accelerator.print("The number of samples in dataset", len(dataset)) - accelerator.print("The number of samples in test_dataset", len(test_dataset)) - args.sft.total_episodes = len(dataset) - args.sft.num_updates = args.sft.total_episodes // args.sft.batch_size - - console = Console(force_terminal=True) - run_name = f"{args.exp_name}__{args.seed}__{int(time.time())}" - writer = SimpleNamespace() # dummy writer - writer.add_scalar = lambda x, y, z: None - writer.add_histogram = lambda x, y, z: None - if accelerator.is_main_process: - if args.track: - import wandb - - wandb.init( - project=args.wandb_project_name, - entity=args.wandb_entity, - sync_tensorboard=True, - config=asdict(args), - name=run_name, - save_code=True, - ) - wandb.run.log_code(".") - writer = SummaryWriter(f"runs/{run_name}") - writer.add_text( - "hyperparameters", - "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])), - ) - pprint(args) - device = accelerator.device - local_seed = args.seed + accelerator.process_index * 100003 # Prime - random.seed(local_seed) - np.random.seed(local_seed) - torch.manual_seed(local_seed) - torch.backends.cudnn.deterministic = True - tokenizer = AutoTokenizer.from_pretrained( - args.base_model, - padding_side="right", - trust_remote_code=True, - ) - # we use the padding token manually but do not resize the token embedding of the model - tokenizer.add_special_tokens({"pad_token": "[PAD]"}) - policy = AutoModelForCausalLM.from_pretrained(args.base_model, trust_remote_code=True) - policy.generation_config.eos_token_id = None # disable `pad_token_id` and `eos_token_id` because we just want to - policy.generation_config.pad_token_id = None # generate tokens without truncation / padding - # IMPORTANT: Layer norm produces weird gradients, which affects Adam optimizer to impact all the parameters systematically - # see https://github.com/pytorch/pytorch/issues/104857 for more details - if args.use_tensorflow_adam: - optimizer = AdamTensorFlowStyle(policy.parameters(), lr=args.sft.lr, eps=args.sft.eps) - else: - optimizer = optim.Adam(policy.parameters(), lr=args.sft.lr, eps=args.sft.eps) - - def process_query_data(x): - pad_summary_w_leading_space = " " + x["summary"] - return { - **process_query(x, encoder=tokenizer, hparams=patch_h), - "reference_response": tokenizer.encode( - pad_summary_w_leading_space, - padding="max_length", - max_length=args.task.response_length, - truncation=True, - # with an extra leading space to account for the space between the query and response - ), - } - - dataset = dataset.map(process_query_data) - dataset = dataset.with_format("torch", columns=["query_token", "reference_response"]) - dataset = dataset.shuffle(seed=local_seed) - test_dataset = test_dataset.map(process_query_data) - test_dataset = test_dataset.with_format("torch", columns=["query_token", "reference_response"]) - test_dataset = test_dataset.shuffle(seed=local_seed) - dataloader = DataLoader(dataset, batch_size=args.sft.local_micro_batch_size) - test_dataloader = DataLoader(test_dataset, batch_size=args.sft.local_micro_batch_size) - policy, optimizer, dataloader, test_dataloader = accelerator.prepare(policy, optimizer, dataloader, test_dataloader) - iter_dataloader = iter(dataloader) - # WARNING: even with `max_new_tokens` and `min_new_tokens` set to the same value, the number of tokens generated - # may not be the same. TODO: investigate further, we just want to generate a fixed number of tokens - generation_config = GenerationConfig( - max_new_tokens=args.task.response_length, - min_new_tokens=args.task.response_length, - temperature=args.task.temperature, - top_k=0.0, - top_p=1.0, - do_sample=True, - ) - rouge = evaluate.load("rouge") - - print("===training policy===") - global_step = 0 - test_data = test_dataset[0:10] - test_data = {k: v.to(device) for k, v in test_data.items()} - loss_stats = torch.zeros(args.sft.gradient_accumulation_steps, device=device) - gradient_accumulation_idx = 0 - - # Given parameters - eta_min = 0 - eta_max = 6.35e-5 - T_max = args.sft.num_updates - - for update in range(1, args.sft.num_updates + 1): - global_step += 1 * args.sft.batch_size - accelerator.print(f"update {update}, global_step {global_step}") - # frac = 1.0 - (update - 1.0) / args.sft.num_updates - # lrnow = frac * args.sft.lr - lrnow = eta_min + 0.5 * (eta_max - eta_min) * (1 + np.cos(np.pi * (update - 1) / T_max)) - optimizer.param_groups[0]["lr"] = lrnow - data = next(iter_dataloader) - queries = data["query_token"].to(device) - reference_responses = data["reference_response"].to(device) - query_responses = torch.cat((queries, reference_responses), dim=1) - query_responses = right_padding_to_left_padding(query_responses, tokenizer.pad_token_id).to(device) - with accelerator.accumulate(policy): - output = forward(policy, query_responses, tokenizer) - loss = output.loss - accelerator.backward(loss) - optimizer.step() - optimizer.zero_grad() - loss_stats[gradient_accumulation_idx] = loss - gradient_accumulation_idx = (gradient_accumulation_idx + 1) % args.sft.gradient_accumulation_steps - if update > 1 and (update - 1) % args.sft.gradient_accumulation_steps == 0: - writer.add_scalar("loss", accelerator.gather(loss_stats).mean().item(), update) - writer.add_scalar("lr", lrnow, update) - if (update - 1) % args.print_sample_output_freq * args.sft.gradient_accumulation_steps == 0: - rouge_scores = collections.defaultdict(list) - for test_idx, test_data in enumerate(test_dataloader): - with torch.no_grad(): - test_queries = test_data["query_token"].to(device) - test_reference_responses = test_data["reference_response"].to(device) - test_queries = right_padding_to_left_padding(test_queries, tokenizer.pad_token_id) - generated_responses = generate( - accelerator.unwrap_model(policy), test_queries, tokenizer, generation_config - ) - accelerator.print(update, test_idx) - - all_decode_test_queries = tokenizer.batch_decode(test_queries, skip_special_tokens=True) - all_decode_test_query_responses = tokenizer.batch_decode(generated_responses, skip_special_tokens=True) - all_decode_test_reference_responses = tokenizer.batch_decode( - test_reference_responses, skip_special_tokens=True - ) - all_decode_test_responses = [ - x[len(y) :] for x, y in zip(all_decode_test_query_responses, all_decode_test_queries) - ] - rouge_score = rouge.compute( - predictions=all_decode_test_responses, references=all_decode_test_reference_responses - ) - rouge_scores["rouge1"].append(rouge_score["rouge1"]) - rouge_scores["rouge2"].append(rouge_score["rouge2"]) - rouge_scores["rougeL"].append(rouge_score["rougeL"]) - - if test_idx == 0: - try: - all_df = pd.DataFrame( - { - "query": all_decode_test_queries, - "response": all_decode_test_responses, - "reference": all_decode_test_reference_responses, - } - ) - if accelerator.is_main_process and args.track: - wandb.log({"samples/query_responses": wandb.Table(dataframe=all_df)}, step=update) - print_rich_table(f"Sample Output at Step {update}", all_df[:4], console) - except Exception as e: - print(e) - - for k, v in rouge_scores.items(): - rouge_metric = torch.tensor(v, device=device) - rouge_metric = accelerator.gather(rouge_metric) - writer.add_scalar(f"rouge/{k}", rouge_metric.mean().item(), update) - accelerator.print(f"rouge/{k}: {rouge_metric.mean().item()} {rouge_metric.shape} {rouge_metric}") - - # save model - if args.save_path: - os.makedirs(os.path.dirname(args.save_path), exist_ok=True) - accelerator.save_model(policy, args.save_path) - - if args.upload_model: - repo_name = f"{args.exp_name}__{args.rewards.label_dataset}__seed{args.seed}__{int(time.time())}" - repo_id = f"{args.hf_entity}/{repo_name}" if args.hf_entity else repo_name - policy.save_pretrained(repo_id, safe_serialization=True, push_to_hub=True) - tokenizer.save_pretrained(repo_id, push_to_hub=True) - - -if __name__ == "__main__": - args = tyro.cli(Args) - train(args) diff --git a/lm_human_preference_details/summarization/train_sft_accelerate_summarize_executor.py b/lm_human_preference_details/summarization/train_sft_accelerate_summarize_executor.py deleted file mode 100644 index 5f9b4a2..0000000 --- a/lm_human_preference_details/summarization/train_sft_accelerate_summarize_executor.py +++ /dev/null @@ -1,539 +0,0 @@ -import collections -import os -import random -import time -from concurrent.futures import ProcessPoolExecutor -from dataclasses import asdict, dataclass, field -from types import SimpleNamespace -from typing import List, Optional - -import evaluate -import numpy as np -import pandas as pd -import torch -import torch.optim as optim -import tyro -from accelerate import Accelerator -from datasets import load_dataset -from rich.console import Console -from rich.pretty import pprint -from rich.table import Table -from torch import Tensor, optim -from torch.optim.optimizer import ( - _dispatch_sqrt, - _get_value, - _use_grad_for_differentiable, -) -from torch.utils.data import DataLoader -from torch.utils.tensorboard import SummaryWriter -from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig - -from lm_human_preference_details.data import process_query - - -@dataclass -class SFTHParams: - gradient_accumulation_steps: int = 16 - local_micro_batch_size: int = 1 - noptepochs: int = 1 - lr: float = 6.35e-5 - eps: float = 1e-5 - total_episodes: tyro.conf.Suppress[int] = None - local_batch_size: tyro.conf.Suppress[int] = None - batch_size: tyro.conf.Suppress[int] = None - mini_batch_size: tyro.conf.Suppress[int] = None - world_size: tyro.conf.Suppress[int] = None - batch_size: tyro.conf.Suppress[int] = None - num_updates: tyro.conf.Suppress[int] = None - - -@dataclass -class TaskHParams: - # Query params - query_length: int = 512 - query_dataset: str = "vwxyzjn/summarize_from_feedback_tldr_3_filtered" - - query_format_str: Optional[str] = "SUBREDDIT: r/{subreddit}\n\nTITLE: {title}\n\nPOST: {post}\n\nTL;DR:" - query_truncate_field: Optional[str] = "post" - query_truncate_text: Optional[str] = "\n" - query_padding: Optional[str] = None # defaults to repeated spaces - query_pad_side: Optional[str] = "left" - - # Response params - response_length: int = 48 - - # Truncate response after the first occurrence of this token at or after index after when sampling. - truncate_token: int = 50256 # EOS token - truncate_after: int = 16 - penalty_reward_value: int = -1 - - # LM params - temperature: float = 0.01 - - -# a patch -@dataclass -class TaskQueryHParams: - length: int = None - dataset: str = None - format_str: Optional[str] = None # if underlying dataset yields dicts, can format arbitrarily - truncate_field: Optional[str] = None - truncate_text: Optional[str] = None - padding: Optional[str] = None # defaults to repeated spaces - pad_side: Optional[str] = None - - -@dataclass -class Args: - # common args - exp_name: str = os.path.basename(__file__)[: -len(".py")] - """the name of this experiment""" - seed: int = 1 - """seed of the experiment""" - track: bool = False - """if toggled, this experiment will be tracked with Weights and Biases""" - wandb_project_name: str = "cleanrl" - """the wandb's project name""" - wandb_entity: Optional[str] = None - """the entity (team) of wandb's project""" - cuda: bool = True - """Whether to use cuda if available.""" - run_name: tyro.conf.Suppress[str] = None - """TO BE FILLED: a unique name of this run""" - upload_model: bool = False - "whether to upload the saved model to huggingface" - hf_entity: str = "" - "the user or org name of the model repository from the Hugging Face Hub" - - base_model: str = "gpt2" - """the name of the pretrained model to use""" - deepspeed: bool = False - """Whether to use deepspeed to train the model""" - print_sample_output_freq: int = 180 - """How often to print sample output""" - save_path: str = "models/sft_policy.pt" - """Where to save the model""" - use_tensorflow_adam: bool = True - """Whether to use tensorflow-style Adam optimizer instead of PyTorch's""" - task: TaskHParams = field(default_factory=TaskHParams) - sft: SFTHParams = field(default_factory=SFTHParams) - - -def print_rich_table(title: str, df: pd.DataFrame, console: Console) -> Table: - table = Table(show_lines=True) - for column in df.columns: - table.add_column(column) - for _, row in df.iterrows(): - table.add_row(*row.astype(str).tolist()) - console.rule(f"[bold red]{title}") - console.print(table) - - -def calculate_rouge( - base_model: str, - test_queries: List[List[str]], - generated_responses: List[List[str]], - test_reference_responses: List[List[str]], -): - tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True) - all_decode_test_queries = tokenizer.batch_decode(test_queries, skip_special_tokens=True) - all_decode_test_query_responses = tokenizer.batch_decode(generated_responses, skip_special_tokens=True) - all_decode_test_reference_responses = tokenizer.batch_decode(test_reference_responses, skip_special_tokens=True) - all_decode_test_responses = [x[len(y) :] for x, y in zip(all_decode_test_query_responses, all_decode_test_queries)] - rouge = evaluate.load("rouge") - return rouge.compute(predictions=predictions, references=references) - - -def _single_tensor_adam( - params: List[Tensor], - grads: List[Tensor], - exp_avgs: List[Tensor], - exp_avg_sqs: List[Tensor], - max_exp_avg_sqs: List[Tensor], - state_steps: List[Tensor], - grad_scale: Optional[Tensor], - found_inf: Optional[Tensor], - *, - amsgrad: bool, - beta1: float, - beta2: float, - lr: float, - weight_decay: float, - eps: float, - maximize: bool, - capturable: bool, - differentiable: bool, -): - assert grad_scale is None and found_inf is None - - for i, param in enumerate(params): - grad = grads[i] if not maximize else -grads[i] - exp_avg = exp_avgs[i] - exp_avg_sq = exp_avg_sqs[i] - step_t = state_steps[i] - # update step - step_t += 1 - # Decay the first and second moment running average coefficient - exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) - exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2) - step = _get_value(step_t) - - ### pytorch adam implementation: - # bias_correction1 = 1 - beta1 ** step - # bias_correction2 = 1 - beta2 ** step - # step_size = lr / bias_correction1 - # bias_correction2_sqrt = _dispatch_sqrt(bias_correction2) - # denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps) - # param.addcdiv_(exp_avg, denom, value=-step_size) - - ### tensorflow adam implementation: - lr_t = lr * _dispatch_sqrt(1 - beta2**step) / (1 - beta1**step) - denom = exp_avg_sq.sqrt().add_(eps) - param.addcdiv_(exp_avg, denom, value=-lr_t) - - -def adam( - params: List[Tensor], - grads: List[Tensor], - exp_avgs: List[Tensor], - exp_avg_sqs: List[Tensor], - max_exp_avg_sqs: List[Tensor], - state_steps: List[Tensor], - # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627 - # setting this as kwarg for now as functional API is compiled by torch/distributed/optim - foreach: Optional[bool] = None, - capturable: bool = False, - differentiable: bool = False, - fused: Optional[bool] = None, - grad_scale: Optional[Tensor] = None, - found_inf: Optional[Tensor] = None, - *, - amsgrad: bool, - beta1: float, - beta2: float, - lr: float, - weight_decay: float, - eps: float, - maximize: bool, -): - func = _single_tensor_adam - - func( - params, - grads, - exp_avgs, - exp_avg_sqs, - max_exp_avg_sqs, - state_steps, - amsgrad=amsgrad, - beta1=beta1, - beta2=beta2, - lr=lr, - weight_decay=weight_decay, - eps=eps, - maximize=maximize, - capturable=capturable, - differentiable=differentiable, - grad_scale=grad_scale, - found_inf=found_inf, - ) - - -class AdamTensorFlowStyle(optim.Adam): - @_use_grad_for_differentiable - def step(self, closure=None): - self._cuda_graph_capture_health_check() - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - - for group in self.param_groups: - params_with_grad = [] - grads = [] - exp_avgs = [] - exp_avg_sqs = [] - max_exp_avg_sqs = [] - state_steps = [] - beta1, beta2 = group["betas"] - - self._init_group( - group, - params_with_grad, - grads, - exp_avgs, - exp_avg_sqs, - max_exp_avg_sqs, - state_steps, - ) - - adam( - params_with_grad, - grads, - exp_avgs, - exp_avg_sqs, - max_exp_avg_sqs, - state_steps, - amsgrad=group["amsgrad"], - beta1=beta1, - beta2=beta2, - lr=group["lr"], - weight_decay=group["weight_decay"], - eps=group["eps"], - maximize=group["maximize"], - foreach=group["foreach"], - capturable=group["capturable"], - differentiable=group["differentiable"], - fused=group["fused"], - grad_scale=getattr(self, "grad_scale", None), - found_inf=getattr(self, "found_inf", None), - ) - - return loss - - -def right_padding_to_left_padding(tokens, pad_id): - """Convert from right padding to left padding.""" - assert tokens.ndim == 2 - return torch.tensor( - [[pad_id] * (row == pad_id).sum() + [x for x in row if x != pad_id] for row in tokens], - device=tokens.device, - ) - - -def ceil_div(a, b): - return (a - 1) // b + 1 - - -def exact_div(a, b): - q = a // b - if a != q * b: - raise ValueError(f"Inexact division: {a} / {b} = {a / b}") - return q - - -def generate(lm_backbone, queries, tokenizer, generation_config): - """generate in a way that does not affect padding tokens""" - context_length = queries.shape[1] - attention_mask = queries != tokenizer.pad_token_id - input_ids = torch.masked_fill(queries, ~attention_mask, 0) - output = lm_backbone.generate( - input_ids=input_ids, - attention_mask=attention_mask, - # position_ids=attention_mask.cumsum(1) - attention_mask.long(), # generation collapsed if this was turned on. TODO: why does generation collapse with this? - generation_config=generation_config, - return_dict_in_generate=True, - ) - # restore padding tokens - return torch.cat((queries, output.sequences[:, context_length:]), dim=1) - - -def forward(policy, query_responses, tokenizer): - attention_mask = query_responses != tokenizer.pad_token_id - position_ids = attention_mask.cumsum(1) - attention_mask.long() # exclusive cumsum - input_ids = torch.masked_fill(query_responses, ~attention_mask, 0) - return policy( - labels=input_ids, - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - return_dict=True, - ) - - -def train(args: Args): - accelerator = Accelerator(gradient_accumulation_steps=args.sft.gradient_accumulation_steps) - args.sft.world_size = accelerator.num_processes - args.sft.local_batch_size = args.sft.local_micro_batch_size * args.sft.gradient_accumulation_steps - args.sft.batch_size = int(args.sft.local_batch_size * args.sft.world_size) - patch_h = TaskQueryHParams( - length=args.task.query_length, - dataset=args.task.query_dataset, - format_str=args.task.query_format_str, - truncate_field=args.task.query_truncate_field, - truncate_text=args.task.query_truncate_text, - padding=args.task.query_padding, - pad_side=args.task.query_pad_side, - ) - dataset = load_dataset(args.task.query_dataset, split="train") - test_dataset = load_dataset(args.task.query_dataset, split="test") - accelerator.print("The number of samples in dataset", len(dataset)) - accelerator.print("The number of samples in test_dataset", len(test_dataset)) - args.sft.total_episodes = len(dataset) - args.sft.num_updates = args.sft.total_episodes // args.sft.batch_size - - console = Console(force_terminal=True) - run_name = f"{args.exp_name}__{args.seed}__{int(time.time())}" - writer = SimpleNamespace() # dummy writer - writer.add_scalar = lambda x, y, z: None - writer.add_histogram = lambda x, y, z: None - if accelerator.is_main_process: - if args.track: - import wandb - - wandb.init( - project=args.wandb_project_name, - entity=args.wandb_entity, - sync_tensorboard=True, - config=asdict(args), - name=run_name, - save_code=True, - ) - wandb.run.log_code(".") - writer = SummaryWriter(f"runs/{run_name}") - writer.add_text( - "hyperparameters", - "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])), - ) - pprint(args) - device = accelerator.device - local_seed = args.seed + accelerator.process_index * 100003 # Prime - random.seed(local_seed) - np.random.seed(local_seed) - torch.manual_seed(local_seed) - torch.backends.cudnn.deterministic = True - tokenizer = AutoTokenizer.from_pretrained( - args.base_model, - padding_side="right", - trust_remote_code=True, - ) - # we use the padding token manually but do not resize the token embedding of the model - tokenizer.add_special_tokens({"pad_token": "[PAD]"}) - policy = AutoModelForCausalLM.from_pretrained(args.base_model, trust_remote_code=True) - policy.generation_config.eos_token_id = None # disable `pad_token_id` and `eos_token_id` because we just want to - policy.generation_config.pad_token_id = None # generate tokens without truncation / padding - # IMPORTANT: Layer norm produces weird gradients, which affects Adam optimizer to impact all the parameters systematically - # see https://github.com/pytorch/pytorch/issues/104857 for more details - if args.use_tensorflow_adam: - optimizer = AdamTensorFlowStyle(policy.parameters(), lr=args.sft.lr, eps=args.sft.eps) - else: - optimizer = optim.Adam(policy.parameters(), lr=args.sft.lr, eps=args.sft.eps) - - def process_query_data(x): - return { - **process_query(x, encoder=tokenizer, hparams=patch_h), - "reference_response": tokenizer.encode( - f" {x['summary']}", - padding="max_length", - max_length=args.task.response_length, - truncation=True, - # with an extra leading space to account for the space between the query and response - ), - } - - dataset = dataset.map(process_query_data) - dataset = dataset.with_format("torch", columns=["query_token", "reference_response"]) - dataset = dataset.shuffle(seed=local_seed) - test_dataset = test_dataset.map(process_query_data) - test_dataset = test_dataset.with_format("torch", columns=["query_token", "reference_response"]) - test_dataset = test_dataset.shuffle(seed=local_seed) - dataloader = DataLoader(dataset, batch_size=args.sft.local_micro_batch_size) - test_dataloader = DataLoader(test_dataset, batch_size=args.sft.local_micro_batch_size) - policy, optimizer, dataloader, test_dataloader = accelerator.prepare(policy, optimizer, dataloader, test_dataloader) - iter_dataloader = iter(dataloader) - # WARNING: even with `max_new_tokens` and `min_new_tokens` set to the same value, the number of tokens generated - # may not be the same. TODO: investigate further, we just want to generate a fixed number of tokens - generation_config = GenerationConfig( - max_new_tokens=args.task.response_length, - min_new_tokens=args.task.response_length, - temperature=args.task.temperature, - top_k=0.0, - top_p=1.0, - do_sample=True, - ) - ProcessPoolExecutor() - # rouge = evaluate.load("rouge") - - print("===training policy===") - global_step = 0 - test_data = test_dataset[0:10] - test_data = {k: v.to(device) for k, v in test_data.items()} - loss_stats = torch.zeros(args.sft.gradient_accumulation_steps, device=device) - gradient_accumulation_idx = 0 - - # Given parameters - eta_min = 0 - eta_max = 6.35e-5 - T_max = args.sft.num_updates - - for update in range(1, args.sft.num_updates + 1): - global_step += 1 * args.sft.batch_size - accelerator.print(f"update {update}, global_step {global_step}") - # frac = 1.0 - (update - 1.0) / args.sft.num_updates - # lrnow = frac * args.sft.lr - lrnow = eta_min + 0.5 * (eta_max - eta_min) * (1 + np.cos(np.pi * (update - 1) / T_max)) - optimizer.param_groups[0]["lr"] = lrnow - data = next(iter_dataloader) - queries = data["query_token"].to(device) - reference_responses = data["reference_response"].to(device) - query_responses = torch.cat((queries, reference_responses), dim=1) - query_responses = right_padding_to_left_padding(query_responses, tokenizer.pad_token_id).to(device) - with accelerator.accumulate(policy): - output = forward(policy, query_responses, tokenizer) - loss = output.loss - accelerator.backward(loss) - optimizer.step() - optimizer.zero_grad() - loss_stats[gradient_accumulation_idx] = loss - gradient_accumulation_idx = (gradient_accumulation_idx + 1) % args.sft.gradient_accumulation_steps - if update > 1 and (update - 1) % args.sft.gradient_accumulation_steps == 0: - writer.add_scalar("loss", accelerator.gather(loss_stats).mean().item(), update) - writer.add_scalar("lr", lrnow, update) - if (update - 1) % args.print_sample_output_freq * args.sft.gradient_accumulation_steps == 0: - rouge_scores = collections.defaultdict(list) - futures = [] - for test_idx, test_data in enumerate(test_dataloader): - with torch.no_grad(): - test_queries = test_data["query_token"].to(device) - test_data["reference_response"] - # test_queries = right_padding_to_left_padding(test_queries, tokenizer.pad_token_id) - generate(accelerator.unwrap_model(policy), test_queries, tokenizer, generation_config) - accelerator.print(update, test_idx) - - # futures.append( - # executor.submit( - # calculate_rouge, - # args.base_model, - # test_queries.cpu(), - # generated_responses.cpu(), - # test_reference_responses.cpu(), - # ) - # ) - # if test_idx == 0: - # try: - # all_df = pd.DataFrame( - # { - # "query": all_decode_test_queries, - # "response": all_decode_test_responses, - # "reference": all_decode_test_reference_responses, - # } - # ) - # if accelerator.is_main_process and args.track: - # wandb.log({"samples/query_responses": wandb.Table(dataframe=all_df)}, step=update) - # print_rich_table(f"Sample Output at Step {update}", all_df[:4], console) - # except Exception as e: - # print(e) - - rouge_scores = [f.result() for f in futures] # list of dicts - rouge_scores = {k: np.mean([x[k] for x in rouge_scores]) for k in rouge_scores[0].keys()} - for k, v in rouge_scores.items(): - rouge_metric = torch.tensor(v, device=device) - rouge_metric = accelerator.gather(rouge_metric) - writer.add_scalar(f"rouge/{k}", rouge_metric.mean().item(), update) - accelerator.print(f"rouge/{k}: {rouge_metric.mean().item()} {rouge_metric.shape} {rouge_metric}") - - # save model - if accelerator.is_main_process and args.save_path: - os.makedirs(os.path.dirname(args.save_path), exist_ok=True) - torch.save(accelerator.unwrap_model(policy).state_dict(), args.save_path) - - if args.upload_model: - repo_name = f"{args.exp_name}__{args.rewards.label_dataset}__seed{args.seed}__{int(time.time())}" - repo_id = f"{args.hf_entity}/{repo_name}" if args.hf_entity else repo_name - policy.save_pretrained(repo_id, safe_serialization=True, push_to_hub=True) - tokenizer.save_pretrained(repo_id, push_to_hub=True) - - -if __name__ == "__main__": - args = tyro.cli(Args) - train(args) diff --git a/lm_human_preference_details/summarize_old/train_policy_accelerate_summarize.py b/lm_human_preference_details/summarize_old/train_policy_accelerate_summarize.py deleted file mode 100644 index ce44ca9..0000000 --- a/lm_human_preference_details/summarize_old/train_policy_accelerate_summarize.py +++ /dev/null @@ -1,870 +0,0 @@ -import os -import random -import time -from dataclasses import asdict, dataclass, field -from types import SimpleNamespace -from typing import List, Optional - -import numpy as np -import pandas as pd -import torch -import torch.nn as nn -import torch.nn.functional as F -import torch.optim as optim -import tyro -from accelerate import Accelerator -from accelerate.state import AcceleratorState -from datasets import load_dataset -from rich.console import Console -from rich.pretty import pprint -from rich.table import Table -from torch import Tensor, optim -from torch.optim.optimizer import ( - _dispatch_sqrt, - _get_value, - _use_grad_for_differentiable, -) -from torch.utils.data import DataLoader -from torch.utils.tensorboard import SummaryWriter -from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig - -from lm_human_preference_details.data import process_query - - -@dataclass -class AdaptiveKLParams: - target: float = 6.0 - horizon: int = 10000 # in episodes - - -@dataclass -class RewardHParams: - kl_coef: float = 0.15 - use_adaptive_kl: bool = True - adaptive_kl: Optional[AdaptiveKLParams] = field(default_factory=AdaptiveKLParams) - trained_model: Optional[str] = "models/reward.pt" - label_dataset: tyro.conf.Suppress[Optional[str]] = None - - -@dataclass -class PpoHParams: - total_episodes: int = 1000000 - local_batch_size: int = 64 - local_mini_batch_size: tyro.conf.Suppress[int] = None - batch_size: tyro.conf.Suppress[int] = None - mini_batch_size: tyro.conf.Suppress[int] = None - gradient_accumulation_steps: int = 1 - """gradient accumulation steps""" - local_micro_batch_size: tyro.conf.Suppress[int] = None - """per rank micro batch size""" - world_size: tyro.conf.Suppress[int] = None - batch_size: tyro.conf.Suppress[int] = None - minibatch_size: tyro.conf.Suppress[int] = None - num_updates: tyro.conf.Suppress[int] = None - nminibatches: int = 1 - noptepochs: int = 4 - lr: float = 0.00001 - eps: float = 1e-5 - vf_coef: float = 0.1 - cliprange: float = 0.2 - cliprange_value: float = 0.2 - gamma: float = 1 - lam: float = 0.95 - whiten_rewards: bool = True - - -@dataclass -class TaskHParams: - # Query params - query_length: int = 512 - query_dataset: str = "vwxyzjn/summarize_from_feedback_tldr_3_filtered" - - query_format_str: Optional[str] = "SUBREDDIT: r/{subreddit}\n\nTITLE: {title}\n\nPOST: {post}\n\nTL;DR:" - query_truncate_field: Optional[str] = "post" - query_truncate_text: Optional[str] = "\n" - query_padding: Optional[str] = None # defaults to repeated spaces - query_pad_side: Optional[str] = "left" - - # Response params - response_length: int = 48 - - # Truncate response after the first occurrence of this token at or after index after when sampling. - truncate_token: int = 50256 # EOS token - truncate_after: int = 16 - penalty_reward_value: int = -1 - - # LM params - temperature: float = 0.7 - - -# a patch -@dataclass -class TaskQueryHParams: - length: int = None - dataset: str = None - format_str: Optional[str] = None # if underlying dataset yields dicts, can format arbitrarily - truncate_field: Optional[str] = None - truncate_text: Optional[str] = None - padding: Optional[str] = None # defaults to repeated spaces - pad_side: Optional[str] = None - - -@dataclass -class Args: - # common args - exp_name: str = os.path.basename(__file__)[: -len(".py")] - """the name of this experiment""" - seed: int = 1 - """seed of the experiment""" - track: bool = False - """if toggled, this experiment will be tracked with Weights and Biases""" - wandb_project_name: str = "cleanrl" - """the wandb's project name""" - wandb_entity: Optional[str] = None - """the entity (team) of wandb's project""" - cuda: bool = True - """Whether to use cuda if available.""" - run_name: tyro.conf.Suppress[str] = None - """TO BE FILLED: a unique name of this run""" - upload_model: bool = False - "whether to upload the saved model to huggingface" - hf_entity: str = "" - "the user or org name of the model repository from the Hugging Face Hub" - - base_model: str = "gpt2" - """the name of the pretrained model to use""" - deepspeed: bool = False - """Whether to use deepspeed to train the model""" - print_sample_output_freq: int = 10 - """How often to print sample output""" - sft_model_path: str = "models/sft_policy.pt" - """Where to load the SFT model""" - save_path: str = "models/policy.pt" - """Where to save the model""" - use_tensorflow_adam: bool = True - """Whether to use tensorflow-style Adam optimizer instead of PyTorch's""" - task: TaskHParams = field(default_factory=TaskHParams) - rewards: RewardHParams = field(default_factory=RewardHParams) - ppo: PpoHParams = field(default_factory=PpoHParams) - - -def first_true_indices(bools, dtype=torch.long): - """ - Takes an N-dimensional bool tensor and returns an (N-1)-dimensional tensor of integers giving - the position of the first True in each "row". - - Returns the length of the rows (bools.size(-1)) if no element is True in a given row. - """ - row_len = bools.size(-1) - zero_or_index = row_len * (~bools).type(dtype) + torch.arange(row_len, dtype=dtype, device=bools.device) - return torch.min(zero_or_index, dim=-1).values - - -def print_rich_table(title: str, df: pd.DataFrame, console: Console) -> Table: - table = Table(show_lines=True) - for column in df.columns: - table.add_column(column) - for _, row in df.iterrows(): - table.add_row(*row.astype(str).tolist()) - console.rule(f"[bold red]{title}") - console.print(table) - - -def _single_tensor_adam( - params: List[Tensor], - grads: List[Tensor], - exp_avgs: List[Tensor], - exp_avg_sqs: List[Tensor], - max_exp_avg_sqs: List[Tensor], - state_steps: List[Tensor], - grad_scale: Optional[Tensor], - found_inf: Optional[Tensor], - *, - amsgrad: bool, - beta1: float, - beta2: float, - lr: float, - weight_decay: float, - eps: float, - maximize: bool, - capturable: bool, - differentiable: bool, -): - assert grad_scale is None and found_inf is None - - for i, param in enumerate(params): - grad = grads[i] if not maximize else -grads[i] - exp_avg = exp_avgs[i] - exp_avg_sq = exp_avg_sqs[i] - step_t = state_steps[i] - # update step - step_t += 1 - # Decay the first and second moment running average coefficient - exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) - exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2) - step = _get_value(step_t) - - ### pytorch adam implementation: - # bias_correction1 = 1 - beta1 ** step - # bias_correction2 = 1 - beta2 ** step - # step_size = lr / bias_correction1 - # bias_correction2_sqrt = _dispatch_sqrt(bias_correction2) - # denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps) - # param.addcdiv_(exp_avg, denom, value=-step_size) - - ### tensorflow adam implementation: - lr_t = lr * _dispatch_sqrt(1 - beta2**step) / (1 - beta1**step) - denom = exp_avg_sq.sqrt().add_(eps) - param.addcdiv_(exp_avg, denom, value=-lr_t) - - -def adam( - params: List[Tensor], - grads: List[Tensor], - exp_avgs: List[Tensor], - exp_avg_sqs: List[Tensor], - max_exp_avg_sqs: List[Tensor], - state_steps: List[Tensor], - # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627 - # setting this as kwarg for now as functional API is compiled by torch/distributed/optim - foreach: Optional[bool] = None, - capturable: bool = False, - differentiable: bool = False, - fused: Optional[bool] = None, - grad_scale: Optional[Tensor] = None, - found_inf: Optional[Tensor] = None, - *, - amsgrad: bool, - beta1: float, - beta2: float, - lr: float, - weight_decay: float, - eps: float, - maximize: bool, -): - func = _single_tensor_adam - - func( - params, - grads, - exp_avgs, - exp_avg_sqs, - max_exp_avg_sqs, - state_steps, - amsgrad=amsgrad, - beta1=beta1, - beta2=beta2, - lr=lr, - weight_decay=weight_decay, - eps=eps, - maximize=maximize, - capturable=capturable, - differentiable=differentiable, - grad_scale=grad_scale, - found_inf=found_inf, - ) - - -class AdamTensorFlowStyle(optim.Adam): - @_use_grad_for_differentiable - def step(self, closure=None): - self._cuda_graph_capture_health_check() - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - - for group in self.param_groups: - params_with_grad = [] - grads = [] - exp_avgs = [] - exp_avg_sqs = [] - max_exp_avg_sqs = [] - state_steps = [] - beta1, beta2 = group["betas"] - - self._init_group( - group, - params_with_grad, - grads, - exp_avgs, - exp_avg_sqs, - max_exp_avg_sqs, - state_steps, - ) - - adam( - params_with_grad, - grads, - exp_avgs, - exp_avg_sqs, - max_exp_avg_sqs, - state_steps, - amsgrad=group["amsgrad"], - beta1=beta1, - beta2=beta2, - lr=group["lr"], - weight_decay=group["weight_decay"], - eps=group["eps"], - maximize=group["maximize"], - foreach=group["foreach"], - capturable=group["capturable"], - differentiable=group["differentiable"], - fused=group["fused"], - grad_scale=getattr(self, "grad_scale", None), - found_inf=getattr(self, "found_inf", None), - ) - - return loss - - -class AdaptiveKLController: - def __init__(self, init_kl_coef: float, hparams: AdaptiveKLParams): - self.value = init_kl_coef - self.hparams = hparams - - def update(self, current, n_steps): - target = self.hparams.target - proportional_error = np.clip(current / target - 1, -0.2, 0.2) - mult = 1 + proportional_error * n_steps / self.hparams.horizon - self.value *= mult - - -def layer_init(layer, std=np.sqrt(2), bias_const=0.0): - torch.nn.init.normal_(layer.weight, std=std) - torch.nn.init.constant_(layer.bias, val=bias_const) - return layer - - -def whiten(values, shift_mean=True): - # `unbiased=False` matches TF `tf.nn.moments`'s setting - mean, var = torch.mean(values), torch.var(values, unbiased=False) - whitened = (values - mean) * torch.rsqrt(var + 1e-8) - if not shift_mean: - whitened += mean - return whitened - - -class AutoModelForCausalLMWithScalarHead(nn.Module): - def __init__(self, lm_backbone): - super().__init__() - self.lm_backbone = lm_backbone - self.scalar_head = layer_init(nn.Linear(lm_backbone.config.hidden_size, 1), std=0) - - def forward(self, **kwargs): - output = self.lm_backbone(**kwargs) - return output, self.scalar_head(output.hidden_states[-1]) - - -class AutoModelForCausalLMWithRewardHead(nn.Module): - def __init__(self, lm_backbone): - super().__init__() - self.lm_backbone = lm_backbone - self.scalar_head = layer_init( - nn.Linear(lm_backbone.config.hidden_size, 1), - std=1 / np.sqrt(lm_backbone.config.hidden_size + 1), - ) - self.reward_gain = torch.nn.Parameter(torch.tensor(1.0), requires_grad=True) - self.reward_bias = torch.nn.Parameter(torch.tensor(0.0), requires_grad=True) - - def forward(self, **kwargs): - output = self.lm_backbone(**kwargs) - reward_latents = output.hidden_states[-1] - # shape: [batch_size, length, hidden_size] - last_reward_latents = reward_latents - # shape: [batch_size, hidden_size] - reward = self.scalar_head(last_reward_latents) - # shape: [batch_size, 1] - reward = self.reward_gain * reward + self.reward_bias - return output, reward - - -def right_padding_to_left_padding(tokens, pad_id): - """Convert from right padding to left padding.""" - assert tokens.ndim == 2 - return torch.tensor( - [[pad_id] * (row == pad_id).sum() + [x for x in row if x != pad_id] for row in tokens], - device=tokens.device, - ) - - -def ceil_div(a, b): - return (a - 1) // b + 1 - - -def exact_div(a, b): - q = a // b - if a != q * b: - raise ValueError(f"Inexact division: {a} / {b} = {a / b}") - return q - - -def generate(lm_backbone, queries, tokenizer, generation_config): - """generate in a way that does not affect padding tokens""" - context_length = queries.shape[1] - attention_mask = queries != tokenizer.pad_token_id - input_ids = queries.clone() - input_ids[~attention_mask] = 0 # set padding tokens to 0 - output = lm_backbone.generate( - input_ids=input_ids, - attention_mask=attention_mask, - # position_ids=attention_mask.cumsum(1) - attention_mask.long(), # generation collapsed if this was turned on. TODO: why does generation collapse with this? - generation_config=generation_config, - return_dict_in_generate=True, - ) - # restore padding tokens - return torch.cat((queries, output.sequences[:, context_length:]), dim=1) - - -def get_reward(reward_model, query_responses, args): - attention_mask = query_responses != args.pad_token_id - position_ids = attention_mask.cumsum(1) - attention_mask.long() # exclusive cumsum - input_ids = query_responses.clone() - input_ids[~attention_mask] = 0 - return reward_model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - return_dict=True, - output_hidden_states=True, - ) - - -def forward(policy, query_responses, tokenizer): - attention_mask = query_responses != tokenizer.pad_token_id - position_ids = attention_mask.cumsum(1) - attention_mask.long() # exclusive cumsum - input_ids = query_responses.clone() - input_ids[~attention_mask] = 0 - return policy( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - return_dict=True, - output_hidden_states=True, - ) - - -def train(args: Args): - accelerator = Accelerator(gradient_accumulation_steps=args.ppo.gradient_accumulation_steps) - args.ppo.world_size = accelerator.num_processes - args.ppo.batch_size = int(args.ppo.local_batch_size * args.ppo.world_size) - args.ppo.minibatch_size = exact_div(args.ppo.batch_size, args.ppo.nminibatches) - args.ppo.local_mini_batch_size = exact_div(args.ppo.local_batch_size, args.ppo.nminibatches) - args.ppo.local_micro_batch_size = exact_div(args.ppo.local_mini_batch_size, args.ppo.gradient_accumulation_steps) - patch_h = TaskQueryHParams( - length=args.task.query_length, - dataset=args.task.query_dataset, - format_str=args.task.query_format_str, - truncate_field=args.task.query_truncate_field, - truncate_text=args.task.query_truncate_text, - padding=args.task.query_padding, - pad_side=args.task.query_pad_side, - ) - if args.ppo.whiten_rewards: - assert ( - args.ppo.local_mini_batch_size >= 8 - ), f"Per-rank minibatch size {args.ppo.local_mini_batch_size} is insufficient for whitening" - # `per_rank_rollout_batch_size` is our `args.ppo.local_batch_size` - # `per_rank_minibatch_size` is our `args.ppo.local_mini_batch_size` - args.ppo.num_updates = args.ppo.total_episodes // args.ppo.batch_size - - console = Console(force_terminal=True) - run_name = f"{args.exp_name}__{args.seed}__{int(time.time())}" - writer = SimpleNamespace() # dummy writer - writer.add_scalar = lambda x, y, z: None - writer.add_histogram = lambda x, y, z: None - if accelerator.is_main_process: - if args.track: - import wandb - - wandb.init( - project=args.wandb_project_name, - entity=args.wandb_entity, - sync_tensorboard=True, - config=asdict(args), - name=run_name, - save_code=True, - ) - wandb.run.log_code(".") - writer = SummaryWriter(f"runs/{run_name}") - writer.add_text( - "hyperparameters", - "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])), - ) - pprint(args) - device = accelerator.device - local_seed = args.seed + accelerator.process_index * 100003 # Prime - random.seed(local_seed) - np.random.seed(local_seed) - torch.manual_seed(local_seed) - torch.backends.cudnn.deterministic = True - tokenizer = AutoTokenizer.from_pretrained( - args.base_model, - padding_side="right", - trust_remote_code=True, - ) - # we use the padding token manually but do not resize the token embedding of the model - tokenizer.add_special_tokens({"pad_token": "[PAD]"}) - reward_model = AutoModelForCausalLMWithRewardHead( - AutoModelForCausalLM.from_pretrained(args.base_model, trust_remote_code=True) - ) - if args.rewards.trained_model: - reward_model.load_state_dict(torch.load(args.rewards.trained_model, map_location=device)) - print(f"loaded pretrained reward model from {args.rewards.trained_model}") - # each class should have a separate pretrained model that do not share weights - ref_policy = AutoModelForCausalLMWithScalarHead( - AutoModelForCausalLM.from_pretrained(args.base_model, trust_remote_code=True) - ) - policy = AutoModelForCausalLMWithScalarHead(AutoModelForCausalLM.from_pretrained(args.base_model, trust_remote_code=True)) - if args.sft_model_path: - policy.lm_backbone.load_state_dict(torch.load(args.sft_model_path, map_location=device)) - ref_policy.lm_backbone.load_state_dict(torch.load(args.sft_model_path, map_location=device)) - print(f"loaded pretrained policy from {args.sft_model_path}") - policy.lm_backbone.generation_config.eos_token_id = ( - None # disable `pad_token_id` and `eos_token_id` because we just want to - ) - policy.lm_backbone.generation_config.pad_token_id = None # generate tokens without truncation / padding - # IMPORTANT: Layer norm produces weird gradients, which affects Adam optimizer to impact all the parameters systematically - # see https://github.com/pytorch/pytorch/issues/104857 for more details - if args.use_tensorflow_adam: - optimizer = AdamTensorFlowStyle(policy.parameters(), lr=args.ppo.lr, eps=args.ppo.eps) - else: - optimizer = optim.Adam(policy.parameters(), lr=args.ppo.lr, eps=args.ppo.eps) - dataset = load_dataset(args.task.query_dataset, split="train") - - def process_query_data(x): - return { - **process_query(x, encoder=tokenizer, hparams=patch_h), - } - - dataset = dataset.map(process_query_data) - dataset = dataset.with_format("torch", columns=["query_token"]) - dataset = dataset.shuffle(seed=local_seed) - dataloader = DataLoader(dataset, batch_size=args.ppo.local_batch_size) - policy, optimizer, dataloader = accelerator.prepare(policy, optimizer, dataloader) - if args.deepspeed: - import deepspeed - - deepspeed_states = AcceleratorState().deepspeed_plugin - # deepspeed_states.deepspeed_config["train_micro_batch_size_per_gpu"] = args.ppo.local_micro_batch_size - # deepspeed_states.deepspeed_config["checkpoint"] = {"use_node_local_storage": True} - eval_ds_config = { - "train_micro_batch_size_per_gpu": deepspeed_states.deepspeed_config["train_micro_batch_size_per_gpu"], - # "steps_per_print": 10, - # "zero_optimization": { - # "stage": stage, - # "stage3_param_persistence_threshold": 1e4, - # "offload_param": { - # "device": off_load_device - # } - # }, - "bf16": {"enabled": True}, - "prescale_gradients": False, - "wall_clock_breakdown": False, - } - reward_model, *_ = deepspeed.initialize(model=reward_model, config=eval_ds_config) - reward_model.eval() - ref_policy, *_ = deepspeed.initialize(model=ref_policy, config=eval_ds_config) - ref_policy.eval() - else: - ref_policy = ref_policy.to(device) - reward_model = reward_model.to(device) - - def repeat_generator(): # TODO: ideally we shuffle the dataloader as well - while True: - yield from dataloader - - iter_dataloader = iter(repeat_generator()) - kl_ctl = AdaptiveKLController(args.rewards.kl_coef, hparams=args.rewards.adaptive_kl) - # WARNING: even with `max_new_tokens` and `min_new_tokens` set to the same value, the number of tokens generated - # may not be the same. TODO: investigate further, we just want to generate a fixed number of tokens - generation_config = GenerationConfig( - max_new_tokens=args.task.response_length, - min_new_tokens=args.task.response_length, - temperature=(args.task.temperature + 1e-7), - top_k=0.0, - top_p=1.0, - do_sample=True, - ) - - print("===training policy===") - global_step = 0 - stats_shape = (args.ppo.noptepochs, args.ppo.nminibatches, args.ppo.gradient_accumulation_steps) - approxkls_stats = torch.zeros(stats_shape, device=device) - clipfracs_stats = torch.zeros(stats_shape, device=device) - pg_losses_stats = torch.zeros(stats_shape, device=device) - vf_losses_stats = torch.zeros(stats_shape, device=device) - vf_clipfrac_stats = torch.zeros(stats_shape, device=device) - entropies_stats = torch.zeros(stats_shape, device=device) - for update in range(1, args.ppo.num_updates + 1): - global_step += 1 * args.ppo.batch_size - frac = 1.0 - (update - 1.0) / args.ppo.num_updates - lrnow = frac * args.ppo.lr - optimizer.param_groups[0]["lr"] = lrnow - data = next(iter_dataloader) - with torch.no_grad(): - queries = data["query_token"].to(device) - queries = right_padding_to_left_padding(data["query_token"], tokenizer.pad_token_id).to(device) - query_responses = generate( - accelerator.unwrap_model(policy).lm_backbone, - queries, - tokenizer, - generation_config, - ) - context_length = queries.shape[1] - responses = query_responses[:, context_length:] - - output, full_values = forward(policy, query_responses, tokenizer) - values = full_values[:, context_length - 1 : -1].squeeze(-1) - logits = output.logits[:, context_length - 1 : -1] - logits /= args.task.temperature + 1e-7 - all_logprobs = F.log_softmax(logits, dim=-1) - logprobs = torch.gather(all_logprobs, 2, responses.unsqueeze(-1)).squeeze(-1) - del output, logits, all_logprobs - torch.cuda.empty_cache() - - ref_output, _ = forward(ref_policy, query_responses, tokenizer) - ref_logits = ref_output.logits[:, context_length - 1 : -1] - ref_logits /= args.task.temperature + 1e-7 - ref_all_logprobs = F.log_softmax(ref_logits, dim=-1) - ref_logprobs = torch.gather(ref_all_logprobs, 2, responses.unsqueeze(-1)).squeeze(-1) - del ref_output, ref_logits, ref_all_logprobs - torch.cuda.empty_cache() - - # **Response Processing** - # 1. truncate at the first occurrence of `truncate_token` that appears at or after - # position truncate_after in the responses - # https://github.com/openai/lm-human-preferences/blob/cbfd210bb8b08f6bc5c26878c10984b90f516c66/lm_human_preferences/train_policy.py#L378 - truncate_token_mask = responses == args.task.truncate_token - truncate_after_or_token_mask = torch.cat( - [ - torch.zeros_like(truncate_token_mask)[:, : args.task.truncate_after], - truncate_token_mask[:, args.task.truncate_after :], - ], - dim=1, - ) - truncate_mask = (torch.cumsum(truncate_after_or_token_mask, dim=1) - truncate_after_or_token_mask.long()).bool() - postprocessed_responses = torch.where( - truncate_mask, - torch.full_like(responses, tokenizer.pad_token_id), - responses, - ) - del truncate_token_mask, truncate_after_or_token_mask, truncate_mask - torch.cuda.empty_cache() - - # 2. run reward model on the truncated responses - postprocessed_query_responses = torch.cat((queries, postprocessed_responses), 1) - postprocessed_query_responses = right_padding_to_left_padding( - postprocessed_query_responses, tokenizer.pad_token_id - ) - scores = get_reward(reward_model, postprocessed_query_responses, tokenizer)[1] - last_response_indices = first_true_indices(query_responses == tokenizer.pad_token_id) - 1 - last_response_indices = torch.max( - last_response_indices, - torch.zeros([1], dtype=last_response_indices.dtype, device=query_responses.device), - ) - scores = scores[:, :, 0].gather(1, last_response_indices.unsqueeze(1)).view(-1) - - # 3. filter response. Ensure that the sample contains truncate_token - # responses not passing that filter will receive a low (fixed) score - # only query humans on responses that pass that filter - matches_token = postprocessed_responses[:, args.task.truncate_after :] == args.task.truncate_token - filter_mask = torch.any(matches_token, dim=-1) - scores = torch.where( - filter_mask, - scores, - torch.full_like(scores, args.task.penalty_reward_value), - ) - del matches_token, filter_mask - torch.cuda.empty_cache() - - # 4. compute rewards - kl = logprobs - ref_logprobs - non_score_reward = -kl_ctl.value * kl - rewards = non_score_reward.clone() - rewards[:, -1] += scores - - # 5. whiten rewards - if args.ppo.whiten_rewards: - rewards = whiten(rewards, shift_mean=False) - - if args.print_sample_output_freq > 0 and (update - 1) % args.print_sample_output_freq == 0: - try: - all_decode_queries = tokenizer.batch_decode(queries, skip_special_tokens=True) - all_postprocessed_query_responses = tokenizer.batch_decode( - postprocessed_query_responses, skip_special_tokens=True - ) - all_postprocessed_responses = [ - x[len(y) :] for x, y in zip(all_postprocessed_query_responses, all_decode_queries) - ] - - kl_sum = kl.sum(axis=1) - all_df = pd.DataFrame( - { - "query": all_decode_queries, - "response": all_postprocessed_responses, - "score": scores.float().cpu().numpy(), - "kl": kl_sum.float().cpu().numpy(), - "reward": (scores - kl_ctl.value * kl_sum).float().cpu().numpy(), - } - ) - if accelerator.is_main_process and args.track: - wandb.log({"query_responses": wandb.Table(dataframe=all_df)}, step=update) - print_rich_table("stuff", all_df[:4], console) - except Exception as e: - print(e) - del ( - all_decode_queries, - all_postprocessed_query_responses, - all_postprocessed_responses, - kl_sum, - all_df, - ) - del postprocessed_query_responses - torch.cuda.empty_cache() - - # 6. compute advantages and returns - lastgaelam = 0 - advantages_reversed = [] - gen_length = args.task.response_length - for t in reversed(range(gen_length)): - nextvalues = values[:, t + 1] if t < gen_length - 1 else 0.0 - delta = rewards[:, t] + args.ppo.gamma * nextvalues - values[:, t] - lastgaelam = delta + args.ppo.gamma * args.ppo.lam * lastgaelam - advantages_reversed.append(lastgaelam) - advantages = torch.stack(advantages_reversed[::-1], axis=1) - returns = advantages + values - advantages = whiten(advantages) - return_mean, return_var = returns.mean(), returns.var() - value_mean, value_var = values.mean(), values.var() - - # Do multiple epochs of PPO training, with a fresh random shuffle in each epoch - for ppo_epoch_idx in range(args.ppo.noptepochs): - b_inds = np.random.permutation(args.ppo.local_batch_size) - minibatch_idx = 0 - for mini_batch_start in range(0, args.ppo.local_batch_size, args.ppo.local_mini_batch_size): - mini_batch_end = mini_batch_start + args.ppo.local_mini_batch_size - mini_batch_inds = b_inds[mini_batch_start:mini_batch_end] - gradient_accumulation_idx = 0 - for micro_batch_start in range(0, args.ppo.local_mini_batch_size, args.ppo.local_micro_batch_size): - with accelerator.accumulate(policy): - micro_batch_end = micro_batch_start + args.ppo.local_micro_batch_size - micro_batch_inds = mini_batch_inds[micro_batch_start:micro_batch_end] - mb_return = returns[micro_batch_inds] - mb_advantage = advantages[micro_batch_inds] - mb_values = values[micro_batch_inds] - mb_responses = responses[micro_batch_inds] - mb_query_responses = query_responses[micro_batch_inds] - mb_logprobs = logprobs[micro_batch_inds] - - output, vpred_temp = forward(policy, mb_query_responses, tokenizer) - logits = output.logits[:, context_length - 1 : -1] - logits /= args.task.temperature + 1e-7 - new_all_logprobs = F.log_softmax(logits, dim=-1) - new_logprobs = torch.gather(new_all_logprobs, 2, mb_responses.unsqueeze(-1)).squeeze(-1) - vpred = vpred_temp[:, context_length - 1 : -1].squeeze(-1) - vpredclipped = torch.clamp( - vpred, - mb_values - args.ppo.cliprange_value, - mb_values + args.ppo.cliprange_value, - ) - vf_losses1 = torch.square(vpred - mb_return) - vf_losses2 = torch.square(vpredclipped - mb_return) - vf_loss = 0.5 * torch.max(vf_losses1, vf_losses2).mean() - vf_clipfrac = (vf_losses2 > vf_losses1).float().mean() - logprobs_diff = new_logprobs - mb_logprobs - ratio = torch.exp(logprobs_diff) - pg_losses = -mb_advantage * ratio - pg_losses2 = -mb_advantage * torch.clamp(ratio, 1.0 - args.ppo.cliprange, 1.0 + args.ppo.cliprange) - pg_loss = torch.max(pg_losses, pg_losses2).mean() - pg_clipfrac = (pg_losses2 > pg_losses).float().mean() - loss = pg_loss + args.ppo.vf_coef * vf_loss - accelerator.backward(loss) - optimizer.step() - optimizer.zero_grad() - prob_dist = torch.nn.functional.softmax(logits, dim=-1) - entropy = torch.logsumexp(logits, dim=-1) - torch.sum(prob_dist * logits, dim=-1) - approxkl = 0.5 * (logprobs_diff**2).mean() - with torch.no_grad(): - approxkls_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = approxkl - clipfracs_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = pg_clipfrac - pg_losses_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = pg_loss - vf_losses_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = vf_loss - vf_clipfrac_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = vf_clipfrac - entropies_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = entropy.mean() - gradient_accumulation_idx += 1 - minibatch_idx += 1 - if accelerator.is_main_process: - console.print( - f"ppo_epoch_idx", - ppo_epoch_idx, - "approxkl", - approxkl.item(), - "pg_loss", - pg_loss.item(), - "pg_clipfrac", - pg_clipfrac.item(), - "ratio", - ratio.mean().item(), - ) - - with torch.no_grad(): - if not args.deepspeed: # for some reason there is a OOM with the `writer.add_histogram` - writer.add_histogram("ppo/val/ratio_hist", ratio, update) - kl = logprobs - ref_logprobs - mean_kl = kl.sum(1).mean() - mean_entropy = (-logprobs).sum(1).mean() - mean_non_score_reward = non_score_reward.sum(1).mean() - writer.add_scalar("objective/kl_coef", kl_ctl.value, update) - writer.add_scalar("objective/kl", accelerator.gather(mean_kl).mean().item(), update) - writer.add_scalar("objective/entropy", accelerator.gather(mean_entropy).mean().item(), update) - writer.add_scalar("objective/non_score_reward", accelerator.gather(mean_non_score_reward).mean().item(), update) - writer.add_scalar( - "objective/score_total", accelerator.gather(mean_non_score_reward + scores.mean()).mean().item(), update - ) - writer.add_scalar("objective/scores", accelerator.gather(scores.mean()).mean().item(), update) - writer.add_scalar("ppo/loss/policy", accelerator.gather(pg_loss).mean().item(), update) - writer.add_scalar("ppo/loss/value", accelerator.gather(vf_loss).mean().item(), update) - writer.add_scalar("ppo/loss/total", accelerator.gather(loss).mean().item(), update) - writer.add_scalar("ppo/policy/entropy", accelerator.gather(entropy.mean()).mean().item(), update) - writer.add_scalar("ppo/policy/approxkl", accelerator.gather(approxkl).mean().item(), update) - writer.add_scalar("ppo/policy/clipfrac", accelerator.gather(pg_clipfrac).mean().item(), update) - writer.add_scalar("ppo/policy/approxkl_avg", accelerator.gather(approxkls_stats).mean().item(), update) - writer.add_scalar("ppo/policy/clipfrac_avg", accelerator.gather(clipfracs_stats).mean().item(), update) - writer.add_scalar("ppo/loss/policy_avg", accelerator.gather(pg_losses_stats).mean().item(), update) - writer.add_scalar("ppo/loss/value_avg", accelerator.gather(vf_losses_stats).mean().item(), update) - writer.add_scalar("ppo/val/clipfrac_avg", accelerator.gather(vf_clipfrac_stats).mean().item(), update) - writer.add_scalar("ppo/policy/entropy_avg", accelerator.gather(entropies_stats).mean().item(), update) - writer.add_scalar("ppo/returns/mean", accelerator.gather(return_mean).mean().item(), update) - writer.add_scalar("ppo/returns/var", accelerator.gather(return_var).mean().item(), update) - writer.add_scalar("ppo/val/vpred", accelerator.gather(vpred.mean()).mean().item(), update) - writer.add_scalar("ppo/val/error", accelerator.gather(vf_losses1.mean()).mean().item(), update) - writer.add_scalar("ppo/val/clipfrac", accelerator.gather(vf_clipfrac).mean().item(), update) - writer.add_scalar("ppo/val/mean", accelerator.gather(value_mean).mean().item(), update) - writer.add_scalar("ppo/val/var", accelerator.gather(value_var).mean().item(), update) - writer.add_scalar("ppo/val/ratio", accelerator.gather(ratio.mean()).mean().item(), update) - writer.add_scalar("ppo/val/ratio_var", accelerator.gather(ratio.mean()).var().item(), update) - writer.add_scalar("ppo/val/advantage", accelerator.gather(advantages.mean()).mean().item(), update) - writer.add_scalar("ppo/val/advantage_var", accelerator.gather(advantages.mean()).var().item(), update) - writer.add_scalar("ppo/val/num_eos_tokens", (responses == tokenizer.eos_token_id).sum().item(), update) - writer.add_scalar("ppo/lr", lrnow, update) - writer.add_scalar("ppo/episode", global_step, update) - if args.rewards.use_adaptive_kl: - kl_ctl.update(mean_kl.item(), args.ppo.batch_size) - del kl, mean_kl, mean_entropy, mean_non_score_reward, scores - - # save model - if accelerator.is_main_process and args.save_path: - os.makedirs(os.path.dirname(args.save_path), exist_ok=True) - torch.save(policy.state_dict(), args.save_path) - - if args.upload_model: - repo_name = f"{args.exp_name}__{args.rewards.label_dataset}__seed{args.seed}__{int(time.time())}" - repo_id = f"{args.hf_entity}/{repo_name}" if args.hf_entity else repo_name - policy.lm_backbone.save_pretrained(repo_id, safe_serialization=True, push_to_hub=True) - tokenizer.save_pretrained(repo_id, push_to_hub=True) - - -if __name__ == "__main__": - args = tyro.cli(Args) - train(args) diff --git a/lm_human_preference_details/summarize_old/train_policy_accelerate_summarize_separate.py b/lm_human_preference_details/summarize_old/train_policy_accelerate_summarize_separate.py deleted file mode 100644 index e2293d3..0000000 --- a/lm_human_preference_details/summarize_old/train_policy_accelerate_summarize_separate.py +++ /dev/null @@ -1,1029 +0,0 @@ -import os -import random -import time -from dataclasses import asdict, dataclass, field -from types import SimpleNamespace -from typing import List, Optional - -import numpy as np -import pandas as pd -import torch -import torch.nn as nn -import torch.nn.functional as F -import torch.optim as optim -import tyro -from accelerate import Accelerator -from accelerate.state import AcceleratorState -from datasets import load_dataset -from rich.console import Console -from rich.pretty import pprint -from rich.table import Table -from torch import Tensor, optim -from torch.optim.optimizer import ( - _dispatch_sqrt, - _get_value, - _use_grad_for_differentiable, -) -from torch.utils.data import DataLoader -from torch.utils.tensorboard import SummaryWriter -from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig - -from lm_human_preference_details.data import process_query - -INVALID_LOGPROB = 1.0 - - -@dataclass -class AdaptiveKLParams: - target: float = 6.0 - horizon: int = 10000 # in episodes - - -@dataclass -class RewardHParams: - kl_coef: float = 0.15 - use_adaptive_kl: bool = True - adaptive_kl: Optional[AdaptiveKLParams] = field(default_factory=AdaptiveKLParams) - trained_model: Optional[str] = "models/reward" - label_dataset: tyro.conf.Suppress[Optional[str]] = None - - -@dataclass -class PpoHParams: - total_episodes: int = 1000000 - local_batch_size: int = 64 - local_mini_batch_size: tyro.conf.Suppress[int] = None - batch_size: tyro.conf.Suppress[int] = None - mini_batch_size: tyro.conf.Suppress[int] = None - gradient_accumulation_steps: int = 1 - """gradient accumulation steps""" - local_micro_batch_size: tyro.conf.Suppress[int] = None - """per rank micro batch size""" - world_size: tyro.conf.Suppress[int] = None - batch_size: tyro.conf.Suppress[int] = None - minibatch_size: tyro.conf.Suppress[int] = None - num_updates: tyro.conf.Suppress[int] = None - nminibatches: int = 1 - noptepochs: int = 4 - lr: float = 0.00001 - eps: float = 1e-5 - vf_coef: float = 0.1 - cliprange: float = 0.2 - cliprange_value: float = 0.2 - gamma: float = 1 - lam: float = 0.95 - whiten_rewards: bool = True - - -@dataclass -class TaskHParams: - # Query params - query_length: int = 512 - query_dataset: str = "vwxyzjn/summarize_from_feedback_tldr_3_filtered" - - query_format_str: Optional[str] = "SUBREDDIT: r/{subreddit}\n\nTITLE: {title}\n\nPOST: {post}\n\nTL;DR:" - query_truncate_field: Optional[str] = "post" - query_truncate_text: Optional[str] = "\n" - query_padding: Optional[str] = None # defaults to repeated spaces - query_pad_side: Optional[str] = "left" - - # Response params - response_length: int = 48 - - # Truncate response after the first occurrence of this token at or after index after when sampling. - truncate_token: int = 50256 # EOS token - truncate_after: int = 16 - penalty_reward_value: int = -1 - - # LM params - temperature: float = 0.7 - - -# a patch -@dataclass -class TaskQueryHParams: - length: int = None - dataset: str = None - format_str: Optional[str] = None # if underlying dataset yields dicts, can format arbitrarily - truncate_field: Optional[str] = None - truncate_text: Optional[str] = None - padding: Optional[str] = None # defaults to repeated spaces - pad_side: Optional[str] = None - - -@dataclass -class Args: - # common args - exp_name: str = os.path.basename(__file__)[: -len(".py")] - """the name of this experiment""" - seed: int = 1 - """seed of the experiment""" - track: bool = False - """if toggled, this experiment will be tracked with Weights and Biases""" - wandb_project_name: str = "cleanrl" - """the wandb's project name""" - wandb_entity: Optional[str] = None - """the entity (team) of wandb's project""" - cuda: bool = True - """Whether to use cuda if available.""" - run_name: tyro.conf.Suppress[str] = None - """TO BE FILLED: a unique name of this run""" - load_from_cache_file: bool = False - """Whether to load data from the local cache file in `dataset.map`""" - upload_model: bool = False - "whether to upload the saved model to huggingface" - hf_entity: str = "" - "the user or org name of the model repository from the Hugging Face Hub" - - base_model: str = "gpt2" - """the name of the pretrained model to use""" - deepspeed: bool = False - """Whether to use deepspeed to train the model""" - print_sample_output_freq: int = 1 - """How often to print sample output""" - sft_model_path: str = "models/sft_policy" - """Where to load the SFT model""" - save_path: str = "models/policy.pt" - """Where to save the model""" - use_tensorflow_adam: bool = True - """Whether to use tensorflow-style Adam optimizer instead of PyTorch's""" - task: TaskHParams = field(default_factory=TaskHParams) - rewards: RewardHParams = field(default_factory=RewardHParams) - ppo: PpoHParams = field(default_factory=PpoHParams) - - -def first_true_indices(bools, dtype=torch.long): - """ - Takes an N-dimensional bool tensor and returns an (N-1)-dimensional tensor of integers giving - the position of the first True in each "row". - - Returns the length of the rows (bools.size(-1)) if no element is True in a given row. - """ - row_len = bools.size(-1) - zero_or_index = row_len * (~bools).type(dtype) + torch.arange(row_len, dtype=dtype, device=bools.device) - return torch.min(zero_or_index, dim=-1).values - - -def print_rich_table(title: str, df: pd.DataFrame, console: Console) -> Table: - table = Table(show_lines=True) - for column in df.columns: - table.add_column(column) - for _, row in df.iterrows(): - table.add_row(*row.astype(str).tolist()) - console.rule(f"[bold red]{title}") - console.print(table) - - -def _single_tensor_adam( - params: List[Tensor], - grads: List[Tensor], - exp_avgs: List[Tensor], - exp_avg_sqs: List[Tensor], - max_exp_avg_sqs: List[Tensor], - state_steps: List[Tensor], - grad_scale: Optional[Tensor], - found_inf: Optional[Tensor], - *, - amsgrad: bool, - beta1: float, - beta2: float, - lr: float, - weight_decay: float, - eps: float, - maximize: bool, - capturable: bool, - differentiable: bool, -): - assert grad_scale is None and found_inf is None - - for i, param in enumerate(params): - grad = grads[i] if not maximize else -grads[i] - exp_avg = exp_avgs[i] - exp_avg_sq = exp_avg_sqs[i] - step_t = state_steps[i] - # update step - step_t += 1 - # Decay the first and second moment running average coefficient - exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) - exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2) - step = _get_value(step_t) - - ### pytorch adam implementation: - # bias_correction1 = 1 - beta1 ** step - # bias_correction2 = 1 - beta2 ** step - # step_size = lr / bias_correction1 - # bias_correction2_sqrt = _dispatch_sqrt(bias_correction2) - # denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps) - # param.addcdiv_(exp_avg, denom, value=-step_size) - - ### tensorflow adam implementation: - lr_t = lr * _dispatch_sqrt(1 - beta2**step) / (1 - beta1**step) - denom = exp_avg_sq.sqrt().add_(eps) - param.addcdiv_(exp_avg, denom, value=-lr_t) - - -def adam( - params: List[Tensor], - grads: List[Tensor], - exp_avgs: List[Tensor], - exp_avg_sqs: List[Tensor], - max_exp_avg_sqs: List[Tensor], - state_steps: List[Tensor], - # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627 - # setting this as kwarg for now as functional API is compiled by torch/distributed/optim - foreach: Optional[bool] = None, - capturable: bool = False, - differentiable: bool = False, - fused: Optional[bool] = None, - grad_scale: Optional[Tensor] = None, - found_inf: Optional[Tensor] = None, - *, - amsgrad: bool, - beta1: float, - beta2: float, - lr: float, - weight_decay: float, - eps: float, - maximize: bool, -): - func = _single_tensor_adam - - func( - params, - grads, - exp_avgs, - exp_avg_sqs, - max_exp_avg_sqs, - state_steps, - amsgrad=amsgrad, - beta1=beta1, - beta2=beta2, - lr=lr, - weight_decay=weight_decay, - eps=eps, - maximize=maximize, - capturable=capturable, - differentiable=differentiable, - grad_scale=grad_scale, - found_inf=found_inf, - ) - - -class AdamTensorFlowStyle(optim.Adam): - @_use_grad_for_differentiable - def step(self, closure=None): - self._cuda_graph_capture_health_check() - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - - for group in self.param_groups: - params_with_grad = [] - grads = [] - exp_avgs = [] - exp_avg_sqs = [] - max_exp_avg_sqs = [] - state_steps = [] - beta1, beta2 = group["betas"] - - self._init_group( - group, - params_with_grad, - grads, - exp_avgs, - exp_avg_sqs, - max_exp_avg_sqs, - state_steps, - ) - - adam( - params_with_grad, - grads, - exp_avgs, - exp_avg_sqs, - max_exp_avg_sqs, - state_steps, - amsgrad=group["amsgrad"], - beta1=beta1, - beta2=beta2, - lr=group["lr"], - weight_decay=group["weight_decay"], - eps=group["eps"], - maximize=group["maximize"], - foreach=group["foreach"], - capturable=group["capturable"], - differentiable=group["differentiable"], - fused=group["fused"], - grad_scale=getattr(self, "grad_scale", None), - found_inf=getattr(self, "found_inf", None), - ) - - return loss - - -class AdaptiveKLController: - def __init__(self, init_kl_coef: float, hparams: AdaptiveKLParams): - self.value = init_kl_coef - self.hparams = hparams - - def update(self, current, n_steps): - target = self.hparams.target - proportional_error = np.clip(current / target - 1, -0.2, 0.2) - mult = 1 + proportional_error * n_steps / self.hparams.horizon - self.value *= mult - - -def layer_init(layer, std=np.sqrt(2), bias_const=0.0): - torch.nn.init.normal_(layer.weight, std=std) - torch.nn.init.constant_(layer.bias, val=bias_const) - return layer - - -def whiten(values, shift_mean=True): - # `unbiased=False` matches TF `tf.nn.moments`'s setting - mean, var = torch.mean(values), torch.var(values, unbiased=False) - whitened = (values - mean) * torch.rsqrt(var + 1e-8) - if not shift_mean: - whitened += mean - return whitened - - -class AutoModelForCausalLMWithRewardHead(nn.Module): - def __init__(self, lm_backbone): - super().__init__() - self.lm_backbone = lm_backbone - self.scalar_head = layer_init( - nn.Linear(lm_backbone.config.hidden_size, 1), - std=1 / np.sqrt(lm_backbone.config.hidden_size + 1), - ) - # self.reward_gain = torch.nn.Parameter(torch.tensor(1.0), requires_grad=True) - # self.reward_bias = torch.nn.Parameter(torch.tensor(0.0), requires_grad=True) - - def forward(self, **kwargs): - output = self.lm_backbone(**kwargs) - last_reward_latents = output.hidden_states[-1] - # shape: [batch_size, length, hidden_size] - # last_reward_latents = reward_latents - # shape: [batch_size, hidden_size] - reward = self.scalar_head(last_reward_latents) - # # shape: [batch_size, 1] - # reward = self.reward_gain * reward + self.reward_bias - return output, reward - - -# taken from https://github.com/OpenLMLab/MOSS-RLHF/blob/40b91eb2f2b71b16919addede0341d2bef70825d/ppo/ppo_trainer.py#L29 -# we did this we can do a single `model = accelerator.prepare(model)` -class PolicyAndValueWrapper(nn.Module): - def __init__(self, policy, critic) -> None: - super().__init__() - self.policy = policy - self.critic = critic - - def forward(self, **kwargs): - return self.policy(**kwargs), self.critic(**kwargs) - - -def right_padding_to_left_padding(tokens, pad_id): - """Convert from right padding to left padding.""" - assert tokens.ndim == 2 - return torch.tensor( - [[pad_id] * (row == pad_id).sum() + [x for x in row if x != pad_id] for row in tokens], - device=tokens.device, - ) - - -def ceil_div(a, b): - return (a - 1) // b + 1 - - -def exact_div(a, b): - q = a // b - if a != q * b: - raise ValueError(f"Inexact division: {a} / {b} = {a / b}") - return q - - -def generate(lm_backbone, queries, tokenizer, generation_config): - """generate in a way that does not affect padding tokens""" - context_length = queries.shape[1] - attention_mask = queries != tokenizer.pad_token_id - input_ids = queries.clone() - input_ids[~attention_mask] = 0 # set padding tokens to 0 - output = lm_backbone.generate( - input_ids=input_ids, - attention_mask=attention_mask, - # position_ids=attention_mask.cumsum(1) - attention_mask.long(), # generation collapsed if this was turned on. TODO: why does generation collapse with this? - generation_config=generation_config, - return_dict_in_generate=True, - ) - # restore padding tokens - return torch.cat((queries, output.sequences[:, context_length:]), dim=1) - - -def get_reward(reward_model, query_responses, tokenizer): - attention_mask = query_responses != tokenizer.pad_token_id - position_ids = attention_mask.cumsum(1) - attention_mask.long() # exclusive cumsum - input_ids = torch.masked_fill(query_responses, ~attention_mask, 0) - return reward_model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - return_dict=True, - output_hidden_states=True, - ) - - -def get_reward_complete(reward_model, query_responses, tokenizer): - reward = get_reward(reward_model, query_responses, tokenizer)[1] - last_response_indices = first_true_indices(query_responses == tokenizer.pad_token_id) - 1 - last_response_indices = torch.max( - last_response_indices, - torch.zeros([1], dtype=last_response_indices.dtype, device=query_responses.device), - ) - return reward[:, :, 0].gather(1, last_response_indices.unsqueeze(1)).view(-1) - - -def forward(policy, query_responses, tokenizer): - attention_mask = query_responses != tokenizer.pad_token_id - position_ids = attention_mask.cumsum(1) - attention_mask.long() # exclusive cumsum - input_ids = query_responses.clone() - input_ids[~attention_mask] = 0 - return policy( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - return_dict=True, - output_hidden_states=True, - ) - - -# def train(args: Args): -if __name__ == "__main__": - args = tyro.cli(Args) - accelerator = Accelerator(gradient_accumulation_steps=args.ppo.gradient_accumulation_steps) - args.ppo.world_size = accelerator.num_processes - args.ppo.batch_size = int(args.ppo.local_batch_size * args.ppo.world_size) - args.ppo.minibatch_size = exact_div(args.ppo.batch_size, args.ppo.nminibatches) - args.ppo.local_mini_batch_size = exact_div(args.ppo.local_batch_size, args.ppo.nminibatches) - args.ppo.local_micro_batch_size = exact_div(args.ppo.local_mini_batch_size, args.ppo.gradient_accumulation_steps) - patch_h = TaskQueryHParams( - length=args.task.query_length, - dataset=args.task.query_dataset, - format_str=args.task.query_format_str, - truncate_field=args.task.query_truncate_field, - truncate_text=args.task.query_truncate_text, - padding=args.task.query_padding, - pad_side=args.task.query_pad_side, - ) - if args.ppo.whiten_rewards: - assert ( - args.ppo.local_mini_batch_size >= 8 - ), f"Per-rank minibatch size {args.ppo.local_mini_batch_size} is insufficient for whitening" - # `per_rank_rollout_batch_size` is our `args.ppo.local_batch_size` - # `per_rank_minibatch_size` is our `args.ppo.local_mini_batch_size` - args.ppo.num_updates = args.ppo.total_episodes // args.ppo.batch_size - - console = Console(force_terminal=True) - run_name = f"{args.exp_name}__{args.seed}__{int(time.time())}" - writer = SimpleNamespace() # dummy writer - writer.add_scalar = lambda x, y, z: None - writer.add_histogram = lambda x, y, z: None - if accelerator.is_main_process: - if args.track: - import wandb - - wandb.init( - project=args.wandb_project_name, - entity=args.wandb_entity, - sync_tensorboard=True, - config=asdict(args), - name=run_name, - save_code=True, - ) - wandb.run.log_code(".") - writer = SummaryWriter(f"runs/{run_name}") - writer.add_text( - "hyperparameters", - "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])), - ) - pprint(args) - device = accelerator.device - local_seed = args.seed + accelerator.process_index * 100003 # Prime - random.seed(local_seed) - np.random.seed(local_seed) - torch.manual_seed(local_seed) - torch.backends.cudnn.deterministic = True - tokenizer = AutoTokenizer.from_pretrained( - args.base_model, - padding_side="right", - trust_remote_code=True, - ) - # we use the padding token manually but do not resize the token embedding of the model - tokenizer.add_special_tokens({"pad_token": "[PAD]"}) - reward_model = AutoModelForCausalLMWithRewardHead( - AutoModelForCausalLM.from_pretrained(args.base_model, trust_remote_code=True) - ) - critic = AutoModelForCausalLMWithRewardHead(AutoModelForCausalLM.from_pretrained(args.base_model, trust_remote_code=True)) - if args.rewards.trained_model: - reward_model.load_state_dict(torch.load(args.rewards.trained_model, map_location=device)) - critic.load_state_dict(torch.load(args.rewards.trained_model, map_location=device)) - print(f"loaded pretrained reward model from {args.rewards.trained_model}") - # each class should have a separate pretrained model that do not share weights - ref_policy = AutoModelForCausalLM.from_pretrained(args.base_model, trust_remote_code=True) - policy = AutoModelForCausalLM.from_pretrained(args.base_model, trust_remote_code=True) - if args.sft_model_path: - policy.load_state_dict(torch.load(args.sft_model_path, map_location=device)) - ref_policy.load_state_dict(torch.load(args.sft_model_path, map_location=device)) - print(f"loaded pretrained policy from {args.sft_model_path}") - policy.generation_config.eos_token_id = None # disable `pad_token_id` and `eos_token_id` because we just want to - policy.generation_config.pad_token_id = None # generate tokens without truncation / padding - model = PolicyAndValueWrapper(policy, critic) - if args.use_tensorflow_adam: - optimizer = AdamTensorFlowStyle(model.parameters(), lr=args.ppo.lr, eps=args.ppo.eps) - else: - optimizer = optim.Adam(model.parameters(), lr=args.ppo.lr, eps=args.ppo.eps) - dataset = load_dataset(args.task.query_dataset, split="train") - validation_dataset = load_dataset(args.task.query_dataset, split="validation") - - def process_query_data(x): - return { - **process_query(x, encoder=tokenizer, hparams=patch_h), - "reference_response": tokenizer.encode( - f" {x['summary']}<|endoftext|>", - padding="max_length", - max_length=args.task.response_length, - truncation=True, - # with an extra leading space to account for the space between the query and response - ), - } - - dataset = dataset.map(process_query_data, load_from_cache_file=args.load_from_cache_file) - dataset = dataset.with_format("torch", columns=["query_token", "reference_response"]) - dataset = dataset.shuffle(seed=local_seed) - dataloader = DataLoader(dataset, batch_size=args.ppo.local_batch_size) - validation_dataset = validation_dataset.map(process_query_data, load_from_cache_file=args.load_from_cache_file) - validation_dataset = validation_dataset.with_format("torch", columns=["query_token", "reference_response"]) - validation_dataloader = DataLoader(validation_dataset, batch_size=args.ppo.local_batch_size) - model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader) - validation_dataloader = accelerator.prepare(validation_dataloader) - if args.deepspeed: - import deepspeed - - deepspeed_states = AcceleratorState().deepspeed_plugin - # deepspeed_states.deepspeed_config["train_micro_batch_size_per_gpu"] = args.ppo.local_micro_batch_size - # deepspeed_states.deepspeed_config["checkpoint"] = {"use_node_local_storage": True} - eval_ds_config = { - "train_micro_batch_size_per_gpu": deepspeed_states.deepspeed_config["train_micro_batch_size_per_gpu"], - # "steps_per_print": 10, - # "zero_optimization": { - # "stage": stage, - # "stage3_param_persistence_threshold": 1e4, - # "offload_param": { - # "device": off_load_device - # } - # }, - "bf16": {"enabled": True}, - "prescale_gradients": False, - "wall_clock_breakdown": False, - } - reward_model, *_ = deepspeed.initialize(model=reward_model, config=eval_ds_config) - reward_model.eval() - ref_policy, *_ = deepspeed.initialize(model=ref_policy, config=eval_ds_config) - ref_policy.eval() - else: - ref_policy = ref_policy.to(device) - reward_model = reward_model.to(device) - - def repeat_generator(): # TODO: ideally we shuffle the dataloader as well - while True: - yield from dataloader - - sample_validation_inds = np.arange(args.ppo.batch_size) - local_sample_validation_inds = sample_validation_inds[accelerator.process_index :: accelerator.num_processes] - sample_validation = validation_dataset[local_sample_validation_inds] - sample_validation = {k: v.to(device) for k, v in sample_validation.items()} - sample_validation_queries = sample_validation["query_token"] - with torch.no_grad(): - print(sample_validation_queries.shape) - sample_validation_queries = right_padding_to_left_padding(sample_validation_queries, tokenizer.pad_token_id) - sample_validation_reference_response = sample_validation["reference_response"] - sample_validation_query_reference_responses = torch.cat( - (sample_validation_queries, sample_validation_reference_response), dim=1 - ) - sample_validation_reference_scores = get_reward_complete( - reward_model, sample_validation_query_reference_responses, tokenizer - ) - # breakpoint() - - iter_dataloader = iter(repeat_generator()) - kl_ctl = AdaptiveKLController(args.rewards.kl_coef, hparams=args.rewards.adaptive_kl) - # WARNING: even with `max_new_tokens` and `min_new_tokens` set to the same value, the number of tokens generated - # may not be the same. TODO: investigate further, we just want to generate a fixed number of tokens - generation_config = GenerationConfig( - max_new_tokens=args.task.response_length, - min_new_tokens=args.task.response_length, - temperature=(args.task.temperature + 1e-7), - top_k=0.0, - top_p=1.0, - do_sample=True, - ) - - # print("===Normalize reward model *before* training===") - # print( - # "before normalization. " - # + f"Gain: {accelerator.unwrap_model(reward_model).reward_gain.data}" - # + f" Bias: {accelerator.unwrap_model(reward_model).reward_bias.data}" - # ) - - # normalize( - # tokenizer, - # accelerator, - # device, - # reward_model, - # reward_model, - # dataloader, - # validation_dataloader, - # ) - # print( - # "after normalization. " - # + f"Gain: {accelerator.unwrap_model(reward_model).reward_gain.data}" - # + f" Bias: {accelerator.unwrap_model(reward_model).reward_bias.data}" - # ) - # # # save model - # # if args.save_path: - # # os.makedirs(os.path.dirname("models/correct_reward.pt"), exist_ok=True) - # # torch.save(accelerator.unwrap_model(reward_model).state_dict(), "models/correct_reward.pt") - # raise - - print("===training policy===") - global_step = 0 - stats_shape = (args.ppo.noptepochs, args.ppo.nminibatches, args.ppo.gradient_accumulation_steps) - approxkls_stats = torch.zeros(stats_shape, device=device) - clipfracs_stats = torch.zeros(stats_shape, device=device) - pg_losses_stats = torch.zeros(stats_shape, device=device) - vf_losses_stats = torch.zeros(stats_shape, device=device) - vf_clipfrac_stats = torch.zeros(stats_shape, device=device) - entropies_stats = torch.zeros(stats_shape, device=device) - for update in range(1, args.ppo.num_updates + 1): - global_step += 1 * args.ppo.batch_size - frac = 1.0 - (update - 1.0) / args.ppo.num_updates - lrnow = frac * args.ppo.lr - optimizer.param_groups[0]["lr"] = lrnow - data = next(iter_dataloader) - with torch.no_grad(): - """ - let's use `P` to denote the padding token, `T` to denote the truncate token, and `X` to denote the - actual tokens. - queries: `PPXXX` - query_responses: `PPXXX,XXXXTXX` # the space separates the query and response - response: `XXXXTXX` - postprocessed_responses: `XXXXTXX` -> `XXXXTPP` - postprocessed_query_responses: `PPXXX,XXXXTPP` - scores: ↑ # corresponding to this `X` token - - """ - queries = data["query_token"].to(device) - reference_responses = data["reference_response"].to(device) - queries = right_padding_to_left_padding(data["query_token"], tokenizer.pad_token_id).to(device) - query_reference_responses = torch.cat((queries, reference_responses), dim=1) - query_responses = generate( - accelerator.unwrap_model(model).policy, - queries, - tokenizer, - generation_config, - ) - context_length = queries.shape[1] - responses = query_responses[:, context_length:] - - # validation - sample_validation_query_responses = generate( - accelerator.unwrap_model(model).policy, - sample_validation_queries, - tokenizer, - generation_config, - ) - sample_validation_responses = sample_validation_query_responses[:, context_length:] - truncate_token_mask = sample_validation_responses == args.task.truncate_token - truncate_after_or_token_mask = torch.cat( - [ - torch.zeros_like(truncate_token_mask)[:, : args.task.truncate_after], - truncate_token_mask[:, args.task.truncate_after :], - ], - dim=1, - ) - truncate_mask = (torch.cumsum(truncate_after_or_token_mask, dim=1) - truncate_after_or_token_mask.long()).bool() - postprocessed_sample_validation_responses = torch.where( - truncate_mask, - torch.full_like(sample_validation_responses, tokenizer.pad_token_id), - sample_validation_responses, - ) - postprocessed_sample_validation_query_responses = torch.cat( - (sample_validation_queries, postprocessed_sample_validation_responses), 1 - ) - del truncate_token_mask, truncate_after_or_token_mask, truncate_mask - torch.cuda.empty_cache() - - output = forward(accelerator.unwrap_model(model).policy, query_responses, tokenizer) - full_values = get_reward(accelerator.unwrap_model(model).critic, query_responses, tokenizer)[1] - values = full_values[:, context_length - 1 : -1].squeeze(-1) - logits = output.logits[:, context_length - 1 : -1] - logits /= args.task.temperature + 1e-7 - all_logprobs = F.log_softmax(logits, dim=-1) - logprobs = torch.gather(all_logprobs, 2, responses.unsqueeze(-1)).squeeze(-1) - del output, logits, all_logprobs - torch.cuda.empty_cache() - - ref_output = forward(ref_policy, query_responses, tokenizer) - ref_logits = ref_output.logits[:, context_length - 1 : -1] - ref_logits /= args.task.temperature + 1e-7 - ref_all_logprobs = F.log_softmax(ref_logits, dim=-1) - ref_logprobs = torch.gather(ref_all_logprobs, 2, responses.unsqueeze(-1)).squeeze(-1) - del ref_output, ref_logits, ref_all_logprobs - torch.cuda.empty_cache() - - # **Response Processing** - # 1. truncate at the first occurrence of `truncate_token` that appears at or after - # position truncate_after in the responses - # https://github.com/openai/lm-human-preferences/blob/cbfd210bb8b08f6bc5c26878c10984b90f516c66/lm_human_preferences/train_policy.py#L378 - # truncate_token_mask = responses == args.task.truncate_token - # truncate_after_or_token_mask = torch.cat( - # [ - # torch.zeros_like(truncate_token_mask)[:, : args.task.truncate_after], - # truncate_token_mask[:, args.task.truncate_after :], - # ], - # dim=1, - # ) - # truncate_mask = (torch.cumsum(truncate_after_or_token_mask, dim=1) - truncate_after_or_token_mask.long()).bool() - # postprocessed_responses = torch.where( - # truncate_mask, - # torch.full_like(responses, tokenizer.pad_token_id), - # responses, - # ) - # del truncate_token_mask, truncate_after_or_token_mask, truncate_mask - - trunc_idxs = first_true_indices(responses == args.task.truncate_token).unsqueeze(-1) - new_size = [1] * (len(responses.size()) - 1) + [args.task.response_length] - idxs = torch.arange(args.task.response_length, device=responses.device).view(*new_size) - postprocessed_responses = torch.masked_fill(responses, idxs > trunc_idxs, tokenizer.pad_token_id) - torch.cuda.empty_cache() - - # 2. run reward model on the truncated responses - postprocessed_query_responses = torch.cat((queries, postprocessed_responses), 1) - padding_mask = postprocessed_responses == tokenizer.pad_token_id - logprobs = torch.masked_fill(logprobs, padding_mask, INVALID_LOGPROB) - ref_logprobs = torch.masked_fill(ref_logprobs, padding_mask, INVALID_LOGPROB) - values = torch.masked_fill(values, padding_mask, 0) - - scores = get_reward_complete(reward_model, postprocessed_query_responses, tokenizer) - rew = get_reward(reward_model, postprocessed_query_responses, tokenizer)[1] - - qr = postprocessed_query_responses - attention_mask = qr != tokenizer.pad_token_id - position_ids = attention_mask.cumsum(1) - attention_mask.long() # exclusive cumsum - input_ids = torch.masked_fill(qr, ~attention_mask, 0) - output = reward_model.lm_backbone( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - return_dict=True, - output_hidden_states=True, - ) - last_reward_latents = output.hidden_states[ - -1 - ] # TODO: investigate whether it should be output.hidden_states[0] or output.hidden_states[-1] - reward = reward_model.scalar_head(last_reward_latents) - - print(postprocessed_query_responses[0:5, 537:]) - print(rew.squeeze(-1)[0:5, 537:]) - print(scores) - breakpoint() - - reference_scores = get_reward_complete(reward_model, query_reference_responses, tokenizer) - # note that we do not truncate the validation responses - validation_score = get_reward_complete(reward_model, postprocessed_sample_validation_query_responses, tokenizer) - - # carperAI-style score normaliation - accelerator.print("before score", scores, scores.mean()) - accelerator.print("reference_scores", reference_scores, reference_scores.mean()) - scores = scores - reference_scores - accelerator.print("after score", scores, scores.mean()) - - # 3. filter response. Ensure that the sample contains truncate_token - # responses not passing that filter will receive a low (fixed) score - # only query humans on responses that pass that filter - contain_pad_token = torch.any(postprocessed_responses == tokenizer.pad_token_id, dim=-1) - scores = torch.where(contain_pad_token, scores, torch.full_like(scores, args.task.penalty_reward_value)) - torch.cuda.empty_cache() - - # 4. compute rewards - kl = logprobs - ref_logprobs - non_score_reward = -kl_ctl.value * kl - rewards = non_score_reward.clone() - rewards[:, -1] += scores - - # 5. whiten rewards - if args.ppo.whiten_rewards: - rewards = whiten(rewards, shift_mean=False) - - if args.print_sample_output_freq > 0 and (update - 1) % args.print_sample_output_freq == 0: - try: - all_decode_validation_queries = tokenizer.batch_decode(sample_validation_queries) - all_sample_validation_query_responses = tokenizer.batch_decode(sample_validation_query_responses) - all_sample_validation_query_responses_postprocessed = tokenizer.batch_decode( - postprocessed_sample_validation_query_responses - ) - all_sample_validation_responses = [ - x[len(y) :] for x, y in zip(all_sample_validation_query_responses, all_decode_validation_queries) - ] - all_sample_validation_postprocessed_responses = [ - x[len(y) :] - for x, y in zip(all_sample_validation_query_responses_postprocessed, all_decode_validation_queries) - ] - all_sample_validation_reference_responses = tokenizer.batch_decode(sample_validation_reference_response) - all_sample_validation_df = pd.DataFrame( - { - "query": all_decode_validation_queries, - "response": all_sample_validation_responses, - "postprocessed_response": all_sample_validation_postprocessed_responses, - "reference_responses": all_sample_validation_reference_responses, - "scores": validation_score.float().cpu().numpy(), - "reference_scores": sample_validation_reference_scores.float().cpu().numpy(), - } - ) - if accelerator.is_main_process and args.track: - wandb.log({"samples/query_responses": wandb.Table(dataframe=all_sample_validation_df)}, step=update) - print_rich_table("stuff", all_sample_validation_df[:4], console) - - except Exception as e: - print(e) - del ( - all_decode_validation_queries, - all_sample_validation_query_responses, - all_sample_validation_responses, - all_sample_validation_reference_responses, - all_sample_validation_df, - ) - del postprocessed_query_responses - torch.cuda.empty_cache() - - # 6. compute advantages and returns - lastgaelam = 0 - advantages_reversed = [] - gen_length = args.task.response_length - for t in reversed(range(gen_length)): - nextvalues = values[:, t + 1] if t < gen_length - 1 else 0.0 - delta = rewards[:, t] + args.ppo.gamma * nextvalues - values[:, t] - lastgaelam = delta + args.ppo.gamma * args.ppo.lam * lastgaelam - advantages_reversed.append(lastgaelam) - advantages = torch.stack(advantages_reversed[::-1], axis=1) - returns = advantages + values - advantages = whiten(advantages) - return_mean, return_var = returns.mean(), returns.var() - value_mean, value_var = values.mean(), values.var() - - # Do multiple epochs of PPO training, with a fresh random shuffle in each epoch - for ppo_epoch_idx in range(args.ppo.noptepochs): - b_inds = np.random.permutation(args.ppo.local_batch_size) - minibatch_idx = 0 - for mini_batch_start in range(0, args.ppo.local_batch_size, args.ppo.local_mini_batch_size): - mini_batch_end = mini_batch_start + args.ppo.local_mini_batch_size - mini_batch_inds = b_inds[mini_batch_start:mini_batch_end] - gradient_accumulation_idx = 0 - for micro_batch_start in range(0, args.ppo.local_mini_batch_size, args.ppo.local_micro_batch_size): - with accelerator.accumulate(policy): - micro_batch_end = micro_batch_start + args.ppo.local_micro_batch_size - micro_batch_inds = mini_batch_inds[micro_batch_start:micro_batch_end] - mb_return = returns[micro_batch_inds] - mb_advantage = advantages[micro_batch_inds] - mb_values = values[micro_batch_inds] - mb_responses = responses[micro_batch_inds] - mb_query_responses = query_responses[micro_batch_inds] - mb_logprobs = logprobs[micro_batch_inds] - - # output, vpred_temp = forward(policy, mb_query_responses, tokenizer) - output, (_, vpred_temp) = forward(model, mb_query_responses, tokenizer) - # TODO: value also use the EOS token index!!!!!!!!!!!!!!!!!!! - # TODO: value also use the EOS token index!!!!!!!!!!!!!!!!!!! - # TODO: value also use the EOS token index!!!!!!!!!!!!!!!!!!! - # TODO: value also use the EOS token index!!!!!!!!!!!!!!!!!!! - # TODO: value also use the EOS token index!!!!!!!!!!!!!!!!!!! - # TODO: value also use the EOS token index!!!!!!!!!!!!!!!!!!! - # TODO: value also use the EOS token index!!!!!!!!!!!!!!!!!!! - # TODO: value also use the EOS token index!!!!!!!!!!!!!!!!!!! - # TODO: value also use the EOS token index!!!!!!!!!!!!!!!!!!! - - logits = output.logits[:, context_length - 1 : -1] - logits /= args.task.temperature + 1e-7 - new_all_logprobs = F.log_softmax(logits, dim=-1) - new_logprobs = torch.gather(new_all_logprobs, 2, mb_responses.unsqueeze(-1)).squeeze(-1) - new_logprobs = torch.masked_fill(new_logprobs, padding_mask[micro_batch_inds], INVALID_LOGPROB) - vpred = vpred_temp[:, context_length - 1 : -1].squeeze(-1) - vpred = torch.masked_fill(vpred, padding_mask[micro_batch_inds], 0) - vpredclipped = torch.clamp( - vpred, - mb_values - args.ppo.cliprange_value, - mb_values + args.ppo.cliprange_value, - ) - vf_losses1 = torch.square(vpred - mb_return) - vf_losses2 = torch.square(vpredclipped - mb_return) - vf_loss = 0.5 * torch.max(vf_losses1, vf_losses2).mean() - vf_clipfrac = (vf_losses2 > vf_losses1).float().mean() - logprobs_diff = new_logprobs - mb_logprobs - ratio = torch.exp(logprobs_diff) - pg_losses = -mb_advantage * ratio - pg_losses2 = -mb_advantage * torch.clamp(ratio, 1.0 - args.ppo.cliprange, 1.0 + args.ppo.cliprange) - pg_loss = torch.max(pg_losses, pg_losses2).mean() - pg_clipfrac = (pg_losses2 > pg_losses).float().mean() - loss = pg_loss + args.ppo.vf_coef * vf_loss - accelerator.backward(loss) - breakpoint() - optimizer.step() - optimizer.zero_grad() - prob_dist = torch.nn.functional.softmax(logits, dim=-1) - entropy = torch.logsumexp(logits, dim=-1) - torch.sum(prob_dist * logits, dim=-1) - approxkl = 0.5 * (logprobs_diff**2).mean() - with torch.no_grad(): - approxkls_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = approxkl - clipfracs_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = pg_clipfrac - pg_losses_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = pg_loss - vf_losses_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = vf_loss - vf_clipfrac_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = vf_clipfrac - entropies_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = entropy.mean() - gradient_accumulation_idx += 1 - minibatch_idx += 1 - if accelerator.is_main_process: - console.print( - f"ppo_epoch_idx", - ppo_epoch_idx, - "approxkl", - approxkl.item(), - "pg_loss", - pg_loss.item(), - "pg_clipfrac", - pg_clipfrac.item(), - "ratio", - ratio.mean().item(), - ) - - with torch.no_grad(): - if not args.deepspeed: # for some reason there is a OOM with the `writer.add_histogram` - writer.add_histogram("ppo/val/ratio_hist", ratio, update) - kl = logprobs - ref_logprobs - mean_kl = kl.sum(1).mean() - mean_entropy = (-logprobs).sum(1).mean() - mean_non_score_reward = non_score_reward.sum(1).mean() - writer.add_scalar("objective/kl_coef", kl_ctl.value, update) - writer.add_scalar("objective/kl", accelerator.gather(mean_kl).mean().item(), update) - writer.add_scalar("objective/entropy", accelerator.gather(mean_entropy).mean().item(), update) - writer.add_scalar("objective/non_score_reward", accelerator.gather(mean_non_score_reward).mean().item(), update) - writer.add_scalar( - "objective/score_total", accelerator.gather(mean_non_score_reward + scores.mean()).mean().item(), update - ) - writer.add_scalar("objective/scores", accelerator.gather(scores.mean()).mean().item(), update) - writer.add_scalar("objective/reference_scores", accelerator.gather(reference_scores.mean()).mean().item(), update) - writer.add_scalar("objective/validation_score", accelerator.gather(validation_score.mean()).mean().item(), update) - writer.add_scalar("ppo/loss/policy", accelerator.gather(pg_loss).mean().item(), update) - writer.add_scalar("ppo/loss/value", accelerator.gather(vf_loss).mean().item(), update) - writer.add_scalar("ppo/loss/total", accelerator.gather(loss).mean().item(), update) - writer.add_scalar("ppo/policy/entropy", accelerator.gather(entropy.mean()).mean().item(), update) - writer.add_scalar("ppo/policy/approxkl", accelerator.gather(approxkl).mean().item(), update) - writer.add_scalar("ppo/policy/clipfrac", accelerator.gather(pg_clipfrac).mean().item(), update) - writer.add_scalar("ppo/policy/approxkl_avg", accelerator.gather(approxkls_stats).mean().item(), update) - writer.add_scalar("ppo/policy/clipfrac_avg", accelerator.gather(clipfracs_stats).mean().item(), update) - writer.add_scalar("ppo/loss/policy_avg", accelerator.gather(pg_losses_stats).mean().item(), update) - writer.add_scalar("ppo/loss/value_avg", accelerator.gather(vf_losses_stats).mean().item(), update) - writer.add_scalar("ppo/val/clipfrac_avg", accelerator.gather(vf_clipfrac_stats).mean().item(), update) - writer.add_scalar("ppo/policy/entropy_avg", accelerator.gather(entropies_stats).mean().item(), update) - writer.add_scalar("ppo/returns/mean", accelerator.gather(return_mean).mean().item(), update) - writer.add_scalar("ppo/returns/var", accelerator.gather(return_var).mean().item(), update) - writer.add_scalar("ppo/val/vpred", accelerator.gather(vpred.mean()).mean().item(), update) - writer.add_scalar("ppo/val/error", accelerator.gather(vf_losses1.mean()).mean().item(), update) - writer.add_scalar("ppo/val/clipfrac", accelerator.gather(vf_clipfrac).mean().item(), update) - writer.add_scalar("ppo/val/mean", accelerator.gather(value_mean).mean().item(), update) - writer.add_scalar("ppo/val/var", accelerator.gather(value_var).mean().item(), update) - writer.add_scalar("ppo/val/ratio", accelerator.gather(ratio.mean()).mean().item(), update) - writer.add_scalar("ppo/val/ratio_var", accelerator.gather(ratio.mean()).var().item(), update) - writer.add_scalar("ppo/val/advantage", accelerator.gather(advantages.mean()).mean().item(), update) - writer.add_scalar("ppo/val/advantage_var", accelerator.gather(advantages.mean()).var().item(), update) - writer.add_scalar("ppo/val/num_eos_tokens", (responses == tokenizer.eos_token_id).sum().item(), update) - writer.add_scalar("ppo/lr", lrnow, update) - writer.add_scalar("ppo/episode", global_step, update) - if args.rewards.use_adaptive_kl: - kl_ctl.update(mean_kl.item(), args.ppo.batch_size) - del kl, mean_kl, mean_entropy, mean_non_score_reward, scores - - # save model - if accelerator.is_main_process and args.save_path: - os.makedirs(os.path.dirname(args.save_path), exist_ok=True) - torch.save(policy.state_dict(), args.save_path) - - if args.upload_model: - repo_name = f"{args.exp_name}__{args.rewards.label_dataset}__seed{args.seed}__{int(time.time())}" - repo_id = f"{args.hf_entity}/{repo_name}" if args.hf_entity else repo_name - policy.save_pretrained(repo_id, safe_serialization=True, push_to_hub=True) - tokenizer.save_pretrained(repo_id, push_to_hub=True) - -# if __name__ == "__main__": -# args = tyro.cli(Args) -# train(args) diff --git a/lm_human_preference_details/summarize_old/train_reward_accelerate_summarize.py b/lm_human_preference_details/summarize_old/train_reward_accelerate_summarize.py deleted file mode 100644 index 9ff6abe..0000000 --- a/lm_human_preference_details/summarize_old/train_reward_accelerate_summarize.py +++ /dev/null @@ -1,827 +0,0 @@ -import os -import random -import time -from dataclasses import asdict, dataclass, field -from types import SimpleNamespace -from typing import List, Literal, Optional - -import numpy as np -import pandas as pd -import torch -import torch.nn as nn -import torch.nn.functional as F -import torch.optim as optim -import transformers -import tyro -from accelerate import Accelerator -from accelerate.state import AcceleratorState -from accelerate.utils import DistributedDataParallelKwargs, broadcast -from datasets import load_dataset -from rich.console import Console -from rich.pretty import pprint -from rich.table import Table -from torch import Tensor, optim -from torch.optim.optimizer import ( - _dispatch_sqrt, - _get_value, - _use_grad_for_differentiable, -) -from torch.utils.data import DataLoader -from torch.utils.tensorboard import SummaryWriter -from transformers import AutoModelForCausalLM, AutoTokenizer, get_scheduler - -from lm_human_preference_details.data import process_query - - -@dataclass -class LabelHParams: - type: str = None - num_train: int = 92832 - num_labels: int = 2 - source: str = None - - -@dataclass -class TaskHParams: - # Query params - query_length: int = 512 - query_dataset: str = "vwxyzjn/summarize_from_feedback_tldr_3_filtered" - - query_format_str: Optional[str] = "SUBREDDIT: r/{subreddit}\n\nTITLE: {title}\n\nPOST: {post}\n\nTL;DR:" - query_truncate_field: Optional[str] = "post" - query_truncate_text: Optional[str] = "\n" - query_padding: Optional[str] = None # defaults to repeated spaces - query_pad_side: Optional[str] = "left" - - # Response params - response_length: int = 48 - - # LM params - temperature: float = 0.7 - - -# a patch -@dataclass -class TaskQueryHParams: - length: int = None - dataset: str = None - format_str: Optional[str] = None # if underlying dataset yields dicts, can format arbitrarily - truncate_field: Optional[str] = None - truncate_text: Optional[str] = None - padding: Optional[str] = None # defaults to repeated spaces - pad_side: Optional[str] = None - - -@dataclass -class Args: - # common args - exp_name: str = os.path.basename(__file__)[: -len(".py")] - """the name of this experiment""" - seed: int = 1 - """seed of the experiment""" - track: bool = False - """if toggled, this experiment will be tracked with Weights and Biases""" - wandb_project_name: str = "cleanrl" - """the wandb's project name""" - wandb_entity: Optional[str] = None - """the entity (team) of wandb's project""" - cuda: bool = True - """Whether to use cuda if available.""" - run_name: tyro.conf.Suppress[str] = None - """TO BE FILLED: a unique name of this run""" - load_from_cache_file: bool = False - """Whether to load data from the local cache file in `dataset.map`""" - - base_model: str = "gpt2" - """the name of the pretrained model to use""" - deepspeed: bool = False - """Whether to use deepspeed to train the model""" - label_dataset: str = "openai/summarize_from_feedback" - """the name of the dataset to use for labels in `https://huggingface.co/datasets/vwxyzjn/lm-human-preferences`""" - local_batch_size: int = 4 - """per rank batch size""" - gradient_accumulation_steps: int = 1 - """gradient accumulation steps""" - local_micro_batch_size: tyro.conf.Suppress[int] = None - """per rank micro batch size""" - lr: float = 0.00005 - """the learning rate""" - eps: float = 1e-5 - """the epsilon for AdamW""" - local_rollout_batch_size: int = 512 - """per rank rollout batch size""" - rollout_batch_size: tyro.conf.Suppress[int] = None - """rollout batch size""" - world_size: tyro.conf.Suppress[int] = None - """the number of processes to use""" - batch_size: tyro.conf.Suppress[int] = None - """the batch size across all ranks""" - local_normalize_samples: int = 256 - """Samples used to estimate reward mean and std""" - normalize_samples: tyro.conf.Suppress[int] = None - """Samples used to estimate reward mean and std across all ranks""" - debug_normalize: int = 0 - """Samples used to check that normalization worked""" - normalize_before: bool = True - """Whether, before training, to normalize the rewards on the policy to the scales on the training buffer. (For comparisons, just use mean 0, var 1.)""" - normalize_after: bool = True - """Whether, after training, to normalize the rewards on the ref policy to mean 0, var 1 (so the KL coefficient always has the same meaning).""" - print_sample_output_freq: int = 300 - """How often to print sample output""" - sft_model_path: str = "models/sft_policy" - """Where to load the SFT model""" - logsigmoid: bool = True - """Whether to use log-sigmoid loss instead of cross-entropy loss""" - trainable_param_percentage: float = 1.0 - """Percentage of parameters to train""" - num_epochs: int = 1 - """Number of epochs to train""" - num_updates: tyro.conf.Suppress[int] = None - """Number of updates to train""" - save_path: str = "models/reward" - """Where to save the model""" - optimizer: Literal["tf_adam", "adam", "adamw"] = "adamw" - """Which optimizer to use""" - scheduler: str = "constant_with_warmup" - """Which scheduler to use""" - warm_up_steps: int = 100 - """Number of warm up steps for the scheduler""" - task: TaskHParams = field(default_factory=TaskHParams) - labels: LabelHParams = field(default_factory=LabelHParams) - - -def first_true_indices(bools, dtype=torch.long): - """ - Takes an N-dimensional bool tensor and returns an (N-1)-dimensional tensor of integers giving - the position of the first True in each "row". - - Returns the length of the rows (bools.size(-1)) if no element is True in a given row. - """ - row_len = bools.size(-1) - zero_or_index = row_len * (~bools).type(dtype) + torch.arange(row_len, dtype=dtype, device=bools.device) - return torch.min(zero_or_index, dim=-1).values - - -def print_rich_table(title: str, df: pd.DataFrame, console: Console) -> Table: - table = Table(show_lines=True) - for column in df.columns: - table.add_column(column) - for _, row in df.iterrows(): - table.add_row(*row.astype(str).tolist()) - console.rule(f"[bold red]{title}") - console.print(table) - - -def _single_tensor_adam( - params: List[Tensor], - grads: List[Tensor], - exp_avgs: List[Tensor], - exp_avg_sqs: List[Tensor], - max_exp_avg_sqs: List[Tensor], - state_steps: List[Tensor], - grad_scale: Optional[Tensor], - found_inf: Optional[Tensor], - *, - amsgrad: bool, - beta1: float, - beta2: float, - lr: float, - weight_decay: float, - eps: float, - maximize: bool, - capturable: bool, - differentiable: bool, -): - assert grad_scale is None and found_inf is None - - for i, param in enumerate(params): - grad = grads[i] if not maximize else -grads[i] - exp_avg = exp_avgs[i] - exp_avg_sq = exp_avg_sqs[i] - step_t = state_steps[i] - # update step - step_t += 1 - # Decay the first and second moment running average coefficient - exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) - exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2) - step = _get_value(step_t) - - ### pytorch adam implementation: - # bias_correction1 = 1 - beta1 ** step - # bias_correction2 = 1 - beta2 ** step - # step_size = lr / bias_correction1 - # bias_correction2_sqrt = _dispatch_sqrt(bias_correction2) - # denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps) - # param.addcdiv_(exp_avg, denom, value=-step_size) - - ### tensorflow adam implementation: - lr_t = lr * _dispatch_sqrt(1 - beta2**step) / (1 - beta1**step) - denom = exp_avg_sq.sqrt().add_(eps) - param.addcdiv_(exp_avg, denom, value=-lr_t) - - -def adam( - params: List[Tensor], - grads: List[Tensor], - exp_avgs: List[Tensor], - exp_avg_sqs: List[Tensor], - max_exp_avg_sqs: List[Tensor], - state_steps: List[Tensor], - # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627 - # setting this as kwarg for now as functional API is compiled by torch/distributed/optim - foreach: Optional[bool] = None, - capturable: bool = False, - differentiable: bool = False, - fused: Optional[bool] = None, - grad_scale: Optional[Tensor] = None, - found_inf: Optional[Tensor] = None, - *, - amsgrad: bool, - beta1: float, - beta2: float, - lr: float, - weight_decay: float, - eps: float, - maximize: bool, -): - func = _single_tensor_adam - - func( - params, - grads, - exp_avgs, - exp_avg_sqs, - max_exp_avg_sqs, - state_steps, - amsgrad=amsgrad, - beta1=beta1, - beta2=beta2, - lr=lr, - weight_decay=weight_decay, - eps=eps, - maximize=maximize, - capturable=capturable, - differentiable=differentiable, - grad_scale=grad_scale, - found_inf=found_inf, - ) - - -class AdamTensorFlowStyle(optim.Adam): - @_use_grad_for_differentiable - def step(self, closure=None): - self._cuda_graph_capture_health_check() - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - - for group in self.param_groups: - params_with_grad = [] - grads = [] - exp_avgs = [] - exp_avg_sqs = [] - max_exp_avg_sqs = [] - state_steps = [] - beta1, beta2 = group["betas"] - - self._init_group( - group, - params_with_grad, - grads, - exp_avgs, - exp_avg_sqs, - max_exp_avg_sqs, - state_steps, - ) - - adam( - params_with_grad, - grads, - exp_avgs, - exp_avg_sqs, - max_exp_avg_sqs, - state_steps, - amsgrad=group["amsgrad"], - beta1=beta1, - beta2=beta2, - lr=group["lr"], - weight_decay=group["weight_decay"], - eps=group["eps"], - maximize=group["maximize"], - foreach=group["foreach"], - capturable=group["capturable"], - differentiable=group["differentiable"], - fused=group["fused"], - grad_scale=getattr(self, "grad_scale", None), - found_inf=getattr(self, "found_inf", None), - ) - - return loss - - -def layer_init(layer, std=np.sqrt(2), bias_const=0.0): - torch.nn.init.normal_(layer.weight, std=std) - torch.nn.init.constant_(layer.bias, val=bias_const) - return layer - - -class AutoModelForCausalLMWithRewardHead(nn.Module): - def __init__(self, lm_backbone): - super().__init__() - self.lm_backbone = lm_backbone - self.scalar_head = layer_init( - nn.Linear(lm_backbone.config.hidden_size, 1), - std=1 / np.sqrt(lm_backbone.config.hidden_size + 1), - ) - # self.reward_gain = torch.nn.Parameter(torch.tensor(1.0), requires_grad=True) - # self.reward_bias = torch.nn.Parameter(torch.tensor(0.0), requires_grad=True) - - def forward(self, **kwargs): - output = self.lm_backbone(**kwargs) - last_reward_latents = output.hidden_states[-1] - # shape: [batch_size, hidden_size] - reward = self.scalar_head(last_reward_latents) - return output, reward - - -def right_padding_to_left_padding(tokens, pad_id): - """Convert from right padding to left padding.""" - assert tokens.ndim == 2 - return torch.tensor( - [[pad_id] * (row == pad_id).sum() + [x for x in row if x != pad_id] for row in tokens], - device=tokens.device, - ) - - -def ceil_div(a, b): - return (a - 1) // b + 1 - - -def exact_div(a, b): - q = a // b - if a != q * b: - raise ValueError(f"Inexact division: {a} / {b} = {a / b}") - return q - - -def generate(lm_backbone, queries, tokenizer, generation_config): - """generate in a way that does not affect padding tokens""" - context_length = queries.shape[1] - attention_mask = queries != tokenizer.pad_token_id - input_ids = queries.clone() - input_ids[~attention_mask] = 0 # set padding tokens to 0 - output = lm_backbone.generate( - input_ids=input_ids, - attention_mask=attention_mask, - # position_ids=attention_mask.cumsum(1) - attention_mask.long(), # generation collapsed if this was turned on. TODO: why does generation collapse with this? - generation_config=generation_config, - return_dict_in_generate=True, - ) - # restore padding tokens - return torch.cat((queries, output.sequences[:, context_length:]), dim=1) - - -def get_reward(reward_model, query_responses, tokenizer): - attention_mask = query_responses != tokenizer.pad_token_id - position_ids = attention_mask.cumsum(1) - attention_mask.long() # exclusive cumsum - input_ids = torch.masked_fill(query_responses, ~attention_mask, 0) - return reward_model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - return_dict=True, - output_hidden_states=True, - ) - - -def get_reward_complete(reward_model, query_responses, tokenizer): - reward = get_reward(reward_model, query_responses, tokenizer)[1] - last_response_indices = first_true_indices(query_responses == tokenizer.pad_token_id) - 1 - last_response_indices = torch.max( - last_response_indices, - torch.zeros([1], dtype=last_response_indices.dtype, device=query_responses.device), - ) - return reward[:, :, 0].gather(1, last_response_indices.unsqueeze(1)).view(-1), reward - - -def normalize( - tokenizer, - accelerator, - device, - lm_backbone, - reward_model, - dataloader, - validation_dataloader, -): - idx = 0 - with torch.no_grad(): - # reset reward scales - accelerator.unwrap_model(reward_model).reward_gain.data.fill_(1.0) - accelerator.unwrap_model(reward_model).reward_bias.data.fill_(0.0) - # number of minibatches for computing the normalization statistics - rewards = [] - for data in dataloader: - idx += len(data["query_token"]) - queries = data["query_token"].to(device) - queries = right_padding_to_left_padding(queries, tokenizer.pad_token_id).to(device) - reference_response = data["reference_response"].to(device) - query_responses = torch.cat((queries, reference_response), dim=1) - score = get_reward_complete(reward_model, query_responses, tokenizer) - rewards.append(score) - accelerator.print(f"====number of samples per device: {idx}") - rewards = torch.cat(rewards) - rewards = accelerator.gather(rewards) - mean, std = rewards.mean(), rewards.std() - print(f"mean: {mean}, std: {std}") - - # reward normalization - target_mean, target_std = torch.tensor(0.0, device=device), torch.tensor(1.0, device=device) - gain = target_std / std - bias = target_mean - gain * mean - print(f"gain: {gain}, bias: {bias}") - accelerator.unwrap_model(reward_model).reward_gain.data = gain - accelerator.unwrap_model(reward_model).reward_bias.data = bias - - # validate normalization - rewards = [] - for data in validation_dataloader: - queries = data["query_token"].to(device) - queries = right_padding_to_left_padding(queries, tokenizer.pad_token_id).to(device) - reference_response = data["reference_response"].to(device) - query_responses = torch.cat((queries, reference_response), dim=1) - score = get_reward_complete(reward_model, query_responses, tokenizer) - rewards.append(score) - rewards = torch.cat(rewards) - rewards = accelerator.gather(rewards) - mean, std = rewards.mean(), rewards.std() - print(f"after mean: {mean}, after std: {std}") - - -def evaluate(args, accelerator, device, reward_model, validation_label): - # reward_model.eval() - with torch.no_grad(): - # eval on validation_label, some duplicate code (I don't want to make the training loop into a function...) - test_accuracies = [] - eval_len = len(validation_label) - len_labels = (eval_len // args.batch_size) * args.batch_size # in case the last batch is not full - new_all_inds = np.arange(len_labels) - for start in range(0, len_labels, args.batch_size): - end = start + args.batch_size - b_inds_all = new_all_inds[start:end] - b_inds = b_inds_all[accelerator.process_index :: accelerator.num_processes] # multi-GPU slicing - for micro_batch_start in range(0, args.local_batch_size, args.local_micro_batch_size): - micro_batch_end = micro_batch_start + args.local_micro_batch_size - micro_batch_inds = b_inds[micro_batch_start:micro_batch_end] - mb_data = validation_label[micro_batch_inds] - mb_query = torch.from_numpy(np.stack(mb_data["query_token"])).to(device) - mb_query = right_padding_to_left_padding(mb_query, args.pad_token_id).to(device) - mb_best = torch.from_numpy(np.stack(mb_data["choice"])).to(device) - mb_responses = [ - torch.from_numpy(np.stack(mb_data[f"response{i}_token"])).to(device) for i in range(args.labels.num_labels) - ] - predicted_reward = [] - for i in range(args.labels.num_labels): - query_responses = torch.cat([mb_query, mb_responses[i]], dim=1) - score, _ = get_reward_complete(reward_model, query_responses, args) - predicted_reward.append(score) - predicted_reward = torch.stack( - predicted_reward, dim=1 - ) # shape (batch_size, num_labels), basically a reward prediction for each label - accuracy = (predicted_reward.argmax(1) == mb_best).float().mean() - test_accuracies.append(accuracy) - test_accuracy = accelerator.gather(torch.stack(test_accuracies).mean()).mean().item() - # reward_model.train() - return test_accuracy - - -def train(args: Args): - accelerator = Accelerator( - kwargs_handlers=[ - DistributedDataParallelKwargs( - broadcast_buffers=False, - # find_unused_parameters=True, - ) - ], # this is needed to avoid https://github.com/pytorch/pytorch/issues/22095#issuecomment-505099500 - gradient_accumulation_steps=args.gradient_accumulation_steps, - ) - args.world_size = accelerator.num_processes - args.batch_size = int(args.local_batch_size * args.world_size) - args.rollout_batch_size = int(args.local_rollout_batch_size * args.world_size) - args.local_micro_batch_size = exact_div(args.local_batch_size, args.gradient_accumulation_steps) - args.num_updates = args.labels.num_train // args.batch_size - patch_h = TaskQueryHParams( - length=args.task.query_length, - dataset=args.task.query_dataset, - format_str=args.task.query_format_str, - truncate_field=args.task.query_truncate_field, - truncate_text=args.task.query_truncate_text, - padding=args.task.query_padding, - pad_side=args.task.query_pad_side, - ) - - console = Console(force_terminal=True) - run_name = f"{args.exp_name}__{args.seed}__{int(time.time())}" - writer = SimpleNamespace() # dummy writer - writer.add_scalar = lambda x, y, z: None - if accelerator.is_main_process: - if args.track: - import wandb - - wandb.init( - project=args.wandb_project_name, - entity=args.wandb_entity, - sync_tensorboard=True, - config=asdict(args), - name=run_name, - save_code=True, - ) - wandb.run.log_code(".") - writer = SummaryWriter(f"runs/{run_name}") - writer.add_text( - "hyperparameters", - "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])), - ) - pprint(args) - device = accelerator.device - local_seed = args.seed + accelerator.process_index * 100003 # Prime - random.seed(local_seed) - np.random.seed(local_seed) - torch.manual_seed(local_seed) - torch.backends.cudnn.deterministic = True - tokenizer = AutoTokenizer.from_pretrained( - args.base_model, - padding_side="right", - trust_remote_code=True, - ) - # we use the padding token manually but do not resize the token embedding of the model - tokenizer.add_special_tokens({"pad_token": "[PAD]"}) - args.pad_token_id = tokenizer.pad_token_id - reward_model = AutoModelForCausalLMWithRewardHead( - AutoModelForCausalLM.from_pretrained(args.base_model, trust_remote_code=True) - ) - - # freeze the first 70% of layers - if args.trainable_param_percentage < 1.0: - layers = reward_model.lm_backbone.transformer.h - num_layers = len(layers) - num_unfrozen = int(args.trainable_param_percentage * num_layers) - for layer in layers[:-num_unfrozen]: - layer.requires_grad_(False) - - if args.sft_model_path: - reward_model.lm_backbone.load_state_dict(torch.load(args.sft_model_path, map_location=device)) - print(f"loaded SFT model from {args.sft_model_path}") - reward_model.lm_backbone.generation_config.eos_token_id = ( - None # disable `pad_token_id` and `eos_token_id` because we just want to - ) - reward_model.lm_backbone.generation_config.pad_token_id = None # generate tokens without truncation / padding - # make sure the `lm_head` or `embed_out` does not require gradients, otherwise - # pytorch DDP complains; see https://gist.github.com/vwxyzjn/45fc8706dfb3cf33695f0f57cc44a533 - if isinstance(reward_model.lm_backbone, transformers.GPTNeoXForCausalLM): - reward_model.lm_backbone.embed_out.requires_grad_(False) - if args.optimizer == "tf_adam": - optimizer = AdamTensorFlowStyle(reward_model.parameters(), lr=args.lr, eps=args.eps) - elif args.optimizer == "adam": - optimizer = optim.Adam(reward_model.parameters(), lr=args.lr, eps=args.eps) - elif args.optimizer == "adamw": - optimizer = optim.AdamW(reward_model.parameters(), lr=args.lr, eps=args.eps) - # TODO: use AdamW - scheduler = get_scheduler( - args.scheduler, - optimizer=optimizer, - num_warmup_steps=args.warm_up_steps, - num_training_steps=args.num_updates * args.num_epochs, - ) - - if args.deepspeed: - pass - - deepspeed_states = AcceleratorState().deepspeed_plugin - deepspeed_states.deepspeed_config["train_micro_batch_size_per_gpu"] = args.local_micro_batch_size - - reward_model, optimizer, scheduler = accelerator.prepare(reward_model, optimizer, scheduler) - if args.normalize_before: - dataset = load_dataset(args.task.query_dataset, split="train") - validation_dataset = load_dataset(args.task.query_dataset, split="validation") - - def process_query_data(x): - return { - **process_query(x, encoder=tokenizer, hparams=patch_h), - "reference_response": tokenizer.encode( - f" {x['summary']}<|endoftext|>", - padding="max_length", - max_length=args.task.response_length, - truncation=True, - # with an extra leading space to account for the space between the query and response - ), - } - - dataset = dataset.map(process_query_data, load_from_cache_file=args.load_from_cache_file) - dataset = dataset.with_format("torch", columns=["query_token", "reference_response"]) - dataset = dataset.shuffle(seed=local_seed) - dataloader = DataLoader(dataset, batch_size=args.local_rollout_batch_size) - validation_dataset = validation_dataset.map(process_query_data, load_from_cache_file=args.load_from_cache_file) - validation_dataset = validation_dataset.with_format("torch", columns=["query_token", "reference_response"]) - validation_dataset = validation_dataset.shuffle(seed=local_seed) - validation_dataloader = DataLoader(validation_dataset, batch_size=args.local_rollout_batch_size) - dataloader = accelerator.prepare(dataloader) - iter(dataloader) - print("===Normalize reward model *before* training===") - print( - "before normalization. " - + f"Gain: {accelerator.unwrap_model(reward_model).reward_gain.data}" - + f" Bias: {accelerator.unwrap_model(reward_model).reward_bias.data}" - ) - - normalize( - tokenizer, - accelerator, - device, - reward_model, - reward_model, - dataloader, - validation_dataloader, - ) - print( - "after normalization. " - + f"Gain: {accelerator.unwrap_model(reward_model).reward_gain.data}" - + f" Bias: {accelerator.unwrap_model(reward_model).reward_bias.data}" - ) - - # `label` has keys `['sample0', 'query', 'best', 'sample3', 'sample1', 'sample2']` - label = load_dataset(args.label_dataset, "comparisons", split="train") - validation_label = load_dataset(args.label_dataset, "comparisons", split="validation") - dev_validation_label = validation_label.filter(lambda x: x["split"] == "valid1") - eval_validation_label = validation_label.filter(lambda x: x["split"] == "valid2") - accelerator.print("Num labels found in source:", len(label)) - accelerator.print("training on", args.labels.num_train, "in batches of", args.local_batch_size) - - def process_response_data(x): - return { - **process_query(x["info"], encoder=tokenizer, hparams=patch_h), - "response0_token": tokenizer.encode( - f" {x['summaries'][0]['text']}<|endoftext|>", - padding="max_length", - max_length=args.task.response_length, - truncation=True, - ), - "response1_token": tokenizer.encode( - f" {x['summaries'][1]['text']}<|endoftext|>", - padding="max_length", - max_length=args.task.response_length, - truncation=True, - ), - } - - label = label.map(process_response_data, load_from_cache_file=args.load_from_cache_file) - dev_validation_label = dev_validation_label.map(process_response_data, load_from_cache_file=args.load_from_cache_file) - eval_validation_label = eval_validation_label.map(process_response_data, load_from_cache_file=args.load_from_cache_file) - # TODO: check if all labels have eos token - accelerator.print("===training reward model===") - num_train = (args.labels.num_train // args.batch_size) * args.batch_size - for epoch in range(args.num_epochs): - all_inds = np.random.permutation(args.labels.num_train) - # ensure that all processes have the same shuffled indices - all_inds = broadcast(torch.tensor(all_inds, device=device), 0) - all_inds = all_inds.cpu().numpy() - accelerator.print(f"epoch: {epoch}") - for (epoch_global_step, start) in enumerate(range(0, num_train, args.batch_size)): - global_step = epoch * args.num_updates + epoch_global_step - end = start + args.batch_size - b_inds_all = all_inds[start:end] - b_inds = b_inds_all[accelerator.process_index :: accelerator.num_processes] # multi-GPU slicing - # accelerator.print(f"global_step: {global_step}, start: {start}, end: {end}, b_inds: {b_inds}") - if accelerator.is_main_process: - pprint( - { - "global_step": global_step, - "start:end": f"{start}:{end}", - "b_inds_all": b_inds_all, - "b_inds": b_inds, - } - ) - losses = torch.zeros((args.gradient_accumulation_steps,), device=device) - accuracies = torch.zeros((args.gradient_accumulation_steps,), device=device) - reward_preferreds = torch.zeros((args.gradient_accumulation_steps,), device=device) - reward_rejecteds = torch.zeros((args.gradient_accumulation_steps,), device=device) - gradient_accumulation_step = 0 - # reward_model.train() - for micro_batch_start in range(0, args.local_batch_size, args.local_micro_batch_size): - with accelerator.accumulate(reward_model): - micro_batch_end = micro_batch_start + args.local_micro_batch_size - micro_batch_inds = b_inds[micro_batch_start:micro_batch_end] - mb_data = label[micro_batch_inds] - # pprint({ - # "micro_batch_start:micro_batch_end": f"{micro_batch_start}:{micro_batch_end}", - # "micro_batch_inds": micro_batch_inds, - # }) - mb_query = torch.from_numpy(np.stack(mb_data["query_token"])).to(device) - mb_best = torch.from_numpy(np.stack(mb_data["choice"])).to(device) - mb_responses = [ - torch.from_numpy(np.stack(mb_data[f"response{i}_token"])).to(device) - for i in range(args.labels.num_labels) - ] - mb_query_tiled = mb_query.unsqueeze(1).repeat(1, len(mb_responses), 1) - query_responses = torch.cat([mb_query_tiled, torch.stack(mb_responses).transpose(0, 1)], dim=2).flatten( - 0, 1 - ) - predicted_reward, reward = get_reward_complete(reward_model, query_responses, tokenizer) - predicted_reward = predicted_reward.view( - -1, len(mb_responses) - ) # TODO check shape for no gradienta ccumulation steps - - # print(tokenizer.decode(mb_query[0])) - # print(tokenizer.decode(mb_responses[0][0])) - # print(tokenizer.decode(mb_responses[1][0])) - # predicted_reward = [] - # rewards = [] - # for i in range(args.labels.num_labels): - # query_responses = torch.cat([mb_query, mb_responses[i]], dim=1) - # score, reward = get_reward_complete(reward_model, query_responses, tokenizer) - # rewards.append(reward.squeeze(-1)) - # predicted_reward.append(score) - # # shape (batch_size, num_labels), basically a reward prediction for each label - # predicted_reward = torch.stack(predicted_reward, dim=1) - # breakpoint() - accuracy = (predicted_reward.argmax(1) == mb_best).float().mean() - reward_preferred = predicted_reward.gather(1, mb_best.view(-1, 1)).view(-1) - reward_rejected = predicted_reward.gather(1, (1 - mb_best).view(-1, 1)).view(-1) - if args.logsigmoid: - loss = -F.logsigmoid(reward_preferred - reward_rejected).mean() - else: - loss = F.cross_entropy(predicted_reward, mb_best) - accelerator.backward(loss) - - # for k, v in reward_model.named_parameters(): - # if v.requires_grad: - # if v.grad is None: - # print(f"found unused param: {k}") - - optimizer.step() # accelerate handles gradient accumulation automatically - optimizer.zero_grad() - scheduler.step() - losses[gradient_accumulation_step] = loss - accuracies[gradient_accumulation_step] = accuracy - reward_preferreds[gradient_accumulation_step] = reward_preferred.mean() - reward_rejecteds[gradient_accumulation_step] = reward_rejected.mean() - gradient_accumulation_step += 1 - - train_accuracy = accelerator.gather(accuracies).mean().item() - writer.add_scalar("train/loss", accelerator.gather(losses).mean().item(), global_step) - writer.add_scalar("train/accuracy", train_accuracy, global_step) - writer.add_scalar("train/reward_preferred", accelerator.gather(reward_preferreds).mean().item(), global_step) - writer.add_scalar("train/reward_rejected", accelerator.gather(reward_rejecteds).mean().item(), global_step) - lr = scheduler.get_last_lr() - writer.add_scalar("train/lr", np.array(lr).mean().item(), global_step) - accelerator.print("train/accuracy", train_accuracy) - - # if args.print_sample_output_freq > 0 and global_step % args.print_sample_output_freq == 0: - if global_step == args.num_updates - 1: # first and last update - dev_validation_accuracy = evaluate(args, accelerator, device, reward_model, dev_validation_label) - writer.add_scalar("dev_validation/accuracy", dev_validation_accuracy, global_step) - accelerator.print("dev_validation/accuracy", dev_validation_accuracy, global_step) - eval_validation_accuracy = evaluate(args, accelerator, device, reward_model, eval_validation_label) - writer.add_scalar("eval_validation/accuracy", eval_validation_accuracy, global_step) - accelerator.print("eval_validation/accuracy", eval_validation_accuracy, global_step) - eval_validation_accuracy = evaluate(args, accelerator, device, reward_model, label) - writer.add_scalar("train_full/accuracy", eval_validation_accuracy, global_step) - accelerator.print("train_full/accuracy", eval_validation_accuracy, global_step) - - torch.cuda.empty_cache() - if args.normalize_after: - print("===Normalize reward model *after* training===") - print( - "before normalization. " - + f"Gain: {accelerator.unwrap_model(reward_model).reward_gain.data}" - + f" Bias: {accelerator.unwrap_model(reward_model).reward_bias.data}" - ) - - normalize( - tokenizer, - accelerator, - device, - reward_model, - reward_model, - dataloader, - validation_dataloader, - ) - print( - "after normalization. " - + f"Gain: {accelerator.unwrap_model(reward_model).reward_gain.data}" - + f" Bias: {accelerator.unwrap_model(reward_model).reward_bias.data}" - ) - - # save model - if args.save_path: - os.makedirs(os.path.dirname(args.save_path), exist_ok=True) - # torch.save(accelerator.unwrap_model(reward_model).state_dict(), args.save_path) - accelerator.save_model(reward_model, args.save_path) - - if accelerator.is_main_process and args.track: - wandb.finish() - - -if __name__ == "__main__": - args = tyro.cli(Args) - train(args) From f5dc7cdc8543335d12a400ed5ed2cc55715c1c6f Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Fri, 8 Dec 2023 14:53:33 +0000 Subject: [PATCH 35/62] seems successful --- ...n_policy_accelerate_summarize_separate3.py | 968 ++++++++++++++++++ 1 file changed, 968 insertions(+) create mode 100644 lm_human_preference_details/train_policy_accelerate_summarize_separate3.py diff --git a/lm_human_preference_details/train_policy_accelerate_summarize_separate3.py b/lm_human_preference_details/train_policy_accelerate_summarize_separate3.py new file mode 100644 index 0000000..b5b6eef --- /dev/null +++ b/lm_human_preference_details/train_policy_accelerate_summarize_separate3.py @@ -0,0 +1,968 @@ +import os +import random +import time +from dataclasses import asdict, dataclass, field +from types import SimpleNamespace +from typing import List, Literal, Optional + +import numpy as np +import pandas as pd +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +import tyro +from accelerate import Accelerator +from accelerate.state import AcceleratorState +from datasets import load_dataset +from rich.console import Console +from rich.pretty import pprint +from rich.table import Table +from torch import Tensor, optim +from torch.optim.optimizer import ( + _dispatch_sqrt, + _get_value, + _use_grad_for_differentiable, +) +from torch.utils.data import DataLoader +from torch.utils.tensorboard import SummaryWriter +from transformers import ( + AutoConfig, + AutoModelForCausalLM, + AutoTokenizer, + GenerationConfig, +) + + +INVALID_LOGPROB = 1.0 + + +@dataclass +class AdaptiveKLParams: + target: float = 6.0 + horizon: int = 10000 # in episodes + + +@dataclass +class RewardHParams: + use_adaptive_kl: bool = True + adaptive_kl: Optional[AdaptiveKLParams] = field(default_factory=AdaptiveKLParams) + trained_model: Optional[str] = "" + label_dataset: tyro.conf.Suppress[Optional[str]] = None + dataset_mean: float = 0. + dataset_std: float = 1. + kl_coef: float = 0.15 + + +@dataclass +class PpoHParams: + total_episodes: int = 1000000 + local_batch_size: int = 64 + local_mini_batch_size: tyro.conf.Suppress[int] = None + batch_size: tyro.conf.Suppress[int] = None + mini_batch_size: tyro.conf.Suppress[int] = None + gradient_accumulation_steps: int = 64 + """gradient accumulation steps""" + local_micro_batch_size: tyro.conf.Suppress[int] = None + """per rank micro batch size""" + world_size: tyro.conf.Suppress[int] = None + batch_size: tyro.conf.Suppress[int] = None + minibatch_size: tyro.conf.Suppress[int] = None + num_updates: tyro.conf.Suppress[int] = None + nminibatches: int = 1 + noptepochs: int = 4 + lr: float = 0.00001 + eps: float = 1e-5 + vf_coef: float = 0.1 + cliprange: float = 0.2 + cliprange_value: float = 0.2 + gamma: float = 1 + lam: float = 0.95 + whiten_rewards: bool = True + + +@dataclass +class TaskHParams: + # Query params + query_length: int = 512 + query_dataset: str = "vwxyzjn/summarize_from_feedback_tldr_3_filtered_oai_preprocessing_pythia-160m_53" + + query_format_str: Optional[str] = "SUBREDDIT: r/{subreddit}\n\nTITLE: {title}\n\nPOST: {post}\n\nTL;DR:" + query_truncate_field: Optional[str] = "post" + query_truncate_text: Optional[str] = "\n" + query_padding: Optional[str] = None # defaults to repeated spaces + query_pad_side: Optional[str] = "left" + + # Response params + response_length: int = 53 + + # Truncate response after the first occurrence of this token at or after index after when sampling. + truncate_token: Literal["eos"] = "eos" + truncate_token_id: Optional[int] = None + truncate_after: int = 16 + penalty_reward_value: int = -1 + + # LM params + temperature: float = 0.7 + + +# a patch +@dataclass +class TaskQueryHParams: + length: int = None + dataset: str = None + format_str: Optional[str] = None # if underlying dataset yields dicts, can format arbitrarily + truncate_field: Optional[str] = None + truncate_text: Optional[str] = None + padding: Optional[str] = None # defaults to repeated spaces + pad_side: Optional[str] = None + + +@dataclass +class Args: + # common args + exp_name: str = os.path.basename(__file__)[: -len(".py")] + """the name of this experiment""" + seed: int = 1 + """seed of the experiment""" + track: bool = False + """if toggled, this experiment will be tracked with Weights and Biases""" + wandb_project_name: str = "tldr_summarize" + """the wandb's project name""" + wandb_entity: Optional[str] = None + """the entity (team) of wandb's project""" + cuda: bool = True + """Whether to use cuda if available.""" + run_name: tyro.conf.Suppress[str] = None + """TO BE FILLED: a unique name of this run""" + load_from_cache_file: bool = False + """Whether to load data from the local cache file in `dataset.map`""" + upload_model: bool = False + "whether to upload the saved model to huggingface" + hf_entity: str = "" + "the user or org name of the model repository from the Hugging Face Hub" + + base_model: str = "EleutherAI/pythia-160m" + """the name of the pretrained model to use""" + dropout_layer_keys: List[str] = field(default_factory=lambda: ["attn_pdrop", "embd_pdrop", "resid_pdrop", "summary_first_dropout"]) + """Which layers to apply dropout to""" + deepspeed: bool = False + """Whether to use deepspeed to train the model""" + print_sample_output_freq: int = 10 + """How often to print sample output""" + save_path: str = "models/ppo_policy" + """Where to save the model""" + optimizer: Literal["tf_adam", "adam", "adamw"] = "adamw" + """Which optimizer to use""" + sft_model_path: str = "" + """Where to load the SFT model""" + task: TaskHParams = field(default_factory=TaskHParams) + rewards: RewardHParams = field(default_factory=RewardHParams) + ppo: PpoHParams = field(default_factory=PpoHParams) + + +# taken from https://github.com/microsoft/DeepSpeedExamples/blob/737c6740bec38b77a24a59135b6481a53d566b38/applications/DeepSpeed-Chat/training/utils/model/model_utils.py#L20C1-L26C52 +def configure_dropout(model_config, dropout_layer_keys, dropout): + if dropout is not None: + for key in dropout_layer_keys: + if hasattr(model_config, key): + print(f"Setting model_config.{key} to {dropout}") + setattr(model_config, key, dropout) + + +def print_rich_table(title: str, df: pd.DataFrame, console: Console) -> Table: + table = Table(show_lines=True) + for column in df.columns: + table.add_column(column) + for _, row in df.iterrows(): + table.add_row(*row.astype(str).tolist()) + console.rule(f"[bold red]{title}") + console.print(table) + + +def _single_tensor_adam( + params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_avg_sqs: List[Tensor], + max_exp_avg_sqs: List[Tensor], + state_steps: List[Tensor], + grad_scale: Optional[Tensor], + found_inf: Optional[Tensor], + *, + amsgrad: bool, + beta1: float, + beta2: float, + lr: float, + weight_decay: float, + eps: float, + maximize: bool, + capturable: bool, + differentiable: bool, +): + assert grad_scale is None and found_inf is None + + for i, param in enumerate(params): + grad = grads[i] if not maximize else -grads[i] + exp_avg = exp_avgs[i] + exp_avg_sq = exp_avg_sqs[i] + step_t = state_steps[i] + # update step + step_t += 1 + # Decay the first and second moment running average coefficient + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2) + step = _get_value(step_t) + + ### pytorch adam implementation: + # bias_correction1 = 1 - beta1 ** step + # bias_correction2 = 1 - beta2 ** step + # step_size = lr / bias_correction1 + # bias_correction2_sqrt = _dispatch_sqrt(bias_correction2) + # denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps) + # param.addcdiv_(exp_avg, denom, value=-step_size) + + ### tensorflow adam implementation: + lr_t = lr * _dispatch_sqrt(1 - beta2**step) / (1 - beta1**step) + denom = exp_avg_sq.sqrt().add_(eps) + param.addcdiv_(exp_avg, denom, value=-lr_t) + + +def adam( + params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_avg_sqs: List[Tensor], + max_exp_avg_sqs: List[Tensor], + state_steps: List[Tensor], + # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627 + # setting this as kwarg for now as functional API is compiled by torch/distributed/optim + foreach: Optional[bool] = None, + capturable: bool = False, + differentiable: bool = False, + fused: Optional[bool] = None, + grad_scale: Optional[Tensor] = None, + found_inf: Optional[Tensor] = None, + *, + amsgrad: bool, + beta1: float, + beta2: float, + lr: float, + weight_decay: float, + eps: float, + maximize: bool, +): + func = _single_tensor_adam + + func( + params, + grads, + exp_avgs, + exp_avg_sqs, + max_exp_avg_sqs, + state_steps, + amsgrad=amsgrad, + beta1=beta1, + beta2=beta2, + lr=lr, + weight_decay=weight_decay, + eps=eps, + maximize=maximize, + capturable=capturable, + differentiable=differentiable, + grad_scale=grad_scale, + found_inf=found_inf, + ) + + +class AdamTensorFlowStyle(optim.Adam): + @_use_grad_for_differentiable + def step(self, closure=None): + self._cuda_graph_capture_health_check() + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + params_with_grad = [] + grads = [] + exp_avgs = [] + exp_avg_sqs = [] + max_exp_avg_sqs = [] + state_steps = [] + beta1, beta2 = group["betas"] + + self._init_group( + group, + params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + max_exp_avg_sqs, + state_steps, + ) + + adam( + params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + max_exp_avg_sqs, + state_steps, + amsgrad=group["amsgrad"], + beta1=beta1, + beta2=beta2, + lr=group["lr"], + weight_decay=group["weight_decay"], + eps=group["eps"], + maximize=group["maximize"], + foreach=group["foreach"], + capturable=group["capturable"], + differentiable=group["differentiable"], + fused=group["fused"], + grad_scale=getattr(self, "grad_scale", None), + found_inf=getattr(self, "found_inf", None), + ) + + return loss + + +class AdaptiveKLController: + def __init__(self, init_kl_coef: float, hparams: AdaptiveKLParams): + self.value = init_kl_coef + self.hparams = hparams + + def update(self, current, n_steps): + target = self.hparams.target + proportional_error = np.clip(current / target - 1, -0.2, 0.2) + mult = 1 + proportional_error * n_steps / self.hparams.horizon + self.value *= mult + + +def layer_init(layer, std=np.sqrt(2), bias_const=0.0): + torch.nn.init.normal_(layer.weight, std=std) + torch.nn.init.constant_(layer.bias, val=bias_const) + return layer + + +def whiten(values, shift_mean=True): + # `unbiased=False` matches TF `tf.nn.moments`'s setting + mean, var = torch.mean(values), torch.var(values, unbiased=False) + whitened = (values - mean) * torch.rsqrt(var + 1e-8) + if not shift_mean: + whitened += mean + return whitened + + +class AutoModelForCausalLMWithRewardHead(nn.Module): + def __init__(self, lm_backbone): + super().__init__() + self.lm_backbone = lm_backbone + # self.scalar_head = layer_init( + # nn.Linear(lm_backbone.config.hidden_size, 1), + # std=1 / np.sqrt(lm_backbone.config.hidden_size + 1), + # ) + self.scalar_head = layer_init(nn.Linear(lm_backbone.config.hidden_size, 1), std=0) + # self.reward_gain = torch.nn.Parameter(torch.tensor(1.0), requires_grad=False) + self.reward_bias = torch.nn.Parameter(torch.tensor(0.0), requires_grad=False) + + def forward(self, **kwargs): + output = self.lm_backbone(**kwargs) + reward = self.scalar_head(output.hidden_states[-1]) - self.reward_bias + return reward + + +# taken from https://github.com/OpenLMLab/MOSS-RLHF/blob/40b91eb2f2b71b16919addede0341d2bef70825d/ppo/ppo_trainer.py#L29 +# we did this we can do a single `model = accelerator.prepare(model)` +class PolicyAndValueWrapper(nn.Module): + def __init__(self, policy, critic) -> None: + super().__init__() + self.policy = policy + self.critic = critic + + def forward(self, **kwargs): + return self.policy(**kwargs), self.critic(**kwargs) + + +def ceil_div(a, b): + return (a - 1) // b + 1 + + +def exact_div(a, b): + q = a // b + if a != q * b: + raise ValueError(f"Inexact division: {a} / {b} = {a / b}") + return q + + +def generate(lm_backbone, queries, tokenizer, generation_config): + """generate in a way that does not affect padding tokens""" + context_length = queries.shape[1] + attention_mask = queries != tokenizer.pad_token_id + input_ids = torch.masked_fill(queries, ~attention_mask, 0) + output = lm_backbone.generate( + input_ids=input_ids, + attention_mask=attention_mask, + # position_ids=attention_mask.cumsum(1) - attention_mask.long(), # generation collapsed if this was turned on. TODO: why does generation collapse with this? + generation_config=generation_config, + return_dict_in_generate=True, + ) + return torch.cat((queries, output.sequences[:, context_length:]), dim=1) + + +def first_true_indices(bools, dtype=torch.long): + """ + Takes an N-dimensional bool tensor and returns an (N-1)-dimensional tensor of integers giving + the position of the first True in each "row". + + Returns the length of the rows (bools.size(-1)) if no element is True in a given row. + """ + row_len = bools.size(-1) + zero_or_index = row_len * (~bools).type(dtype) + torch.arange(row_len, dtype=dtype, device=bools.device) + return torch.min(zero_or_index, dim=-1).values + + +def truncate_response(args, tokenizer, responses): + trunc_idxs = first_true_indices(responses == args.task.truncate_token_id).unsqueeze(-1) + new_size = [1] * (len(responses.size()) - 1) + [args.task.response_length] + idxs = torch.arange(args.task.response_length, device=responses.device).view(*new_size) + postprocessed_responses = torch.masked_fill(responses, idxs > trunc_idxs, tokenizer.pad_token_id) + return postprocessed_responses + + +def get_reward(reward_model, query_responses, tokenizer): + attention_mask = query_responses != tokenizer.pad_token_id + # position_ids = attention_mask.cumsum(1) - attention_mask.long() # exclusive cumsum + input_ids = torch.masked_fill(query_responses, ~attention_mask, 0) + reward_logits = reward_model( + input_ids=input_ids, + attention_mask=attention_mask, + # position_ids=position_ids, + return_dict=True, + output_hidden_states=True, + ) + sequence_lengths = first_true_indices(query_responses == tokenizer.pad_token_id) - 1 + # sequence_lengths1 = ( + # torch.eq(query_responses, tokenizer.pad_token_id).long().argmax(-1) - 1).to( + # query_responses.device + # ) + # print(f"======={sequence_lengths1=} {sequence_lengths=}") + # https://github.com/huggingface/transformers/blob/dc68a39c8111217683bf49a4912d0c9018bab33d/src/transformers/models/gpt2/modeling_gpt2.py#L1454 + return reward_logits, reward_logits[torch.arange(reward_logits.size(0), device=reward_logits.device), sequence_lengths].squeeze(-1), sequence_lengths + + +def forward(policy, query_responses, tokenizer): + attention_mask = query_responses != tokenizer.pad_token_id + # position_ids = attention_mask.cumsum(1) - attention_mask.long() # exclusive cumsum + input_ids = torch.masked_fill(query_responses, ~attention_mask, 0) + return policy( + input_ids=input_ids, + attention_mask=attention_mask, + # position_ids=position_ids, + return_dict=True, + output_hidden_states=True, + ) + + +# def train(args: Args): +if __name__ == "__main__": + args = tyro.cli(Args) + accelerator = Accelerator(gradient_accumulation_steps=args.ppo.gradient_accumulation_steps) + args.ppo.world_size = accelerator.num_processes + args.ppo.batch_size = int(args.ppo.local_batch_size * args.ppo.world_size) + args.ppo.minibatch_size = exact_div(args.ppo.batch_size, args.ppo.nminibatches) + args.ppo.local_mini_batch_size = exact_div(args.ppo.local_batch_size, args.ppo.nminibatches) + args.ppo.local_micro_batch_size = exact_div(args.ppo.local_mini_batch_size, args.ppo.gradient_accumulation_steps) + if args.ppo.whiten_rewards: + assert ( + args.ppo.local_mini_batch_size >= 8 + ), f"Per-rank minibatch size {args.ppo.local_mini_batch_size} is insufficient for whitening" + # `per_rank_rollout_batch_size` is our `args.ppo.local_batch_size` + # `per_rank_minibatch_size` is our `args.ppo.local_mini_batch_size` + args.ppo.num_updates = args.ppo.total_episodes // args.ppo.batch_size + tokenizer = AutoTokenizer.from_pretrained( + args.base_model, + padding_side="right", + trust_remote_code=True, + ) + # we use the padding token manually but do not resize the token embedding of the model + tokenizer.add_special_tokens({"pad_token": "[PAD]"}) + if args.task.truncate_token == "eos": + args.task.truncate_token_id = tokenizer.eos_token_id + + console = Console(force_terminal=True) + run_name = f"{args.exp_name}__{args.seed}__{int(time.time())}" + writer = SimpleNamespace() # dummy writer + writer.add_scalar = lambda x, y, z: None + writer.add_histogram = lambda x, y, z: None + if accelerator.is_main_process: + if args.track: + import wandb + + wandb.init( + project=args.wandb_project_name, + entity=args.wandb_entity, + sync_tensorboard=True, + config=asdict(args), + name=run_name, + save_code=True, + ) + wandb.run.log_code(".") + writer = SummaryWriter(f"runs/{run_name}") + writer.add_text( + "hyperparameters", + "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])), + ) + pprint(args) + device = accelerator.device + local_seed = args.seed + accelerator.process_index * 100003 # Prime + random.seed(local_seed) + np.random.seed(local_seed) + torch.manual_seed(local_seed) + torch.backends.cudnn.deterministic = True + + model_config = AutoConfig.from_pretrained(args.base_model) + configure_dropout(model_config, args.dropout_layer_keys, 0.0) # disable dropout + if accelerator.is_main_process: + pprint(model_config) + critic = AutoModelForCausalLMWithRewardHead( + AutoModelForCausalLM.from_pretrained( + args.base_model, + config=model_config, + trust_remote_code=True, + ) + ) + reward_model = AutoModelForCausalLMWithRewardHead( + AutoModelForCausalLM.from_pretrained( + args.base_model, + config=model_config, + trust_remote_code=True, + ) + ) + if args.rewards.trained_model: + # TODO: i did not load the critic + # critic.load_state_dict(torch.load(args.rewards.trained_model, map_location=device), strict=False) + # critic.reward_bias.data = torch.tensor(args.rewards.dataset_mean) + reward_model.load_state_dict(torch.load(args.rewards.trained_model, map_location=device), strict=False) + reward_model.reward_bias.data = torch.tensor(args.rewards.dataset_mean) + print(f"loaded pretrained reward model from {args.rewards.trained_model}") + # each class should have a separate pretrained model that do not share weights + ref_policy = AutoModelForCausalLM.from_pretrained(args.base_model, config=model_config, trust_remote_code=True) + policy = AutoModelForCausalLM.from_pretrained(args.base_model, config=model_config, trust_remote_code=True) + # policy.gradient_checkpointing_enable() + # accelerator.print(policy) + # critic.lm_backbone.gradient_checkpointing_enable() + # accelerator.print(critic) + if args.sft_model_path: + policy.load_state_dict(torch.load(args.sft_model_path, map_location=device)) + ref_policy.load_state_dict(torch.load(args.sft_model_path, map_location=device)) + print(f"loaded pretrained policy from {args.sft_model_path}") + policy.generation_config.eos_token_id = None # disable `pad_token_id` and `eos_token_id` because we just want to + policy.generation_config.pad_token_id = None # generate tokens without truncation / padding + model = PolicyAndValueWrapper(policy, critic) + if args.optimizer == "tf_adam": + optimizer = AdamTensorFlowStyle(model.parameters(), lr=args.ppo.lr, eps=args.ppo.eps) + elif args.optimizer == "adam": + optimizer = optim.Adam(model.parameters(), lr=args.ppo.lr, eps=args.ppo.eps) + elif args.optimizer == "adamw": + optimizer = optim.AdamW(model.parameters(), lr=args.ppo.lr, eps=args.ppo.eps) + + dataset = load_dataset(args.task.query_dataset, split="train") + validation_dataset = load_dataset(args.task.query_dataset, split="validation") + dataset = dataset.with_format("torch", columns=["query_token", "reference_response_token"]) + dataset = dataset.shuffle(seed=local_seed) + dataloader = DataLoader(dataset, batch_size=args.ppo.local_batch_size) + validation_dataset = validation_dataset.with_format("torch", columns=["query_token", "reference_response_token"]) + validation_dataloader = DataLoader(validation_dataset, batch_size=args.ppo.local_batch_size) + model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader) + validation_dataloader = accelerator.prepare(validation_dataloader) + if args.deepspeed: + import deepspeed + + deepspeed_states = AcceleratorState().deepspeed_plugin + # deepspeed_states.deepspeed_config["train_micro_batch_size_per_gpu"] = args.ppo.local_micro_batch_size + # deepspeed_states.deepspeed_config["checkpoint"] = {"use_node_local_storage": True} + + offload = False + eval_ds_config = { + "train_micro_batch_size_per_gpu": deepspeed_states.deepspeed_config["train_micro_batch_size_per_gpu"], + "bf16": {"enabled": True}, + "prescale_gradients": False, + "wall_clock_breakdown": False, + } + if offload: + eval_ds_config["zero_optimization"] = { + "stage": 3, + "stage3_param_persistence_threshold": 1e4, + "offload_param": { + "device": "cpu" + } + } + accelerator.print(f"{eval_ds_config=}") + reward_model, *_ = deepspeed.initialize(model=reward_model, config=eval_ds_config) + reward_model.eval() + ref_policy, *_ = deepspeed.initialize(model=ref_policy, config=eval_ds_config) + ref_policy.eval() + else: + ref_policy = ref_policy.to(device) + reward_model = reward_model.to(device) + + def repeat_generator(): # TODO: ideally we shuffle the dataloader as well + while True: + yield from dataloader + + sample_validation_inds = np.arange(args.ppo.batch_size) + local_sample_validation_inds = sample_validation_inds[accelerator.process_index :: accelerator.num_processes] + sample_validation = validation_dataset[local_sample_validation_inds] + sample_validation_queries = torch.Tensor(sample_validation["query_token"]).to(device) + with torch.no_grad(): + # sample_validation_queries = shift_pad_id_left(sample_validation_queries, tokenizer.pad_token_id) + sample_validation_reference_response = torch.Tensor(sample_validation["reference_response_token"]).to(device) + sample_validation_query_reference_responses = torch.cat( + (sample_validation_queries, sample_validation_reference_response), dim=1 + ) + # sample_validation_query_reference_responses = shift_pad_id_left( + # sample_validation_query_reference_responses, tokenizer.pad_token_id + # ) + _, sample_validation_reference_scores, _ = get_reward( + reward_model, sample_validation_query_reference_responses, tokenizer + ) + + iter_dataloader = iter(repeat_generator()) + kl_ctl = AdaptiveKLController(args.rewards.kl_coef, hparams=args.rewards.adaptive_kl) + # WARNING: even with `max_new_tokens` and `min_new_tokens` set to the same value, the number of tokens generated + # may not be the same. TODO: investigate further, we just want to generate a fixed number of tokens + generation_config = GenerationConfig( + max_new_tokens=args.task.response_length, + min_new_tokens=args.task.response_length, + temperature=(args.task.temperature + 1e-7), + top_k=0.0, + top_p=1.0, + do_sample=True, + ) + + print("===training policy===") + global_step = 0 + stats_shape = (args.ppo.noptepochs, args.ppo.nminibatches, args.ppo.gradient_accumulation_steps) + approxkl_stats = torch.zeros(stats_shape, device=device) + pg_clipfrac_stats = torch.zeros(stats_shape, device=device) + pg_loss_stats = torch.zeros(stats_shape, device=device) + vf_loss_stats = torch.zeros(stats_shape, device=device) + vf_clipfrac_stats = torch.zeros(stats_shape, device=device) + entropy_stats = torch.zeros(stats_shape, device=device) + ratio_stats = torch.zeros(stats_shape, device=device) + model.eval() + for update in range(1, args.ppo.num_updates + 1): + global_step += 1 * args.ppo.batch_size + frac = 1.0 - (update - 1.0) / args.ppo.num_updates + lrnow = frac * args.ppo.lr + optimizer.param_groups[0]["lr"] = lrnow + data = next(iter_dataloader) + with torch.no_grad(): + queries = data["query_token"].to(device) + query_responses = generate( + accelerator.unwrap_model(model).policy, + queries, + tokenizer, + generation_config, + ) + context_length = queries.shape[1] + responses = query_responses[:, context_length:] + + # validation + sample_validation_query_responses = generate( + accelerator.unwrap_model(model).policy, + sample_validation_queries, + tokenizer, + generation_config, + ) + sample_validation_responses = sample_validation_query_responses[:, context_length:] + postprocessed_sample_validation_responses = truncate_response(args, tokenizer, sample_validation_responses) + postprocessed_sample_validation_query_responses = torch.cat( + (sample_validation_queries, postprocessed_sample_validation_responses), 1 + ) + torch.cuda.empty_cache() + + # TODO: do I do this with query response or post-processed query response? + output = forward(accelerator.unwrap_model(model).policy, query_responses, tokenizer) + logits = output.logits[:, context_length - 1 : -1] + logits /= (args.task.temperature + 1e-7) + all_logprobs = F.log_softmax(logits, dim=-1) + logprobs = torch.gather(all_logprobs, 2, responses.unsqueeze(-1)).squeeze(-1) + del output, logits, all_logprobs + torch.cuda.empty_cache() + + ref_output = forward(ref_policy, query_responses, tokenizer) + ref_logits = ref_output.logits[:, context_length - 1 : -1] + ref_logits /= (args.task.temperature + 1e-7) + ref_all_logprobs = F.log_softmax(ref_logits, dim=-1) + ref_logprobs = torch.gather(ref_all_logprobs, 2, responses.unsqueeze(-1)).squeeze(-1) + del ref_output, ref_logits, ref_all_logprobs + torch.cuda.empty_cache() + + # **Response Processing** + postprocessed_responses = truncate_response(args, tokenizer, responses) + torch.cuda.empty_cache() + + # 2. run reward model on the truncated responses + postprocessed_query_responses = torch.cat((queries, postprocessed_responses), 1) + # sequence_lengths = first_true_indices(postprocessed_responses == tokenizer.pad_token_id) - 1 + # actual_start = torch.arange(postprocessed_responses.size(0), device=postprocessed_responses.device) + # actual_end = sequence_lengths + # padding_mask = postprocessed_responses == tokenizer.pad_token_id + + full_values, _, _ = get_reward(accelerator.unwrap_model(model).critic, query_responses, tokenizer) + values = full_values[:, context_length - 1 : -1].squeeze(-1) + # values_mask = postprocessed_responses != args.task.truncate_token_id + # values = torch.masked_fill(values, values_mask, 0) + # values = torch.masked_fill(values, padding_mask, 0) + + # logprobs = torch.masked_fill(logprobs, padding_mask, INVALID_LOGPROB) + # ref_logprobs = torch.masked_fill(ref_logprobs, padding_mask, INVALID_LOGPROB) + _, scores, _ = get_reward(reward_model, postprocessed_query_responses, tokenizer) + + _, validation_score, _ = get_reward(reward_model, postprocessed_sample_validation_query_responses, tokenizer) + + # 3. filter response. Ensure that the sample contains truncate_token_id + # responses not passing that filter will receive a low (fixed) score + # only query humans on responses that pass that filter + contain_pad_token = torch.any(postprocessed_responses == tokenizer.pad_token_id, dim=-1) + scores = torch.where(contain_pad_token, scores, torch.full_like(scores, args.task.penalty_reward_value)) + + # TODO: do we need to deal with penalty values? + # penalty_values = torch.full_like(values, 0) + # penalty_values[:,-1] += args.task.penalty_reward_value + # values = torch.where(contain_pad_token, values, penalty_values) + accelerator.print(f"{scores=}, {(contain_pad_token.sum() / len(contain_pad_token))=}") + # torch.cuda.empty_cache() + + # 4. compute rewards + kl = logprobs - ref_logprobs + # kl = torch.masked_fill(kl, padding_mask, 0) + non_score_reward = -kl_ctl.value * kl + rewards = non_score_reward.clone() + rewards[:, -1] += scores + + # 5. whiten rewards + if args.ppo.whiten_rewards: + rewards = whiten(rewards, shift_mean=False) + + if args.print_sample_output_freq > 0 and (update - 1) % args.print_sample_output_freq == 0: + try: + all_decode_validation_queries = tokenizer.batch_decode(sample_validation_queries, skip_special_tokens=True) + all_sample_validation_responses = tokenizer.batch_decode(sample_validation_responses) + all_sample_validation_query_responses_postprocessed = tokenizer.batch_decode( + postprocessed_sample_validation_query_responses, skip_special_tokens=True + ) + all_sample_validation_postprocessed_responses = [ + x[len(y) :] + for x, y in zip(all_sample_validation_query_responses_postprocessed, all_decode_validation_queries) + ] + all_sample_validation_reference_responses = tokenizer.batch_decode(sample_validation_reference_response) + all_sample_validation_df = pd.DataFrame( + { + "query": all_decode_validation_queries, + "response": all_sample_validation_responses, + "postprocessed_response": all_sample_validation_postprocessed_responses, + "reference_responses": all_sample_validation_reference_responses, + "scores": validation_score.float().cpu().numpy(), + "reference_scores": sample_validation_reference_scores.float().cpu().numpy(), + } + ) + if accelerator.is_main_process: + all_sample_validation_df.to_json(f"runs/{run_name}/table.json") + if args.track: + wandb.log({"samples/query_responses": wandb.Table(dataframe=all_sample_validation_df)}, step=update) + print_rich_table("stuff", all_sample_validation_df[:4], console) + + except Exception as e: + print(e) + del ( + all_decode_validation_queries, + all_sample_validation_responses, + all_sample_validation_reference_responses, + all_sample_validation_df, + ) + # del postprocessed_query_responses + # torch.cuda.empty_cache() + + # 6. compute advantages and returns + lastgaelam = 0 + advantages_reversed = [] + gen_length = args.task.response_length + for t in reversed(range(gen_length)): + nextvalues = values[:, t + 1] if t < gen_length - 1 else 0.0 + delta = rewards[:, t] + args.ppo.gamma * nextvalues - values[:, t] + lastgaelam = delta + args.ppo.gamma * args.ppo.lam * lastgaelam + advantages_reversed.append(lastgaelam) + advantages = torch.stack(advantages_reversed[::-1], axis=1) + returns = advantages + values + advantages = whiten(advantages) + return_mean, return_var = returns.mean(), returns.var() + value_mean, value_var = values.mean(), values.var() + writer.add_histogram("rewards", rewards[0].float(), global_step) + writer.add_histogram("advantages", advantages[0].float(), global_step) + accelerator.print("rewards====", rewards[0]) + accelerator.print("advantages====", advantages[0]) + # raise + # pprint({ + # "rewards": rewards, + # "returns": returns, + # "advantages": advantages, + # }) + # breakpoint() + # Do multiple epochs of PPO training, with a fresh random shuffle in each epoch + for ppo_epoch_idx in range(args.ppo.noptepochs): + b_inds = np.random.permutation(args.ppo.local_batch_size) + minibatch_idx = 0 + for mini_batch_start in range(0, args.ppo.local_batch_size, args.ppo.local_mini_batch_size): + mini_batch_end = mini_batch_start + args.ppo.local_mini_batch_size + mini_batch_inds = b_inds[mini_batch_start:mini_batch_end] + gradient_accumulation_idx = 0 + for micro_batch_start in range(0, args.ppo.local_mini_batch_size, args.ppo.local_micro_batch_size): + with accelerator.accumulate(policy): + micro_batch_end = micro_batch_start + args.ppo.local_micro_batch_size + micro_batch_inds = mini_batch_inds[micro_batch_start:micro_batch_end] + mb_return = returns[micro_batch_inds] + mb_advantage = advantages[micro_batch_inds] + mb_values = values[micro_batch_inds] + mb_responses = responses[micro_batch_inds] + mb_query_responses = query_responses[micro_batch_inds] + mb_logprobs = logprobs[micro_batch_inds] + + output, vpred_temp = forward(model, mb_query_responses, tokenizer) + logits = output.logits[:, context_length - 1 : -1] + logits /= (args.task.temperature + 1e-7) + new_all_logprobs = F.log_softmax(logits, dim=-1) + new_logprobs = torch.gather(new_all_logprobs, 2, mb_responses.unsqueeze(-1)).squeeze(-1) + # new_logprobs = torch.masked_fill(new_logprobs, padding_mask[micro_batch_inds], INVALID_LOGPROB) + vpred = vpred_temp[:, context_length - 1 : -1].squeeze(-1) + # vpred = torch.masked_fill(vpred, padding_mask[micro_batch_inds], 0) + # vpred = torch.masked_fill(vpred, values_mask[micro_batch_inds], 0) + vpredclipped = torch.clamp( + vpred, + mb_values - args.ppo.cliprange_value, + mb_values + args.ppo.cliprange_value, + ) + vf_losses1 = torch.square(vpred - mb_return) + vf_losses2 = torch.square(vpredclipped - mb_return) + vf_loss = 0.5 * torch.max(vf_losses1, vf_losses2).mean() + vf_clipfrac = (vf_losses2 > vf_losses1).float().mean() + logprobs_diff = new_logprobs - mb_logprobs + ratio = torch.exp(logprobs_diff) + pg_losses = -mb_advantage * ratio + pg_losses2 = -mb_advantage * torch.clamp(ratio, 1.0 - args.ppo.cliprange, 1.0 + args.ppo.cliprange) + pg_loss = torch.max(pg_losses, pg_losses2).mean() + pg_clipfrac = (pg_losses2 > pg_losses).float().mean() + loss = pg_loss + args.ppo.vf_coef * vf_loss + accelerator.backward(loss) + optimizer.step() + optimizer.zero_grad() + prob_dist = torch.nn.functional.softmax(logits, dim=-1) + entropy = torch.logsumexp(logits, dim=-1) - torch.sum(prob_dist * logits, dim=-1) + approxkl = 0.5 * (logprobs_diff**2).mean() + # if ppo_epoch_idx == 0 and micro_batch_start == 0: + # torch.testing.assert_close(ratio, torch.zeros_like(ratio) + 1, atol=1e-4, rtol=1e-4) + # if ppo_epoch_idx == 0: + # pprint({ + # # "responses": responses, + # # "values": values, + # "rewards": rewards, + # # "scores": scores, + # "advantages": advantages, + # # "ratio": ratio, + # # "pg_losses": pg_losses, + # # "approxkl": approxkl, + # # "pg_loss": pg_loss, + # # "pg_clipfrac": pg_clipfrac, + # # "ratio": ratio.mean(), + # # "vf_loss": vf_loss, + # # "vf_clipfrac": vf_clipfrac, + # # "entropy": masked_mean(entropy, ~padding_mask[micro_batch_inds]), + # }) + # breakpoint() + with torch.no_grad(): + approxkl_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = approxkl + pg_clipfrac_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = pg_clipfrac + pg_loss_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = pg_loss + vf_loss_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = vf_loss + vf_clipfrac_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = vf_clipfrac + entropy_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = entropy.mean() + ratio_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = ratio.mean() + gradient_accumulation_idx += 1 + minibatch_idx += 1 + if accelerator.is_main_process: + console.print( + f"ppo_epoch_idx", + ppo_epoch_idx, + "approxkl", + approxkl_stats[:ppo_epoch_idx+1].mean().item(), + "pg_loss", + pg_loss_stats[:ppo_epoch_idx+1].mean().item(), + "pg_clipfrac", + pg_clipfrac_stats[:ppo_epoch_idx+1].mean().item(), + "ratio", + ratio_stats[:ppo_epoch_idx+1].mean().item(), + ) + # raise + # breakpoint() + with torch.no_grad(): + if not args.deepspeed: # for some reason there is a OOM with the `writer.add_histogram` + writer.add_histogram("ppo/val/ratio_hist", ratio, update) + mean_kl = kl.sum(1).mean() + mean_entropy = (-logprobs).sum(1).mean() + mean_non_score_reward = non_score_reward.sum(1).mean() + writer.add_scalar("objective/kl_coef", kl_ctl.value, update) + writer.add_scalar("objective/kl", accelerator.gather(mean_kl).mean().item(), update) + writer.add_scalar("objective/entropy", accelerator.gather(mean_entropy).mean().item(), update) + writer.add_scalar("objective/non_score_reward", accelerator.gather(mean_non_score_reward).mean().item(), update) + writer.add_scalar( + "objective/score_total", accelerator.gather(mean_non_score_reward + scores.mean()).mean().item(), update + ) + writer.add_scalar("objective/scores", accelerator.gather(scores.mean()).mean().item(), update) + writer.add_scalar("objective/validation_score", accelerator.gather(validation_score.mean()).mean().item(), update) + writer.add_scalar("ppo/loss/policy", accelerator.gather(pg_loss).mean().item(), update) + writer.add_scalar("ppo/loss/value", accelerator.gather(vf_loss).mean().item(), update) + writer.add_scalar("ppo/loss/total", accelerator.gather(loss).mean().item(), update) + writer.add_scalar("ppo/policy/entropy", accelerator.gather(entropy.mean()).mean().item(), update) + writer.add_scalar("ppo/policy/approxkl", accelerator.gather(approxkl).mean().item(), update) + writer.add_scalar("ppo/policy/clipfrac", accelerator.gather(pg_clipfrac).mean().item(), update) + writer.add_scalar("ppo/policy/approxkl_avg", accelerator.gather(approxkl_stats).mean().item(), update) + writer.add_scalar("ppo/policy/clipfrac_avg", accelerator.gather(pg_clipfrac_stats).mean().item(), update) + writer.add_scalar("ppo/loss/policy_avg", accelerator.gather(pg_loss_stats).mean().item(), update) + writer.add_scalar("ppo/loss/value_avg", accelerator.gather(vf_loss_stats).mean().item(), update) + writer.add_scalar("ppo/val/clipfrac_avg", accelerator.gather(vf_clipfrac_stats).mean().item(), update) + writer.add_scalar("ppo/policy/entropy_avg", accelerator.gather(entropy_stats).mean().item(), update) + writer.add_scalar("ppo/returns/mean", accelerator.gather(return_mean).mean().item(), update) + writer.add_scalar("ppo/returns/var", accelerator.gather(return_var).mean().item(), update) + writer.add_scalar("ppo/val/vpred", accelerator.gather(vpred.mean()).mean().item(), update) + writer.add_scalar("ppo/val/error", accelerator.gather(vf_losses1.mean()).mean().item(), update) + writer.add_scalar("ppo/val/clipfrac", accelerator.gather(vf_clipfrac).mean().item(), update) + writer.add_scalar("ppo/val/mean", accelerator.gather(value_mean).mean().item(), update) + writer.add_scalar("ppo/val/var", accelerator.gather(value_var).mean().item(), update) + writer.add_scalar("ppo/val/ratio", accelerator.gather(ratio_stats).mean().item(), update) + writer.add_scalar("ppo/val/ratio_var", accelerator.gather(ratio_stats).var().item(), update) + writer.add_scalar("ppo/val/advantage", accelerator.gather(advantages.mean()).mean().item(), update) + writer.add_scalar("ppo/val/advantage_var", accelerator.gather(advantages.mean()).var().item(), update) + writer.add_scalar("ppo/val/num_eos_tokens", (responses == tokenizer.eos_token_id).sum().item(), update) + writer.add_scalar("ppo/lr", lrnow, update) + writer.add_scalar("ppo/episode", global_step, update) + if args.rewards.use_adaptive_kl: + kl_ctl.update(mean_kl.item(), args.ppo.batch_size) + del kl, mean_kl, mean_entropy, mean_non_score_reward, scores + + # save model + if args.save_path: + os.makedirs(os.path.dirname(args.save_path), exist_ok=True) + accelerator.save_model(policy, args.save_path, max_shard_size="1000GB") + + if args.upload_model and accelerator.is_main_process: + repo_name = f"{args.exp_name}__{args.rewards.label_dataset}__seed{args.seed}__{int(time.time())}" + repo_id = f"{args.hf_entity}/{repo_name}" if args.hf_entity else repo_name + policy.save_pretrained(repo_id, safe_serialization=True, push_to_hub=True) + tokenizer.save_pretrained(repo_id, push_to_hub=True) + +# if __name__ == "__main__": +# args = tyro.cli(Args) +# train(args) From a46b8a6717af10ae161915c19a9f4bcfb9f0f369 Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Tue, 12 Dec 2023 15:15:44 +0000 Subject: [PATCH 36/62] push changes --- ...n_policy_accelerate_summarize_separate4.py | 977 ++++++++++++++++ ...elerate_summarize_separate5_load_critic.py | 976 ++++++++++++++++ ...ummarize_separate6_correct_reward_index.py | 979 ++++++++++++++++ ...te7_correct_reward_index_no_load_critic.py | 979 ++++++++++++++++ ...parate8_correct_reward_index_deepspeed3.py | 1018 +++++++++++++++++ 5 files changed, 4929 insertions(+) create mode 100644 lm_human_preference_details/train_policy_accelerate_summarize_separate4.py create mode 100644 lm_human_preference_details/train_policy_accelerate_summarize_separate5_load_critic.py create mode 100644 lm_human_preference_details/train_policy_accelerate_summarize_separate6_correct_reward_index.py create mode 100644 lm_human_preference_details/train_policy_accelerate_summarize_separate7_correct_reward_index_no_load_critic.py create mode 100644 lm_human_preference_details/train_policy_accelerate_summarize_separate8_correct_reward_index_deepspeed3.py diff --git a/lm_human_preference_details/train_policy_accelerate_summarize_separate4.py b/lm_human_preference_details/train_policy_accelerate_summarize_separate4.py new file mode 100644 index 0000000..49d023a --- /dev/null +++ b/lm_human_preference_details/train_policy_accelerate_summarize_separate4.py @@ -0,0 +1,977 @@ +import os +import random +import time +from dataclasses import asdict, dataclass, field +from types import SimpleNamespace +from typing import List, Literal, Optional + +import numpy as np +import pandas as pd +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +import tyro +from accelerate import Accelerator +from accelerate.state import AcceleratorState +from datasets import load_dataset +from rich.console import Console +from rich.pretty import pprint +from rich.table import Table +from torch import Tensor, optim +from torch.optim.optimizer import ( + _dispatch_sqrt, + _get_value, + _use_grad_for_differentiable, +) +from torch.utils.data import DataLoader +from torch.utils.tensorboard import SummaryWriter +from transformers import ( + AutoConfig, + AutoModelForCausalLM, + AutoTokenizer, + GenerationConfig, +) + + +INVALID_LOGPROB = 1.0 + + +@dataclass +class AdaptiveKLParams: + target: float = 6.0 + horizon: int = 10000 # in episodes + + +@dataclass +class RewardHParams: + use_adaptive_kl: bool = True + adaptive_kl: Optional[AdaptiveKLParams] = field(default_factory=AdaptiveKLParams) + trained_model: Optional[str] = "" + label_dataset: tyro.conf.Suppress[Optional[str]] = None + dataset_mean: float = 0. + dataset_std: float = 1. + kl_coef: float = 0.15 + + +@dataclass +class PpoHParams: + total_episodes: int = 1000000 + local_batch_size: int = 64 + local_mini_batch_size: tyro.conf.Suppress[int] = None + batch_size: tyro.conf.Suppress[int] = None + mini_batch_size: tyro.conf.Suppress[int] = None + gradient_accumulation_steps: int = 64 + """gradient accumulation steps""" + local_micro_batch_size: tyro.conf.Suppress[int] = None + """per rank micro batch size""" + world_size: tyro.conf.Suppress[int] = None + batch_size: tyro.conf.Suppress[int] = None + minibatch_size: tyro.conf.Suppress[int] = None + num_updates: tyro.conf.Suppress[int] = None + nminibatches: int = 1 + noptepochs: int = 4 + lr: float = 0.00001 + eps: float = 1e-5 + vf_coef: float = 0.1 + cliprange: float = 0.2 + cliprange_value: float = 0.2 + gamma: float = 1 + lam: float = 0.95 + whiten_rewards: bool = True + + +@dataclass +class TaskHParams: + # Query params + query_length: int = 512 + query_dataset: str = "vwxyzjn/summarize_from_feedback_tldr_3_filtered_oai_preprocessing_pythia-160m_53" + + query_format_str: Optional[str] = "SUBREDDIT: r/{subreddit}\n\nTITLE: {title}\n\nPOST: {post}\n\nTL;DR:" + query_truncate_field: Optional[str] = "post" + query_truncate_text: Optional[str] = "\n" + query_padding: Optional[str] = None # defaults to repeated spaces + query_pad_side: Optional[str] = "left" + + # Response params + response_length: int = 53 + + # Truncate response after the first occurrence of this token at or after index after when sampling. + truncate_token: Literal["eos"] = "eos" + truncate_token_id: Optional[int] = None + truncate_after: int = 16 + penalty_reward_value: int = -1 + + # LM params + temperature: float = 0.7 + + +# a patch +@dataclass +class TaskQueryHParams: + length: int = None + dataset: str = None + format_str: Optional[str] = None # if underlying dataset yields dicts, can format arbitrarily + truncate_field: Optional[str] = None + truncate_text: Optional[str] = None + padding: Optional[str] = None # defaults to repeated spaces + pad_side: Optional[str] = None + + +@dataclass +class Args: + # common args + exp_name: str = os.path.basename(__file__)[: -len(".py")] + """the name of this experiment""" + seed: int = 1 + """seed of the experiment""" + track: bool = False + """if toggled, this experiment will be tracked with Weights and Biases""" + wandb_project_name: str = "tldr_summarize" + """the wandb's project name""" + wandb_entity: Optional[str] = None + """the entity (team) of wandb's project""" + cuda: bool = True + """Whether to use cuda if available.""" + run_name: tyro.conf.Suppress[str] = None + """TO BE FILLED: a unique name of this run""" + load_from_cache_file: bool = False + """Whether to load data from the local cache file in `dataset.map`""" + upload_model: bool = False + "whether to upload the saved model to huggingface" + hf_entity: str = "" + "the user or org name of the model repository from the Hugging Face Hub" + + base_model: str = "EleutherAI/pythia-160m" + """the name of the pretrained model to use""" + dropout_layer_keys: List[str] = field(default_factory=lambda: ["attn_pdrop", "embd_pdrop", "resid_pdrop", "summary_first_dropout"]) + """Which layers to apply dropout to""" + deepspeed: bool = False + """Whether to use deepspeed to train the model""" + print_sample_output_freq: int = 10 + """How often to print sample output""" + save_path: str = "models/ppo_policy" + """Where to save the model""" + optimizer: Literal["tf_adam", "adam", "adamw"] = "adamw" + """Which optimizer to use""" + sft_model_path: str = "" + """Where to load the SFT model""" + task: TaskHParams = field(default_factory=TaskHParams) + rewards: RewardHParams = field(default_factory=RewardHParams) + ppo: PpoHParams = field(default_factory=PpoHParams) + + +# taken from https://github.com/microsoft/DeepSpeedExamples/blob/737c6740bec38b77a24a59135b6481a53d566b38/applications/DeepSpeed-Chat/training/utils/model/model_utils.py#L20C1-L26C52 +def configure_dropout(model_config, dropout_layer_keys, dropout): + if dropout is not None: + for key in dropout_layer_keys: + if hasattr(model_config, key): + print(f"Setting model_config.{key} to {dropout}") + setattr(model_config, key, dropout) + + +def print_rich_table(title: str, df: pd.DataFrame, console: Console) -> Table: + table = Table(show_lines=True) + for column in df.columns: + table.add_column(column) + for _, row in df.iterrows(): + table.add_row(*row.astype(str).tolist()) + console.rule(f"[bold red]{title}") + console.print(table) + + +def _single_tensor_adam( + params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_avg_sqs: List[Tensor], + max_exp_avg_sqs: List[Tensor], + state_steps: List[Tensor], + grad_scale: Optional[Tensor], + found_inf: Optional[Tensor], + *, + amsgrad: bool, + beta1: float, + beta2: float, + lr: float, + weight_decay: float, + eps: float, + maximize: bool, + capturable: bool, + differentiable: bool, +): + assert grad_scale is None and found_inf is None + + for i, param in enumerate(params): + grad = grads[i] if not maximize else -grads[i] + exp_avg = exp_avgs[i] + exp_avg_sq = exp_avg_sqs[i] + step_t = state_steps[i] + # update step + step_t += 1 + # Decay the first and second moment running average coefficient + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2) + step = _get_value(step_t) + + ### pytorch adam implementation: + # bias_correction1 = 1 - beta1 ** step + # bias_correction2 = 1 - beta2 ** step + # step_size = lr / bias_correction1 + # bias_correction2_sqrt = _dispatch_sqrt(bias_correction2) + # denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps) + # param.addcdiv_(exp_avg, denom, value=-step_size) + + ### tensorflow adam implementation: + lr_t = lr * _dispatch_sqrt(1 - beta2**step) / (1 - beta1**step) + denom = exp_avg_sq.sqrt().add_(eps) + param.addcdiv_(exp_avg, denom, value=-lr_t) + + +def adam( + params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_avg_sqs: List[Tensor], + max_exp_avg_sqs: List[Tensor], + state_steps: List[Tensor], + # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627 + # setting this as kwarg for now as functional API is compiled by torch/distributed/optim + foreach: Optional[bool] = None, + capturable: bool = False, + differentiable: bool = False, + fused: Optional[bool] = None, + grad_scale: Optional[Tensor] = None, + found_inf: Optional[Tensor] = None, + *, + amsgrad: bool, + beta1: float, + beta2: float, + lr: float, + weight_decay: float, + eps: float, + maximize: bool, +): + func = _single_tensor_adam + + func( + params, + grads, + exp_avgs, + exp_avg_sqs, + max_exp_avg_sqs, + state_steps, + amsgrad=amsgrad, + beta1=beta1, + beta2=beta2, + lr=lr, + weight_decay=weight_decay, + eps=eps, + maximize=maximize, + capturable=capturable, + differentiable=differentiable, + grad_scale=grad_scale, + found_inf=found_inf, + ) + + +class AdamTensorFlowStyle(optim.Adam): + @_use_grad_for_differentiable + def step(self, closure=None): + self._cuda_graph_capture_health_check() + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + params_with_grad = [] + grads = [] + exp_avgs = [] + exp_avg_sqs = [] + max_exp_avg_sqs = [] + state_steps = [] + beta1, beta2 = group["betas"] + + self._init_group( + group, + params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + max_exp_avg_sqs, + state_steps, + ) + + adam( + params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + max_exp_avg_sqs, + state_steps, + amsgrad=group["amsgrad"], + beta1=beta1, + beta2=beta2, + lr=group["lr"], + weight_decay=group["weight_decay"], + eps=group["eps"], + maximize=group["maximize"], + foreach=group["foreach"], + capturable=group["capturable"], + differentiable=group["differentiable"], + fused=group["fused"], + grad_scale=getattr(self, "grad_scale", None), + found_inf=getattr(self, "found_inf", None), + ) + + return loss + + +class AdaptiveKLController: + def __init__(self, init_kl_coef: float, hparams: AdaptiveKLParams): + self.value = init_kl_coef + self.hparams = hparams + + def update(self, current, n_steps): + target = self.hparams.target + proportional_error = np.clip(current / target - 1, -0.2, 0.2) + mult = 1 + proportional_error * n_steps / self.hparams.horizon + self.value *= mult + + +def layer_init(layer, std=np.sqrt(2), bias_const=0.0): + torch.nn.init.normal_(layer.weight, std=std) + torch.nn.init.constant_(layer.bias, val=bias_const) + return layer + + +def whiten(values, shift_mean=True): + # `unbiased=False` matches TF `tf.nn.moments`'s setting + mean, var = torch.mean(values), torch.var(values, unbiased=False) + whitened = (values - mean) * torch.rsqrt(var + 1e-8) + if not shift_mean: + whitened += mean + return whitened + + +class AutoModelForCausalLMWithRewardHead(nn.Module): + def __init__(self, lm_backbone): + super().__init__() + self.lm_backbone = lm_backbone + # self.scalar_head = layer_init( + # nn.Linear(lm_backbone.config.hidden_size, 1), + # std=1 / np.sqrt(lm_backbone.config.hidden_size + 1), + # ) + self.scalar_head = layer_init(nn.Linear(lm_backbone.config.hidden_size, 1), std=0) + # self.reward_gain = torch.nn.Parameter(torch.tensor(1.0), requires_grad=False) + self.reward_bias = torch.nn.Parameter(torch.tensor(0.0), requires_grad=False) + + def forward(self, **kwargs): + output = self.lm_backbone(**kwargs) + reward = self.scalar_head(output.hidden_states[-1]) - self.reward_bias + return reward + + +# taken from https://github.com/OpenLMLab/MOSS-RLHF/blob/40b91eb2f2b71b16919addede0341d2bef70825d/ppo/ppo_trainer.py#L29 +# we did this we can do a single `model = accelerator.prepare(model)` +class PolicyAndValueWrapper(nn.Module): + def __init__(self, policy, critic) -> None: + super().__init__() + self.policy = policy + self.critic = critic + + def forward(self, **kwargs): + return self.policy(**kwargs), self.critic(**kwargs) + + +def ceil_div(a, b): + return (a - 1) // b + 1 + + +def exact_div(a, b): + q = a // b + if a != q * b: + raise ValueError(f"Inexact division: {a} / {b} = {a / b}") + return q + + +def generate(lm_backbone, queries, tokenizer, generation_config): + """generate in a way that does not affect padding tokens""" + context_length = queries.shape[1] + attention_mask = queries != tokenizer.pad_token_id + input_ids = torch.masked_fill(queries, ~attention_mask, 0) + output = lm_backbone.generate( + input_ids=input_ids, + attention_mask=attention_mask, + # position_ids=attention_mask.cumsum(1) - attention_mask.long(), # generation collapsed if this was turned on. TODO: why does generation collapse with this? + generation_config=generation_config, + return_dict_in_generate=True, + ) + return torch.cat((queries, output.sequences[:, context_length:]), dim=1) + + +def first_true_indices(bools, dtype=torch.long): + """ + Takes an N-dimensional bool tensor and returns an (N-1)-dimensional tensor of integers giving + the position of the first True in each "row". + + Returns the length of the rows (bools.size(-1)) if no element is True in a given row. + """ + row_len = bools.size(-1) + zero_or_index = row_len * (~bools).type(dtype) + torch.arange(row_len, dtype=dtype, device=bools.device) + return torch.min(zero_or_index, dim=-1).values + + +def truncate_response(args, tokenizer, responses): + trunc_idxs = first_true_indices(responses == args.task.truncate_token_id).unsqueeze(-1) + new_size = [1] * (len(responses.size()) - 1) + [args.task.response_length] + idxs = torch.arange(args.task.response_length, device=responses.device).view(*new_size) + postprocessed_responses = torch.masked_fill(responses, idxs > trunc_idxs, tokenizer.pad_token_id) + return postprocessed_responses + + +def get_reward(reward_model, query_responses, tokenizer): + attention_mask = query_responses != tokenizer.pad_token_id + # position_ids = attention_mask.cumsum(1) - attention_mask.long() # exclusive cumsum + input_ids = torch.masked_fill(query_responses, ~attention_mask, 0) + reward_logits = reward_model( + input_ids=input_ids, + attention_mask=attention_mask, + # position_ids=position_ids, + return_dict=True, + output_hidden_states=True, + ) + sequence_lengths = first_true_indices(query_responses == tokenizer.pad_token_id) - 1 + # sequence_lengths1 = ( + # torch.eq(query_responses, tokenizer.pad_token_id).long().argmax(-1) - 1).to( + # query_responses.device + # ) + # print(f"======={sequence_lengths1=} {sequence_lengths=}") + # https://github.com/huggingface/transformers/blob/dc68a39c8111217683bf49a4912d0c9018bab33d/src/transformers/models/gpt2/modeling_gpt2.py#L1454 + return reward_logits, reward_logits[torch.arange(reward_logits.size(0), device=reward_logits.device), sequence_lengths].squeeze(-1), sequence_lengths + + +def forward(policy, query_responses, tokenizer): + attention_mask = query_responses != tokenizer.pad_token_id + # position_ids = attention_mask.cumsum(1) - attention_mask.long() # exclusive cumsum + input_ids = torch.masked_fill(query_responses, ~attention_mask, 0) + return policy( + input_ids=input_ids, + attention_mask=attention_mask, + # position_ids=position_ids, + return_dict=True, + output_hidden_states=True, + ) + + +# def train(args: Args): +if __name__ == "__main__": + args = tyro.cli(Args) + accelerator = Accelerator(gradient_accumulation_steps=args.ppo.gradient_accumulation_steps) + args.ppo.world_size = accelerator.num_processes + args.ppo.batch_size = int(args.ppo.local_batch_size * args.ppo.world_size) + args.ppo.minibatch_size = exact_div(args.ppo.batch_size, args.ppo.nminibatches) + args.ppo.local_mini_batch_size = exact_div(args.ppo.local_batch_size, args.ppo.nminibatches) + args.ppo.local_micro_batch_size = exact_div(args.ppo.local_mini_batch_size, args.ppo.gradient_accumulation_steps) + if args.ppo.whiten_rewards: + assert ( + args.ppo.local_mini_batch_size >= 8 + ), f"Per-rank minibatch size {args.ppo.local_mini_batch_size} is insufficient for whitening" + # `per_rank_rollout_batch_size` is our `args.ppo.local_batch_size` + # `per_rank_minibatch_size` is our `args.ppo.local_mini_batch_size` + args.ppo.num_updates = args.ppo.total_episodes // args.ppo.batch_size + tokenizer = AutoTokenizer.from_pretrained( + args.base_model, + padding_side="right", + trust_remote_code=True, + ) + # we use the padding token manually but do not resize the token embedding of the model + tokenizer.add_special_tokens({"pad_token": "[PAD]"}) + if args.task.truncate_token == "eos": + args.task.truncate_token_id = tokenizer.eos_token_id + + console = Console(force_terminal=True) + run_name = f"{args.exp_name}__{args.seed}__{int(time.time())}" + writer = SimpleNamespace() # dummy writer + writer.add_scalar = lambda x, y, z: None + writer.add_histogram = lambda x, y, z: None + if accelerator.is_main_process: + if args.track: + import wandb + + wandb.init( + project=args.wandb_project_name, + entity=args.wandb_entity, + sync_tensorboard=True, + config=asdict(args), + name=run_name, + save_code=True, + ) + wandb.run.log_code(".") + writer = SummaryWriter(f"runs/{run_name}") + writer.add_text( + "hyperparameters", + "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])), + ) + pprint(args) + device = accelerator.device + local_seed = args.seed + accelerator.process_index * 100003 # Prime + random.seed(local_seed) + np.random.seed(local_seed) + torch.manual_seed(local_seed) + torch.backends.cudnn.deterministic = True + + model_config = AutoConfig.from_pretrained(args.base_model) + configure_dropout(model_config, args.dropout_layer_keys, 0.0) # disable dropout + if accelerator.is_main_process: + pprint(model_config) + critic = AutoModelForCausalLMWithRewardHead( + AutoModelForCausalLM.from_pretrained( + args.base_model, + config=model_config, + trust_remote_code=True, + ) + ) + reward_model = AutoModelForCausalLMWithRewardHead( + AutoModelForCausalLM.from_pretrained( + args.base_model, + config=model_config, + trust_remote_code=True, + ) + ) + if args.rewards.trained_model: + # TODO: i did not load the critic + # critic.load_state_dict(torch.load(args.rewards.trained_model, map_location=device), strict=False) + # critic.reward_bias.data = torch.tensor(args.rewards.dataset_mean) + reward_model.load_state_dict(torch.load(args.rewards.trained_model, map_location=device), strict=False) + reward_model.reward_bias.data = torch.tensor(args.rewards.dataset_mean) + print(f"loaded pretrained reward model from {args.rewards.trained_model}") + # each class should have a separate pretrained model that do not share weights + ref_policy = AutoModelForCausalLM.from_pretrained(args.base_model, config=model_config, trust_remote_code=True) + policy = AutoModelForCausalLM.from_pretrained(args.base_model, config=model_config, trust_remote_code=True) + policy.gradient_checkpointing_enable() + accelerator.print(policy) + critic.lm_backbone.gradient_checkpointing_enable() + accelerator.print(critic) + if args.sft_model_path: + policy.load_state_dict(torch.load(args.sft_model_path, map_location=device)) + ref_policy.load_state_dict(torch.load(args.sft_model_path, map_location=device)) + print(f"loaded pretrained policy from {args.sft_model_path}") + policy.generation_config.eos_token_id = None # disable `pad_token_id` and `eos_token_id` because we just want to + policy.generation_config.pad_token_id = None # generate tokens without truncation / padding + model = PolicyAndValueWrapper(policy, critic) + if args.optimizer == "tf_adam": + optimizer = AdamTensorFlowStyle(model.parameters(), lr=args.ppo.lr, eps=args.ppo.eps) + elif args.optimizer == "adam": + optimizer = optim.Adam(model.parameters(), lr=args.ppo.lr, eps=args.ppo.eps) + elif args.optimizer == "adamw": + optimizer = optim.AdamW(model.parameters(), lr=args.ppo.lr, eps=args.ppo.eps) + + dataset = load_dataset(args.task.query_dataset, split="train") + validation_dataset = load_dataset(args.task.query_dataset, split="validation") + dataset = dataset.with_format("torch", columns=["query_token", "reference_response_token"]) + dataset = dataset.shuffle(seed=local_seed) + dataloader = DataLoader(dataset, batch_size=args.ppo.local_batch_size) + validation_dataset = validation_dataset.with_format("torch", columns=["query_token", "reference_response_token"]) + validation_dataloader = DataLoader(validation_dataset, batch_size=args.ppo.local_batch_size) + model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader) + validation_dataloader = accelerator.prepare(validation_dataloader) + if args.deepspeed: + import deepspeed + + deepspeed_states = AcceleratorState().deepspeed_plugin + # deepspeed_states.deepspeed_config["train_micro_batch_size_per_gpu"] = args.ppo.local_micro_batch_size + # deepspeed_states.deepspeed_config["checkpoint"] = {"use_node_local_storage": True} + + offload = False + eval_ds_config = { + "train_micro_batch_size_per_gpu": deepspeed_states.deepspeed_config["train_micro_batch_size_per_gpu"], + "bf16": {"enabled": True}, + "prescale_gradients": False, + "wall_clock_breakdown": False, + } + if offload: + eval_ds_config["zero_optimization"] = { + "stage": 3, + "stage3_param_persistence_threshold": 1e4, + "offload_param": { + "device": "cpu" + } + } + accelerator.print(f"{eval_ds_config=}") + reward_model, *_ = deepspeed.initialize(model=reward_model, config=eval_ds_config) + reward_model.eval() + ref_policy, *_ = deepspeed.initialize(model=ref_policy, config=eval_ds_config) + ref_policy.eval() + else: + ref_policy = ref_policy.to(device) + reward_model = reward_model.to(device) + + def repeat_generator(): # TODO: ideally we shuffle the dataloader as well + while True: + yield from dataloader + + sample_validation_inds = np.arange(args.ppo.batch_size) + local_sample_validation_inds = sample_validation_inds[accelerator.process_index :: accelerator.num_processes] + sample_validation = validation_dataset[local_sample_validation_inds] + sample_validation_queries = torch.Tensor(sample_validation["query_token"]).to(device) + with torch.no_grad(): + # sample_validation_queries = shift_pad_id_left(sample_validation_queries, tokenizer.pad_token_id) + sample_validation_reference_response = torch.Tensor(sample_validation["reference_response_token"]).to(device) + sample_validation_query_reference_responses = torch.cat( + (sample_validation_queries, sample_validation_reference_response), dim=1 + ) + # sample_validation_query_reference_responses = shift_pad_id_left( + # sample_validation_query_reference_responses, tokenizer.pad_token_id + # ) + _, sample_validation_reference_scores, _ = get_reward( + reward_model, sample_validation_query_reference_responses, tokenizer + ) + + iter_dataloader = iter(repeat_generator()) + kl_ctl = AdaptiveKLController(args.rewards.kl_coef, hparams=args.rewards.adaptive_kl) + # WARNING: even with `max_new_tokens` and `min_new_tokens` set to the same value, the number of tokens generated + # may not be the same. TODO: investigate further, we just want to generate a fixed number of tokens + generation_config = GenerationConfig( + max_new_tokens=args.task.response_length, + min_new_tokens=args.task.response_length, + temperature=(args.task.temperature + 1e-7), + top_k=0.0, + top_p=1.0, + do_sample=True, + ) + # use the same `0.01` temperature for validation response generation https://github.com/openai/summarize-from-feedback/blob/700967448d10004279f138666442bf1497d0e705/exps/sample.py#L27 + validation_generation_config= GenerationConfig( + max_new_tokens=args.task.response_length, + min_new_tokens=args.task.response_length, + temperature=(0.01 + 1e-7), + top_k=0.0, + top_p=1.0, + do_sample=True, + ) + + print("===training policy===") + global_step = 0 + stats_shape = (args.ppo.noptepochs, args.ppo.nminibatches, args.ppo.gradient_accumulation_steps) + approxkl_stats = torch.zeros(stats_shape, device=device) + pg_clipfrac_stats = torch.zeros(stats_shape, device=device) + pg_loss_stats = torch.zeros(stats_shape, device=device) + vf_loss_stats = torch.zeros(stats_shape, device=device) + vf_clipfrac_stats = torch.zeros(stats_shape, device=device) + entropy_stats = torch.zeros(stats_shape, device=device) + ratio_stats = torch.zeros(stats_shape, device=device) + model.train() + for update in range(1, args.ppo.num_updates + 1): + global_step += 1 * args.ppo.batch_size + frac = 1.0 - (update - 1.0) / args.ppo.num_updates + lrnow = frac * args.ppo.lr + optimizer.param_groups[0]["lr"] = lrnow + data = next(iter_dataloader) + with torch.no_grad(): + queries = data["query_token"].to(device) + query_responses = generate( + accelerator.unwrap_model(model).policy, + queries, + tokenizer, + generation_config, + ) + context_length = queries.shape[1] + responses = query_responses[:, context_length:] + + # validation + sample_validation_query_responses = generate( + accelerator.unwrap_model(model).policy, + sample_validation_queries, + tokenizer, + validation_generation_config, + ) + sample_validation_responses = sample_validation_query_responses[:, context_length:] + postprocessed_sample_validation_responses = truncate_response(args, tokenizer, sample_validation_responses) + postprocessed_sample_validation_query_responses = torch.cat( + (sample_validation_queries, postprocessed_sample_validation_responses), 1 + ) + torch.cuda.empty_cache() + + # TODO: do I do this with query response or post-processed query response? + output = forward(accelerator.unwrap_model(model).policy, query_responses, tokenizer) + logits = output.logits[:, context_length - 1 : -1] + logits /= (args.task.temperature + 1e-7) + all_logprobs = F.log_softmax(logits, dim=-1) + logprobs = torch.gather(all_logprobs, 2, responses.unsqueeze(-1)).squeeze(-1) + del output, logits, all_logprobs + torch.cuda.empty_cache() + + ref_output = forward(ref_policy, query_responses, tokenizer) + ref_logits = ref_output.logits[:, context_length - 1 : -1] + ref_logits /= (args.task.temperature + 1e-7) + ref_all_logprobs = F.log_softmax(ref_logits, dim=-1) + ref_logprobs = torch.gather(ref_all_logprobs, 2, responses.unsqueeze(-1)).squeeze(-1) + del ref_output, ref_logits, ref_all_logprobs + torch.cuda.empty_cache() + + # **Response Processing** + postprocessed_responses = truncate_response(args, tokenizer, responses) + torch.cuda.empty_cache() + + # 2. run reward model on the truncated responses + postprocessed_query_responses = torch.cat((queries, postprocessed_responses), 1) + # sequence_lengths = first_true_indices(postprocessed_responses == tokenizer.pad_token_id) - 1 + # actual_start = torch.arange(postprocessed_responses.size(0), device=postprocessed_responses.device) + # actual_end = sequence_lengths + # padding_mask = postprocessed_responses == tokenizer.pad_token_id + + full_values, _, _ = get_reward(accelerator.unwrap_model(model).critic, query_responses, tokenizer) + values = full_values[:, context_length - 1 : -1].squeeze(-1) + # values_mask = postprocessed_responses != args.task.truncate_token_id + # values = torch.masked_fill(values, values_mask, 0) + # values = torch.masked_fill(values, padding_mask, 0) + + # logprobs = torch.masked_fill(logprobs, padding_mask, INVALID_LOGPROB) + # ref_logprobs = torch.masked_fill(ref_logprobs, padding_mask, INVALID_LOGPROB) + _, scores, _ = get_reward(reward_model, postprocessed_query_responses, tokenizer) + + _, validation_score, _ = get_reward(reward_model, postprocessed_sample_validation_query_responses, tokenizer) + + # 3. filter response. Ensure that the sample contains truncate_token_id + # responses not passing that filter will receive a low (fixed) score + # only query humans on responses that pass that filter + contain_pad_token = torch.any(postprocessed_responses == tokenizer.pad_token_id, dim=-1) + scores = torch.where(contain_pad_token, scores, torch.full_like(scores, args.task.penalty_reward_value)) + + # TODO: do we need to deal with penalty values? + # penalty_values = torch.full_like(values, 0) + # penalty_values[:,-1] += args.task.penalty_reward_value + # values = torch.where(contain_pad_token, values, penalty_values) + accelerator.print(f"{scores=}, {(contain_pad_token.sum() / len(contain_pad_token))=}") + # torch.cuda.empty_cache() + + # 4. compute rewards + kl = logprobs - ref_logprobs + # kl = torch.masked_fill(kl, padding_mask, 0) + non_score_reward = -kl_ctl.value * kl + rewards = non_score_reward.clone() + rewards[:, -1] += scores + + # 5. whiten rewards + if args.ppo.whiten_rewards: + rewards = whiten(rewards, shift_mean=False) + + if args.print_sample_output_freq > 0 and (update - 1) % args.print_sample_output_freq == 0: + try: + all_decode_validation_queries = tokenizer.batch_decode(sample_validation_queries, skip_special_tokens=True) + all_sample_validation_responses = tokenizer.batch_decode(sample_validation_responses) + all_sample_validation_query_responses_postprocessed = tokenizer.batch_decode( + postprocessed_sample_validation_query_responses, skip_special_tokens=True + ) + all_sample_validation_postprocessed_responses = [ + x[len(y) :] + for x, y in zip(all_sample_validation_query_responses_postprocessed, all_decode_validation_queries) + ] + all_sample_validation_reference_responses = tokenizer.batch_decode(sample_validation_reference_response) + all_sample_validation_df = pd.DataFrame( + { + "query": all_decode_validation_queries, + "response": all_sample_validation_responses, + "postprocessed_response": all_sample_validation_postprocessed_responses, + "reference_responses": all_sample_validation_reference_responses, + "scores": validation_score.float().cpu().numpy(), + "reference_scores": sample_validation_reference_scores.float().cpu().numpy(), + } + ) + if accelerator.is_main_process: + all_sample_validation_df.to_json(f"runs/{run_name}/table.json") + if args.track: + wandb.log({"samples/query_responses": wandb.Table(dataframe=all_sample_validation_df)}, step=update) + print_rich_table("stuff", all_sample_validation_df[:4], console) + + except Exception as e: + print(e) + del ( + all_decode_validation_queries, + all_sample_validation_responses, + all_sample_validation_reference_responses, + all_sample_validation_df, + ) + # del postprocessed_query_responses + # torch.cuda.empty_cache() + + # 6. compute advantages and returns + lastgaelam = 0 + advantages_reversed = [] + gen_length = args.task.response_length + for t in reversed(range(gen_length)): + nextvalues = values[:, t + 1] if t < gen_length - 1 else 0.0 + delta = rewards[:, t] + args.ppo.gamma * nextvalues - values[:, t] + lastgaelam = delta + args.ppo.gamma * args.ppo.lam * lastgaelam + advantages_reversed.append(lastgaelam) + advantages = torch.stack(advantages_reversed[::-1], axis=1) + returns = advantages + values + advantages = whiten(advantages) + return_mean, return_var = returns.mean(), returns.var() + value_mean, value_var = values.mean(), values.var() + writer.add_histogram("rewards", rewards[0].float(), global_step) + writer.add_histogram("advantages", advantages[0].float(), global_step) + accelerator.print("rewards====", rewards[0]) + accelerator.print("advantages====", advantages[0]) + # raise + # pprint({ + # "rewards": rewards, + # "returns": returns, + # "advantages": advantages, + # }) + # breakpoint() + # Do multiple epochs of PPO training, with a fresh random shuffle in each epoch + for ppo_epoch_idx in range(args.ppo.noptepochs): + b_inds = np.random.permutation(args.ppo.local_batch_size) + minibatch_idx = 0 + for mini_batch_start in range(0, args.ppo.local_batch_size, args.ppo.local_mini_batch_size): + mini_batch_end = mini_batch_start + args.ppo.local_mini_batch_size + mini_batch_inds = b_inds[mini_batch_start:mini_batch_end] + gradient_accumulation_idx = 0 + for micro_batch_start in range(0, args.ppo.local_mini_batch_size, args.ppo.local_micro_batch_size): + with accelerator.accumulate(policy): + micro_batch_end = micro_batch_start + args.ppo.local_micro_batch_size + micro_batch_inds = mini_batch_inds[micro_batch_start:micro_batch_end] + mb_return = returns[micro_batch_inds] + mb_advantage = advantages[micro_batch_inds] + mb_values = values[micro_batch_inds] + mb_responses = responses[micro_batch_inds] + mb_query_responses = query_responses[micro_batch_inds] + mb_logprobs = logprobs[micro_batch_inds] + + output, vpred_temp = forward(model, mb_query_responses, tokenizer) + logits = output.logits[:, context_length - 1 : -1] + logits /= (args.task.temperature + 1e-7) + new_all_logprobs = F.log_softmax(logits, dim=-1) + new_logprobs = torch.gather(new_all_logprobs, 2, mb_responses.unsqueeze(-1)).squeeze(-1) + # new_logprobs = torch.masked_fill(new_logprobs, padding_mask[micro_batch_inds], INVALID_LOGPROB) + vpred = vpred_temp[:, context_length - 1 : -1].squeeze(-1) + # vpred = torch.masked_fill(vpred, padding_mask[micro_batch_inds], 0) + # vpred = torch.masked_fill(vpred, values_mask[micro_batch_inds], 0) + vpredclipped = torch.clamp( + vpred, + mb_values - args.ppo.cliprange_value, + mb_values + args.ppo.cliprange_value, + ) + vf_losses1 = torch.square(vpred - mb_return) + vf_losses2 = torch.square(vpredclipped - mb_return) + vf_loss = 0.5 * torch.max(vf_losses1, vf_losses2).mean() + vf_clipfrac = (vf_losses2 > vf_losses1).float().mean() + logprobs_diff = new_logprobs - mb_logprobs + ratio = torch.exp(logprobs_diff) + pg_losses = -mb_advantage * ratio + pg_losses2 = -mb_advantage * torch.clamp(ratio, 1.0 - args.ppo.cliprange, 1.0 + args.ppo.cliprange) + pg_loss = torch.max(pg_losses, pg_losses2).mean() + pg_clipfrac = (pg_losses2 > pg_losses).float().mean() + loss = pg_loss + args.ppo.vf_coef * vf_loss + accelerator.backward(loss) + optimizer.step() + optimizer.zero_grad() + prob_dist = torch.nn.functional.softmax(logits, dim=-1) + entropy = torch.logsumexp(logits, dim=-1) - torch.sum(prob_dist * logits, dim=-1) + approxkl = 0.5 * (logprobs_diff**2).mean() + # if ppo_epoch_idx == 0 and micro_batch_start == 0: + # torch.testing.assert_close(ratio, torch.zeros_like(ratio) + 1, atol=1e-4, rtol=1e-4) + # if ppo_epoch_idx == 0: + # pprint({ + # # "responses": responses, + # # "values": values, + # "rewards": rewards, + # # "scores": scores, + # "advantages": advantages, + # # "ratio": ratio, + # # "pg_losses": pg_losses, + # # "approxkl": approxkl, + # # "pg_loss": pg_loss, + # # "pg_clipfrac": pg_clipfrac, + # # "ratio": ratio.mean(), + # # "vf_loss": vf_loss, + # # "vf_clipfrac": vf_clipfrac, + # # "entropy": masked_mean(entropy, ~padding_mask[micro_batch_inds]), + # }) + # breakpoint() + with torch.no_grad(): + approxkl_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = approxkl + pg_clipfrac_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = pg_clipfrac + pg_loss_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = pg_loss + vf_loss_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = vf_loss + vf_clipfrac_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = vf_clipfrac + entropy_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = entropy.mean() + ratio_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = ratio.mean() + gradient_accumulation_idx += 1 + minibatch_idx += 1 + if accelerator.is_main_process: + console.print( + f"ppo_epoch_idx", + ppo_epoch_idx, + "approxkl", + approxkl_stats[:ppo_epoch_idx+1].mean().item(), + "pg_loss", + pg_loss_stats[:ppo_epoch_idx+1].mean().item(), + "pg_clipfrac", + pg_clipfrac_stats[:ppo_epoch_idx+1].mean().item(), + "ratio", + ratio_stats[:ppo_epoch_idx+1].mean().item(), + ) + # raise + # breakpoint() + with torch.no_grad(): + if not args.deepspeed: # for some reason there is a OOM with the `writer.add_histogram` + writer.add_histogram("ppo/val/ratio_hist", ratio, update) + mean_kl = kl.sum(1).mean() + mean_entropy = (-logprobs).sum(1).mean() + mean_non_score_reward = non_score_reward.sum(1).mean() + writer.add_scalar("objective/kl_coef", kl_ctl.value, update) + writer.add_scalar("objective/kl", accelerator.gather(mean_kl).mean().item(), update) + writer.add_scalar("objective/entropy", accelerator.gather(mean_entropy).mean().item(), update) + writer.add_scalar("objective/non_score_reward", accelerator.gather(mean_non_score_reward).mean().item(), update) + writer.add_scalar( + "objective/score_total", accelerator.gather(mean_non_score_reward + scores.mean()).mean().item(), update + ) + writer.add_scalar("objective/scores", accelerator.gather(scores.mean()).mean().item(), update) + writer.add_scalar("objective/validation_score", accelerator.gather(validation_score.mean()).mean().item(), update) + writer.add_scalar("ppo/loss/policy", accelerator.gather(pg_loss).mean().item(), update) + writer.add_scalar("ppo/loss/value", accelerator.gather(vf_loss).mean().item(), update) + writer.add_scalar("ppo/loss/total", accelerator.gather(loss).mean().item(), update) + writer.add_scalar("ppo/policy/entropy", accelerator.gather(entropy.mean()).mean().item(), update) + writer.add_scalar("ppo/policy/approxkl", accelerator.gather(approxkl).mean().item(), update) + writer.add_scalar("ppo/policy/clipfrac", accelerator.gather(pg_clipfrac).mean().item(), update) + writer.add_scalar("ppo/policy/approxkl_avg", accelerator.gather(approxkl_stats).mean().item(), update) + writer.add_scalar("ppo/policy/clipfrac_avg", accelerator.gather(pg_clipfrac_stats).mean().item(), update) + writer.add_scalar("ppo/loss/policy_avg", accelerator.gather(pg_loss_stats).mean().item(), update) + writer.add_scalar("ppo/loss/value_avg", accelerator.gather(vf_loss_stats).mean().item(), update) + writer.add_scalar("ppo/val/clipfrac_avg", accelerator.gather(vf_clipfrac_stats).mean().item(), update) + writer.add_scalar("ppo/policy/entropy_avg", accelerator.gather(entropy_stats).mean().item(), update) + writer.add_scalar("ppo/returns/mean", accelerator.gather(return_mean).mean().item(), update) + writer.add_scalar("ppo/returns/var", accelerator.gather(return_var).mean().item(), update) + writer.add_scalar("ppo/val/vpred", accelerator.gather(vpred.mean()).mean().item(), update) + writer.add_scalar("ppo/val/error", accelerator.gather(vf_losses1.mean()).mean().item(), update) + writer.add_scalar("ppo/val/clipfrac", accelerator.gather(vf_clipfrac).mean().item(), update) + writer.add_scalar("ppo/val/mean", accelerator.gather(value_mean).mean().item(), update) + writer.add_scalar("ppo/val/var", accelerator.gather(value_var).mean().item(), update) + writer.add_scalar("ppo/val/ratio", accelerator.gather(ratio_stats).mean().item(), update) + writer.add_scalar("ppo/val/ratio_var", accelerator.gather(ratio_stats).var().item(), update) + writer.add_scalar("ppo/val/advantage", accelerator.gather(advantages.mean()).mean().item(), update) + writer.add_scalar("ppo/val/advantage_var", accelerator.gather(advantages.mean()).var().item(), update) + writer.add_scalar("ppo/val/num_eos_tokens", (responses == tokenizer.eos_token_id).sum().item(), update) + writer.add_scalar("ppo/lr", lrnow, update) + writer.add_scalar("ppo/episode", global_step, update) + if args.rewards.use_adaptive_kl: + kl_ctl.update(mean_kl.item(), args.ppo.batch_size) + del kl, mean_kl, mean_entropy, mean_non_score_reward, scores + + # save model + if args.save_path: + os.makedirs(os.path.dirname(args.save_path), exist_ok=True) + accelerator.save_model(policy, args.save_path, max_shard_size="1000GB") + + if args.upload_model and accelerator.is_main_process: + repo_name = f"{args.exp_name}__{args.rewards.label_dataset}__seed{args.seed}__{int(time.time())}" + repo_id = f"{args.hf_entity}/{repo_name}" if args.hf_entity else repo_name + policy.save_pretrained(repo_id, safe_serialization=True, push_to_hub=True) + tokenizer.save_pretrained(repo_id, push_to_hub=True) + +# if __name__ == "__main__": +# args = tyro.cli(Args) +# train(args) diff --git a/lm_human_preference_details/train_policy_accelerate_summarize_separate5_load_critic.py b/lm_human_preference_details/train_policy_accelerate_summarize_separate5_load_critic.py new file mode 100644 index 0000000..ea30982 --- /dev/null +++ b/lm_human_preference_details/train_policy_accelerate_summarize_separate5_load_critic.py @@ -0,0 +1,976 @@ +import os +import random +import time +from dataclasses import asdict, dataclass, field +from types import SimpleNamespace +from typing import List, Literal, Optional + +import numpy as np +import pandas as pd +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +import tyro +from accelerate import Accelerator +from accelerate.state import AcceleratorState +from datasets import load_dataset +from rich.console import Console +from rich.pretty import pprint +from rich.table import Table +from torch import Tensor, optim +from torch.optim.optimizer import ( + _dispatch_sqrt, + _get_value, + _use_grad_for_differentiable, +) +from torch.utils.data import DataLoader +from torch.utils.tensorboard import SummaryWriter +from transformers import ( + AutoConfig, + AutoModelForCausalLM, + AutoTokenizer, + GenerationConfig, +) + + +INVALID_LOGPROB = 1.0 + + +@dataclass +class AdaptiveKLParams: + target: float = 6.0 + horizon: int = 10000 # in episodes + + +@dataclass +class RewardHParams: + use_adaptive_kl: bool = True + adaptive_kl: Optional[AdaptiveKLParams] = field(default_factory=AdaptiveKLParams) + trained_model: Optional[str] = "" + label_dataset: tyro.conf.Suppress[Optional[str]] = None + dataset_mean: float = 0. + dataset_std: float = 1. + kl_coef: float = 0.15 + + +@dataclass +class PpoHParams: + total_episodes: int = 1000000 + local_batch_size: int = 64 + local_mini_batch_size: tyro.conf.Suppress[int] = None + batch_size: tyro.conf.Suppress[int] = None + mini_batch_size: tyro.conf.Suppress[int] = None + gradient_accumulation_steps: int = 64 + """gradient accumulation steps""" + local_micro_batch_size: tyro.conf.Suppress[int] = None + """per rank micro batch size""" + world_size: tyro.conf.Suppress[int] = None + batch_size: tyro.conf.Suppress[int] = None + minibatch_size: tyro.conf.Suppress[int] = None + num_updates: tyro.conf.Suppress[int] = None + nminibatches: int = 1 + noptepochs: int = 4 + lr: float = 0.00001 + eps: float = 1e-5 + vf_coef: float = 0.1 + cliprange: float = 0.2 + cliprange_value: float = 0.2 + gamma: float = 1 + lam: float = 0.95 + whiten_rewards: bool = True + + +@dataclass +class TaskHParams: + # Query params + query_length: int = 512 + query_dataset: str = "vwxyzjn/summarize_from_feedback_tldr_3_filtered_oai_preprocessing_pythia-160m_53" + + query_format_str: Optional[str] = "SUBREDDIT: r/{subreddit}\n\nTITLE: {title}\n\nPOST: {post}\n\nTL;DR:" + query_truncate_field: Optional[str] = "post" + query_truncate_text: Optional[str] = "\n" + query_padding: Optional[str] = None # defaults to repeated spaces + query_pad_side: Optional[str] = "left" + + # Response params + response_length: int = 53 + + # Truncate response after the first occurrence of this token at or after index after when sampling. + truncate_token: Literal["eos"] = "eos" + truncate_token_id: Optional[int] = None + truncate_after: int = 16 + penalty_reward_value: int = -1 + + # LM params + temperature: float = 0.7 + + +# a patch +@dataclass +class TaskQueryHParams: + length: int = None + dataset: str = None + format_str: Optional[str] = None # if underlying dataset yields dicts, can format arbitrarily + truncate_field: Optional[str] = None + truncate_text: Optional[str] = None + padding: Optional[str] = None # defaults to repeated spaces + pad_side: Optional[str] = None + + +@dataclass +class Args: + # common args + exp_name: str = os.path.basename(__file__)[: -len(".py")] + """the name of this experiment""" + seed: int = 1 + """seed of the experiment""" + track: bool = False + """if toggled, this experiment will be tracked with Weights and Biases""" + wandb_project_name: str = "tldr_summarize" + """the wandb's project name""" + wandb_entity: Optional[str] = None + """the entity (team) of wandb's project""" + cuda: bool = True + """Whether to use cuda if available.""" + run_name: tyro.conf.Suppress[str] = None + """TO BE FILLED: a unique name of this run""" + load_from_cache_file: bool = False + """Whether to load data from the local cache file in `dataset.map`""" + upload_model: bool = False + "whether to upload the saved model to huggingface" + hf_entity: str = "" + "the user or org name of the model repository from the Hugging Face Hub" + + base_model: str = "EleutherAI/pythia-160m" + """the name of the pretrained model to use""" + dropout_layer_keys: List[str] = field(default_factory=lambda: ["attn_pdrop", "embd_pdrop", "resid_pdrop", "summary_first_dropout"]) + """Which layers to apply dropout to""" + deepspeed: bool = False + """Whether to use deepspeed to train the model""" + print_sample_output_freq: int = 10 + """How often to print sample output""" + save_path: str = "models/ppo_policy" + """Where to save the model""" + optimizer: Literal["tf_adam", "adam", "adamw"] = "adamw" + """Which optimizer to use""" + sft_model_path: str = "" + """Where to load the SFT model""" + task: TaskHParams = field(default_factory=TaskHParams) + rewards: RewardHParams = field(default_factory=RewardHParams) + ppo: PpoHParams = field(default_factory=PpoHParams) + + +# taken from https://github.com/microsoft/DeepSpeedExamples/blob/737c6740bec38b77a24a59135b6481a53d566b38/applications/DeepSpeed-Chat/training/utils/model/model_utils.py#L20C1-L26C52 +def configure_dropout(model_config, dropout_layer_keys, dropout): + if dropout is not None: + for key in dropout_layer_keys: + if hasattr(model_config, key): + print(f"Setting model_config.{key} to {dropout}") + setattr(model_config, key, dropout) + + +def print_rich_table(title: str, df: pd.DataFrame, console: Console) -> Table: + table = Table(show_lines=True) + for column in df.columns: + table.add_column(column) + for _, row in df.iterrows(): + table.add_row(*row.astype(str).tolist()) + console.rule(f"[bold red]{title}") + console.print(table) + + +def _single_tensor_adam( + params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_avg_sqs: List[Tensor], + max_exp_avg_sqs: List[Tensor], + state_steps: List[Tensor], + grad_scale: Optional[Tensor], + found_inf: Optional[Tensor], + *, + amsgrad: bool, + beta1: float, + beta2: float, + lr: float, + weight_decay: float, + eps: float, + maximize: bool, + capturable: bool, + differentiable: bool, +): + assert grad_scale is None and found_inf is None + + for i, param in enumerate(params): + grad = grads[i] if not maximize else -grads[i] + exp_avg = exp_avgs[i] + exp_avg_sq = exp_avg_sqs[i] + step_t = state_steps[i] + # update step + step_t += 1 + # Decay the first and second moment running average coefficient + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2) + step = _get_value(step_t) + + ### pytorch adam implementation: + # bias_correction1 = 1 - beta1 ** step + # bias_correction2 = 1 - beta2 ** step + # step_size = lr / bias_correction1 + # bias_correction2_sqrt = _dispatch_sqrt(bias_correction2) + # denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps) + # param.addcdiv_(exp_avg, denom, value=-step_size) + + ### tensorflow adam implementation: + lr_t = lr * _dispatch_sqrt(1 - beta2**step) / (1 - beta1**step) + denom = exp_avg_sq.sqrt().add_(eps) + param.addcdiv_(exp_avg, denom, value=-lr_t) + + +def adam( + params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_avg_sqs: List[Tensor], + max_exp_avg_sqs: List[Tensor], + state_steps: List[Tensor], + # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627 + # setting this as kwarg for now as functional API is compiled by torch/distributed/optim + foreach: Optional[bool] = None, + capturable: bool = False, + differentiable: bool = False, + fused: Optional[bool] = None, + grad_scale: Optional[Tensor] = None, + found_inf: Optional[Tensor] = None, + *, + amsgrad: bool, + beta1: float, + beta2: float, + lr: float, + weight_decay: float, + eps: float, + maximize: bool, +): + func = _single_tensor_adam + + func( + params, + grads, + exp_avgs, + exp_avg_sqs, + max_exp_avg_sqs, + state_steps, + amsgrad=amsgrad, + beta1=beta1, + beta2=beta2, + lr=lr, + weight_decay=weight_decay, + eps=eps, + maximize=maximize, + capturable=capturable, + differentiable=differentiable, + grad_scale=grad_scale, + found_inf=found_inf, + ) + + +class AdamTensorFlowStyle(optim.Adam): + @_use_grad_for_differentiable + def step(self, closure=None): + self._cuda_graph_capture_health_check() + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + params_with_grad = [] + grads = [] + exp_avgs = [] + exp_avg_sqs = [] + max_exp_avg_sqs = [] + state_steps = [] + beta1, beta2 = group["betas"] + + self._init_group( + group, + params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + max_exp_avg_sqs, + state_steps, + ) + + adam( + params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + max_exp_avg_sqs, + state_steps, + amsgrad=group["amsgrad"], + beta1=beta1, + beta2=beta2, + lr=group["lr"], + weight_decay=group["weight_decay"], + eps=group["eps"], + maximize=group["maximize"], + foreach=group["foreach"], + capturable=group["capturable"], + differentiable=group["differentiable"], + fused=group["fused"], + grad_scale=getattr(self, "grad_scale", None), + found_inf=getattr(self, "found_inf", None), + ) + + return loss + + +class AdaptiveKLController: + def __init__(self, init_kl_coef: float, hparams: AdaptiveKLParams): + self.value = init_kl_coef + self.hparams = hparams + + def update(self, current, n_steps): + target = self.hparams.target + proportional_error = np.clip(current / target - 1, -0.2, 0.2) + mult = 1 + proportional_error * n_steps / self.hparams.horizon + self.value *= mult + + +def layer_init(layer, std=np.sqrt(2), bias_const=0.0): + torch.nn.init.normal_(layer.weight, std=std) + torch.nn.init.constant_(layer.bias, val=bias_const) + return layer + + +def whiten(values, shift_mean=True): + # `unbiased=False` matches TF `tf.nn.moments`'s setting + mean, var = torch.mean(values), torch.var(values, unbiased=False) + whitened = (values - mean) * torch.rsqrt(var + 1e-8) + if not shift_mean: + whitened += mean + return whitened + + +class AutoModelForCausalLMWithRewardHead(nn.Module): + def __init__(self, lm_backbone): + super().__init__() + self.lm_backbone = lm_backbone + # self.scalar_head = layer_init( + # nn.Linear(lm_backbone.config.hidden_size, 1), + # std=1 / np.sqrt(lm_backbone.config.hidden_size + 1), + # ) + self.scalar_head = layer_init(nn.Linear(lm_backbone.config.hidden_size, 1), std=0) + # self.reward_gain = torch.nn.Parameter(torch.tensor(1.0), requires_grad=False) + self.reward_bias = torch.nn.Parameter(torch.tensor(0.0), requires_grad=False) + + def forward(self, **kwargs): + output = self.lm_backbone(**kwargs) + reward = self.scalar_head(output.hidden_states[-1]) - self.reward_bias + return reward + + +# taken from https://github.com/OpenLMLab/MOSS-RLHF/blob/40b91eb2f2b71b16919addede0341d2bef70825d/ppo/ppo_trainer.py#L29 +# we did this we can do a single `model = accelerator.prepare(model)` +class PolicyAndValueWrapper(nn.Module): + def __init__(self, policy, critic) -> None: + super().__init__() + self.policy = policy + self.critic = critic + + def forward(self, **kwargs): + return self.policy(**kwargs), self.critic(**kwargs) + + +def ceil_div(a, b): + return (a - 1) // b + 1 + + +def exact_div(a, b): + q = a // b + if a != q * b: + raise ValueError(f"Inexact division: {a} / {b} = {a / b}") + return q + + +def generate(lm_backbone, queries, tokenizer, generation_config): + """generate in a way that does not affect padding tokens""" + context_length = queries.shape[1] + attention_mask = queries != tokenizer.pad_token_id + input_ids = torch.masked_fill(queries, ~attention_mask, 0) + output = lm_backbone.generate( + input_ids=input_ids, + attention_mask=attention_mask, + # position_ids=attention_mask.cumsum(1) - attention_mask.long(), # generation collapsed if this was turned on. TODO: why does generation collapse with this? + generation_config=generation_config, + return_dict_in_generate=True, + ) + return torch.cat((queries, output.sequences[:, context_length:]), dim=1) + + +def first_true_indices(bools, dtype=torch.long): + """ + Takes an N-dimensional bool tensor and returns an (N-1)-dimensional tensor of integers giving + the position of the first True in each "row". + + Returns the length of the rows (bools.size(-1)) if no element is True in a given row. + """ + row_len = bools.size(-1) + zero_or_index = row_len * (~bools).type(dtype) + torch.arange(row_len, dtype=dtype, device=bools.device) + return torch.min(zero_or_index, dim=-1).values + + +def truncate_response(args, tokenizer, responses): + trunc_idxs = first_true_indices(responses == args.task.truncate_token_id).unsqueeze(-1) + new_size = [1] * (len(responses.size()) - 1) + [args.task.response_length] + idxs = torch.arange(args.task.response_length, device=responses.device).view(*new_size) + postprocessed_responses = torch.masked_fill(responses, idxs > trunc_idxs, tokenizer.pad_token_id) + return postprocessed_responses + + +def get_reward(reward_model, query_responses, tokenizer): + attention_mask = query_responses != tokenizer.pad_token_id + # position_ids = attention_mask.cumsum(1) - attention_mask.long() # exclusive cumsum + input_ids = torch.masked_fill(query_responses, ~attention_mask, 0) + reward_logits = reward_model( + input_ids=input_ids, + attention_mask=attention_mask, + # position_ids=position_ids, + return_dict=True, + output_hidden_states=True, + ) + sequence_lengths = first_true_indices(query_responses == tokenizer.pad_token_id) - 1 + # sequence_lengths1 = ( + # torch.eq(query_responses, tokenizer.pad_token_id).long().argmax(-1) - 1).to( + # query_responses.device + # ) + # print(f"======={sequence_lengths1=} {sequence_lengths=}") + # https://github.com/huggingface/transformers/blob/dc68a39c8111217683bf49a4912d0c9018bab33d/src/transformers/models/gpt2/modeling_gpt2.py#L1454 + return reward_logits, reward_logits[torch.arange(reward_logits.size(0), device=reward_logits.device), sequence_lengths].squeeze(-1), sequence_lengths + + +def forward(policy, query_responses, tokenizer): + attention_mask = query_responses != tokenizer.pad_token_id + # position_ids = attention_mask.cumsum(1) - attention_mask.long() # exclusive cumsum + input_ids = torch.masked_fill(query_responses, ~attention_mask, 0) + return policy( + input_ids=input_ids, + attention_mask=attention_mask, + # position_ids=position_ids, + return_dict=True, + output_hidden_states=True, + ) + + +# def train(args: Args): +if __name__ == "__main__": + args = tyro.cli(Args) + accelerator = Accelerator(gradient_accumulation_steps=args.ppo.gradient_accumulation_steps) + args.ppo.world_size = accelerator.num_processes + args.ppo.batch_size = int(args.ppo.local_batch_size * args.ppo.world_size) + args.ppo.minibatch_size = exact_div(args.ppo.batch_size, args.ppo.nminibatches) + args.ppo.local_mini_batch_size = exact_div(args.ppo.local_batch_size, args.ppo.nminibatches) + args.ppo.local_micro_batch_size = exact_div(args.ppo.local_mini_batch_size, args.ppo.gradient_accumulation_steps) + if args.ppo.whiten_rewards: + assert ( + args.ppo.local_mini_batch_size >= 8 + ), f"Per-rank minibatch size {args.ppo.local_mini_batch_size} is insufficient for whitening" + # `per_rank_rollout_batch_size` is our `args.ppo.local_batch_size` + # `per_rank_minibatch_size` is our `args.ppo.local_mini_batch_size` + args.ppo.num_updates = args.ppo.total_episodes // args.ppo.batch_size + tokenizer = AutoTokenizer.from_pretrained( + args.base_model, + padding_side="right", + trust_remote_code=True, + ) + # we use the padding token manually but do not resize the token embedding of the model + tokenizer.add_special_tokens({"pad_token": "[PAD]"}) + if args.task.truncate_token == "eos": + args.task.truncate_token_id = tokenizer.eos_token_id + + console = Console(force_terminal=True) + run_name = f"{args.exp_name}__{args.seed}__{int(time.time())}" + writer = SimpleNamespace() # dummy writer + writer.add_scalar = lambda x, y, z: None + writer.add_histogram = lambda x, y, z: None + if accelerator.is_main_process: + if args.track: + import wandb + + wandb.init( + project=args.wandb_project_name, + entity=args.wandb_entity, + sync_tensorboard=True, + config=asdict(args), + name=run_name, + save_code=True, + ) + wandb.run.log_code(".") + writer = SummaryWriter(f"runs/{run_name}") + writer.add_text( + "hyperparameters", + "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])), + ) + pprint(args) + device = accelerator.device + local_seed = args.seed + accelerator.process_index * 100003 # Prime + random.seed(local_seed) + np.random.seed(local_seed) + torch.manual_seed(local_seed) + torch.backends.cudnn.deterministic = True + + model_config = AutoConfig.from_pretrained(args.base_model) + configure_dropout(model_config, args.dropout_layer_keys, 0.0) # disable dropout + if accelerator.is_main_process: + pprint(model_config) + critic = AutoModelForCausalLMWithRewardHead( + AutoModelForCausalLM.from_pretrained( + args.base_model, + config=model_config, + trust_remote_code=True, + ) + ) + reward_model = AutoModelForCausalLMWithRewardHead( + AutoModelForCausalLM.from_pretrained( + args.base_model, + config=model_config, + trust_remote_code=True, + ) + ) + if args.rewards.trained_model: + critic.load_state_dict(torch.load(args.rewards.trained_model, map_location=device), strict=False) + critic.reward_bias.data = torch.tensor(args.rewards.dataset_mean) + reward_model.load_state_dict(torch.load(args.rewards.trained_model, map_location=device), strict=False) + reward_model.reward_bias.data = torch.tensor(args.rewards.dataset_mean) + print(f"loaded pretrained reward model from {args.rewards.trained_model}") + # each class should have a separate pretrained model that do not share weights + ref_policy = AutoModelForCausalLM.from_pretrained(args.base_model, config=model_config, trust_remote_code=True) + policy = AutoModelForCausalLM.from_pretrained(args.base_model, config=model_config, trust_remote_code=True) + policy.gradient_checkpointing_enable() + accelerator.print(policy) + critic.lm_backbone.gradient_checkpointing_enable() + accelerator.print(critic) + if args.sft_model_path: + policy.load_state_dict(torch.load(args.sft_model_path, map_location=device)) + ref_policy.load_state_dict(torch.load(args.sft_model_path, map_location=device)) + print(f"loaded pretrained policy from {args.sft_model_path}") + policy.generation_config.eos_token_id = None # disable `pad_token_id` and `eos_token_id` because we just want to + policy.generation_config.pad_token_id = None # generate tokens without truncation / padding + model = PolicyAndValueWrapper(policy, critic) + if args.optimizer == "tf_adam": + optimizer = AdamTensorFlowStyle(model.parameters(), lr=args.ppo.lr, eps=args.ppo.eps) + elif args.optimizer == "adam": + optimizer = optim.Adam(model.parameters(), lr=args.ppo.lr, eps=args.ppo.eps) + elif args.optimizer == "adamw": + optimizer = optim.AdamW(model.parameters(), lr=args.ppo.lr, eps=args.ppo.eps) + + dataset = load_dataset(args.task.query_dataset, split="train") + validation_dataset = load_dataset(args.task.query_dataset, split="validation") + dataset = dataset.with_format("torch", columns=["query_token", "reference_response_token"]) + dataset = dataset.shuffle(seed=local_seed) + dataloader = DataLoader(dataset, batch_size=args.ppo.local_batch_size) + validation_dataset = validation_dataset.with_format("torch", columns=["query_token", "reference_response_token"]) + validation_dataloader = DataLoader(validation_dataset, batch_size=args.ppo.local_batch_size) + model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader) + validation_dataloader = accelerator.prepare(validation_dataloader) + if args.deepspeed: + import deepspeed + + deepspeed_states = AcceleratorState().deepspeed_plugin + # deepspeed_states.deepspeed_config["train_micro_batch_size_per_gpu"] = args.ppo.local_micro_batch_size + # deepspeed_states.deepspeed_config["checkpoint"] = {"use_node_local_storage": True} + + offload = False + eval_ds_config = { + "train_micro_batch_size_per_gpu": deepspeed_states.deepspeed_config["train_micro_batch_size_per_gpu"], + "bf16": {"enabled": True}, + "prescale_gradients": False, + "wall_clock_breakdown": False, + } + if offload: + eval_ds_config["zero_optimization"] = { + "stage": 3, + "stage3_param_persistence_threshold": 1e4, + "offload_param": { + "device": "cpu" + } + } + accelerator.print(f"{eval_ds_config=}") + reward_model, *_ = deepspeed.initialize(model=reward_model, config=eval_ds_config) + reward_model.eval() + ref_policy, *_ = deepspeed.initialize(model=ref_policy, config=eval_ds_config) + ref_policy.eval() + else: + ref_policy = ref_policy.to(device) + reward_model = reward_model.to(device) + + def repeat_generator(): # TODO: ideally we shuffle the dataloader as well + while True: + yield from dataloader + + sample_validation_inds = np.arange(args.ppo.batch_size) + local_sample_validation_inds = sample_validation_inds[accelerator.process_index :: accelerator.num_processes] + sample_validation = validation_dataset[local_sample_validation_inds] + sample_validation_queries = torch.Tensor(sample_validation["query_token"]).to(device) + with torch.no_grad(): + # sample_validation_queries = shift_pad_id_left(sample_validation_queries, tokenizer.pad_token_id) + sample_validation_reference_response = torch.Tensor(sample_validation["reference_response_token"]).to(device) + sample_validation_query_reference_responses = torch.cat( + (sample_validation_queries, sample_validation_reference_response), dim=1 + ) + # sample_validation_query_reference_responses = shift_pad_id_left( + # sample_validation_query_reference_responses, tokenizer.pad_token_id + # ) + _, sample_validation_reference_scores, _ = get_reward( + reward_model, sample_validation_query_reference_responses, tokenizer + ) + + iter_dataloader = iter(repeat_generator()) + kl_ctl = AdaptiveKLController(args.rewards.kl_coef, hparams=args.rewards.adaptive_kl) + # WARNING: even with `max_new_tokens` and `min_new_tokens` set to the same value, the number of tokens generated + # may not be the same. TODO: investigate further, we just want to generate a fixed number of tokens + generation_config = GenerationConfig( + max_new_tokens=args.task.response_length, + min_new_tokens=args.task.response_length, + temperature=(args.task.temperature + 1e-7), + top_k=0.0, + top_p=1.0, + do_sample=True, + ) + # use the same `0.01` temperature for validation response generation https://github.com/openai/summarize-from-feedback/blob/700967448d10004279f138666442bf1497d0e705/exps/sample.py#L27 + validation_generation_config= GenerationConfig( + max_new_tokens=args.task.response_length, + min_new_tokens=args.task.response_length, + temperature=(0.01 + 1e-7), + top_k=0.0, + top_p=1.0, + do_sample=True, + ) + + print("===training policy===") + global_step = 0 + stats_shape = (args.ppo.noptepochs, args.ppo.nminibatches, args.ppo.gradient_accumulation_steps) + approxkl_stats = torch.zeros(stats_shape, device=device) + pg_clipfrac_stats = torch.zeros(stats_shape, device=device) + pg_loss_stats = torch.zeros(stats_shape, device=device) + vf_loss_stats = torch.zeros(stats_shape, device=device) + vf_clipfrac_stats = torch.zeros(stats_shape, device=device) + entropy_stats = torch.zeros(stats_shape, device=device) + ratio_stats = torch.zeros(stats_shape, device=device) + model.train() + for update in range(1, args.ppo.num_updates + 1): + global_step += 1 * args.ppo.batch_size + frac = 1.0 - (update - 1.0) / args.ppo.num_updates + lrnow = frac * args.ppo.lr + optimizer.param_groups[0]["lr"] = lrnow + data = next(iter_dataloader) + with torch.no_grad(): + queries = data["query_token"].to(device) + query_responses = generate( + accelerator.unwrap_model(model).policy, + queries, + tokenizer, + generation_config, + ) + context_length = queries.shape[1] + responses = query_responses[:, context_length:] + + # validation + sample_validation_query_responses = generate( + accelerator.unwrap_model(model).policy, + sample_validation_queries, + tokenizer, + validation_generation_config, + ) + sample_validation_responses = sample_validation_query_responses[:, context_length:] + postprocessed_sample_validation_responses = truncate_response(args, tokenizer, sample_validation_responses) + postprocessed_sample_validation_query_responses = torch.cat( + (sample_validation_queries, postprocessed_sample_validation_responses), 1 + ) + torch.cuda.empty_cache() + + # TODO: do I do this with query response or post-processed query response? + output = forward(accelerator.unwrap_model(model).policy, query_responses, tokenizer) + logits = output.logits[:, context_length - 1 : -1] + logits /= (args.task.temperature + 1e-7) + all_logprobs = F.log_softmax(logits, dim=-1) + logprobs = torch.gather(all_logprobs, 2, responses.unsqueeze(-1)).squeeze(-1) + del output, logits, all_logprobs + torch.cuda.empty_cache() + + ref_output = forward(ref_policy, query_responses, tokenizer) + ref_logits = ref_output.logits[:, context_length - 1 : -1] + ref_logits /= (args.task.temperature + 1e-7) + ref_all_logprobs = F.log_softmax(ref_logits, dim=-1) + ref_logprobs = torch.gather(ref_all_logprobs, 2, responses.unsqueeze(-1)).squeeze(-1) + del ref_output, ref_logits, ref_all_logprobs + torch.cuda.empty_cache() + + # **Response Processing** + postprocessed_responses = truncate_response(args, tokenizer, responses) + torch.cuda.empty_cache() + + # 2. run reward model on the truncated responses + postprocessed_query_responses = torch.cat((queries, postprocessed_responses), 1) + # sequence_lengths = first_true_indices(postprocessed_responses == tokenizer.pad_token_id) - 1 + # actual_start = torch.arange(postprocessed_responses.size(0), device=postprocessed_responses.device) + # actual_end = sequence_lengths + # padding_mask = postprocessed_responses == tokenizer.pad_token_id + + full_values, _, _ = get_reward(accelerator.unwrap_model(model).critic, query_responses, tokenizer) + values = full_values[:, context_length - 1 : -1].squeeze(-1) + # values_mask = postprocessed_responses != args.task.truncate_token_id + # values = torch.masked_fill(values, values_mask, 0) + # values = torch.masked_fill(values, padding_mask, 0) + + # logprobs = torch.masked_fill(logprobs, padding_mask, INVALID_LOGPROB) + # ref_logprobs = torch.masked_fill(ref_logprobs, padding_mask, INVALID_LOGPROB) + _, scores, _ = get_reward(reward_model, postprocessed_query_responses, tokenizer) + + _, validation_score, _ = get_reward(reward_model, postprocessed_sample_validation_query_responses, tokenizer) + + # 3. filter response. Ensure that the sample contains truncate_token_id + # responses not passing that filter will receive a low (fixed) score + # only query humans on responses that pass that filter + contain_pad_token = torch.any(postprocessed_responses == tokenizer.pad_token_id, dim=-1) + scores = torch.where(contain_pad_token, scores, torch.full_like(scores, args.task.penalty_reward_value)) + + # TODO: do we need to deal with penalty values? + # penalty_values = torch.full_like(values, 0) + # penalty_values[:,-1] += args.task.penalty_reward_value + # values = torch.where(contain_pad_token, values, penalty_values) + accelerator.print(f"{scores=}, {(contain_pad_token.sum() / len(contain_pad_token))=}") + # torch.cuda.empty_cache() + + # 4. compute rewards + kl = logprobs - ref_logprobs + # kl = torch.masked_fill(kl, padding_mask, 0) + non_score_reward = -kl_ctl.value * kl + rewards = non_score_reward.clone() + rewards[:, -1] += scores + + # 5. whiten rewards + if args.ppo.whiten_rewards: + rewards = whiten(rewards, shift_mean=False) + + if args.print_sample_output_freq > 0 and (update - 1) % args.print_sample_output_freq == 0: + try: + all_decode_validation_queries = tokenizer.batch_decode(sample_validation_queries, skip_special_tokens=True) + all_sample_validation_responses = tokenizer.batch_decode(sample_validation_responses) + all_sample_validation_query_responses_postprocessed = tokenizer.batch_decode( + postprocessed_sample_validation_query_responses, skip_special_tokens=True + ) + all_sample_validation_postprocessed_responses = [ + x[len(y) :] + for x, y in zip(all_sample_validation_query_responses_postprocessed, all_decode_validation_queries) + ] + all_sample_validation_reference_responses = tokenizer.batch_decode(sample_validation_reference_response) + all_sample_validation_df = pd.DataFrame( + { + "query": all_decode_validation_queries, + "response": all_sample_validation_responses, + "postprocessed_response": all_sample_validation_postprocessed_responses, + "reference_responses": all_sample_validation_reference_responses, + "scores": validation_score.float().cpu().numpy(), + "reference_scores": sample_validation_reference_scores.float().cpu().numpy(), + } + ) + if accelerator.is_main_process: + all_sample_validation_df.to_json(f"runs/{run_name}/table.json") + if args.track: + wandb.log({"samples/query_responses": wandb.Table(dataframe=all_sample_validation_df)}, step=update) + print_rich_table("stuff", all_sample_validation_df[:4], console) + + except Exception as e: + print(e) + del ( + all_decode_validation_queries, + all_sample_validation_responses, + all_sample_validation_reference_responses, + all_sample_validation_df, + ) + # del postprocessed_query_responses + # torch.cuda.empty_cache() + + # 6. compute advantages and returns + lastgaelam = 0 + advantages_reversed = [] + gen_length = args.task.response_length + for t in reversed(range(gen_length)): + nextvalues = values[:, t + 1] if t < gen_length - 1 else 0.0 + delta = rewards[:, t] + args.ppo.gamma * nextvalues - values[:, t] + lastgaelam = delta + args.ppo.gamma * args.ppo.lam * lastgaelam + advantages_reversed.append(lastgaelam) + advantages = torch.stack(advantages_reversed[::-1], axis=1) + returns = advantages + values + advantages = whiten(advantages) + return_mean, return_var = returns.mean(), returns.var() + value_mean, value_var = values.mean(), values.var() + writer.add_histogram("rewards", rewards[0].float(), global_step) + writer.add_histogram("advantages", advantages[0].float(), global_step) + accelerator.print("rewards====", rewards[0]) + accelerator.print("advantages====", advantages[0]) + # raise + # pprint({ + # "rewards": rewards, + # "returns": returns, + # "advantages": advantages, + # }) + # breakpoint() + # Do multiple epochs of PPO training, with a fresh random shuffle in each epoch + for ppo_epoch_idx in range(args.ppo.noptepochs): + b_inds = np.random.permutation(args.ppo.local_batch_size) + minibatch_idx = 0 + for mini_batch_start in range(0, args.ppo.local_batch_size, args.ppo.local_mini_batch_size): + mini_batch_end = mini_batch_start + args.ppo.local_mini_batch_size + mini_batch_inds = b_inds[mini_batch_start:mini_batch_end] + gradient_accumulation_idx = 0 + for micro_batch_start in range(0, args.ppo.local_mini_batch_size, args.ppo.local_micro_batch_size): + with accelerator.accumulate(policy): + micro_batch_end = micro_batch_start + args.ppo.local_micro_batch_size + micro_batch_inds = mini_batch_inds[micro_batch_start:micro_batch_end] + mb_return = returns[micro_batch_inds] + mb_advantage = advantages[micro_batch_inds] + mb_values = values[micro_batch_inds] + mb_responses = responses[micro_batch_inds] + mb_query_responses = query_responses[micro_batch_inds] + mb_logprobs = logprobs[micro_batch_inds] + + output, vpred_temp = forward(model, mb_query_responses, tokenizer) + logits = output.logits[:, context_length - 1 : -1] + logits /= (args.task.temperature + 1e-7) + new_all_logprobs = F.log_softmax(logits, dim=-1) + new_logprobs = torch.gather(new_all_logprobs, 2, mb_responses.unsqueeze(-1)).squeeze(-1) + # new_logprobs = torch.masked_fill(new_logprobs, padding_mask[micro_batch_inds], INVALID_LOGPROB) + vpred = vpred_temp[:, context_length - 1 : -1].squeeze(-1) + # vpred = torch.masked_fill(vpred, padding_mask[micro_batch_inds], 0) + # vpred = torch.masked_fill(vpred, values_mask[micro_batch_inds], 0) + vpredclipped = torch.clamp( + vpred, + mb_values - args.ppo.cliprange_value, + mb_values + args.ppo.cliprange_value, + ) + vf_losses1 = torch.square(vpred - mb_return) + vf_losses2 = torch.square(vpredclipped - mb_return) + vf_loss = 0.5 * torch.max(vf_losses1, vf_losses2).mean() + vf_clipfrac = (vf_losses2 > vf_losses1).float().mean() + logprobs_diff = new_logprobs - mb_logprobs + ratio = torch.exp(logprobs_diff) + pg_losses = -mb_advantage * ratio + pg_losses2 = -mb_advantage * torch.clamp(ratio, 1.0 - args.ppo.cliprange, 1.0 + args.ppo.cliprange) + pg_loss = torch.max(pg_losses, pg_losses2).mean() + pg_clipfrac = (pg_losses2 > pg_losses).float().mean() + loss = pg_loss + args.ppo.vf_coef * vf_loss + accelerator.backward(loss) + optimizer.step() + optimizer.zero_grad() + prob_dist = torch.nn.functional.softmax(logits, dim=-1) + entropy = torch.logsumexp(logits, dim=-1) - torch.sum(prob_dist * logits, dim=-1) + approxkl = 0.5 * (logprobs_diff**2).mean() + # if ppo_epoch_idx == 0 and micro_batch_start == 0: + # torch.testing.assert_close(ratio, torch.zeros_like(ratio) + 1, atol=1e-4, rtol=1e-4) + # if ppo_epoch_idx == 0: + # pprint({ + # # "responses": responses, + # # "values": values, + # "rewards": rewards, + # # "scores": scores, + # "advantages": advantages, + # # "ratio": ratio, + # # "pg_losses": pg_losses, + # # "approxkl": approxkl, + # # "pg_loss": pg_loss, + # # "pg_clipfrac": pg_clipfrac, + # # "ratio": ratio.mean(), + # # "vf_loss": vf_loss, + # # "vf_clipfrac": vf_clipfrac, + # # "entropy": masked_mean(entropy, ~padding_mask[micro_batch_inds]), + # }) + # breakpoint() + with torch.no_grad(): + approxkl_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = approxkl + pg_clipfrac_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = pg_clipfrac + pg_loss_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = pg_loss + vf_loss_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = vf_loss + vf_clipfrac_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = vf_clipfrac + entropy_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = entropy.mean() + ratio_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = ratio.mean() + gradient_accumulation_idx += 1 + minibatch_idx += 1 + if accelerator.is_main_process: + console.print( + f"ppo_epoch_idx", + ppo_epoch_idx, + "approxkl", + approxkl_stats[:ppo_epoch_idx+1].mean().item(), + "pg_loss", + pg_loss_stats[:ppo_epoch_idx+1].mean().item(), + "pg_clipfrac", + pg_clipfrac_stats[:ppo_epoch_idx+1].mean().item(), + "ratio", + ratio_stats[:ppo_epoch_idx+1].mean().item(), + ) + # raise + # breakpoint() + with torch.no_grad(): + if not args.deepspeed: # for some reason there is a OOM with the `writer.add_histogram` + writer.add_histogram("ppo/val/ratio_hist", ratio, update) + mean_kl = kl.sum(1).mean() + mean_entropy = (-logprobs).sum(1).mean() + mean_non_score_reward = non_score_reward.sum(1).mean() + writer.add_scalar("objective/kl_coef", kl_ctl.value, update) + writer.add_scalar("objective/kl", accelerator.gather(mean_kl).mean().item(), update) + writer.add_scalar("objective/entropy", accelerator.gather(mean_entropy).mean().item(), update) + writer.add_scalar("objective/non_score_reward", accelerator.gather(mean_non_score_reward).mean().item(), update) + writer.add_scalar( + "objective/score_total", accelerator.gather(mean_non_score_reward + scores.mean()).mean().item(), update + ) + writer.add_scalar("objective/scores", accelerator.gather(scores.mean()).mean().item(), update) + writer.add_scalar("objective/validation_score", accelerator.gather(validation_score.mean()).mean().item(), update) + writer.add_scalar("ppo/loss/policy", accelerator.gather(pg_loss).mean().item(), update) + writer.add_scalar("ppo/loss/value", accelerator.gather(vf_loss).mean().item(), update) + writer.add_scalar("ppo/loss/total", accelerator.gather(loss).mean().item(), update) + writer.add_scalar("ppo/policy/entropy", accelerator.gather(entropy.mean()).mean().item(), update) + writer.add_scalar("ppo/policy/approxkl", accelerator.gather(approxkl).mean().item(), update) + writer.add_scalar("ppo/policy/clipfrac", accelerator.gather(pg_clipfrac).mean().item(), update) + writer.add_scalar("ppo/policy/approxkl_avg", accelerator.gather(approxkl_stats).mean().item(), update) + writer.add_scalar("ppo/policy/clipfrac_avg", accelerator.gather(pg_clipfrac_stats).mean().item(), update) + writer.add_scalar("ppo/loss/policy_avg", accelerator.gather(pg_loss_stats).mean().item(), update) + writer.add_scalar("ppo/loss/value_avg", accelerator.gather(vf_loss_stats).mean().item(), update) + writer.add_scalar("ppo/val/clipfrac_avg", accelerator.gather(vf_clipfrac_stats).mean().item(), update) + writer.add_scalar("ppo/policy/entropy_avg", accelerator.gather(entropy_stats).mean().item(), update) + writer.add_scalar("ppo/returns/mean", accelerator.gather(return_mean).mean().item(), update) + writer.add_scalar("ppo/returns/var", accelerator.gather(return_var).mean().item(), update) + writer.add_scalar("ppo/val/vpred", accelerator.gather(vpred.mean()).mean().item(), update) + writer.add_scalar("ppo/val/error", accelerator.gather(vf_losses1.mean()).mean().item(), update) + writer.add_scalar("ppo/val/clipfrac", accelerator.gather(vf_clipfrac).mean().item(), update) + writer.add_scalar("ppo/val/mean", accelerator.gather(value_mean).mean().item(), update) + writer.add_scalar("ppo/val/var", accelerator.gather(value_var).mean().item(), update) + writer.add_scalar("ppo/val/ratio", accelerator.gather(ratio_stats).mean().item(), update) + writer.add_scalar("ppo/val/ratio_var", accelerator.gather(ratio_stats).var().item(), update) + writer.add_scalar("ppo/val/advantage", accelerator.gather(advantages.mean()).mean().item(), update) + writer.add_scalar("ppo/val/advantage_var", accelerator.gather(advantages.mean()).var().item(), update) + writer.add_scalar("ppo/val/num_eos_tokens", (responses == tokenizer.eos_token_id).sum().item(), update) + writer.add_scalar("ppo/lr", lrnow, update) + writer.add_scalar("ppo/episode", global_step, update) + if args.rewards.use_adaptive_kl: + kl_ctl.update(mean_kl.item(), args.ppo.batch_size) + del kl, mean_kl, mean_entropy, mean_non_score_reward, scores + + # save model + if args.save_path: + os.makedirs(os.path.dirname(args.save_path), exist_ok=True) + accelerator.save_model(policy, args.save_path, max_shard_size="1000GB") + + if args.upload_model and accelerator.is_main_process: + repo_name = f"{args.exp_name}__{args.rewards.label_dataset}__seed{args.seed}__{int(time.time())}" + repo_id = f"{args.hf_entity}/{repo_name}" if args.hf_entity else repo_name + policy.save_pretrained(repo_id, safe_serialization=True, push_to_hub=True) + tokenizer.save_pretrained(repo_id, push_to_hub=True) + +# if __name__ == "__main__": +# args = tyro.cli(Args) +# train(args) diff --git a/lm_human_preference_details/train_policy_accelerate_summarize_separate6_correct_reward_index.py b/lm_human_preference_details/train_policy_accelerate_summarize_separate6_correct_reward_index.py new file mode 100644 index 0000000..7331a66 --- /dev/null +++ b/lm_human_preference_details/train_policy_accelerate_summarize_separate6_correct_reward_index.py @@ -0,0 +1,979 @@ +import os +import random +import time +from dataclasses import asdict, dataclass, field +from types import SimpleNamespace +from typing import List, Literal, Optional + +import numpy as np +import pandas as pd +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +import tyro +from accelerate import Accelerator +from accelerate.state import AcceleratorState +from datasets import load_dataset +from rich.console import Console +from rich.pretty import pprint +from rich.table import Table +from torch import Tensor, optim +from torch.optim.optimizer import ( + _dispatch_sqrt, + _get_value, + _use_grad_for_differentiable, +) +from torch.utils.data import DataLoader +from torch.utils.tensorboard import SummaryWriter +from transformers import ( + AutoConfig, + AutoModelForCausalLM, + AutoTokenizer, + GenerationConfig, +) + + +INVALID_LOGPROB = 1.0 + + +@dataclass +class AdaptiveKLParams: + target: float = 6.0 + horizon: int = 10000 # in episodes + + +@dataclass +class RewardHParams: + use_adaptive_kl: bool = True + adaptive_kl: Optional[AdaptiveKLParams] = field(default_factory=AdaptiveKLParams) + trained_model: Optional[str] = "" + label_dataset: tyro.conf.Suppress[Optional[str]] = None + dataset_mean: float = 0. + dataset_std: float = 1. + kl_coef: float = 0.15 + + +@dataclass +class PpoHParams: + total_episodes: int = 1000000 + local_batch_size: int = 64 + local_mini_batch_size: tyro.conf.Suppress[int] = None + batch_size: tyro.conf.Suppress[int] = None + mini_batch_size: tyro.conf.Suppress[int] = None + gradient_accumulation_steps: int = 64 + """gradient accumulation steps""" + local_micro_batch_size: tyro.conf.Suppress[int] = None + """per rank micro batch size""" + world_size: tyro.conf.Suppress[int] = None + batch_size: tyro.conf.Suppress[int] = None + minibatch_size: tyro.conf.Suppress[int] = None + num_updates: tyro.conf.Suppress[int] = None + nminibatches: int = 1 + noptepochs: int = 4 + lr: float = 0.00001 + eps: float = 1e-5 + vf_coef: float = 0.1 + cliprange: float = 0.2 + cliprange_value: float = 0.2 + gamma: float = 1 + lam: float = 0.95 + whiten_rewards: bool = True + + +@dataclass +class TaskHParams: + # Query params + query_length: int = 512 + query_dataset: str = "vwxyzjn/summarize_from_feedback_tldr_3_filtered_oai_preprocessing_pythia-160m_53" + + query_format_str: Optional[str] = "SUBREDDIT: r/{subreddit}\n\nTITLE: {title}\n\nPOST: {post}\n\nTL;DR:" + query_truncate_field: Optional[str] = "post" + query_truncate_text: Optional[str] = "\n" + query_padding: Optional[str] = None # defaults to repeated spaces + query_pad_side: Optional[str] = "left" + + # Response params + response_length: int = 53 + + # Truncate response after the first occurrence of this token at or after index after when sampling. + truncate_token: Literal["eos"] = "eos" + truncate_token_id: Optional[int] = None + truncate_after: int = 16 + penalty_reward_value: int = -1 + + # LM params + temperature: float = 0.7 + + +# a patch +@dataclass +class TaskQueryHParams: + length: int = None + dataset: str = None + format_str: Optional[str] = None # if underlying dataset yields dicts, can format arbitrarily + truncate_field: Optional[str] = None + truncate_text: Optional[str] = None + padding: Optional[str] = None # defaults to repeated spaces + pad_side: Optional[str] = None + + +@dataclass +class Args: + # common args + exp_name: str = os.path.basename(__file__)[: -len(".py")] + """the name of this experiment""" + seed: int = 1 + """seed of the experiment""" + track: bool = False + """if toggled, this experiment will be tracked with Weights and Biases""" + wandb_project_name: str = "tldr_summarize" + """the wandb's project name""" + wandb_entity: Optional[str] = None + """the entity (team) of wandb's project""" + cuda: bool = True + """Whether to use cuda if available.""" + run_name: tyro.conf.Suppress[str] = None + """TO BE FILLED: a unique name of this run""" + load_from_cache_file: bool = False + """Whether to load data from the local cache file in `dataset.map`""" + upload_model: bool = False + "whether to upload the saved model to huggingface" + hf_entity: str = "" + "the user or org name of the model repository from the Hugging Face Hub" + + base_model: str = "EleutherAI/pythia-160m" + """the name of the pretrained model to use""" + dropout_layer_keys: List[str] = field(default_factory=lambda: ["attn_pdrop", "embd_pdrop", "resid_pdrop", "summary_first_dropout"]) + """Which layers to apply dropout to""" + deepspeed: bool = False + """Whether to use deepspeed to train the model""" + print_sample_output_freq: int = 10 + """How often to print sample output""" + save_path: str = "models/ppo_policy" + """Where to save the model""" + optimizer: Literal["tf_adam", "adam", "adamw"] = "adamw" + """Which optimizer to use""" + sft_model_path: str = "" + """Where to load the SFT model""" + task: TaskHParams = field(default_factory=TaskHParams) + rewards: RewardHParams = field(default_factory=RewardHParams) + ppo: PpoHParams = field(default_factory=PpoHParams) + + +# taken from https://github.com/microsoft/DeepSpeedExamples/blob/737c6740bec38b77a24a59135b6481a53d566b38/applications/DeepSpeed-Chat/training/utils/model/model_utils.py#L20C1-L26C52 +def configure_dropout(model_config, dropout_layer_keys, dropout): + if dropout is not None: + for key in dropout_layer_keys: + if hasattr(model_config, key): + print(f"Setting model_config.{key} to {dropout}") + setattr(model_config, key, dropout) + + +def print_rich_table(title: str, df: pd.DataFrame, console: Console) -> Table: + table = Table(show_lines=True) + for column in df.columns: + table.add_column(column) + for _, row in df.iterrows(): + table.add_row(*row.astype(str).tolist()) + console.rule(f"[bold red]{title}") + console.print(table) + + +def _single_tensor_adam( + params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_avg_sqs: List[Tensor], + max_exp_avg_sqs: List[Tensor], + state_steps: List[Tensor], + grad_scale: Optional[Tensor], + found_inf: Optional[Tensor], + *, + amsgrad: bool, + beta1: float, + beta2: float, + lr: float, + weight_decay: float, + eps: float, + maximize: bool, + capturable: bool, + differentiable: bool, +): + assert grad_scale is None and found_inf is None + + for i, param in enumerate(params): + grad = grads[i] if not maximize else -grads[i] + exp_avg = exp_avgs[i] + exp_avg_sq = exp_avg_sqs[i] + step_t = state_steps[i] + # update step + step_t += 1 + # Decay the first and second moment running average coefficient + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2) + step = _get_value(step_t) + + ### pytorch adam implementation: + # bias_correction1 = 1 - beta1 ** step + # bias_correction2 = 1 - beta2 ** step + # step_size = lr / bias_correction1 + # bias_correction2_sqrt = _dispatch_sqrt(bias_correction2) + # denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps) + # param.addcdiv_(exp_avg, denom, value=-step_size) + + ### tensorflow adam implementation: + lr_t = lr * _dispatch_sqrt(1 - beta2**step) / (1 - beta1**step) + denom = exp_avg_sq.sqrt().add_(eps) + param.addcdiv_(exp_avg, denom, value=-lr_t) + + +def adam( + params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_avg_sqs: List[Tensor], + max_exp_avg_sqs: List[Tensor], + state_steps: List[Tensor], + # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627 + # setting this as kwarg for now as functional API is compiled by torch/distributed/optim + foreach: Optional[bool] = None, + capturable: bool = False, + differentiable: bool = False, + fused: Optional[bool] = None, + grad_scale: Optional[Tensor] = None, + found_inf: Optional[Tensor] = None, + *, + amsgrad: bool, + beta1: float, + beta2: float, + lr: float, + weight_decay: float, + eps: float, + maximize: bool, +): + func = _single_tensor_adam + + func( + params, + grads, + exp_avgs, + exp_avg_sqs, + max_exp_avg_sqs, + state_steps, + amsgrad=amsgrad, + beta1=beta1, + beta2=beta2, + lr=lr, + weight_decay=weight_decay, + eps=eps, + maximize=maximize, + capturable=capturable, + differentiable=differentiable, + grad_scale=grad_scale, + found_inf=found_inf, + ) + + +class AdamTensorFlowStyle(optim.Adam): + @_use_grad_for_differentiable + def step(self, closure=None): + self._cuda_graph_capture_health_check() + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + params_with_grad = [] + grads = [] + exp_avgs = [] + exp_avg_sqs = [] + max_exp_avg_sqs = [] + state_steps = [] + beta1, beta2 = group["betas"] + + self._init_group( + group, + params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + max_exp_avg_sqs, + state_steps, + ) + + adam( + params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + max_exp_avg_sqs, + state_steps, + amsgrad=group["amsgrad"], + beta1=beta1, + beta2=beta2, + lr=group["lr"], + weight_decay=group["weight_decay"], + eps=group["eps"], + maximize=group["maximize"], + foreach=group["foreach"], + capturable=group["capturable"], + differentiable=group["differentiable"], + fused=group["fused"], + grad_scale=getattr(self, "grad_scale", None), + found_inf=getattr(self, "found_inf", None), + ) + + return loss + + +class AdaptiveKLController: + def __init__(self, init_kl_coef: float, hparams: AdaptiveKLParams): + self.value = init_kl_coef + self.hparams = hparams + + def update(self, current, n_steps): + target = self.hparams.target + proportional_error = np.clip(current / target - 1, -0.2, 0.2) + mult = 1 + proportional_error * n_steps / self.hparams.horizon + self.value *= mult + + +def layer_init(layer, std=np.sqrt(2), bias_const=0.0): + torch.nn.init.normal_(layer.weight, std=std) + torch.nn.init.constant_(layer.bias, val=bias_const) + return layer + + +def whiten(values, shift_mean=True): + # `unbiased=False` matches TF `tf.nn.moments`'s setting + mean, var = torch.mean(values), torch.var(values, unbiased=False) + whitened = (values - mean) * torch.rsqrt(var + 1e-8) + if not shift_mean: + whitened += mean + return whitened + + +class AutoModelForCausalLMWithRewardHead(nn.Module): + def __init__(self, lm_backbone): + super().__init__() + self.lm_backbone = lm_backbone + # self.scalar_head = layer_init( + # nn.Linear(lm_backbone.config.hidden_size, 1), + # std=1 / np.sqrt(lm_backbone.config.hidden_size + 1), + # ) + self.scalar_head = layer_init(nn.Linear(lm_backbone.config.hidden_size, 1), std=0) + # self.reward_gain = torch.nn.Parameter(torch.tensor(1.0), requires_grad=False) + self.reward_bias = torch.nn.Parameter(torch.tensor(0.0), requires_grad=False) + + def forward(self, **kwargs): + output = self.lm_backbone(**kwargs) + reward = self.scalar_head(output.hidden_states[-1]) - self.reward_bias + return reward + + +# taken from https://github.com/OpenLMLab/MOSS-RLHF/blob/40b91eb2f2b71b16919addede0341d2bef70825d/ppo/ppo_trainer.py#L29 +# we did this we can do a single `model = accelerator.prepare(model)` +class PolicyAndValueWrapper(nn.Module): + def __init__(self, policy, critic) -> None: + super().__init__() + self.policy = policy + self.critic = critic + + def forward(self, **kwargs): + return self.policy(**kwargs), self.critic(**kwargs) + + +def ceil_div(a, b): + return (a - 1) // b + 1 + + +def exact_div(a, b): + q = a // b + if a != q * b: + raise ValueError(f"Inexact division: {a} / {b} = {a / b}") + return q + + +def generate(lm_backbone, queries, tokenizer, generation_config): + """generate in a way that does not affect padding tokens""" + context_length = queries.shape[1] + attention_mask = queries != tokenizer.pad_token_id + input_ids = torch.masked_fill(queries, ~attention_mask, 0) + output = lm_backbone.generate( + input_ids=input_ids, + attention_mask=attention_mask, + # position_ids=attention_mask.cumsum(1) - attention_mask.long(), # generation collapsed if this was turned on. TODO: why does generation collapse with this? + generation_config=generation_config, + return_dict_in_generate=True, + ) + return torch.cat((queries, output.sequences[:, context_length:]), dim=1) + + +def first_true_indices(bools, dtype=torch.long): + """ + Takes an N-dimensional bool tensor and returns an (N-1)-dimensional tensor of integers giving + the position of the first True in each "row". + + Returns the length of the rows (bools.size(-1)) if no element is True in a given row. + """ + row_len = bools.size(-1) + zero_or_index = row_len * (~bools).type(dtype) + torch.arange(row_len, dtype=dtype, device=bools.device) + return torch.min(zero_or_index, dim=-1).values + + +def truncate_response(args, tokenizer, responses): + trunc_idxs = first_true_indices(responses == args.task.truncate_token_id).unsqueeze(-1) + new_size = [1] * (len(responses.size()) - 1) + [args.task.response_length] + idxs = torch.arange(args.task.response_length, device=responses.device).view(*new_size) + postprocessed_responses = torch.masked_fill(responses, idxs > trunc_idxs, tokenizer.pad_token_id) + return postprocessed_responses + + +def get_reward(reward_model, query_responses, tokenizer): + attention_mask = query_responses != tokenizer.pad_token_id + # position_ids = attention_mask.cumsum(1) - attention_mask.long() # exclusive cumsum + input_ids = torch.masked_fill(query_responses, ~attention_mask, 0) + reward_logits = reward_model( + input_ids=input_ids, + attention_mask=attention_mask, + # position_ids=position_ids, + return_dict=True, + output_hidden_states=True, + ) + sequence_lengths = first_true_indices(query_responses == tokenizer.pad_token_id) - 1 + # sequence_lengths1 = ( + # torch.eq(query_responses, tokenizer.pad_token_id).long().argmax(-1) - 1).to( + # query_responses.device + # ) + # print(f"======={sequence_lengths1=} {sequence_lengths=}") + # https://github.com/huggingface/transformers/blob/dc68a39c8111217683bf49a4912d0c9018bab33d/src/transformers/models/gpt2/modeling_gpt2.py#L1454 + return reward_logits, reward_logits[torch.arange(reward_logits.size(0), device=reward_logits.device), sequence_lengths].squeeze(-1), sequence_lengths + + +def forward(policy, query_responses, tokenizer): + attention_mask = query_responses != tokenizer.pad_token_id + # position_ids = attention_mask.cumsum(1) - attention_mask.long() # exclusive cumsum + input_ids = torch.masked_fill(query_responses, ~attention_mask, 0) + return policy( + input_ids=input_ids, + attention_mask=attention_mask, + # position_ids=position_ids, + return_dict=True, + output_hidden_states=True, + ) + + +# def train(args: Args): +if __name__ == "__main__": + args = tyro.cli(Args) + accelerator = Accelerator(gradient_accumulation_steps=args.ppo.gradient_accumulation_steps) + args.ppo.world_size = accelerator.num_processes + args.ppo.batch_size = int(args.ppo.local_batch_size * args.ppo.world_size) + args.ppo.minibatch_size = exact_div(args.ppo.batch_size, args.ppo.nminibatches) + args.ppo.local_mini_batch_size = exact_div(args.ppo.local_batch_size, args.ppo.nminibatches) + args.ppo.local_micro_batch_size = exact_div(args.ppo.local_mini_batch_size, args.ppo.gradient_accumulation_steps) + if args.ppo.whiten_rewards: + assert ( + args.ppo.local_mini_batch_size >= 8 + ), f"Per-rank minibatch size {args.ppo.local_mini_batch_size} is insufficient for whitening" + # `per_rank_rollout_batch_size` is our `args.ppo.local_batch_size` + # `per_rank_minibatch_size` is our `args.ppo.local_mini_batch_size` + args.ppo.num_updates = args.ppo.total_episodes // args.ppo.batch_size + tokenizer = AutoTokenizer.from_pretrained( + args.base_model, + padding_side="right", + trust_remote_code=True, + ) + # we use the padding token manually but do not resize the token embedding of the model + tokenizer.add_special_tokens({"pad_token": "[PAD]"}) + if args.task.truncate_token == "eos": + args.task.truncate_token_id = tokenizer.eos_token_id + + console = Console(force_terminal=True) + run_name = f"{args.exp_name}__{args.seed}__{int(time.time())}" + writer = SimpleNamespace() # dummy writer + writer.add_scalar = lambda x, y, z: None + writer.add_histogram = lambda x, y, z: None + if accelerator.is_main_process: + if args.track: + import wandb + + wandb.init( + project=args.wandb_project_name, + entity=args.wandb_entity, + sync_tensorboard=True, + config=asdict(args), + name=run_name, + save_code=True, + ) + wandb.run.log_code(".") + writer = SummaryWriter(f"runs/{run_name}") + writer.add_text( + "hyperparameters", + "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])), + ) + pprint(args) + device = accelerator.device + local_seed = args.seed + accelerator.process_index * 100003 # Prime + random.seed(local_seed) + np.random.seed(local_seed) + torch.manual_seed(local_seed) + torch.backends.cudnn.deterministic = True + + model_config = AutoConfig.from_pretrained(args.base_model) + configure_dropout(model_config, args.dropout_layer_keys, 0.0) # disable dropout + if accelerator.is_main_process: + pprint(model_config) + critic = AutoModelForCausalLMWithRewardHead( + AutoModelForCausalLM.from_pretrained( + args.base_model, + config=model_config, + trust_remote_code=True, + ) + ) + reward_model = AutoModelForCausalLMWithRewardHead( + AutoModelForCausalLM.from_pretrained( + args.base_model, + config=model_config, + trust_remote_code=True, + ) + ) + if args.rewards.trained_model: + critic.load_state_dict(torch.load(args.rewards.trained_model, map_location=device), strict=False) + critic.reward_bias.data = torch.tensor(args.rewards.dataset_mean) + reward_model.load_state_dict(torch.load(args.rewards.trained_model, map_location=device), strict=False) + reward_model.reward_bias.data = torch.tensor(args.rewards.dataset_mean) + print(f"loaded pretrained reward model from {args.rewards.trained_model}") + # each class should have a separate pretrained model that do not share weights + ref_policy = AutoModelForCausalLM.from_pretrained(args.base_model, config=model_config, trust_remote_code=True) + policy = AutoModelForCausalLM.from_pretrained(args.base_model, config=model_config, trust_remote_code=True) + policy.gradient_checkpointing_enable() + accelerator.print(policy) + critic.lm_backbone.gradient_checkpointing_enable() + accelerator.print(critic) + if args.sft_model_path: + policy.load_state_dict(torch.load(args.sft_model_path, map_location=device)) + ref_policy.load_state_dict(torch.load(args.sft_model_path, map_location=device)) + print(f"loaded pretrained policy from {args.sft_model_path}") + policy.generation_config.eos_token_id = None # disable `pad_token_id` and `eos_token_id` because we just want to + policy.generation_config.pad_token_id = None # generate tokens without truncation / padding + model = PolicyAndValueWrapper(policy, critic) + if args.optimizer == "tf_adam": + optimizer = AdamTensorFlowStyle(model.parameters(), lr=args.ppo.lr, eps=args.ppo.eps) + elif args.optimizer == "adam": + optimizer = optim.Adam(model.parameters(), lr=args.ppo.lr, eps=args.ppo.eps) + elif args.optimizer == "adamw": + optimizer = optim.AdamW(model.parameters(), lr=args.ppo.lr, eps=args.ppo.eps) + + dataset = load_dataset(args.task.query_dataset, split="train") + validation_dataset = load_dataset(args.task.query_dataset, split="validation") + dataset = dataset.with_format("torch", columns=["query_token", "reference_response_token"]) + dataset = dataset.shuffle(seed=local_seed) + dataloader = DataLoader(dataset, batch_size=args.ppo.local_batch_size) + validation_dataset = validation_dataset.with_format("torch", columns=["query_token", "reference_response_token"]) + validation_dataloader = DataLoader(validation_dataset, batch_size=args.ppo.local_batch_size) + model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader) + validation_dataloader = accelerator.prepare(validation_dataloader) + if args.deepspeed: + import deepspeed + + deepspeed_states = AcceleratorState().deepspeed_plugin + # deepspeed_states.deepspeed_config["train_micro_batch_size_per_gpu"] = args.ppo.local_micro_batch_size + # deepspeed_states.deepspeed_config["checkpoint"] = {"use_node_local_storage": True} + + offload = False + eval_ds_config = { + "train_micro_batch_size_per_gpu": deepspeed_states.deepspeed_config["train_micro_batch_size_per_gpu"], + "bf16": {"enabled": True}, + "prescale_gradients": False, + "wall_clock_breakdown": False, + } + if offload: + eval_ds_config["zero_optimization"] = { + "stage": 3, + "stage3_param_persistence_threshold": 1e4, + "offload_param": { + "device": "cpu" + } + } + accelerator.print(f"{eval_ds_config=}") + reward_model, *_ = deepspeed.initialize(model=reward_model, config=eval_ds_config) + reward_model.eval() + ref_policy, *_ = deepspeed.initialize(model=ref_policy, config=eval_ds_config) + ref_policy.eval() + else: + ref_policy = ref_policy.to(device) + reward_model = reward_model.to(device) + + def repeat_generator(): # TODO: ideally we shuffle the dataloader as well + while True: + yield from dataloader + + sample_validation_inds = np.arange(args.ppo.batch_size) + local_sample_validation_inds = sample_validation_inds[accelerator.process_index :: accelerator.num_processes] + sample_validation = validation_dataset[local_sample_validation_inds] + sample_validation_queries = torch.Tensor(sample_validation["query_token"]).to(device) + with torch.no_grad(): + # sample_validation_queries = shift_pad_id_left(sample_validation_queries, tokenizer.pad_token_id) + sample_validation_reference_response = torch.Tensor(sample_validation["reference_response_token"]).to(device) + sample_validation_query_reference_responses = torch.cat( + (sample_validation_queries, sample_validation_reference_response), dim=1 + ) + # sample_validation_query_reference_responses = shift_pad_id_left( + # sample_validation_query_reference_responses, tokenizer.pad_token_id + # ) + _, sample_validation_reference_scores, _ = get_reward( + reward_model, sample_validation_query_reference_responses, tokenizer + ) + + iter_dataloader = iter(repeat_generator()) + kl_ctl = AdaptiveKLController(args.rewards.kl_coef, hparams=args.rewards.adaptive_kl) + # WARNING: even with `max_new_tokens` and `min_new_tokens` set to the same value, the number of tokens generated + # may not be the same. TODO: investigate further, we just want to generate a fixed number of tokens + generation_config = GenerationConfig( + max_new_tokens=args.task.response_length, + min_new_tokens=args.task.response_length, + temperature=(args.task.temperature + 1e-7), + top_k=0.0, + top_p=1.0, + do_sample=True, + ) + # use the same `0.01` temperature for validation response generation https://github.com/openai/summarize-from-feedback/blob/700967448d10004279f138666442bf1497d0e705/exps/sample.py#L27 + validation_generation_config= GenerationConfig( + max_new_tokens=args.task.response_length, + min_new_tokens=args.task.response_length, + temperature=(0.01 + 1e-7), + top_k=0.0, + top_p=1.0, + do_sample=True, + ) + + print("===training policy===") + global_step = 0 + stats_shape = (args.ppo.noptepochs, args.ppo.nminibatches, args.ppo.gradient_accumulation_steps) + approxkl_stats = torch.zeros(stats_shape, device=device) + pg_clipfrac_stats = torch.zeros(stats_shape, device=device) + pg_loss_stats = torch.zeros(stats_shape, device=device) + vf_loss_stats = torch.zeros(stats_shape, device=device) + vf_clipfrac_stats = torch.zeros(stats_shape, device=device) + entropy_stats = torch.zeros(stats_shape, device=device) + ratio_stats = torch.zeros(stats_shape, device=device) + model.train() + for update in range(1, args.ppo.num_updates + 1): + global_step += 1 * args.ppo.batch_size + frac = 1.0 - (update - 1.0) / args.ppo.num_updates + lrnow = frac * args.ppo.lr + optimizer.param_groups[0]["lr"] = lrnow + data = next(iter_dataloader) + with torch.no_grad(): + queries = data["query_token"].to(device) + query_responses = generate( + accelerator.unwrap_model(model).policy, + queries, + tokenizer, + generation_config, + ) + context_length = queries.shape[1] + responses = query_responses[:, context_length:] + + # validation + sample_validation_query_responses = generate( + accelerator.unwrap_model(model).policy, + sample_validation_queries, + tokenizer, + validation_generation_config, + ) + sample_validation_responses = sample_validation_query_responses[:, context_length:] + postprocessed_sample_validation_responses = truncate_response(args, tokenizer, sample_validation_responses) + postprocessed_sample_validation_query_responses = torch.cat( + (sample_validation_queries, postprocessed_sample_validation_responses), 1 + ) + torch.cuda.empty_cache() + + # TODO: do I do this with query response or post-processed query response? + output = forward(accelerator.unwrap_model(model).policy, query_responses, tokenizer) + logits = output.logits[:, context_length - 1 : -1] + logits /= (args.task.temperature + 1e-7) + all_logprobs = F.log_softmax(logits, dim=-1) + logprobs = torch.gather(all_logprobs, 2, responses.unsqueeze(-1)).squeeze(-1) + del output, logits, all_logprobs + torch.cuda.empty_cache() + + ref_output = forward(ref_policy, query_responses, tokenizer) + ref_logits = ref_output.logits[:, context_length - 1 : -1] + ref_logits /= (args.task.temperature + 1e-7) + ref_all_logprobs = F.log_softmax(ref_logits, dim=-1) + ref_logprobs = torch.gather(ref_all_logprobs, 2, responses.unsqueeze(-1)).squeeze(-1) + del ref_output, ref_logits, ref_all_logprobs + torch.cuda.empty_cache() + + # **Response Processing** + postprocessed_responses = truncate_response(args, tokenizer, responses) + torch.cuda.empty_cache() + + # 2. run reward model on the truncated responses + postprocessed_query_responses = torch.cat((queries, postprocessed_responses), 1) + # sequence_lengths = first_true_indices(postprocessed_responses == tokenizer.pad_token_id) - 1 + # actual_start = torch.arange(postprocessed_responses.size(0), device=postprocessed_responses.device) + # actual_end = sequence_lengths + # padding_mask = postprocessed_responses == tokenizer.pad_token_id + sequence_lengths = first_true_indices(postprocessed_responses == tokenizer.pad_token_id) - 1 + + full_values, _, _ = get_reward(accelerator.unwrap_model(model).critic, query_responses, tokenizer) + values = full_values[:, context_length - 1 : -1].squeeze(-1) + # values_mask = postprocessed_responses != args.task.truncate_token_id + # values = torch.masked_fill(values, values_mask, 0) + # values = torch.masked_fill(values, padding_mask, 0) + + # logprobs = torch.masked_fill(logprobs, padding_mask, INVALID_LOGPROB) + # ref_logprobs = torch.masked_fill(ref_logprobs, padding_mask, INVALID_LOGPROB) + _, scores, _ = get_reward(reward_model, postprocessed_query_responses, tokenizer) + + _, validation_score, _ = get_reward(reward_model, postprocessed_sample_validation_query_responses, tokenizer) + + # 3. filter response. Ensure that the sample contains truncate_token_id + # responses not passing that filter will receive a low (fixed) score + # only query humans on responses that pass that filter + contain_pad_token = torch.any(postprocessed_responses == tokenizer.pad_token_id, dim=-1) + scores = torch.where(contain_pad_token, scores, torch.full_like(scores, args.task.penalty_reward_value)) + + # TODO: do we need to deal with penalty values? + # penalty_values = torch.full_like(values, 0) + # penalty_values[:,-1] += args.task.penalty_reward_value + # values = torch.where(contain_pad_token, values, penalty_values) + accelerator.print(f"{scores=}, {(contain_pad_token.sum() / len(contain_pad_token))=}") + # torch.cuda.empty_cache() + + # 4. compute rewards + kl = logprobs - ref_logprobs + # kl = torch.masked_fill(kl, padding_mask, 0) + non_score_reward = -kl_ctl.value * kl + rewards = non_score_reward.clone() + actual_start = torch.arange(rewards.size(0), device=rewards.device) + actual_end = sequence_lengths + rewards[[actual_start, actual_end]] += scores + + # 5. whiten rewards + if args.ppo.whiten_rewards: + rewards = whiten(rewards, shift_mean=False) + + if args.print_sample_output_freq > 0 and (update - 1) % args.print_sample_output_freq == 0: + try: + all_decode_validation_queries = tokenizer.batch_decode(sample_validation_queries, skip_special_tokens=True) + all_sample_validation_responses = tokenizer.batch_decode(sample_validation_responses) + all_sample_validation_query_responses_postprocessed = tokenizer.batch_decode( + postprocessed_sample_validation_query_responses, skip_special_tokens=True + ) + all_sample_validation_postprocessed_responses = [ + x[len(y) :] + for x, y in zip(all_sample_validation_query_responses_postprocessed, all_decode_validation_queries) + ] + all_sample_validation_reference_responses = tokenizer.batch_decode(sample_validation_reference_response) + all_sample_validation_df = pd.DataFrame( + { + "query": all_decode_validation_queries, + "response": all_sample_validation_responses, + "postprocessed_response": all_sample_validation_postprocessed_responses, + "reference_responses": all_sample_validation_reference_responses, + "scores": validation_score.float().cpu().numpy(), + "reference_scores": sample_validation_reference_scores.float().cpu().numpy(), + } + ) + if accelerator.is_main_process: + all_sample_validation_df.to_json(f"runs/{run_name}/table.json") + if args.track: + wandb.log({"samples/query_responses": wandb.Table(dataframe=all_sample_validation_df)}, step=update) + print_rich_table("stuff", all_sample_validation_df[:4], console) + + except Exception as e: + print(e) + del ( + all_decode_validation_queries, + all_sample_validation_responses, + all_sample_validation_reference_responses, + all_sample_validation_df, + ) + # del postprocessed_query_responses + # torch.cuda.empty_cache() + + # 6. compute advantages and returns + lastgaelam = 0 + advantages_reversed = [] + gen_length = args.task.response_length + for t in reversed(range(gen_length)): + nextvalues = values[:, t + 1] if t < gen_length - 1 else 0.0 + delta = rewards[:, t] + args.ppo.gamma * nextvalues - values[:, t] + lastgaelam = delta + args.ppo.gamma * args.ppo.lam * lastgaelam + advantages_reversed.append(lastgaelam) + advantages = torch.stack(advantages_reversed[::-1], axis=1) + returns = advantages + values + advantages = whiten(advantages) + return_mean, return_var = returns.mean(), returns.var() + value_mean, value_var = values.mean(), values.var() + writer.add_histogram("rewards", rewards[0].float(), global_step) + writer.add_histogram("advantages", advantages[0].float(), global_step) + accelerator.print("rewards====", rewards[0]) + accelerator.print("advantages====", advantages[0]) + # raise + # pprint({ + # "rewards": rewards, + # "returns": returns, + # "advantages": advantages, + # }) + # breakpoint() + # Do multiple epochs of PPO training, with a fresh random shuffle in each epoch + for ppo_epoch_idx in range(args.ppo.noptepochs): + b_inds = np.random.permutation(args.ppo.local_batch_size) + minibatch_idx = 0 + for mini_batch_start in range(0, args.ppo.local_batch_size, args.ppo.local_mini_batch_size): + mini_batch_end = mini_batch_start + args.ppo.local_mini_batch_size + mini_batch_inds = b_inds[mini_batch_start:mini_batch_end] + gradient_accumulation_idx = 0 + for micro_batch_start in range(0, args.ppo.local_mini_batch_size, args.ppo.local_micro_batch_size): + with accelerator.accumulate(policy): + micro_batch_end = micro_batch_start + args.ppo.local_micro_batch_size + micro_batch_inds = mini_batch_inds[micro_batch_start:micro_batch_end] + mb_return = returns[micro_batch_inds] + mb_advantage = advantages[micro_batch_inds] + mb_values = values[micro_batch_inds] + mb_responses = responses[micro_batch_inds] + mb_query_responses = query_responses[micro_batch_inds] + mb_logprobs = logprobs[micro_batch_inds] + + output, vpred_temp = forward(model, mb_query_responses, tokenizer) + logits = output.logits[:, context_length - 1 : -1] + logits /= (args.task.temperature + 1e-7) + new_all_logprobs = F.log_softmax(logits, dim=-1) + new_logprobs = torch.gather(new_all_logprobs, 2, mb_responses.unsqueeze(-1)).squeeze(-1) + # new_logprobs = torch.masked_fill(new_logprobs, padding_mask[micro_batch_inds], INVALID_LOGPROB) + vpred = vpred_temp[:, context_length - 1 : -1].squeeze(-1) + # vpred = torch.masked_fill(vpred, padding_mask[micro_batch_inds], 0) + # vpred = torch.masked_fill(vpred, values_mask[micro_batch_inds], 0) + vpredclipped = torch.clamp( + vpred, + mb_values - args.ppo.cliprange_value, + mb_values + args.ppo.cliprange_value, + ) + vf_losses1 = torch.square(vpred - mb_return) + vf_losses2 = torch.square(vpredclipped - mb_return) + vf_loss = 0.5 * torch.max(vf_losses1, vf_losses2).mean() + vf_clipfrac = (vf_losses2 > vf_losses1).float().mean() + logprobs_diff = new_logprobs - mb_logprobs + ratio = torch.exp(logprobs_diff) + pg_losses = -mb_advantage * ratio + pg_losses2 = -mb_advantage * torch.clamp(ratio, 1.0 - args.ppo.cliprange, 1.0 + args.ppo.cliprange) + pg_loss = torch.max(pg_losses, pg_losses2).mean() + pg_clipfrac = (pg_losses2 > pg_losses).float().mean() + loss = pg_loss + args.ppo.vf_coef * vf_loss + accelerator.backward(loss) + optimizer.step() + optimizer.zero_grad() + prob_dist = torch.nn.functional.softmax(logits, dim=-1) + entropy = torch.logsumexp(logits, dim=-1) - torch.sum(prob_dist * logits, dim=-1) + approxkl = 0.5 * (logprobs_diff**2).mean() + # if ppo_epoch_idx == 0 and micro_batch_start == 0: + # torch.testing.assert_close(ratio, torch.zeros_like(ratio) + 1, atol=1e-4, rtol=1e-4) + # if ppo_epoch_idx == 0: + # pprint({ + # # "responses": responses, + # # "values": values, + # "rewards": rewards, + # # "scores": scores, + # "advantages": advantages, + # # "ratio": ratio, + # # "pg_losses": pg_losses, + # # "approxkl": approxkl, + # # "pg_loss": pg_loss, + # # "pg_clipfrac": pg_clipfrac, + # # "ratio": ratio.mean(), + # # "vf_loss": vf_loss, + # # "vf_clipfrac": vf_clipfrac, + # # "entropy": masked_mean(entropy, ~padding_mask[micro_batch_inds]), + # }) + # breakpoint() + with torch.no_grad(): + approxkl_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = approxkl + pg_clipfrac_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = pg_clipfrac + pg_loss_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = pg_loss + vf_loss_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = vf_loss + vf_clipfrac_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = vf_clipfrac + entropy_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = entropy.mean() + ratio_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = ratio.mean() + gradient_accumulation_idx += 1 + minibatch_idx += 1 + if accelerator.is_main_process: + console.print( + f"ppo_epoch_idx", + ppo_epoch_idx, + "approxkl", + approxkl_stats[:ppo_epoch_idx+1].mean().item(), + "pg_loss", + pg_loss_stats[:ppo_epoch_idx+1].mean().item(), + "pg_clipfrac", + pg_clipfrac_stats[:ppo_epoch_idx+1].mean().item(), + "ratio", + ratio_stats[:ppo_epoch_idx+1].mean().item(), + ) + # raise + # breakpoint() + with torch.no_grad(): + if not args.deepspeed: # for some reason there is a OOM with the `writer.add_histogram` + writer.add_histogram("ppo/val/ratio_hist", ratio, update) + mean_kl = kl.sum(1).mean() + mean_entropy = (-logprobs).sum(1).mean() + mean_non_score_reward = non_score_reward.sum(1).mean() + writer.add_scalar("objective/kl_coef", kl_ctl.value, update) + writer.add_scalar("objective/kl", accelerator.gather(mean_kl).mean().item(), update) + writer.add_scalar("objective/entropy", accelerator.gather(mean_entropy).mean().item(), update) + writer.add_scalar("objective/non_score_reward", accelerator.gather(mean_non_score_reward).mean().item(), update) + writer.add_scalar( + "objective/score_total", accelerator.gather(mean_non_score_reward + scores.mean()).mean().item(), update + ) + writer.add_scalar("objective/scores", accelerator.gather(scores.mean()).mean().item(), update) + writer.add_scalar("objective/validation_score", accelerator.gather(validation_score.mean()).mean().item(), update) + writer.add_scalar("ppo/loss/policy", accelerator.gather(pg_loss).mean().item(), update) + writer.add_scalar("ppo/loss/value", accelerator.gather(vf_loss).mean().item(), update) + writer.add_scalar("ppo/loss/total", accelerator.gather(loss).mean().item(), update) + writer.add_scalar("ppo/policy/entropy", accelerator.gather(entropy.mean()).mean().item(), update) + writer.add_scalar("ppo/policy/approxkl", accelerator.gather(approxkl).mean().item(), update) + writer.add_scalar("ppo/policy/clipfrac", accelerator.gather(pg_clipfrac).mean().item(), update) + writer.add_scalar("ppo/policy/approxkl_avg", accelerator.gather(approxkl_stats).mean().item(), update) + writer.add_scalar("ppo/policy/clipfrac_avg", accelerator.gather(pg_clipfrac_stats).mean().item(), update) + writer.add_scalar("ppo/loss/policy_avg", accelerator.gather(pg_loss_stats).mean().item(), update) + writer.add_scalar("ppo/loss/value_avg", accelerator.gather(vf_loss_stats).mean().item(), update) + writer.add_scalar("ppo/val/clipfrac_avg", accelerator.gather(vf_clipfrac_stats).mean().item(), update) + writer.add_scalar("ppo/policy/entropy_avg", accelerator.gather(entropy_stats).mean().item(), update) + writer.add_scalar("ppo/returns/mean", accelerator.gather(return_mean).mean().item(), update) + writer.add_scalar("ppo/returns/var", accelerator.gather(return_var).mean().item(), update) + writer.add_scalar("ppo/val/vpred", accelerator.gather(vpred.mean()).mean().item(), update) + writer.add_scalar("ppo/val/error", accelerator.gather(vf_losses1.mean()).mean().item(), update) + writer.add_scalar("ppo/val/clipfrac", accelerator.gather(vf_clipfrac).mean().item(), update) + writer.add_scalar("ppo/val/mean", accelerator.gather(value_mean).mean().item(), update) + writer.add_scalar("ppo/val/var", accelerator.gather(value_var).mean().item(), update) + writer.add_scalar("ppo/val/ratio", accelerator.gather(ratio_stats).mean().item(), update) + writer.add_scalar("ppo/val/ratio_var", accelerator.gather(ratio_stats).var().item(), update) + writer.add_scalar("ppo/val/advantage", accelerator.gather(advantages.mean()).mean().item(), update) + writer.add_scalar("ppo/val/advantage_var", accelerator.gather(advantages.mean()).var().item(), update) + writer.add_scalar("ppo/val/num_eos_tokens", (responses == tokenizer.eos_token_id).sum().item(), update) + writer.add_scalar("ppo/lr", lrnow, update) + writer.add_scalar("ppo/episode", global_step, update) + if args.rewards.use_adaptive_kl: + kl_ctl.update(mean_kl.item(), args.ppo.batch_size) + del kl, mean_kl, mean_entropy, mean_non_score_reward, scores + + # save model + if args.save_path: + os.makedirs(os.path.dirname(args.save_path), exist_ok=True) + accelerator.save_model(policy, args.save_path, max_shard_size="1000GB") + + if args.upload_model and accelerator.is_main_process: + repo_name = f"{args.exp_name}__{args.rewards.label_dataset}__seed{args.seed}__{int(time.time())}" + repo_id = f"{args.hf_entity}/{repo_name}" if args.hf_entity else repo_name + policy.save_pretrained(repo_id, safe_serialization=True, push_to_hub=True) + tokenizer.save_pretrained(repo_id, push_to_hub=True) + +# if __name__ == "__main__": +# args = tyro.cli(Args) +# train(args) diff --git a/lm_human_preference_details/train_policy_accelerate_summarize_separate7_correct_reward_index_no_load_critic.py b/lm_human_preference_details/train_policy_accelerate_summarize_separate7_correct_reward_index_no_load_critic.py new file mode 100644 index 0000000..a316656 --- /dev/null +++ b/lm_human_preference_details/train_policy_accelerate_summarize_separate7_correct_reward_index_no_load_critic.py @@ -0,0 +1,979 @@ +import os +import random +import time +from dataclasses import asdict, dataclass, field +from types import SimpleNamespace +from typing import List, Literal, Optional + +import numpy as np +import pandas as pd +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +import tyro +from accelerate import Accelerator +from accelerate.state import AcceleratorState +from datasets import load_dataset +from rich.console import Console +from rich.pretty import pprint +from rich.table import Table +from torch import Tensor, optim +from torch.optim.optimizer import ( + _dispatch_sqrt, + _get_value, + _use_grad_for_differentiable, +) +from torch.utils.data import DataLoader +from torch.utils.tensorboard import SummaryWriter +from transformers import ( + AutoConfig, + AutoModelForCausalLM, + AutoTokenizer, + GenerationConfig, +) + + +INVALID_LOGPROB = 1.0 + + +@dataclass +class AdaptiveKLParams: + target: float = 6.0 + horizon: int = 10000 # in episodes + + +@dataclass +class RewardHParams: + use_adaptive_kl: bool = True + adaptive_kl: Optional[AdaptiveKLParams] = field(default_factory=AdaptiveKLParams) + trained_model: Optional[str] = "" + label_dataset: tyro.conf.Suppress[Optional[str]] = None + dataset_mean: float = 0. + dataset_std: float = 1. + kl_coef: float = 0.15 + + +@dataclass +class PpoHParams: + total_episodes: int = 1000000 + local_batch_size: int = 64 + local_mini_batch_size: tyro.conf.Suppress[int] = None + batch_size: tyro.conf.Suppress[int] = None + mini_batch_size: tyro.conf.Suppress[int] = None + gradient_accumulation_steps: int = 64 + """gradient accumulation steps""" + local_micro_batch_size: tyro.conf.Suppress[int] = None + """per rank micro batch size""" + world_size: tyro.conf.Suppress[int] = None + batch_size: tyro.conf.Suppress[int] = None + minibatch_size: tyro.conf.Suppress[int] = None + num_updates: tyro.conf.Suppress[int] = None + nminibatches: int = 1 + noptepochs: int = 4 + lr: float = 0.00001 + eps: float = 1e-5 + vf_coef: float = 0.1 + cliprange: float = 0.2 + cliprange_value: float = 0.2 + gamma: float = 1 + lam: float = 0.95 + whiten_rewards: bool = True + + +@dataclass +class TaskHParams: + # Query params + query_length: int = 512 + query_dataset: str = "vwxyzjn/summarize_from_feedback_tldr_3_filtered_oai_preprocessing_pythia-160m_53" + + query_format_str: Optional[str] = "SUBREDDIT: r/{subreddit}\n\nTITLE: {title}\n\nPOST: {post}\n\nTL;DR:" + query_truncate_field: Optional[str] = "post" + query_truncate_text: Optional[str] = "\n" + query_padding: Optional[str] = None # defaults to repeated spaces + query_pad_side: Optional[str] = "left" + + # Response params + response_length: int = 53 + + # Truncate response after the first occurrence of this token at or after index after when sampling. + truncate_token: Literal["eos"] = "eos" + truncate_token_id: Optional[int] = None + truncate_after: int = 16 + penalty_reward_value: int = -1 + + # LM params + temperature: float = 0.7 + + +# a patch +@dataclass +class TaskQueryHParams: + length: int = None + dataset: str = None + format_str: Optional[str] = None # if underlying dataset yields dicts, can format arbitrarily + truncate_field: Optional[str] = None + truncate_text: Optional[str] = None + padding: Optional[str] = None # defaults to repeated spaces + pad_side: Optional[str] = None + + +@dataclass +class Args: + # common args + exp_name: str = os.path.basename(__file__)[: -len(".py")] + """the name of this experiment""" + seed: int = 1 + """seed of the experiment""" + track: bool = False + """if toggled, this experiment will be tracked with Weights and Biases""" + wandb_project_name: str = "tldr_summarize" + """the wandb's project name""" + wandb_entity: Optional[str] = None + """the entity (team) of wandb's project""" + cuda: bool = True + """Whether to use cuda if available.""" + run_name: tyro.conf.Suppress[str] = None + """TO BE FILLED: a unique name of this run""" + load_from_cache_file: bool = False + """Whether to load data from the local cache file in `dataset.map`""" + upload_model: bool = False + "whether to upload the saved model to huggingface" + hf_entity: str = "" + "the user or org name of the model repository from the Hugging Face Hub" + + base_model: str = "EleutherAI/pythia-160m" + """the name of the pretrained model to use""" + dropout_layer_keys: List[str] = field(default_factory=lambda: ["attn_pdrop", "embd_pdrop", "resid_pdrop", "summary_first_dropout"]) + """Which layers to apply dropout to""" + deepspeed: bool = False + """Whether to use deepspeed to train the model""" + print_sample_output_freq: int = 10 + """How often to print sample output""" + save_path: str = "models/ppo_policy" + """Where to save the model""" + optimizer: Literal["tf_adam", "adam", "adamw"] = "adamw" + """Which optimizer to use""" + sft_model_path: str = "" + """Where to load the SFT model""" + task: TaskHParams = field(default_factory=TaskHParams) + rewards: RewardHParams = field(default_factory=RewardHParams) + ppo: PpoHParams = field(default_factory=PpoHParams) + + +# taken from https://github.com/microsoft/DeepSpeedExamples/blob/737c6740bec38b77a24a59135b6481a53d566b38/applications/DeepSpeed-Chat/training/utils/model/model_utils.py#L20C1-L26C52 +def configure_dropout(model_config, dropout_layer_keys, dropout): + if dropout is not None: + for key in dropout_layer_keys: + if hasattr(model_config, key): + print(f"Setting model_config.{key} to {dropout}") + setattr(model_config, key, dropout) + + +def print_rich_table(title: str, df: pd.DataFrame, console: Console) -> Table: + table = Table(show_lines=True) + for column in df.columns: + table.add_column(column) + for _, row in df.iterrows(): + table.add_row(*row.astype(str).tolist()) + console.rule(f"[bold red]{title}") + console.print(table) + + +def _single_tensor_adam( + params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_avg_sqs: List[Tensor], + max_exp_avg_sqs: List[Tensor], + state_steps: List[Tensor], + grad_scale: Optional[Tensor], + found_inf: Optional[Tensor], + *, + amsgrad: bool, + beta1: float, + beta2: float, + lr: float, + weight_decay: float, + eps: float, + maximize: bool, + capturable: bool, + differentiable: bool, +): + assert grad_scale is None and found_inf is None + + for i, param in enumerate(params): + grad = grads[i] if not maximize else -grads[i] + exp_avg = exp_avgs[i] + exp_avg_sq = exp_avg_sqs[i] + step_t = state_steps[i] + # update step + step_t += 1 + # Decay the first and second moment running average coefficient + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2) + step = _get_value(step_t) + + ### pytorch adam implementation: + # bias_correction1 = 1 - beta1 ** step + # bias_correction2 = 1 - beta2 ** step + # step_size = lr / bias_correction1 + # bias_correction2_sqrt = _dispatch_sqrt(bias_correction2) + # denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps) + # param.addcdiv_(exp_avg, denom, value=-step_size) + + ### tensorflow adam implementation: + lr_t = lr * _dispatch_sqrt(1 - beta2**step) / (1 - beta1**step) + denom = exp_avg_sq.sqrt().add_(eps) + param.addcdiv_(exp_avg, denom, value=-lr_t) + + +def adam( + params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_avg_sqs: List[Tensor], + max_exp_avg_sqs: List[Tensor], + state_steps: List[Tensor], + # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627 + # setting this as kwarg for now as functional API is compiled by torch/distributed/optim + foreach: Optional[bool] = None, + capturable: bool = False, + differentiable: bool = False, + fused: Optional[bool] = None, + grad_scale: Optional[Tensor] = None, + found_inf: Optional[Tensor] = None, + *, + amsgrad: bool, + beta1: float, + beta2: float, + lr: float, + weight_decay: float, + eps: float, + maximize: bool, +): + func = _single_tensor_adam + + func( + params, + grads, + exp_avgs, + exp_avg_sqs, + max_exp_avg_sqs, + state_steps, + amsgrad=amsgrad, + beta1=beta1, + beta2=beta2, + lr=lr, + weight_decay=weight_decay, + eps=eps, + maximize=maximize, + capturable=capturable, + differentiable=differentiable, + grad_scale=grad_scale, + found_inf=found_inf, + ) + + +class AdamTensorFlowStyle(optim.Adam): + @_use_grad_for_differentiable + def step(self, closure=None): + self._cuda_graph_capture_health_check() + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + params_with_grad = [] + grads = [] + exp_avgs = [] + exp_avg_sqs = [] + max_exp_avg_sqs = [] + state_steps = [] + beta1, beta2 = group["betas"] + + self._init_group( + group, + params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + max_exp_avg_sqs, + state_steps, + ) + + adam( + params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + max_exp_avg_sqs, + state_steps, + amsgrad=group["amsgrad"], + beta1=beta1, + beta2=beta2, + lr=group["lr"], + weight_decay=group["weight_decay"], + eps=group["eps"], + maximize=group["maximize"], + foreach=group["foreach"], + capturable=group["capturable"], + differentiable=group["differentiable"], + fused=group["fused"], + grad_scale=getattr(self, "grad_scale", None), + found_inf=getattr(self, "found_inf", None), + ) + + return loss + + +class AdaptiveKLController: + def __init__(self, init_kl_coef: float, hparams: AdaptiveKLParams): + self.value = init_kl_coef + self.hparams = hparams + + def update(self, current, n_steps): + target = self.hparams.target + proportional_error = np.clip(current / target - 1, -0.2, 0.2) + mult = 1 + proportional_error * n_steps / self.hparams.horizon + self.value *= mult + + +def layer_init(layer, std=np.sqrt(2), bias_const=0.0): + torch.nn.init.normal_(layer.weight, std=std) + torch.nn.init.constant_(layer.bias, val=bias_const) + return layer + + +def whiten(values, shift_mean=True): + # `unbiased=False` matches TF `tf.nn.moments`'s setting + mean, var = torch.mean(values), torch.var(values, unbiased=False) + whitened = (values - mean) * torch.rsqrt(var + 1e-8) + if not shift_mean: + whitened += mean + return whitened + + +class AutoModelForCausalLMWithRewardHead(nn.Module): + def __init__(self, lm_backbone): + super().__init__() + self.lm_backbone = lm_backbone + # self.scalar_head = layer_init( + # nn.Linear(lm_backbone.config.hidden_size, 1), + # std=1 / np.sqrt(lm_backbone.config.hidden_size + 1), + # ) + self.scalar_head = layer_init(nn.Linear(lm_backbone.config.hidden_size, 1), std=0) + # self.reward_gain = torch.nn.Parameter(torch.tensor(1.0), requires_grad=False) + self.reward_bias = torch.nn.Parameter(torch.tensor(0.0), requires_grad=False) + + def forward(self, **kwargs): + output = self.lm_backbone(**kwargs) + reward = self.scalar_head(output.hidden_states[-1]) - self.reward_bias + return reward + + +# taken from https://github.com/OpenLMLab/MOSS-RLHF/blob/40b91eb2f2b71b16919addede0341d2bef70825d/ppo/ppo_trainer.py#L29 +# we did this we can do a single `model = accelerator.prepare(model)` +class PolicyAndValueWrapper(nn.Module): + def __init__(self, policy, critic) -> None: + super().__init__() + self.policy = policy + self.critic = critic + + def forward(self, **kwargs): + return self.policy(**kwargs), self.critic(**kwargs) + + +def ceil_div(a, b): + return (a - 1) // b + 1 + + +def exact_div(a, b): + q = a // b + if a != q * b: + raise ValueError(f"Inexact division: {a} / {b} = {a / b}") + return q + + +def generate(lm_backbone, queries, tokenizer, generation_config): + """generate in a way that does not affect padding tokens""" + context_length = queries.shape[1] + attention_mask = queries != tokenizer.pad_token_id + input_ids = torch.masked_fill(queries, ~attention_mask, 0) + output = lm_backbone.generate( + input_ids=input_ids, + attention_mask=attention_mask, + # position_ids=attention_mask.cumsum(1) - attention_mask.long(), # generation collapsed if this was turned on. TODO: why does generation collapse with this? + generation_config=generation_config, + return_dict_in_generate=True, + ) + return torch.cat((queries, output.sequences[:, context_length:]), dim=1) + + +def first_true_indices(bools, dtype=torch.long): + """ + Takes an N-dimensional bool tensor and returns an (N-1)-dimensional tensor of integers giving + the position of the first True in each "row". + + Returns the length of the rows (bools.size(-1)) if no element is True in a given row. + """ + row_len = bools.size(-1) + zero_or_index = row_len * (~bools).type(dtype) + torch.arange(row_len, dtype=dtype, device=bools.device) + return torch.min(zero_or_index, dim=-1).values + + +def truncate_response(args, tokenizer, responses): + trunc_idxs = first_true_indices(responses == args.task.truncate_token_id).unsqueeze(-1) + new_size = [1] * (len(responses.size()) - 1) + [args.task.response_length] + idxs = torch.arange(args.task.response_length, device=responses.device).view(*new_size) + postprocessed_responses = torch.masked_fill(responses, idxs > trunc_idxs, tokenizer.pad_token_id) + return postprocessed_responses + + +def get_reward(reward_model, query_responses, tokenizer): + attention_mask = query_responses != tokenizer.pad_token_id + # position_ids = attention_mask.cumsum(1) - attention_mask.long() # exclusive cumsum + input_ids = torch.masked_fill(query_responses, ~attention_mask, 0) + reward_logits = reward_model( + input_ids=input_ids, + attention_mask=attention_mask, + # position_ids=position_ids, + return_dict=True, + output_hidden_states=True, + ) + sequence_lengths = first_true_indices(query_responses == tokenizer.pad_token_id) - 1 + # sequence_lengths1 = ( + # torch.eq(query_responses, tokenizer.pad_token_id).long().argmax(-1) - 1).to( + # query_responses.device + # ) + # print(f"======={sequence_lengths1=} {sequence_lengths=}") + # https://github.com/huggingface/transformers/blob/dc68a39c8111217683bf49a4912d0c9018bab33d/src/transformers/models/gpt2/modeling_gpt2.py#L1454 + return reward_logits, reward_logits[torch.arange(reward_logits.size(0), device=reward_logits.device), sequence_lengths].squeeze(-1), sequence_lengths + + +def forward(policy, query_responses, tokenizer): + attention_mask = query_responses != tokenizer.pad_token_id + # position_ids = attention_mask.cumsum(1) - attention_mask.long() # exclusive cumsum + input_ids = torch.masked_fill(query_responses, ~attention_mask, 0) + return policy( + input_ids=input_ids, + attention_mask=attention_mask, + # position_ids=position_ids, + return_dict=True, + output_hidden_states=True, + ) + + +# def train(args: Args): +if __name__ == "__main__": + args = tyro.cli(Args) + accelerator = Accelerator(gradient_accumulation_steps=args.ppo.gradient_accumulation_steps) + args.ppo.world_size = accelerator.num_processes + args.ppo.batch_size = int(args.ppo.local_batch_size * args.ppo.world_size) + args.ppo.minibatch_size = exact_div(args.ppo.batch_size, args.ppo.nminibatches) + args.ppo.local_mini_batch_size = exact_div(args.ppo.local_batch_size, args.ppo.nminibatches) + args.ppo.local_micro_batch_size = exact_div(args.ppo.local_mini_batch_size, args.ppo.gradient_accumulation_steps) + if args.ppo.whiten_rewards: + assert ( + args.ppo.local_mini_batch_size >= 8 + ), f"Per-rank minibatch size {args.ppo.local_mini_batch_size} is insufficient for whitening" + # `per_rank_rollout_batch_size` is our `args.ppo.local_batch_size` + # `per_rank_minibatch_size` is our `args.ppo.local_mini_batch_size` + args.ppo.num_updates = args.ppo.total_episodes // args.ppo.batch_size + tokenizer = AutoTokenizer.from_pretrained( + args.base_model, + padding_side="right", + trust_remote_code=True, + ) + # we use the padding token manually but do not resize the token embedding of the model + tokenizer.add_special_tokens({"pad_token": "[PAD]"}) + if args.task.truncate_token == "eos": + args.task.truncate_token_id = tokenizer.eos_token_id + + console = Console(force_terminal=True) + run_name = f"{args.exp_name}__{args.seed}__{int(time.time())}" + writer = SimpleNamespace() # dummy writer + writer.add_scalar = lambda x, y, z: None + writer.add_histogram = lambda x, y, z: None + if accelerator.is_main_process: + if args.track: + import wandb + + wandb.init( + project=args.wandb_project_name, + entity=args.wandb_entity, + sync_tensorboard=True, + config=asdict(args), + name=run_name, + save_code=True, + ) + wandb.run.log_code(".") + writer = SummaryWriter(f"runs/{run_name}") + writer.add_text( + "hyperparameters", + "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])), + ) + pprint(args) + device = accelerator.device + local_seed = args.seed + accelerator.process_index * 100003 # Prime + random.seed(local_seed) + np.random.seed(local_seed) + torch.manual_seed(local_seed) + torch.backends.cudnn.deterministic = True + + model_config = AutoConfig.from_pretrained(args.base_model) + configure_dropout(model_config, args.dropout_layer_keys, 0.0) # disable dropout + if accelerator.is_main_process: + pprint(model_config) + critic = AutoModelForCausalLMWithRewardHead( + AutoModelForCausalLM.from_pretrained( + args.base_model, + config=model_config, + trust_remote_code=True, + ) + ) + reward_model = AutoModelForCausalLMWithRewardHead( + AutoModelForCausalLM.from_pretrained( + args.base_model, + config=model_config, + trust_remote_code=True, + ) + ) + if args.rewards.trained_model: + # critic.load_state_dict(torch.load(args.rewards.trained_model, map_location=device), strict=False) + # critic.reward_bias.data = torch.tensor(args.rewards.dataset_mean) + reward_model.load_state_dict(torch.load(args.rewards.trained_model, map_location=device), strict=False) + reward_model.reward_bias.data = torch.tensor(args.rewards.dataset_mean) + print(f"loaded pretrained reward model from {args.rewards.trained_model}") + # each class should have a separate pretrained model that do not share weights + ref_policy = AutoModelForCausalLM.from_pretrained(args.base_model, config=model_config, trust_remote_code=True) + policy = AutoModelForCausalLM.from_pretrained(args.base_model, config=model_config, trust_remote_code=True) + policy.gradient_checkpointing_enable() + accelerator.print(policy) + critic.lm_backbone.gradient_checkpointing_enable() + accelerator.print(critic) + if args.sft_model_path: + policy.load_state_dict(torch.load(args.sft_model_path, map_location=device)) + ref_policy.load_state_dict(torch.load(args.sft_model_path, map_location=device)) + print(f"loaded pretrained policy from {args.sft_model_path}") + policy.generation_config.eos_token_id = None # disable `pad_token_id` and `eos_token_id` because we just want to + policy.generation_config.pad_token_id = None # generate tokens without truncation / padding + model = PolicyAndValueWrapper(policy, critic) + if args.optimizer == "tf_adam": + optimizer = AdamTensorFlowStyle(model.parameters(), lr=args.ppo.lr, eps=args.ppo.eps) + elif args.optimizer == "adam": + optimizer = optim.Adam(model.parameters(), lr=args.ppo.lr, eps=args.ppo.eps) + elif args.optimizer == "adamw": + optimizer = optim.AdamW(model.parameters(), lr=args.ppo.lr, eps=args.ppo.eps) + + dataset = load_dataset(args.task.query_dataset, split="train") + validation_dataset = load_dataset(args.task.query_dataset, split="validation") + dataset = dataset.with_format("torch", columns=["query_token", "reference_response_token"]) + dataset = dataset.shuffle(seed=local_seed) + dataloader = DataLoader(dataset, batch_size=args.ppo.local_batch_size) + validation_dataset = validation_dataset.with_format("torch", columns=["query_token", "reference_response_token"]) + validation_dataloader = DataLoader(validation_dataset, batch_size=args.ppo.local_batch_size) + model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader) + validation_dataloader = accelerator.prepare(validation_dataloader) + if args.deepspeed: + import deepspeed + + deepspeed_states = AcceleratorState().deepspeed_plugin + # deepspeed_states.deepspeed_config["train_micro_batch_size_per_gpu"] = args.ppo.local_micro_batch_size + # deepspeed_states.deepspeed_config["checkpoint"] = {"use_node_local_storage": True} + + offload = False + eval_ds_config = { + "train_micro_batch_size_per_gpu": deepspeed_states.deepspeed_config["train_micro_batch_size_per_gpu"], + "bf16": {"enabled": True}, + "prescale_gradients": False, + "wall_clock_breakdown": False, + } + if offload: + eval_ds_config["zero_optimization"] = { + "stage": 3, + "stage3_param_persistence_threshold": 1e4, + "offload_param": { + "device": "cpu" + } + } + accelerator.print(f"{eval_ds_config=}") + reward_model, *_ = deepspeed.initialize(model=reward_model, config=eval_ds_config) + reward_model.eval() + ref_policy, *_ = deepspeed.initialize(model=ref_policy, config=eval_ds_config) + ref_policy.eval() + else: + ref_policy = ref_policy.to(device) + reward_model = reward_model.to(device) + + def repeat_generator(): # TODO: ideally we shuffle the dataloader as well + while True: + yield from dataloader + + sample_validation_inds = np.arange(args.ppo.batch_size) + local_sample_validation_inds = sample_validation_inds[accelerator.process_index :: accelerator.num_processes] + sample_validation = validation_dataset[local_sample_validation_inds] + sample_validation_queries = torch.Tensor(sample_validation["query_token"]).to(device) + with torch.no_grad(): + # sample_validation_queries = shift_pad_id_left(sample_validation_queries, tokenizer.pad_token_id) + sample_validation_reference_response = torch.Tensor(sample_validation["reference_response_token"]).to(device) + sample_validation_query_reference_responses = torch.cat( + (sample_validation_queries, sample_validation_reference_response), dim=1 + ) + # sample_validation_query_reference_responses = shift_pad_id_left( + # sample_validation_query_reference_responses, tokenizer.pad_token_id + # ) + _, sample_validation_reference_scores, _ = get_reward( + reward_model, sample_validation_query_reference_responses, tokenizer + ) + + iter_dataloader = iter(repeat_generator()) + kl_ctl = AdaptiveKLController(args.rewards.kl_coef, hparams=args.rewards.adaptive_kl) + # WARNING: even with `max_new_tokens` and `min_new_tokens` set to the same value, the number of tokens generated + # may not be the same. TODO: investigate further, we just want to generate a fixed number of tokens + generation_config = GenerationConfig( + max_new_tokens=args.task.response_length, + min_new_tokens=args.task.response_length, + temperature=(args.task.temperature + 1e-7), + top_k=0.0, + top_p=1.0, + do_sample=True, + ) + # use the same `0.01` temperature for validation response generation https://github.com/openai/summarize-from-feedback/blob/700967448d10004279f138666442bf1497d0e705/exps/sample.py#L27 + validation_generation_config= GenerationConfig( + max_new_tokens=args.task.response_length, + min_new_tokens=args.task.response_length, + temperature=(0.01 + 1e-7), + top_k=0.0, + top_p=1.0, + do_sample=True, + ) + + print("===training policy===") + global_step = 0 + stats_shape = (args.ppo.noptepochs, args.ppo.nminibatches, args.ppo.gradient_accumulation_steps) + approxkl_stats = torch.zeros(stats_shape, device=device) + pg_clipfrac_stats = torch.zeros(stats_shape, device=device) + pg_loss_stats = torch.zeros(stats_shape, device=device) + vf_loss_stats = torch.zeros(stats_shape, device=device) + vf_clipfrac_stats = torch.zeros(stats_shape, device=device) + entropy_stats = torch.zeros(stats_shape, device=device) + ratio_stats = torch.zeros(stats_shape, device=device) + model.train() + for update in range(1, args.ppo.num_updates + 1): + global_step += 1 * args.ppo.batch_size + frac = 1.0 - (update - 1.0) / args.ppo.num_updates + lrnow = frac * args.ppo.lr + optimizer.param_groups[0]["lr"] = lrnow + data = next(iter_dataloader) + with torch.no_grad(): + queries = data["query_token"].to(device) + query_responses = generate( + accelerator.unwrap_model(model).policy, + queries, + tokenizer, + generation_config, + ) + context_length = queries.shape[1] + responses = query_responses[:, context_length:] + + # validation + sample_validation_query_responses = generate( + accelerator.unwrap_model(model).policy, + sample_validation_queries, + tokenizer, + validation_generation_config, + ) + sample_validation_responses = sample_validation_query_responses[:, context_length:] + postprocessed_sample_validation_responses = truncate_response(args, tokenizer, sample_validation_responses) + postprocessed_sample_validation_query_responses = torch.cat( + (sample_validation_queries, postprocessed_sample_validation_responses), 1 + ) + torch.cuda.empty_cache() + + # TODO: do I do this with query response or post-processed query response? + output = forward(accelerator.unwrap_model(model).policy, query_responses, tokenizer) + logits = output.logits[:, context_length - 1 : -1] + logits /= (args.task.temperature + 1e-7) + all_logprobs = F.log_softmax(logits, dim=-1) + logprobs = torch.gather(all_logprobs, 2, responses.unsqueeze(-1)).squeeze(-1) + del output, logits, all_logprobs + torch.cuda.empty_cache() + + ref_output = forward(ref_policy, query_responses, tokenizer) + ref_logits = ref_output.logits[:, context_length - 1 : -1] + ref_logits /= (args.task.temperature + 1e-7) + ref_all_logprobs = F.log_softmax(ref_logits, dim=-1) + ref_logprobs = torch.gather(ref_all_logprobs, 2, responses.unsqueeze(-1)).squeeze(-1) + del ref_output, ref_logits, ref_all_logprobs + torch.cuda.empty_cache() + + # **Response Processing** + postprocessed_responses = truncate_response(args, tokenizer, responses) + torch.cuda.empty_cache() + + # 2. run reward model on the truncated responses + postprocessed_query_responses = torch.cat((queries, postprocessed_responses), 1) + # sequence_lengths = first_true_indices(postprocessed_responses == tokenizer.pad_token_id) - 1 + # actual_start = torch.arange(postprocessed_responses.size(0), device=postprocessed_responses.device) + # actual_end = sequence_lengths + # padding_mask = postprocessed_responses == tokenizer.pad_token_id + sequence_lengths = first_true_indices(postprocessed_responses == tokenizer.pad_token_id) - 1 + + full_values, _, _ = get_reward(accelerator.unwrap_model(model).critic, query_responses, tokenizer) + values = full_values[:, context_length - 1 : -1].squeeze(-1) + # values_mask = postprocessed_responses != args.task.truncate_token_id + # values = torch.masked_fill(values, values_mask, 0) + # values = torch.masked_fill(values, padding_mask, 0) + + # logprobs = torch.masked_fill(logprobs, padding_mask, INVALID_LOGPROB) + # ref_logprobs = torch.masked_fill(ref_logprobs, padding_mask, INVALID_LOGPROB) + _, scores, _ = get_reward(reward_model, postprocessed_query_responses, tokenizer) + + _, validation_score, _ = get_reward(reward_model, postprocessed_sample_validation_query_responses, tokenizer) + + # 3. filter response. Ensure that the sample contains truncate_token_id + # responses not passing that filter will receive a low (fixed) score + # only query humans on responses that pass that filter + contain_pad_token = torch.any(postprocessed_responses == tokenizer.pad_token_id, dim=-1) + scores = torch.where(contain_pad_token, scores, torch.full_like(scores, args.task.penalty_reward_value)) + + # TODO: do we need to deal with penalty values? + # penalty_values = torch.full_like(values, 0) + # penalty_values[:,-1] += args.task.penalty_reward_value + # values = torch.where(contain_pad_token, values, penalty_values) + accelerator.print(f"{scores=}, {(contain_pad_token.sum() / len(contain_pad_token))=}") + # torch.cuda.empty_cache() + + # 4. compute rewards + kl = logprobs - ref_logprobs + # kl = torch.masked_fill(kl, padding_mask, 0) + non_score_reward = -kl_ctl.value * kl + rewards = non_score_reward.clone() + actual_start = torch.arange(rewards.size(0), device=rewards.device) + actual_end = sequence_lengths + rewards[[actual_start, actual_end]] += scores + + # 5. whiten rewards + if args.ppo.whiten_rewards: + rewards = whiten(rewards, shift_mean=False) + + if args.print_sample_output_freq > 0 and (update - 1) % args.print_sample_output_freq == 0: + try: + all_decode_validation_queries = tokenizer.batch_decode(sample_validation_queries, skip_special_tokens=True) + all_sample_validation_responses = tokenizer.batch_decode(sample_validation_responses) + all_sample_validation_query_responses_postprocessed = tokenizer.batch_decode( + postprocessed_sample_validation_query_responses, skip_special_tokens=True + ) + all_sample_validation_postprocessed_responses = [ + x[len(y) :] + for x, y in zip(all_sample_validation_query_responses_postprocessed, all_decode_validation_queries) + ] + all_sample_validation_reference_responses = tokenizer.batch_decode(sample_validation_reference_response) + all_sample_validation_df = pd.DataFrame( + { + "query": all_decode_validation_queries, + "response": all_sample_validation_responses, + "postprocessed_response": all_sample_validation_postprocessed_responses, + "reference_responses": all_sample_validation_reference_responses, + "scores": validation_score.float().cpu().numpy(), + "reference_scores": sample_validation_reference_scores.float().cpu().numpy(), + } + ) + if accelerator.is_main_process: + all_sample_validation_df.to_json(f"runs/{run_name}/table.json") + if args.track: + wandb.log({"samples/query_responses": wandb.Table(dataframe=all_sample_validation_df)}, step=update) + print_rich_table("stuff", all_sample_validation_df[:4], console) + + except Exception as e: + print(e) + del ( + all_decode_validation_queries, + all_sample_validation_responses, + all_sample_validation_reference_responses, + all_sample_validation_df, + ) + # del postprocessed_query_responses + # torch.cuda.empty_cache() + + # 6. compute advantages and returns + lastgaelam = 0 + advantages_reversed = [] + gen_length = args.task.response_length + for t in reversed(range(gen_length)): + nextvalues = values[:, t + 1] if t < gen_length - 1 else 0.0 + delta = rewards[:, t] + args.ppo.gamma * nextvalues - values[:, t] + lastgaelam = delta + args.ppo.gamma * args.ppo.lam * lastgaelam + advantages_reversed.append(lastgaelam) + advantages = torch.stack(advantages_reversed[::-1], axis=1) + returns = advantages + values + advantages = whiten(advantages) + return_mean, return_var = returns.mean(), returns.var() + value_mean, value_var = values.mean(), values.var() + writer.add_histogram("rewards", rewards[0].float(), global_step) + writer.add_histogram("advantages", advantages[0].float(), global_step) + accelerator.print("rewards====", rewards[0]) + accelerator.print("advantages====", advantages[0]) + # raise + # pprint({ + # "rewards": rewards, + # "returns": returns, + # "advantages": advantages, + # }) + # breakpoint() + # Do multiple epochs of PPO training, with a fresh random shuffle in each epoch + for ppo_epoch_idx in range(args.ppo.noptepochs): + b_inds = np.random.permutation(args.ppo.local_batch_size) + minibatch_idx = 0 + for mini_batch_start in range(0, args.ppo.local_batch_size, args.ppo.local_mini_batch_size): + mini_batch_end = mini_batch_start + args.ppo.local_mini_batch_size + mini_batch_inds = b_inds[mini_batch_start:mini_batch_end] + gradient_accumulation_idx = 0 + for micro_batch_start in range(0, args.ppo.local_mini_batch_size, args.ppo.local_micro_batch_size): + with accelerator.accumulate(policy): + micro_batch_end = micro_batch_start + args.ppo.local_micro_batch_size + micro_batch_inds = mini_batch_inds[micro_batch_start:micro_batch_end] + mb_return = returns[micro_batch_inds] + mb_advantage = advantages[micro_batch_inds] + mb_values = values[micro_batch_inds] + mb_responses = responses[micro_batch_inds] + mb_query_responses = query_responses[micro_batch_inds] + mb_logprobs = logprobs[micro_batch_inds] + + output, vpred_temp = forward(model, mb_query_responses, tokenizer) + logits = output.logits[:, context_length - 1 : -1] + logits /= (args.task.temperature + 1e-7) + new_all_logprobs = F.log_softmax(logits, dim=-1) + new_logprobs = torch.gather(new_all_logprobs, 2, mb_responses.unsqueeze(-1)).squeeze(-1) + # new_logprobs = torch.masked_fill(new_logprobs, padding_mask[micro_batch_inds], INVALID_LOGPROB) + vpred = vpred_temp[:, context_length - 1 : -1].squeeze(-1) + # vpred = torch.masked_fill(vpred, padding_mask[micro_batch_inds], 0) + # vpred = torch.masked_fill(vpred, values_mask[micro_batch_inds], 0) + vpredclipped = torch.clamp( + vpred, + mb_values - args.ppo.cliprange_value, + mb_values + args.ppo.cliprange_value, + ) + vf_losses1 = torch.square(vpred - mb_return) + vf_losses2 = torch.square(vpredclipped - mb_return) + vf_loss = 0.5 * torch.max(vf_losses1, vf_losses2).mean() + vf_clipfrac = (vf_losses2 > vf_losses1).float().mean() + logprobs_diff = new_logprobs - mb_logprobs + ratio = torch.exp(logprobs_diff) + pg_losses = -mb_advantage * ratio + pg_losses2 = -mb_advantage * torch.clamp(ratio, 1.0 - args.ppo.cliprange, 1.0 + args.ppo.cliprange) + pg_loss = torch.max(pg_losses, pg_losses2).mean() + pg_clipfrac = (pg_losses2 > pg_losses).float().mean() + loss = pg_loss + args.ppo.vf_coef * vf_loss + accelerator.backward(loss) + optimizer.step() + optimizer.zero_grad() + prob_dist = torch.nn.functional.softmax(logits, dim=-1) + entropy = torch.logsumexp(logits, dim=-1) - torch.sum(prob_dist * logits, dim=-1) + approxkl = 0.5 * (logprobs_diff**2).mean() + # if ppo_epoch_idx == 0 and micro_batch_start == 0: + # torch.testing.assert_close(ratio, torch.zeros_like(ratio) + 1, atol=1e-4, rtol=1e-4) + # if ppo_epoch_idx == 0: + # pprint({ + # # "responses": responses, + # # "values": values, + # "rewards": rewards, + # # "scores": scores, + # "advantages": advantages, + # # "ratio": ratio, + # # "pg_losses": pg_losses, + # # "approxkl": approxkl, + # # "pg_loss": pg_loss, + # # "pg_clipfrac": pg_clipfrac, + # # "ratio": ratio.mean(), + # # "vf_loss": vf_loss, + # # "vf_clipfrac": vf_clipfrac, + # # "entropy": masked_mean(entropy, ~padding_mask[micro_batch_inds]), + # }) + # breakpoint() + with torch.no_grad(): + approxkl_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = approxkl + pg_clipfrac_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = pg_clipfrac + pg_loss_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = pg_loss + vf_loss_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = vf_loss + vf_clipfrac_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = vf_clipfrac + entropy_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = entropy.mean() + ratio_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = ratio.mean() + gradient_accumulation_idx += 1 + minibatch_idx += 1 + if accelerator.is_main_process: + console.print( + f"ppo_epoch_idx", + ppo_epoch_idx, + "approxkl", + approxkl_stats[:ppo_epoch_idx+1].mean().item(), + "pg_loss", + pg_loss_stats[:ppo_epoch_idx+1].mean().item(), + "pg_clipfrac", + pg_clipfrac_stats[:ppo_epoch_idx+1].mean().item(), + "ratio", + ratio_stats[:ppo_epoch_idx+1].mean().item(), + ) + # raise + # breakpoint() + with torch.no_grad(): + if not args.deepspeed: # for some reason there is a OOM with the `writer.add_histogram` + writer.add_histogram("ppo/val/ratio_hist", ratio, update) + mean_kl = kl.sum(1).mean() + mean_entropy = (-logprobs).sum(1).mean() + mean_non_score_reward = non_score_reward.sum(1).mean() + writer.add_scalar("objective/kl_coef", kl_ctl.value, update) + writer.add_scalar("objective/kl", accelerator.gather(mean_kl).mean().item(), update) + writer.add_scalar("objective/entropy", accelerator.gather(mean_entropy).mean().item(), update) + writer.add_scalar("objective/non_score_reward", accelerator.gather(mean_non_score_reward).mean().item(), update) + writer.add_scalar( + "objective/score_total", accelerator.gather(mean_non_score_reward + scores.mean()).mean().item(), update + ) + writer.add_scalar("objective/scores", accelerator.gather(scores.mean()).mean().item(), update) + writer.add_scalar("objective/validation_score", accelerator.gather(validation_score.mean()).mean().item(), update) + writer.add_scalar("ppo/loss/policy", accelerator.gather(pg_loss).mean().item(), update) + writer.add_scalar("ppo/loss/value", accelerator.gather(vf_loss).mean().item(), update) + writer.add_scalar("ppo/loss/total", accelerator.gather(loss).mean().item(), update) + writer.add_scalar("ppo/policy/entropy", accelerator.gather(entropy.mean()).mean().item(), update) + writer.add_scalar("ppo/policy/approxkl", accelerator.gather(approxkl).mean().item(), update) + writer.add_scalar("ppo/policy/clipfrac", accelerator.gather(pg_clipfrac).mean().item(), update) + writer.add_scalar("ppo/policy/approxkl_avg", accelerator.gather(approxkl_stats).mean().item(), update) + writer.add_scalar("ppo/policy/clipfrac_avg", accelerator.gather(pg_clipfrac_stats).mean().item(), update) + writer.add_scalar("ppo/loss/policy_avg", accelerator.gather(pg_loss_stats).mean().item(), update) + writer.add_scalar("ppo/loss/value_avg", accelerator.gather(vf_loss_stats).mean().item(), update) + writer.add_scalar("ppo/val/clipfrac_avg", accelerator.gather(vf_clipfrac_stats).mean().item(), update) + writer.add_scalar("ppo/policy/entropy_avg", accelerator.gather(entropy_stats).mean().item(), update) + writer.add_scalar("ppo/returns/mean", accelerator.gather(return_mean).mean().item(), update) + writer.add_scalar("ppo/returns/var", accelerator.gather(return_var).mean().item(), update) + writer.add_scalar("ppo/val/vpred", accelerator.gather(vpred.mean()).mean().item(), update) + writer.add_scalar("ppo/val/error", accelerator.gather(vf_losses1.mean()).mean().item(), update) + writer.add_scalar("ppo/val/clipfrac", accelerator.gather(vf_clipfrac).mean().item(), update) + writer.add_scalar("ppo/val/mean", accelerator.gather(value_mean).mean().item(), update) + writer.add_scalar("ppo/val/var", accelerator.gather(value_var).mean().item(), update) + writer.add_scalar("ppo/val/ratio", accelerator.gather(ratio_stats).mean().item(), update) + writer.add_scalar("ppo/val/ratio_var", accelerator.gather(ratio_stats).var().item(), update) + writer.add_scalar("ppo/val/advantage", accelerator.gather(advantages.mean()).mean().item(), update) + writer.add_scalar("ppo/val/advantage_var", accelerator.gather(advantages.mean()).var().item(), update) + writer.add_scalar("ppo/val/num_eos_tokens", (responses == tokenizer.eos_token_id).sum().item(), update) + writer.add_scalar("ppo/lr", lrnow, update) + writer.add_scalar("ppo/episode", global_step, update) + if args.rewards.use_adaptive_kl: + kl_ctl.update(mean_kl.item(), args.ppo.batch_size) + del kl, mean_kl, mean_entropy, mean_non_score_reward, scores + + # save model + if args.save_path: + os.makedirs(os.path.dirname(args.save_path), exist_ok=True) + accelerator.save_model(policy, args.save_path, max_shard_size="1000GB") + + if args.upload_model and accelerator.is_main_process: + repo_name = f"{args.exp_name}__{args.rewards.label_dataset}__seed{args.seed}__{int(time.time())}" + repo_id = f"{args.hf_entity}/{repo_name}" if args.hf_entity else repo_name + policy.save_pretrained(repo_id, safe_serialization=True, push_to_hub=True) + tokenizer.save_pretrained(repo_id, push_to_hub=True) + +# if __name__ == "__main__": +# args = tyro.cli(Args) +# train(args) diff --git a/lm_human_preference_details/train_policy_accelerate_summarize_separate8_correct_reward_index_deepspeed3.py b/lm_human_preference_details/train_policy_accelerate_summarize_separate8_correct_reward_index_deepspeed3.py new file mode 100644 index 0000000..a1e860a --- /dev/null +++ b/lm_human_preference_details/train_policy_accelerate_summarize_separate8_correct_reward_index_deepspeed3.py @@ -0,0 +1,1018 @@ +import os +import random +import time +from dataclasses import asdict, dataclass, field +from types import SimpleNamespace +from typing import List, Literal, Optional + +import numpy as np +import pandas as pd +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +import tyro +from accelerate import Accelerator +from accelerate.state import AcceleratorState +from datasets import load_dataset +from rich.console import Console +from rich.pretty import pprint +from rich.table import Table +from torch import Tensor, optim +from torch.optim.optimizer import ( + _dispatch_sqrt, + _get_value, + _use_grad_for_differentiable, +) +from torch.utils.data import DataLoader +from torch.utils.tensorboard import SummaryWriter +from transformers import ( + AutoConfig, + AutoModelForCausalLM, + AutoTokenizer, + GenerationConfig, +) + + +INVALID_LOGPROB = 1.0 + + +@dataclass +class AdaptiveKLParams: + target: float = 6.0 + horizon: int = 10000 # in episodes + + +@dataclass +class RewardHParams: + use_adaptive_kl: bool = True + adaptive_kl: Optional[AdaptiveKLParams] = field(default_factory=AdaptiveKLParams) + trained_model: Optional[str] = "" + label_dataset: tyro.conf.Suppress[Optional[str]] = None + dataset_mean: float = 0. + dataset_std: float = 1. + kl_coef: float = 0.15 + + +@dataclass +class PpoHParams: + total_episodes: int = 1000000 + local_batch_size: int = 64 + local_mini_batch_size: tyro.conf.Suppress[int] = None + batch_size: tyro.conf.Suppress[int] = None + mini_batch_size: tyro.conf.Suppress[int] = None + gradient_accumulation_steps: int = 64 + """gradient accumulation steps""" + local_micro_batch_size: tyro.conf.Suppress[int] = None + """per rank micro batch size""" + world_size: tyro.conf.Suppress[int] = None + batch_size: tyro.conf.Suppress[int] = None + minibatch_size: tyro.conf.Suppress[int] = None + num_updates: tyro.conf.Suppress[int] = None + nminibatches: int = 1 + noptepochs: int = 4 + lr: float = 0.00001 + eps: float = 1e-5 + vf_coef: float = 0.1 + cliprange: float = 0.2 + cliprange_value: float = 0.2 + gamma: float = 1 + lam: float = 0.95 + whiten_rewards: bool = True + + +@dataclass +class TaskHParams: + # Query params + query_length: int = 512 + query_dataset: str = "vwxyzjn/summarize_from_feedback_tldr_3_filtered_oai_preprocessing_pythia-160m_53" + + query_format_str: Optional[str] = "SUBREDDIT: r/{subreddit}\n\nTITLE: {title}\n\nPOST: {post}\n\nTL;DR:" + query_truncate_field: Optional[str] = "post" + query_truncate_text: Optional[str] = "\n" + query_padding: Optional[str] = None # defaults to repeated spaces + query_pad_side: Optional[str] = "left" + + # Response params + response_length: int = 53 + + # Truncate response after the first occurrence of this token at or after index after when sampling. + truncate_token: Literal["eos"] = "eos" + truncate_token_id: Optional[int] = None + truncate_after: int = 16 + penalty_reward_value: int = -1 + + # LM params + temperature: float = 0.7 + + +# a patch +@dataclass +class TaskQueryHParams: + length: int = None + dataset: str = None + format_str: Optional[str] = None # if underlying dataset yields dicts, can format arbitrarily + truncate_field: Optional[str] = None + truncate_text: Optional[str] = None + padding: Optional[str] = None # defaults to repeated spaces + pad_side: Optional[str] = None + + +@dataclass +class Args: + # common args + exp_name: str = os.path.basename(__file__)[: -len(".py")] + """the name of this experiment""" + seed: int = 1 + """seed of the experiment""" + track: bool = False + """if toggled, this experiment will be tracked with Weights and Biases""" + wandb_project_name: str = "tldr_summarize" + """the wandb's project name""" + wandb_entity: Optional[str] = None + """the entity (team) of wandb's project""" + cuda: bool = True + """Whether to use cuda if available.""" + run_name: tyro.conf.Suppress[str] = None + """TO BE FILLED: a unique name of this run""" + load_from_cache_file: bool = False + """Whether to load data from the local cache file in `dataset.map`""" + upload_model: bool = False + "whether to upload the saved model to huggingface" + hf_entity: str = "" + "the user or org name of the model repository from the Hugging Face Hub" + + base_model: str = "EleutherAI/pythia-160m" + """the name of the pretrained model to use""" + dropout_layer_keys: List[str] = field(default_factory=lambda: ["attn_pdrop", "embd_pdrop", "resid_pdrop", "summary_first_dropout"]) + """Which layers to apply dropout to""" + deepspeed: bool = False + """Whether to use deepspeed to train the model""" + print_sample_output_freq: int = 10 + """How often to print sample output""" + save_path: str = "models/ppo_policy" + """Where to save the model""" + optimizer: Literal["tf_adam", "adam", "adamw"] = "adamw" + """Which optimizer to use""" + sft_model_path: str = "" + """Where to load the SFT model""" + task: TaskHParams = field(default_factory=TaskHParams) + rewards: RewardHParams = field(default_factory=RewardHParams) + ppo: PpoHParams = field(default_factory=PpoHParams) + + +# taken from https://github.com/microsoft/DeepSpeedExamples/blob/737c6740bec38b77a24a59135b6481a53d566b38/applications/DeepSpeed-Chat/training/utils/model/model_utils.py#L20C1-L26C52 +def configure_dropout(model_config, dropout_layer_keys, dropout): + if dropout is not None: + for key in dropout_layer_keys: + if hasattr(model_config, key): + print(f"Setting model_config.{key} to {dropout}") + setattr(model_config, key, dropout) + + +def print_rich_table(title: str, df: pd.DataFrame, console: Console) -> Table: + table = Table(show_lines=True) + for column in df.columns: + table.add_column(column) + for _, row in df.iterrows(): + table.add_row(*row.astype(str).tolist()) + console.rule(f"[bold red]{title}") + console.print(table) + + +def _single_tensor_adam( + params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_avg_sqs: List[Tensor], + max_exp_avg_sqs: List[Tensor], + state_steps: List[Tensor], + grad_scale: Optional[Tensor], + found_inf: Optional[Tensor], + *, + amsgrad: bool, + beta1: float, + beta2: float, + lr: float, + weight_decay: float, + eps: float, + maximize: bool, + capturable: bool, + differentiable: bool, +): + assert grad_scale is None and found_inf is None + + for i, param in enumerate(params): + grad = grads[i] if not maximize else -grads[i] + exp_avg = exp_avgs[i] + exp_avg_sq = exp_avg_sqs[i] + step_t = state_steps[i] + # update step + step_t += 1 + # Decay the first and second moment running average coefficient + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2) + step = _get_value(step_t) + + ### pytorch adam implementation: + # bias_correction1 = 1 - beta1 ** step + # bias_correction2 = 1 - beta2 ** step + # step_size = lr / bias_correction1 + # bias_correction2_sqrt = _dispatch_sqrt(bias_correction2) + # denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps) + # param.addcdiv_(exp_avg, denom, value=-step_size) + + ### tensorflow adam implementation: + lr_t = lr * _dispatch_sqrt(1 - beta2**step) / (1 - beta1**step) + denom = exp_avg_sq.sqrt().add_(eps) + param.addcdiv_(exp_avg, denom, value=-lr_t) + + +def adam( + params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_avg_sqs: List[Tensor], + max_exp_avg_sqs: List[Tensor], + state_steps: List[Tensor], + # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627 + # setting this as kwarg for now as functional API is compiled by torch/distributed/optim + foreach: Optional[bool] = None, + capturable: bool = False, + differentiable: bool = False, + fused: Optional[bool] = None, + grad_scale: Optional[Tensor] = None, + found_inf: Optional[Tensor] = None, + *, + amsgrad: bool, + beta1: float, + beta2: float, + lr: float, + weight_decay: float, + eps: float, + maximize: bool, +): + func = _single_tensor_adam + + func( + params, + grads, + exp_avgs, + exp_avg_sqs, + max_exp_avg_sqs, + state_steps, + amsgrad=amsgrad, + beta1=beta1, + beta2=beta2, + lr=lr, + weight_decay=weight_decay, + eps=eps, + maximize=maximize, + capturable=capturable, + differentiable=differentiable, + grad_scale=grad_scale, + found_inf=found_inf, + ) + + +class AdamTensorFlowStyle(optim.Adam): + @_use_grad_for_differentiable + def step(self, closure=None): + self._cuda_graph_capture_health_check() + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + params_with_grad = [] + grads = [] + exp_avgs = [] + exp_avg_sqs = [] + max_exp_avg_sqs = [] + state_steps = [] + beta1, beta2 = group["betas"] + + self._init_group( + group, + params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + max_exp_avg_sqs, + state_steps, + ) + + adam( + params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + max_exp_avg_sqs, + state_steps, + amsgrad=group["amsgrad"], + beta1=beta1, + beta2=beta2, + lr=group["lr"], + weight_decay=group["weight_decay"], + eps=group["eps"], + maximize=group["maximize"], + foreach=group["foreach"], + capturable=group["capturable"], + differentiable=group["differentiable"], + fused=group["fused"], + grad_scale=getattr(self, "grad_scale", None), + found_inf=getattr(self, "found_inf", None), + ) + + return loss + + +class AdaptiveKLController: + def __init__(self, init_kl_coef: float, hparams: AdaptiveKLParams): + self.value = init_kl_coef + self.hparams = hparams + + def update(self, current, n_steps): + target = self.hparams.target + proportional_error = np.clip(current / target - 1, -0.2, 0.2) + mult = 1 + proportional_error * n_steps / self.hparams.horizon + self.value *= mult + + +def layer_init(layer, std=np.sqrt(2), bias_const=0.0): + torch.nn.init.normal_(layer.weight, std=std) + torch.nn.init.constant_(layer.bias, val=bias_const) + return layer + + +def whiten(values, shift_mean=True): + # `unbiased=False` matches TF `tf.nn.moments`'s setting + mean, var = torch.mean(values), torch.var(values, unbiased=False) + whitened = (values - mean) * torch.rsqrt(var + 1e-8) + if not shift_mean: + whitened += mean + return whitened + + +class AutoModelForCausalLMWithRewardHead(nn.Module): + def __init__(self, lm_backbone): + super().__init__() + self.lm_backbone = lm_backbone + # self.scalar_head = layer_init( + # nn.Linear(lm_backbone.config.hidden_size, 1), + # std=1 / np.sqrt(lm_backbone.config.hidden_size + 1), + # ) + self.scalar_head = layer_init(nn.Linear(lm_backbone.config.hidden_size, 1), std=0) + # self.reward_gain = torch.nn.Parameter(torch.tensor(1.0), requires_grad=False) + self.reward_bias = torch.nn.Parameter(torch.tensor(0.0), requires_grad=False) + + def forward(self, **kwargs): + output = self.lm_backbone(**kwargs) + reward = self.scalar_head(output.hidden_states[-1]) - self.reward_bias + return reward + + +# taken from https://github.com/OpenLMLab/MOSS-RLHF/blob/40b91eb2f2b71b16919addede0341d2bef70825d/ppo/ppo_trainer.py#L29 +# we did this we can do a single `model = accelerator.prepare(model)` +class PolicyAndValueWrapper(nn.Module): + def __init__(self, policy, critic) -> None: + super().__init__() + self.policy = policy + self.critic = critic + + def forward(self, **kwargs): + return self.policy(**kwargs), self.critic(**kwargs) + + +def ceil_div(a, b): + return (a - 1) // b + 1 + + +def exact_div(a, b): + q = a // b + if a != q * b: + raise ValueError(f"Inexact division: {a} / {b} = {a / b}") + return q + + +def generate(lm_backbone, queries, tokenizer, generation_config): + """generate in a way that does not affect padding tokens""" + context_length = queries.shape[1] + attention_mask = queries != tokenizer.pad_token_id + input_ids = torch.masked_fill(queries, ~attention_mask, 0) + output = lm_backbone.generate( + input_ids=input_ids, + attention_mask=attention_mask, + # position_ids=attention_mask.cumsum(1) - attention_mask.long(), # generation collapsed if this was turned on. TODO: why does generation collapse with this? + generation_config=generation_config, + return_dict_in_generate=True, + ) + return torch.cat((queries, output.sequences[:, context_length:]), dim=1) + + +def first_true_indices(bools, dtype=torch.long): + """ + Takes an N-dimensional bool tensor and returns an (N-1)-dimensional tensor of integers giving + the position of the first True in each "row". + + Returns the length of the rows (bools.size(-1)) if no element is True in a given row. + """ + row_len = bools.size(-1) + zero_or_index = row_len * (~bools).type(dtype) + torch.arange(row_len, dtype=dtype, device=bools.device) + return torch.min(zero_or_index, dim=-1).values + + +def truncate_response(args, tokenizer, responses): + trunc_idxs = first_true_indices(responses == args.task.truncate_token_id).unsqueeze(-1) + new_size = [1] * (len(responses.size()) - 1) + [args.task.response_length] + idxs = torch.arange(args.task.response_length, device=responses.device).view(*new_size) + postprocessed_responses = torch.masked_fill(responses, idxs > trunc_idxs, tokenizer.pad_token_id) + return postprocessed_responses + + +def get_reward(reward_model, query_responses, tokenizer): + attention_mask = query_responses != tokenizer.pad_token_id + # position_ids = attention_mask.cumsum(1) - attention_mask.long() # exclusive cumsum + input_ids = torch.masked_fill(query_responses, ~attention_mask, 0) + reward_logits = reward_model( + input_ids=input_ids, + attention_mask=attention_mask, + # position_ids=position_ids, + return_dict=True, + output_hidden_states=True, + ) + sequence_lengths = first_true_indices(query_responses == tokenizer.pad_token_id) - 1 + # sequence_lengths1 = ( + # torch.eq(query_responses, tokenizer.pad_token_id).long().argmax(-1) - 1).to( + # query_responses.device + # ) + # print(f"======={sequence_lengths1=} {sequence_lengths=}") + # https://github.com/huggingface/transformers/blob/dc68a39c8111217683bf49a4912d0c9018bab33d/src/transformers/models/gpt2/modeling_gpt2.py#L1454 + return reward_logits, reward_logits[torch.arange(reward_logits.size(0), device=reward_logits.device), sequence_lengths].squeeze(-1), sequence_lengths + + +def forward(policy, query_responses, tokenizer): + attention_mask = query_responses != tokenizer.pad_token_id + # position_ids = attention_mask.cumsum(1) - attention_mask.long() # exclusive cumsum + input_ids = torch.masked_fill(query_responses, ~attention_mask, 0) + return policy( + input_ids=input_ids, + attention_mask=attention_mask, + # position_ids=position_ids, + return_dict=True, + output_hidden_states=True, + ) + + +# def train(args: Args): +if __name__ == "__main__": + args = tyro.cli(Args) + accelerator = Accelerator(gradient_accumulation_steps=args.ppo.gradient_accumulation_steps) + args.ppo.world_size = accelerator.num_processes + args.ppo.batch_size = int(args.ppo.local_batch_size * args.ppo.world_size) + args.ppo.minibatch_size = exact_div(args.ppo.batch_size, args.ppo.nminibatches) + args.ppo.local_mini_batch_size = exact_div(args.ppo.local_batch_size, args.ppo.nminibatches) + args.ppo.local_micro_batch_size = exact_div(args.ppo.local_mini_batch_size, args.ppo.gradient_accumulation_steps) + if args.ppo.whiten_rewards: + assert ( + args.ppo.local_mini_batch_size >= 8 + ), f"Per-rank minibatch size {args.ppo.local_mini_batch_size} is insufficient for whitening" + # `per_rank_rollout_batch_size` is our `args.ppo.local_batch_size` + # `per_rank_minibatch_size` is our `args.ppo.local_mini_batch_size` + args.ppo.num_updates = args.ppo.total_episodes // args.ppo.batch_size + tokenizer = AutoTokenizer.from_pretrained( + args.base_model, + padding_side="right", + trust_remote_code=True, + ) + # we use the padding token manually but do not resize the token embedding of the model + tokenizer.add_special_tokens({"pad_token": "[PAD]"}) + if args.task.truncate_token == "eos": + args.task.truncate_token_id = tokenizer.eos_token_id + + console = Console(force_terminal=True) + run_name = f"{args.exp_name}__{args.seed}__{int(time.time())}" + writer = SimpleNamespace() # dummy writer + writer.add_scalar = lambda x, y, z: None + writer.add_histogram = lambda x, y, z: None + if accelerator.is_main_process: + if args.track: + import wandb + + wandb.init( + project=args.wandb_project_name, + entity=args.wandb_entity, + sync_tensorboard=True, + config=asdict(args), + name=run_name, + save_code=True, + ) + wandb.run.log_code(".") + writer = SummaryWriter(f"runs/{run_name}") + writer.add_text( + "hyperparameters", + "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])), + ) + pprint(args) + device = accelerator.device + local_seed = args.seed + accelerator.process_index * 100003 # Prime + random.seed(local_seed) + np.random.seed(local_seed) + torch.manual_seed(local_seed) + torch.backends.cudnn.deterministic = True + + model_config = AutoConfig.from_pretrained(args.base_model) + configure_dropout(model_config, args.dropout_layer_keys, 0.0) # disable dropout + if accelerator.is_main_process: + pprint(model_config) + setattr(model_config, "use_cache", False) + critic = AutoModelForCausalLMWithRewardHead( + AutoModelForCausalLM.from_pretrained( + args.base_model, + config=model_config, + trust_remote_code=True, + ) + ) + reward_model = AutoModelForCausalLMWithRewardHead( + AutoModelForCausalLM.from_pretrained( + args.base_model, + config=model_config, + trust_remote_code=True, + ) + ) + if args.rewards.trained_model: + critic.load_state_dict(torch.load(args.rewards.trained_model, map_location=device), strict=False) + critic.reward_bias.data = torch.tensor(args.rewards.dataset_mean) + reward_model.load_state_dict(torch.load(args.rewards.trained_model, map_location=device), strict=False) + reward_model.reward_bias.data = torch.tensor(args.rewards.dataset_mean) + print(f"loaded pretrained reward model from {args.rewards.trained_model}") + # each class should have a separate pretrained model that do not share weights + ref_policy = AutoModelForCausalLM.from_pretrained(args.base_model, config=model_config, trust_remote_code=True) + policy = AutoModelForCausalLM.from_pretrained(args.base_model, config=model_config, trust_remote_code=True) + # policy.gradient_checkpointing_enable() + # critic.lm_backbone.gradient_checkpointing_enable() + accelerator.print(policy) + accelerator.print(critic) + if args.sft_model_path: + policy.load_state_dict(torch.load(args.sft_model_path, map_location=device)) + ref_policy.load_state_dict(torch.load(args.sft_model_path, map_location=device)) + print(f"loaded pretrained policy from {args.sft_model_path}") + policy.generation_config.eos_token_id = None # disable `pad_token_id` and `eos_token_id` because we just want to + policy.generation_config.pad_token_id = None # generate tokens without truncation / padding + model = PolicyAndValueWrapper(policy, critic) + if args.optimizer == "tf_adam": + optimizer = AdamTensorFlowStyle(model.parameters(), lr=args.ppo.lr, eps=args.ppo.eps) + elif args.optimizer == "adam": + optimizer = optim.Adam(model.parameters(), lr=args.ppo.lr, eps=args.ppo.eps) + elif args.optimizer == "adamw": + if args.deepspeed: + deepspeed_states = AcceleratorState().deepspeed_plugin + from deepspeed.ops.adam import DeepSpeedCPUAdam + # if deepspeed_states.deepspeed_config['zero_optimization']['offload_optimizer']['device'] in ('none', None): + # return optim.AdamW(params, eps=self.opt.eps, betas=(self.opt.beta1, self.opt.beta2)) + optimizer = DeepSpeedCPUAdam(model.parameters(), lr=args.ppo.lr, eps=args.ppo.eps) + else: + optimizer = optim.AdamW(model.parameters(), lr=args.ppo.lr, eps=args.ppo.eps) + + dataset = load_dataset(args.task.query_dataset, split="train") + validation_dataset = load_dataset(args.task.query_dataset, split="validation") + dataset = dataset.with_format("torch", columns=["query_token", "reference_response_token"]) + dataset = dataset.shuffle(seed=local_seed) + dataloader = DataLoader(dataset, batch_size=args.ppo.local_batch_size) + validation_dataset = validation_dataset.with_format("torch", columns=["query_token", "reference_response_token"]) + validation_dataloader = DataLoader(validation_dataset, batch_size=args.ppo.local_batch_size) + model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader) + validation_dataloader = accelerator.prepare(validation_dataloader) + if args.deepspeed: + import deepspeed + + # deepspeed_states = AcceleratorState().deepspeed_plugin + # # deepspeed_states.deepspeed_config["train_micro_batch_size_per_gpu"] = args.ppo.local_micro_batch_size + # # deepspeed_states.deepspeed_config["checkpoint"] = {"use_node_local_storage": True} + + # offload = False + # eval_ds_config = { + # "train_micro_batch_size_per_gpu": deepspeed_states.deepspeed_config["train_micro_batch_size_per_gpu"], + # "bf16": {"enabled": True}, + # "prescale_gradients": False, + # "wall_clock_breakdown": False, + # } + # if offload: + # eval_ds_config["zero_optimization"] = { + # "stage": 3, + # "stage3_param_persistence_threshold": 1e4, + # "offload_param": { + # "device": "cpu" + # } + # } + # accelerator.print(f"{eval_ds_config=}") + # reward_model, *_ = deepspeed.initialize(model=reward_model, config=eval_ds_config) + # ref_policy, *_ = deepspeed.initialize(model=ref_policy, config=eval_ds_config) + # reward_model = accelerator.prepare(reward_model) + # ref_policy = accelerator.prepare(ref_policy) + + deepspeed_plugin = accelerator.state.deepspeed_plugin + batch_size_per_device = deepspeed_plugin.deepspeed_config["train_micro_batch_size_per_gpu"] + deepspeed_plugin.deepspeed_config["checkpoint"] = {"use_node_local_storage": True} + # See DeepSpeed docs for definition of these parameters: https://deepspeed.readthedocs.io/en/latest/zero3.html + config_kwargs = { + "train_micro_batch_size_per_gpu": batch_size_per_device, + "bf16": {"enabled": True}, + "prescale_gradients": False, + "wall_clock_breakdown": False, + "zero_optimization": { + "stage": 3, + "offload_param": {"device": deepspeed_plugin.offload_param_device}, + "stage3_prefetch_bucket_size": 0, + "stage3_max_live_parameters": 0, + "stage3_max_reuse_distance": 0, + }, + } + accelerator.print(config_kwargs) + reward_model, *_ = deepspeed.initialize(model=reward_model, config=config_kwargs) + reward_model.eval() + ref_policy, *_ = deepspeed.initialize(model=ref_policy, config=config_kwargs) + ref_policy.eval() + + else: + ref_policy = ref_policy.to(device) + reward_model = reward_model.to(device) + reward_model.eval() + ref_policy.eval() + + def repeat_generator(): # TODO: ideally we shuffle the dataloader as well + while True: + yield from dataloader + iter_dataloader = iter(repeat_generator()) + + sample_validation_inds = np.arange(args.ppo.batch_size) + local_sample_validation_inds = sample_validation_inds[accelerator.process_index :: accelerator.num_processes] + sample_validation = validation_dataset[local_sample_validation_inds] + sample_validation_queries = torch.Tensor(sample_validation["query_token"]).to(device) + with torch.no_grad(): + # sample_validation_queries = shift_pad_id_left(sample_validation_queries, tokenizer.pad_token_id) + sample_validation_reference_response = torch.Tensor(sample_validation["reference_response_token"]).to(device) + sample_validation_query_reference_responses = torch.cat( + (sample_validation_queries, sample_validation_reference_response), dim=1 + ) + # sample_validation_query_reference_responses = shift_pad_id_left( + # sample_validation_query_reference_responses, tokenizer.pad_token_id + # ) + data = next(iter_dataloader) + queries = data["query_token"].to(device) + accelerator.print(f"==={queries.shape=}, {queries.dtype}") + accelerator.print(f"==={sample_validation_query_reference_responses.shape=}, {sample_validation_query_reference_responses.dtype}") + _, sample_validation_reference_scores, _ = get_reward( + reward_model, sample_validation_query_reference_responses, tokenizer + ) + + + kl_ctl = AdaptiveKLController(args.rewards.kl_coef, hparams=args.rewards.adaptive_kl) + # WARNING: even with `max_new_tokens` and `min_new_tokens` set to the same value, the number of tokens generated + # may not be the same. TODO: investigate further, we just want to generate a fixed number of tokens + generation_config = GenerationConfig( + max_new_tokens=args.task.response_length, + min_new_tokens=args.task.response_length, + temperature=(args.task.temperature + 1e-7), + top_k=0.0, + top_p=1.0, + do_sample=True, + ) + # use the same `0.01` temperature for validation response generation https://github.com/openai/summarize-from-feedback/blob/700967448d10004279f138666442bf1497d0e705/exps/sample.py#L27 + validation_generation_config= GenerationConfig( + max_new_tokens=args.task.response_length, + min_new_tokens=args.task.response_length, + temperature=(0.01 + 1e-7), + top_k=0.0, + top_p=1.0, + do_sample=True, + ) + accelerator.print(f"----------------{optimizer=}") + accelerator.print("===training policy===") + global_step = 0 + stats_shape = (args.ppo.noptepochs, args.ppo.nminibatches, args.ppo.gradient_accumulation_steps) + approxkl_stats = torch.zeros(stats_shape, device=device) + pg_clipfrac_stats = torch.zeros(stats_shape, device=device) + pg_loss_stats = torch.zeros(stats_shape, device=device) + vf_loss_stats = torch.zeros(stats_shape, device=device) + vf_clipfrac_stats = torch.zeros(stats_shape, device=device) + entropy_stats = torch.zeros(stats_shape, device=device) + ratio_stats = torch.zeros(stats_shape, device=device) + model.train() + for update in range(1, args.ppo.num_updates + 1): + global_step += 1 * args.ppo.batch_size + frac = 1.0 - (update - 1.0) / args.ppo.num_updates + lrnow = frac * args.ppo.lr + optimizer.param_groups[0]["lr"] = lrnow + data = next(iter_dataloader) + with torch.no_grad(): + queries = data["query_token"].to(device) + query_responses = generate( + accelerator.unwrap_model(model).policy, + queries, + tokenizer, + generation_config, + ) + context_length = queries.shape[1] + responses = query_responses[:, context_length:] + + # validation + sample_validation_query_responses = generate( + accelerator.unwrap_model(model).policy, + sample_validation_queries, + tokenizer, + validation_generation_config, + ) + sample_validation_responses = sample_validation_query_responses[:, context_length:] + postprocessed_sample_validation_responses = truncate_response(args, tokenizer, sample_validation_responses) + postprocessed_sample_validation_query_responses = torch.cat( + (sample_validation_queries, postprocessed_sample_validation_responses), 1 + ) + torch.cuda.empty_cache() + + # TODO: do I do this with query response or post-processed query response? + output = forward(accelerator.unwrap_model(model).policy, query_responses, tokenizer) + logits = output.logits[:, context_length - 1 : -1] + logits /= (args.task.temperature + 1e-7) + all_logprobs = F.log_softmax(logits, dim=-1) + logprobs = torch.gather(all_logprobs, 2, responses.unsqueeze(-1)).squeeze(-1) + del output, logits, all_logprobs + torch.cuda.empty_cache() + + ref_output = forward(ref_policy, query_responses, tokenizer) + ref_logits = ref_output.logits[:, context_length - 1 : -1] + ref_logits /= (args.task.temperature + 1e-7) + ref_all_logprobs = F.log_softmax(ref_logits, dim=-1) + ref_logprobs = torch.gather(ref_all_logprobs, 2, responses.unsqueeze(-1)).squeeze(-1) + del ref_output, ref_logits, ref_all_logprobs + torch.cuda.empty_cache() + + # **Response Processing** + postprocessed_responses = truncate_response(args, tokenizer, responses) + torch.cuda.empty_cache() + + # 2. run reward model on the truncated responses + postprocessed_query_responses = torch.cat((queries, postprocessed_responses), 1) + # sequence_lengths = first_true_indices(postprocessed_responses == tokenizer.pad_token_id) - 1 + # actual_start = torch.arange(postprocessed_responses.size(0), device=postprocessed_responses.device) + # actual_end = sequence_lengths + # padding_mask = postprocessed_responses == tokenizer.pad_token_id + sequence_lengths = first_true_indices(postprocessed_responses == tokenizer.pad_token_id) - 1 + + full_values, _, _ = get_reward(accelerator.unwrap_model(model).critic, query_responses, tokenizer) + values = full_values[:, context_length - 1 : -1].squeeze(-1) + # values_mask = postprocessed_responses != args.task.truncate_token_id + # values = torch.masked_fill(values, values_mask, 0) + # values = torch.masked_fill(values, padding_mask, 0) + + # logprobs = torch.masked_fill(logprobs, padding_mask, INVALID_LOGPROB) + # ref_logprobs = torch.masked_fill(ref_logprobs, padding_mask, INVALID_LOGPROB) + _, scores, _ = get_reward(reward_model, postprocessed_query_responses, tokenizer) + + _, validation_score, _ = get_reward(reward_model, postprocessed_sample_validation_query_responses, tokenizer) + + # 3. filter response. Ensure that the sample contains truncate_token_id + # responses not passing that filter will receive a low (fixed) score + # only query humans on responses that pass that filter + contain_pad_token = torch.any(postprocessed_responses == tokenizer.pad_token_id, dim=-1) + scores = torch.where(contain_pad_token, scores, torch.full_like(scores, args.task.penalty_reward_value)) + + # TODO: do we need to deal with penalty values? + # penalty_values = torch.full_like(values, 0) + # penalty_values[:,-1] += args.task.penalty_reward_value + # values = torch.where(contain_pad_token, values, penalty_values) + accelerator.print(f"{scores=}, {(contain_pad_token.sum() / len(contain_pad_token))=}") + # torch.cuda.empty_cache() + + # 4. compute rewards + kl = logprobs - ref_logprobs + # kl = torch.masked_fill(kl, padding_mask, 0) + non_score_reward = -kl_ctl.value * kl + rewards = non_score_reward.clone() + actual_start = torch.arange(rewards.size(0), device=rewards.device) + actual_end = sequence_lengths + rewards[[actual_start, actual_end]] += scores + + # 5. whiten rewards + if args.ppo.whiten_rewards: + rewards = whiten(rewards, shift_mean=False) + + if args.print_sample_output_freq > 0 and (update - 1) % args.print_sample_output_freq == 0: + try: + all_decode_validation_queries = tokenizer.batch_decode(sample_validation_queries, skip_special_tokens=True) + all_sample_validation_responses = tokenizer.batch_decode(sample_validation_responses) + all_sample_validation_query_responses_postprocessed = tokenizer.batch_decode( + postprocessed_sample_validation_query_responses, skip_special_tokens=True + ) + all_sample_validation_postprocessed_responses = [ + x[len(y) :] + for x, y in zip(all_sample_validation_query_responses_postprocessed, all_decode_validation_queries) + ] + all_sample_validation_reference_responses = tokenizer.batch_decode(sample_validation_reference_response) + all_sample_validation_df = pd.DataFrame( + { + "query": all_decode_validation_queries, + "response": all_sample_validation_responses, + "postprocessed_response": all_sample_validation_postprocessed_responses, + "reference_responses": all_sample_validation_reference_responses, + "scores": validation_score.float().cpu().numpy(), + "reference_scores": sample_validation_reference_scores.float().cpu().numpy(), + } + ) + if accelerator.is_main_process: + all_sample_validation_df.to_json(f"runs/{run_name}/table.json") + if args.track: + wandb.log({"samples/query_responses": wandb.Table(dataframe=all_sample_validation_df)}, step=update) + # print_rich_table("stuff", all_sample_validation_df[:4], console) + + except Exception as e: + print(e) + del ( + all_decode_validation_queries, + all_sample_validation_responses, + all_sample_validation_reference_responses, + all_sample_validation_df, + ) + # del postprocessed_query_responses + # torch.cuda.empty_cache() + + # 6. compute advantages and returns + lastgaelam = 0 + advantages_reversed = [] + gen_length = args.task.response_length + for t in reversed(range(gen_length)): + nextvalues = values[:, t + 1] if t < gen_length - 1 else 0.0 + delta = rewards[:, t] + args.ppo.gamma * nextvalues - values[:, t] + lastgaelam = delta + args.ppo.gamma * args.ppo.lam * lastgaelam + advantages_reversed.append(lastgaelam) + advantages = torch.stack(advantages_reversed[::-1], axis=1) + returns = advantages + values + advantages = whiten(advantages) + return_mean, return_var = returns.mean(), returns.var() + value_mean, value_var = values.mean(), values.var() + writer.add_histogram("rewards", rewards[0].float(), global_step) + writer.add_histogram("advantages", advantages[0].float(), global_step) + accelerator.print("rewards====", rewards[0]) + accelerator.print("advantages====", advantages[0]) + # raise + # pprint({ + # "rewards": rewards, + # "returns": returns, + # "advantages": advantages, + # }) + # breakpoint() + # Do multiple epochs of PPO training, with a fresh random shuffle in each epoch + for ppo_epoch_idx in range(args.ppo.noptepochs): + b_inds = np.random.permutation(args.ppo.local_batch_size) + minibatch_idx = 0 + for mini_batch_start in range(0, args.ppo.local_batch_size, args.ppo.local_mini_batch_size): + mini_batch_end = mini_batch_start + args.ppo.local_mini_batch_size + mini_batch_inds = b_inds[mini_batch_start:mini_batch_end] + gradient_accumulation_idx = 0 + for micro_batch_start in range(0, args.ppo.local_mini_batch_size, args.ppo.local_micro_batch_size): + with accelerator.accumulate(policy): + micro_batch_end = micro_batch_start + args.ppo.local_micro_batch_size + micro_batch_inds = mini_batch_inds[micro_batch_start:micro_batch_end] + mb_return = returns[micro_batch_inds] + mb_advantage = advantages[micro_batch_inds] + mb_values = values[micro_batch_inds] + mb_responses = responses[micro_batch_inds] + mb_query_responses = query_responses[micro_batch_inds] + mb_logprobs = logprobs[micro_batch_inds] + + output, vpred_temp = forward(model, mb_query_responses, tokenizer) + logits = output.logits[:, context_length - 1 : -1] + logits /= (args.task.temperature + 1e-7) + new_all_logprobs = F.log_softmax(logits, dim=-1) + new_logprobs = torch.gather(new_all_logprobs, 2, mb_responses.unsqueeze(-1)).squeeze(-1) + # new_logprobs = torch.masked_fill(new_logprobs, padding_mask[micro_batch_inds], INVALID_LOGPROB) + vpred = vpred_temp[:, context_length - 1 : -1].squeeze(-1) + # vpred = torch.masked_fill(vpred, padding_mask[micro_batch_inds], 0) + # vpred = torch.masked_fill(vpred, values_mask[micro_batch_inds], 0) + vpredclipped = torch.clamp( + vpred, + mb_values - args.ppo.cliprange_value, + mb_values + args.ppo.cliprange_value, + ) + vf_losses1 = torch.square(vpred - mb_return) + vf_losses2 = torch.square(vpredclipped - mb_return) + vf_loss = 0.5 * torch.max(vf_losses1, vf_losses2).mean() + vf_clipfrac = (vf_losses2 > vf_losses1).float().mean() + logprobs_diff = new_logprobs - mb_logprobs + ratio = torch.exp(logprobs_diff) + pg_losses = -mb_advantage * ratio + pg_losses2 = -mb_advantage * torch.clamp(ratio, 1.0 - args.ppo.cliprange, 1.0 + args.ppo.cliprange) + pg_loss = torch.max(pg_losses, pg_losses2).mean() + pg_clipfrac = (pg_losses2 > pg_losses).float().mean() + loss = pg_loss + args.ppo.vf_coef * vf_loss + accelerator.backward(loss) + optimizer.step() + optimizer.zero_grad() + prob_dist = torch.nn.functional.softmax(logits, dim=-1) + entropy = torch.logsumexp(logits, dim=-1) - torch.sum(prob_dist * logits, dim=-1) + approxkl = 0.5 * (logprobs_diff**2).mean() + # if ppo_epoch_idx == 0 and micro_batch_start == 0: + # torch.testing.assert_close(ratio, torch.zeros_like(ratio) + 1, atol=1e-4, rtol=1e-4) + # if ppo_epoch_idx == 0: + # pprint({ + # # "responses": responses, + # # "values": values, + # "rewards": rewards, + # # "scores": scores, + # "advantages": advantages, + # # "ratio": ratio, + # # "pg_losses": pg_losses, + # # "approxkl": approxkl, + # # "pg_loss": pg_loss, + # # "pg_clipfrac": pg_clipfrac, + # # "ratio": ratio.mean(), + # # "vf_loss": vf_loss, + # # "vf_clipfrac": vf_clipfrac, + # # "entropy": masked_mean(entropy, ~padding_mask[micro_batch_inds]), + # }) + # breakpoint() + with torch.no_grad(): + approxkl_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = approxkl + pg_clipfrac_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = pg_clipfrac + pg_loss_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = pg_loss + vf_loss_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = vf_loss + vf_clipfrac_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = vf_clipfrac + entropy_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = entropy.mean() + ratio_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = ratio.mean() + gradient_accumulation_idx += 1 + minibatch_idx += 1 + if accelerator.is_main_process: + console.print( + f"ppo_epoch_idx", + ppo_epoch_idx, + "approxkl", + approxkl_stats[:ppo_epoch_idx+1].mean().item(), + "pg_loss", + pg_loss_stats[:ppo_epoch_idx+1].mean().item(), + "pg_clipfrac", + pg_clipfrac_stats[:ppo_epoch_idx+1].mean().item(), + "ratio", + ratio_stats[:ppo_epoch_idx+1].mean().item(), + ) + # raise + # breakpoint() + with torch.no_grad(): + if not args.deepspeed: # for some reason there is a OOM with the `writer.add_histogram` + writer.add_histogram("ppo/val/ratio_hist", ratio, update) + mean_kl = kl.sum(1).mean() + mean_entropy = (-logprobs).sum(1).mean() + mean_non_score_reward = non_score_reward.sum(1).mean() + writer.add_scalar("objective/kl_coef", kl_ctl.value, update) + writer.add_scalar("objective/kl", accelerator.gather(mean_kl).mean().item(), update) + writer.add_scalar("objective/entropy", accelerator.gather(mean_entropy).mean().item(), update) + writer.add_scalar("objective/non_score_reward", accelerator.gather(mean_non_score_reward).mean().item(), update) + writer.add_scalar( + "objective/score_total", accelerator.gather(mean_non_score_reward + scores.mean()).mean().item(), update + ) + writer.add_scalar("objective/scores", accelerator.gather(scores.mean()).mean().item(), update) + writer.add_scalar("objective/validation_score", accelerator.gather(validation_score.mean()).mean().item(), update) + writer.add_scalar("ppo/loss/policy", accelerator.gather(pg_loss).mean().item(), update) + writer.add_scalar("ppo/loss/value", accelerator.gather(vf_loss).mean().item(), update) + writer.add_scalar("ppo/loss/total", accelerator.gather(loss).mean().item(), update) + writer.add_scalar("ppo/policy/entropy", accelerator.gather(entropy.mean()).mean().item(), update) + writer.add_scalar("ppo/policy/approxkl", accelerator.gather(approxkl).mean().item(), update) + writer.add_scalar("ppo/policy/clipfrac", accelerator.gather(pg_clipfrac).mean().item(), update) + writer.add_scalar("ppo/policy/approxkl_avg", accelerator.gather(approxkl_stats).mean().item(), update) + writer.add_scalar("ppo/policy/clipfrac_avg", accelerator.gather(pg_clipfrac_stats).mean().item(), update) + writer.add_scalar("ppo/loss/policy_avg", accelerator.gather(pg_loss_stats).mean().item(), update) + writer.add_scalar("ppo/loss/value_avg", accelerator.gather(vf_loss_stats).mean().item(), update) + writer.add_scalar("ppo/val/clipfrac_avg", accelerator.gather(vf_clipfrac_stats).mean().item(), update) + writer.add_scalar("ppo/policy/entropy_avg", accelerator.gather(entropy_stats).mean().item(), update) + writer.add_scalar("ppo/returns/mean", accelerator.gather(return_mean).mean().item(), update) + writer.add_scalar("ppo/returns/var", accelerator.gather(return_var).mean().item(), update) + writer.add_scalar("ppo/val/vpred", accelerator.gather(vpred.mean()).mean().item(), update) + writer.add_scalar("ppo/val/error", accelerator.gather(vf_losses1.mean()).mean().item(), update) + writer.add_scalar("ppo/val/clipfrac", accelerator.gather(vf_clipfrac).mean().item(), update) + writer.add_scalar("ppo/val/mean", accelerator.gather(value_mean).mean().item(), update) + writer.add_scalar("ppo/val/var", accelerator.gather(value_var).mean().item(), update) + writer.add_scalar("ppo/val/ratio", accelerator.gather(ratio_stats).mean().item(), update) + writer.add_scalar("ppo/val/ratio_var", accelerator.gather(ratio_stats).var().item(), update) + writer.add_scalar("ppo/val/advantage", accelerator.gather(advantages.mean()).mean().item(), update) + writer.add_scalar("ppo/val/advantage_var", accelerator.gather(advantages.mean()).var().item(), update) + writer.add_scalar("ppo/val/num_eos_tokens", (responses == tokenizer.eos_token_id).sum().item(), update) + writer.add_scalar("ppo/lr", lrnow, update) + writer.add_scalar("ppo/episode", global_step, update) + if args.rewards.use_adaptive_kl: + kl_ctl.update(mean_kl.item(), args.ppo.batch_size) + del kl, mean_kl, mean_entropy, mean_non_score_reward, scores + + # save model + if args.save_path: + os.makedirs(os.path.dirname(args.save_path), exist_ok=True) + accelerator.save_model(policy, args.save_path, max_shard_size="1000GB") + + if args.upload_model and accelerator.is_main_process: + repo_name = f"{args.exp_name}__{args.rewards.label_dataset}__seed{args.seed}__{int(time.time())}" + repo_id = f"{args.hf_entity}/{repo_name}" if args.hf_entity else repo_name + policy.save_pretrained(repo_id, safe_serialization=True, push_to_hub=True) + tokenizer.save_pretrained(repo_id, push_to_hub=True) + +# if __name__ == "__main__": +# args = tyro.cli(Args) +# train(args) From 601c755c83e1d41ce39d5c88f3e02fd04fcc3fb9 Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Mon, 18 Dec 2023 04:42:43 +0000 Subject: [PATCH 37/62] update sft stuff --- .../train_sft_accelerate_summarize.py | 448 ++++++------------ 1 file changed, 154 insertions(+), 294 deletions(-) diff --git a/lm_human_preference_details/train_sft_accelerate_summarize.py b/lm_human_preference_details/train_sft_accelerate_summarize.py index 1f4e252..58b9234 100644 --- a/lm_human_preference_details/train_sft_accelerate_summarize.py +++ b/lm_human_preference_details/train_sft_accelerate_summarize.py @@ -1,4 +1,5 @@ import collections +import functools import os import random import time @@ -18,13 +19,8 @@ from rich.console import Console from rich.pretty import pprint from rich.table import Table -from torch import Tensor, optim +from torch import optim from torch.nn import functional as F -from torch.optim.optimizer import ( - _dispatch_sqrt, - _get_value, - _use_grad_for_differentiable, -) from torch.utils.data import DataLoader from torch.utils.tensorboard import SummaryWriter from transformers import ( @@ -32,25 +28,11 @@ AutoModelForCausalLM, AutoTokenizer, GenerationConfig, + PreTrainedModel, get_scheduler, ) -@dataclass -class SFTHParams: - gradient_accumulation_steps: int = 16 - local_micro_batch_size: int = 1 - noptepochs: int = 1 - lr: float = 6.35e-5 - eps: float = 1e-5 - total_episodes: tyro.conf.Suppress[int] = None - micro_batch_size: tyro.conf.Suppress[int] = None - local_batch_size: tyro.conf.Suppress[int] = None - batch_size: tyro.conf.Suppress[int] = None - world_size: tyro.conf.Suppress[int] = None - num_updates: tyro.conf.Suppress[int] = None - - @dataclass class TaskHParams: # Query params @@ -90,11 +72,11 @@ class Args: """the entity (team) of wandb's project""" cuda: bool = True """Whether to use cuda if available.""" - run_name: tyro.conf.Suppress[str] = None + run_name: Optional[str] = None """TO BE FILLED: a unique name of this run""" load_from_cache_file: bool = False """Whether to load data from the local cache file in `dataset.map`""" - upload_model: bool = False + push_to_hub: bool = False "whether to upload the saved model to huggingface" hf_entity: str = "" "the user or org name of the model repository from the Hugging Face Hub" @@ -107,16 +89,41 @@ class Args: """Whether to use deepspeed to train the model""" print_sample_output_freq: int = 220 """How often to print sample output""" - save_path: str = "models/sft_policy" + output_dir: str = "models/sft_policy" """Where to save the model""" - optimizer: Literal["tf_adam", "adam", "adamw"] = "adamw" + optimizer: Literal["adam", "adamw"] = "adamw" """Which optimizer to use""" scheduler: str = "cosine" """Which scheduler to use""" warm_up_steps: int = 0 """Number of warm up steps for the scheduler""" + run_eval: bool = False + """Whether to run evaluation""" + + local_micro_batch_size: int = 1 + """The micro batch size per GPU (HF's `per_device_train_batch_size`)""" + gradient_accumulation_steps: int = 16 + """The number of gradient accumulation steps""" + noptepochs: int = 1 + """The number of epochs to train""" + lr: float = 6.35e-5 + """The learning rate""" + eps: float = 1e-5 + """The epsilon value for the optimizer""" + + total_episodes: Optional[int] = None + """The total number of episodes in the dataset""" + micro_batch_size: Optional[int] = None + """The micro batch size across devices (HF's `per_device_train_batch_size` * `world_size`)""" + local_batch_size: Optional[int] = None + """The batch size per GPU (HF's `per_device_train_batch_size` * `gradient_accumulation_steps`)""" + batch_size: Optional[int] = None + """The batch size across devices (HF's `per_device_train_batch_size` * `world_size` * `gradient_accumulation_steps`)""" + world_size: Optional[int] = None + """The number of processes (GPUs) to use""" + num_updates: Optional[int] = None + """The number of updates to train""" task: TaskHParams = field(default_factory=TaskHParams) - sft: SFTHParams = field(default_factory=SFTHParams) # taken from https://github.com/microsoft/DeepSpeedExamples/blob/737c6740bec38b77a24a59135b6481a53d566b38/applications/DeepSpeed-Chat/training/utils/model/model_utils.py#L20C1-L26C52 @@ -138,165 +145,6 @@ def print_rich_table(title: str, df: pd.DataFrame, console: Console) -> Table: console.print(table) -def _single_tensor_adam( - params: List[Tensor], - grads: List[Tensor], - exp_avgs: List[Tensor], - exp_avg_sqs: List[Tensor], - max_exp_avg_sqs: List[Tensor], - state_steps: List[Tensor], - grad_scale: Optional[Tensor], - found_inf: Optional[Tensor], - *, - amsgrad: bool, - beta1: float, - beta2: float, - lr: float, - weight_decay: float, - eps: float, - maximize: bool, - capturable: bool, - differentiable: bool, -): - assert grad_scale is None and found_inf is None - - for i, param in enumerate(params): - grad = grads[i] if not maximize else -grads[i] - exp_avg = exp_avgs[i] - exp_avg_sq = exp_avg_sqs[i] - step_t = state_steps[i] - # update step - step_t += 1 - # Decay the first and second moment running average coefficient - exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) - exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2) - step = _get_value(step_t) - - ### pytorch adam implementation: - # bias_correction1 = 1 - beta1 ** step - # bias_correction2 = 1 - beta2 ** step - # step_size = lr / bias_correction1 - # bias_correction2_sqrt = _dispatch_sqrt(bias_correction2) - # denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps) - # param.addcdiv_(exp_avg, denom, value=-step_size) - - ### tensorflow adam implementation: - lr_t = lr * _dispatch_sqrt(1 - beta2**step) / (1 - beta1**step) - denom = exp_avg_sq.sqrt().add_(eps) - param.addcdiv_(exp_avg, denom, value=-lr_t) - - -def adam( - params: List[Tensor], - grads: List[Tensor], - exp_avgs: List[Tensor], - exp_avg_sqs: List[Tensor], - max_exp_avg_sqs: List[Tensor], - state_steps: List[Tensor], - # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627 - # setting this as kwarg for now as functional API is compiled by torch/distributed/optim - foreach: Optional[bool] = None, - capturable: bool = False, - differentiable: bool = False, - fused: Optional[bool] = None, - grad_scale: Optional[Tensor] = None, - found_inf: Optional[Tensor] = None, - *, - amsgrad: bool, - beta1: float, - beta2: float, - lr: float, - weight_decay: float, - eps: float, - maximize: bool, -): - func = _single_tensor_adam - - func( - params, - grads, - exp_avgs, - exp_avg_sqs, - max_exp_avg_sqs, - state_steps, - amsgrad=amsgrad, - beta1=beta1, - beta2=beta2, - lr=lr, - weight_decay=weight_decay, - eps=eps, - maximize=maximize, - capturable=capturable, - differentiable=differentiable, - grad_scale=grad_scale, - found_inf=found_inf, - ) - - -class AdamTensorFlowStyle(optim.Adam): - @_use_grad_for_differentiable - def step(self, closure=None): - self._cuda_graph_capture_health_check() - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - - for group in self.param_groups: - params_with_grad = [] - grads = [] - exp_avgs = [] - exp_avg_sqs = [] - max_exp_avg_sqs = [] - state_steps = [] - beta1, beta2 = group["betas"] - - self._init_group( - group, - params_with_grad, - grads, - exp_avgs, - exp_avg_sqs, - max_exp_avg_sqs, - state_steps, - ) - - adam( - params_with_grad, - grads, - exp_avgs, - exp_avg_sqs, - max_exp_avg_sqs, - state_steps, - amsgrad=group["amsgrad"], - beta1=beta1, - beta2=beta2, - lr=group["lr"], - weight_decay=group["weight_decay"], - eps=group["eps"], - maximize=group["maximize"], - foreach=group["foreach"], - capturable=group["capturable"], - differentiable=group["differentiable"], - fused=group["fused"], - grad_scale=getattr(self, "grad_scale", None), - found_inf=getattr(self, "found_inf", None), - ) - - return loss - - -def ceil_div(a, b): - return (a - 1) // b + 1 - - -def exact_div(a, b): - q = a // b - if a != q * b: - raise ValueError(f"Inexact division: {a} / {b} = {a / b}") - return q - - def generate(lm_backbone, queries, tokenizer, generation_config): """generate in a way that does not affect padding tokens""" context_length = queries.shape[1] @@ -327,17 +175,17 @@ def forward(policy, query_responses, tokenizer): # def train(args: Args): if __name__ == "__main__": args = tyro.cli(Args) - accelerator = Accelerator(gradient_accumulation_steps=args.sft.gradient_accumulation_steps) - args.sft.world_size = accelerator.num_processes - args.sft.local_batch_size = args.sft.local_micro_batch_size * args.sft.gradient_accumulation_steps - args.sft.micro_batch_size = int(args.sft.local_micro_batch_size * args.sft.world_size) - args.sft.batch_size = int(args.sft.local_batch_size * args.sft.world_size) + accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps) + args.world_size = accelerator.num_processes + args.local_batch_size = args.local_micro_batch_size * args.gradient_accumulation_steps + args.micro_batch_size = int(args.local_micro_batch_size * args.world_size) + args.batch_size = int(args.local_batch_size * args.world_size) dataset = load_dataset(args.task.query_dataset, split="train") validation_dataset = load_dataset(args.task.query_dataset, split="validation") accelerator.print("The number of samples in dataset", len(dataset)) accelerator.print("The number of samples in validation_dataset", len(validation_dataset)) - args.sft.total_episodes = len(dataset) - args.sft.num_updates = args.sft.total_episodes // args.sft.local_batch_size + args.total_episodes = len(dataset) + args.num_updates = args.total_episodes // args.local_batch_size console = Console(force_terminal=True) run_name = f"{args.exp_name}__{args.seed}__{int(time.time())}" @@ -379,27 +227,25 @@ def forward(policy, query_responses, tokenizer): configure_dropout(model_config, args.dropout_layer_keys, 0.0) # disable dropout if accelerator.is_main_process: pprint(model_config) - policy = AutoModelForCausalLM.from_pretrained(args.base_model, config=model_config, trust_remote_code=True) + policy: PreTrainedModel = AutoModelForCausalLM.from_pretrained(args.base_model, config=model_config, trust_remote_code=True) policy.generation_config.eos_token_id = None # disable `pad_token_id` and `eos_token_id` because we just want to policy.generation_config.pad_token_id = None # generate tokens without truncation / padding - if args.optimizer == "tf_adam": - optimizer = AdamTensorFlowStyle(policy.parameters(), lr=args.sft.lr, eps=args.sft.eps) - elif args.optimizer == "adam": - optimizer = optim.Adam(policy.parameters(), lr=args.sft.lr, eps=args.sft.eps) + if args.optimizer == "adam": + optimizer = optim.Adam(policy.parameters(), lr=args.lr, eps=args.eps) elif args.optimizer == "adamw": - optimizer = optim.AdamW(policy.parameters(), lr=args.sft.lr, eps=args.sft.eps) + optimizer = optim.AdamW(policy.parameters(), lr=args.lr, eps=args.eps) scheduler = get_scheduler( args.scheduler, optimizer=optimizer, num_warmup_steps=args.warm_up_steps, - num_training_steps=args.sft.num_updates, + num_training_steps=args.num_updates, ) dataset = dataset.with_format("torch", columns=["query_token", "reference_response_token"]) dataset = dataset.shuffle(seed=local_seed) - dataloader = DataLoader(dataset, batch_size=args.sft.local_micro_batch_size) + dataloader = DataLoader(dataset, batch_size=args.local_micro_batch_size) validation_dataset = validation_dataset.with_format("torch", columns=["query_token", "reference_response_token"]) - validation_dataloader = DataLoader(validation_dataset, batch_size=args.sft.local_micro_batch_size) + validation_dataloader = DataLoader(validation_dataset, batch_size=args.local_micro_batch_size) policy, optimizer, dataloader, scheduler = accelerator.prepare( policy, optimizer, dataloader, scheduler ) @@ -416,15 +262,15 @@ def forward(policy, query_responses, tokenizer): ) rouge = evaluate.load("rouge") - print("===training policy===") - loss_stats = torch.zeros(args.sft.gradient_accumulation_steps, device=device) + accelerator.print("===training policy===") + loss_stats = torch.zeros(args.gradient_accumulation_steps, device=device) policy.train() gradient_accumulation_idx = 0 global_step = 0 update = 0 for data in dataloader: update += 1 - global_step += args.sft.micro_batch_size + global_step += args.micro_batch_size reference_responses = data["reference_response_token"].to(device, non_blocking=True) queries = data["query_token"].to(device, non_blocking=True) query_responses = torch.cat((queries, reference_responses), dim=1) @@ -443,102 +289,116 @@ def forward(policy, query_responses, tokenizer): optimizer.zero_grad() scheduler.step() loss_stats[gradient_accumulation_idx] = loss - gradient_accumulation_idx = (gradient_accumulation_idx + 1) % args.sft.gradient_accumulation_steps - if update > 1 and (update - 1) % args.sft.gradient_accumulation_steps == 0: + gradient_accumulation_idx = (gradient_accumulation_idx + 1) % args.gradient_accumulation_steps + if update > 1 and (update - 1) % args.gradient_accumulation_steps == 0: writer.add_scalar("loss", accelerator.gather(loss_stats).mean().item(), update) writer.add_scalar("lr", scheduler.get_last_lr()[0], update) accelerator.print(f"{loss.item()=}, {scheduler.get_last_lr()=}, {update=}") - # if update == args.sft.num_updates - 1: # update == 1 or - policy.eval() - rouge_scores = collections.defaultdict(list) - all_decode_validation_queries = [] - all_decode_validation_query_responses = [] - all_decode_validation_responses = [] - all_decode_validation_reference_responses = [] - all_validation_losses = [] - for validation_idx, validation_data in tqdm(enumerate(validation_dataloader)): - with torch.no_grad(): - validation_reference_responses = validation_data["reference_response_token"].to(device, non_blocking=True) - validation_queries = validation_data["query_token"].to(device, non_blocking=True) - validation_query_reference_responses = torch.cat( - (validation_queries, validation_reference_responses), dim=1 + break + if args.run_eval: + policy.eval() + rouge_scores = collections.defaultdict(list) + all_decode_validation_queries = [] + all_decode_validation_query_responses = [] + all_decode_validation_responses = [] + all_decode_validation_reference_responses = [] + all_validation_losses = [] + for validation_idx, validation_data in tqdm(enumerate(validation_dataloader)): + with torch.no_grad(): + validation_reference_responses = validation_data["reference_response_token"].to(device, non_blocking=True) + validation_queries = validation_data["query_token"].to(device, non_blocking=True) + validation_query_reference_responses = torch.cat( + (validation_queries, validation_reference_responses), dim=1 + ) + + validation_output = forward(policy, validation_query_reference_responses, tokenizer) + validation_labels = validation_query_reference_responses.masked_fill( + validation_query_reference_responses == tokenizer.pad_token_id, -1 + ) + validation_lm_logits = validation_output.logits + # hand-rolled transformer loss: Shift so that tokens < n predict n + # but unlike `transformers` we mask the padding tokens via `ignore_index=-1` + validation_shift_logits = validation_lm_logits[..., :-1, :].contiguous() + validation_shift_labels = validation_labels[..., 1:].contiguous() + validation_loss = F.cross_entropy( + validation_shift_logits.view(-1, validation_shift_logits.size(-1)), + validation_shift_labels.view(-1), + ignore_index=-1, + ) + validation_loss = accelerator.gather(validation_loss) + all_validation_losses.append(validation_loss) + + generated_responses = generate( + accelerator.unwrap_model(policy), + validation_queries, + tokenizer, + generation_config, + ) + decode_validation_queries = tokenizer.batch_decode(accelerator.gather(validation_queries)) + decode_validation_query_responses = tokenizer.batch_decode(accelerator.gather(generated_responses)) + decode_validation_reference_responses = tokenizer.batch_decode( + accelerator.gather(validation_reference_responses) + ) + decode_validation_responses = tokenizer.batch_decode(accelerator.gather(generated_responses[:, -args.task.response_length:])) + rouge_score = rouge.compute( + predictions=decode_validation_responses, references=decode_validation_reference_responses + ) + rouge_scores["rouge1"].append(rouge_score["rouge1"]) + rouge_scores["rouge2"].append(rouge_score["rouge2"]) + rouge_scores["rougeL"].append(rouge_score["rougeL"]) + + all_decode_validation_queries.extend(decode_validation_queries) + all_decode_validation_query_responses.extend(decode_validation_query_responses) + all_decode_validation_responses.extend(decode_validation_responses) + all_decode_validation_reference_responses.extend(decode_validation_reference_responses) + try: + all_df = pd.DataFrame( + { + "query": all_decode_validation_queries, + "response": all_decode_validation_responses, + "reference": all_decode_validation_reference_responses, + } ) - - validation_output = forward(policy, validation_query_reference_responses, tokenizer) - validation_labels = validation_query_reference_responses.masked_fill( - validation_query_reference_responses == tokenizer.pad_token_id, -1 - ) - validation_lm_logits = validation_output.logits - # hand-rolled transformer loss: Shift so that tokens < n predict n - # but unlike `transformers` we mask the padding tokens via `ignore_index=-1` - validation_shift_logits = validation_lm_logits[..., :-1, :].contiguous() - validation_shift_labels = validation_labels[..., 1:].contiguous() - validation_loss = F.cross_entropy( - validation_shift_logits.view(-1, validation_shift_logits.size(-1)), - validation_shift_labels.view(-1), - ignore_index=-1, - ) - validation_loss = accelerator.gather(validation_loss) - all_validation_losses.append(validation_loss) - - generated_responses = generate( - accelerator.unwrap_model(policy), - validation_queries, - tokenizer, - generation_config, - ) - decode_validation_queries = tokenizer.batch_decode(accelerator.gather(validation_queries)) - decode_validation_query_responses = tokenizer.batch_decode(accelerator.gather(generated_responses)) - decode_validation_reference_responses = tokenizer.batch_decode( - accelerator.gather(validation_reference_responses) - ) - decode_validation_responses = tokenizer.batch_decode(accelerator.gather(generated_responses[:, -args.task.response_length:])) - rouge_score = rouge.compute( - predictions=decode_validation_responses, references=decode_validation_reference_responses - ) - rouge_scores["rouge1"].append(rouge_score["rouge1"]) - rouge_scores["rouge2"].append(rouge_score["rouge2"]) - rouge_scores["rougeL"].append(rouge_score["rougeL"]) - - all_decode_validation_queries.extend(decode_validation_queries) - all_decode_validation_query_responses.extend(decode_validation_query_responses) - all_decode_validation_responses.extend(decode_validation_responses) - all_decode_validation_reference_responses.extend(decode_validation_reference_responses) - # if validation_idx == 10: - # break - try: - all_df = pd.DataFrame( - { - "query": all_decode_validation_queries, - "response": all_decode_validation_responses, - "reference": all_decode_validation_reference_responses, - } - ) - accelerator.print(all_df) - if accelerator.is_main_process and args.track: - wandb.log({"samples/query_responses": wandb.Table(dataframe=all_df)}, step=update) - print_rich_table(f"Sample Output at Step {update}", all_df[:4], console) - except Exception as e: - print(e) - - for k, v in rouge_scores.items(): - rouge_metric = torch.tensor(v, device=device) - rouge_metric = accelerator.gather(rouge_metric) - writer.add_scalar(f"rouge/{k}", rouge_metric.mean().item(), update) - accelerator.print(f"rouge/{k}: {rouge_metric.mean().item()} {rouge_metric.shape} {rouge_metric}") - writer.add_scalar("validation_loss", torch.stack(all_validation_losses).mean().item(), update) - policy.train() + accelerator.print(all_df) + if accelerator.is_main_process and args.track: + wandb.log({"samples/query_responses": wandb.Table(dataframe=all_df)}, step=update) + print_rich_table(f"Sample Output at Step {update}", all_df[:4], console) + except Exception as e: + print(e) + + for k, v in rouge_scores.items(): + rouge_metric = torch.tensor(v, device=device) + rouge_metric = accelerator.gather(rouge_metric) + writer.add_scalar(f"rouge/{k}", rouge_metric.mean().item(), update) + accelerator.print(f"rouge/{k}: {rouge_metric.mean().item()} {rouge_metric.shape} {rouge_metric}") + writer.add_scalar("validation_loss", torch.stack(all_validation_losses).mean().item(), update) # save model - if args.save_path: - os.makedirs(os.path.dirname(args.save_path), exist_ok=True) - accelerator.save_model(policy, args.save_path, max_shard_size="1000GB") - - if args.upload_model and accelerator.is_main_process: - repo_name = f"{args.exp_name}__tldr__seed{args.seed}__{int(time.time())}" - repo_id = f"{args.hf_entity}/{repo_name}" if args.hf_entity else repo_name - policy.save_pretrained(repo_id, safe_serialization=True, push_to_hub=True) - tokenizer.save_pretrained(repo_id, push_to_hub=True) + if args.output_dir: + os.makedirs(os.path.dirname(args.output_dir), exist_ok=True) + time_tensor = torch.tensor(int(time.time()), device=device) + time_int = accelerator.gather(time_tensor)[0].item() # avoid different timestamps across processes + repo_name = f"{args.base_model.replace('/', '_')}__{args.exp_name}__tldr__seed{args.seed}" + repo_id = f"{args.hf_entity}/{repo_name}" if args.hf_entity else repo_name + + if accelerator.is_main_process: + tokenizer.save_pretrained(args.output_dir, repo_id=repo_id) + if args.push_to_hub: + tokenizer.push_to_hub(repo_id, revision=str(time_int)) + + unwrapped: PreTrainedModel = accelerator.unwrap_model(policy) + accelerator.wait_for_everyone() + if accelerator.is_main_process: + unwrapped.save_pretrained( + args.output_dir, + is_main_process=accelerator.is_main_process, + save_function=accelerator.save, + state_dict=accelerator.get_state_dict(policy), + safe_serialization=False, + repo_id=repo_id, + ) + if args.push_to_hub: + unwrapped.push_to_hub(repo_id, revision=str(time_int), safe_serialization=False) # if __name__ == "__main__": # args = tyro.cli(Args) From 0f4dc10efc656ec28c0a9c8e5ad083b92d1120d8 Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Mon, 18 Dec 2023 04:43:21 +0000 Subject: [PATCH 38/62] update dependencies --- benchmark/trl.slurm_template | 4 +- poetry.lock | 954 +++++++++++++++++++++++------------ pyproject.toml | 14 +- 3 files changed, 642 insertions(+), 330 deletions(-) diff --git a/benchmark/trl.slurm_template b/benchmark/trl.slurm_template index 8f4a3d4..a7816a1 100644 --- a/benchmark/trl.slurm_template +++ b/benchmark/trl.slurm_template @@ -1,11 +1,11 @@ #!/bin/bash -#SBATCH --partition=production-cluster +#SBATCH --partition=hopper-prod #SBATCH --gpus-per-task={{gpus_per_task}} #SBATCH --cpus-per-gpu={{cpus_per_gpu}} #SBATCH --ntasks={{ntasks}} #SBATCH --output=slurm/logs/%x_%j.out #SBATCH --array={{array}} -#SBATCH --exclude=ip-26-0-149-199 +##SBATCH --exclude=ip-26-0-149-199 {{nodes}} diff --git a/poetry.lock b/poetry.lock index 2b044e6..503b091 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.6.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand. [[package]] name = "absl-py" @@ -13,31 +13,33 @@ files = [ [[package]] name = "accelerate" -version = "0.22.0" +version = "0.25.0" description = "Accelerate" optional = false python-versions = ">=3.8.0" files = [ - {file = "accelerate-0.22.0-py3-none-any.whl", hash = "sha256:d132e57bfc4b0417464997b14aa141fd88696cbb4472eb03116c2bd97542befc"}, - {file = "accelerate-0.22.0.tar.gz", hash = "sha256:2b0a83e3cd07c89448c5d5a94f72bc1db98d5e0c498ca17984871f01dbf83247"}, + {file = "accelerate-0.25.0-py3-none-any.whl", hash = "sha256:c7bb817eb974bba0ff3ea1ba0f24d55afb86d50e3d4fe98d6922dc69cf2ccff1"}, + {file = "accelerate-0.25.0.tar.gz", hash = "sha256:ecf55b0ab278a1dac8539dde0d276977aff04683f07ede73eaf02478538576a1"}, ] [package.dependencies] +huggingface-hub = "*" numpy = ">=1.17" packaging = ">=20.0" psutil = "*" pyyaml = "*" +safetensors = ">=0.3.1" torch = ">=1.10.0" [package.extras] -dev = ["bitsandbytes", "black (>=23.1,<24.0)", "datasets", "deepspeed", "evaluate", "hf-doc-builder (>=0.3.0)", "parameterized", "pytest", "pytest-subtests", "pytest-xdist", "rich", "ruff (>=0.0.241)", "scikit-learn", "scipy", "tqdm", "transformers", "urllib3 (<2.0.0)"] +dev = ["bitsandbytes", "black (>=23.1,<24.0)", "datasets", "deepspeed", "evaluate", "hf-doc-builder (>=0.3.0)", "parameterized", "pytest", "pytest-subtests", "pytest-xdist", "rich", "ruff (>=0.0.241)", "scikit-learn", "scipy", "timm", "tqdm", "transformers", "urllib3 (<2.0.0)"] quality = ["black (>=23.1,<24.0)", "hf-doc-builder (>=0.3.0)", "ruff (>=0.0.241)", "urllib3 (<2.0.0)"] rich = ["rich"] sagemaker = ["sagemaker"] -test-dev = ["bitsandbytes", "datasets", "deepspeed", "evaluate", "scikit-learn", "scipy", "tqdm", "transformers"] +test-dev = ["bitsandbytes", "datasets", "deepspeed", "evaluate", "scikit-learn", "scipy", "timm", "tqdm", "transformers"] test-prod = ["parameterized", "pytest", "pytest-subtests", "pytest-xdist"] -test-trackers = ["comet-ml", "tensorboard", "wandb"] -testing = ["bitsandbytes", "datasets", "deepspeed", "evaluate", "parameterized", "pytest", "pytest-subtests", "pytest-xdist", "scikit-learn", "scipy", "tqdm", "transformers"] +test-trackers = ["comet-ml", "dvclive", "tensorboard", "wandb"] +testing = ["bitsandbytes", "datasets", "deepspeed", "evaluate", "parameterized", "pytest", "pytest-subtests", "pytest-xdist", "scikit-learn", "scipy", "timm", "tqdm", "transformers"] [[package]] name = "aiohttp" @@ -161,6 +163,20 @@ files = [ [package.dependencies] frozenlist = ">=1.1.0" +[[package]] +name = "annotated-types" +version = "0.6.0" +description = "Reusable constraint types to use with typing.Annotated" +optional = false +python-versions = ">=3.8" +files = [ + {file = "annotated_types-0.6.0-py3-none-any.whl", hash = "sha256:0641064de18ba7a25dee8f96403ebc39113d0cb953a01429249d5c7564666a43"}, + {file = "annotated_types-0.6.0.tar.gz", hash = "sha256:563339e807e53ffd9c267e99fc6d9ea23eb8443c08f112651963e24e22f84a5d"}, +] + +[package.dependencies] +typing-extensions = {version = ">=4.0.0", markers = "python_version < \"3.9\""} + [[package]] name = "appdirs" version = "1.4.4" @@ -550,35 +566,6 @@ wrapt = "*" pytorch = ["torch (>=1.13.0)"] test = ["pytest", "tensorflow", "tensorflow-datasets", "torch (>=1.13.0,!=2.0.0)"] -[[package]] -name = "cmake" -version = "3.26.4" -description = "CMake is an open-source, cross-platform family of tools designed to build, test and package software" -optional = false -python-versions = "*" -files = [ - {file = "cmake-3.26.4-py2.py3-none-macosx_10_10_universal2.macosx_10_10_x86_64.macosx_11_0_arm64.macosx_11_0_universal2.whl", hash = "sha256:230227bf99f36614de84cdc92ffce3a50eb2803020e946f8da945a08fcf766bf"}, - {file = "cmake-3.26.4-py2.py3-none-manylinux2010_i686.manylinux_2_12_i686.whl", hash = "sha256:248a90816abfc10ff6e1109b54b8235c3e62f0ac92da16541753deb3b5ae063d"}, - {file = "cmake-3.26.4-py2.py3-none-manylinux2010_x86_64.manylinux_2_12_x86_64.whl", hash = "sha256:1b92f9f59f48c803106dbdd6750b0f571a0500e25d3a62c42ba84bb7a9240d10"}, - {file = "cmake-3.26.4-py2.py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:3175442985558d5415b97f264a6a1bb0af5ecfe10e3f7510257b1ea66bd33848"}, - {file = "cmake-3.26.4-py2.py3-none-manylinux2014_i686.manylinux_2_17_i686.whl", hash = "sha256:1d887be5f1a3f17559a78707a6bc0560f4f8cb93cebb9d823d90a63e68bae09b"}, - {file = "cmake-3.26.4-py2.py3-none-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:235d8eac93a28dcce5a1cd7130412885a2aa53d5735cb2230e0f26f589347b65"}, - {file = "cmake-3.26.4-py2.py3-none-manylinux2014_s390x.manylinux_2_17_s390x.whl", hash = "sha256:05cfd76c637eb22058c95e2dc383cadd4e0615e2643e637bb498a6cc24825790"}, - {file = "cmake-3.26.4-py2.py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:93015da6f1c0e1e5f2debf752f1803ea52d742d915ad674043d36e471f937507"}, - {file = "cmake-3.26.4-py2.py3-none-musllinux_1_1_aarch64.whl", hash = "sha256:d726671ae7ae4aa6989e73d26b9f8f8e6af45163a26ea243949d72246566fdd8"}, - {file = "cmake-3.26.4-py2.py3-none-musllinux_1_1_i686.whl", hash = "sha256:432837364aa6cab2826a72e8a4cdd3586f5ac9ce495217ccd59aa70f2bba8120"}, - {file = "cmake-3.26.4-py2.py3-none-musllinux_1_1_ppc64le.whl", hash = "sha256:24110035aff586a04a6a6fcf4609270642e4f503c0620c962dff75b653f81414"}, - {file = "cmake-3.26.4-py2.py3-none-musllinux_1_1_s390x.whl", hash = "sha256:3e280e81713408987b7053f5b922c9f94e45668ca6efff1f02846309ca0b5b0f"}, - {file = "cmake-3.26.4-py2.py3-none-musllinux_1_1_x86_64.whl", hash = "sha256:c3b0e72750c0f6c0373242c1299bc4ffdbebdd5004966ae6df0b2e9845aa6990"}, - {file = "cmake-3.26.4-py2.py3-none-win32.whl", hash = "sha256:e058e59154a1e490fb9425b420f87e28144292397607638d73e323509f7efae6"}, - {file = "cmake-3.26.4-py2.py3-none-win_amd64.whl", hash = "sha256:b7a6946c345497c14064e0c9585b30f5aaebbefdfc0b245b6bb5a978eb4fc85f"}, - {file = "cmake-3.26.4-py2.py3-none-win_arm64.whl", hash = "sha256:93a03bad17b9741acaff4a8651f8596496506602fa123e70fe67142f1b21ee2e"}, - {file = "cmake-3.26.4.tar.gz", hash = "sha256:d45b30b9ce7280829888c78650177ab525df2b6785e1a5b3d82b4c147d828c0e"}, -] - -[package.extras] -test = ["coverage (>=4.2)", "flake8 (>=3.0.4)", "path.py (>=11.5.0)", "pytest (>=3.0.3)", "pytest-cov (>=2.4.0)", "pytest-runner (>=2.9)", "pytest-virtualenv (>=1.7.0)", "scikit-build (>=0.10.0)", "setuptools (>=28.0.0)", "virtualenv (>=15.0.3)", "wheel"] - [[package]] name = "colorama" version = "0.4.6" @@ -701,6 +688,41 @@ files = [ {file = "decorator-5.1.1.tar.gz", hash = "sha256:637996211036b6385ef91435e4fae22989472f9d571faba8927ba8253acbc330"}, ] +[[package]] +name = "deepspeed" +version = "0.12.5" +description = "DeepSpeed library" +optional = false +python-versions = "*" +files = [ + {file = "deepspeed-0.12.5.tar.gz", hash = "sha256:7aca1e761f21792b49cbbb6b6ce6ef1cd5fb17d5738835aee3680b0a1c5a8234"}, +] + +[package.dependencies] +hjson = "*" +ninja = "*" +numpy = "*" +packaging = ">=20.0" +psutil = "*" +py-cpuinfo = "*" +pydantic = "*" +pynvml = "*" +torch = "*" +tqdm = "*" + +[package.extras] +1bit-mpi = ["mpi4py"] +all = ["accelerate", "autodoc_pydantic", "clang-format (==16.0.2)", "coverage", "deepspeed-kernels", "diffusers", "docutils (<0.18)", "future", "google", "hjson", "importlib-metadata (>=4)", "lm-eval (==0.3.0)", "mpi4py", "mup", "neural-compressor (==2.1.0)", "packaging", "pre-commit (>=2.20.0)", "protobuf", "psutil", "py-cpuinfo", "pydantic (<2.0.0)", "pytest", "pytest-forked", "pytest-randomly", "pytest-xdist", "recommonmark", "sphinx", "sphinx-rtd-theme", "sphinx_rtd_theme", "tabulate", "tensorboard", "torch", "torchvision", "tqdm", "transformers (>=4.32.1)", "transformers[sentencepiece]", "triton (==1.0.0)", "triton (>=2.1.0)", "wandb", "xgboost"] +autotuning = ["tabulate"] +autotuning-ml = ["hjson", "tabulate", "xgboost"] +dev = ["accelerate", "clang-format (==16.0.2)", "coverage", "deepspeed-kernels", "docutils (<0.18)", "future", "importlib-metadata (>=4)", "mup", "pre-commit (>=2.20.0)", "pytest", "pytest-forked", "pytest-randomly", "pytest-xdist", "recommonmark", "sphinx", "sphinx-rtd-theme", "tensorboard", "torchvision", "transformers (>=4.32.1)", "wandb"] +inf = ["google", "lm-eval (==0.3.0)", "protobuf", "transformers (>=4.32.1)", "transformers[sentencepiece]"] +readthedocs = ["autodoc_pydantic", "docutils (<0.18)", "hjson", "packaging", "psutil", "py-cpuinfo", "pydantic (<2.0.0)", "recommonmark", "sphinx_rtd_theme", "torch", "tqdm"] +sd = ["diffusers", "triton (>=2.1.0)"] +sparse = ["neural-compressor (==2.1.0)"] +sparse-attn = ["triton (==1.0.0)"] +triton = ["triton (>=2.1.0)"] + [[package]] name = "dill" version = "0.3.6" @@ -843,6 +865,42 @@ etree-jax = ["etils[etree]", "jax[cpu]"] etree-tf = ["etils[etree]", "tensorflow"] lazy-imports = ["etils[ecolab]"] +[[package]] +name = "evaluate" +version = "0.4.1" +description = "HuggingFace community-driven open-source library of evaluation" +optional = false +python-versions = ">=3.7.0" +files = [ + {file = "evaluate-0.4.1-py3-none-any.whl", hash = "sha256:3ff079ab09572c0a2c1e6d749887c19f6783ab993320412cd39f6fe501d28510"}, + {file = "evaluate-0.4.1.tar.gz", hash = "sha256:d721d9f2059ced79770d8a0509e954fbd1bbac96a8f9160e29888d8073cda3d9"}, +] + +[package.dependencies] +datasets = ">=2.0.0" +dill = "*" +fsspec = {version = ">=2021.05.0", extras = ["http"]} +huggingface-hub = ">=0.7.0" +multiprocess = "*" +numpy = ">=1.17" +packaging = "*" +pandas = "*" +requests = ">=2.19.0" +responses = "<0.19" +tqdm = ">=4.62.1" +xxhash = "*" + +[package.extras] +dev = ["Werkzeug (>=1.0.1)", "absl-py", "accelerate", "bert-score (>=0.3.6)", "black (>=22.0,<23.0)", "cer (>=1.2.0)", "charcut (>=1.1.1)", "flake8 (>=3.8.3)", "isort (>=5.0.0)", "jiwer", "mauve-text", "nltk", "pytest", "pytest-datadir", "pytest-xdist", "pyyaml (>=5.3.1)", "requests-file (>=1.5.1)", "rouge-score (>=0.1.2)", "sacrebleu", "sacremoses", "scikit-learn", "scipy", "sentencepiece", "seqeval", "six (>=1.15.0,<1.16.0)", "tensorflow (>=2.3,!=2.6.0,!=2.6.1,<=2.10)", "texttable (>=1.6.3)", "tldextract (>=3.1.0)", "toml (>=0.10.1)", "torch", "transformers", "trectools", "unidecode (>=1.3.4)"] +docs = ["s3fs"] +evaluator = ["scipy (>=1.7.1)", "transformers"] +quality = ["black (>=22.0,<23.0)", "flake8 (>=3.8.3)", "isort (>=5.0.0)", "pyyaml (>=5.3.1)"] +template = ["cookiecutter", "gradio (>=3.0.0)"] +tensorflow = ["tensorflow (>=2.2.0,!=2.6.0,!=2.6.1)"] +tensorflow-gpu = ["tensorflow-gpu (>=2.2.0,!=2.6.0,!=2.6.1)"] +tests = ["Werkzeug (>=1.0.1)", "absl-py", "accelerate", "bert-score (>=0.3.6)", "cer (>=1.2.0)", "charcut (>=1.1.1)", "jiwer", "mauve-text", "nltk", "pytest", "pytest-datadir", "pytest-xdist", "requests-file (>=1.5.1)", "rouge-score (>=0.1.2)", "sacrebleu", "sacremoses", "scikit-learn", "scipy", "sentencepiece", "seqeval", "six (>=1.15.0,<1.16.0)", "tensorflow (>=2.3,!=2.6.0,!=2.6.1,<=2.10)", "texttable (>=1.6.3)", "tldextract (>=3.1.0)", "toml (>=0.10.1)", "torch", "transformers", "trectools", "unidecode (>=1.3.4)"] +torch = ["torch"] + [[package]] name = "exceptiongroup" version = "1.1.3" @@ -912,52 +970,6 @@ typing-extensions = ">=4.1.1" all = ["matplotlib"] testing = ["atari-py (==0.2.5)", "clu", "einops", "gym (==0.18.3)", "jaxlib", "jraph (>=0.0.6dev0)", "ml-collections", "mypy", "nbstripout", "opencv-python", "pytest", "pytest-cov", "pytest-custom-exit-code", "pytest-xdist (==1.34.0)", "pytype", "sentencepiece", "tensorflow", "tensorflow-datasets", "tensorflow-text (>=2.11.0)", "torch"] -[[package]] -name = "frozendict" -version = "2.3.8" -description = "A simple immutable dictionary" -optional = false -python-versions = ">=3.6" -files = [ - {file = "frozendict-2.3.8-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:d188d062084fba0e4bf32719ff7380b26c050b932ff164043ce82ab90587c52b"}, - {file = "frozendict-2.3.8-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:f2a4e818ac457f6354401dcb631527af25e5a20fcfc81e6b5054b45fc245caca"}, - {file = "frozendict-2.3.8-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9a506d807858fa961aaa5b48dab6154fdc6bd045bbe9310788bbff141bb42d13"}, - {file = "frozendict-2.3.8-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:750632cc890d8ee9484fe6d31b261159144b6efacc08e1317fe46accd1410373"}, - {file = "frozendict-2.3.8-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:7ee5fe2658a8ac9a57f748acaf563f6a47f80b8308cbf0a04fac0ba057d41f75"}, - {file = "frozendict-2.3.8-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:23c4bb46e6b8246e1e7e49b5593c2bc09221db0d8f31f7c092be8dfb42b9e620"}, - {file = "frozendict-2.3.8-cp310-cp310-win_amd64.whl", hash = "sha256:c31abc8acea309b132dde441856829f6003a3d242da8b54bce4c0f2a3c8c63f0"}, - {file = "frozendict-2.3.8-cp310-cp310-win_arm64.whl", hash = "sha256:9ea5520e85447ff8d4681e181941e482662817ccba921b7cb3f87922056d892a"}, - {file = "frozendict-2.3.8-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:f83fed36497af9562ead5e9fb8443224ba2781786bd3b92b1087cb7d0ff20135"}, - {file = "frozendict-2.3.8-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e27c5c1d29d0eda7979253ec88abc239da1313b38f39f4b16984db3b3e482300"}, - {file = "frozendict-2.3.8-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e4c785de7f1a13f15963945f400656b18f057c2fc76c089dacf127a2bb188c03"}, - {file = "frozendict-2.3.8-cp36-cp36m-musllinux_1_1_aarch64.whl", hash = "sha256:8cf35ddd25513428ec152614def9696afb93ae5ec0eb54fa6aa6206eda77ac4c"}, - {file = "frozendict-2.3.8-cp36-cp36m-musllinux_1_1_x86_64.whl", hash = "sha256:ffc684773de7c88724788fa9787d0016fd75830412d58acbd9ed1a04762c675b"}, - {file = "frozendict-2.3.8-cp36-cp36m-win_amd64.whl", hash = "sha256:4c258aab9c8488338634f2ec670ef049dbf0ab0e7a2fa9bc2c7b5009cb614801"}, - {file = "frozendict-2.3.8-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:47fc26468407fdeb428cfc89495b7921419e670355c21b383765482fdf6c5c14"}, - {file = "frozendict-2.3.8-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6ea638228692db2bf94bce40ea4b25f4077588497b516bd16576575560094bd9"}, - {file = "frozendict-2.3.8-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7a75bf87e76c4386caecdbdd02a99e53ad43a6b5c38fb3d5a634a9fc9ce41462"}, - {file = "frozendict-2.3.8-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:ed5a6c5c7a0f57269577c2a338a6002949aea21a23b7b7d06da7e7dced8b605b"}, - {file = "frozendict-2.3.8-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:d086440328a465dea9bef2dbad7548d75d1a0a0d21f43a08c03e1ec79ac5240e"}, - {file = "frozendict-2.3.8-cp37-cp37m-win_amd64.whl", hash = "sha256:0bc4767e2f83db5b701c787e22380296977368b0c57e485ca71b2eedfa11c4a3"}, - {file = "frozendict-2.3.8-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:638cf363d3cbca31a341503cf2219eac52a5f5140449676fae3d9644cd3c5487"}, - {file = "frozendict-2.3.8-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:2b2fd8ce36277919b36e3c834d2389f3cd7ac068ae730c312671dd4439a5dd65"}, - {file = "frozendict-2.3.8-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3957d52f1906b0c85f641a1911d214255873f6408ab4e5ad657cc27a247fb145"}, - {file = "frozendict-2.3.8-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:72cfe08ab8ae524e54848fa90b22d02c1b1ecfb3064438696bcaa4b953f18772"}, - {file = "frozendict-2.3.8-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:4742e76c4111bd09198d3ab66cef94be8506212311338f9182d6ef5f5cb60493"}, - {file = "frozendict-2.3.8-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:313ed8d9ba6bac35d7635cd9580ee5721a0fb016f4d2d20f0efa05dbecbdb1be"}, - {file = "frozendict-2.3.8-cp38-cp38-win_amd64.whl", hash = "sha256:d3c6ce943946c2a61501c8cf116fff0892d11dd579877eb36e2aea2c27fddfef"}, - {file = "frozendict-2.3.8-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:f0f573dc4861dd7ec9e055c8cceaf45355e894e749f621f199aab7b311ac4bdb"}, - {file = "frozendict-2.3.8-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:2b3435e5f1ca5ae68a5e95e64b09d6d5c645cadd6b87569a0b3019dd248c8d00"}, - {file = "frozendict-2.3.8-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:145afd033ebfade28416093335261b8ec1af5cccc593482309e7add062ec8668"}, - {file = "frozendict-2.3.8-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:da98427de26b5a2865727947480cbb53860089c4d195baa29c539da811cea617"}, - {file = "frozendict-2.3.8-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:5e82befa7c385a668d569cebbebbdf49cee6fea4083f08e869a1b08cfb640a9f"}, - {file = "frozendict-2.3.8-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:80abe81d36e889ceec665e06ec764a7638000fa3e7be09786ac4d3ddc64b76db"}, - {file = "frozendict-2.3.8-cp39-cp39-win_amd64.whl", hash = "sha256:8ccc94ac781710db44e142e1a11ff9b31d02c032c01c6868d51fcbef73086225"}, - {file = "frozendict-2.3.8-cp39-cp39-win_arm64.whl", hash = "sha256:e72dbc1bcc2203cef38d205f692396f5505921a5680f66aa9a7e8bb71fd38f28"}, - {file = "frozendict-2.3.8-py311-none-any.whl", hash = "sha256:ba41a7ed019bd03b62d63ed3f8dea35b8243d1936f7c9ed4b5298ca45a01928e"}, - {file = "frozendict-2.3.8.tar.gz", hash = "sha256:5526559eca8f1780a4ee5146896f59afc31435313560208dd394a3a5e537d3ff"}, -] - [[package]] name = "frozenlist" version = "1.3.3" @@ -1221,20 +1233,31 @@ files = [ [package.extras] protobuf = ["grpcio-tools (>=1.54.2)"] +[[package]] +name = "hjson" +version = "3.1.0" +description = "Hjson, a user interface for JSON." +optional = false +python-versions = "*" +files = [ + {file = "hjson-3.1.0-py3-none-any.whl", hash = "sha256:65713cdcf13214fb554eb8b4ef803419733f4f5e551047c9b711098ab7186b89"}, + {file = "hjson-3.1.0.tar.gz", hash = "sha256:55af475a27cf83a7969c808399d7bccdec8fb836a07ddbd574587593b9cdcf75"}, +] + [[package]] name = "huggingface-hub" -version = "0.15.1" +version = "0.19.4" description = "Client library to download and publish models, datasets and other repos on the huggingface.co hub" optional = false -python-versions = ">=3.7.0" +python-versions = ">=3.8.0" files = [ - {file = "huggingface_hub-0.15.1-py3-none-any.whl", hash = "sha256:05b0fb0abbf1f625dfee864648ac3049fe225ac4371c7bafaca0c2d3a2f83445"}, - {file = "huggingface_hub-0.15.1.tar.gz", hash = "sha256:a61b7d1a7769fe10119e730277c72ab99d95c48d86a3d6da3e9f3d0f632a4081"}, + {file = "huggingface_hub-0.19.4-py3-none-any.whl", hash = "sha256:dba013f779da16f14b606492828f3760600a1e1801432d09fe1c33e50b825bb5"}, + {file = "huggingface_hub-0.19.4.tar.gz", hash = "sha256:176a4fc355a851c17550e7619488f383189727eab209534d7cef2114dae77b22"}, ] [package.dependencies] filelock = "*" -fsspec = "*" +fsspec = ">=2023.5.0" packaging = ">=20.9" pyyaml = ">=5.1" requests = "*" @@ -1242,15 +1265,17 @@ tqdm = ">=4.42.1" typing-extensions = ">=3.7.4.3" [package.extras] -all = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "black (>=23.1,<24.0)", "gradio", "jedi", "mypy (==0.982)", "numpy", "pytest", "pytest-cov", "pytest-env", "pytest-vcr", "pytest-xdist", "ruff (>=0.0.241)", "soundfile", "types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "urllib3 (<2.0)"] +all = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "gradio", "jedi", "mypy (==1.5.1)", "numpy", "pydantic (>1.1,<2.0)", "pydantic (>1.1,<3.0)", "pytest", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-vcr", "pytest-xdist", "ruff (>=0.1.3)", "soundfile", "types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "typing-extensions (>=4.8.0)", "urllib3 (<2.0)"] cli = ["InquirerPy (==0.3.4)"] -dev = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "black (>=23.1,<24.0)", "gradio", "jedi", "mypy (==0.982)", "numpy", "pytest", "pytest-cov", "pytest-env", "pytest-vcr", "pytest-xdist", "ruff (>=0.0.241)", "soundfile", "types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "urllib3 (<2.0)"] +dev = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "gradio", "jedi", "mypy (==1.5.1)", "numpy", "pydantic (>1.1,<2.0)", "pydantic (>1.1,<3.0)", "pytest", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-vcr", "pytest-xdist", "ruff (>=0.1.3)", "soundfile", "types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "typing-extensions (>=4.8.0)", "urllib3 (<2.0)"] +docs = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "gradio", "hf-doc-builder", "jedi", "mypy (==1.5.1)", "numpy", "pydantic (>1.1,<2.0)", "pydantic (>1.1,<3.0)", "pytest", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-vcr", "pytest-xdist", "ruff (>=0.1.3)", "soundfile", "types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "typing-extensions (>=4.8.0)", "urllib3 (<2.0)", "watchdog"] fastai = ["fastai (>=2.4)", "fastcore (>=1.3.27)", "toml"] -quality = ["black (>=23.1,<24.0)", "mypy (==0.982)", "ruff (>=0.0.241)"] +inference = ["aiohttp", "pydantic (>1.1,<2.0)", "pydantic (>1.1,<3.0)"] +quality = ["mypy (==1.5.1)", "ruff (>=0.1.3)"] tensorflow = ["graphviz", "pydot", "tensorflow"] -testing = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "gradio", "jedi", "numpy", "pytest", "pytest-cov", "pytest-env", "pytest-vcr", "pytest-xdist", "soundfile", "urllib3 (<2.0)"] +testing = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "gradio", "jedi", "numpy", "pydantic (>1.1,<2.0)", "pydantic (>1.1,<3.0)", "pytest", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-vcr", "pytest-xdist", "soundfile", "urllib3 (<2.0)"] torch = ["torch"] -typing = ["types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3"] +typing = ["types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "typing-extensions (>=4.8.0)"] [[package]] name = "identify" @@ -1489,6 +1514,17 @@ MarkupSafe = ">=2.0" [package.extras] i18n = ["Babel (>=2.7)"] +[[package]] +name = "joblib" +version = "1.3.2" +description = "Lightweight pipelining with Python functions" +optional = false +python-versions = ">=3.7" +files = [ + {file = "joblib-1.3.2-py3-none-any.whl", hash = "sha256:ef4331c65f239985f3f2220ecc87db222f08fd22097a3dd5698f693875f8cbb9"}, + {file = "joblib-1.3.2.tar.gz", hash = "sha256:92f865e621e17784e7955080b6d042489e3b8e294949cc44c6eac304f59772b1"}, +] + [[package]] name = "jupyter-client" version = "8.3.0" @@ -1532,16 +1568,6 @@ traitlets = ">=5.3" docs = ["myst-parser", "sphinx-autodoc-typehints", "sphinxcontrib-github-alt", "sphinxcontrib-spelling", "traitlets"] test = ["ipykernel", "pre-commit", "pytest", "pytest-cov", "pytest-timeout"] -[[package]] -name = "lit" -version = "16.0.6" -description = "A Software Testing Tool" -optional = false -python-versions = "*" -files = [ - {file = "lit-16.0.6.tar.gz", hash = "sha256:84623c9c23b6b14763d637f4e63e6b721b3446ada40bf7001d8fee70b8e77a9a"}, -] - [[package]] name = "markdown" version = "3.4.3" @@ -1610,6 +1636,16 @@ files = [ {file = "MarkupSafe-2.1.3-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:5bbe06f8eeafd38e5d0a4894ffec89378b6c6a625ff57e3028921f8ff59318ac"}, {file = "MarkupSafe-2.1.3-cp311-cp311-win32.whl", hash = "sha256:dd15ff04ffd7e05ffcb7fe79f1b98041b8ea30ae9234aed2a9168b5797c3effb"}, {file = "MarkupSafe-2.1.3-cp311-cp311-win_amd64.whl", hash = "sha256:134da1eca9ec0ae528110ccc9e48041e0828d79f24121a1a146161103c76e686"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:f698de3fd0c4e6972b92290a45bd9b1536bffe8c6759c62471efaa8acb4c37bc"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:aa57bd9cf8ae831a362185ee444e15a93ecb2e344c8e52e4d721ea3ab6ef1823"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ffcc3f7c66b5f5b7931a5aa68fc9cecc51e685ef90282f4a82f0f5e9b704ad11"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:47d4f1c5f80fc62fdd7777d0d40a2e9dda0a05883ab11374334f6c4de38adffd"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1f67c7038d560d92149c060157d623c542173016c4babc0c1913cca0564b9939"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:9aad3c1755095ce347e26488214ef77e0485a3c34a50c5a5e2471dff60b9dd9c"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:14ff806850827afd6b07a5f32bd917fb7f45b046ba40c57abdb636674a8b559c"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8f9293864fe09b8149f0cc42ce56e3f0e54de883a9de90cd427f191c346eb2e1"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-win32.whl", hash = "sha256:715d3562f79d540f251b99ebd6d8baa547118974341db04f5ad06d5ea3eb8007"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-win_amd64.whl", hash = "sha256:1b8dd8c3fd14349433c79fa8abeb573a55fc0fdd769133baac1f5e07abf54aeb"}, {file = "MarkupSafe-2.1.3-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:8e254ae696c88d98da6555f5ace2279cf7cd5b3f52be2b5cf97feafe883b58d2"}, {file = "MarkupSafe-2.1.3-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cb0932dc158471523c9637e807d9bfb93e06a95cbf010f1a38b98623b929ef2b"}, {file = "MarkupSafe-2.1.3-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9402b03f1a1b4dc4c19845e5c749e3ab82d5078d16a2a4c2cd2df62d57bb0707"}, @@ -1712,8 +1748,8 @@ files = [ [package.dependencies] numpy = [ {version = ">1.20", markers = "python_version <= \"3.9\""}, - {version = ">=1.23.3", markers = "python_version > \"3.10\""}, {version = ">=1.21.2", markers = "python_version > \"3.9\" and python_version <= \"3.10\""}, + {version = ">=1.23.3", markers = "python_version > \"3.10\""}, ] [package.extras] @@ -1957,6 +1993,58 @@ doc = ["nb2plots (>=0.6)", "numpydoc (>=1.5)", "pillow (>=9.4)", "pydata-sphinx- extra = ["lxml (>=4.6)", "pydot (>=1.4.2)", "pygraphviz (>=1.10)", "sympy (>=1.10)"] test = ["codecov (>=2.1)", "pytest (>=7.2)", "pytest-cov (>=4.0)"] +[[package]] +name = "ninja" +version = "1.11.1.1" +description = "Ninja is a small build system with a focus on speed" +optional = false +python-versions = "*" +files = [ + {file = "ninja-1.11.1.1-py2.py3-none-macosx_10_9_universal2.macosx_10_9_x86_64.macosx_11_0_arm64.macosx_11_0_universal2.whl", hash = "sha256:376889c76d87b95b5719fdd61dd7db193aa7fd4432e5d52d2e44e4c497bdbbee"}, + {file = "ninja-1.11.1.1-py2.py3-none-manylinux1_i686.manylinux_2_5_i686.whl", hash = "sha256:ecf80cf5afd09f14dcceff28cb3f11dc90fb97c999c89307aea435889cb66877"}, + {file = "ninja-1.11.1.1-py2.py3-none-manylinux1_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:84502ec98f02a037a169c4b0d5d86075eaf6afc55e1879003d6cab51ced2ea4b"}, + {file = "ninja-1.11.1.1-py2.py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:73b93c14046447c7c5cc892433d4fae65d6364bec6685411cb97a8bcf815f93a"}, + {file = "ninja-1.11.1.1-py2.py3-none-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:18302d96a5467ea98b68e1cae1ae4b4fb2b2a56a82b955193c637557c7273dbd"}, + {file = "ninja-1.11.1.1-py2.py3-none-manylinux2014_s390x.manylinux_2_17_s390x.whl", hash = "sha256:aad34a70ef15b12519946c5633344bc775a7656d789d9ed5fdb0d456383716ef"}, + {file = "ninja-1.11.1.1-py2.py3-none-musllinux_1_1_aarch64.whl", hash = "sha256:d491fc8d89cdcb416107c349ad1e3a735d4c4af5e1cb8f5f727baca6350fdaea"}, + {file = "ninja-1.11.1.1-py2.py3-none-musllinux_1_1_i686.whl", hash = "sha256:7563ce1d9fe6ed5af0b8dd9ab4a214bf4ff1f2f6fd6dc29f480981f0f8b8b249"}, + {file = "ninja-1.11.1.1-py2.py3-none-musllinux_1_1_ppc64le.whl", hash = "sha256:9df724344202b83018abb45cb1efc22efd337a1496514e7e6b3b59655be85205"}, + {file = "ninja-1.11.1.1-py2.py3-none-musllinux_1_1_s390x.whl", hash = "sha256:3e0f9be5bb20d74d58c66cc1c414c3e6aeb45c35b0d0e41e8d739c2c0d57784f"}, + {file = "ninja-1.11.1.1-py2.py3-none-musllinux_1_1_x86_64.whl", hash = "sha256:76482ba746a2618eecf89d5253c0d1e4f1da1270d41e9f54dfbd91831b0f6885"}, + {file = "ninja-1.11.1.1-py2.py3-none-win32.whl", hash = "sha256:fa2ba9d74acfdfbfbcf06fad1b8282de8a7a8c481d9dee45c859a8c93fcc1082"}, + {file = "ninja-1.11.1.1-py2.py3-none-win_amd64.whl", hash = "sha256:95da904130bfa02ea74ff9c0116b4ad266174fafb1c707aa50212bc7859aebf1"}, + {file = "ninja-1.11.1.1-py2.py3-none-win_arm64.whl", hash = "sha256:185e0641bde601e53841525c4196278e9aaf4463758da6dd1e752c0a0f54136a"}, + {file = "ninja-1.11.1.1.tar.gz", hash = "sha256:9d793b08dd857e38d0b6ffe9e6b7145d7c485a42dcfea04905ca0cdb6017cc3c"}, +] + +[package.extras] +test = ["codecov (>=2.0.5)", "coverage (>=4.2)", "flake8 (>=3.0.4)", "pytest (>=4.5.0)", "pytest-cov (>=2.7.1)", "pytest-runner (>=5.1)", "pytest-virtualenv (>=1.7.0)", "virtualenv (>=15.0.3)"] + +[[package]] +name = "nltk" +version = "3.8.1" +description = "Natural Language Toolkit" +optional = false +python-versions = ">=3.7" +files = [ + {file = "nltk-3.8.1-py3-none-any.whl", hash = "sha256:fd5c9109f976fa86bcadba8f91e47f5e9293bd034474752e92a520f81c93dda5"}, + {file = "nltk-3.8.1.zip", hash = "sha256:1834da3d0682cba4f2cede2f9aad6b0fafb6461ba451db0efb6f9c39798d64d3"}, +] + +[package.dependencies] +click = "*" +joblib = "*" +regex = ">=2021.8.3" +tqdm = "*" + +[package.extras] +all = ["matplotlib", "numpy", "pyparsing", "python-crfsuite", "requests", "scikit-learn", "scipy", "twython"] +corenlp = ["requests"] +machine-learning = ["numpy", "python-crfsuite", "scikit-learn", "scipy"] +plot = ["matplotlib"] +tgrep = ["pyparsing"] +twitter = ["twython"] + [[package]] name = "nodeenv" version = "1.8.0" @@ -2009,137 +2097,113 @@ files = [ ] [[package]] -name = "nvidia-cublas-cu11" -version = "11.10.3.66" +name = "nvidia-cublas-cu12" +version = "12.1.3.1" description = "CUBLAS native runtime libraries" optional = false python-versions = ">=3" files = [ - {file = "nvidia_cublas_cu11-11.10.3.66-py3-none-manylinux1_x86_64.whl", hash = "sha256:d32e4d75f94ddfb93ea0a5dda08389bcc65d8916a25cb9f37ac89edaeed3bded"}, - {file = "nvidia_cublas_cu11-11.10.3.66-py3-none-win_amd64.whl", hash = "sha256:8ac17ba6ade3ed56ab898a036f9ae0756f1e81052a317bf98f8c6d18dc3ae49e"}, + {file = "nvidia_cublas_cu12-12.1.3.1-py3-none-manylinux1_x86_64.whl", hash = "sha256:ee53ccca76a6fc08fb9701aa95b6ceb242cdaab118c3bb152af4e579af792728"}, + {file = "nvidia_cublas_cu12-12.1.3.1-py3-none-win_amd64.whl", hash = "sha256:2b964d60e8cf11b5e1073d179d85fa340c120e99b3067558f3cf98dd69d02906"}, ] -[package.dependencies] -setuptools = "*" -wheel = "*" - [[package]] -name = "nvidia-cuda-cupti-cu11" -version = "11.7.101" +name = "nvidia-cuda-cupti-cu12" +version = "12.1.105" description = "CUDA profiling tools runtime libs." optional = false python-versions = ">=3" files = [ - {file = "nvidia_cuda_cupti_cu11-11.7.101-py3-none-manylinux1_x86_64.whl", hash = "sha256:e0cfd9854e1f2edaa36ca20d21cd0bdd5dcfca4e3b9e130a082e05b33b6c5895"}, - {file = "nvidia_cuda_cupti_cu11-11.7.101-py3-none-win_amd64.whl", hash = "sha256:7cc5b8f91ae5e1389c3c0ad8866b3b016a175e827ea8f162a672990a402ab2b0"}, + {file = "nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl", hash = "sha256:e54fde3983165c624cb79254ae9818a456eb6e87a7fd4d56a2352c24ee542d7e"}, + {file = "nvidia_cuda_cupti_cu12-12.1.105-py3-none-win_amd64.whl", hash = "sha256:bea8236d13a0ac7190bd2919c3e8e6ce1e402104276e6f9694479e48bb0eb2a4"}, ] -[package.dependencies] -setuptools = "*" -wheel = "*" - [[package]] -name = "nvidia-cuda-nvrtc-cu11" -version = "11.7.99" +name = "nvidia-cuda-nvrtc-cu12" +version = "12.1.105" description = "NVRTC native runtime libraries" optional = false python-versions = ">=3" files = [ - {file = "nvidia_cuda_nvrtc_cu11-11.7.99-2-py3-none-manylinux1_x86_64.whl", hash = "sha256:9f1562822ea264b7e34ed5930567e89242d266448e936b85bc97a3370feabb03"}, - {file = "nvidia_cuda_nvrtc_cu11-11.7.99-py3-none-manylinux1_x86_64.whl", hash = "sha256:f7d9610d9b7c331fa0da2d1b2858a4a8315e6d49765091d28711c8946e7425e7"}, - {file = "nvidia_cuda_nvrtc_cu11-11.7.99-py3-none-win_amd64.whl", hash = "sha256:f2effeb1309bdd1b3854fc9b17eaf997808f8b25968ce0c7070945c4265d64a3"}, + {file = "nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl", hash = "sha256:339b385f50c309763ca65456ec75e17bbefcbbf2893f462cb8b90584cd27a1c2"}, + {file = "nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-win_amd64.whl", hash = "sha256:0a98a522d9ff138b96c010a65e145dc1b4850e9ecb75a0172371793752fd46ed"}, ] -[package.dependencies] -setuptools = "*" -wheel = "*" - [[package]] -name = "nvidia-cuda-runtime-cu11" -version = "11.7.99" +name = "nvidia-cuda-runtime-cu12" +version = "12.1.105" description = "CUDA Runtime native Libraries" optional = false python-versions = ">=3" files = [ - {file = "nvidia_cuda_runtime_cu11-11.7.99-py3-none-manylinux1_x86_64.whl", hash = "sha256:cc768314ae58d2641f07eac350f40f99dcb35719c4faff4bc458a7cd2b119e31"}, - {file = "nvidia_cuda_runtime_cu11-11.7.99-py3-none-win_amd64.whl", hash = "sha256:bc77fa59a7679310df9d5c70ab13c4e34c64ae2124dd1efd7e5474b71be125c7"}, + {file = "nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl", hash = "sha256:6e258468ddf5796e25f1dc591a31029fa317d97a0a94ed93468fc86301d61e40"}, + {file = "nvidia_cuda_runtime_cu12-12.1.105-py3-none-win_amd64.whl", hash = "sha256:dfb46ef84d73fababab44cf03e3b83f80700d27ca300e537f85f636fac474344"}, ] -[package.dependencies] -setuptools = "*" -wheel = "*" - [[package]] -name = "nvidia-cudnn-cu11" -version = "8.5.0.96" +name = "nvidia-cudnn-cu12" +version = "8.9.2.26" description = "cuDNN runtime libraries" optional = false python-versions = ">=3" files = [ - {file = "nvidia_cudnn_cu11-8.5.0.96-2-py3-none-manylinux1_x86_64.whl", hash = "sha256:402f40adfc6f418f9dae9ab402e773cfed9beae52333f6d86ae3107a1b9527e7"}, - {file = "nvidia_cudnn_cu11-8.5.0.96-py3-none-manylinux1_x86_64.whl", hash = "sha256:71f8111eb830879ff2836db3cccf03bbd735df9b0d17cd93761732ac50a8a108"}, + {file = "nvidia_cudnn_cu12-8.9.2.26-py3-none-manylinux1_x86_64.whl", hash = "sha256:5ccb288774fdfb07a7e7025ffec286971c06d8d7b4fb162525334616d7629ff9"}, ] [package.dependencies] -setuptools = "*" -wheel = "*" +nvidia-cublas-cu12 = "*" [[package]] -name = "nvidia-cufft-cu11" -version = "10.9.0.58" +name = "nvidia-cufft-cu12" +version = "11.0.2.54" description = "CUFFT native runtime libraries" optional = false python-versions = ">=3" files = [ - {file = "nvidia_cufft_cu11-10.9.0.58-py3-none-manylinux1_x86_64.whl", hash = "sha256:222f9da70c80384632fd6035e4c3f16762d64ea7a843829cb278f98b3cb7dd81"}, - {file = "nvidia_cufft_cu11-10.9.0.58-py3-none-win_amd64.whl", hash = "sha256:c4d316f17c745ec9c728e30409612eaf77a8404c3733cdf6c9c1569634d1ca03"}, + {file = "nvidia_cufft_cu12-11.0.2.54-py3-none-manylinux1_x86_64.whl", hash = "sha256:794e3948a1aa71fd817c3775866943936774d1c14e7628c74f6f7417224cdf56"}, + {file = "nvidia_cufft_cu12-11.0.2.54-py3-none-win_amd64.whl", hash = "sha256:d9ac353f78ff89951da4af698f80870b1534ed69993f10a4cf1d96f21357e253"}, ] [[package]] -name = "nvidia-curand-cu11" -version = "10.2.10.91" +name = "nvidia-curand-cu12" +version = "10.3.2.106" description = "CURAND native runtime libraries" optional = false python-versions = ">=3" files = [ - {file = "nvidia_curand_cu11-10.2.10.91-py3-none-manylinux1_x86_64.whl", hash = "sha256:eecb269c970fa599a2660c9232fa46aaccbf90d9170b96c462e13bcb4d129e2c"}, - {file = "nvidia_curand_cu11-10.2.10.91-py3-none-win_amd64.whl", hash = "sha256:f742052af0e1e75523bde18895a9ed016ecf1e5aa0ecddfcc3658fd11a1ff417"}, + {file = "nvidia_curand_cu12-10.3.2.106-py3-none-manylinux1_x86_64.whl", hash = "sha256:9d264c5036dde4e64f1de8c50ae753237c12e0b1348738169cd0f8a536c0e1e0"}, + {file = "nvidia_curand_cu12-10.3.2.106-py3-none-win_amd64.whl", hash = "sha256:75b6b0c574c0037839121317e17fd01f8a69fd2ef8e25853d826fec30bdba74a"}, ] -[package.dependencies] -setuptools = "*" -wheel = "*" - [[package]] -name = "nvidia-cusolver-cu11" -version = "11.4.0.1" +name = "nvidia-cusolver-cu12" +version = "11.4.5.107" description = "CUDA solver native runtime libraries" optional = false python-versions = ">=3" files = [ - {file = "nvidia_cusolver_cu11-11.4.0.1-2-py3-none-manylinux1_x86_64.whl", hash = "sha256:72fa7261d755ed55c0074960df5904b65e2326f7adce364cbe4945063c1be412"}, - {file = "nvidia_cusolver_cu11-11.4.0.1-py3-none-manylinux1_x86_64.whl", hash = "sha256:700b781bfefd57d161443aff9ace1878584b93e0b2cfef3d6e9296d96febbf99"}, - {file = "nvidia_cusolver_cu11-11.4.0.1-py3-none-win_amd64.whl", hash = "sha256:00f70b256add65f8c1eb3b6a65308795a93e7740f6df9e273eccbba770d370c4"}, + {file = "nvidia_cusolver_cu12-11.4.5.107-py3-none-manylinux1_x86_64.whl", hash = "sha256:8a7ec542f0412294b15072fa7dab71d31334014a69f953004ea7a118206fe0dd"}, + {file = "nvidia_cusolver_cu12-11.4.5.107-py3-none-win_amd64.whl", hash = "sha256:74e0c3a24c78612192a74fcd90dd117f1cf21dea4822e66d89e8ea80e3cd2da5"}, ] [package.dependencies] -setuptools = "*" -wheel = "*" +nvidia-cublas-cu12 = "*" +nvidia-cusparse-cu12 = "*" +nvidia-nvjitlink-cu12 = "*" [[package]] -name = "nvidia-cusparse-cu11" -version = "11.7.4.91" +name = "nvidia-cusparse-cu12" +version = "12.1.0.106" description = "CUSPARSE native runtime libraries" optional = false python-versions = ">=3" files = [ - {file = "nvidia_cusparse_cu11-11.7.4.91-py3-none-manylinux1_x86_64.whl", hash = "sha256:a3389de714db63321aa11fbec3919271f415ef19fda58aed7f2ede488c32733d"}, - {file = "nvidia_cusparse_cu11-11.7.4.91-py3-none-win_amd64.whl", hash = "sha256:304a01599534f5186a8ed1c3756879282c72c118bc77dd890dc1ff868cad25b9"}, + {file = "nvidia_cusparse_cu12-12.1.0.106-py3-none-manylinux1_x86_64.whl", hash = "sha256:f3b50f42cf363f86ab21f720998517a659a48131e8d538dc02f8768237bd884c"}, + {file = "nvidia_cusparse_cu12-12.1.0.106-py3-none-win_amd64.whl", hash = "sha256:b798237e81b9719373e8fae8d4f091b70a0cf09d9d85c95a557e11df2d8e9a5a"}, ] [package.dependencies] -setuptools = "*" -wheel = "*" +nvidia-nvjitlink-cu12 = "*" [[package]] name = "nvidia-ml-py" @@ -2153,29 +2217,36 @@ files = [ ] [[package]] -name = "nvidia-nccl-cu11" -version = "2.14.3" +name = "nvidia-nccl-cu12" +version = "2.18.1" description = "NVIDIA Collective Communication Library (NCCL) Runtime" optional = false python-versions = ">=3" files = [ - {file = "nvidia_nccl_cu11-2.14.3-py3-none-manylinux1_x86_64.whl", hash = "sha256:5e5534257d1284b8e825bc3a182c6f06acd6eb405e9f89d49340e98cd8f136eb"}, + {file = "nvidia_nccl_cu12-2.18.1-py3-none-manylinux1_x86_64.whl", hash = "sha256:1a6c4acefcbebfa6de320f412bf7866de856e786e0462326ba1bac40de0b5e71"}, ] [[package]] -name = "nvidia-nvtx-cu11" -version = "11.7.91" -description = "NVIDIA Tools Extension" +name = "nvidia-nvjitlink-cu12" +version = "12.3.101" +description = "Nvidia JIT LTO Library" optional = false python-versions = ">=3" files = [ - {file = "nvidia_nvtx_cu11-11.7.91-py3-none-manylinux1_x86_64.whl", hash = "sha256:b22c64eee426a62fc00952b507d6d29cf62b4c9df7a480fcc417e540e05fd5ac"}, - {file = "nvidia_nvtx_cu11-11.7.91-py3-none-win_amd64.whl", hash = "sha256:dfd7fcb2a91742513027d63a26b757f38dd8b07fecac282c4d132a9d373ff064"}, + {file = "nvidia_nvjitlink_cu12-12.3.101-py3-none-manylinux1_x86_64.whl", hash = "sha256:64335a8088e2b9d196ae8665430bc6a2b7e6ef2eb877a9c735c804bd4ff6467c"}, + {file = "nvidia_nvjitlink_cu12-12.3.101-py3-none-win_amd64.whl", hash = "sha256:1b2e317e437433753530792f13eece58f0aec21a2b05903be7bffe58a606cbd1"}, ] -[package.dependencies] -setuptools = "*" -wheel = "*" +[[package]] +name = "nvidia-nvtx-cu12" +version = "12.1.105" +description = "NVIDIA Tools Extension" +optional = false +python-versions = ">=3" +files = [ + {file = "nvidia_nvtx_cu12-12.1.105-py3-none-manylinux1_x86_64.whl", hash = "sha256:dc21cf308ca5691e7c04d962e213f8a4aa9bbfa23d95412f452254c2caeb09e5"}, + {file = "nvidia_nvtx_cu12-12.1.105-py3-none-win_amd64.whl", hash = "sha256:65f4d98982b31b60026e0e6de73fbdfc09d08a96f4656dd3665ca616a11e1e82"}, +] [[package]] name = "nvitop" @@ -2564,6 +2635,17 @@ files = [ [package.extras] tests = ["pytest"] +[[package]] +name = "py-cpuinfo" +version = "9.0.0" +description = "Get CPU info with pure Python" +optional = false +python-versions = "*" +files = [ + {file = "py-cpuinfo-9.0.0.tar.gz", hash = "sha256:3cdbbf3fac90dc6f118bfd64384f309edeadd902d7c8fb17f02ffa1fc3f49690"}, + {file = "py_cpuinfo-9.0.0-py3-none-any.whl", hash = "sha256:859625bc251f64e21f077d099d4162689c762b5d6a4c3c97553d56241c9674d5"}, +] + [[package]] name = "pyarrow" version = "12.0.0" @@ -2637,6 +2719,142 @@ files = [ {file = "pycparser-2.21.tar.gz", hash = "sha256:e644fdec12f7872f86c58ff790da456218b10f863970249516d60a5eaca77206"}, ] +[[package]] +name = "pydantic" +version = "2.5.2" +description = "Data validation using Python type hints" +optional = false +python-versions = ">=3.7" +files = [ + {file = "pydantic-2.5.2-py3-none-any.whl", hash = "sha256:80c50fb8e3dcecfddae1adbcc00ec5822918490c99ab31f6cf6140ca1c1429f0"}, + {file = "pydantic-2.5.2.tar.gz", hash = "sha256:ff177ba64c6faf73d7afa2e8cad38fd456c0dbe01c9954e71038001cd15a6edd"}, +] + +[package.dependencies] +annotated-types = ">=0.4.0" +pydantic-core = "2.14.5" +typing-extensions = ">=4.6.1" + +[package.extras] +email = ["email-validator (>=2.0.0)"] + +[[package]] +name = "pydantic-core" +version = "2.14.5" +description = "" +optional = false +python-versions = ">=3.7" +files = [ + {file = "pydantic_core-2.14.5-cp310-cp310-macosx_10_7_x86_64.whl", hash = "sha256:7e88f5696153dc516ba6e79f82cc4747e87027205f0e02390c21f7cb3bd8abfd"}, + {file = "pydantic_core-2.14.5-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:4641e8ad4efb697f38a9b64ca0523b557c7931c5f84e0fd377a9a3b05121f0de"}, + {file = "pydantic_core-2.14.5-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:774de879d212db5ce02dfbf5b0da9a0ea386aeba12b0b95674a4ce0593df3d07"}, + {file = "pydantic_core-2.14.5-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:ebb4e035e28f49b6f1a7032920bb9a0c064aedbbabe52c543343d39341a5b2a3"}, + {file = "pydantic_core-2.14.5-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b53e9ad053cd064f7e473a5f29b37fc4cc9dc6d35f341e6afc0155ea257fc911"}, + {file = "pydantic_core-2.14.5-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:8aa1768c151cf562a9992462239dfc356b3d1037cc5a3ac829bb7f3bda7cc1f9"}, + {file = "pydantic_core-2.14.5-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:eac5c82fc632c599f4639a5886f96867ffced74458c7db61bc9a66ccb8ee3113"}, + {file = "pydantic_core-2.14.5-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:d2ae91f50ccc5810b2f1b6b858257c9ad2e08da70bf890dee02de1775a387c66"}, + {file = "pydantic_core-2.14.5-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:6b9ff467ffbab9110e80e8c8de3bcfce8e8b0fd5661ac44a09ae5901668ba997"}, + {file = "pydantic_core-2.14.5-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:61ea96a78378e3bd5a0be99b0e5ed00057b71f66115f5404d0dae4819f495093"}, + {file = "pydantic_core-2.14.5-cp310-none-win32.whl", hash = "sha256:bb4c2eda937a5e74c38a41b33d8c77220380a388d689bcdb9b187cf6224c9720"}, + {file = "pydantic_core-2.14.5-cp310-none-win_amd64.whl", hash = "sha256:b7851992faf25eac90bfcb7bfd19e1f5ffa00afd57daec8a0042e63c74a4551b"}, + {file = "pydantic_core-2.14.5-cp311-cp311-macosx_10_7_x86_64.whl", hash = "sha256:4e40f2bd0d57dac3feb3a3aed50f17d83436c9e6b09b16af271b6230a2915459"}, + {file = "pydantic_core-2.14.5-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:ab1cdb0f14dc161ebc268c09db04d2c9e6f70027f3b42446fa11c153521c0e88"}, + {file = "pydantic_core-2.14.5-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:aae7ea3a1c5bb40c93cad361b3e869b180ac174656120c42b9fadebf685d121b"}, + {file = "pydantic_core-2.14.5-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:60b7607753ba62cf0739177913b858140f11b8af72f22860c28eabb2f0a61937"}, + {file = "pydantic_core-2.14.5-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2248485b0322c75aee7565d95ad0e16f1c67403a470d02f94da7344184be770f"}, + {file = "pydantic_core-2.14.5-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:823fcc638f67035137a5cd3f1584a4542d35a951c3cc68c6ead1df7dac825c26"}, + {file = "pydantic_core-2.14.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:96581cfefa9123accc465a5fd0cc833ac4d75d55cc30b633b402e00e7ced00a6"}, + {file = "pydantic_core-2.14.5-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:a33324437018bf6ba1bb0f921788788641439e0ed654b233285b9c69704c27b4"}, + {file = "pydantic_core-2.14.5-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:9bd18fee0923ca10f9a3ff67d4851c9d3e22b7bc63d1eddc12f439f436f2aada"}, + {file = "pydantic_core-2.14.5-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:853a2295c00f1d4429db4c0fb9475958543ee80cfd310814b5c0ef502de24dda"}, + {file = "pydantic_core-2.14.5-cp311-none-win32.whl", hash = "sha256:cb774298da62aea5c80a89bd58c40205ab4c2abf4834453b5de207d59d2e1651"}, + {file = "pydantic_core-2.14.5-cp311-none-win_amd64.whl", hash = "sha256:e87fc540c6cac7f29ede02e0f989d4233f88ad439c5cdee56f693cc9c1c78077"}, + {file = "pydantic_core-2.14.5-cp311-none-win_arm64.whl", hash = "sha256:57d52fa717ff445cb0a5ab5237db502e6be50809b43a596fb569630c665abddf"}, + {file = "pydantic_core-2.14.5-cp312-cp312-macosx_10_7_x86_64.whl", hash = "sha256:e60f112ac88db9261ad3a52032ea46388378034f3279c643499edb982536a093"}, + {file = "pydantic_core-2.14.5-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:6e227c40c02fd873c2a73a98c1280c10315cbebe26734c196ef4514776120aeb"}, + {file = "pydantic_core-2.14.5-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f0cbc7fff06a90bbd875cc201f94ef0ee3929dfbd5c55a06674b60857b8b85ed"}, + {file = "pydantic_core-2.14.5-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:103ef8d5b58596a731b690112819501ba1db7a36f4ee99f7892c40da02c3e189"}, + {file = "pydantic_core-2.14.5-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c949f04ecad823f81b1ba94e7d189d9dfb81edbb94ed3f8acfce41e682e48cef"}, + {file = "pydantic_core-2.14.5-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c1452a1acdf914d194159439eb21e56b89aa903f2e1c65c60b9d874f9b950e5d"}, + {file = "pydantic_core-2.14.5-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cb4679d4c2b089e5ef89756bc73e1926745e995d76e11925e3e96a76d5fa51fc"}, + {file = "pydantic_core-2.14.5-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:cf9d3fe53b1ee360e2421be95e62ca9b3296bf3f2fb2d3b83ca49ad3f925835e"}, + {file = "pydantic_core-2.14.5-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:70f4b4851dbb500129681d04cc955be2a90b2248d69273a787dda120d5cf1f69"}, + {file = "pydantic_core-2.14.5-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:59986de5710ad9613ff61dd9b02bdd2f615f1a7052304b79cc8fa2eb4e336d2d"}, + {file = "pydantic_core-2.14.5-cp312-none-win32.whl", hash = "sha256:699156034181e2ce106c89ddb4b6504c30db8caa86e0c30de47b3e0654543260"}, + {file = "pydantic_core-2.14.5-cp312-none-win_amd64.whl", hash = "sha256:5baab5455c7a538ac7e8bf1feec4278a66436197592a9bed538160a2e7d11e36"}, + {file = "pydantic_core-2.14.5-cp312-none-win_arm64.whl", hash = "sha256:e47e9a08bcc04d20975b6434cc50bf82665fbc751bcce739d04a3120428f3e27"}, + {file = "pydantic_core-2.14.5-cp37-cp37m-macosx_10_7_x86_64.whl", hash = "sha256:af36f36538418f3806048f3b242a1777e2540ff9efaa667c27da63d2749dbce0"}, + {file = "pydantic_core-2.14.5-cp37-cp37m-macosx_11_0_arm64.whl", hash = "sha256:45e95333b8418ded64745f14574aa9bfc212cb4fbeed7a687b0c6e53b5e188cd"}, + {file = "pydantic_core-2.14.5-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4e47a76848f92529879ecfc417ff88a2806438f57be4a6a8bf2961e8f9ca9ec7"}, + {file = "pydantic_core-2.14.5-cp37-cp37m-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:d81e6987b27bc7d101c8597e1cd2bcaa2fee5e8e0f356735c7ed34368c471550"}, + {file = "pydantic_core-2.14.5-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:34708cc82c330e303f4ce87758828ef6e457681b58ce0e921b6e97937dd1e2a3"}, + {file = "pydantic_core-2.14.5-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:652c1988019752138b974c28f43751528116bcceadad85f33a258869e641d753"}, + {file = "pydantic_core-2.14.5-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6e4d090e73e0725b2904fdbdd8d73b8802ddd691ef9254577b708d413bf3006e"}, + {file = "pydantic_core-2.14.5-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:5c7d5b5005f177764e96bd584d7bf28d6e26e96f2a541fdddb934c486e36fd59"}, + {file = "pydantic_core-2.14.5-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:a71891847f0a73b1b9eb86d089baee301477abef45f7eaf303495cd1473613e4"}, + {file = "pydantic_core-2.14.5-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:a717aef6971208f0851a2420b075338e33083111d92041157bbe0e2713b37325"}, + {file = "pydantic_core-2.14.5-cp37-none-win32.whl", hash = "sha256:de790a3b5aa2124b8b78ae5faa033937a72da8efe74b9231698b5a1dd9be3405"}, + {file = "pydantic_core-2.14.5-cp37-none-win_amd64.whl", hash = "sha256:6c327e9cd849b564b234da821236e6bcbe4f359a42ee05050dc79d8ed2a91588"}, + {file = "pydantic_core-2.14.5-cp38-cp38-macosx_10_7_x86_64.whl", hash = "sha256:ef98ca7d5995a82f43ec0ab39c4caf6a9b994cb0b53648ff61716370eadc43cf"}, + {file = "pydantic_core-2.14.5-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:c6eae413494a1c3f89055da7a5515f32e05ebc1a234c27674a6956755fb2236f"}, + {file = "pydantic_core-2.14.5-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:dcf4e6d85614f7a4956c2de5a56531f44efb973d2fe4a444d7251df5d5c4dcfd"}, + {file = "pydantic_core-2.14.5-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:6637560562134b0e17de333d18e69e312e0458ee4455bdad12c37100b7cad706"}, + {file = "pydantic_core-2.14.5-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:77fa384d8e118b3077cccfcaf91bf83c31fe4dc850b5e6ee3dc14dc3d61bdba1"}, + {file = "pydantic_core-2.14.5-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:16e29bad40bcf97aac682a58861249ca9dcc57c3f6be22f506501833ddb8939c"}, + {file = "pydantic_core-2.14.5-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:531f4b4252fac6ca476fbe0e6f60f16f5b65d3e6b583bc4d87645e4e5ddde331"}, + {file = "pydantic_core-2.14.5-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:074f3d86f081ce61414d2dc44901f4f83617329c6f3ab49d2bc6c96948b2c26b"}, + {file = "pydantic_core-2.14.5-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:c2adbe22ab4babbca99c75c5d07aaf74f43c3195384ec07ccbd2f9e3bddaecec"}, + {file = "pydantic_core-2.14.5-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:0f6116a558fd06d1b7c2902d1c4cf64a5bd49d67c3540e61eccca93f41418124"}, + {file = "pydantic_core-2.14.5-cp38-none-win32.whl", hash = "sha256:fe0a5a1025eb797752136ac8b4fa21aa891e3d74fd340f864ff982d649691867"}, + {file = "pydantic_core-2.14.5-cp38-none-win_amd64.whl", hash = "sha256:079206491c435b60778cf2b0ee5fd645e61ffd6e70c47806c9ed51fc75af078d"}, + {file = "pydantic_core-2.14.5-cp39-cp39-macosx_10_7_x86_64.whl", hash = "sha256:a6a16f4a527aae4f49c875da3cdc9508ac7eef26e7977952608610104244e1b7"}, + {file = "pydantic_core-2.14.5-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:abf058be9517dc877227ec3223f0300034bd0e9f53aebd63cf4456c8cb1e0863"}, + {file = "pydantic_core-2.14.5-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:49b08aae5013640a3bfa25a8eebbd95638ec3f4b2eaf6ed82cf0c7047133f03b"}, + {file = "pydantic_core-2.14.5-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:c2d97e906b4ff36eb464d52a3bc7d720bd6261f64bc4bcdbcd2c557c02081ed2"}, + {file = "pydantic_core-2.14.5-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3128e0bbc8c091ec4375a1828d6118bc20404883169ac95ffa8d983b293611e6"}, + {file = "pydantic_core-2.14.5-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:88e74ab0cdd84ad0614e2750f903bb0d610cc8af2cc17f72c28163acfcf372a4"}, + {file = "pydantic_core-2.14.5-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c339dabd8ee15f8259ee0f202679b6324926e5bc9e9a40bf981ce77c038553db"}, + {file = "pydantic_core-2.14.5-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:3387277f1bf659caf1724e1afe8ee7dbc9952a82d90f858ebb931880216ea955"}, + {file = "pydantic_core-2.14.5-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:ba6b6b3846cfc10fdb4c971980a954e49d447cd215ed5a77ec8190bc93dd7bc5"}, + {file = "pydantic_core-2.14.5-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:ca61d858e4107ce5e1330a74724fe757fc7135190eb5ce5c9d0191729f033209"}, + {file = "pydantic_core-2.14.5-cp39-none-win32.whl", hash = "sha256:ec1e72d6412f7126eb7b2e3bfca42b15e6e389e1bc88ea0069d0cc1742f477c6"}, + {file = "pydantic_core-2.14.5-cp39-none-win_amd64.whl", hash = "sha256:c0b97ec434041827935044bbbe52b03d6018c2897349670ff8fe11ed24d1d4ab"}, + {file = "pydantic_core-2.14.5-pp310-pypy310_pp73-macosx_10_7_x86_64.whl", hash = "sha256:79e0a2cdbdc7af3f4aee3210b1172ab53d7ddb6a2d8c24119b5706e622b346d0"}, + {file = "pydantic_core-2.14.5-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:678265f7b14e138d9a541ddabbe033012a2953315739f8cfa6d754cc8063e8ca"}, + {file = "pydantic_core-2.14.5-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:95b15e855ae44f0c6341ceb74df61b606e11f1087e87dcb7482377374aac6abe"}, + {file = "pydantic_core-2.14.5-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:09b0e985fbaf13e6b06a56d21694d12ebca6ce5414b9211edf6f17738d82b0f8"}, + {file = "pydantic_core-2.14.5-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:3ad873900297bb36e4b6b3f7029d88ff9829ecdc15d5cf20161775ce12306f8a"}, + {file = "pydantic_core-2.14.5-pp310-pypy310_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:2d0ae0d8670164e10accbeb31d5ad45adb71292032d0fdb9079912907f0085f4"}, + {file = "pydantic_core-2.14.5-pp310-pypy310_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:d37f8ec982ead9ba0a22a996129594938138a1503237b87318392a48882d50b7"}, + {file = "pydantic_core-2.14.5-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:35613015f0ba7e14c29ac6c2483a657ec740e5ac5758d993fdd5870b07a61d8b"}, + {file = "pydantic_core-2.14.5-pp37-pypy37_pp73-macosx_10_7_x86_64.whl", hash = "sha256:ab4ea451082e684198636565224bbb179575efc1658c48281b2c866bfd4ddf04"}, + {file = "pydantic_core-2.14.5-pp37-pypy37_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4ce601907e99ea5b4adb807ded3570ea62186b17f88e271569144e8cca4409c7"}, + {file = "pydantic_core-2.14.5-pp37-pypy37_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fb2ed8b3fe4bf4506d6dab3b93b83bbc22237e230cba03866d561c3577517d18"}, + {file = "pydantic_core-2.14.5-pp37-pypy37_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:70f947628e074bb2526ba1b151cee10e4c3b9670af4dbb4d73bc8a89445916b5"}, + {file = "pydantic_core-2.14.5-pp37-pypy37_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:4bc536201426451f06f044dfbf341c09f540b4ebdb9fd8d2c6164d733de5e634"}, + {file = "pydantic_core-2.14.5-pp37-pypy37_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:f4791cf0f8c3104ac668797d8c514afb3431bc3305f5638add0ba1a5a37e0d88"}, + {file = "pydantic_core-2.14.5-pp38-pypy38_pp73-macosx_10_7_x86_64.whl", hash = "sha256:038c9f763e650712b899f983076ce783175397c848da04985658e7628cbe873b"}, + {file = "pydantic_core-2.14.5-pp38-pypy38_pp73-macosx_11_0_arm64.whl", hash = "sha256:27548e16c79702f1e03f5628589c6057c9ae17c95b4c449de3c66b589ead0520"}, + {file = "pydantic_core-2.14.5-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c97bee68898f3f4344eb02fec316db93d9700fb1e6a5b760ffa20d71d9a46ce3"}, + {file = "pydantic_core-2.14.5-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b9b759b77f5337b4ea024f03abc6464c9f35d9718de01cfe6bae9f2e139c397e"}, + {file = "pydantic_core-2.14.5-pp38-pypy38_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:439c9afe34638ace43a49bf72d201e0ffc1a800295bed8420c2a9ca8d5e3dbb3"}, + {file = "pydantic_core-2.14.5-pp38-pypy38_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:ba39688799094c75ea8a16a6b544eb57b5b0f3328697084f3f2790892510d144"}, + {file = "pydantic_core-2.14.5-pp38-pypy38_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:ccd4d5702bb90b84df13bd491be8d900b92016c5a455b7e14630ad7449eb03f8"}, + {file = "pydantic_core-2.14.5-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:81982d78a45d1e5396819bbb4ece1fadfe5f079335dd28c4ab3427cd95389944"}, + {file = "pydantic_core-2.14.5-pp39-pypy39_pp73-macosx_10_7_x86_64.whl", hash = "sha256:7f8210297b04e53bc3da35db08b7302a6a1f4889c79173af69b72ec9754796b8"}, + {file = "pydantic_core-2.14.5-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:8c8a8812fe6f43a3a5b054af6ac2d7b8605c7bcab2804a8a7d68b53f3cd86e00"}, + {file = "pydantic_core-2.14.5-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:206ed23aecd67c71daf5c02c3cd19c0501b01ef3cbf7782db9e4e051426b3d0d"}, + {file = "pydantic_core-2.14.5-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c2027d05c8aebe61d898d4cffd774840a9cb82ed356ba47a90d99ad768f39789"}, + {file = "pydantic_core-2.14.5-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:40180930807ce806aa71eda5a5a5447abb6b6a3c0b4b3b1b1962651906484d68"}, + {file = "pydantic_core-2.14.5-pp39-pypy39_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:615a0a4bff11c45eb3c1996ceed5bdaa2f7b432425253a7c2eed33bb86d80abc"}, + {file = "pydantic_core-2.14.5-pp39-pypy39_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:f5e412d717366e0677ef767eac93566582518fe8be923361a5c204c1a62eaafe"}, + {file = "pydantic_core-2.14.5-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:513b07e99c0a267b1d954243845d8a833758a6726a3b5d8948306e3fe14675e3"}, + {file = "pydantic_core-2.14.5.tar.gz", hash = "sha256:6d30226dfc816dd0fdf120cae611dd2215117e4f9b124af8c60ab9093b6e8e71"}, +] + +[package.dependencies] +typing-extensions = ">=4.6.0,<4.7.0 || >4.7.0" + [[package]] name = "pygments" version = "2.15.1" @@ -2651,6 +2869,17 @@ files = [ [package.extras] plugins = ["importlib-metadata"] +[[package]] +name = "pynvml" +version = "11.5.0" +description = "Python Bindings for the NVIDIA Management Library" +optional = false +python-versions = ">=3.6" +files = [ + {file = "pynvml-11.5.0-py3-none-any.whl", hash = "sha256:5cce014ac01b098d08f06178f86c37be409b80b2e903a5a03ce15eed60f55e25"}, + {file = "pynvml-11.5.0.tar.gz", hash = "sha256:d027b21b95b1088b9fc278117f9f61b7c67f8e33a787e9f83f735f0f71ac32d0"}, +] + [[package]] name = "pytest" version = "7.4.0" @@ -3048,6 +3277,22 @@ typing-extensions = {version = ">=4.0.0,<5.0", markers = "python_version < \"3.9 [package.extras] jupyter = ["ipywidgets (>=7.5.1,<9)"] +[[package]] +name = "rouge-score" +version = "0.1.2" +description = "Pure python implementation of ROUGE-1.5.5." +optional = false +python-versions = ">=3.7" +files = [ + {file = "rouge_score-0.1.2.tar.gz", hash = "sha256:c7d4da2683e68c9abf0135ef915d63a46643666f848e558a1b9f7ead17ff0f04"}, +] + +[package.dependencies] +absl-py = "*" +nltk = "*" +numpy = "*" +six = ">=1.14.0" + [[package]] name = "rsa" version = "4.9" @@ -3368,6 +3613,20 @@ files = [ [package.dependencies] mpmath = ">=0.19" +[[package]] +name = "tabulate" +version = "0.9.0" +description = "Pretty-print tabular data" +optional = false +python-versions = ">=3.7" +files = [ + {file = "tabulate-0.9.0-py3-none-any.whl", hash = "sha256:024ca478df22e9340661486f85298cff5f6dcdba14f3813e8830015b9ed1948f"}, + {file = "tabulate-0.9.0.tar.gz", hash = "sha256:0095b12bf5966de529c0feb1fa08671671b3368eec77d7ef7ab114be2c068b3c"}, +] + +[package.extras] +widechars = ["wcwidth"] + [[package]] name = "tensorboard" version = "2.13.0" @@ -3449,56 +3708,117 @@ tests = ["pytest", "pytest-cov"] [[package]] name = "tokenizers" -version = "0.13.3" -description = "Fast and Customizable Tokenizers" +version = "0.15.0" +description = "" optional = false -python-versions = "*" +python-versions = ">=3.7" files = [ - {file = "tokenizers-0.13.3-cp310-cp310-macosx_10_11_x86_64.whl", hash = "sha256:f3835c5be51de8c0a092058a4d4380cb9244fb34681fd0a295fbf0a52a5fdf33"}, - {file = "tokenizers-0.13.3-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:4ef4c3e821730f2692489e926b184321e887f34fb8a6b80b8096b966ba663d07"}, - {file = "tokenizers-0.13.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c5fd1a6a25353e9aa762e2aae5a1e63883cad9f4e997c447ec39d071020459bc"}, - {file = "tokenizers-0.13.3-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ee0b1b311d65beab83d7a41c56a1e46ab732a9eed4460648e8eb0bd69fc2d059"}, - {file = "tokenizers-0.13.3-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5ef4215284df1277dadbcc5e17d4882bda19f770d02348e73523f7e7d8b8d396"}, - {file = "tokenizers-0.13.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a4d53976079cff8a033f778fb9adca2d9d69d009c02fa2d71a878b5f3963ed30"}, - {file = "tokenizers-0.13.3-cp310-cp310-win32.whl", hash = "sha256:1f0e3b4c2ea2cd13238ce43548959c118069db7579e5d40ec270ad77da5833ce"}, - {file = "tokenizers-0.13.3-cp310-cp310-win_amd64.whl", hash = "sha256:89649c00d0d7211e8186f7a75dfa1db6996f65edce4b84821817eadcc2d3c79e"}, - {file = "tokenizers-0.13.3-cp311-cp311-macosx_10_11_universal2.whl", hash = "sha256:56b726e0d2bbc9243872b0144515ba684af5b8d8cd112fb83ee1365e26ec74c8"}, - {file = "tokenizers-0.13.3-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:cc5c022ce692e1f499d745af293ab9ee6f5d92538ed2faf73f9708c89ee59ce6"}, - {file = "tokenizers-0.13.3-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f55c981ac44ba87c93e847c333e58c12abcbb377a0c2f2ef96e1a266e4184ff2"}, - {file = "tokenizers-0.13.3-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f247eae99800ef821a91f47c5280e9e9afaeed9980fc444208d5aa6ba69ff148"}, - {file = "tokenizers-0.13.3-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:4b3e3215d048e94f40f1c95802e45dcc37c5b05eb46280fc2ccc8cd351bff839"}, - {file = "tokenizers-0.13.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9ba2b0bf01777c9b9bc94b53764d6684554ce98551fec496f71bc5be3a03e98b"}, - {file = "tokenizers-0.13.3-cp311-cp311-win32.whl", hash = "sha256:cc78d77f597d1c458bf0ea7c2a64b6aa06941c7a99cb135b5969b0278824d808"}, - {file = "tokenizers-0.13.3-cp311-cp311-win_amd64.whl", hash = "sha256:ecf182bf59bd541a8876deccf0360f5ae60496fd50b58510048020751cf1724c"}, - {file = "tokenizers-0.13.3-cp37-cp37m-macosx_10_11_x86_64.whl", hash = "sha256:0527dc5436a1f6bf2c0327da3145687d3bcfbeab91fed8458920093de3901b44"}, - {file = "tokenizers-0.13.3-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:07cbb2c307627dc99b44b22ef05ff4473aa7c7cc1fec8f0a8b37d8a64b1a16d2"}, - {file = "tokenizers-0.13.3-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:4560dbdeaae5b7ee0d4e493027e3de6d53c991b5002d7ff95083c99e11dd5ac0"}, - {file = "tokenizers-0.13.3-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:64064bd0322405c9374305ab9b4c07152a1474370327499911937fd4a76d004b"}, - {file = "tokenizers-0.13.3-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b8c6e2ab0f2e3d939ca66aa1d596602105fe33b505cd2854a4c1717f704c51de"}, - {file = "tokenizers-0.13.3-cp37-cp37m-win32.whl", hash = "sha256:6cc29d410768f960db8677221e497226e545eaaea01aa3613fa0fdf2cc96cff4"}, - {file = "tokenizers-0.13.3-cp37-cp37m-win_amd64.whl", hash = "sha256:fc2a7fdf864554a0dacf09d32e17c0caa9afe72baf9dd7ddedc61973bae352d8"}, - {file = "tokenizers-0.13.3-cp38-cp38-macosx_10_11_x86_64.whl", hash = "sha256:8791dedba834c1fc55e5f1521be325ea3dafb381964be20684b92fdac95d79b7"}, - {file = "tokenizers-0.13.3-cp38-cp38-macosx_12_0_arm64.whl", hash = "sha256:d607a6a13718aeb20507bdf2b96162ead5145bbbfa26788d6b833f98b31b26e1"}, - {file = "tokenizers-0.13.3-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3791338f809cd1bf8e4fee6b540b36822434d0c6c6bc47162448deee3f77d425"}, - {file = "tokenizers-0.13.3-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c2f35f30e39e6aab8716f07790f646bdc6e4a853816cc49a95ef2a9016bf9ce6"}, - {file = "tokenizers-0.13.3-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:310204dfed5aa797128b65d63538a9837cbdd15da2a29a77d67eefa489edda26"}, - {file = "tokenizers-0.13.3-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a0f9b92ea052305166559f38498b3b0cae159caea712646648aaa272f7160963"}, - {file = "tokenizers-0.13.3-cp38-cp38-win32.whl", hash = "sha256:9a3fa134896c3c1f0da6e762d15141fbff30d094067c8f1157b9fdca593b5806"}, - {file = "tokenizers-0.13.3-cp38-cp38-win_amd64.whl", hash = "sha256:8e7b0cdeace87fa9e760e6a605e0ae8fc14b7d72e9fc19c578116f7287bb873d"}, - {file = "tokenizers-0.13.3-cp39-cp39-macosx_10_11_x86_64.whl", hash = "sha256:00cee1e0859d55507e693a48fa4aef07060c4bb6bd93d80120e18fea9371c66d"}, - {file = "tokenizers-0.13.3-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:a23ff602d0797cea1d0506ce69b27523b07e70f6dda982ab8cf82402de839088"}, - {file = "tokenizers-0.13.3-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:70ce07445050b537d2696022dafb115307abdffd2a5c106f029490f84501ef97"}, - {file = "tokenizers-0.13.3-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:280ffe95f50eaaf655b3a1dc7ff1d9cf4777029dbbc3e63a74e65a056594abc3"}, - {file = "tokenizers-0.13.3-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:97acfcec592f7e9de8cadcdcda50a7134423ac8455c0166b28c9ff04d227b371"}, - {file = "tokenizers-0.13.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dd7730c98a3010cd4f523465867ff95cd9d6430db46676ce79358f65ae39797b"}, - {file = "tokenizers-0.13.3-cp39-cp39-win32.whl", hash = "sha256:48625a108029cb1ddf42e17a81b5a3230ba6888a70c9dc14e81bc319e812652d"}, - {file = "tokenizers-0.13.3-cp39-cp39-win_amd64.whl", hash = "sha256:bc0a6f1ba036e482db6453571c9e3e60ecd5489980ffd95d11dc9f960483d783"}, - {file = "tokenizers-0.13.3.tar.gz", hash = "sha256:2e546dbb68b623008a5442353137fbb0123d311a6d7ba52f2667c8862a75af2e"}, + {file = "tokenizers-0.15.0-cp310-cp310-macosx_10_7_x86_64.whl", hash = "sha256:cd3cd0299aaa312cd2988957598f80becd04d5a07338741eca076057a2b37d6e"}, + {file = "tokenizers-0.15.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:8a922c492c721744ee175f15b91704be2d305569d25f0547c77cd6c9f210f9dc"}, + {file = "tokenizers-0.15.0-cp310-cp310-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:331dd786d02fc38698f835fff61c99480f98b73ce75a4c65bd110c9af5e4609a"}, + {file = "tokenizers-0.15.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:88dd0961c437d413ab027f8b115350c121d49902cfbadf08bb8f634b15fa1814"}, + {file = "tokenizers-0.15.0-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:6fdcc55339df7761cd52e1fbe8185d3b3963bc9e3f3545faa6c84f9e8818259a"}, + {file = "tokenizers-0.15.0-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f1480b0051d8ab5408e8e4db2dc832f7082ea24aa0722c427bde2418c6f3bd07"}, + {file = "tokenizers-0.15.0-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:9855e6c258918f9cf62792d4f6ddfa6c56dccd8c8118640f867f6393ecaf8bd7"}, + {file = "tokenizers-0.15.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:de9529fe75efcd54ba8d516aa725e1851df9199f0669b665c55e90df08f5af86"}, + {file = "tokenizers-0.15.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:8edcc90a36eab0705fe9121d6c77c6e42eeef25c7399864fd57dfb27173060bf"}, + {file = "tokenizers-0.15.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:ae17884aafb3e94f34fb7cfedc29054f5f54e142475ebf8a265a4e388fee3f8b"}, + {file = "tokenizers-0.15.0-cp310-none-win32.whl", hash = "sha256:9a3241acdc9b44cff6e95c4a55b9be943ef3658f8edb3686034d353734adba05"}, + {file = "tokenizers-0.15.0-cp310-none-win_amd64.whl", hash = "sha256:4b31807cb393d6ea31926b307911c89a1209d5e27629aa79553d1599c8ffdefe"}, + {file = "tokenizers-0.15.0-cp311-cp311-macosx_10_7_x86_64.whl", hash = "sha256:af7e9be8c05d30bb137b9fd20f9d99354816599e5fd3d58a4b1e28ba3b36171f"}, + {file = "tokenizers-0.15.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:c3d7343fa562ea29661783344a2d83662db0d3d17a6fa6a403cac8e512d2d9fd"}, + {file = "tokenizers-0.15.0-cp311-cp311-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:32371008788aeeb0309a9244809a23e4c0259625e6b74a103700f6421373f395"}, + {file = "tokenizers-0.15.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ca9db64c7c9954fbae698884c5bb089764edc549731e5f9b7fa1dd4e4d78d77f"}, + {file = "tokenizers-0.15.0-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:dbed5944c31195514669cf6381a0d8d47f164943000d10f93d6d02f0d45c25e0"}, + {file = "tokenizers-0.15.0-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:aab16c4a26d351d63e965b0c792f5da7227a37b69a6dc6d922ff70aa595b1b0c"}, + {file = "tokenizers-0.15.0-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3c2b60b12fdd310bf85ce5d7d3f823456b9b65eed30f5438dd7761879c495983"}, + {file = "tokenizers-0.15.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0344d6602740e44054a9e5bbe9775a5e149c4dddaff15959bb07dcce95a5a859"}, + {file = "tokenizers-0.15.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:4525f6997d81d9b6d9140088f4f5131f6627e4c960c2c87d0695ae7304233fc3"}, + {file = "tokenizers-0.15.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:65975094fef8cc68919644936764efd2ce98cf1bacbe8db2687155d2b0625bee"}, + {file = "tokenizers-0.15.0-cp311-none-win32.whl", hash = "sha256:ff5d2159c5d93015f5a4542aac6c315506df31853123aa39042672031768c301"}, + {file = "tokenizers-0.15.0-cp311-none-win_amd64.whl", hash = "sha256:2dd681b53cf615e60a31a115a3fda3980e543d25ca183797f797a6c3600788a3"}, + {file = "tokenizers-0.15.0-cp312-cp312-macosx_10_7_x86_64.whl", hash = "sha256:c9cce6ee149a3d703f86877bc2a6d997e34874b2d5a2d7839e36b2273f31d3d9"}, + {file = "tokenizers-0.15.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:4a0a94bc3370e6f1cc8a07a8ae867ce13b7c1b4291432a773931a61f256d44ea"}, + {file = "tokenizers-0.15.0-cp312-cp312-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:309cfcccfc7e502cb1f1de2c9c1c94680082a65bfd3a912d5a5b2c90c677eb60"}, + {file = "tokenizers-0.15.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8413e994dd7d875ab13009127fc85633916c71213917daf64962bafd488f15dc"}, + {file = "tokenizers-0.15.0-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:d0ebf9430f901dbdc3dcb06b493ff24a3644c9f88c08e6a1d6d0ae2228b9b818"}, + {file = "tokenizers-0.15.0-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:10361e9c7864b22dd791ec5126327f6c9292fb1d23481d4895780688d5e298ac"}, + {file = "tokenizers-0.15.0-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:babe42635b8a604c594bdc56d205755f73414fce17ba8479d142a963a6c25cbc"}, + {file = "tokenizers-0.15.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3768829861e964c7a4556f5f23307fce6a23872c2ebf030eb9822dbbbf7e9b2a"}, + {file = "tokenizers-0.15.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:9c91588a630adc88065e1c03ac6831e3e2112558869b9ebcb2b8afd8a14c944d"}, + {file = "tokenizers-0.15.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:77606994e793ca54ecf3a3619adc8a906a28ca223d9354b38df41cb8766a0ed6"}, + {file = "tokenizers-0.15.0-cp37-cp37m-macosx_10_7_x86_64.whl", hash = "sha256:6fe143939f3b596681922b2df12a591a5b010e7dcfbee2202482cd0c1c2f2459"}, + {file = "tokenizers-0.15.0-cp37-cp37m-macosx_11_0_arm64.whl", hash = "sha256:b7bee0f1795e3e3561e9a557061b1539e5255b8221e3f928f58100282407e090"}, + {file = "tokenizers-0.15.0-cp37-cp37m-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:5d37e7f4439b4c46192ab4f2ff38ab815e4420f153caa13dec9272ef14403d34"}, + {file = "tokenizers-0.15.0-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:caadf255cf7f951b38d10097836d1f3bcff4aeaaffadfdf748bab780bf5bff95"}, + {file = "tokenizers-0.15.0-cp37-cp37m-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:05accb9162bf711a941b1460b743d62fec61c160daf25e53c5eea52c74d77814"}, + {file = "tokenizers-0.15.0-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:26a2ef890740127cb115ee5260878f4a677e36a12831795fd7e85887c53b430b"}, + {file = "tokenizers-0.15.0-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e54c5f26df14913620046b33e822cb3bcd091a332a55230c0e63cc77135e2169"}, + {file = "tokenizers-0.15.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:669b8ed653a578bcff919566631156f5da3aab84c66f3c0b11a6281e8b4731c7"}, + {file = "tokenizers-0.15.0-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:0ea480d943297df26f06f508dab6e012b07f42bf3dffdd36e70799368a5f5229"}, + {file = "tokenizers-0.15.0-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:bc80a0a565ebfc7cd89de7dd581da8c2b3238addfca6280572d27d763f135f2f"}, + {file = "tokenizers-0.15.0-cp37-none-win32.whl", hash = "sha256:cdd945e678bbdf4517d5d8de66578a5030aeefecdb46f5320b034de9cad8d4dd"}, + {file = "tokenizers-0.15.0-cp37-none-win_amd64.whl", hash = "sha256:1ab96ab7dc706e002c32b2ea211a94c1c04b4f4de48354728c3a6e22401af322"}, + {file = "tokenizers-0.15.0-cp38-cp38-macosx_10_7_x86_64.whl", hash = "sha256:f21c9eb71c9a671e2a42f18b456a3d118e50c7f0fc4dd9fa8f4eb727fea529bf"}, + {file = "tokenizers-0.15.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:2a5f4543a35889679fc3052086e69e81880b2a5a28ff2a52c5a604be94b77a3f"}, + {file = "tokenizers-0.15.0-cp38-cp38-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:f8aa81afec893e952bd39692b2d9ef60575ed8c86fce1fd876a06d2e73e82dca"}, + {file = "tokenizers-0.15.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1574a5a4af22c3def93fe8fe4adcc90a39bf5797ed01686a4c46d1c3bc677d2f"}, + {file = "tokenizers-0.15.0-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:7c7982fd0ec9e9122d03b209dac48cebfea3de0479335100ef379a9a959b9a5a"}, + {file = "tokenizers-0.15.0-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f8d16b647032df2ce2c1f9097236e046ea9fedd969b25637b9d5d734d78aa53b"}, + {file = "tokenizers-0.15.0-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b3cdf29e6f9653da330515dc8fa414be5a93aae79e57f8acc50d4028dd843edf"}, + {file = "tokenizers-0.15.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7286f3df10de840867372e3e64b99ef58c677210e3ceb653cd0e740a5c53fe78"}, + {file = "tokenizers-0.15.0-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:aabc83028baa5a36ce7a94e7659250f0309c47fa4a639e5c2c38e6d5ea0de564"}, + {file = "tokenizers-0.15.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:72f78b0e0e276b1fc14a672fa73f3acca034ba8db4e782124a2996734a9ba9cf"}, + {file = "tokenizers-0.15.0-cp38-none-win32.whl", hash = "sha256:9680b0ecc26e7e42f16680c1aa62e924d58d1c2dd992707081cc10a374896ea2"}, + {file = "tokenizers-0.15.0-cp38-none-win_amd64.whl", hash = "sha256:f17cbd88dab695911cbdd385a5a7e3709cc61dff982351f5d1b5939f074a2466"}, + {file = "tokenizers-0.15.0-cp39-cp39-macosx_10_7_x86_64.whl", hash = "sha256:3661862df7382c5eb23ac4fbf7c75e69b02dc4f5784e4c5a734db406b5b24596"}, + {file = "tokenizers-0.15.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:c3045d191dad49647f5a5039738ecf1c77087945c7a295f7bcf051c37067e883"}, + {file = "tokenizers-0.15.0-cp39-cp39-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:a9fcaad9ab0801f14457d7c820d9f246b5ab590c407fc6b073819b1573097aa7"}, + {file = "tokenizers-0.15.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a79f17027f24fe9485701c8dbb269b9c713954ec3bdc1e7075a66086c0c0cd3c"}, + {file = "tokenizers-0.15.0-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:01a3aa332abc4bee7640563949fcfedca4de8f52691b3b70f2fc6ca71bfc0f4e"}, + {file = "tokenizers-0.15.0-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:05b83896a893cdfedad8785250daa3ba9f0504848323471524d4783d7291661e"}, + {file = "tokenizers-0.15.0-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:cbbf2489fcf25d809731ba2744ff278dd07d9eb3f8b7482726bd6cae607073a4"}, + {file = "tokenizers-0.15.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ab806ad521a5e9de38078b7add97589c313915f6f5fec6b2f9f289d14d607bd6"}, + {file = "tokenizers-0.15.0-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:4a522612d5c88a41563e3463226af64e2fa00629f65cdcc501d1995dd25d23f5"}, + {file = "tokenizers-0.15.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:e58a38c4e6075810bdfb861d9c005236a72a152ebc7005941cc90d1bbf16aca9"}, + {file = "tokenizers-0.15.0-cp39-none-win32.whl", hash = "sha256:b8034f1041fd2bd2b84ff9f4dc4ae2e1c3b71606820a9cd5c562ebd291a396d1"}, + {file = "tokenizers-0.15.0-cp39-none-win_amd64.whl", hash = "sha256:edde9aa964145d528d0e0dbf14f244b8a85ebf276fb76869bc02e2530fa37a96"}, + {file = "tokenizers-0.15.0-pp310-pypy310_pp73-macosx_10_7_x86_64.whl", hash = "sha256:309445d10d442b7521b98083dc9f0b5df14eca69dbbfebeb98d781ee2cef5d30"}, + {file = "tokenizers-0.15.0-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:d3125a6499226d4d48efc54f7498886b94c418e93a205b673bc59364eecf0804"}, + {file = "tokenizers-0.15.0-pp310-pypy310_pp73-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:ed56ddf0d54877bb9c6d885177db79b41576e61b5ef6defeb579dcb803c04ad5"}, + {file = "tokenizers-0.15.0-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3b22cd714706cc5b18992a232b023f736e539495f5cc61d2d28d176e55046f6c"}, + {file = "tokenizers-0.15.0-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fac2719b1e9bc8e8e7f6599b99d0a8e24f33d023eb8ef644c0366a596f0aa926"}, + {file = "tokenizers-0.15.0-pp310-pypy310_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:85ddae17570ec7e5bfaf51ffa78d044f444a8693e1316e1087ee6150596897ee"}, + {file = "tokenizers-0.15.0-pp310-pypy310_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:76f1bed992e396bf6f83e3df97b64ff47885e45e8365f8983afed8556a0bc51f"}, + {file = "tokenizers-0.15.0-pp37-pypy37_pp73-macosx_10_7_x86_64.whl", hash = "sha256:3bb0f4df6dce41a1c7482087b60d18c372ef4463cb99aa8195100fcd41e0fd64"}, + {file = "tokenizers-0.15.0-pp37-pypy37_pp73-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:22c27672c27a059a5f39ff4e49feed8c7f2e1525577c8a7e3978bd428eb5869d"}, + {file = "tokenizers-0.15.0-pp37-pypy37_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:78104f5d035c9991f92831fc0efe9e64a05d4032194f2a69f67aaa05a4d75bbb"}, + {file = "tokenizers-0.15.0-pp37-pypy37_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a40b73dc19d82c3e3ffb40abdaacca8fbc95eeb26c66b7f9f860aebc07a73998"}, + {file = "tokenizers-0.15.0-pp37-pypy37_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:d801d1368188c74552cd779b1286e67cb9fd96f4c57a9f9a2a09b6def9e1ab37"}, + {file = "tokenizers-0.15.0-pp37-pypy37_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:82641ffb13a4da1293fcc9f437d457647e60ed0385a9216cd135953778b3f0a1"}, + {file = "tokenizers-0.15.0-pp38-pypy38_pp73-macosx_10_7_x86_64.whl", hash = "sha256:160f9d1810f2c18fffa94aa98bf17632f6bd2dabc67fcb01a698ca80c37d52ee"}, + {file = "tokenizers-0.15.0-pp38-pypy38_pp73-macosx_11_0_arm64.whl", hash = "sha256:8d7d6eea831ed435fdeeb9bcd26476226401d7309d115a710c65da4088841948"}, + {file = "tokenizers-0.15.0-pp38-pypy38_pp73-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:f6456bec6c557d63d8ec0023758c32f589e1889ed03c055702e84ce275488bed"}, + {file = "tokenizers-0.15.0-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1eef39a502fad3bf104b9e1906b4fb0cee20e44e755e51df9a98f8922c3bf6d4"}, + {file = "tokenizers-0.15.0-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c1e4664c5b797e093c19b794bbecc19d2367e782b4a577d8b7c1821db5dc150d"}, + {file = "tokenizers-0.15.0-pp38-pypy38_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:ca003fb5f3995ff5cf676db6681b8ea5d54d3b30bea36af1120e78ee1a4a4cdf"}, + {file = "tokenizers-0.15.0-pp38-pypy38_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:7f17363141eb0c53752c89e10650b85ef059a52765d0802ba9613dbd2d21d425"}, + {file = "tokenizers-0.15.0-pp39-pypy39_pp73-macosx_10_7_x86_64.whl", hash = "sha256:8a765db05581c7d7e1280170f2888cda351760d196cc059c37ea96f121125799"}, + {file = "tokenizers-0.15.0-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:2a0dd641a72604486cd7302dd8f87a12c8a9b45e1755e47d2682733f097c1af5"}, + {file = "tokenizers-0.15.0-pp39-pypy39_pp73-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:0a1a3c973e4dc97797fc19e9f11546c95278ffc55c4492acb742f69e035490bc"}, + {file = "tokenizers-0.15.0-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d4fab75642aae4e604e729d6f78e0addb9d7e7d49e28c8f4d16b24da278e5263"}, + {file = "tokenizers-0.15.0-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:65f80be77f6327a86d8fd35a4467adcfe6174c159b4ab52a1a8dd4c6f2d7d9e1"}, + {file = "tokenizers-0.15.0-pp39-pypy39_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:a8da7533dbe66b88afd430c56a2f2ce1fd82e2681868f857da38eeb3191d7498"}, + {file = "tokenizers-0.15.0-pp39-pypy39_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:fa8eb4584fc6cbe6a84d7a7864be3ed28e23e9fd2146aa8ef1814d579df91958"}, + {file = "tokenizers-0.15.0.tar.gz", hash = "sha256:10c7e6e7b4cabd757da59e93f5f8d1126291d16f8b54f28510825ef56a3e5d0e"}, ] +[package.dependencies] +huggingface_hub = ">=0.16.4,<1.0" + [package.extras] -dev = ["black (==22.3)", "datasets", "numpy", "pytest", "requests"] -docs = ["setuptools-rust", "sphinx", "sphinx-rtd-theme"] +dev = ["tokenizers[testing]"] +docs = ["setuptools_rust", "sphinx", "sphinx_rtd_theme"] testing = ["black (==22.3)", "datasets", "numpy", "pytest", "requests"] [[package]] @@ -3525,57 +3845,55 @@ files = [ [[package]] name = "torch" -version = "2.0.0" +version = "2.1.2" description = "Tensors and Dynamic neural networks in Python with strong GPU acceleration" optional = false python-versions = ">=3.8.0" files = [ - {file = "torch-2.0.0-1-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:c9090bda7d2eeeecd74f51b721420dbeb44f838d4536cc1b284e879417e3064a"}, - {file = "torch-2.0.0-1-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:bd42db2a48a20574d2c33489e120e9f32789c4dc13c514b0c44272972d14a2d7"}, - {file = "torch-2.0.0-1-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:8969aa8375bcbc0c2993e7ede0a7f889df9515f18b9b548433f412affed478d9"}, - {file = "torch-2.0.0-1-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:ab2da16567cb55b67ae39e32d520d68ec736191d88ac79526ca5874754c32203"}, - {file = "torch-2.0.0-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:7a9319a67294ef02459a19738bbfa8727bb5307b822dadd708bc2ccf6c901aca"}, - {file = "torch-2.0.0-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:9f01fe1f6263f31bd04e1757946fd63ad531ae37f28bb2dbf66f5c826ee089f4"}, - {file = "torch-2.0.0-cp310-cp310-win_amd64.whl", hash = "sha256:527f4ae68df7b8301ee6b1158ca56350282ea633686537b30dbb5d7b4a52622a"}, - {file = "torch-2.0.0-cp310-none-macosx_10_9_x86_64.whl", hash = "sha256:ce9b5a49bd513dff7950a5a07d6e26594dd51989cee05ba388b03e8e366fd5d5"}, - {file = "torch-2.0.0-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:53e1c33c6896583cdb9a583693e22e99266444c4a43392dddc562640d39e542b"}, - {file = "torch-2.0.0-cp311-cp311-manylinux1_x86_64.whl", hash = "sha256:09651bff72e439d004c991f15add0c397c66f98ab36fe60d5514b44e4da722e8"}, - {file = "torch-2.0.0-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:d439aec349c98f12819e8564b8c54008e4613dd4428582af0e6e14c24ca85870"}, - {file = "torch-2.0.0-cp311-cp311-win_amd64.whl", hash = "sha256:2802f84f021907deee7e9470ed10c0e78af7457ac9a08a6cd7d55adef835fede"}, - {file = "torch-2.0.0-cp311-none-macosx_10_9_x86_64.whl", hash = "sha256:01858620f25f25e7a9ec4b547ff38e5e27c92d38ec4ccba9cfbfb31d7071ed9c"}, - {file = "torch-2.0.0-cp311-none-macosx_11_0_arm64.whl", hash = "sha256:9a2e53b5783ef5896a6af338b36d782f28e83c8ddfc2ac44b67b066d9d76f498"}, - {file = "torch-2.0.0-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:ec5fff2447663e369682838ff0f82187b4d846057ef4d119a8dea7772a0b17dd"}, - {file = "torch-2.0.0-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:11b0384fe3c18c01b8fc5992e70fc519cde65e44c51cc87be1838c1803daf42f"}, - {file = "torch-2.0.0-cp38-cp38-win_amd64.whl", hash = "sha256:e54846aa63855298cfb1195487f032e413e7ac9cbfa978fda32354cc39551475"}, - {file = "torch-2.0.0-cp38-none-macosx_10_9_x86_64.whl", hash = "sha256:cc788cbbbbc6eb4c90e52c550efd067586c2693092cf367c135b34893a64ae78"}, - {file = "torch-2.0.0-cp38-none-macosx_11_0_arm64.whl", hash = "sha256:d292640f0fd72b7a31b2a6e3b635eb5065fcbedd4478f9cad1a1e7a9ec861d35"}, - {file = "torch-2.0.0-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:6befaad784004b7af357e3d87fa0863c1f642866291f12a4c2af2de435e8ac5c"}, - {file = "torch-2.0.0-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:a83b26bd6ae36fbf5fee3d56973d9816e2002e8a3b7d9205531167c28aaa38a7"}, - {file = "torch-2.0.0-cp39-cp39-win_amd64.whl", hash = "sha256:c7e67195e1c3e33da53954b026e89a8e1ff3bc1aeb9eb32b677172d4a9b5dcbf"}, - {file = "torch-2.0.0-cp39-none-macosx_10_9_x86_64.whl", hash = "sha256:6e0b97beb037a165669c312591f242382e9109a240e20054d5a5782d9236cad0"}, - {file = "torch-2.0.0-cp39-none-macosx_11_0_arm64.whl", hash = "sha256:297a4919aff1c0f98a58ebe969200f71350a1d4d4f986dbfd60c02ffce780e99"}, + {file = "torch-2.1.2-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:3a871edd6c02dae77ad810335c0833391c1a4ce49af21ea8cf0f6a5d2096eea8"}, + {file = "torch-2.1.2-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:bef6996c27d8f6e92ea4e13a772d89611da0e103b48790de78131e308cf73076"}, + {file = "torch-2.1.2-cp310-cp310-win_amd64.whl", hash = "sha256:0e13034fd5fb323cbbc29e56d0637a3791e50dd589616f40c79adfa36a5a35a1"}, + {file = "torch-2.1.2-cp310-none-macosx_10_9_x86_64.whl", hash = "sha256:d9b535cad0df3d13997dbe8bd68ac33e0e3ae5377639c9881948e40794a61403"}, + {file = "torch-2.1.2-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:f9a55d55af02826ebfbadf4e9b682f0f27766bc33df8236b48d28d705587868f"}, + {file = "torch-2.1.2-cp311-cp311-manylinux1_x86_64.whl", hash = "sha256:a6ebbe517097ef289cc7952783588c72de071d4b15ce0f8b285093f0916b1162"}, + {file = "torch-2.1.2-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:8f32ce591616a30304f37a7d5ea80b69ca9e1b94bba7f308184bf616fdaea155"}, + {file = "torch-2.1.2-cp311-cp311-win_amd64.whl", hash = "sha256:e0ee6cf90c8970e05760f898d58f9ac65821c37ffe8b04269ec787aa70962b69"}, + {file = "torch-2.1.2-cp311-none-macosx_10_9_x86_64.whl", hash = "sha256:76d37967c31c99548ad2c4d3f2cf191db48476f2e69b35a0937137116da356a1"}, + {file = "torch-2.1.2-cp311-none-macosx_11_0_arm64.whl", hash = "sha256:e2d83f07b4aac983453ea5bf8f9aa9dacf2278a8d31247f5d9037f37befc60e4"}, + {file = "torch-2.1.2-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:f41fe0c7ecbf903a568c73486139a75cfab287a0f6c17ed0698fdea7a1e8641d"}, + {file = "torch-2.1.2-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:e3225f47d50bb66f756fe9196a768055d1c26b02154eb1f770ce47a2578d3aa7"}, + {file = "torch-2.1.2-cp38-cp38-win_amd64.whl", hash = "sha256:33d59cd03cb60106857f6c26b36457793637512998666ee3ce17311f217afe2b"}, + {file = "torch-2.1.2-cp38-none-macosx_10_9_x86_64.whl", hash = "sha256:8e221deccd0def6c2badff6be403e0c53491805ed9915e2c029adbcdb87ab6b5"}, + {file = "torch-2.1.2-cp38-none-macosx_11_0_arm64.whl", hash = "sha256:05b18594f60a911a0c4f023f38a8bda77131fba5fd741bda626e97dcf5a3dd0a"}, + {file = "torch-2.1.2-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:9ca96253b761e9aaf8e06fb30a66ee301aecbf15bb5a303097de1969077620b6"}, + {file = "torch-2.1.2-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:d93ba70f67b08c2ae5598ee711cbc546a1bc8102cef938904b8c85c2089a51a0"}, + {file = "torch-2.1.2-cp39-cp39-win_amd64.whl", hash = "sha256:255b50bc0608db177e6a3cc118961d77de7e5105f07816585fa6f191f33a9ff3"}, + {file = "torch-2.1.2-cp39-none-macosx_10_9_x86_64.whl", hash = "sha256:6984cd5057c0c977b3c9757254e989d3f1124f4ce9d07caa6cb637783c71d42a"}, + {file = "torch-2.1.2-cp39-none-macosx_11_0_arm64.whl", hash = "sha256:bc195d7927feabc0eb7c110e457c955ed2ab616f3c7c28439dd4188cf589699f"}, ] [package.dependencies] filelock = "*" +fsspec = "*" jinja2 = "*" networkx = "*" -nvidia-cublas-cu11 = {version = "11.10.3.66", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} -nvidia-cuda-cupti-cu11 = {version = "11.7.101", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} -nvidia-cuda-nvrtc-cu11 = {version = "11.7.99", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} -nvidia-cuda-runtime-cu11 = {version = "11.7.99", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} -nvidia-cudnn-cu11 = {version = "8.5.0.96", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} -nvidia-cufft-cu11 = {version = "10.9.0.58", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} -nvidia-curand-cu11 = {version = "10.2.10.91", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} -nvidia-cusolver-cu11 = {version = "11.4.0.1", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} -nvidia-cusparse-cu11 = {version = "11.7.4.91", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} -nvidia-nccl-cu11 = {version = "2.14.3", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} -nvidia-nvtx-cu11 = {version = "11.7.91", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-cublas-cu12 = {version = "12.1.3.1", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-cuda-cupti-cu12 = {version = "12.1.105", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-cuda-nvrtc-cu12 = {version = "12.1.105", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-cuda-runtime-cu12 = {version = "12.1.105", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-cudnn-cu12 = {version = "8.9.2.26", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-cufft-cu12 = {version = "11.0.2.54", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-curand-cu12 = {version = "10.3.2.106", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-cusolver-cu12 = {version = "11.4.5.107", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-cusparse-cu12 = {version = "12.1.0.106", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-nccl-cu12 = {version = "2.18.1", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-nvtx-cu12 = {version = "12.1.105", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} sympy = "*" -triton = {version = "2.0.0", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +triton = {version = "2.1.0", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} typing-extensions = "*" [package.extras] +dynamo = ["jinja2"] opt-einsum = ["opt-einsum (>=3.3)"] [[package]] @@ -3635,107 +3953,94 @@ test = ["argcomplete (>=2.0)", "pre-commit", "pytest", "pytest-mock"] [[package]] name = "transformers" -version = "4.30.1" +version = "4.36.1" description = "State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow" optional = false -python-versions = ">=3.7.0" +python-versions = ">=3.8.0" files = [ - {file = "transformers-4.30.1-py3-none-any.whl", hash = "sha256:9b12bd9d69f21b7c56cd512117fd52856b3def1c9bfc1da97ab0ee4e8bcbd797"}, - {file = "transformers-4.30.1.tar.gz", hash = "sha256:fa74fc271d0692f385d571ce83ec898e3350455f6076d21631f4eed4916e6ffd"}, + {file = "transformers-4.36.1-py3-none-any.whl", hash = "sha256:0e309d03634885f02d46801ec4f2c3fc1d614a5b9ebde608181f3e842bac53b8"}, + {file = "transformers-4.36.1.tar.gz", hash = "sha256:28e55952d9bed68f06cf45a3d29cc480679b528afe944e68f8cf6c799e428759"}, ] [package.dependencies] filelock = "*" -huggingface-hub = ">=0.14.1,<1.0" +huggingface-hub = ">=0.19.3,<1.0" numpy = ">=1.17" packaging = ">=20.0" pyyaml = ">=5.1" regex = "!=2019.12.17" requests = "*" safetensors = ">=0.3.1" -tokenizers = ">=0.11.1,<0.11.3 || >0.11.3,<0.14" +tokenizers = ">=0.14,<0.19" tqdm = ">=4.27" [package.extras] -accelerate = ["accelerate (>=0.20.2)"] -agents = ["Pillow", "accelerate (>=0.20.2)", "datasets (!=2.5.0)", "diffusers", "opencv-python", "sentencepiece (>=0.1.91,!=0.1.92)", "torch (>=1.9,!=1.12.0)"] -all = ["Pillow", "accelerate (>=0.20.2)", "av (==9.2.0)", "codecarbon (==1.2.0)", "decord (==0.6.0)", "flax (>=0.4.1,<=0.6.9)", "jax (>=0.2.8,!=0.3.2,<=0.3.6)", "jaxlib (>=0.1.65,<=0.3.6)", "kenlm", "keras-nlp (>=0.3.1)", "librosa", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "phonemizer", "protobuf (<=3.20.3)", "pyctcdecode (>=0.4.0)", "ray[tune]", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "tensorflow (>=2.4,<2.13)", "tensorflow-text (<2.13)", "tf2onnx", "timm", "tokenizers (>=0.11.1,!=0.11.3,<0.14)", "torch (>=1.9,!=1.12.0)", "torchaudio", "torchvision"] +accelerate = ["accelerate (>=0.21.0)"] +agents = ["Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.21.0)", "datasets (!=2.5.0)", "diffusers", "opencv-python", "sentencepiece (>=0.1.91,!=0.1.92)", "torch (>=1.10,!=1.12.0)"] +all = ["Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.21.0)", "av (==9.2.0)", "codecarbon (==1.2.0)", "decord (==0.6.0)", "flax (>=0.4.1,<=0.7.0)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "kenlm", "keras-nlp (>=0.3.1)", "librosa", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "phonemizer", "protobuf", "pyctcdecode (>=0.4.0)", "ray[tune] (>=2.7.0)", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "tensorflow (>=2.6,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timm", "tokenizers (>=0.14,<0.19)", "torch (>=1.10,!=1.12.0)", "torchaudio", "torchvision"] audio = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)"] codecarbon = ["codecarbon (==1.2.0)"] -deepspeed = ["accelerate (>=0.20.2)", "deepspeed (>=0.8.3)"] -deepspeed-testing = ["GitPython (<3.1.19)", "accelerate (>=0.20.2)", "beautifulsoup4", "black (>=23.1,<24.0)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "deepspeed (>=0.8.3)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "hf-doc-builder (>=0.3.0)", "nltk", "optuna", "parameterized", "protobuf (<=3.20.3)", "psutil", "pytest (>=7.2.0)", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "sentencepiece (>=0.1.91,!=0.1.92)", "timeout-decorator"] -dev = ["GitPython (<3.1.19)", "Pillow", "accelerate (>=0.20.2)", "av (==9.2.0)", "beautifulsoup4", "black (>=23.1,<24.0)", "codecarbon (==1.2.0)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "decord (==0.6.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "flax (>=0.4.1,<=0.6.9)", "fugashi (>=1.0)", "hf-doc-builder", "hf-doc-builder (>=0.3.0)", "ipadic (>=1.0.0,<2.0)", "isort (>=5.5.4)", "jax (>=0.2.8,!=0.3.2,<=0.3.6)", "jaxlib (>=0.1.65,<=0.3.6)", "kenlm", "keras-nlp (>=0.3.1)", "librosa", "nltk", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "parameterized", "phonemizer", "protobuf (<=3.20.3)", "psutil", "pyctcdecode (>=0.4.0)", "pytest (>=7.2.0)", "pytest-timeout", "pytest-xdist", "ray[tune]", "rhoknp (>=1.1.0,<1.3.1)", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (>=0.0.241,<=0.0.259)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "tensorflow (>=2.4,<2.13)", "tensorflow-text (<2.13)", "tf2onnx", "timeout-decorator", "timm", "tokenizers (>=0.11.1,!=0.11.3,<0.14)", "torch (>=1.9,!=1.12.0)", "torchaudio", "torchvision", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)", "urllib3 (<2.0.0)"] -dev-tensorflow = ["GitPython (<3.1.19)", "Pillow", "beautifulsoup4", "black (>=23.1,<24.0)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "hf-doc-builder", "hf-doc-builder (>=0.3.0)", "isort (>=5.5.4)", "kenlm", "keras-nlp (>=0.3.1)", "librosa", "nltk", "onnxconverter-common", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "parameterized", "phonemizer", "protobuf (<=3.20.3)", "psutil", "pyctcdecode (>=0.4.0)", "pytest (>=7.2.0)", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (>=0.0.241,<=0.0.259)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorflow (>=2.4,<2.13)", "tensorflow-text (<2.13)", "tf2onnx", "timeout-decorator", "tokenizers (>=0.11.1,!=0.11.3,<0.14)", "urllib3 (<2.0.0)"] -dev-torch = ["GitPython (<3.1.19)", "Pillow", "accelerate (>=0.20.2)", "beautifulsoup4", "black (>=23.1,<24.0)", "codecarbon (==1.2.0)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "fugashi (>=1.0)", "hf-doc-builder", "hf-doc-builder (>=0.3.0)", "ipadic (>=1.0.0,<2.0)", "isort (>=5.5.4)", "kenlm", "librosa", "nltk", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "optuna", "parameterized", "phonemizer", "protobuf (<=3.20.3)", "psutil", "pyctcdecode (>=0.4.0)", "pytest (>=7.2.0)", "pytest-timeout", "pytest-xdist", "ray[tune]", "rhoknp (>=1.1.0,<1.3.1)", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (>=0.0.241,<=0.0.259)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "timeout-decorator", "timm", "tokenizers (>=0.11.1,!=0.11.3,<0.14)", "torch (>=1.9,!=1.12.0)", "torchaudio", "torchvision", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)", "urllib3 (<2.0.0)"] -docs = ["Pillow", "accelerate (>=0.20.2)", "av (==9.2.0)", "codecarbon (==1.2.0)", "decord (==0.6.0)", "flax (>=0.4.1,<=0.6.9)", "hf-doc-builder", "jax (>=0.2.8,!=0.3.2,<=0.3.6)", "jaxlib (>=0.1.65,<=0.3.6)", "kenlm", "keras-nlp (>=0.3.1)", "librosa", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "phonemizer", "protobuf (<=3.20.3)", "pyctcdecode (>=0.4.0)", "ray[tune]", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "tensorflow (>=2.4,<2.13)", "tensorflow-text (<2.13)", "tf2onnx", "timm", "tokenizers (>=0.11.1,!=0.11.3,<0.14)", "torch (>=1.9,!=1.12.0)", "torchaudio", "torchvision"] +deepspeed = ["accelerate (>=0.21.0)", "deepspeed (>=0.9.3)"] +deepspeed-testing = ["GitPython (<3.1.19)", "accelerate (>=0.21.0)", "beautifulsoup4", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "deepspeed (>=0.9.3)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "hf-doc-builder (>=0.3.0)", "nltk", "optuna", "parameterized", "protobuf", "psutil", "pydantic (<2)", "pytest (>=7.2.0)", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.1.5)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorboard", "timeout-decorator"] +dev = ["GitPython (<3.1.19)", "Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.21.0)", "av (==9.2.0)", "beautifulsoup4", "codecarbon (==1.2.0)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "decord (==0.6.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "flax (>=0.4.1,<=0.7.0)", "fugashi (>=1.0)", "hf-doc-builder", "hf-doc-builder (>=0.3.0)", "ipadic (>=1.0.0,<2.0)", "isort (>=5.5.4)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "kenlm", "keras-nlp (>=0.3.1)", "librosa", "nltk", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pydantic (<2)", "pytest (>=7.2.0)", "pytest-timeout", "pytest-xdist", "ray[tune] (>=2.7.0)", "rhoknp (>=1.1.0,<1.3.1)", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.1.5)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "tensorboard", "tensorflow (>=2.6,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timeout-decorator", "timm", "tokenizers (>=0.14,<0.19)", "torch (>=1.10,!=1.12.0)", "torchaudio", "torchvision", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)", "urllib3 (<2.0.0)"] +dev-tensorflow = ["GitPython (<3.1.19)", "Pillow (>=10.0.1,<=15.0)", "beautifulsoup4", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "hf-doc-builder", "hf-doc-builder (>=0.3.0)", "isort (>=5.5.4)", "kenlm", "keras-nlp (>=0.3.1)", "librosa", "nltk", "onnxconverter-common", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pydantic (<2)", "pytest (>=7.2.0)", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.1.5)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorboard", "tensorflow (>=2.6,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timeout-decorator", "tokenizers (>=0.14,<0.19)", "urllib3 (<2.0.0)"] +dev-torch = ["GitPython (<3.1.19)", "Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.21.0)", "beautifulsoup4", "codecarbon (==1.2.0)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "fugashi (>=1.0)", "hf-doc-builder", "hf-doc-builder (>=0.3.0)", "ipadic (>=1.0.0,<2.0)", "isort (>=5.5.4)", "kenlm", "librosa", "nltk", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "optuna", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pydantic (<2)", "pytest (>=7.2.0)", "pytest-timeout", "pytest-xdist", "ray[tune] (>=2.7.0)", "rhoknp (>=1.1.0,<1.3.1)", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.1.5)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "tensorboard", "timeout-decorator", "timm", "tokenizers (>=0.14,<0.19)", "torch (>=1.10,!=1.12.0)", "torchaudio", "torchvision", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)", "urllib3 (<2.0.0)"] +docs = ["Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.21.0)", "av (==9.2.0)", "codecarbon (==1.2.0)", "decord (==0.6.0)", "flax (>=0.4.1,<=0.7.0)", "hf-doc-builder", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "kenlm", "keras-nlp (>=0.3.1)", "librosa", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "phonemizer", "protobuf", "pyctcdecode (>=0.4.0)", "ray[tune] (>=2.7.0)", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "tensorflow (>=2.6,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timm", "tokenizers (>=0.14,<0.19)", "torch (>=1.10,!=1.12.0)", "torchaudio", "torchvision"] docs-specific = ["hf-doc-builder"] -fairscale = ["fairscale (>0.3)"] -flax = ["flax (>=0.4.1,<=0.6.9)", "jax (>=0.2.8,!=0.3.2,<=0.3.6)", "jaxlib (>=0.1.65,<=0.3.6)", "optax (>=0.0.8,<=0.1.4)"] +flax = ["flax (>=0.4.1,<=0.7.0)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "optax (>=0.0.8,<=0.1.4)"] flax-speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)"] ftfy = ["ftfy"] -integrations = ["optuna", "ray[tune]", "sigopt"] +integrations = ["optuna", "ray[tune] (>=2.7.0)", "sigopt"] ja = ["fugashi (>=1.0)", "ipadic (>=1.0.0,<2.0)", "rhoknp (>=1.1.0,<1.3.1)", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)"] modelcreation = ["cookiecutter (==1.7.3)"] natten = ["natten (>=0.14.6)"] onnx = ["onnxconverter-common", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "tf2onnx"] onnxruntime = ["onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)"] optuna = ["optuna"] -quality = ["GitPython (<3.1.19)", "black (>=23.1,<24.0)", "datasets (!=2.5.0)", "hf-doc-builder (>=0.3.0)", "isort (>=5.5.4)", "ruff (>=0.0.241,<=0.0.259)", "urllib3 (<2.0.0)"] -ray = ["ray[tune]"] +quality = ["GitPython (<3.1.19)", "datasets (!=2.5.0)", "hf-doc-builder (>=0.3.0)", "isort (>=5.5.4)", "ruff (==0.1.5)", "urllib3 (<2.0.0)"] +ray = ["ray[tune] (>=2.7.0)"] retrieval = ["datasets (!=2.5.0)", "faiss-cpu"] sagemaker = ["sagemaker (>=2.31.0)"] -sentencepiece = ["protobuf (<=3.20.3)", "sentencepiece (>=0.1.91,!=0.1.92)"] -serving = ["fastapi", "pydantic", "starlette", "uvicorn"] +sentencepiece = ["protobuf", "sentencepiece (>=0.1.91,!=0.1.92)"] +serving = ["fastapi", "pydantic (<2)", "starlette", "uvicorn"] sigopt = ["sigopt"] sklearn = ["scikit-learn"] speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)", "torchaudio"] -testing = ["GitPython (<3.1.19)", "beautifulsoup4", "black (>=23.1,<24.0)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "hf-doc-builder (>=0.3.0)", "nltk", "parameterized", "protobuf (<=3.20.3)", "psutil", "pytest (>=7.2.0)", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "timeout-decorator"] -tf = ["keras-nlp (>=0.3.1)", "onnxconverter-common", "tensorflow (>=2.4,<2.13)", "tensorflow-text (<2.13)", "tf2onnx"] -tf-cpu = ["keras-nlp (>=0.3.1)", "onnxconverter-common", "tensorflow-cpu (>=2.4,<2.13)", "tensorflow-text (<2.13)", "tf2onnx"] +testing = ["GitPython (<3.1.19)", "beautifulsoup4", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "hf-doc-builder (>=0.3.0)", "nltk", "parameterized", "protobuf", "psutil", "pydantic (<2)", "pytest (>=7.2.0)", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.1.5)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "tensorboard", "timeout-decorator"] +tf = ["keras-nlp (>=0.3.1)", "onnxconverter-common", "tensorflow (>=2.6,<2.16)", "tensorflow-text (<2.16)", "tf2onnx"] +tf-cpu = ["keras-nlp (>=0.3.1)", "onnxconverter-common", "tensorflow-cpu (>=2.6,<2.16)", "tensorflow-text (<2.16)", "tf2onnx"] tf-speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)"] timm = ["timm"] -tokenizers = ["tokenizers (>=0.11.1,!=0.11.3,<0.14)"] -torch = ["accelerate (>=0.20.2)", "torch (>=1.9,!=1.12.0)"] +tokenizers = ["tokenizers (>=0.14,<0.19)"] +torch = ["accelerate (>=0.21.0)", "torch (>=1.10,!=1.12.0)"] torch-speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)", "torchaudio"] -torch-vision = ["Pillow", "torchvision"] -torchhub = ["filelock", "huggingface-hub (>=0.14.1,<1.0)", "importlib-metadata", "numpy (>=1.17)", "packaging (>=20.0)", "protobuf (<=3.20.3)", "regex (!=2019.12.17)", "requests", "sentencepiece (>=0.1.91,!=0.1.92)", "tokenizers (>=0.11.1,!=0.11.3,<0.14)", "torch (>=1.9,!=1.12.0)", "tqdm (>=4.27)"] +torch-vision = ["Pillow (>=10.0.1,<=15.0)", "torchvision"] +torchhub = ["filelock", "huggingface-hub (>=0.19.3,<1.0)", "importlib-metadata", "numpy (>=1.17)", "packaging (>=20.0)", "protobuf", "regex (!=2019.12.17)", "requests", "sentencepiece (>=0.1.91,!=0.1.92)", "tokenizers (>=0.14,<0.19)", "torch (>=1.10,!=1.12.0)", "tqdm (>=4.27)"] video = ["av (==9.2.0)", "decord (==0.6.0)"] -vision = ["Pillow"] +vision = ["Pillow (>=10.0.1,<=15.0)"] [[package]] name = "triton" -version = "2.0.0" +version = "2.1.0" description = "A language and compiler for custom Deep Learning operations" optional = false python-versions = "*" files = [ - {file = "triton-2.0.0-1-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:38806ee9663f4b0f7cd64790e96c579374089e58f49aac4a6608121aa55e2505"}, - {file = "triton-2.0.0-1-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:226941c7b8595219ddef59a1fdb821e8c744289a132415ddd584facedeb475b1"}, - {file = "triton-2.0.0-1-cp36-cp36m-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:4c9fc8c89874bc48eb7e7b2107a9b8d2c0bf139778637be5bfccb09191685cfd"}, - {file = "triton-2.0.0-1-cp37-cp37m-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:d2684b6a60b9f174f447f36f933e9a45f31db96cb723723ecd2dcfd1c57b778b"}, - {file = "triton-2.0.0-1-cp38-cp38-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:9d4978298b74fcf59a75fe71e535c092b023088933b2f1df933ec32615e4beef"}, - {file = "triton-2.0.0-1-cp39-cp39-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:74f118c12b437fb2ca25e1a04759173b517582fcf4c7be11913316c764213656"}, - {file = "triton-2.0.0-1-pp37-pypy37_pp73-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:9618815a8da1d9157514f08f855d9e9ff92e329cd81c0305003eb9ec25cc5add"}, - {file = "triton-2.0.0-1-pp38-pypy38_pp73-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:1aca3303629cd3136375b82cb9921727f804e47ebee27b2677fef23005c3851a"}, - {file = "triton-2.0.0-1-pp39-pypy39_pp73-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:e3e13aa8b527c9b642e3a9defcc0fbd8ffbe1c80d8ac8c15a01692478dc64d8a"}, - {file = "triton-2.0.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8f05a7e64e4ca0565535e3d5d3405d7e49f9d308505bb7773d21fb26a4c008c2"}, - {file = "triton-2.0.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bb4b99ca3c6844066e516658541d876c28a5f6e3a852286bbc97ad57134827fd"}, - {file = "triton-2.0.0-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:47b4d70dc92fb40af553b4460492c31dc7d3a114a979ffb7a5cdedb7eb546c08"}, - {file = "triton-2.0.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fedce6a381901b1547e0e7e1f2546e4f65dca6d91e2d8a7305a2d1f5551895be"}, - {file = "triton-2.0.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:75834f27926eab6c7f00ce73aaf1ab5bfb9bec6eb57ab7c0bfc0a23fac803b4c"}, - {file = "triton-2.0.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0117722f8c2b579cd429e0bee80f7731ae05f63fe8e9414acd9a679885fcbf42"}, - {file = "triton-2.0.0-pp37-pypy37_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bcd9be5d0c2e45d2b7e6ddc6da20112b6862d69741576f9c3dbaf941d745ecae"}, - {file = "triton-2.0.0-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:42a0d2c3fc2eab4ba71384f2e785fbfd47aa41ae05fa58bf12cb31dcbd0aeceb"}, - {file = "triton-2.0.0-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:52c47b72c72693198163ece9d90a721299e4fb3b8e24fd13141e384ad952724f"}, + {file = "triton-2.1.0-0-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:66439923a30d5d48399b08a9eae10370f6c261a5ec864a64983bae63152d39d7"}, + {file = "triton-2.1.0-0-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:919b06453f0033ea52c13eaf7833de0e57db3178d23d4e04f9fc71c4f2c32bf8"}, + {file = "triton-2.1.0-0-cp37-cp37m-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:ae4bb8a91de790e1866405211c4d618379781188f40d5c4c399766914e84cd94"}, + {file = "triton-2.1.0-0-cp38-cp38-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:39f6fb6bdccb3e98f3152e3fbea724f1aeae7d749412bbb1fa9c441d474eba26"}, + {file = "triton-2.1.0-0-cp39-cp39-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:21544e522c02005a626c8ad63d39bdff2f31d41069592919ef281e964ed26446"}, + {file = "triton-2.1.0-0-pp37-pypy37_pp73-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:143582ca31dd89cd982bd3bf53666bab1c7527d41e185f9e3d8a3051ce1b663b"}, + {file = "triton-2.1.0-0-pp38-pypy38_pp73-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:82fc5aeeedf6e36be4e4530cbdcba81a09d65c18e02f52dc298696d45721f3bd"}, + {file = "triton-2.1.0-0-pp39-pypy39_pp73-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:81a96d110a738ff63339fc892ded095b31bd0d205e3aace262af8400d40b6fa8"}, ] [package.dependencies] -cmake = "*" filelock = "*" -lit = "*" -torch = "*" [package.extras] +build = ["cmake (>=3.18)", "lit"] tests = ["autopep8", "flake8", "isort", "numpy", "pytest", "scipy (>=1.7.1)"] tutorials = ["matplotlib", "pandas", "tabulate"] @@ -3752,23 +4057,24 @@ files = [ [[package]] name = "tyro" -version = "0.5.3" +version = "0.6.0" description = "Strongly typed, zero-effort CLI interfaces" optional = false -python-versions = ">=3.7,<4.0" +python-versions = ">=3.7" files = [ - {file = "tyro-0.5.3-py3-none-any.whl", hash = "sha256:5dc67b189694015ac7922124255c8c88274507910e1a02adadf510b730270eca"}, - {file = "tyro-0.5.3.tar.gz", hash = "sha256:ca074d911af86e30c31e2a17a0c58f67573421a98892a1b2bc0baf271dc1862b"}, + {file = "tyro-0.6.0-py3-none-any.whl", hash = "sha256:1ff3697dece8bcf0f0597f1dacfaa04bdbe96d3b5371dff95f3f1bc61429d7a4"}, + {file = "tyro-0.6.0.tar.gz", hash = "sha256:3e9892762d18b95869f8053bb12fb7999ea337148f34153b0d5ca9b97eafa28f"}, ] [package.dependencies] -colorama = {version = ">=0.4.0,<0.5.0", markers = "sys_platform == \"win32\""} -docstring-parser = ">=0.14.1,<0.15.0" -frozendict = ">=2.3.4,<3.0.0" -PyYAML = ">=6.0,<7.0" +colorama = {version = ">=0.4.0", markers = "platform_system == \"Windows\""} +docstring-parser = ">=0.14.1" rich = ">=11.1.0" -shtab = ">=1.5.6,<2.0.0" -typing-extensions = ">=4.3.0,<5.0.0" +shtab = ">=1.5.6" +typing-extensions = ">=4.3.0" + +[package.extras] +dev = ["PyYAML (>=6.0)", "attrs (>=21.4.0)", "coverage[toml] (>=6.5.0)", "flax (>=0.6.9)", "frozendict (>=2.3.4)", "mypy (>=1.4.1)", "numpy (>=1.20.0)", "omegaconf (>=2.2.2)", "pydantic (>=2.3.0)", "pyright (>=1.1.264)", "pytest (>=7.1.2)", "pytest-cov (>=3.0.0)", "torch (>=1.10.0)"] [[package]] name = "tzdata" @@ -4219,4 +4525,4 @@ testing = ["big-O", "flake8 (<5)", "jaraco.functools", "jaraco.itertools", "more [metadata] lock-version = "2.0" python-versions = "^3.8" -content-hash = "38feccf314a01e0cba0808d9cad747f035d1627ecbd852a6701bdd8228fb2ad5" +content-hash = "cc56f3606c4e024c6a1ceafe6d1544c50f6793978cfe1ccd5723fcc3fe82b10c" diff --git a/pyproject.toml b/pyproject.toml index ac3b7be..1b23ff3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,16 +8,16 @@ packages = [{include = "lm_human_preference_details"}] [tool.poetry.dependencies] python = "^3.8" -torch = "2.0.0" -tyro = "^0.5.3" +torch = "^2.1.2" +tyro = "^0.6.0" datasets = "^2.12.0" wandb = "^0.15.4" nvitop = "^1.1.2" ftfy = "^6.1.1" rich = "^13.4.2" -transformers = "^4.30.1" +transformers = "^4.36.1" tensorboard = "^2.13.0" -accelerate = "^0.22.0" +accelerate = "^0.25.0" jax = "0.4.8" flax = "0.6.8" optax = "0.1.4" @@ -26,6 +26,12 @@ orbax = "0.1.4" einops = "^0.6.1" black = "^23.7.0" clu = "^0.0.9" +tabulate = "^0.9.0" +deepspeed = "^0.12.5" +evaluate = "^0.4.1" +nltk = "^3.8.1" +rouge-score = "^0.1.2" +huggingface-hub = "^0.19.4" [tool.poetry.group.dev.dependencies] From 2a12638368422ef8c9d0717247b3388df63c18e1 Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Mon, 18 Dec 2023 21:33:52 +0000 Subject: [PATCH 39/62] quick push --- .../train_reward_accelerate_summarize.py | 552 +++++++----------- .../train_sft_accelerate_summarize.py | 144 ++--- 2 files changed, 288 insertions(+), 408 deletions(-) diff --git a/lm_human_preference_details/train_reward_accelerate_summarize.py b/lm_human_preference_details/train_reward_accelerate_summarize.py index e096f8b..e4be55a 100644 --- a/lm_human_preference_details/train_reward_accelerate_summarize.py +++ b/lm_human_preference_details/train_reward_accelerate_summarize.py @@ -12,7 +12,6 @@ import torch.nn as nn import torch.nn.functional as F import torch.optim as optim -import transformers import tyro from accelerate import Accelerator from accelerate.state import AcceleratorState @@ -21,16 +20,11 @@ from rich.console import Console from rich.pretty import pprint from rich.table import Table -from torch import Tensor, optim -from torch.optim.optimizer import ( - _dispatch_sqrt, - _get_value, - _use_grad_for_differentiable, -) +from torch import optim from torch.utils.tensorboard import SummaryWriter from torch.utils.data import DataLoader from tqdm import tqdm -from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, get_scheduler +from transformers import AutoConfig, AutoModel, AutoTokenizer, get_scheduler, PreTrainedModel, PretrainedConfig @dataclass @@ -43,14 +37,27 @@ class LabelHParams: # a patch @dataclass -class TaskQueryHParams: - length: int = None - dataset: str = None - format_str: Optional[str] = None # if underlying dataset yields dicts, can format arbitrarily - truncate_field: Optional[str] = None - truncate_text: Optional[str] = None - padding: Optional[str] = None # defaults to repeated spaces - pad_side: Optional[str] = None +class TaskHParams: + # Query params + query_length: int = 512 + query_dataset: str = "vwxyzjn/summarize_from_feedback_tldr_3_filtered_oai_preprocessing_pythia-160m_53" + + query_format_str: Optional[str] = "SUBREDDIT: r/{subreddit}\n\nTITLE: {title}\n\nPOST: {post}\n\nTL;DR:" + query_truncate_field: Optional[str] = "post" + query_truncate_text: Optional[str] = "\n" + query_padding: Optional[str] = None # defaults to repeated spaces + query_pad_side: Optional[str] = "left" + + # Response params + response_length: int = 53 + + # Truncate response after the first occurrence of this token at or after index after when sampling. + truncate_token: int = 50256 # EOS token + truncate_after: int = 16 + penalty_reward_value: int = -1 + + # LM params + temperature: float = 0.01 @dataclass @@ -68,69 +75,66 @@ class Args: """the entity (team) of wandb's project""" cuda: bool = True """Whether to use cuda if available.""" - run_name: tyro.conf.Suppress[str] = None - """TO BE FILLED: a unique name of this run""" + run_name: Optional[str] = None + """a unique name of this run""" load_from_cache_file: bool = False """Whether to load data from the local cache file in `dataset.map`""" - - base_model: str = "EleutherAI/pythia-160m" - """the name of the pretrained model to use""" - dropout_layer_keys: List[str] = field(default_factory=lambda: ["attn_pdrop", "embd_pdrop", "resid_pdrop", "summary_first_dropout"]) - """Which layers to apply dropout to""" + push_to_hub: bool = False + "whether to upload the saved model to huggingface" + hf_entity: str = "" + "the user or org name of the model repository from the Hugging Face Hub" deepspeed: bool = False """Whether to use deepspeed to train the model""" - label_dataset: str = "vwxyzjn/summarize_from_feedback_oai_preprocessing_pythia-160m_169" - """the name of the dataset to use for labels in `https://huggingface.co/datasets/vwxyzjn/lm-human-preferences`""" - local_batch_size: int = 8 - """per rank batch size""" - gradient_accumulation_steps: int = 1 - """gradient accumulation steps""" - local_micro_batch_size: tyro.conf.Suppress[int] = None - """per rank micro batch size""" - local_eval_batch_size: int = 8 - """per rank eval batch size""" + print_sample_output_freq: int = 220 + """How often to print sample output""" + run_eval: bool = False + """Whether to run evaluation""" + + # optimizer args + eps: float = 1e-5 + """the epsilon value for the optimizer""" lr: float = 5e-6 """the learning rate""" - eps: float = 1e-5 - """the epsilon for AdamW""" - local_rollout_batch_size: int = 512 - """per rank rollout batch size""" - rollout_batch_size: tyro.conf.Suppress[int] = None - """rollout batch size""" - world_size: tyro.conf.Suppress[int] = None - """the number of processes to use""" - batch_size: tyro.conf.Suppress[int] = None - """the batch size across all ranks""" - local_normalize_samples: int = 256 - """Samples used to estimate reward mean and std""" - normalize_samples: tyro.conf.Suppress[int] = None - """Samples used to estimate reward mean and std across all ranks""" - debug_normalize: int = 0 - """Samples used to check that normalization worked""" - normalize_before: bool = False - """Whether, before training, to normalize the rewards on the policy to the scales on the training buffer. (For comparisons, just use mean 0, var 1.)""" - normalize_after: bool = False - """Whether, after training, to normalize the rewards on the ref policy to mean 0, var 1 (so the KL coefficient always has the same meaning).""" - print_sample_output_freq: int = 300 - """How often to print sample output""" - sft_model_path: str = "" - """Where to load the SFT model""" - logsigmoid: bool = True - """Whether to use log-sigmoid loss instead of cross-entropy loss""" - trainable_param_percentage: float = 1.0 - """Percentage of parameters to train""" - num_epochs: int = 1 - """Number of epochs to train""" - num_updates: tyro.conf.Suppress[int] = None - """Number of updates to train""" - save_path: str = "models/reward" - """Where to save the model""" - optimizer: Literal["tf_adam", "adam", "adamw"] = "adamw" + optimizer: Literal["adam", "adamw"] = "adamw" """Which optimizer to use""" scheduler: str = "cosine" """Which scheduler to use""" warm_up_steps: int = 0 """Number of warm up steps for the scheduler""" + + gradient_accumulation_steps: int = 8 + """The number of gradient accumulation steps""" + local_micro_batch_size: Optional[int] = 1 + """The micro batch size per GPU (HF's `per_device_train_batch_size`)""" + total_episodes: Optional[int] = None + """The total number of episodes in the dataset""" + micro_batch_size: Optional[int] = None + """The micro batch size across devices (HF's `per_device_train_batch_size` * `world_size`)""" + local_batch_size: Optional[int] = None + """The batch size per GPU (HF's `per_device_train_batch_size` * `gradient_accumulation_steps`)""" + batch_size: Optional[int] = None + """The batch size across devices (HF's `per_device_train_batch_size` * `world_size` * `gradient_accumulation_steps`)""" + local_eval_batch_size: int = 8 + """per rank eval batch size""" + world_size: Optional[int] = None + """The number of processes (GPUs) to use""" + num_train_epochs: int = 1 + """Number of epochs to train""" + num_updates: Optional[int] = None + """The number of updates to train""" + + # other args + base_model: str = "EleutherAI/pythia-160m" + """the name of the pretrained model to use""" + dropout_layer_keys: List[str] = field(default_factory=lambda: ["attn_pdrop", "embd_pdrop", "resid_pdrop", "summary_first_dropout"]) + """Which layers to apply dropout to""" + output_dir: str = "models/reward_policy" + """Where to save the model""" + label_dataset: str = "vwxyzjn/summarize_from_feedback_oai_preprocessing_pythia-160m_169" + """the name of the dataset to use for labels in `https://huggingface.co/datasets/vwxyzjn/lm-human-preferences`""" + logsigmoid: bool = True + """Whether to use log-sigmoid loss instead of cross-entropy loss""" + task: TaskHParams = field(default_factory=TaskHParams) labels: LabelHParams = field(default_factory=LabelHParams) @@ -153,170 +157,36 @@ def print_rich_table(title: str, df: pd.DataFrame, console: Console) -> Table: console.print(table) -def _single_tensor_adam( - params: List[Tensor], - grads: List[Tensor], - exp_avgs: List[Tensor], - exp_avg_sqs: List[Tensor], - max_exp_avg_sqs: List[Tensor], - state_steps: List[Tensor], - grad_scale: Optional[Tensor], - found_inf: Optional[Tensor], - *, - amsgrad: bool, - beta1: float, - beta2: float, - lr: float, - weight_decay: float, - eps: float, - maximize: bool, - capturable: bool, - differentiable: bool, -): - assert grad_scale is None and found_inf is None - - for i, param in enumerate(params): - grad = grads[i] if not maximize else -grads[i] - exp_avg = exp_avgs[i] - exp_avg_sq = exp_avg_sqs[i] - step_t = state_steps[i] - # update step - step_t += 1 - # Decay the first and second moment running average coefficient - exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) - exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2) - step = _get_value(step_t) - - ### pytorch adam implementation: - # bias_correction1 = 1 - beta1 ** step - # bias_correction2 = 1 - beta2 ** step - # step_size = lr / bias_correction1 - # bias_correction2_sqrt = _dispatch_sqrt(bias_correction2) - # denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps) - # param.addcdiv_(exp_avg, denom, value=-step_size) - - ### tensorflow adam implementation: - lr_t = lr * _dispatch_sqrt(1 - beta2**step) / (1 - beta1**step) - denom = exp_avg_sq.sqrt().add_(eps) - param.addcdiv_(exp_avg, denom, value=-lr_t) - - -def adam( - params: List[Tensor], - grads: List[Tensor], - exp_avgs: List[Tensor], - exp_avg_sqs: List[Tensor], - max_exp_avg_sqs: List[Tensor], - state_steps: List[Tensor], - # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627 - # setting this as kwarg for now as functional API is compiled by torch/distributed/optim - foreach: Optional[bool] = None, - capturable: bool = False, - differentiable: bool = False, - fused: Optional[bool] = None, - grad_scale: Optional[Tensor] = None, - found_inf: Optional[Tensor] = None, - *, - amsgrad: bool, - beta1: float, - beta2: float, - lr: float, - weight_decay: float, - eps: float, - maximize: bool, -): - func = _single_tensor_adam - - func( - params, - grads, - exp_avgs, - exp_avg_sqs, - max_exp_avg_sqs, - state_steps, - amsgrad=amsgrad, - beta1=beta1, - beta2=beta2, - lr=lr, - weight_decay=weight_decay, - eps=eps, - maximize=maximize, - capturable=capturable, - differentiable=differentiable, - grad_scale=grad_scale, - found_inf=found_inf, - ) - - -class AdamTensorFlowStyle(optim.Adam): - @_use_grad_for_differentiable - def step(self, closure=None): - self._cuda_graph_capture_health_check() - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - - for group in self.param_groups: - params_with_grad = [] - grads = [] - exp_avgs = [] - exp_avg_sqs = [] - max_exp_avg_sqs = [] - state_steps = [] - beta1, beta2 = group["betas"] - - self._init_group( - group, - params_with_grad, - grads, - exp_avgs, - exp_avg_sqs, - max_exp_avg_sqs, - state_steps, - ) - - adam( - params_with_grad, - grads, - exp_avgs, - exp_avg_sqs, - max_exp_avg_sqs, - state_steps, - amsgrad=group["amsgrad"], - beta1=beta1, - beta2=beta2, - lr=group["lr"], - weight_decay=group["weight_decay"], - eps=group["eps"], - maximize=group["maximize"], - foreach=group["foreach"], - capturable=group["capturable"], - differentiable=group["differentiable"], - fused=group["fused"], - grad_scale=getattr(self, "grad_scale", None), - found_inf=getattr(self, "found_inf", None), - ) - - return loss - - def layer_init(layer, std=np.sqrt(2), bias_const=0.0): torch.nn.init.normal_(layer.weight, std=std) torch.nn.init.constant_(layer.bias, val=bias_const) return layer -class AutoModelForCausalLMWithRewardHead(nn.Module): - def __init__(self, lm_backbone): - super().__init__() - self.lm_backbone = lm_backbone +class ScalarModelConfig(PretrainedConfig): + model_type = 'scalar_model' + def __init__(self, base_model: str = "gpt2", **kwargs): + super().__init__(**kwargs) + self.base_model = base_model + +class ScalarModel(PreTrainedModel): + config_class = ScalarModelConfig + def __init__(self, config: ScalarModelConfig): + super().__init__(config) + self.config = config + self.model_config = AutoConfig.from_pretrained( + config.base_model, + trust_remote_code=True, + ) + self.lm_backbone = AutoModel.from_pretrained( + config.base_model, + config=self.model_config, + trust_remote_code=True, + ) self.scalar_head = layer_init( - nn.Linear(lm_backbone.config.hidden_size, 1), - std=1 / np.sqrt(lm_backbone.config.hidden_size + 1), + nn.Linear(self.model_config.hidden_size, 1), + std=1 / np.sqrt(self.model_config.hidden_size + 1), ) - # self.reward_gain = torch.nn.Parameter(torch.tensor(1.0), requires_grad=True) - # self.reward_bias = torch.nn.Parameter(torch.tensor(0.0), requires_grad=True) def forward(self, **kwargs): output = self.lm_backbone(**kwargs) @@ -324,30 +194,10 @@ def forward(self, **kwargs): return reward -# def left_padding_to_right_padding(tokens, pad_id): -# """Convert from left padding to right padding.""" -# assert tokens.ndim == 2 -# return torch.tensor( -# [[x for x in row if x != pad_id] + [pad_id] * (row == pad_id).sum() for row in tokens], -# device=tokens.device, -# ) - - -def ceil_div(a, b): - return (a - 1) // b + 1 - - -def exact_div(a, b): - q = a // b - if a != q * b: - raise ValueError(f"Inexact division: {a} / {b} = {a / b}") - return q - - -def get_reward(reward_model, query_responses, tokenizer): +def get_reward(model, query_responses, tokenizer): attention_mask = query_responses != tokenizer.pad_token_id input_ids = torch.masked_fill(query_responses, ~attention_mask, 0) - reward_logits = reward_model( + reward_logits = model( input_ids=input_ids, attention_mask=attention_mask, return_dict=True, @@ -361,8 +211,8 @@ def get_reward(reward_model, query_responses, tokenizer): return reward_logits[torch.arange(reward_logits.size(0), device=reward_logits.device), sequence_lengths] -def evaluate(args, accelerator, tokenizer, reward_model, dataloader): - reward_model.eval() +def evaluate(args, accelerator, tokenizer, model, dataloader): + model.eval() with torch.no_grad(): items = defaultdict(list) for data in tqdm(dataloader): @@ -372,7 +222,7 @@ def evaluate(args, accelerator, tokenizer, reward_model, dataloader): mb_query_tiled = mb_query.unsqueeze(1).repeat(1, args.labels.num_labels, 1) query_responses = torch.cat([mb_query_tiled, mb_responses], dim=2).flatten(0, 1) # query_responses = left_padding_to_right_padding(query_responses, tokenizer.pad_token_id) - predicted_reward = get_reward(reward_model, query_responses, tokenizer) + predicted_reward = get_reward(model, query_responses, tokenizer) predicted_reward = predicted_reward.view(-1, args.labels.num_labels) accuracy = (predicted_reward.argmax(1) == mb_best).float() @@ -390,28 +240,33 @@ def evaluate(args, accelerator, tokenizer, reward_model, dataloader): items["response0_policy"].append(data["response0_policy"][i]) items["response1_policy"].append(data["response1_policy"][i]) items["accuracy"].append(accuracy[i].item()) - reward_model.train() + model.train() return pd.DataFrame(items) # def train(args: Args): if __name__ == "__main__": args = tyro.cli(Args) - accelerator = Accelerator( - kwargs_handlers=[ - DistributedDataParallelKwargs( - broadcast_buffers=False, - # find_unused_parameters=True, - ) - ], # this is needed to avoid https://github.com/pytorch/pytorch/issues/22095#issuecomment-505099500 - gradient_accumulation_steps=args.gradient_accumulation_steps, - ) + accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps) + local_seed = args.seed + accelerator.process_index * 100003 # Prime args.world_size = accelerator.num_processes - args.batch_size = int(args.local_batch_size * args.world_size) - args.rollout_batch_size = int(args.local_rollout_batch_size * args.world_size) - args.local_micro_batch_size = exact_div(args.local_batch_size, args.gradient_accumulation_steps) + args.local_batch_size = args.local_micro_batch_size * args.gradient_accumulation_steps args.micro_batch_size = int(args.local_micro_batch_size * args.world_size) - args.num_updates = args.labels.num_train // args.local_batch_size + args.batch_size = int(args.local_batch_size * args.world_size) + + # load dataset + dataset = load_dataset(args.label_dataset, "comparisons", split="train") + dataset = dataset.shuffle(seed=local_seed) + dataset = dataset.select(range(args.labels.num_train)) + dataset = dataset.with_format("torch", columns=["query_token", "choice", "response0_token", "response1_token", "batch", "split"]) + dataloader = DataLoader(dataset, batch_size=args.local_micro_batch_size) + validation_dataset = load_dataset(args.label_dataset, "comparisons", split="validation").flatten() + validation_dataset = validation_dataset.with_format("torch", columns=["query_token", "choice", "response0_token", "response1_token", "batch", "split", "extra.confidence", "response0_policy", "response1_policy", "policies"]) + validation_dataloader = DataLoader(validation_dataset, batch_size=args.local_eval_batch_size) + accelerator.print("The number of samples in dataset", len(dataset)) + accelerator.print("The number of samples in validation_dataset", len(validation_dataset)) + args.total_episodes = len(dataset) + args.num_updates = args.total_episodes // args.local_batch_size console = Console(force_terminal=True) run_name = f"{args.exp_name}__{args.seed}__{int(time.time())}" @@ -437,7 +292,6 @@ def evaluate(args, accelerator, tokenizer, reward_model, dataloader): ) pprint(args) device = accelerator.device - local_seed = args.seed + accelerator.process_index * 100003 # Prime random.seed(local_seed) np.random.seed(local_seed) torch.manual_seed(local_seed) @@ -449,58 +303,27 @@ def evaluate(args, accelerator, tokenizer, reward_model, dataloader): ) # we use the padding token manually but do not resize the token embedding of the model tokenizer.add_special_tokens({"pad_token": "[PAD]"}) - model_config = AutoConfig.from_pretrained(args.base_model) - configure_dropout(model_config, args.dropout_layer_keys, 0.0) # disable dropout - if accelerator.is_main_process: - pprint(model_config) - reward_model = AutoModelForCausalLMWithRewardHead( - AutoModelForCausalLM.from_pretrained( - args.base_model, - config=model_config, - trust_remote_code=True, - ) - ) - # freeze the first 70% of layers - if args.trainable_param_percentage < 1.0: - layers = reward_model.lm_backbone.transformer.h - num_layers = len(layers) - num_unfrozen = int(args.trainable_param_percentage * num_layers) - for layer in layers[:-num_unfrozen]: - layer.requires_grad_(False) - - if args.sft_model_path: - reward_model.lm_backbone.load_state_dict(torch.load(args.sft_model_path, map_location=device)) - print(f"loaded SFT model from {args.sft_model_path}") - # make sure the `lm_head` or `embed_out` does not require gradients, otherwise - # pytorch DDP complains; see https://gist.github.com/vwxyzjn/45fc8706dfb3cf33695f0f57cc44a533 - if isinstance(reward_model.lm_backbone, transformers.GPTNeoXForCausalLM): - reward_model.lm_backbone.embed_out.requires_grad_(False) - if args.optimizer == "tf_adam": - optimizer = AdamTensorFlowStyle(reward_model.parameters(), lr=args.lr, eps=args.eps) - elif args.optimizer == "adam": - optimizer = optim.Adam(reward_model.parameters(), lr=args.lr, eps=args.eps) + # model_config = AutoConfig.from_pretrained(args.base_model) + # configure_dropout(model_config, args.dropout_layer_keys, 0.0) # disable dropout + # if accelerator.is_main_process: + # pprint(model_config) + model: PreTrainedModel = ScalarModel(ScalarModelConfig(base_model=args.base_model)) + if args.optimizer == "adam": + optimizer = optim.Adam(model.parameters(), lr=args.lr, eps=args.eps) elif args.optimizer == "adamw": - optimizer = optim.AdamW(reward_model.parameters(), lr=args.lr, eps=args.eps) + optimizer = optim.AdamW(model.parameters(), lr=args.lr, eps=args.eps) scheduler = get_scheduler( args.scheduler, optimizer=optimizer, num_warmup_steps=args.warm_up_steps, - num_training_steps=args.num_updates * args.num_epochs, + num_training_steps=args.num_updates * args.num_train_epochs, ) if args.deepspeed: deepspeed_states = AcceleratorState().deepspeed_plugin deepspeed_states.deepspeed_config["train_micro_batch_size_per_gpu"] = args.local_micro_batch_size - label = load_dataset(args.label_dataset, "comparisons", split="train") - label = label.shuffle(seed=local_seed) - label = label.select(range(args.labels.num_train)) - label = label.with_format("torch", columns=["query_token", "choice", "response0_token", "response1_token", "batch", "split"]) - dataloader = DataLoader(label, batch_size=args.local_micro_batch_size) - reward_model, optimizer, dataloader, scheduler = accelerator.prepare(reward_model, optimizer, dataloader, scheduler) - validation_label = load_dataset(args.label_dataset, "comparisons", split="validation").flatten() - validation_label = validation_label.with_format("torch", columns=["query_token", "choice", "response0_token", "response1_token", "batch", "split", "extra.confidence", "response0_policy", "response1_policy", "policies"]) - validation_dataloader = DataLoader(validation_label, batch_size=args.local_eval_batch_size) + model, optimizer, dataloader, scheduler = accelerator.prepare(model, optimizer, dataloader, scheduler) validation_dataloader = accelerator.prepare(validation_dataloader) accelerator.print("===training reward model===") @@ -508,11 +331,11 @@ def evaluate(args, accelerator, tokenizer, reward_model, dataloader): accuracies = torch.zeros((args.gradient_accumulation_steps,), device=device) reward_preferreds = torch.zeros((args.gradient_accumulation_steps,), device=device) reward_rejecteds = torch.zeros((args.gradient_accumulation_steps,), device=device) - reward_model.train() + model.train() gradient_accumulation_idx = 0 global_step = 0 update = 0 - for epoch in range(args.num_epochs): + for epoch in range(args.num_train_epochs): accelerator.print(f"epoch: {epoch}") for data in dataloader: update += 1 @@ -522,9 +345,8 @@ def evaluate(args, accelerator, tokenizer, reward_model, dataloader): mb_best = data["choice"] mb_query_tiled = mb_query.unsqueeze(1).repeat(1, args.labels.num_labels, 1) query_responses = torch.cat([mb_query_tiled, mb_responses], dim=2).flatten(0, 1) - # query_responses = left_padding_to_right_padding(query_responses, tokenizer.pad_token_id) - with accelerator.accumulate(reward_model): - predicted_reward = get_reward(reward_model, query_responses, tokenizer) + with accelerator.accumulate(model): + predicted_reward = get_reward(model, query_responses, tokenizer) predicted_reward = predicted_reward.view(-1, args.labels.num_labels) accuracy = (predicted_reward.argmax(1) == mb_best).float().mean() reward_preferred = predicted_reward.gather(1, mb_best.view(-1, 1)).view(-1) @@ -534,11 +356,7 @@ def evaluate(args, accelerator, tokenizer, reward_model, dataloader): else: loss = F.cross_entropy(predicted_reward, mb_best) accelerator.backward(loss) - # for k, v in reward_model.named_parameters(): - # if v.requires_grad: - # if v.grad is None: - # print(f"found unused param: {k}") - optimizer.step() # accelerate handles gradient accumulation automatically + optimizer.step() optimizer.zero_grad() scheduler.step() losses[gradient_accumulation_idx] = loss @@ -546,6 +364,7 @@ def evaluate(args, accelerator, tokenizer, reward_model, dataloader): reward_preferreds[gradient_accumulation_idx] = reward_preferred.mean() reward_rejecteds[gradient_accumulation_idx] = reward_rejected.mean() gradient_accumulation_idx = (gradient_accumulation_idx + 1) % args.gradient_accumulation_steps + break if update > 1 and (update - 1) % args.gradient_accumulation_steps == 0: train_accuracy = accelerator.gather(accuracies).mean().item() writer.add_scalar("train/loss", accelerator.gather(losses).mean().item(), global_step) @@ -556,35 +375,88 @@ def evaluate(args, accelerator, tokenizer, reward_model, dataloader): accelerator.print(f"{train_accuracy=}, {scheduler.get_last_lr()=}, {update=}") # if args.print_sample_output_freq > 0 and global_step % args.print_sample_output_freq == 0: - evaluate_df = evaluate(args, accelerator, tokenizer, reward_model, validation_dataloader) - for split, row in evaluate_df[["split", "accuracy"]].groupby(["split"]).mean().iterrows(): - writer.add_scalar(f"eval/accuracy/{split}", row["accuracy"], global_step) - accelerator.print(f"{split} accuracy: {row['accuracy']}") - for batch, row in evaluate_df[["batch", "accuracy"]].groupby(["batch"]).mean().iterrows(): - writer.add_scalar(f"eval/accuracy/{batch}", row["accuracy"], global_step) - accelerator.print(f"{batch} accuracy: {row['accuracy']}") - for confi, row in evaluate_df[["confidence", "accuracy"]].groupby(["confidence"]).mean().iterrows(): - writer.add_scalar(f"eval/confidence/{confi}", row["accuracy"], global_step) - accelerator.print(f"{confi} confidence: {row['accuracy']}") - writer.add_scalar("eval/accuracy", evaluate_df["accuracy"].mean(), global_step) - accelerator.print(f"eval accuracy: {evaluate_df['accuracy'].mean()}") - if accelerator.is_main_process: - os.makedirs(f"eval_tables/{run_name}", exist_ok=True) - evaluate_df.to_csv(f"eval_tables/{run_name}/eval_{update}.csv") - if args.track: - wandb.log({"samples/query_responses": wandb.Table(dataframe=evaluate_df)}, step=update) - - - torch.cuda.empty_cache() + # evaluate_df = evaluate(args, accelerator, tokenizer, model, validation_dataloader) + # for split, row in evaluate_df[["split", "accuracy"]].groupby(["split"]).mean().iterrows(): + # writer.add_scalar(f"eval/accuracy/{split}", row["accuracy"], global_step) + # accelerator.print(f"{split} accuracy: {row['accuracy']}") + # for batch, row in evaluate_df[["batch", "accuracy"]].groupby(["batch"]).mean().iterrows(): + # writer.add_scalar(f"eval/accuracy/{batch}", row["accuracy"], global_step) + # accelerator.print(f"{batch} accuracy: {row['accuracy']}") + # for confi, row in evaluate_df[["confidence", "accuracy"]].groupby(["confidence"]).mean().iterrows(): + # writer.add_scalar(f"eval/confidence/{confi}", row["accuracy"], global_step) + # accelerator.print(f"{confi} confidence: {row['accuracy']}") + # writer.add_scalar("eval/accuracy", evaluate_df["accuracy"].mean(), global_step) + # accelerator.print(f"eval accuracy: {evaluate_df['accuracy'].mean()}") + # if accelerator.is_main_process: + # os.makedirs(f"eval_tables/{run_name}", exist_ok=True) + # evaluate_df.to_csv(f"eval_tables/{run_name}/eval_{update}.csv") + # if args.track: + # wandb.log({"samples/query_responses": wandb.Table(dataframe=evaluate_df)}, step=update) + # torch.cuda.empty_cache() + + # norm_dataset = load_dataset(args.task.query_dataset, split="train") + # norm_dataset = norm_dataset.with_format("torch", columns=["query_token", "reference_response_token"]) + # norm_dataset = norm_dataset.shuffle(seed=local_seed) + # norm_dataloader = DataLoader(norm_dataset, batch_size=args.local_eval_batch_size) + # items = defaultdict(list) + # norm_dataloader = accelerator.prepare(norm_dataloader) + # with torch.no_grad(): + # for data in tqdm(norm_dataloader): + # reference_responses = data["reference_response_token"].to(device, non_blocking=True) + # queries = data["query_token"].to(device, non_blocking=True) + # query_responses = torch.cat((queries, reference_responses), dim=1) + # predicted_reward = get_reward(model, query_responses, tokenizer) + # predicted_reward = accelerator.gather(predicted_reward) + # queries = accelerator.gather(queries) + # reference_responses = accelerator.gather(reference_responses) + # accelerator.print(predicted_reward.shape) + # for i in range(len(predicted_reward)): + # items["query"].append(tokenizer.decode(queries[i], skip_special_tokens=True)) + # items["reference_response"].append(tokenizer.decode(reference_responses[i])) + # items["predicted_reward"].append(predicted_reward[i].item()) + + # if accelerator.is_main_process: + # norm_df = pd.DataFrame(items) + # os.makedirs(f"eval_tables/{run_name}", exist_ok=True) + # norm_df.to_csv(f"eval_tables/{run_name}/eval_{update}_normalized.csv") + # if args.track: + # wandb.log({"samples/normalized": wandb.Table(dataframe=norm_df)}, step=update) + # stats = { + # "mean": norm_df["predicted_reward"].mean(), + # "std": norm_df["predicted_reward"].std(), + # "max": norm_df["predicted_reward"].max(), + # "min": norm_df["predicted_reward"].min() + # } + # for stat_name, stat_value in stats.items(): + # writer.add_scalar(f"eval/normalized_{stat_name}", stat_value, global_step) + # accelerator.print(f"Normalized Reward {stat_name.capitalize()}: {stat_value}") # save model - if args.save_path: - os.makedirs(os.path.dirname(args.save_path), exist_ok=True) - accelerator.save_model(reward_model, args.save_path, max_shard_size="1000GB") + if args.output_dir and args.num_train_epochs > 0: + os.makedirs(os.path.dirname(args.output_dir), exist_ok=True) + time_tensor = torch.tensor(int(time.time()), device=device) + time_int = accelerator.gather(time_tensor).item() # avoid different timestamps across processes + repo_name = f"{args.base_model.replace('/', '_')}__{args.exp_name}__tldr__seed{args.seed}" + repo_id = f"{args.hf_entity}/{repo_name}" if args.hf_entity else repo_name - if accelerator.is_main_process and args.track: - wandb.finish() + if accelerator.is_main_process: + tokenizer.save_pretrained(args.output_dir, repo_id=repo_id) + if args.push_to_hub: + tokenizer.push_to_hub(repo_id, revision=str(time_int)) + unwrapped: PreTrainedModel = accelerator.unwrap_model(model) + accelerator.wait_for_everyone() + if accelerator.is_main_process: + unwrapped.save_pretrained( + args.output_dir, + is_main_process=accelerator.is_main_process, + save_function=accelerator.save, + state_dict=accelerator.get_state_dict(model), + safe_serialization=False, + repo_id=repo_id, + ) + if args.push_to_hub: + unwrapped.push_to_hub(repo_id, revision=str(time_int), safe_serialization=False) if __name__ == "__main__": args = tyro.cli(Args) diff --git a/lm_human_preference_details/train_sft_accelerate_summarize.py b/lm_human_preference_details/train_sft_accelerate_summarize.py index 58b9234..d434688 100644 --- a/lm_human_preference_details/train_sft_accelerate_summarize.py +++ b/lm_human_preference_details/train_sft_accelerate_summarize.py @@ -73,44 +73,36 @@ class Args: cuda: bool = True """Whether to use cuda if available.""" run_name: Optional[str] = None - """TO BE FILLED: a unique name of this run""" + """a unique name of this run""" load_from_cache_file: bool = False """Whether to load data from the local cache file in `dataset.map`""" push_to_hub: bool = False "whether to upload the saved model to huggingface" hf_entity: str = "" "the user or org name of the model repository from the Hugging Face Hub" - - base_model: str = "EleutherAI/pythia-160m" - """the name of the pretrained model to use""" - dropout_layer_keys: List[str] = field(default_factory=lambda: ["attn_pdrop", "embd_pdrop", "resid_pdrop", "summary_first_dropout"]) - """Which layers to apply dropout to""" deepspeed: bool = False """Whether to use deepspeed to train the model""" print_sample_output_freq: int = 220 """How often to print sample output""" - output_dir: str = "models/sft_policy" - """Where to save the model""" + run_eval: bool = False + """Whether to run evaluation""" + + # optimizer args + eps: float = 1e-5 + """the epsilon value for the optimizer""" + lr: float = 6.35e-5 + """the learning rate""" optimizer: Literal["adam", "adamw"] = "adamw" """Which optimizer to use""" scheduler: str = "cosine" """Which scheduler to use""" warm_up_steps: int = 0 """Number of warm up steps for the scheduler""" - run_eval: bool = False - """Whether to run evaluation""" - local_micro_batch_size: int = 1 - """The micro batch size per GPU (HF's `per_device_train_batch_size`)""" gradient_accumulation_steps: int = 16 """The number of gradient accumulation steps""" - noptepochs: int = 1 - """The number of epochs to train""" - lr: float = 6.35e-5 - """The learning rate""" - eps: float = 1e-5 - """The epsilon value for the optimizer""" - + local_micro_batch_size: Optional[int] = 1 + """The micro batch size per GPU (HF's `per_device_train_batch_size`)""" total_episodes: Optional[int] = None """The total number of episodes in the dataset""" micro_batch_size: Optional[int] = None @@ -119,10 +111,22 @@ class Args: """The batch size per GPU (HF's `per_device_train_batch_size` * `gradient_accumulation_steps`)""" batch_size: Optional[int] = None """The batch size across devices (HF's `per_device_train_batch_size` * `world_size` * `gradient_accumulation_steps`)""" + local_eval_batch_size: int = 8 + """per rank eval batch size""" world_size: Optional[int] = None """The number of processes (GPUs) to use""" + num_train_epochs: int = 1 + """Number of epochs to train""" num_updates: Optional[int] = None """The number of updates to train""" + + # other args + base_model: str = "EleutherAI/pythia-160m" + """the name of the pretrained model to use""" + dropout_layer_keys: List[str] = field(default_factory=lambda: ["attn_pdrop", "embd_pdrop", "resid_pdrop", "summary_first_dropout"]) + """Which layers to apply dropout to""" + output_dir: str = "models/sft_model" + """Where to save the model""" task: TaskHParams = field(default_factory=TaskHParams) @@ -160,11 +164,11 @@ def generate(lm_backbone, queries, tokenizer, generation_config): return torch.cat((queries, output.sequences[:, context_length:]), dim=1) -def forward(policy, query_responses, tokenizer): +def forward(model, query_responses, tokenizer): attention_mask = query_responses != tokenizer.pad_token_id # position_ids = attention_mask.cumsum(1) - attention_mask.long() # exclusive cumsum input_ids = torch.masked_fill(query_responses, ~attention_mask, 0) - return policy( + return model( input_ids=input_ids, attention_mask=attention_mask, # position_ids=position_ids, @@ -176,12 +180,20 @@ def forward(policy, query_responses, tokenizer): if __name__ == "__main__": args = tyro.cli(Args) accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps) + local_seed = args.seed + accelerator.process_index * 100003 # Prime args.world_size = accelerator.num_processes args.local_batch_size = args.local_micro_batch_size * args.gradient_accumulation_steps args.micro_batch_size = int(args.local_micro_batch_size * args.world_size) args.batch_size = int(args.local_batch_size * args.world_size) + + # load dataset dataset = load_dataset(args.task.query_dataset, split="train") + dataset = dataset.shuffle(seed=local_seed) + dataset = dataset.with_format("torch", columns=["query_token", "reference_response_token"]) + dataloader = DataLoader(dataset, batch_size=args.local_micro_batch_size) validation_dataset = load_dataset(args.task.query_dataset, split="validation") + validation_dataset = validation_dataset.with_format("torch", columns=["query_token", "reference_response_token"]) + validation_dataloader = DataLoader(validation_dataset, batch_size=args.local_eval_batch_size) accelerator.print("The number of samples in dataset", len(dataset)) accelerator.print("The number of samples in validation_dataset", len(validation_dataset)) args.total_episodes = len(dataset) @@ -211,7 +223,6 @@ def forward(policy, query_responses, tokenizer): ) pprint(args) device = accelerator.device - local_seed = args.seed + accelerator.process_index * 100003 # Prime random.seed(local_seed) np.random.seed(local_seed) torch.manual_seed(local_seed) @@ -227,13 +238,13 @@ def forward(policy, query_responses, tokenizer): configure_dropout(model_config, args.dropout_layer_keys, 0.0) # disable dropout if accelerator.is_main_process: pprint(model_config) - policy: PreTrainedModel = AutoModelForCausalLM.from_pretrained(args.base_model, config=model_config, trust_remote_code=True) - policy.generation_config.eos_token_id = None # disable `pad_token_id` and `eos_token_id` because we just want to - policy.generation_config.pad_token_id = None # generate tokens without truncation / padding + model: PreTrainedModel = AutoModelForCausalLM.from_pretrained(args.base_model, config=model_config, trust_remote_code=True) + model.generation_config.eos_token_id = None # disable `pad_token_id` and `eos_token_id` because we just want to + model.generation_config.pad_token_id = None # generate tokens without truncation / padding if args.optimizer == "adam": - optimizer = optim.Adam(policy.parameters(), lr=args.lr, eps=args.eps) + optimizer = optim.Adam(model.parameters(), lr=args.lr, eps=args.eps) elif args.optimizer == "adamw": - optimizer = optim.AdamW(policy.parameters(), lr=args.lr, eps=args.eps) + optimizer = optim.AdamW(model.parameters(), lr=args.lr, eps=args.eps) scheduler = get_scheduler( args.scheduler, optimizer=optimizer, @@ -241,13 +252,8 @@ def forward(policy, query_responses, tokenizer): num_training_steps=args.num_updates, ) - dataset = dataset.with_format("torch", columns=["query_token", "reference_response_token"]) - dataset = dataset.shuffle(seed=local_seed) - dataloader = DataLoader(dataset, batch_size=args.local_micro_batch_size) - validation_dataset = validation_dataset.with_format("torch", columns=["query_token", "reference_response_token"]) - validation_dataloader = DataLoader(validation_dataset, batch_size=args.local_micro_batch_size) - policy, optimizer, dataloader, scheduler = accelerator.prepare( - policy, optimizer, dataloader, scheduler + model, optimizer, dataloader, scheduler = accelerator.prepare( + model, optimizer, dataloader, scheduler ) validation_dataloader = accelerator.prepare(validation_dataloader) # WARNING: even with `max_new_tokens` and `min_new_tokens` set to the same value, the number of tokens generated @@ -262,41 +268,43 @@ def forward(policy, query_responses, tokenizer): ) rouge = evaluate.load("rouge") - accelerator.print("===training policy===") + accelerator.print("===training model===") loss_stats = torch.zeros(args.gradient_accumulation_steps, device=device) - policy.train() + model.train() gradient_accumulation_idx = 0 global_step = 0 update = 0 - for data in dataloader: - update += 1 - global_step += args.micro_batch_size - reference_responses = data["reference_response_token"].to(device, non_blocking=True) - queries = data["query_token"].to(device, non_blocking=True) - query_responses = torch.cat((queries, reference_responses), dim=1) - with accelerator.accumulate(policy): - output = forward(policy, query_responses, tokenizer) - # mask out gradient effects on response padding tokens - labels = query_responses.masked_fill(query_responses == tokenizer.pad_token_id, -1) - lm_logits = output.logits - # hand-rolled transformer loss: Shift so that tokens < n predict n - # but unlike `transformers` we mask the padding tokens via `ignore_index=-1` - shift_logits = lm_logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - loss = F.cross_entropy(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1), ignore_index=-1) - accelerator.backward(loss) - optimizer.step() - optimizer.zero_grad() - scheduler.step() - loss_stats[gradient_accumulation_idx] = loss - gradient_accumulation_idx = (gradient_accumulation_idx + 1) % args.gradient_accumulation_steps - if update > 1 and (update - 1) % args.gradient_accumulation_steps == 0: - writer.add_scalar("loss", accelerator.gather(loss_stats).mean().item(), update) - writer.add_scalar("lr", scheduler.get_last_lr()[0], update) - accelerator.print(f"{loss.item()=}, {scheduler.get_last_lr()=}, {update=}") - break + for epoch in range(args.num_train_epochs): + accelerator.print(f"epoch: {epoch}") + for data in dataloader: + update += 1 + global_step += args.micro_batch_size + reference_responses = data["reference_response_token"].to(device, non_blocking=True) + queries = data["query_token"].to(device, non_blocking=True) + query_responses = torch.cat((queries, reference_responses), dim=1) + with accelerator.accumulate(model): + output = forward(model, query_responses, tokenizer) + # mask out gradient effects on response padding tokens + labels = query_responses.masked_fill(query_responses == tokenizer.pad_token_id, -1) + lm_logits = output.logits + # hand-rolled transformer loss: Shift so that tokens < n predict n + # but unlike `transformers` we mask the padding tokens via `ignore_index=-1` + shift_logits = lm_logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + loss = F.cross_entropy(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1), ignore_index=-1) + accelerator.backward(loss) + optimizer.step() + optimizer.zero_grad() + scheduler.step() + loss_stats[gradient_accumulation_idx] = loss + gradient_accumulation_idx = (gradient_accumulation_idx + 1) % args.gradient_accumulation_steps + if update > 1 and (update - 1) % args.gradient_accumulation_steps == 0: + writer.add_scalar("loss", accelerator.gather(loss_stats).mean().item(), update) + writer.add_scalar("lr", scheduler.get_last_lr()[0], update) + accelerator.print(f"{loss.item()=}, {scheduler.get_last_lr()=}, {update=}") + break if args.run_eval: - policy.eval() + model.eval() rouge_scores = collections.defaultdict(list) all_decode_validation_queries = [] all_decode_validation_query_responses = [] @@ -311,7 +319,7 @@ def forward(policy, query_responses, tokenizer): (validation_queries, validation_reference_responses), dim=1 ) - validation_output = forward(policy, validation_query_reference_responses, tokenizer) + validation_output = forward(model, validation_query_reference_responses, tokenizer) validation_labels = validation_query_reference_responses.masked_fill( validation_query_reference_responses == tokenizer.pad_token_id, -1 ) @@ -329,7 +337,7 @@ def forward(policy, query_responses, tokenizer): all_validation_losses.append(validation_loss) generated_responses = generate( - accelerator.unwrap_model(policy), + accelerator.unwrap_model(model), validation_queries, tokenizer, generation_config, @@ -386,14 +394,14 @@ def forward(policy, query_responses, tokenizer): if args.push_to_hub: tokenizer.push_to_hub(repo_id, revision=str(time_int)) - unwrapped: PreTrainedModel = accelerator.unwrap_model(policy) + unwrapped: PreTrainedModel = accelerator.unwrap_model(model) accelerator.wait_for_everyone() if accelerator.is_main_process: unwrapped.save_pretrained( args.output_dir, is_main_process=accelerator.is_main_process, save_function=accelerator.save, - state_dict=accelerator.get_state_dict(policy), + state_dict=accelerator.get_state_dict(model), safe_serialization=False, repo_id=repo_id, ) From 7e1336fbff960223544cb511dc5b01914e6bbc89 Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Mon, 18 Dec 2023 21:34:19 +0000 Subject: [PATCH 40/62] rename --- .../{train_reward_accelerate_summarize.py => summarize/reward.py} | 0 .../{train_sft_accelerate_summarize.py => summarize/sft.py} | 0 2 files changed, 0 insertions(+), 0 deletions(-) rename lm_human_preference_details/{train_reward_accelerate_summarize.py => summarize/reward.py} (100%) rename lm_human_preference_details/{train_sft_accelerate_summarize.py => summarize/sft.py} (100%) diff --git a/lm_human_preference_details/train_reward_accelerate_summarize.py b/lm_human_preference_details/summarize/reward.py similarity index 100% rename from lm_human_preference_details/train_reward_accelerate_summarize.py rename to lm_human_preference_details/summarize/reward.py diff --git a/lm_human_preference_details/train_sft_accelerate_summarize.py b/lm_human_preference_details/summarize/sft.py similarity index 100% rename from lm_human_preference_details/train_sft_accelerate_summarize.py rename to lm_human_preference_details/summarize/sft.py From 89ea1c5c652fcad3d20c2b0a2c5f0f8fb1feaec3 Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Tue, 19 Dec 2023 01:28:14 +0000 Subject: [PATCH 41/62] quick change --- .../summarize/reward.py | 125 +++++++++--------- lm_human_preference_details/summarize/sft.py | 28 ++-- 2 files changed, 81 insertions(+), 72 deletions(-) diff --git a/lm_human_preference_details/summarize/reward.py b/lm_human_preference_details/summarize/reward.py index e4be55a..2f02ec5 100644 --- a/lm_human_preference_details/summarize/reward.py +++ b/lm_human_preference_details/summarize/reward.py @@ -375,74 +375,75 @@ def evaluate(args, accelerator, tokenizer, model, dataloader): accelerator.print(f"{train_accuracy=}, {scheduler.get_last_lr()=}, {update=}") # if args.print_sample_output_freq > 0 and global_step % args.print_sample_output_freq == 0: - # evaluate_df = evaluate(args, accelerator, tokenizer, model, validation_dataloader) - # for split, row in evaluate_df[["split", "accuracy"]].groupby(["split"]).mean().iterrows(): - # writer.add_scalar(f"eval/accuracy/{split}", row["accuracy"], global_step) - # accelerator.print(f"{split} accuracy: {row['accuracy']}") - # for batch, row in evaluate_df[["batch", "accuracy"]].groupby(["batch"]).mean().iterrows(): - # writer.add_scalar(f"eval/accuracy/{batch}", row["accuracy"], global_step) - # accelerator.print(f"{batch} accuracy: {row['accuracy']}") - # for confi, row in evaluate_df[["confidence", "accuracy"]].groupby(["confidence"]).mean().iterrows(): - # writer.add_scalar(f"eval/confidence/{confi}", row["accuracy"], global_step) - # accelerator.print(f"{confi} confidence: {row['accuracy']}") - # writer.add_scalar("eval/accuracy", evaluate_df["accuracy"].mean(), global_step) - # accelerator.print(f"eval accuracy: {evaluate_df['accuracy'].mean()}") - # if accelerator.is_main_process: - # os.makedirs(f"eval_tables/{run_name}", exist_ok=True) - # evaluate_df.to_csv(f"eval_tables/{run_name}/eval_{update}.csv") - # if args.track: - # wandb.log({"samples/query_responses": wandb.Table(dataframe=evaluate_df)}, step=update) - # torch.cuda.empty_cache() - - # norm_dataset = load_dataset(args.task.query_dataset, split="train") - # norm_dataset = norm_dataset.with_format("torch", columns=["query_token", "reference_response_token"]) - # norm_dataset = norm_dataset.shuffle(seed=local_seed) - # norm_dataloader = DataLoader(norm_dataset, batch_size=args.local_eval_batch_size) - # items = defaultdict(list) - # norm_dataloader = accelerator.prepare(norm_dataloader) - # with torch.no_grad(): - # for data in tqdm(norm_dataloader): - # reference_responses = data["reference_response_token"].to(device, non_blocking=True) - # queries = data["query_token"].to(device, non_blocking=True) - # query_responses = torch.cat((queries, reference_responses), dim=1) - # predicted_reward = get_reward(model, query_responses, tokenizer) - # predicted_reward = accelerator.gather(predicted_reward) - # queries = accelerator.gather(queries) - # reference_responses = accelerator.gather(reference_responses) - # accelerator.print(predicted_reward.shape) - # for i in range(len(predicted_reward)): - # items["query"].append(tokenizer.decode(queries[i], skip_special_tokens=True)) - # items["reference_response"].append(tokenizer.decode(reference_responses[i])) - # items["predicted_reward"].append(predicted_reward[i].item()) + if args.run_eval: + evaluate_df = evaluate(args, accelerator, tokenizer, model, validation_dataloader) + for split, row in evaluate_df[["split", "accuracy"]].groupby(["split"]).mean().iterrows(): + writer.add_scalar(f"eval/accuracy/{split}", row["accuracy"], global_step) + accelerator.print(f"{split} accuracy: {row['accuracy']}") + for batch, row in evaluate_df[["batch", "accuracy"]].groupby(["batch"]).mean().iterrows(): + writer.add_scalar(f"eval/accuracy/{batch}", row["accuracy"], global_step) + accelerator.print(f"{batch} accuracy: {row['accuracy']}") + for confi, row in evaluate_df[["confidence", "accuracy"]].groupby(["confidence"]).mean().iterrows(): + writer.add_scalar(f"eval/confidence/{confi}", row["accuracy"], global_step) + accelerator.print(f"{confi} confidence: {row['accuracy']}") + writer.add_scalar("eval/accuracy", evaluate_df["accuracy"].mean(), global_step) + accelerator.print(f"eval accuracy: {evaluate_df['accuracy'].mean()}") + if accelerator.is_main_process: + os.makedirs(f"eval_tables/{run_name}", exist_ok=True) + evaluate_df.to_csv(f"eval_tables/{run_name}/eval_{update}.csv") + if args.track: + wandb.log({"samples/query_responses": wandb.Table(dataframe=evaluate_df)}, step=update) + torch.cuda.empty_cache() + + norm_dataset = load_dataset(args.task.query_dataset, split="train") + norm_dataset = norm_dataset.with_format("torch", columns=["query_token", "reference_response_token"]) + norm_dataset = norm_dataset.shuffle(seed=local_seed) + norm_dataloader = DataLoader(norm_dataset, batch_size=args.local_eval_batch_size) + items = defaultdict(list) + norm_dataloader = accelerator.prepare(norm_dataloader) + with torch.no_grad(): + for data in tqdm(norm_dataloader): + reference_responses = data["reference_response_token"].to(device, non_blocking=True) + queries = data["query_token"].to(device, non_blocking=True) + query_responses = torch.cat((queries, reference_responses), dim=1) + predicted_reward = get_reward(model, query_responses, tokenizer) + predicted_reward = accelerator.gather(predicted_reward) + queries = accelerator.gather(queries) + reference_responses = accelerator.gather(reference_responses) + accelerator.print(predicted_reward.shape) + for i in range(len(predicted_reward)): + items["query"].append(tokenizer.decode(queries[i], skip_special_tokens=True)) + items["reference_response"].append(tokenizer.decode(reference_responses[i])) + items["predicted_reward"].append(predicted_reward[i].item()) - # if accelerator.is_main_process: - # norm_df = pd.DataFrame(items) - # os.makedirs(f"eval_tables/{run_name}", exist_ok=True) - # norm_df.to_csv(f"eval_tables/{run_name}/eval_{update}_normalized.csv") - # if args.track: - # wandb.log({"samples/normalized": wandb.Table(dataframe=norm_df)}, step=update) - # stats = { - # "mean": norm_df["predicted_reward"].mean(), - # "std": norm_df["predicted_reward"].std(), - # "max": norm_df["predicted_reward"].max(), - # "min": norm_df["predicted_reward"].min() - # } - # for stat_name, stat_value in stats.items(): - # writer.add_scalar(f"eval/normalized_{stat_name}", stat_value, global_step) - # accelerator.print(f"Normalized Reward {stat_name.capitalize()}: {stat_value}") + if accelerator.is_main_process: + norm_df = pd.DataFrame(items) + os.makedirs(f"eval_tables/{run_name}", exist_ok=True) + norm_df.to_csv(f"eval_tables/{run_name}/eval_{update}_normalized.csv") + if args.track: + wandb.log({"samples/normalized": wandb.Table(dataframe=norm_df)}, step=update) + stats = { + "mean": norm_df["predicted_reward"].mean(), + "std": norm_df["predicted_reward"].std(), + "max": norm_df["predicted_reward"].max(), + "min": norm_df["predicted_reward"].min() + } + for stat_name, stat_value in stats.items(): + writer.add_scalar(f"eval/normalized_{stat_name}", stat_value, global_step) + accelerator.print(f"Normalized Reward {stat_name.capitalize()}: {stat_value}") # save model if args.output_dir and args.num_train_epochs > 0: os.makedirs(os.path.dirname(args.output_dir), exist_ok=True) - time_tensor = torch.tensor(int(time.time()), device=device) - time_int = accelerator.gather(time_tensor).item() # avoid different timestamps across processes - repo_name = f"{args.base_model.replace('/', '_')}__{args.exp_name}__tldr__seed{args.seed}" + time_tensor = torch.tensor([int(time.time())], device=device) + time_int = accelerator.gather(time_tensor)[0].item() # avoid different timestamps across processes + repo_name = f"{args.base_model.replace('/', '_')}__{args.exp_name}__tldr" repo_id = f"{args.hf_entity}/{repo_name}" if args.hf_entity else repo_name if accelerator.is_main_process: tokenizer.save_pretrained(args.output_dir, repo_id=repo_id) if args.push_to_hub: - tokenizer.push_to_hub(repo_id, revision=str(time_int)) + tokenizer.push_to_hub(repo_id, revision=f"seed{args.seed}_{str(time_int)}") unwrapped: PreTrainedModel = accelerator.unwrap_model(model) accelerator.wait_for_everyone() @@ -456,8 +457,8 @@ def evaluate(args, accelerator, tokenizer, model, dataloader): repo_id=repo_id, ) if args.push_to_hub: - unwrapped.push_to_hub(repo_id, revision=str(time_int), safe_serialization=False) + unwrapped.push_to_hub(repo_id, revision=f"seed{args.seed}_{str(time_int)}", safe_serialization=False) -if __name__ == "__main__": - args = tyro.cli(Args) - # train(args) +# if __name__ == "__main__": +# args = tyro.cli(Args) +# train(args) diff --git a/lm_human_preference_details/summarize/sft.py b/lm_human_preference_details/summarize/sft.py index d434688..32b1645 100644 --- a/lm_human_preference_details/summarize/sft.py +++ b/lm_human_preference_details/summarize/sft.py @@ -111,7 +111,7 @@ class Args: """The batch size per GPU (HF's `per_device_train_batch_size` * `gradient_accumulation_steps`)""" batch_size: Optional[int] = None """The batch size across devices (HF's `per_device_train_batch_size` * `world_size` * `gradient_accumulation_steps`)""" - local_eval_batch_size: int = 8 + local_eval_batch_size: int = 4 """per rank eval batch size""" world_size: Optional[int] = None """The number of processes (GPUs) to use""" @@ -119,7 +119,6 @@ class Args: """Number of epochs to train""" num_updates: Optional[int] = None """The number of updates to train""" - # other args base_model: str = "EleutherAI/pythia-160m" """the name of the pretrained model to use""" @@ -238,7 +237,11 @@ def forward(model, query_responses, tokenizer): configure_dropout(model_config, args.dropout_layer_keys, 0.0) # disable dropout if accelerator.is_main_process: pprint(model_config) - model: PreTrainedModel = AutoModelForCausalLM.from_pretrained(args.base_model, config=model_config, trust_remote_code=True) + model: PreTrainedModel = AutoModelForCausalLM.from_pretrained( + args.base_model, + config=model_config, + trust_remote_code=True, + ) model.generation_config.eos_token_id = None # disable `pad_token_id` and `eos_token_id` because we just want to model.generation_config.pad_token_id = None # generate tokens without truncation / padding if args.optimizer == "adam": @@ -249,7 +252,7 @@ def forward(model, query_responses, tokenizer): args.scheduler, optimizer=optimizer, num_warmup_steps=args.warm_up_steps, - num_training_steps=args.num_updates, + num_training_steps=args.num_updates * args.num_train_epochs, ) model, optimizer, dataloader, scheduler = accelerator.prepare( @@ -303,6 +306,7 @@ def forward(model, query_responses, tokenizer): writer.add_scalar("lr", scheduler.get_last_lr()[0], update) accelerator.print(f"{loss.item()=}, {scheduler.get_last_lr()=}, {update=}") break + if args.run_eval: model.eval() rouge_scores = collections.defaultdict(list) @@ -345,9 +349,13 @@ def forward(model, query_responses, tokenizer): decode_validation_queries = tokenizer.batch_decode(accelerator.gather(validation_queries)) decode_validation_query_responses = tokenizer.batch_decode(accelerator.gather(generated_responses)) decode_validation_reference_responses = tokenizer.batch_decode( - accelerator.gather(validation_reference_responses) + accelerator.gather(validation_reference_responses), + skip_special_tokens=True, + ) + decode_validation_responses = tokenizer.batch_decode( + accelerator.gather(generated_responses[:, -args.task.response_length:]), + skip_special_tokens=True, ) - decode_validation_responses = tokenizer.batch_decode(accelerator.gather(generated_responses[:, -args.task.response_length:])) rouge_score = rouge.compute( predictions=decode_validation_responses, references=decode_validation_reference_responses ) @@ -384,15 +392,15 @@ def forward(model, query_responses, tokenizer): # save model if args.output_dir: os.makedirs(os.path.dirname(args.output_dir), exist_ok=True) - time_tensor = torch.tensor(int(time.time()), device=device) + time_tensor = torch.tensor([int(time.time())], device=device) time_int = accelerator.gather(time_tensor)[0].item() # avoid different timestamps across processes - repo_name = f"{args.base_model.replace('/', '_')}__{args.exp_name}__tldr__seed{args.seed}" + repo_name = f"{args.base_model.replace('/', '_')}__{args.exp_name}__tldr" repo_id = f"{args.hf_entity}/{repo_name}" if args.hf_entity else repo_name if accelerator.is_main_process: tokenizer.save_pretrained(args.output_dir, repo_id=repo_id) if args.push_to_hub: - tokenizer.push_to_hub(repo_id, revision=str(time_int)) + tokenizer.push_to_hub(repo_id, revision=f"seed{args.seed}_{str(time_int)}") unwrapped: PreTrainedModel = accelerator.unwrap_model(model) accelerator.wait_for_everyone() @@ -406,7 +414,7 @@ def forward(model, query_responses, tokenizer): repo_id=repo_id, ) if args.push_to_hub: - unwrapped.push_to_hub(repo_id, revision=str(time_int), safe_serialization=False) + unwrapped.push_to_hub(repo_id, revision=f"seed{args.seed}_{str(time_int)}", safe_serialization=False) # if __name__ == "__main__": # args = tyro.cli(Args) From fceceaf9fc53f15cebfc43cd2185b677445f464c Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Tue, 19 Dec 2023 14:20:33 +0000 Subject: [PATCH 42/62] precommit --- .../summarize/reward.py | 53 +- lm_human_preference_details/summarize/sft.py | 20 +- lm_human_preference_details/tldr_dataset.py | 40 +- ...in_policy_accelerate_summarize_separate.py | 42 +- ...n_policy_accelerate_summarize_separate1.py | 698 +++++++++++++++--- ...n_policy_accelerate_summarize_separate3.py | 41 +- ...n_policy_accelerate_summarize_separate4.py | 43 +- ...elerate_summarize_separate5_load_critic.py | 43 +- ...ummarize_separate6_correct_reward_index.py | 43 +- ...te7_correct_reward_index_no_load_critic.py | 43 +- ...parate8_correct_reward_index_deepspeed3.py | 47 +- 11 files changed, 832 insertions(+), 281 deletions(-) diff --git a/lm_human_preference_details/summarize/reward.py b/lm_human_preference_details/summarize/reward.py index 2f02ec5..30e3893 100644 --- a/lm_human_preference_details/summarize/reward.py +++ b/lm_human_preference_details/summarize/reward.py @@ -1,7 +1,7 @@ -from collections import defaultdict import os import random import time +from collections import defaultdict from dataclasses import asdict, dataclass, field from types import SimpleNamespace from typing import List, Literal, Optional @@ -15,16 +15,23 @@ import tyro from accelerate import Accelerator from accelerate.state import AcceleratorState -from accelerate.utils import DistributedDataParallelKwargs, gather_object +from accelerate.utils import gather_object from datasets import load_dataset from rich.console import Console from rich.pretty import pprint from rich.table import Table from torch import optim -from torch.utils.tensorboard import SummaryWriter from torch.utils.data import DataLoader +from torch.utils.tensorboard import SummaryWriter from tqdm import tqdm -from transformers import AutoConfig, AutoModel, AutoTokenizer, get_scheduler, PreTrainedModel, PretrainedConfig +from transformers import ( + AutoConfig, + AutoModel, + AutoTokenizer, + PretrainedConfig, + PreTrainedModel, + get_scheduler, +) @dataclass @@ -126,7 +133,9 @@ class Args: # other args base_model: str = "EleutherAI/pythia-160m" """the name of the pretrained model to use""" - dropout_layer_keys: List[str] = field(default_factory=lambda: ["attn_pdrop", "embd_pdrop", "resid_pdrop", "summary_first_dropout"]) + dropout_layer_keys: List[str] = field( + default_factory=lambda: ["attn_pdrop", "embd_pdrop", "resid_pdrop", "summary_first_dropout"] + ) """Which layers to apply dropout to""" output_dir: str = "models/reward_policy" """Where to save the model""" @@ -164,13 +173,16 @@ def layer_init(layer, std=np.sqrt(2), bias_const=0.0): class ScalarModelConfig(PretrainedConfig): - model_type = 'scalar_model' + model_type = "scalar_model" + def __init__(self, base_model: str = "gpt2", **kwargs): super().__init__(**kwargs) self.base_model = base_model + class ScalarModel(PreTrainedModel): config_class = ScalarModelConfig + def __init__(self, config: ScalarModelConfig): super().__init__(config) self.config = config @@ -203,10 +215,7 @@ def get_reward(model, query_responses, tokenizer): return_dict=True, output_hidden_states=True, ) - sequence_lengths = ( - torch.eq(query_responses, tokenizer.pad_token_id).long().argmax(-1) - 1).to( - query_responses.device - ) + sequence_lengths = (torch.eq(query_responses, tokenizer.pad_token_id).long().argmax(-1) - 1).to(query_responses.device) # https://github.com/huggingface/transformers/blob/dc68a39c8111217683bf49a4912d0c9018bab33d/src/transformers/models/gpt2/modeling_gpt2.py#L1454 return reward_logits[torch.arange(reward_logits.size(0), device=reward_logits.device), sequence_lengths] @@ -258,10 +267,26 @@ def evaluate(args, accelerator, tokenizer, model, dataloader): dataset = load_dataset(args.label_dataset, "comparisons", split="train") dataset = dataset.shuffle(seed=local_seed) dataset = dataset.select(range(args.labels.num_train)) - dataset = dataset.with_format("torch", columns=["query_token", "choice", "response0_token", "response1_token", "batch", "split"]) + dataset = dataset.with_format( + "torch", columns=["query_token", "choice", "response0_token", "response1_token", "batch", "split"] + ) dataloader = DataLoader(dataset, batch_size=args.local_micro_batch_size) validation_dataset = load_dataset(args.label_dataset, "comparisons", split="validation").flatten() - validation_dataset = validation_dataset.with_format("torch", columns=["query_token", "choice", "response0_token", "response1_token", "batch", "split", "extra.confidence", "response0_policy", "response1_policy", "policies"]) + validation_dataset = validation_dataset.with_format( + "torch", + columns=[ + "query_token", + "choice", + "response0_token", + "response1_token", + "batch", + "split", + "extra.confidence", + "response0_policy", + "response1_policy", + "policies", + ], + ) validation_dataloader = DataLoader(validation_dataset, batch_size=args.local_eval_batch_size) accelerator.print("The number of samples in dataset", len(dataset)) accelerator.print("The number of samples in validation_dataset", len(validation_dataset)) @@ -426,7 +451,7 @@ def evaluate(args, accelerator, tokenizer, model, dataloader): "mean": norm_df["predicted_reward"].mean(), "std": norm_df["predicted_reward"].std(), "max": norm_df["predicted_reward"].max(), - "min": norm_df["predicted_reward"].min() + "min": norm_df["predicted_reward"].min(), } for stat_name, stat_value in stats.items(): writer.add_scalar(f"eval/normalized_{stat_name}", stat_value, global_step) @@ -436,7 +461,7 @@ def evaluate(args, accelerator, tokenizer, model, dataloader): if args.output_dir and args.num_train_epochs > 0: os.makedirs(os.path.dirname(args.output_dir), exist_ok=True) time_tensor = torch.tensor([int(time.time())], device=device) - time_int = accelerator.gather(time_tensor)[0].item() # avoid different timestamps across processes + time_int = accelerator.gather(time_tensor)[0].item() # avoid different timestamps across processes repo_name = f"{args.base_model.replace('/', '_')}__{args.exp_name}__tldr" repo_id = f"{args.hf_entity}/{repo_name}" if args.hf_entity else repo_name diff --git a/lm_human_preference_details/summarize/sft.py b/lm_human_preference_details/summarize/sft.py index 32b1645..8f84bc1 100644 --- a/lm_human_preference_details/summarize/sft.py +++ b/lm_human_preference_details/summarize/sft.py @@ -1,5 +1,4 @@ import collections -import functools import os import random import time @@ -13,7 +12,6 @@ import torch import torch.optim as optim import tyro -from tqdm import tqdm from accelerate import Accelerator from datasets import load_dataset from rich.console import Console @@ -23,6 +21,7 @@ from torch.nn import functional as F from torch.utils.data import DataLoader from torch.utils.tensorboard import SummaryWriter +from tqdm import tqdm from transformers import ( AutoConfig, AutoModelForCausalLM, @@ -122,7 +121,9 @@ class Args: # other args base_model: str = "EleutherAI/pythia-160m" """the name of the pretrained model to use""" - dropout_layer_keys: List[str] = field(default_factory=lambda: ["attn_pdrop", "embd_pdrop", "resid_pdrop", "summary_first_dropout"]) + dropout_layer_keys: List[str] = field( + default_factory=lambda: ["attn_pdrop", "embd_pdrop", "resid_pdrop", "summary_first_dropout"] + ) """Which layers to apply dropout to""" output_dir: str = "models/sft_model" """Where to save the model""" @@ -255,9 +256,7 @@ def forward(model, query_responses, tokenizer): num_training_steps=args.num_updates * args.num_train_epochs, ) - model, optimizer, dataloader, scheduler = accelerator.prepare( - model, optimizer, dataloader, scheduler - ) + model, optimizer, dataloader, scheduler = accelerator.prepare(model, optimizer, dataloader, scheduler) validation_dataloader = accelerator.prepare(validation_dataloader) # WARNING: even with `max_new_tokens` and `min_new_tokens` set to the same value, the number of tokens generated # may not be the same. TODO: investigate further, we just want to generate a fixed number of tokens @@ -305,7 +304,6 @@ def forward(model, query_responses, tokenizer): writer.add_scalar("loss", accelerator.gather(loss_stats).mean().item(), update) writer.add_scalar("lr", scheduler.get_last_lr()[0], update) accelerator.print(f"{loss.item()=}, {scheduler.get_last_lr()=}, {update=}") - break if args.run_eval: model.eval() @@ -319,9 +317,7 @@ def forward(model, query_responses, tokenizer): with torch.no_grad(): validation_reference_responses = validation_data["reference_response_token"].to(device, non_blocking=True) validation_queries = validation_data["query_token"].to(device, non_blocking=True) - validation_query_reference_responses = torch.cat( - (validation_queries, validation_reference_responses), dim=1 - ) + validation_query_reference_responses = torch.cat((validation_queries, validation_reference_responses), dim=1) validation_output = forward(model, validation_query_reference_responses, tokenizer) validation_labels = validation_query_reference_responses.masked_fill( @@ -353,7 +349,7 @@ def forward(model, query_responses, tokenizer): skip_special_tokens=True, ) decode_validation_responses = tokenizer.batch_decode( - accelerator.gather(generated_responses[:, -args.task.response_length:]), + accelerator.gather(generated_responses[:, -args.task.response_length :]), skip_special_tokens=True, ) rouge_score = rouge.compute( @@ -393,7 +389,7 @@ def forward(model, query_responses, tokenizer): if args.output_dir: os.makedirs(os.path.dirname(args.output_dir), exist_ok=True) time_tensor = torch.tensor([int(time.time())], device=device) - time_int = accelerator.gather(time_tensor)[0].item() # avoid different timestamps across processes + time_int = accelerator.gather(time_tensor)[0].item() # avoid different timestamps across processes repo_name = f"{args.base_model.replace('/', '_')}__{args.exp_name}__tldr" repo_id = f"{args.hf_entity}/{repo_name}" if args.hf_entity else repo_name diff --git a/lm_human_preference_details/tldr_dataset.py b/lm_human_preference_details/tldr_dataset.py index bac3363..d945428 100644 --- a/lm_human_preference_details/tldr_dataset.py +++ b/lm_human_preference_details/tldr_dataset.py @@ -1,15 +1,16 @@ -from dataclasses import dataclass +import multiprocessing import os +from dataclasses import dataclass from typing import Dict, Optional -from datasets import load_dataset -from rich.pretty import pprint -from transformers import AutoTokenizer -import tyro -import multiprocessing import matplotlib.pyplot as plt import pandas as pd +import tyro +from datasets import load_dataset from huggingface_hub import HfApi +from rich.pretty import pprint +from transformers import AutoTokenizer + api = HfApi() @@ -20,11 +21,13 @@ --max-sft-response-length=53 \ --max-rm-response-length=169 """ + + @dataclass class Args: - base_model: str = "gpt2" # EleutherAI/pythia-160m - max_sft_response_length: int = 48 # 53 - max_rm_response_length: int = 153 # 169 + base_model: str = "gpt2" # EleutherAI/pythia-160m + max_sft_response_length: int = 48 # 53 + max_rm_response_length: int = 153 # 169 hf_entity: str = None @@ -36,7 +39,7 @@ class TaskQueryHParams: ] = "SUBREDDIT: r/{subreddit}\n\nTITLE: {title}\n\nPOST: {post}\n\nTL;DR:" # if underlying dataset yields dicts, can format arbitrarily truncate_field: Optional[str] = "post" truncate_text: Optional[str] = "\n" - padding: Optional[str] = " " # empty spaces + padding: Optional[str] = " " # empty spaces pad_side: Optional[str] = "left" @@ -138,7 +141,9 @@ def process_query_data(x): } sft_ds = sft_ds.map(process_query_data, load_from_cache_file=False, num_proc=multiprocessing.cpu_count()) - sft_ds.push_to_hub(f"{args.hf_entity}/summarize_from_feedback_tldr_3_filtered_oai_preprocessing_{args.base_model.split('/')[-1]}_{args.max_sft_response_length}") + sft_ds.push_to_hub( + f"{args.hf_entity}/summarize_from_feedback_tldr_3_filtered_oai_preprocessing_{args.base_model.split('/')[-1]}_{args.max_sft_response_length}" + ) label_ds = load_dataset("openai/summarize_from_feedback", "comparisons") @@ -168,7 +173,9 @@ def process_response_data(x): } label_ds = label_ds.map(process_response_data, load_from_cache_file=False, num_proc=multiprocessing.cpu_count()) - label_ds.push_to_hub(f"{args.hf_entity}/summarize_from_feedback_oai_preprocessing_{args.base_model.split('/')[-1]}_{args.max_rm_response_length}") + label_ds.push_to_hub( + f"{args.hf_entity}/summarize_from_feedback_oai_preprocessing_{args.base_model.split('/')[-1]}_{args.max_rm_response_length}" + ) os.makedirs("dataset_visuals", exist_ok=True) # visualize token length distribution @@ -183,10 +190,10 @@ def process_response_data(x): offset = len(sft_ds) for i, key in enumerate(label_ds.keys()): df = label_ds[key].to_pandas() - axs[2*i + offset].hist(df["response0_token_len"], bins=100) - axs[2*i + offset].set_title(f"{key} split: response0 token length\nmax_length={max(df['response0_token_len'])}") - axs[2*i + offset + 1].hist(df["response1_token_len"], bins=100) - axs[2*i + offset + 1].set_title(f"{key} split: response1 token length\nmax_length={max(df['response1_token_len'])}") + axs[2 * i + offset].hist(df["response0_token_len"], bins=100) + axs[2 * i + offset].set_title(f"{key} split: response0 token length\nmax_length={max(df['response0_token_len'])}") + axs[2 * i + offset + 1].hist(df["response1_token_len"], bins=100) + axs[2 * i + offset + 1].set_title(f"{key} split: response1 token length\nmax_length={max(df['response1_token_len'])}") fig.suptitle(f"{args.base_model} Tokenizer: Token length distribution") fig.tight_layout() fig.savefig("dataset_visuals/token_len.png") @@ -244,4 +251,3 @@ def process_response_data(x): repo_id=f"{args.hf_entity}/summarize_from_feedback_oai_preprocessing_{args.base_model.split('/')[-1]}_{args.max_rm_response_length}", repo_type="dataset", ) - diff --git a/lm_human_preference_details/train_policy_accelerate_summarize_separate.py b/lm_human_preference_details/train_policy_accelerate_summarize_separate.py index 0e28c49..da288a9 100644 --- a/lm_human_preference_details/train_policy_accelerate_summarize_separate.py +++ b/lm_human_preference_details/train_policy_accelerate_summarize_separate.py @@ -33,7 +33,6 @@ GenerationConfig, ) - INVALID_LOGPROB = 1.0 @@ -142,7 +141,9 @@ class Args: base_model: str = "EleutherAI/pythia-160m" """the name of the pretrained model to use""" - dropout_layer_keys: List[str] = field(default_factory=lambda: ["attn_pdrop", "embd_pdrop", "resid_pdrop", "summary_first_dropout"]) + dropout_layer_keys: List[str] = field( + default_factory=lambda: ["attn_pdrop", "embd_pdrop", "resid_pdrop", "summary_first_dropout"] + ) """Which layers to apply dropout to""" deepspeed: bool = False """Whether to use deepspeed to train the model""" @@ -356,8 +357,9 @@ def whiten(values, shift_mean=True): def masked_mean(x, mask): return (x.sum(-1) / (~mask).sum(-1)).mean() + def masked_var(x, mask): - return (x**2).sum(-1) / (~mask).sum(-1) - masked_mean(x, mask)**2 + return (x**2).sum(-1) / (~mask).sum(-1) - masked_mean(x, mask) ** 2 def masked_whiten(values, mask, shift_mean=True): @@ -367,7 +369,7 @@ def masked_whiten(values, mask, shift_mean=True): if not shift_mean: whitened += mean return whitened - + def masked_mean(values, mask, axis=None): """Compute mean of tensor with a masked values.""" @@ -376,6 +378,7 @@ def masked_mean(values, mask, axis=None): else: return (values * mask).sum() / mask.sum() + def masked_var(values, mask, unbiased=True): """Compute variance of tensor with masked values.""" mean = masked_mean(values, mask) @@ -497,7 +500,11 @@ def get_reward(reward_model, query_responses, tokenizer): # ) # print(f"======={sequence_lengths1=} {sequence_lengths=}") # https://github.com/huggingface/transformers/blob/dc68a39c8111217683bf49a4912d0c9018bab33d/src/transformers/models/gpt2/modeling_gpt2.py#L1454 - return reward_logits, reward_logits[torch.arange(reward_logits.size(0), device=reward_logits.device), sequence_lengths].squeeze(-1), sequence_lengths + return ( + reward_logits, + reward_logits[torch.arange(reward_logits.size(0), device=reward_logits.device), sequence_lengths].squeeze(-1), + sequence_lengths, + ) def forward(policy, query_responses, tokenizer): @@ -640,9 +647,7 @@ def forward(policy, query_responses, tokenizer): eval_ds_config["zero_optimization"] = { "stage": 3, "stage3_param_persistence_threshold": 1e4, - "offload_param": { - "device": "cpu" - } + "offload_param": {"device": "cpu"}, } accelerator.print(f"{eval_ds_config=}") reward_model, *_ = deepspeed.initialize(model=reward_model, config=eval_ds_config) @@ -826,7 +831,9 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well if accelerator.is_main_process: all_sample_validation_df.to_json(f"runs/{run_name}/table.json") if args.track: - wandb.log({"samples/query_responses": wandb.Table(dataframe=all_sample_validation_df)}, step=update) + wandb.log( + {"samples/query_responses": wandb.Table(dataframe=all_sample_validation_df)}, step=update + ) print_rich_table("stuff", all_sample_validation_df[:4], console) except Exception as e: @@ -903,8 +910,7 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well vf_losses1 = torch.square(vpred - mb_return) vf_losses2 = torch.square(vpredclipped - mb_return) vf_loss_max = torch.max(vf_losses1, vf_losses2) - - + # vf_loss = 0.5 * vf_loss_max.mean() vf_loss = 0.5 * masked_mean(vf_loss_max, ~padding_mask[micro_batch_inds]) vf_clipfrac = masked_mean((vf_losses2 > vf_losses1).float(), ~padding_mask[micro_batch_inds]) @@ -927,7 +933,7 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well approxkl = 0.5 * masked_mean((logprobs_diff**2), ~padding_mask[micro_batch_inds]) # if ppo_epoch_idx == 0 and micro_batch_start == 0: # torch.testing.assert_close(ratio, torch.zeros_like(ratio) + 1, atol=1e-4, rtol=1e-4) - # if ppo_epoch_idx == 0: + # if ppo_epoch_idx == 0: # pprint({ # # "responses": responses, # # "values": values, @@ -951,7 +957,9 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well pg_loss_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = pg_loss vf_loss_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = vf_loss vf_clipfrac_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = vf_clipfrac - entropy_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = masked_mean(entropy, padding_mask[micro_batch_inds]) + entropy_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = masked_mean( + entropy, padding_mask[micro_batch_inds] + ) ratio_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = ratio.mean() gradient_accumulation_idx += 1 @@ -961,13 +969,13 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well f"ppo_epoch_idx", ppo_epoch_idx, "approxkl", - approxkl_stats[:ppo_epoch_idx+1].mean().item(), + approxkl_stats[: ppo_epoch_idx + 1].mean().item(), "pg_loss", - pg_loss_stats[:ppo_epoch_idx+1].mean().item(), + pg_loss_stats[: ppo_epoch_idx + 1].mean().item(), "pg_clipfrac", - pg_clipfrac_stats[:ppo_epoch_idx+1].mean().item(), + pg_clipfrac_stats[: ppo_epoch_idx + 1].mean().item(), "ratio", - ratio_stats[:ppo_epoch_idx+1].mean().item(), + ratio_stats[: ppo_epoch_idx + 1].mean().item(), ) # breakpoint() with torch.no_grad(): diff --git a/lm_human_preference_details/train_policy_accelerate_summarize_separate1.py b/lm_human_preference_details/train_policy_accelerate_summarize_separate1.py index 927c6bc..a8d03c4 100644 --- a/lm_human_preference_details/train_policy_accelerate_summarize_separate1.py +++ b/lm_human_preference_details/train_policy_accelerate_summarize_separate1.py @@ -33,7 +33,6 @@ GenerationConfig, ) - INVALID_LOGPROB = 1.0 @@ -142,7 +141,9 @@ class Args: base_model: str = "EleutherAI/pythia-160m" """the name of the pretrained model to use""" - dropout_layer_keys: List[str] = field(default_factory=lambda: ["attn_pdrop", "embd_pdrop", "resid_pdrop", "summary_first_dropout"]) + dropout_layer_keys: List[str] = field( + default_factory=lambda: ["attn_pdrop", "embd_pdrop", "resid_pdrop", "summary_first_dropout"] + ) """Which layers to apply dropout to""" deepspeed: bool = False """Whether to use deepspeed to train the model""" @@ -419,12 +420,13 @@ def get_reward(reward_model, query_responses, tokenizer): return_dict=True, output_hidden_states=True, ) - sequence_lengths = ( - torch.eq(query_responses, tokenizer.pad_token_id).long().argmax(-1) - 1).to( - query_responses.device - ) + sequence_lengths = (torch.eq(query_responses, tokenizer.pad_token_id).long().argmax(-1) - 1).to(query_responses.device) # https://github.com/huggingface/transformers/blob/dc68a39c8111217683bf49a4912d0c9018bab33d/src/transformers/models/gpt2/modeling_gpt2.py#L1454 - return reward_logits, reward_logits[torch.arange(reward_logits.size(0), device=reward_logits.device), sequence_lengths].squeeze(-1), sequence_lengths + return ( + reward_logits, + reward_logits[torch.arange(reward_logits.size(0), device=reward_logits.device), sequence_lengths].squeeze(-1), + sequence_lengths, + ) def forward(policy, query_responses, tokenizer): @@ -463,6 +465,7 @@ def truncate_response(args, tokenizer, responses): def masked_mean(x, mask): return (x.sum(-1) / (~mask).sum(-1)).mean() + # def train(args: Args): if __name__ == "__main__": args = tyro.cli(Args) @@ -640,7 +643,7 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well vf_clipfrac_stats = torch.zeros(stats_shape, device=device) entropy_stats = torch.zeros(stats_shape, device=device) ratio_stats = torch.zeros(stats_shape, device=device) - + model.train() for update in range(1, args.ppo.num_updates + 1): global_step += 1 * args.ppo.batch_size @@ -661,60 +664,550 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well generation_config, ) if args.task.response_length != 53: - query_responses = torch.tensor([[ 209, 209, 209, 209, 209, 209, 209, 209, 209, 209, - 209, 209, 209, 209, 209, 209, 209, 209, 209, 209, - 209, 209, 209, 209, 209, 209, 209, 209, 209, 209, - 209, 209, 209, 209, 209, 209, 209, 209, 209, 209, - 209, 209, 209, 209, 209, 209, 209, 209, 209, 209, - 209, 209, 209, 209, 209, 209, 209, 209, 209, 209, - 209, 209, 209, 209, 209, 209, 209, 209, 209, 209, - 209, 209, 209, 209, 209, 209, 209, 209, 209, 209, - 209, 209, 209, 209, 209, 209, 209, 209, 209, 209, - 209, 209, 209, 209, 209, 209, 209, 209, 209, 209, - 209, 209, 209, 209, 209, 209, 209, 209, 209, 209, - 209, 209, 209, 209, 209, 209, 209, 209, 209, 209, - 209, 209, 209, 209, 209, 209, 209, 209, 209, 209, - 209, 209, 209, 209, 209, 209, 209, 209, 209, 209, - 209, 209, 209, 209, 209, 209, 6971, 7941, 1703, 37, - 1433, 27, 391, 16, 22842, 16458, 187, 187, 53, 43561, - 27, 3189, 544, 1348, 278, 62, 5816, 619, 806, 385, - 544, 1797, 269, 62, 846, 608, 2607, 273, 2740, 598, - 15, 187, 187, 15743, 27, 24387, 39714, 187, 6300, 15950, - 436, 1501, 562, 627, 816, 281, 1339, 352, 562, 273, - 619, 985, 15, 187, 2598, 309, 452, 644, 13597, 436, - 3226, 313, 2577, 806, 19609, 15, 309, 369, 617, 806, - 10, 323, 495, 1107, 15, 844, 574, 271, 13103, 673, - 285, 4536, 7227, 35267, 285, 37616, 15, 496, 253, 990, - 13, 352, 1904, 626, 789, 562, 15, 187, 42, 3260, - 309, 7636, 617, 285, 703, 7636, 479, 533, 1841, 816, - 1904, 626, 789, 562, 1955, 281, 1097, 4858, 4606, 15, - 187, 187, 2598, 352, 556, 644, 2761, 608, 2607, 15, - 309, 1694, 689, 253, 31056, 673, 273, 619, 1495, 534, - 369, 1501, 2740, 598, 273, 806, 374, 2607, 15, 209, - 187, 4125, 846, 608, 2607, 13, 309, 816, 2985, 617, - 15, 187, 42, 5476, 627, 11210, 626, 644, 247, 2014, - 835, 309, 6468, 626, 1869, 670, 617, 2568, 15, 23385, - 50276, 187, 42, 871, 309, 10095, 626, 3057, 617, 285, - 309, 1353, 3965, 2119, 703, 1912, 626, 3057, 479, 2057, - 534, 310, 323, 253, 1805, 15, 187, 1231, 6468, 626, - 13452, 323, 5046, 374, 2607, 32, 1633, 751, 326, 15, - 187, 43688, 13, 309, 816, 4571, 626, 6016, 352, 10542, - 285, 3261, 387, 776, 7963, 327, 619, 17899, 7963, 534, - 309, 1620, 755, 327, 15, 187, 1147, 369, 5322, 281, - 923, 617, 2454, 969, 285, 30774, 336, 253, 1711, 1897, - 15, 187, 1147, 369, 5322, 281, 923, 253, 9097, 359, - 1097, 2389, 1024, 3811, 342, 617, 2021, 15, 187, 34937, - 512, 608, 2607, 13, 619, 5249, 5055, 598, 15, 309, - 1694, 247, 14892, 209, 187, 36421, 598, 247, 2257, 273, - 2583, 285, 858, 1841, 1475, 253, 2419, 309, 6468, 626, - 644, 2104, 281, 3966, 3966, 15, 187, 1989, 309, 816, - 2985, 617, 15, 187, 42, 871, 703, 434, 2509, 973, - 13, 3164, 1805, 685, 1078, 15, 187, 2513, 352, 816, - 479, 32, 209, 187, 25954, 6701, 323, 634, 673, 4361, - 436, 285, 11435, 634, 5701, 15, 187, 187, 14135, 28, - 4976, 27, 6365, 619, 806, 19609, 13, 9377, 598, 13, - 309, 2985, 617, 533, 1053, 626, 3057, 617, 285, 12371, - 604, 352, 434, 816, 479, 15, 0,]], device=device) + query_responses = torch.tensor( + [ + [ + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 6971, + 7941, + 1703, + 37, + 1433, + 27, + 391, + 16, + 22842, + 16458, + 187, + 187, + 53, + 43561, + 27, + 3189, + 544, + 1348, + 278, + 62, + 5816, + 619, + 806, + 385, + 544, + 1797, + 269, + 62, + 846, + 608, + 2607, + 273, + 2740, + 598, + 15, + 187, + 187, + 15743, + 27, + 24387, + 39714, + 187, + 6300, + 15950, + 436, + 1501, + 562, + 627, + 816, + 281, + 1339, + 352, + 562, + 273, + 619, + 985, + 15, + 187, + 2598, + 309, + 452, + 644, + 13597, + 436, + 3226, + 313, + 2577, + 806, + 19609, + 15, + 309, + 369, + 617, + 806, + 10, + 323, + 495, + 1107, + 15, + 844, + 574, + 271, + 13103, + 673, + 285, + 4536, + 7227, + 35267, + 285, + 37616, + 15, + 496, + 253, + 990, + 13, + 352, + 1904, + 626, + 789, + 562, + 15, + 187, + 42, + 3260, + 309, + 7636, + 617, + 285, + 703, + 7636, + 479, + 533, + 1841, + 816, + 1904, + 626, + 789, + 562, + 1955, + 281, + 1097, + 4858, + 4606, + 15, + 187, + 187, + 2598, + 352, + 556, + 644, + 2761, + 608, + 2607, + 15, + 309, + 1694, + 689, + 253, + 31056, + 673, + 273, + 619, + 1495, + 534, + 369, + 1501, + 2740, + 598, + 273, + 806, + 374, + 2607, + 15, + 209, + 187, + 4125, + 846, + 608, + 2607, + 13, + 309, + 816, + 2985, + 617, + 15, + 187, + 42, + 5476, + 627, + 11210, + 626, + 644, + 247, + 2014, + 835, + 309, + 6468, + 626, + 1869, + 670, + 617, + 2568, + 15, + 23385, + 50276, + 187, + 42, + 871, + 309, + 10095, + 626, + 3057, + 617, + 285, + 309, + 1353, + 3965, + 2119, + 703, + 1912, + 626, + 3057, + 479, + 2057, + 534, + 310, + 323, + 253, + 1805, + 15, + 187, + 1231, + 6468, + 626, + 13452, + 323, + 5046, + 374, + 2607, + 32, + 1633, + 751, + 326, + 15, + 187, + 43688, + 13, + 309, + 816, + 4571, + 626, + 6016, + 352, + 10542, + 285, + 3261, + 387, + 776, + 7963, + 327, + 619, + 17899, + 7963, + 534, + 309, + 1620, + 755, + 327, + 15, + 187, + 1147, + 369, + 5322, + 281, + 923, + 617, + 2454, + 969, + 285, + 30774, + 336, + 253, + 1711, + 1897, + 15, + 187, + 1147, + 369, + 5322, + 281, + 923, + 253, + 9097, + 359, + 1097, + 2389, + 1024, + 3811, + 342, + 617, + 2021, + 15, + 187, + 34937, + 512, + 608, + 2607, + 13, + 619, + 5249, + 5055, + 598, + 15, + 309, + 1694, + 247, + 14892, + 209, + 187, + 36421, + 598, + 247, + 2257, + 273, + 2583, + 285, + 858, + 1841, + 1475, + 253, + 2419, + 309, + 6468, + 626, + 644, + 2104, + 281, + 3966, + 3966, + 15, + 187, + 1989, + 309, + 816, + 2985, + 617, + 15, + 187, + 42, + 871, + 703, + 434, + 2509, + 973, + 13, + 3164, + 1805, + 685, + 1078, + 15, + 187, + 2513, + 352, + 816, + 479, + 32, + 209, + 187, + 25954, + 6701, + 323, + 634, + 673, + 4361, + 436, + 285, + 11435, + 634, + 5701, + 15, + 187, + 187, + 14135, + 28, + 4976, + 27, + 6365, + 619, + 806, + 19609, + 13, + 9377, + 598, + 13, + 309, + 2985, + 617, + 533, + 1053, + 626, + 3057, + 617, + 285, + 12371, + 604, + 352, + 434, + 816, + 479, + 15, + 0, + ] + ], + device=device, + ) context_length = queries.shape[1] responses = query_responses[:, context_length:] @@ -778,15 +1271,10 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well # responses not passing that filter will receive a low (fixed) score # only query humans on responses that pass that filter contain_pad_token = torch.any(postprocessed_responses == tokenizer.pad_token_id, dim=-1) - - - + # TODO: reverse it back # scores = torch.where(contain_pad_token, scores, torch.full_like(scores, args.task.penalty_reward_value)) - - - - + torch.cuda.empty_cache() # 4. compute rewards @@ -848,16 +1336,9 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well advantages = torch.stack(advantages_reversed[::-1], axis=1) returns = advantages + values - - - # TODO: reverse it back # advantages = whiten(advantages) - - - - return_mean, return_var = returns.mean(), returns.var() value_mean, value_var = values.mean(), values.var() @@ -896,8 +1377,7 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well vf_losses1 = torch.square(vpred - mb_return) vf_losses2 = torch.square(vpredclipped - mb_return) vf_loss_max = torch.max(vf_losses1, vf_losses2) - - + vf_loss = 0.5 * vf_loss_max.mean() # vf_loss = 0.5 * masked_mean(vf_loss_max, padding_mask[micro_batch_inds]) @@ -920,22 +1400,24 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well # approxkl = 0.5 * masked_mean((logprobs_diff**2), padding_mask[micro_batch_inds]) # if ppo_epoch_idx == 0 and micro_batch_start == 0: # torch.testing.assert_close(ratio, torch.zeros_like(ratio) + 1, atol=1e-4, rtol=1e-4) - pprint({ - "responses": responses, - "values": values, - "rewards": rewards, - "scores": scores, - "advantages": advantages, - "ratio": ratio, - "pg_losses": pg_losses, - "approxkl": approxkl, - "pg_loss": pg_loss, - "pg_clipfrac": pg_clipfrac, - "ratio": ratio.mean(), - "vf_loss": vf_loss, - "vf_clipfrac": vf_clipfrac, - "entropy": entropy.mean(), - }) + pprint( + { + "responses": responses, + "values": values, + "rewards": rewards, + "scores": scores, + "advantages": advantages, + "ratio": ratio, + "pg_losses": pg_losses, + "approxkl": approxkl, + "pg_loss": pg_loss, + "pg_clipfrac": pg_clipfrac, + "ratio": ratio.mean(), + "vf_loss": vf_loss, + "vf_clipfrac": vf_clipfrac, + "entropy": entropy.mean(), + } + ) with torch.no_grad(): approxkl_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = approxkl pg_clipfrac_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = pg_clipfrac @@ -946,20 +1428,20 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well ratio_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = ratio.mean() gradient_accumulation_idx += 1 raise - # minibatch_idx += 1 - # if accelerator.is_main_process: - # console.print( - # f"ppo_epoch_idx", - # ppo_epoch_idx, - # "approxkl", - # approxkl_stats[:ppo_epoch_idx+1].mean().item(), - # "pg_loss", - # pg_loss_stats[:ppo_epoch_idx+1].mean().item(), - # "pg_clipfrac", - # pg_clipfrac_stats[:ppo_epoch_idx+1].mean().item(), - # "ratio", - # ratio_stats[:ppo_epoch_idx+1].mean().item(), - # ) + # minibatch_idx += 1 + # if accelerator.is_main_process: + # console.print( + # f"ppo_epoch_idx", + # ppo_epoch_idx, + # "approxkl", + # approxkl_stats[:ppo_epoch_idx+1].mean().item(), + # "pg_loss", + # pg_loss_stats[:ppo_epoch_idx+1].mean().item(), + # "pg_clipfrac", + # pg_clipfrac_stats[:ppo_epoch_idx+1].mean().item(), + # "ratio", + # ratio_stats[:ppo_epoch_idx+1].mean().item(), + # ) with torch.no_grad(): if not args.deepspeed: # for some reason there is a OOM with the `writer.add_histogram` diff --git a/lm_human_preference_details/train_policy_accelerate_summarize_separate3.py b/lm_human_preference_details/train_policy_accelerate_summarize_separate3.py index b5b6eef..0bf5f2e 100644 --- a/lm_human_preference_details/train_policy_accelerate_summarize_separate3.py +++ b/lm_human_preference_details/train_policy_accelerate_summarize_separate3.py @@ -33,7 +33,6 @@ GenerationConfig, ) - INVALID_LOGPROB = 1.0 @@ -49,8 +48,8 @@ class RewardHParams: adaptive_kl: Optional[AdaptiveKLParams] = field(default_factory=AdaptiveKLParams) trained_model: Optional[str] = "" label_dataset: tyro.conf.Suppress[Optional[str]] = None - dataset_mean: float = 0. - dataset_std: float = 1. + dataset_mean: float = 0.0 + dataset_std: float = 1.0 kl_coef: float = 0.15 @@ -144,7 +143,9 @@ class Args: base_model: str = "EleutherAI/pythia-160m" """the name of the pretrained model to use""" - dropout_layer_keys: List[str] = field(default_factory=lambda: ["attn_pdrop", "embd_pdrop", "resid_pdrop", "summary_first_dropout"]) + dropout_layer_keys: List[str] = field( + default_factory=lambda: ["attn_pdrop", "embd_pdrop", "resid_pdrop", "summary_first_dropout"] + ) """Which layers to apply dropout to""" deepspeed: bool = False """Whether to use deepspeed to train the model""" @@ -449,7 +450,11 @@ def get_reward(reward_model, query_responses, tokenizer): # ) # print(f"======={sequence_lengths1=} {sequence_lengths=}") # https://github.com/huggingface/transformers/blob/dc68a39c8111217683bf49a4912d0c9018bab33d/src/transformers/models/gpt2/modeling_gpt2.py#L1454 - return reward_logits, reward_logits[torch.arange(reward_logits.size(0), device=reward_logits.device), sequence_lengths].squeeze(-1), sequence_lengths + return ( + reward_logits, + reward_logits[torch.arange(reward_logits.size(0), device=reward_logits.device), sequence_lengths].squeeze(-1), + sequence_lengths, + ) def forward(policy, query_responses, tokenizer): @@ -595,9 +600,7 @@ def forward(policy, query_responses, tokenizer): eval_ds_config["zero_optimization"] = { "stage": 3, "stage3_param_persistence_threshold": 1e4, - "offload_param": { - "device": "cpu" - } + "offload_param": {"device": "cpu"}, } accelerator.print(f"{eval_ds_config=}") reward_model, *_ = deepspeed.initialize(model=reward_model, config=eval_ds_config) @@ -687,7 +690,7 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well # TODO: do I do this with query response or post-processed query response? output = forward(accelerator.unwrap_model(model).policy, query_responses, tokenizer) logits = output.logits[:, context_length - 1 : -1] - logits /= (args.task.temperature + 1e-7) + logits /= args.task.temperature + 1e-7 all_logprobs = F.log_softmax(logits, dim=-1) logprobs = torch.gather(all_logprobs, 2, responses.unsqueeze(-1)).squeeze(-1) del output, logits, all_logprobs @@ -695,7 +698,7 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well ref_output = forward(ref_policy, query_responses, tokenizer) ref_logits = ref_output.logits[:, context_length - 1 : -1] - ref_logits /= (args.task.temperature + 1e-7) + ref_logits /= args.task.temperature + 1e-7 ref_all_logprobs = F.log_softmax(ref_logits, dim=-1) ref_logprobs = torch.gather(ref_all_logprobs, 2, responses.unsqueeze(-1)).squeeze(-1) del ref_output, ref_logits, ref_all_logprobs @@ -729,7 +732,7 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well # only query humans on responses that pass that filter contain_pad_token = torch.any(postprocessed_responses == tokenizer.pad_token_id, dim=-1) scores = torch.where(contain_pad_token, scores, torch.full_like(scores, args.task.penalty_reward_value)) - + # TODO: do we need to deal with penalty values? # penalty_values = torch.full_like(values, 0) # penalty_values[:,-1] += args.task.penalty_reward_value @@ -773,7 +776,9 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well if accelerator.is_main_process: all_sample_validation_df.to_json(f"runs/{run_name}/table.json") if args.track: - wandb.log({"samples/query_responses": wandb.Table(dataframe=all_sample_validation_df)}, step=update) + wandb.log( + {"samples/query_responses": wandb.Table(dataframe=all_sample_validation_df)}, step=update + ) print_rich_table("stuff", all_sample_validation_df[:4], console) except Exception as e: @@ -833,7 +838,7 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well output, vpred_temp = forward(model, mb_query_responses, tokenizer) logits = output.logits[:, context_length - 1 : -1] - logits /= (args.task.temperature + 1e-7) + logits /= args.task.temperature + 1e-7 new_all_logprobs = F.log_softmax(logits, dim=-1) new_logprobs = torch.gather(new_all_logprobs, 2, mb_responses.unsqueeze(-1)).squeeze(-1) # new_logprobs = torch.masked_fill(new_logprobs, padding_mask[micro_batch_inds], INVALID_LOGPROB) @@ -864,7 +869,7 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well approxkl = 0.5 * (logprobs_diff**2).mean() # if ppo_epoch_idx == 0 and micro_batch_start == 0: # torch.testing.assert_close(ratio, torch.zeros_like(ratio) + 1, atol=1e-4, rtol=1e-4) - # if ppo_epoch_idx == 0: + # if ppo_epoch_idx == 0: # pprint({ # # "responses": responses, # # "values": values, @@ -897,13 +902,13 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well f"ppo_epoch_idx", ppo_epoch_idx, "approxkl", - approxkl_stats[:ppo_epoch_idx+1].mean().item(), + approxkl_stats[: ppo_epoch_idx + 1].mean().item(), "pg_loss", - pg_loss_stats[:ppo_epoch_idx+1].mean().item(), + pg_loss_stats[: ppo_epoch_idx + 1].mean().item(), "pg_clipfrac", - pg_clipfrac_stats[:ppo_epoch_idx+1].mean().item(), + pg_clipfrac_stats[: ppo_epoch_idx + 1].mean().item(), "ratio", - ratio_stats[:ppo_epoch_idx+1].mean().item(), + ratio_stats[: ppo_epoch_idx + 1].mean().item(), ) # raise # breakpoint() diff --git a/lm_human_preference_details/train_policy_accelerate_summarize_separate4.py b/lm_human_preference_details/train_policy_accelerate_summarize_separate4.py index 49d023a..ee18d2f 100644 --- a/lm_human_preference_details/train_policy_accelerate_summarize_separate4.py +++ b/lm_human_preference_details/train_policy_accelerate_summarize_separate4.py @@ -33,7 +33,6 @@ GenerationConfig, ) - INVALID_LOGPROB = 1.0 @@ -49,8 +48,8 @@ class RewardHParams: adaptive_kl: Optional[AdaptiveKLParams] = field(default_factory=AdaptiveKLParams) trained_model: Optional[str] = "" label_dataset: tyro.conf.Suppress[Optional[str]] = None - dataset_mean: float = 0. - dataset_std: float = 1. + dataset_mean: float = 0.0 + dataset_std: float = 1.0 kl_coef: float = 0.15 @@ -144,7 +143,9 @@ class Args: base_model: str = "EleutherAI/pythia-160m" """the name of the pretrained model to use""" - dropout_layer_keys: List[str] = field(default_factory=lambda: ["attn_pdrop", "embd_pdrop", "resid_pdrop", "summary_first_dropout"]) + dropout_layer_keys: List[str] = field( + default_factory=lambda: ["attn_pdrop", "embd_pdrop", "resid_pdrop", "summary_first_dropout"] + ) """Which layers to apply dropout to""" deepspeed: bool = False """Whether to use deepspeed to train the model""" @@ -449,7 +450,11 @@ def get_reward(reward_model, query_responses, tokenizer): # ) # print(f"======={sequence_lengths1=} {sequence_lengths=}") # https://github.com/huggingface/transformers/blob/dc68a39c8111217683bf49a4912d0c9018bab33d/src/transformers/models/gpt2/modeling_gpt2.py#L1454 - return reward_logits, reward_logits[torch.arange(reward_logits.size(0), device=reward_logits.device), sequence_lengths].squeeze(-1), sequence_lengths + return ( + reward_logits, + reward_logits[torch.arange(reward_logits.size(0), device=reward_logits.device), sequence_lengths].squeeze(-1), + sequence_lengths, + ) def forward(policy, query_responses, tokenizer): @@ -595,9 +600,7 @@ def forward(policy, query_responses, tokenizer): eval_ds_config["zero_optimization"] = { "stage": 3, "stage3_param_persistence_threshold": 1e4, - "offload_param": { - "device": "cpu" - } + "offload_param": {"device": "cpu"}, } accelerator.print(f"{eval_ds_config=}") reward_model, *_ = deepspeed.initialize(model=reward_model, config=eval_ds_config) @@ -642,7 +645,7 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well do_sample=True, ) # use the same `0.01` temperature for validation response generation https://github.com/openai/summarize-from-feedback/blob/700967448d10004279f138666442bf1497d0e705/exps/sample.py#L27 - validation_generation_config= GenerationConfig( + validation_generation_config = GenerationConfig( max_new_tokens=args.task.response_length, min_new_tokens=args.task.response_length, temperature=(0.01 + 1e-7), @@ -696,7 +699,7 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well # TODO: do I do this with query response or post-processed query response? output = forward(accelerator.unwrap_model(model).policy, query_responses, tokenizer) logits = output.logits[:, context_length - 1 : -1] - logits /= (args.task.temperature + 1e-7) + logits /= args.task.temperature + 1e-7 all_logprobs = F.log_softmax(logits, dim=-1) logprobs = torch.gather(all_logprobs, 2, responses.unsqueeze(-1)).squeeze(-1) del output, logits, all_logprobs @@ -704,7 +707,7 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well ref_output = forward(ref_policy, query_responses, tokenizer) ref_logits = ref_output.logits[:, context_length - 1 : -1] - ref_logits /= (args.task.temperature + 1e-7) + ref_logits /= args.task.temperature + 1e-7 ref_all_logprobs = F.log_softmax(ref_logits, dim=-1) ref_logprobs = torch.gather(ref_all_logprobs, 2, responses.unsqueeze(-1)).squeeze(-1) del ref_output, ref_logits, ref_all_logprobs @@ -738,7 +741,7 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well # only query humans on responses that pass that filter contain_pad_token = torch.any(postprocessed_responses == tokenizer.pad_token_id, dim=-1) scores = torch.where(contain_pad_token, scores, torch.full_like(scores, args.task.penalty_reward_value)) - + # TODO: do we need to deal with penalty values? # penalty_values = torch.full_like(values, 0) # penalty_values[:,-1] += args.task.penalty_reward_value @@ -782,7 +785,9 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well if accelerator.is_main_process: all_sample_validation_df.to_json(f"runs/{run_name}/table.json") if args.track: - wandb.log({"samples/query_responses": wandb.Table(dataframe=all_sample_validation_df)}, step=update) + wandb.log( + {"samples/query_responses": wandb.Table(dataframe=all_sample_validation_df)}, step=update + ) print_rich_table("stuff", all_sample_validation_df[:4], console) except Exception as e: @@ -842,7 +847,7 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well output, vpred_temp = forward(model, mb_query_responses, tokenizer) logits = output.logits[:, context_length - 1 : -1] - logits /= (args.task.temperature + 1e-7) + logits /= args.task.temperature + 1e-7 new_all_logprobs = F.log_softmax(logits, dim=-1) new_logprobs = torch.gather(new_all_logprobs, 2, mb_responses.unsqueeze(-1)).squeeze(-1) # new_logprobs = torch.masked_fill(new_logprobs, padding_mask[micro_batch_inds], INVALID_LOGPROB) @@ -873,7 +878,7 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well approxkl = 0.5 * (logprobs_diff**2).mean() # if ppo_epoch_idx == 0 and micro_batch_start == 0: # torch.testing.assert_close(ratio, torch.zeros_like(ratio) + 1, atol=1e-4, rtol=1e-4) - # if ppo_epoch_idx == 0: + # if ppo_epoch_idx == 0: # pprint({ # # "responses": responses, # # "values": values, @@ -906,13 +911,13 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well f"ppo_epoch_idx", ppo_epoch_idx, "approxkl", - approxkl_stats[:ppo_epoch_idx+1].mean().item(), + approxkl_stats[: ppo_epoch_idx + 1].mean().item(), "pg_loss", - pg_loss_stats[:ppo_epoch_idx+1].mean().item(), + pg_loss_stats[: ppo_epoch_idx + 1].mean().item(), "pg_clipfrac", - pg_clipfrac_stats[:ppo_epoch_idx+1].mean().item(), + pg_clipfrac_stats[: ppo_epoch_idx + 1].mean().item(), "ratio", - ratio_stats[:ppo_epoch_idx+1].mean().item(), + ratio_stats[: ppo_epoch_idx + 1].mean().item(), ) # raise # breakpoint() diff --git a/lm_human_preference_details/train_policy_accelerate_summarize_separate5_load_critic.py b/lm_human_preference_details/train_policy_accelerate_summarize_separate5_load_critic.py index ea30982..ae59f50 100644 --- a/lm_human_preference_details/train_policy_accelerate_summarize_separate5_load_critic.py +++ b/lm_human_preference_details/train_policy_accelerate_summarize_separate5_load_critic.py @@ -33,7 +33,6 @@ GenerationConfig, ) - INVALID_LOGPROB = 1.0 @@ -49,8 +48,8 @@ class RewardHParams: adaptive_kl: Optional[AdaptiveKLParams] = field(default_factory=AdaptiveKLParams) trained_model: Optional[str] = "" label_dataset: tyro.conf.Suppress[Optional[str]] = None - dataset_mean: float = 0. - dataset_std: float = 1. + dataset_mean: float = 0.0 + dataset_std: float = 1.0 kl_coef: float = 0.15 @@ -144,7 +143,9 @@ class Args: base_model: str = "EleutherAI/pythia-160m" """the name of the pretrained model to use""" - dropout_layer_keys: List[str] = field(default_factory=lambda: ["attn_pdrop", "embd_pdrop", "resid_pdrop", "summary_first_dropout"]) + dropout_layer_keys: List[str] = field( + default_factory=lambda: ["attn_pdrop", "embd_pdrop", "resid_pdrop", "summary_first_dropout"] + ) """Which layers to apply dropout to""" deepspeed: bool = False """Whether to use deepspeed to train the model""" @@ -449,7 +450,11 @@ def get_reward(reward_model, query_responses, tokenizer): # ) # print(f"======={sequence_lengths1=} {sequence_lengths=}") # https://github.com/huggingface/transformers/blob/dc68a39c8111217683bf49a4912d0c9018bab33d/src/transformers/models/gpt2/modeling_gpt2.py#L1454 - return reward_logits, reward_logits[torch.arange(reward_logits.size(0), device=reward_logits.device), sequence_lengths].squeeze(-1), sequence_lengths + return ( + reward_logits, + reward_logits[torch.arange(reward_logits.size(0), device=reward_logits.device), sequence_lengths].squeeze(-1), + sequence_lengths, + ) def forward(policy, query_responses, tokenizer): @@ -594,9 +599,7 @@ def forward(policy, query_responses, tokenizer): eval_ds_config["zero_optimization"] = { "stage": 3, "stage3_param_persistence_threshold": 1e4, - "offload_param": { - "device": "cpu" - } + "offload_param": {"device": "cpu"}, } accelerator.print(f"{eval_ds_config=}") reward_model, *_ = deepspeed.initialize(model=reward_model, config=eval_ds_config) @@ -641,7 +644,7 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well do_sample=True, ) # use the same `0.01` temperature for validation response generation https://github.com/openai/summarize-from-feedback/blob/700967448d10004279f138666442bf1497d0e705/exps/sample.py#L27 - validation_generation_config= GenerationConfig( + validation_generation_config = GenerationConfig( max_new_tokens=args.task.response_length, min_new_tokens=args.task.response_length, temperature=(0.01 + 1e-7), @@ -695,7 +698,7 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well # TODO: do I do this with query response or post-processed query response? output = forward(accelerator.unwrap_model(model).policy, query_responses, tokenizer) logits = output.logits[:, context_length - 1 : -1] - logits /= (args.task.temperature + 1e-7) + logits /= args.task.temperature + 1e-7 all_logprobs = F.log_softmax(logits, dim=-1) logprobs = torch.gather(all_logprobs, 2, responses.unsqueeze(-1)).squeeze(-1) del output, logits, all_logprobs @@ -703,7 +706,7 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well ref_output = forward(ref_policy, query_responses, tokenizer) ref_logits = ref_output.logits[:, context_length - 1 : -1] - ref_logits /= (args.task.temperature + 1e-7) + ref_logits /= args.task.temperature + 1e-7 ref_all_logprobs = F.log_softmax(ref_logits, dim=-1) ref_logprobs = torch.gather(ref_all_logprobs, 2, responses.unsqueeze(-1)).squeeze(-1) del ref_output, ref_logits, ref_all_logprobs @@ -737,7 +740,7 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well # only query humans on responses that pass that filter contain_pad_token = torch.any(postprocessed_responses == tokenizer.pad_token_id, dim=-1) scores = torch.where(contain_pad_token, scores, torch.full_like(scores, args.task.penalty_reward_value)) - + # TODO: do we need to deal with penalty values? # penalty_values = torch.full_like(values, 0) # penalty_values[:,-1] += args.task.penalty_reward_value @@ -781,7 +784,9 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well if accelerator.is_main_process: all_sample_validation_df.to_json(f"runs/{run_name}/table.json") if args.track: - wandb.log({"samples/query_responses": wandb.Table(dataframe=all_sample_validation_df)}, step=update) + wandb.log( + {"samples/query_responses": wandb.Table(dataframe=all_sample_validation_df)}, step=update + ) print_rich_table("stuff", all_sample_validation_df[:4], console) except Exception as e: @@ -841,7 +846,7 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well output, vpred_temp = forward(model, mb_query_responses, tokenizer) logits = output.logits[:, context_length - 1 : -1] - logits /= (args.task.temperature + 1e-7) + logits /= args.task.temperature + 1e-7 new_all_logprobs = F.log_softmax(logits, dim=-1) new_logprobs = torch.gather(new_all_logprobs, 2, mb_responses.unsqueeze(-1)).squeeze(-1) # new_logprobs = torch.masked_fill(new_logprobs, padding_mask[micro_batch_inds], INVALID_LOGPROB) @@ -872,7 +877,7 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well approxkl = 0.5 * (logprobs_diff**2).mean() # if ppo_epoch_idx == 0 and micro_batch_start == 0: # torch.testing.assert_close(ratio, torch.zeros_like(ratio) + 1, atol=1e-4, rtol=1e-4) - # if ppo_epoch_idx == 0: + # if ppo_epoch_idx == 0: # pprint({ # # "responses": responses, # # "values": values, @@ -905,13 +910,13 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well f"ppo_epoch_idx", ppo_epoch_idx, "approxkl", - approxkl_stats[:ppo_epoch_idx+1].mean().item(), + approxkl_stats[: ppo_epoch_idx + 1].mean().item(), "pg_loss", - pg_loss_stats[:ppo_epoch_idx+1].mean().item(), + pg_loss_stats[: ppo_epoch_idx + 1].mean().item(), "pg_clipfrac", - pg_clipfrac_stats[:ppo_epoch_idx+1].mean().item(), + pg_clipfrac_stats[: ppo_epoch_idx + 1].mean().item(), "ratio", - ratio_stats[:ppo_epoch_idx+1].mean().item(), + ratio_stats[: ppo_epoch_idx + 1].mean().item(), ) # raise # breakpoint() diff --git a/lm_human_preference_details/train_policy_accelerate_summarize_separate6_correct_reward_index.py b/lm_human_preference_details/train_policy_accelerate_summarize_separate6_correct_reward_index.py index 7331a66..cc7fef7 100644 --- a/lm_human_preference_details/train_policy_accelerate_summarize_separate6_correct_reward_index.py +++ b/lm_human_preference_details/train_policy_accelerate_summarize_separate6_correct_reward_index.py @@ -33,7 +33,6 @@ GenerationConfig, ) - INVALID_LOGPROB = 1.0 @@ -49,8 +48,8 @@ class RewardHParams: adaptive_kl: Optional[AdaptiveKLParams] = field(default_factory=AdaptiveKLParams) trained_model: Optional[str] = "" label_dataset: tyro.conf.Suppress[Optional[str]] = None - dataset_mean: float = 0. - dataset_std: float = 1. + dataset_mean: float = 0.0 + dataset_std: float = 1.0 kl_coef: float = 0.15 @@ -144,7 +143,9 @@ class Args: base_model: str = "EleutherAI/pythia-160m" """the name of the pretrained model to use""" - dropout_layer_keys: List[str] = field(default_factory=lambda: ["attn_pdrop", "embd_pdrop", "resid_pdrop", "summary_first_dropout"]) + dropout_layer_keys: List[str] = field( + default_factory=lambda: ["attn_pdrop", "embd_pdrop", "resid_pdrop", "summary_first_dropout"] + ) """Which layers to apply dropout to""" deepspeed: bool = False """Whether to use deepspeed to train the model""" @@ -449,7 +450,11 @@ def get_reward(reward_model, query_responses, tokenizer): # ) # print(f"======={sequence_lengths1=} {sequence_lengths=}") # https://github.com/huggingface/transformers/blob/dc68a39c8111217683bf49a4912d0c9018bab33d/src/transformers/models/gpt2/modeling_gpt2.py#L1454 - return reward_logits, reward_logits[torch.arange(reward_logits.size(0), device=reward_logits.device), sequence_lengths].squeeze(-1), sequence_lengths + return ( + reward_logits, + reward_logits[torch.arange(reward_logits.size(0), device=reward_logits.device), sequence_lengths].squeeze(-1), + sequence_lengths, + ) def forward(policy, query_responses, tokenizer): @@ -594,9 +599,7 @@ def forward(policy, query_responses, tokenizer): eval_ds_config["zero_optimization"] = { "stage": 3, "stage3_param_persistence_threshold": 1e4, - "offload_param": { - "device": "cpu" - } + "offload_param": {"device": "cpu"}, } accelerator.print(f"{eval_ds_config=}") reward_model, *_ = deepspeed.initialize(model=reward_model, config=eval_ds_config) @@ -641,7 +644,7 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well do_sample=True, ) # use the same `0.01` temperature for validation response generation https://github.com/openai/summarize-from-feedback/blob/700967448d10004279f138666442bf1497d0e705/exps/sample.py#L27 - validation_generation_config= GenerationConfig( + validation_generation_config = GenerationConfig( max_new_tokens=args.task.response_length, min_new_tokens=args.task.response_length, temperature=(0.01 + 1e-7), @@ -695,7 +698,7 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well # TODO: do I do this with query response or post-processed query response? output = forward(accelerator.unwrap_model(model).policy, query_responses, tokenizer) logits = output.logits[:, context_length - 1 : -1] - logits /= (args.task.temperature + 1e-7) + logits /= args.task.temperature + 1e-7 all_logprobs = F.log_softmax(logits, dim=-1) logprobs = torch.gather(all_logprobs, 2, responses.unsqueeze(-1)).squeeze(-1) del output, logits, all_logprobs @@ -703,7 +706,7 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well ref_output = forward(ref_policy, query_responses, tokenizer) ref_logits = ref_output.logits[:, context_length - 1 : -1] - ref_logits /= (args.task.temperature + 1e-7) + ref_logits /= args.task.temperature + 1e-7 ref_all_logprobs = F.log_softmax(ref_logits, dim=-1) ref_logprobs = torch.gather(ref_all_logprobs, 2, responses.unsqueeze(-1)).squeeze(-1) del ref_output, ref_logits, ref_all_logprobs @@ -738,7 +741,7 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well # only query humans on responses that pass that filter contain_pad_token = torch.any(postprocessed_responses == tokenizer.pad_token_id, dim=-1) scores = torch.where(contain_pad_token, scores, torch.full_like(scores, args.task.penalty_reward_value)) - + # TODO: do we need to deal with penalty values? # penalty_values = torch.full_like(values, 0) # penalty_values[:,-1] += args.task.penalty_reward_value @@ -784,7 +787,9 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well if accelerator.is_main_process: all_sample_validation_df.to_json(f"runs/{run_name}/table.json") if args.track: - wandb.log({"samples/query_responses": wandb.Table(dataframe=all_sample_validation_df)}, step=update) + wandb.log( + {"samples/query_responses": wandb.Table(dataframe=all_sample_validation_df)}, step=update + ) print_rich_table("stuff", all_sample_validation_df[:4], console) except Exception as e: @@ -844,7 +849,7 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well output, vpred_temp = forward(model, mb_query_responses, tokenizer) logits = output.logits[:, context_length - 1 : -1] - logits /= (args.task.temperature + 1e-7) + logits /= args.task.temperature + 1e-7 new_all_logprobs = F.log_softmax(logits, dim=-1) new_logprobs = torch.gather(new_all_logprobs, 2, mb_responses.unsqueeze(-1)).squeeze(-1) # new_logprobs = torch.masked_fill(new_logprobs, padding_mask[micro_batch_inds], INVALID_LOGPROB) @@ -875,7 +880,7 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well approxkl = 0.5 * (logprobs_diff**2).mean() # if ppo_epoch_idx == 0 and micro_batch_start == 0: # torch.testing.assert_close(ratio, torch.zeros_like(ratio) + 1, atol=1e-4, rtol=1e-4) - # if ppo_epoch_idx == 0: + # if ppo_epoch_idx == 0: # pprint({ # # "responses": responses, # # "values": values, @@ -908,13 +913,13 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well f"ppo_epoch_idx", ppo_epoch_idx, "approxkl", - approxkl_stats[:ppo_epoch_idx+1].mean().item(), + approxkl_stats[: ppo_epoch_idx + 1].mean().item(), "pg_loss", - pg_loss_stats[:ppo_epoch_idx+1].mean().item(), + pg_loss_stats[: ppo_epoch_idx + 1].mean().item(), "pg_clipfrac", - pg_clipfrac_stats[:ppo_epoch_idx+1].mean().item(), + pg_clipfrac_stats[: ppo_epoch_idx + 1].mean().item(), "ratio", - ratio_stats[:ppo_epoch_idx+1].mean().item(), + ratio_stats[: ppo_epoch_idx + 1].mean().item(), ) # raise # breakpoint() diff --git a/lm_human_preference_details/train_policy_accelerate_summarize_separate7_correct_reward_index_no_load_critic.py b/lm_human_preference_details/train_policy_accelerate_summarize_separate7_correct_reward_index_no_load_critic.py index a316656..41622b5 100644 --- a/lm_human_preference_details/train_policy_accelerate_summarize_separate7_correct_reward_index_no_load_critic.py +++ b/lm_human_preference_details/train_policy_accelerate_summarize_separate7_correct_reward_index_no_load_critic.py @@ -33,7 +33,6 @@ GenerationConfig, ) - INVALID_LOGPROB = 1.0 @@ -49,8 +48,8 @@ class RewardHParams: adaptive_kl: Optional[AdaptiveKLParams] = field(default_factory=AdaptiveKLParams) trained_model: Optional[str] = "" label_dataset: tyro.conf.Suppress[Optional[str]] = None - dataset_mean: float = 0. - dataset_std: float = 1. + dataset_mean: float = 0.0 + dataset_std: float = 1.0 kl_coef: float = 0.15 @@ -144,7 +143,9 @@ class Args: base_model: str = "EleutherAI/pythia-160m" """the name of the pretrained model to use""" - dropout_layer_keys: List[str] = field(default_factory=lambda: ["attn_pdrop", "embd_pdrop", "resid_pdrop", "summary_first_dropout"]) + dropout_layer_keys: List[str] = field( + default_factory=lambda: ["attn_pdrop", "embd_pdrop", "resid_pdrop", "summary_first_dropout"] + ) """Which layers to apply dropout to""" deepspeed: bool = False """Whether to use deepspeed to train the model""" @@ -449,7 +450,11 @@ def get_reward(reward_model, query_responses, tokenizer): # ) # print(f"======={sequence_lengths1=} {sequence_lengths=}") # https://github.com/huggingface/transformers/blob/dc68a39c8111217683bf49a4912d0c9018bab33d/src/transformers/models/gpt2/modeling_gpt2.py#L1454 - return reward_logits, reward_logits[torch.arange(reward_logits.size(0), device=reward_logits.device), sequence_lengths].squeeze(-1), sequence_lengths + return ( + reward_logits, + reward_logits[torch.arange(reward_logits.size(0), device=reward_logits.device), sequence_lengths].squeeze(-1), + sequence_lengths, + ) def forward(policy, query_responses, tokenizer): @@ -594,9 +599,7 @@ def forward(policy, query_responses, tokenizer): eval_ds_config["zero_optimization"] = { "stage": 3, "stage3_param_persistence_threshold": 1e4, - "offload_param": { - "device": "cpu" - } + "offload_param": {"device": "cpu"}, } accelerator.print(f"{eval_ds_config=}") reward_model, *_ = deepspeed.initialize(model=reward_model, config=eval_ds_config) @@ -641,7 +644,7 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well do_sample=True, ) # use the same `0.01` temperature for validation response generation https://github.com/openai/summarize-from-feedback/blob/700967448d10004279f138666442bf1497d0e705/exps/sample.py#L27 - validation_generation_config= GenerationConfig( + validation_generation_config = GenerationConfig( max_new_tokens=args.task.response_length, min_new_tokens=args.task.response_length, temperature=(0.01 + 1e-7), @@ -695,7 +698,7 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well # TODO: do I do this with query response or post-processed query response? output = forward(accelerator.unwrap_model(model).policy, query_responses, tokenizer) logits = output.logits[:, context_length - 1 : -1] - logits /= (args.task.temperature + 1e-7) + logits /= args.task.temperature + 1e-7 all_logprobs = F.log_softmax(logits, dim=-1) logprobs = torch.gather(all_logprobs, 2, responses.unsqueeze(-1)).squeeze(-1) del output, logits, all_logprobs @@ -703,7 +706,7 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well ref_output = forward(ref_policy, query_responses, tokenizer) ref_logits = ref_output.logits[:, context_length - 1 : -1] - ref_logits /= (args.task.temperature + 1e-7) + ref_logits /= args.task.temperature + 1e-7 ref_all_logprobs = F.log_softmax(ref_logits, dim=-1) ref_logprobs = torch.gather(ref_all_logprobs, 2, responses.unsqueeze(-1)).squeeze(-1) del ref_output, ref_logits, ref_all_logprobs @@ -738,7 +741,7 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well # only query humans on responses that pass that filter contain_pad_token = torch.any(postprocessed_responses == tokenizer.pad_token_id, dim=-1) scores = torch.where(contain_pad_token, scores, torch.full_like(scores, args.task.penalty_reward_value)) - + # TODO: do we need to deal with penalty values? # penalty_values = torch.full_like(values, 0) # penalty_values[:,-1] += args.task.penalty_reward_value @@ -784,7 +787,9 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well if accelerator.is_main_process: all_sample_validation_df.to_json(f"runs/{run_name}/table.json") if args.track: - wandb.log({"samples/query_responses": wandb.Table(dataframe=all_sample_validation_df)}, step=update) + wandb.log( + {"samples/query_responses": wandb.Table(dataframe=all_sample_validation_df)}, step=update + ) print_rich_table("stuff", all_sample_validation_df[:4], console) except Exception as e: @@ -844,7 +849,7 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well output, vpred_temp = forward(model, mb_query_responses, tokenizer) logits = output.logits[:, context_length - 1 : -1] - logits /= (args.task.temperature + 1e-7) + logits /= args.task.temperature + 1e-7 new_all_logprobs = F.log_softmax(logits, dim=-1) new_logprobs = torch.gather(new_all_logprobs, 2, mb_responses.unsqueeze(-1)).squeeze(-1) # new_logprobs = torch.masked_fill(new_logprobs, padding_mask[micro_batch_inds], INVALID_LOGPROB) @@ -875,7 +880,7 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well approxkl = 0.5 * (logprobs_diff**2).mean() # if ppo_epoch_idx == 0 and micro_batch_start == 0: # torch.testing.assert_close(ratio, torch.zeros_like(ratio) + 1, atol=1e-4, rtol=1e-4) - # if ppo_epoch_idx == 0: + # if ppo_epoch_idx == 0: # pprint({ # # "responses": responses, # # "values": values, @@ -908,13 +913,13 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well f"ppo_epoch_idx", ppo_epoch_idx, "approxkl", - approxkl_stats[:ppo_epoch_idx+1].mean().item(), + approxkl_stats[: ppo_epoch_idx + 1].mean().item(), "pg_loss", - pg_loss_stats[:ppo_epoch_idx+1].mean().item(), + pg_loss_stats[: ppo_epoch_idx + 1].mean().item(), "pg_clipfrac", - pg_clipfrac_stats[:ppo_epoch_idx+1].mean().item(), + pg_clipfrac_stats[: ppo_epoch_idx + 1].mean().item(), "ratio", - ratio_stats[:ppo_epoch_idx+1].mean().item(), + ratio_stats[: ppo_epoch_idx + 1].mean().item(), ) # raise # breakpoint() diff --git a/lm_human_preference_details/train_policy_accelerate_summarize_separate8_correct_reward_index_deepspeed3.py b/lm_human_preference_details/train_policy_accelerate_summarize_separate8_correct_reward_index_deepspeed3.py index a1e860a..bb84138 100644 --- a/lm_human_preference_details/train_policy_accelerate_summarize_separate8_correct_reward_index_deepspeed3.py +++ b/lm_human_preference_details/train_policy_accelerate_summarize_separate8_correct_reward_index_deepspeed3.py @@ -33,7 +33,6 @@ GenerationConfig, ) - INVALID_LOGPROB = 1.0 @@ -49,8 +48,8 @@ class RewardHParams: adaptive_kl: Optional[AdaptiveKLParams] = field(default_factory=AdaptiveKLParams) trained_model: Optional[str] = "" label_dataset: tyro.conf.Suppress[Optional[str]] = None - dataset_mean: float = 0. - dataset_std: float = 1. + dataset_mean: float = 0.0 + dataset_std: float = 1.0 kl_coef: float = 0.15 @@ -144,7 +143,9 @@ class Args: base_model: str = "EleutherAI/pythia-160m" """the name of the pretrained model to use""" - dropout_layer_keys: List[str] = field(default_factory=lambda: ["attn_pdrop", "embd_pdrop", "resid_pdrop", "summary_first_dropout"]) + dropout_layer_keys: List[str] = field( + default_factory=lambda: ["attn_pdrop", "embd_pdrop", "resid_pdrop", "summary_first_dropout"] + ) """Which layers to apply dropout to""" deepspeed: bool = False """Whether to use deepspeed to train the model""" @@ -449,7 +450,11 @@ def get_reward(reward_model, query_responses, tokenizer): # ) # print(f"======={sequence_lengths1=} {sequence_lengths=}") # https://github.com/huggingface/transformers/blob/dc68a39c8111217683bf49a4912d0c9018bab33d/src/transformers/models/gpt2/modeling_gpt2.py#L1454 - return reward_logits, reward_logits[torch.arange(reward_logits.size(0), device=reward_logits.device), sequence_lengths].squeeze(-1), sequence_lengths + return ( + reward_logits, + reward_logits[torch.arange(reward_logits.size(0), device=reward_logits.device), sequence_lengths].squeeze(-1), + sequence_lengths, + ) def forward(policy, query_responses, tokenizer): @@ -569,6 +574,7 @@ def forward(policy, query_responses, tokenizer): if args.deepspeed: deepspeed_states = AcceleratorState().deepspeed_plugin from deepspeed.ops.adam import DeepSpeedCPUAdam + # if deepspeed_states.deepspeed_config['zero_optimization']['offload_optimizer']['device'] in ('none', None): # return optim.AdamW(params, eps=self.opt.eps, betas=(self.opt.beta1, self.opt.beta2)) optimizer = DeepSpeedCPUAdam(model.parameters(), lr=args.ppo.lr, eps=args.ppo.eps) @@ -590,7 +596,6 @@ def forward(policy, query_responses, tokenizer): # deepspeed_states = AcceleratorState().deepspeed_plugin # # deepspeed_states.deepspeed_config["train_micro_batch_size_per_gpu"] = args.ppo.local_micro_batch_size # # deepspeed_states.deepspeed_config["checkpoint"] = {"use_node_local_storage": True} - # offload = False # eval_ds_config = { # "train_micro_batch_size_per_gpu": deepspeed_states.deepspeed_config["train_micro_batch_size_per_gpu"], @@ -644,6 +649,7 @@ def forward(policy, query_responses, tokenizer): def repeat_generator(): # TODO: ideally we shuffle the dataloader as well while True: yield from dataloader + iter_dataloader = iter(repeat_generator()) sample_validation_inds = np.arange(args.ppo.batch_size) @@ -662,12 +668,13 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well data = next(iter_dataloader) queries = data["query_token"].to(device) accelerator.print(f"==={queries.shape=}, {queries.dtype}") - accelerator.print(f"==={sample_validation_query_reference_responses.shape=}, {sample_validation_query_reference_responses.dtype}") + accelerator.print( + f"==={sample_validation_query_reference_responses.shape=}, {sample_validation_query_reference_responses.dtype}" + ) _, sample_validation_reference_scores, _ = get_reward( reward_model, sample_validation_query_reference_responses, tokenizer ) - kl_ctl = AdaptiveKLController(args.rewards.kl_coef, hparams=args.rewards.adaptive_kl) # WARNING: even with `max_new_tokens` and `min_new_tokens` set to the same value, the number of tokens generated # may not be the same. TODO: investigate further, we just want to generate a fixed number of tokens @@ -680,7 +687,7 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well do_sample=True, ) # use the same `0.01` temperature for validation response generation https://github.com/openai/summarize-from-feedback/blob/700967448d10004279f138666442bf1497d0e705/exps/sample.py#L27 - validation_generation_config= GenerationConfig( + validation_generation_config = GenerationConfig( max_new_tokens=args.task.response_length, min_new_tokens=args.task.response_length, temperature=(0.01 + 1e-7), @@ -734,7 +741,7 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well # TODO: do I do this with query response or post-processed query response? output = forward(accelerator.unwrap_model(model).policy, query_responses, tokenizer) logits = output.logits[:, context_length - 1 : -1] - logits /= (args.task.temperature + 1e-7) + logits /= args.task.temperature + 1e-7 all_logprobs = F.log_softmax(logits, dim=-1) logprobs = torch.gather(all_logprobs, 2, responses.unsqueeze(-1)).squeeze(-1) del output, logits, all_logprobs @@ -742,7 +749,7 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well ref_output = forward(ref_policy, query_responses, tokenizer) ref_logits = ref_output.logits[:, context_length - 1 : -1] - ref_logits /= (args.task.temperature + 1e-7) + ref_logits /= args.task.temperature + 1e-7 ref_all_logprobs = F.log_softmax(ref_logits, dim=-1) ref_logprobs = torch.gather(ref_all_logprobs, 2, responses.unsqueeze(-1)).squeeze(-1) del ref_output, ref_logits, ref_all_logprobs @@ -777,7 +784,7 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well # only query humans on responses that pass that filter contain_pad_token = torch.any(postprocessed_responses == tokenizer.pad_token_id, dim=-1) scores = torch.where(contain_pad_token, scores, torch.full_like(scores, args.task.penalty_reward_value)) - + # TODO: do we need to deal with penalty values? # penalty_values = torch.full_like(values, 0) # penalty_values[:,-1] += args.task.penalty_reward_value @@ -823,7 +830,9 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well if accelerator.is_main_process: all_sample_validation_df.to_json(f"runs/{run_name}/table.json") if args.track: - wandb.log({"samples/query_responses": wandb.Table(dataframe=all_sample_validation_df)}, step=update) + wandb.log( + {"samples/query_responses": wandb.Table(dataframe=all_sample_validation_df)}, step=update + ) # print_rich_table("stuff", all_sample_validation_df[:4], console) except Exception as e: @@ -883,7 +892,7 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well output, vpred_temp = forward(model, mb_query_responses, tokenizer) logits = output.logits[:, context_length - 1 : -1] - logits /= (args.task.temperature + 1e-7) + logits /= args.task.temperature + 1e-7 new_all_logprobs = F.log_softmax(logits, dim=-1) new_logprobs = torch.gather(new_all_logprobs, 2, mb_responses.unsqueeze(-1)).squeeze(-1) # new_logprobs = torch.masked_fill(new_logprobs, padding_mask[micro_batch_inds], INVALID_LOGPROB) @@ -914,7 +923,7 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well approxkl = 0.5 * (logprobs_diff**2).mean() # if ppo_epoch_idx == 0 and micro_batch_start == 0: # torch.testing.assert_close(ratio, torch.zeros_like(ratio) + 1, atol=1e-4, rtol=1e-4) - # if ppo_epoch_idx == 0: + # if ppo_epoch_idx == 0: # pprint({ # # "responses": responses, # # "values": values, @@ -947,13 +956,13 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well f"ppo_epoch_idx", ppo_epoch_idx, "approxkl", - approxkl_stats[:ppo_epoch_idx+1].mean().item(), + approxkl_stats[: ppo_epoch_idx + 1].mean().item(), "pg_loss", - pg_loss_stats[:ppo_epoch_idx+1].mean().item(), + pg_loss_stats[: ppo_epoch_idx + 1].mean().item(), "pg_clipfrac", - pg_clipfrac_stats[:ppo_epoch_idx+1].mean().item(), + pg_clipfrac_stats[: ppo_epoch_idx + 1].mean().item(), "ratio", - ratio_stats[:ppo_epoch_idx+1].mean().item(), + ratio_stats[: ppo_epoch_idx + 1].mean().item(), ) # raise # breakpoint() From b84c23797d4415a44da485ce695b4e1574bf7682 Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Tue, 19 Dec 2023 21:35:27 +0000 Subject: [PATCH 43/62] rename --- .../ppo.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename lm_human_preference_details/{train_policy_accelerate_summarize_separate6_correct_reward_index.py => summarize/ppo.py} (100%) diff --git a/lm_human_preference_details/train_policy_accelerate_summarize_separate6_correct_reward_index.py b/lm_human_preference_details/summarize/ppo.py similarity index 100% rename from lm_human_preference_details/train_policy_accelerate_summarize_separate6_correct_reward_index.py rename to lm_human_preference_details/summarize/ppo.py From f686c51b9baecaebfdba943355e489168351515a Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Wed, 20 Dec 2023 16:34:29 +0000 Subject: [PATCH 44/62] push --- .pre-commit-config.yaml | 2 +- lm_human_preference_details/summarize/ppo.py | 503 +++++++----------- .../summarize/reward.py | 107 ++-- lm_human_preference_details/summarize/sft.py | 171 +++--- 4 files changed, 344 insertions(+), 439 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index d44ff3a..526eda1 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -35,5 +35,5 @@ repos: hooks: - id: codespell args: - - --ignore-words-list=nd,reacher,thist,ths,magent,ba,rouge + - --ignore-words-list=nd,reacher,thist,ths,magent,ba,rouge,hist - --skip=docs/css/termynal.css,docs/js/termynal.js \ No newline at end of file diff --git a/lm_human_preference_details/summarize/ppo.py b/lm_human_preference_details/summarize/ppo.py index cc7fef7..341e940 100644 --- a/lm_human_preference_details/summarize/ppo.py +++ b/lm_human_preference_details/summarize/ppo.py @@ -18,19 +18,17 @@ from rich.console import Console from rich.pretty import pprint from rich.table import Table -from torch import Tensor, optim -from torch.optim.optimizer import ( - _dispatch_sqrt, - _get_value, - _use_grad_for_differentiable, -) +from torch import optim from torch.utils.data import DataLoader from torch.utils.tensorboard import SummaryWriter from transformers import ( AutoConfig, + AutoModel, AutoModelForCausalLM, AutoTokenizer, GenerationConfig, + PretrainedConfig, + PreTrainedModel, ) INVALID_LOGPROB = 1.0 @@ -44,34 +42,18 @@ class AdaptiveKLParams: @dataclass class RewardHParams: - use_adaptive_kl: bool = True + use_adaptive_kl: bool = False adaptive_kl: Optional[AdaptiveKLParams] = field(default_factory=AdaptiveKLParams) - trained_model: Optional[str] = "" - label_dataset: tyro.conf.Suppress[Optional[str]] = None dataset_mean: float = 0.0 dataset_std: float = 1.0 - kl_coef: float = 0.15 + kl_coef: float = 0.05 @dataclass class PpoHParams: - total_episodes: int = 1000000 - local_batch_size: int = 64 - local_mini_batch_size: tyro.conf.Suppress[int] = None - batch_size: tyro.conf.Suppress[int] = None - mini_batch_size: tyro.conf.Suppress[int] = None - gradient_accumulation_steps: int = 64 - """gradient accumulation steps""" - local_micro_batch_size: tyro.conf.Suppress[int] = None - """per rank micro batch size""" - world_size: tyro.conf.Suppress[int] = None - batch_size: tyro.conf.Suppress[int] = None - minibatch_size: tyro.conf.Suppress[int] = None num_updates: tyro.conf.Suppress[int] = None nminibatches: int = 1 noptepochs: int = 4 - lr: float = 0.00001 - eps: float = 1e-5 vf_coef: float = 0.1 cliprange: float = 0.2 cliprange_value: float = 0.2 @@ -132,33 +114,75 @@ class Args: """the entity (team) of wandb's project""" cuda: bool = True """Whether to use cuda if available.""" - run_name: tyro.conf.Suppress[str] = None - """TO BE FILLED: a unique name of this run""" + run_name: Optional[str] = None + """a unique name of this run""" load_from_cache_file: bool = False """Whether to load data from the local cache file in `dataset.map`""" - upload_model: bool = False + push_to_hub: bool = False "whether to upload the saved model to huggingface" hf_entity: str = "" "the user or org name of the model repository from the Hugging Face Hub" + deepspeed: bool = False + """Whether to use deepspeed to train the model""" + print_sample_output_freq: int = 220 + """How often to print sample output""" + # run_eval: bool = False + # """Whether to run evaluation""" + # optimizer args + eps: float = 1e-5 + """the epsilon value for the optimizer""" + lr: float = 0.00001 + """the learning rate""" + optimizer: Literal["adam", "adamw"] = "adamw" + """Which optimizer to use""" + scheduler: str = "cosine" + """Which scheduler to use""" + warm_up_steps: int = 0 + """Number of warm up steps for the scheduler""" + + world_size: Optional[int] = None + """The number of processes (GPUs) to use""" + num_train_epochs: int = 1 + """Number of epochs to train""" + num_updates: Optional[int] = None + """The number of updates to train""" + gradient_accumulation_steps: int = 64 + """The number of gradient accumulation steps""" + local_micro_batch_size: Optional[int] = 1 + """The micro batch size per GPU (HF's `per_device_train_batch_size`)""" + total_episodes: Optional[int] = 1000000 + """The total number of episodes in the dataset""" + micro_batch_size: Optional[int] = None + """The micro batch size across devices (HF's `per_device_train_batch_size` * `world_size`)""" + local_batch_size: Optional[int] = None + """The batch size per GPU (HF's `per_device_train_batch_size` * `gradient_accumulation_steps`)""" + batch_size: Optional[int] = None + """The batch size across devices (HF's `per_device_train_batch_size` * `world_size` * `gradient_accumulation_steps`)""" + nminibatches: int = 1 + """Number of minibatches to split a batch into""" + local_mini_batch_size: Optional[int] = None + """the mini batch size per GPU""" + mini_batch_size: Optional[int] = None + """the mini batch size across GPUs""" + local_eval_batch_size: int = 8 + """per rank eval batch size""" + + # other args base_model: str = "EleutherAI/pythia-160m" """the name of the pretrained model to use""" + reward_model_path: str = "" + """the name of the pretrained model to use""" + sft_model_path: str = "EleutherAI/pythia-160m" + """the name of the pretrained model to use""" dropout_layer_keys: List[str] = field( default_factory=lambda: ["attn_pdrop", "embd_pdrop", "resid_pdrop", "summary_first_dropout"] ) """Which layers to apply dropout to""" - deepspeed: bool = False - """Whether to use deepspeed to train the model""" - print_sample_output_freq: int = 10 - """How often to print sample output""" - save_path: str = "models/ppo_policy" + output_dir: str = "models/ppo_model" """Where to save the model""" - optimizer: Literal["tf_adam", "adam", "adamw"] = "adamw" - """Which optimizer to use""" - sft_model_path: str = "" - """Where to load the SFT model""" task: TaskHParams = field(default_factory=TaskHParams) - rewards: RewardHParams = field(default_factory=RewardHParams) + reward: RewardHParams = field(default_factory=RewardHParams) ppo: PpoHParams = field(default_factory=PpoHParams) @@ -181,152 +205,10 @@ def print_rich_table(title: str, df: pd.DataFrame, console: Console) -> Table: console.print(table) -def _single_tensor_adam( - params: List[Tensor], - grads: List[Tensor], - exp_avgs: List[Tensor], - exp_avg_sqs: List[Tensor], - max_exp_avg_sqs: List[Tensor], - state_steps: List[Tensor], - grad_scale: Optional[Tensor], - found_inf: Optional[Tensor], - *, - amsgrad: bool, - beta1: float, - beta2: float, - lr: float, - weight_decay: float, - eps: float, - maximize: bool, - capturable: bool, - differentiable: bool, -): - assert grad_scale is None and found_inf is None - - for i, param in enumerate(params): - grad = grads[i] if not maximize else -grads[i] - exp_avg = exp_avgs[i] - exp_avg_sq = exp_avg_sqs[i] - step_t = state_steps[i] - # update step - step_t += 1 - # Decay the first and second moment running average coefficient - exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) - exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2) - step = _get_value(step_t) - - ### pytorch adam implementation: - # bias_correction1 = 1 - beta1 ** step - # bias_correction2 = 1 - beta2 ** step - # step_size = lr / bias_correction1 - # bias_correction2_sqrt = _dispatch_sqrt(bias_correction2) - # denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps) - # param.addcdiv_(exp_avg, denom, value=-step_size) - - ### tensorflow adam implementation: - lr_t = lr * _dispatch_sqrt(1 - beta2**step) / (1 - beta1**step) - denom = exp_avg_sq.sqrt().add_(eps) - param.addcdiv_(exp_avg, denom, value=-lr_t) - - -def adam( - params: List[Tensor], - grads: List[Tensor], - exp_avgs: List[Tensor], - exp_avg_sqs: List[Tensor], - max_exp_avg_sqs: List[Tensor], - state_steps: List[Tensor], - # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627 - # setting this as kwarg for now as functional API is compiled by torch/distributed/optim - foreach: Optional[bool] = None, - capturable: bool = False, - differentiable: bool = False, - fused: Optional[bool] = None, - grad_scale: Optional[Tensor] = None, - found_inf: Optional[Tensor] = None, - *, - amsgrad: bool, - beta1: float, - beta2: float, - lr: float, - weight_decay: float, - eps: float, - maximize: bool, -): - func = _single_tensor_adam - - func( - params, - grads, - exp_avgs, - exp_avg_sqs, - max_exp_avg_sqs, - state_steps, - amsgrad=amsgrad, - beta1=beta1, - beta2=beta2, - lr=lr, - weight_decay=weight_decay, - eps=eps, - maximize=maximize, - capturable=capturable, - differentiable=differentiable, - grad_scale=grad_scale, - found_inf=found_inf, - ) - - -class AdamTensorFlowStyle(optim.Adam): - @_use_grad_for_differentiable - def step(self, closure=None): - self._cuda_graph_capture_health_check() - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - - for group in self.param_groups: - params_with_grad = [] - grads = [] - exp_avgs = [] - exp_avg_sqs = [] - max_exp_avg_sqs = [] - state_steps = [] - beta1, beta2 = group["betas"] - - self._init_group( - group, - params_with_grad, - grads, - exp_avgs, - exp_avg_sqs, - max_exp_avg_sqs, - state_steps, - ) - - adam( - params_with_grad, - grads, - exp_avgs, - exp_avg_sqs, - max_exp_avg_sqs, - state_steps, - amsgrad=group["amsgrad"], - beta1=beta1, - beta2=beta2, - lr=group["lr"], - weight_decay=group["weight_decay"], - eps=group["eps"], - maximize=group["maximize"], - foreach=group["foreach"], - capturable=group["capturable"], - differentiable=group["differentiable"], - fused=group["fused"], - grad_scale=getattr(self, "grad_scale", None), - found_inf=getattr(self, "found_inf", None), - ) - - return loss +def layer_init(layer, std=np.sqrt(2), bias_const=0.0): + torch.nn.init.normal_(layer.weight, std=std) + torch.nn.init.constant_(layer.bias, val=bias_const) + return layer class AdaptiveKLController: @@ -341,12 +223,6 @@ def update(self, current, n_steps): self.value *= mult -def layer_init(layer, std=np.sqrt(2), bias_const=0.0): - torch.nn.init.normal_(layer.weight, std=std) - torch.nn.init.constant_(layer.bias, val=bias_const) - return layer - - def whiten(values, shift_mean=True): # `unbiased=False` matches TF `tf.nn.moments`'s setting mean, var = torch.mean(values), torch.var(values, unbiased=False) @@ -356,24 +232,62 @@ def whiten(values, shift_mean=True): return whitened -class AutoModelForCausalLMWithRewardHead(nn.Module): - def __init__(self, lm_backbone): - super().__init__() - self.lm_backbone = lm_backbone - # self.scalar_head = layer_init( - # nn.Linear(lm_backbone.config.hidden_size, 1), - # std=1 / np.sqrt(lm_backbone.config.hidden_size + 1), - # ) - self.scalar_head = layer_init(nn.Linear(lm_backbone.config.hidden_size, 1), std=0) - # self.reward_gain = torch.nn.Parameter(torch.tensor(1.0), requires_grad=False) - self.reward_bias = torch.nn.Parameter(torch.tensor(0.0), requires_grad=False) +class ScalarModelConfig(PretrainedConfig): + def __init__( + self, + base_model: str = "EleutherAI/pythia-160m", + base_config: PretrainedConfig = AutoConfig.from_pretrained("EleutherAI/pythia-160m"), + hidden_size: int = 768, + bias: float = 0.0, + **kwargs, + ): + super().__init__(**kwargs) + self.base_model = base_model + self.base_config = base_config + self.hidden_size = hidden_size + self.bias = bias + + +class ScalarModel(PreTrainedModel): + config_class = ScalarModelConfig + + def __init__(self, config: ScalarModelConfig): + super().__init__(config) + self.config = config + self.lm_backbone = AutoModel.from_pretrained( + config.base_model, + config=self.config.base_config, + trust_remote_code=True, + ) + self.scalar_head = layer_init( + nn.Linear(self.config.hidden_size, 1), + std=1 / np.sqrt(self.config.hidden_size + 1), + ) def forward(self, **kwargs): output = self.lm_backbone(**kwargs) - reward = self.scalar_head(output.hidden_states[-1]) - self.reward_bias + reward = self.scalar_head(output.hidden_states[-1]) - self.config.bias return reward +def get_reward(model, query_responses, tokenizer): + attention_mask = query_responses != tokenizer.pad_token_id + input_ids = torch.masked_fill(query_responses, ~attention_mask, 0) + reward_logits = model( + input_ids=input_ids, + attention_mask=attention_mask, + return_dict=True, + output_hidden_states=True, + ) + sequence_lengths = first_true_indices(query_responses == tokenizer.pad_token_id) - 1 + # https://github.com/huggingface/transformers/blob/dc68a39c8111217683bf49a4912d0c9018bab33d/src/transformers/models/gpt2/modeling_gpt2.py#L1454 + return ( + reward_logits, + reward_logits[torch.arange(reward_logits.size(0), device=reward_logits.device), sequence_lengths].squeeze(-1), + sequence_lengths, + ) + + # taken from https://github.com/OpenLMLab/MOSS-RLHF/blob/40b91eb2f2b71b16919addede0341d2bef70825d/ppo/ppo_trainer.py#L29 # we did this we can do a single `model = accelerator.prepare(model)` class PolicyAndValueWrapper(nn.Module): @@ -386,10 +300,6 @@ def forward(self, **kwargs): return self.policy(**kwargs), self.critic(**kwargs) -def ceil_div(a, b): - return (a - 1) // b + 1 - - def exact_div(a, b): q = a // b if a != q * b: @@ -432,39 +342,12 @@ def truncate_response(args, tokenizer, responses): return postprocessed_responses -def get_reward(reward_model, query_responses, tokenizer): - attention_mask = query_responses != tokenizer.pad_token_id - # position_ids = attention_mask.cumsum(1) - attention_mask.long() # exclusive cumsum - input_ids = torch.masked_fill(query_responses, ~attention_mask, 0) - reward_logits = reward_model( - input_ids=input_ids, - attention_mask=attention_mask, - # position_ids=position_ids, - return_dict=True, - output_hidden_states=True, - ) - sequence_lengths = first_true_indices(query_responses == tokenizer.pad_token_id) - 1 - # sequence_lengths1 = ( - # torch.eq(query_responses, tokenizer.pad_token_id).long().argmax(-1) - 1).to( - # query_responses.device - # ) - # print(f"======={sequence_lengths1=} {sequence_lengths=}") - # https://github.com/huggingface/transformers/blob/dc68a39c8111217683bf49a4912d0c9018bab33d/src/transformers/models/gpt2/modeling_gpt2.py#L1454 - return ( - reward_logits, - reward_logits[torch.arange(reward_logits.size(0), device=reward_logits.device), sequence_lengths].squeeze(-1), - sequence_lengths, - ) - - -def forward(policy, query_responses, tokenizer): +def forward(model, query_responses, tokenizer): attention_mask = query_responses != tokenizer.pad_token_id - # position_ids = attention_mask.cumsum(1) - attention_mask.long() # exclusive cumsum input_ids = torch.masked_fill(query_responses, ~attention_mask, 0) - return policy( + return model( input_ids=input_ids, attention_mask=attention_mask, - # position_ids=position_ids, return_dict=True, output_hidden_states=True, ) @@ -473,19 +356,21 @@ def forward(policy, query_responses, tokenizer): # def train(args: Args): if __name__ == "__main__": args = tyro.cli(Args) - accelerator = Accelerator(gradient_accumulation_steps=args.ppo.gradient_accumulation_steps) - args.ppo.world_size = accelerator.num_processes - args.ppo.batch_size = int(args.ppo.local_batch_size * args.ppo.world_size) - args.ppo.minibatch_size = exact_div(args.ppo.batch_size, args.ppo.nminibatches) - args.ppo.local_mini_batch_size = exact_div(args.ppo.local_batch_size, args.ppo.nminibatches) - args.ppo.local_micro_batch_size = exact_div(args.ppo.local_mini_batch_size, args.ppo.gradient_accumulation_steps) + accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps) + local_seed = args.seed + accelerator.process_index * 100003 # Prime + args.world_size = accelerator.num_processes + args.local_batch_size = args.local_micro_batch_size * args.gradient_accumulation_steps + args.micro_batch_size = int(args.local_micro_batch_size * args.world_size) + args.batch_size = int(args.local_batch_size * args.world_size) + args.mini_batch_size = exact_div(args.batch_size, args.nminibatches) + args.local_mini_batch_size = exact_div(args.local_batch_size, args.nminibatches) if args.ppo.whiten_rewards: assert ( - args.ppo.local_mini_batch_size >= 8 - ), f"Per-rank minibatch size {args.ppo.local_mini_batch_size} is insufficient for whitening" - # `per_rank_rollout_batch_size` is our `args.ppo.local_batch_size` - # `per_rank_minibatch_size` is our `args.ppo.local_mini_batch_size` - args.ppo.num_updates = args.ppo.total_episodes // args.ppo.batch_size + args.local_mini_batch_size >= 8 + ), f"Per-rank minibatch size {args.local_mini_batch_size} is insufficient for whitening" + # `per_rank_rollout_batch_size` is our `args.local_batch_size` + # `per_rank_minibatch_size` is our `args.local_mini_batch_size` + args.ppo.num_updates = args.total_episodes // args.batch_size tokenizer = AutoTokenizer.from_pretrained( args.base_model, padding_side="right", @@ -521,7 +406,6 @@ def forward(policy, query_responses, tokenizer): ) pprint(args) device = accelerator.device - local_seed = args.seed + accelerator.process_index * 100003 # Prime random.seed(local_seed) np.random.seed(local_seed) torch.manual_seed(local_seed) @@ -529,63 +413,55 @@ def forward(policy, query_responses, tokenizer): model_config = AutoConfig.from_pretrained(args.base_model) configure_dropout(model_config, args.dropout_layer_keys, 0.0) # disable dropout - if accelerator.is_main_process: - pprint(model_config) - critic = AutoModelForCausalLMWithRewardHead( - AutoModelForCausalLM.from_pretrained( - args.base_model, - config=model_config, + scalar_model_config = ScalarModelConfig( + base_model=args.base_model, + base_config=model_config, + hidden_size=model_config.hidden_size, + ) + if not args.reward_model_path: + critic: PreTrainedModel = ScalarModel(scalar_model_config) + reward_model: PreTrainedModel = ScalarModel(scalar_model_config) + else: + critic: PreTrainedModel = ScalarModel.from_pretrained( + args.reward_model_path, trust_remote_code=True, ) - ) - reward_model = AutoModelForCausalLMWithRewardHead( - AutoModelForCausalLM.from_pretrained( - args.base_model, - config=model_config, + reward_model: PreTrainedModel = ScalarModel.from_pretrained( + args.reward_model_path, trust_remote_code=True, ) - ) - if args.rewards.trained_model: - critic.load_state_dict(torch.load(args.rewards.trained_model, map_location=device), strict=False) - critic.reward_bias.data = torch.tensor(args.rewards.dataset_mean) - reward_model.load_state_dict(torch.load(args.rewards.trained_model, map_location=device), strict=False) - reward_model.reward_bias.data = torch.tensor(args.rewards.dataset_mean) - print(f"loaded pretrained reward model from {args.rewards.trained_model}") + if accelerator.is_main_process: + pprint(model_config) + pprint(reward_model.config) # each class should have a separate pretrained model that do not share weights - ref_policy = AutoModelForCausalLM.from_pretrained(args.base_model, config=model_config, trust_remote_code=True) - policy = AutoModelForCausalLM.from_pretrained(args.base_model, config=model_config, trust_remote_code=True) - policy.gradient_checkpointing_enable() + ref_policy = AutoModelForCausalLM.from_pretrained(args.sft_model_path, config=model_config, trust_remote_code=True) + policy = AutoModelForCausalLM.from_pretrained(args.sft_model_path, config=model_config, trust_remote_code=True) + # critic.lm_backbone.gradient_checkpointing_enable() + # policy.gradient_checkpointing_enable() accelerator.print(policy) - critic.lm_backbone.gradient_checkpointing_enable() accelerator.print(critic) - if args.sft_model_path: - policy.load_state_dict(torch.load(args.sft_model_path, map_location=device)) - ref_policy.load_state_dict(torch.load(args.sft_model_path, map_location=device)) - print(f"loaded pretrained policy from {args.sft_model_path}") policy.generation_config.eos_token_id = None # disable `pad_token_id` and `eos_token_id` because we just want to policy.generation_config.pad_token_id = None # generate tokens without truncation / padding model = PolicyAndValueWrapper(policy, critic) - if args.optimizer == "tf_adam": - optimizer = AdamTensorFlowStyle(model.parameters(), lr=args.ppo.lr, eps=args.ppo.eps) - elif args.optimizer == "adam": - optimizer = optim.Adam(model.parameters(), lr=args.ppo.lr, eps=args.ppo.eps) + if args.optimizer == "adam": + optimizer = optim.Adam(model.parameters(), lr=args.lr, eps=args.eps) elif args.optimizer == "adamw": - optimizer = optim.AdamW(model.parameters(), lr=args.ppo.lr, eps=args.ppo.eps) + optimizer = optim.AdamW(model.parameters(), lr=args.lr, eps=args.eps) dataset = load_dataset(args.task.query_dataset, split="train") validation_dataset = load_dataset(args.task.query_dataset, split="validation") dataset = dataset.with_format("torch", columns=["query_token", "reference_response_token"]) dataset = dataset.shuffle(seed=local_seed) - dataloader = DataLoader(dataset, batch_size=args.ppo.local_batch_size) + dataloader = DataLoader(dataset, batch_size=args.local_batch_size) validation_dataset = validation_dataset.with_format("torch", columns=["query_token", "reference_response_token"]) - validation_dataloader = DataLoader(validation_dataset, batch_size=args.ppo.local_batch_size) + validation_dataloader = DataLoader(validation_dataset, batch_size=args.local_batch_size) model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader) validation_dataloader = accelerator.prepare(validation_dataloader) if args.deepspeed: import deepspeed deepspeed_states = AcceleratorState().deepspeed_plugin - # deepspeed_states.deepspeed_config["train_micro_batch_size_per_gpu"] = args.ppo.local_micro_batch_size + # deepspeed_states.deepspeed_config["train_micro_batch_size_per_gpu"] = args.local_micro_batch_size # deepspeed_states.deepspeed_config["checkpoint"] = {"use_node_local_storage": True} offload = False @@ -614,25 +490,21 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well while True: yield from dataloader - sample_validation_inds = np.arange(args.ppo.batch_size) + sample_validation_inds = np.arange(args.batch_size) local_sample_validation_inds = sample_validation_inds[accelerator.process_index :: accelerator.num_processes] sample_validation = validation_dataset[local_sample_validation_inds] sample_validation_queries = torch.Tensor(sample_validation["query_token"]).to(device) with torch.no_grad(): - # sample_validation_queries = shift_pad_id_left(sample_validation_queries, tokenizer.pad_token_id) sample_validation_reference_response = torch.Tensor(sample_validation["reference_response_token"]).to(device) sample_validation_query_reference_responses = torch.cat( (sample_validation_queries, sample_validation_reference_response), dim=1 ) - # sample_validation_query_reference_responses = shift_pad_id_left( - # sample_validation_query_reference_responses, tokenizer.pad_token_id - # ) _, sample_validation_reference_scores, _ = get_reward( reward_model, sample_validation_query_reference_responses, tokenizer ) iter_dataloader = iter(repeat_generator()) - kl_ctl = AdaptiveKLController(args.rewards.kl_coef, hparams=args.rewards.adaptive_kl) + kl_ctl = AdaptiveKLController(args.reward.kl_coef, hparams=args.reward.adaptive_kl) # WARNING: even with `max_new_tokens` and `min_new_tokens` set to the same value, the number of tokens generated # may not be the same. TODO: investigate further, we just want to generate a fixed number of tokens generation_config = GenerationConfig( @@ -653,9 +525,9 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well do_sample=True, ) - print("===training policy===") + accelerator.print("===training policy===") global_step = 0 - stats_shape = (args.ppo.noptepochs, args.ppo.nminibatches, args.ppo.gradient_accumulation_steps) + stats_shape = (args.ppo.noptepochs, args.nminibatches, args.gradient_accumulation_steps) approxkl_stats = torch.zeros(stats_shape, device=device) pg_clipfrac_stats = torch.zeros(stats_shape, device=device) pg_loss_stats = torch.zeros(stats_shape, device=device) @@ -665,9 +537,9 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well ratio_stats = torch.zeros(stats_shape, device=device) model.train() for update in range(1, args.ppo.num_updates + 1): - global_step += 1 * args.ppo.batch_size + global_step += 1 * args.batch_size frac = 1.0 - (update - 1.0) / args.ppo.num_updates - lrnow = frac * args.ppo.lr + lrnow = frac * args.lr optimizer.param_groups[0]["lr"] = lrnow data = next(iter_dataloader) with torch.no_grad(): @@ -790,7 +662,8 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well wandb.log( {"samples/query_responses": wandb.Table(dataframe=all_sample_validation_df)}, step=update ) - print_rich_table("stuff", all_sample_validation_df[:4], console) + else: + print_rich_table(f"Sample Output at Step {update}", all_sample_validation_df[:1], console) except Exception as e: print(e) @@ -830,15 +703,15 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well # breakpoint() # Do multiple epochs of PPO training, with a fresh random shuffle in each epoch for ppo_epoch_idx in range(args.ppo.noptepochs): - b_inds = np.random.permutation(args.ppo.local_batch_size) + b_inds = np.random.permutation(args.local_batch_size) minibatch_idx = 0 - for mini_batch_start in range(0, args.ppo.local_batch_size, args.ppo.local_mini_batch_size): - mini_batch_end = mini_batch_start + args.ppo.local_mini_batch_size + for mini_batch_start in range(0, args.local_batch_size, args.local_mini_batch_size): + mini_batch_end = mini_batch_start + args.local_mini_batch_size mini_batch_inds = b_inds[mini_batch_start:mini_batch_end] gradient_accumulation_idx = 0 - for micro_batch_start in range(0, args.ppo.local_mini_batch_size, args.ppo.local_micro_batch_size): + for micro_batch_start in range(0, args.local_mini_batch_size, args.local_micro_batch_size): with accelerator.accumulate(policy): - micro_batch_end = micro_batch_start + args.ppo.local_micro_batch_size + micro_batch_end = micro_batch_start + args.local_micro_batch_size micro_batch_inds = mini_batch_inds[micro_batch_start:micro_batch_end] mb_return = returns[micro_batch_inds] mb_advantage = advantages[micro_batch_inds] @@ -964,20 +837,36 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well writer.add_scalar("ppo/val/num_eos_tokens", (responses == tokenizer.eos_token_id).sum().item(), update) writer.add_scalar("ppo/lr", lrnow, update) writer.add_scalar("ppo/episode", global_step, update) - if args.rewards.use_adaptive_kl: - kl_ctl.update(mean_kl.item(), args.ppo.batch_size) + if args.reward.use_adaptive_kl: + kl_ctl.update(mean_kl.item(), args.batch_size) del kl, mean_kl, mean_entropy, mean_non_score_reward, scores # save model - if args.save_path: - os.makedirs(os.path.dirname(args.save_path), exist_ok=True) - accelerator.save_model(policy, args.save_path, max_shard_size="1000GB") - - if args.upload_model and accelerator.is_main_process: - repo_name = f"{args.exp_name}__{args.rewards.label_dataset}__seed{args.seed}__{int(time.time())}" - repo_id = f"{args.hf_entity}/{repo_name}" if args.hf_entity else repo_name - policy.save_pretrained(repo_id, safe_serialization=True, push_to_hub=True) - tokenizer.save_pretrained(repo_id, push_to_hub=True) + if args.output_dir and args.num_train_epochs > 0: + os.makedirs(os.path.dirname(args.output_dir), exist_ok=True) + time_tensor = torch.tensor([int(time.time())], device=device) + time_int = accelerator.gather(time_tensor)[0].item() # avoid different timestamps across processes + repo_name = f"{args.base_model.replace('/', '_')}__{args.exp_name}__tldr" + repo_id = f"{args.hf_entity}/{repo_name}" if args.hf_entity else repo_name + + if accelerator.is_main_process: + tokenizer.save_pretrained(args.output_dir, repo_id=repo_id) + if args.push_to_hub: + tokenizer.push_to_hub(repo_id, revision=f"seed{args.seed}_{str(time_int)}") + + unwrapped: PreTrainedModel = accelerator.unwrap_model(policy) + accelerator.wait_for_everyone() + if accelerator.is_main_process: + unwrapped.save_pretrained( + args.output_dir, + is_main_process=accelerator.is_main_process, + save_function=accelerator.save, + state_dict=accelerator.get_state_dict(policy), + safe_serialization=False, + repo_id=repo_id, + ) + if args.push_to_hub: + unwrapped.push_to_hub(repo_id, revision=f"seed{args.seed}_{str(time_int)}", safe_serialization=False) # if __name__ == "__main__": # args = tyro.cli(Args) diff --git a/lm_human_preference_details/summarize/reward.py b/lm_human_preference_details/summarize/reward.py index 30e3893..7b13a42 100644 --- a/lm_human_preference_details/summarize/reward.py +++ b/lm_human_preference_details/summarize/reward.py @@ -109,6 +109,12 @@ class Args: warm_up_steps: int = 0 """Number of warm up steps for the scheduler""" + world_size: Optional[int] = None + """The number of processes (GPUs) to use""" + num_train_epochs: int = 1 + """Number of epochs to train""" + num_updates: Optional[int] = None + """The number of updates to train""" gradient_accumulation_steps: int = 8 """The number of gradient accumulation steps""" local_micro_batch_size: Optional[int] = 1 @@ -123,12 +129,6 @@ class Args: """The batch size across devices (HF's `per_device_train_batch_size` * `world_size` * `gradient_accumulation_steps`)""" local_eval_batch_size: int = 8 """per rank eval batch size""" - world_size: Optional[int] = None - """The number of processes (GPUs) to use""" - num_train_epochs: int = 1 - """Number of epochs to train""" - num_updates: Optional[int] = None - """The number of updates to train""" # other args base_model: str = "EleutherAI/pythia-160m" @@ -137,14 +137,14 @@ class Args: default_factory=lambda: ["attn_pdrop", "embd_pdrop", "resid_pdrop", "summary_first_dropout"] ) """Which layers to apply dropout to""" - output_dir: str = "models/reward_policy" + output_dir: str = "models/reward_model" """Where to save the model""" label_dataset: str = "vwxyzjn/summarize_from_feedback_oai_preprocessing_pythia-160m_169" """the name of the dataset to use for labels in `https://huggingface.co/datasets/vwxyzjn/lm-human-preferences`""" logsigmoid: bool = True """Whether to use log-sigmoid loss instead of cross-entropy loss""" task: TaskHParams = field(default_factory=TaskHParams) - labels: LabelHParams = field(default_factory=LabelHParams) + label: LabelHParams = field(default_factory=LabelHParams) # taken from https://github.com/microsoft/DeepSpeedExamples/blob/737c6740bec38b77a24a59135b6481a53d566b38/applications/DeepSpeed-Chat/training/utils/model/model_utils.py#L20C1-L26C52 @@ -173,11 +173,19 @@ def layer_init(layer, std=np.sqrt(2), bias_const=0.0): class ScalarModelConfig(PretrainedConfig): - model_type = "scalar_model" - - def __init__(self, base_model: str = "gpt2", **kwargs): + def __init__( + self, + base_model: str = "EleutherAI/pythia-160m", + base_config: PretrainedConfig = AutoConfig.from_pretrained("EleutherAI/pythia-160m"), + hidden_size: int = 768, + bias: float = 0.0, + **kwargs, + ): super().__init__(**kwargs) self.base_model = base_model + self.base_config = base_config + self.hidden_size = hidden_size + self.bias = bias class ScalarModel(PreTrainedModel): @@ -186,23 +194,19 @@ class ScalarModel(PreTrainedModel): def __init__(self, config: ScalarModelConfig): super().__init__(config) self.config = config - self.model_config = AutoConfig.from_pretrained( - config.base_model, - trust_remote_code=True, - ) self.lm_backbone = AutoModel.from_pretrained( config.base_model, - config=self.model_config, + config=self.config.base_config, trust_remote_code=True, ) self.scalar_head = layer_init( - nn.Linear(self.model_config.hidden_size, 1), - std=1 / np.sqrt(self.model_config.hidden_size + 1), + nn.Linear(self.config.hidden_size, 1), + std=1 / np.sqrt(self.config.hidden_size + 1), ) def forward(self, **kwargs): output = self.lm_backbone(**kwargs) - reward = self.scalar_head(output.hidden_states[-1]) + reward = self.scalar_head(output.hidden_states[-1]) - self.config.bias return reward @@ -220,7 +224,7 @@ def get_reward(model, query_responses, tokenizer): return reward_logits[torch.arange(reward_logits.size(0), device=reward_logits.device), sequence_lengths] -def evaluate(args, accelerator, tokenizer, model, dataloader): +def evaluate(args: Args, accelerator, tokenizer, model, dataloader): model.eval() with torch.no_grad(): items = defaultdict(list) @@ -228,11 +232,11 @@ def evaluate(args, accelerator, tokenizer, model, dataloader): mb_query = data["query_token"] mb_responses = torch.cat([data[f"response0_token"].unsqueeze(1), data[f"response1_token"].unsqueeze(1)], dim=1) mb_best = data["choice"] - mb_query_tiled = mb_query.unsqueeze(1).repeat(1, args.labels.num_labels, 1) + mb_query_tiled = mb_query.unsqueeze(1).repeat(1, args.label.num_labels, 1) query_responses = torch.cat([mb_query_tiled, mb_responses], dim=2).flatten(0, 1) # query_responses = left_padding_to_right_padding(query_responses, tokenizer.pad_token_id) predicted_reward = get_reward(model, query_responses, tokenizer) - predicted_reward = predicted_reward.view(-1, args.labels.num_labels) + predicted_reward = predicted_reward.view(-1, args.label.num_labels) accuracy = (predicted_reward.argmax(1) == mb_best).float() for k in data: @@ -266,7 +270,7 @@ def evaluate(args, accelerator, tokenizer, model, dataloader): # load dataset dataset = load_dataset(args.label_dataset, "comparisons", split="train") dataset = dataset.shuffle(seed=local_seed) - dataset = dataset.select(range(args.labels.num_train)) + dataset = dataset.select(range(args.label.num_train)) dataset = dataset.with_format( "torch", columns=["query_token", "choice", "response0_token", "response1_token", "batch", "split"] ) @@ -328,11 +332,16 @@ def evaluate(args, accelerator, tokenizer, model, dataloader): ) # we use the padding token manually but do not resize the token embedding of the model tokenizer.add_special_tokens({"pad_token": "[PAD]"}) - # model_config = AutoConfig.from_pretrained(args.base_model) - # configure_dropout(model_config, args.dropout_layer_keys, 0.0) # disable dropout - # if accelerator.is_main_process: - # pprint(model_config) - model: PreTrainedModel = ScalarModel(ScalarModelConfig(base_model=args.base_model)) + model_config = AutoConfig.from_pretrained(args.base_model) + configure_dropout(model_config, args.dropout_layer_keys, 0.0) # disable dropout + scalar_model_config = ScalarModelConfig( + base_model=args.base_model, + base_config=model_config, + hidden_size=model_config.hidden_size, + ) + model: PreTrainedModel = ScalarModel(scalar_model_config) + if accelerator.is_main_process: + pprint(model_config) if args.optimizer == "adam": optimizer = optim.Adam(model.parameters(), lr=args.lr, eps=args.eps) elif args.optimizer == "adamw": @@ -348,10 +357,11 @@ def evaluate(args, accelerator, tokenizer, model, dataloader): deepspeed_states = AcceleratorState().deepspeed_plugin deepspeed_states.deepspeed_config["train_micro_batch_size_per_gpu"] = args.local_micro_batch_size - model, optimizer, dataloader, scheduler = accelerator.prepare(model, optimizer, dataloader, scheduler) + model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader) + # scheduler = accelerator.prepare(scheduler) # breaks with accelerate@0.25.0 validation_dataloader = accelerator.prepare(validation_dataloader) - accelerator.print("===training reward model===") + accelerator.print("===training model===") losses = torch.zeros((args.gradient_accumulation_steps,), device=device) accuracies = torch.zeros((args.gradient_accumulation_steps,), device=device) reward_preferreds = torch.zeros((args.gradient_accumulation_steps,), device=device) @@ -368,11 +378,11 @@ def evaluate(args, accelerator, tokenizer, model, dataloader): mb_query = data["query_token"] mb_responses = torch.cat([data[f"response0_token"].unsqueeze(1), data[f"response1_token"].unsqueeze(1)], dim=1) mb_best = data["choice"] - mb_query_tiled = mb_query.unsqueeze(1).repeat(1, args.labels.num_labels, 1) + mb_query_tiled = mb_query.unsqueeze(1).repeat(1, args.label.num_labels, 1) query_responses = torch.cat([mb_query_tiled, mb_responses], dim=2).flatten(0, 1) with accelerator.accumulate(model): predicted_reward = get_reward(model, query_responses, tokenizer) - predicted_reward = predicted_reward.view(-1, args.labels.num_labels) + predicted_reward = predicted_reward.view(-1, args.label.num_labels) accuracy = (predicted_reward.argmax(1) == mb_best).float().mean() reward_preferred = predicted_reward.gather(1, mb_best.view(-1, 1)).view(-1) reward_rejected = predicted_reward.gather(1, (1 - mb_best).view(-1, 1)).view(-1) @@ -389,30 +399,30 @@ def evaluate(args, accelerator, tokenizer, model, dataloader): reward_preferreds[gradient_accumulation_idx] = reward_preferred.mean() reward_rejecteds[gradient_accumulation_idx] = reward_rejected.mean() gradient_accumulation_idx = (gradient_accumulation_idx + 1) % args.gradient_accumulation_steps - break if update > 1 and (update - 1) % args.gradient_accumulation_steps == 0: train_accuracy = accelerator.gather(accuracies).mean().item() - writer.add_scalar("train/loss", accelerator.gather(losses).mean().item(), global_step) - writer.add_scalar("train/accuracy", train_accuracy, global_step) - writer.add_scalar("train/reward_preferred", accelerator.gather(reward_preferreds).mean().item(), global_step) - writer.add_scalar("train/reward_rejected", accelerator.gather(reward_rejecteds).mean().item(), global_step) - writer.add_scalar("train/lr", scheduler.get_last_lr()[0], global_step) + writer.add_scalar("train/rm/loss", accelerator.gather(losses).mean().item(), global_step) + writer.add_scalar("train/rm/accuracy", train_accuracy, global_step) + writer.add_scalar( + "train/rm/reward_preferred", accelerator.gather(reward_preferreds).mean().item(), global_step + ) + writer.add_scalar("train/rm/reward_rejected", accelerator.gather(reward_rejecteds).mean().item(), global_step) + writer.add_scalar("train/rm/lr", scheduler.get_last_lr()[0], global_step) accelerator.print(f"{train_accuracy=}, {scheduler.get_last_lr()=}, {update=}") - # if args.print_sample_output_freq > 0 and global_step % args.print_sample_output_freq == 0: if args.run_eval: evaluate_df = evaluate(args, accelerator, tokenizer, model, validation_dataloader) for split, row in evaluate_df[["split", "accuracy"]].groupby(["split"]).mean().iterrows(): - writer.add_scalar(f"eval/accuracy/{split}", row["accuracy"], global_step) - accelerator.print(f"{split} accuracy: {row['accuracy']}") + writer.add_scalar(f"eval/rm/accuracy/split/{split}", row["accuracy"], global_step) + accelerator.print(f"eval/rm/accuracy/split/{split}: {row['accuracy']}") for batch, row in evaluate_df[["batch", "accuracy"]].groupby(["batch"]).mean().iterrows(): - writer.add_scalar(f"eval/accuracy/{batch}", row["accuracy"], global_step) - accelerator.print(f"{batch} accuracy: {row['accuracy']}") + writer.add_scalar(f"eval/rm/accuracy/batch/{batch}", row["accuracy"], global_step) + accelerator.print(f"eval/rm/accuracy/batch/{batch}: {row['accuracy']}") for confi, row in evaluate_df[["confidence", "accuracy"]].groupby(["confidence"]).mean().iterrows(): - writer.add_scalar(f"eval/confidence/{confi}", row["accuracy"], global_step) - accelerator.print(f"{confi} confidence: {row['accuracy']}") - writer.add_scalar("eval/accuracy", evaluate_df["accuracy"].mean(), global_step) - accelerator.print(f"eval accuracy: {evaluate_df['accuracy'].mean()}") + writer.add_scalar(f"eval/rm/accuracy/confidence/{confi}", row["accuracy"], global_step) + accelerator.print(f"eval/rm/accuracy/confidence/{confi}: {row['accuracy']}") + writer.add_scalar("eval/rm/accuracy", evaluate_df["accuracy"].mean(), global_step) + accelerator.print(f"eval/rm/accuracy: {evaluate_df['accuracy'].mean()}") if accelerator.is_main_process: os.makedirs(f"eval_tables/{run_name}", exist_ok=True) evaluate_df.to_csv(f"eval_tables/{run_name}/eval_{update}.csv") @@ -454,7 +464,7 @@ def evaluate(args, accelerator, tokenizer, model, dataloader): "min": norm_df["predicted_reward"].min(), } for stat_name, stat_value in stats.items(): - writer.add_scalar(f"eval/normalized_{stat_name}", stat_value, global_step) + writer.add_scalar(f"eval/rm/normalized_{stat_name}", stat_value, global_step) accelerator.print(f"Normalized Reward {stat_name.capitalize()}: {stat_value}") # save model @@ -473,6 +483,7 @@ def evaluate(args, accelerator, tokenizer, model, dataloader): unwrapped: PreTrainedModel = accelerator.unwrap_model(model) accelerator.wait_for_everyone() if accelerator.is_main_process: + unwrapped.config.bias = norm_df["predicted_reward"].mean() unwrapped.save_pretrained( args.output_dir, is_main_process=accelerator.is_main_process, diff --git a/lm_human_preference_details/summarize/sft.py b/lm_human_preference_details/summarize/sft.py index 8f84bc1..f0c49aa 100644 --- a/lm_human_preference_details/summarize/sft.py +++ b/lm_human_preference_details/summarize/sft.py @@ -6,7 +6,7 @@ from types import SimpleNamespace from typing import List, Literal, Optional -import evaluate +import evaluate as hf_evaluate import numpy as np import pandas as pd import torch @@ -31,6 +31,8 @@ get_scheduler, ) +rouge = hf_evaluate.load("rouge") + @dataclass class TaskHParams: @@ -98,6 +100,12 @@ class Args: warm_up_steps: int = 0 """Number of warm up steps for the scheduler""" + world_size: Optional[int] = None + """The number of processes (GPUs) to use""" + num_train_epochs: int = 1 + """Number of epochs to train""" + num_updates: Optional[int] = None + """The number of updates to train""" gradient_accumulation_steps: int = 16 """The number of gradient accumulation steps""" local_micro_batch_size: Optional[int] = 1 @@ -112,12 +120,7 @@ class Args: """The batch size across devices (HF's `per_device_train_batch_size` * `world_size` * `gradient_accumulation_steps`)""" local_eval_batch_size: int = 4 """per rank eval batch size""" - world_size: Optional[int] = None - """The number of processes (GPUs) to use""" - num_train_epochs: int = 1 - """Number of epochs to train""" - num_updates: Optional[int] = None - """The number of updates to train""" + # other args base_model: str = "EleutherAI/pythia-160m" """the name of the pretrained model to use""" @@ -166,16 +169,80 @@ def generate(lm_backbone, queries, tokenizer, generation_config): def forward(model, query_responses, tokenizer): attention_mask = query_responses != tokenizer.pad_token_id - # position_ids = attention_mask.cumsum(1) - attention_mask.long() # exclusive cumsum input_ids = torch.masked_fill(query_responses, ~attention_mask, 0) return model( input_ids=input_ids, attention_mask=attention_mask, - # position_ids=position_ids, return_dict=True, ) +def evaluate(args: Args, accelerator, tokenizer, model, dataloader, generation_config): + model.eval() + rouge_scores = collections.defaultdict(list) + all_decode_queries = [] + all_decode_query_responses = [] + all_decode_responses = [] + all_decode_reference_responses = [] + all_losses = [] + for _, data in tqdm(enumerate(dataloader)): + with torch.no_grad(): + reference_responses = data["reference_response_token"] + queries = data["query_token"] + query_reference_responses = torch.cat((queries, reference_responses), dim=1) + output = forward(model, query_reference_responses, tokenizer) + labels = query_reference_responses.masked_fill(query_reference_responses == tokenizer.pad_token_id, -1) + lm_logits = output.logits + # hand-rolled transformer loss: Shift so that tokens < n predict n + # but unlike `transformers` we mask the padding tokens via `ignore_index=-1` + shift_logits = lm_logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + loss = F.cross_entropy( + shift_logits.view(-1, shift_logits.size(-1)), + shift_labels.view(-1), + ignore_index=-1, + ) + loss = accelerator.gather(loss) + all_losses.append(loss) + + generated_responses = generate( + accelerator.unwrap_model(model), + queries, + tokenizer, + generation_config, + ) + decode_queries = tokenizer.batch_decode(accelerator.gather(queries)) + decode_query_responses = tokenizer.batch_decode(accelerator.gather(generated_responses)) + decode_reference_responses = tokenizer.batch_decode( + accelerator.gather(reference_responses), + skip_special_tokens=True, + ) + decode_responses = tokenizer.batch_decode( + accelerator.gather(generated_responses[:, -args.task.response_length :]), + skip_special_tokens=True, + ) + rouge_score = rouge.compute(predictions=decode_responses, references=decode_reference_responses) + rouge_scores["rouge1"].append(rouge_score["rouge1"]) + rouge_scores["rouge2"].append(rouge_score["rouge2"]) + rouge_scores["rougeL"].append(rouge_score["rougeL"]) + + all_decode_queries.extend(decode_queries) + all_decode_query_responses.extend(decode_query_responses) + all_decode_responses.extend(decode_responses) + all_decode_reference_responses.extend(decode_reference_responses) + return ( + pd.DataFrame( + { + "query": all_decode_queries, + "response": all_decode_responses, + "reference": all_decode_reference_responses, + } + ), + rouge_scores, + all_losses, + ) + + # def train(args: Args): if __name__ == "__main__": args = tyro.cli(Args) @@ -256,7 +323,8 @@ def forward(model, query_responses, tokenizer): num_training_steps=args.num_updates * args.num_train_epochs, ) - model, optimizer, dataloader, scheduler = accelerator.prepare(model, optimizer, dataloader, scheduler) + model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader) + # scheduler = accelerator.prepare(scheduler) # breaks with accelerate@0.25.0 validation_dataloader = accelerator.prepare(validation_dataloader) # WARNING: even with `max_new_tokens` and `min_new_tokens` set to the same value, the number of tokens generated # may not be the same. TODO: investigate further, we just want to generate a fixed number of tokens @@ -268,7 +336,6 @@ def forward(model, query_responses, tokenizer): top_p=1.0, do_sample=True, ) - rouge = evaluate.load("rouge") accelerator.print("===training model===") loss_stats = torch.zeros(args.gradient_accumulation_steps, device=device) @@ -301,83 +368,21 @@ def forward(model, query_responses, tokenizer): loss_stats[gradient_accumulation_idx] = loss gradient_accumulation_idx = (gradient_accumulation_idx + 1) % args.gradient_accumulation_steps if update > 1 and (update - 1) % args.gradient_accumulation_steps == 0: - writer.add_scalar("loss", accelerator.gather(loss_stats).mean().item(), update) - writer.add_scalar("lr", scheduler.get_last_lr()[0], update) + writer.add_scalar("train/sft/loss", accelerator.gather(loss_stats).mean().item(), update) + writer.add_scalar("train/sft/lr", scheduler.get_last_lr()[0], update) accelerator.print(f"{loss.item()=}, {scheduler.get_last_lr()=}, {update=}") if args.run_eval: - model.eval() - rouge_scores = collections.defaultdict(list) - all_decode_validation_queries = [] - all_decode_validation_query_responses = [] - all_decode_validation_responses = [] - all_decode_validation_reference_responses = [] - all_validation_losses = [] - for validation_idx, validation_data in tqdm(enumerate(validation_dataloader)): - with torch.no_grad(): - validation_reference_responses = validation_data["reference_response_token"].to(device, non_blocking=True) - validation_queries = validation_data["query_token"].to(device, non_blocking=True) - validation_query_reference_responses = torch.cat((validation_queries, validation_reference_responses), dim=1) - - validation_output = forward(model, validation_query_reference_responses, tokenizer) - validation_labels = validation_query_reference_responses.masked_fill( - validation_query_reference_responses == tokenizer.pad_token_id, -1 - ) - validation_lm_logits = validation_output.logits - # hand-rolled transformer loss: Shift so that tokens < n predict n - # but unlike `transformers` we mask the padding tokens via `ignore_index=-1` - validation_shift_logits = validation_lm_logits[..., :-1, :].contiguous() - validation_shift_labels = validation_labels[..., 1:].contiguous() - validation_loss = F.cross_entropy( - validation_shift_logits.view(-1, validation_shift_logits.size(-1)), - validation_shift_labels.view(-1), - ignore_index=-1, - ) - validation_loss = accelerator.gather(validation_loss) - all_validation_losses.append(validation_loss) - - generated_responses = generate( - accelerator.unwrap_model(model), - validation_queries, - tokenizer, - generation_config, - ) - decode_validation_queries = tokenizer.batch_decode(accelerator.gather(validation_queries)) - decode_validation_query_responses = tokenizer.batch_decode(accelerator.gather(generated_responses)) - decode_validation_reference_responses = tokenizer.batch_decode( - accelerator.gather(validation_reference_responses), - skip_special_tokens=True, - ) - decode_validation_responses = tokenizer.batch_decode( - accelerator.gather(generated_responses[:, -args.task.response_length :]), - skip_special_tokens=True, - ) - rouge_score = rouge.compute( - predictions=decode_validation_responses, references=decode_validation_reference_responses - ) - rouge_scores["rouge1"].append(rouge_score["rouge1"]) - rouge_scores["rouge2"].append(rouge_score["rouge2"]) - rouge_scores["rougeL"].append(rouge_score["rougeL"]) - - all_decode_validation_queries.extend(decode_validation_queries) - all_decode_validation_query_responses.extend(decode_validation_query_responses) - all_decode_validation_responses.extend(decode_validation_responses) - all_decode_validation_reference_responses.extend(decode_validation_reference_responses) + evaluate_df, rouge_scores, all_validation_losses = evaluate( + args, accelerator, tokenizer, model, dataloader, generation_config + ) + if accelerator.is_main_process and args.track: + wandb.log({"samples/query_responses": wandb.Table(dataframe=evaluate_df)}, step=update) try: - all_df = pd.DataFrame( - { - "query": all_decode_validation_queries, - "response": all_decode_validation_responses, - "reference": all_decode_validation_reference_responses, - } - ) - accelerator.print(all_df) - if accelerator.is_main_process and args.track: - wandb.log({"samples/query_responses": wandb.Table(dataframe=all_df)}, step=update) - print_rich_table(f"Sample Output at Step {update}", all_df[:4], console) + if accelerator.is_main_process: + print_rich_table(f"Sample Output at Step {update}", evaluate_df[:4], console) except Exception as e: print(e) - for k, v in rouge_scores.items(): rouge_metric = torch.tensor(v, device=device) rouge_metric = accelerator.gather(rouge_metric) @@ -386,7 +391,7 @@ def forward(model, query_responses, tokenizer): writer.add_scalar("validation_loss", torch.stack(all_validation_losses).mean().item(), update) # save model - if args.output_dir: + if args.output_dir and args.num_train_epochs > 0: os.makedirs(os.path.dirname(args.output_dir), exist_ok=True) time_tensor = torch.tensor([int(time.time())], device=device) time_int = accelerator.gather(time_tensor)[0].item() # avoid different timestamps across processes From 22aa0d182400f2bd20e8c3c979a2e9b55bfa9f49 Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Thu, 21 Dec 2023 21:34:55 +0000 Subject: [PATCH 45/62] change lr scheduler stuff --- lm_human_preference_details/summarize/ppo.py | 39 ++---------- .../summarize/reward.py | 5 +- lm_human_preference_details/summarize/sft.py | 39 +++++++----- train_ppo.sbatch | 37 +++++++++++ train_pythia.sbatch | 62 +++++++++++++++++++ train_reward.sbatch | 35 +++++++++++ train_sft.sbatch | 35 +++++++++++ 7 files changed, 198 insertions(+), 54 deletions(-) create mode 100644 train_ppo.sbatch create mode 100644 train_pythia.sbatch create mode 100644 train_reward.sbatch create mode 100644 train_sft.sbatch diff --git a/lm_human_preference_details/summarize/ppo.py b/lm_human_preference_details/summarize/ppo.py index 341e940..8236c72 100644 --- a/lm_human_preference_details/summarize/ppo.py +++ b/lm_human_preference_details/summarize/ppo.py @@ -44,7 +44,6 @@ class AdaptiveKLParams: class RewardHParams: use_adaptive_kl: bool = False adaptive_kl: Optional[AdaptiveKLParams] = field(default_factory=AdaptiveKLParams) - dataset_mean: float = 0.0 dataset_std: float = 1.0 kl_coef: float = 0.05 @@ -590,22 +589,10 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well # 2. run reward model on the truncated responses postprocessed_query_responses = torch.cat((queries, postprocessed_responses), 1) - # sequence_lengths = first_true_indices(postprocessed_responses == tokenizer.pad_token_id) - 1 - # actual_start = torch.arange(postprocessed_responses.size(0), device=postprocessed_responses.device) - # actual_end = sequence_lengths - # padding_mask = postprocessed_responses == tokenizer.pad_token_id sequence_lengths = first_true_indices(postprocessed_responses == tokenizer.pad_token_id) - 1 - full_values, _, _ = get_reward(accelerator.unwrap_model(model).critic, query_responses, tokenizer) values = full_values[:, context_length - 1 : -1].squeeze(-1) - # values_mask = postprocessed_responses != args.task.truncate_token_id - # values = torch.masked_fill(values, values_mask, 0) - # values = torch.masked_fill(values, padding_mask, 0) - - # logprobs = torch.masked_fill(logprobs, padding_mask, INVALID_LOGPROB) - # ref_logprobs = torch.masked_fill(ref_logprobs, padding_mask, INVALID_LOGPROB) _, scores, _ = get_reward(reward_model, postprocessed_query_responses, tokenizer) - _, validation_score, _ = get_reward(reward_model, postprocessed_sample_validation_query_responses, tokenizer) # 3. filter response. Ensure that the sample contains truncate_token_id @@ -614,16 +601,11 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well contain_pad_token = torch.any(postprocessed_responses == tokenizer.pad_token_id, dim=-1) scores = torch.where(contain_pad_token, scores, torch.full_like(scores, args.task.penalty_reward_value)) - # TODO: do we need to deal with penalty values? - # penalty_values = torch.full_like(values, 0) - # penalty_values[:,-1] += args.task.penalty_reward_value - # values = torch.where(contain_pad_token, values, penalty_values) accelerator.print(f"{scores=}, {(contain_pad_token.sum() / len(contain_pad_token))=}") # torch.cuda.empty_cache() # 4. compute rewards kl = logprobs - ref_logprobs - # kl = torch.masked_fill(kl, padding_mask, 0) non_score_reward = -kl_ctl.value * kl rewards = non_score_reward.clone() actual_start = torch.arange(rewards.size(0), device=rewards.device) @@ -673,8 +655,8 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well all_sample_validation_reference_responses, all_sample_validation_df, ) - # del postprocessed_query_responses - # torch.cuda.empty_cache() + del postprocessed_query_responses + torch.cuda.empty_cache() # 6. compute advantages and returns lastgaelam = 0 @@ -694,13 +676,7 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well writer.add_histogram("advantages", advantages[0].float(), global_step) accelerator.print("rewards====", rewards[0]) accelerator.print("advantages====", advantages[0]) - # raise - # pprint({ - # "rewards": rewards, - # "returns": returns, - # "advantages": advantages, - # }) - # breakpoint() + # Do multiple epochs of PPO training, with a fresh random shuffle in each epoch for ppo_epoch_idx in range(args.ppo.noptepochs): b_inds = np.random.permutation(args.local_batch_size) @@ -725,10 +701,7 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well logits /= args.task.temperature + 1e-7 new_all_logprobs = F.log_softmax(logits, dim=-1) new_logprobs = torch.gather(new_all_logprobs, 2, mb_responses.unsqueeze(-1)).squeeze(-1) - # new_logprobs = torch.masked_fill(new_logprobs, padding_mask[micro_batch_inds], INVALID_LOGPROB) vpred = vpred_temp[:, context_length - 1 : -1].squeeze(-1) - # vpred = torch.masked_fill(vpred, padding_mask[micro_batch_inds], 0) - # vpred = torch.masked_fill(vpred, values_mask[micro_batch_inds], 0) vpredclipped = torch.clamp( vpred, mb_values - args.ppo.cliprange_value, @@ -794,8 +767,6 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well "ratio", ratio_stats[: ppo_epoch_idx + 1].mean().item(), ) - # raise - # breakpoint() with torch.no_grad(): if not args.deepspeed: # for some reason there is a OOM with the `writer.add_histogram` writer.add_histogram("ppo/val/ratio_hist", ratio, update) @@ -854,14 +825,14 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well if args.push_to_hub: tokenizer.push_to_hub(repo_id, revision=f"seed{args.seed}_{str(time_int)}") - unwrapped: PreTrainedModel = accelerator.unwrap_model(policy) + unwrapped: PreTrainedModel = accelerator.unwrap_model(model).policy accelerator.wait_for_everyone() if accelerator.is_main_process: unwrapped.save_pretrained( args.output_dir, is_main_process=accelerator.is_main_process, save_function=accelerator.save, - state_dict=accelerator.get_state_dict(policy), + state_dict=accelerator.get_state_dict(unwrapped), safe_serialization=False, repo_id=repo_id, ) diff --git a/lm_human_preference_details/summarize/reward.py b/lm_human_preference_details/summarize/reward.py index 7b13a42..215dd35 100644 --- a/lm_human_preference_details/summarize/reward.py +++ b/lm_human_preference_details/summarize/reward.py @@ -295,7 +295,7 @@ def evaluate(args: Args, accelerator, tokenizer, model, dataloader): accelerator.print("The number of samples in dataset", len(dataset)) accelerator.print("The number of samples in validation_dataset", len(validation_dataset)) args.total_episodes = len(dataset) - args.num_updates = args.total_episodes // args.local_batch_size + args.num_updates = args.total_episodes // args.batch_size console = Console(force_terminal=True) run_name = f"{args.exp_name}__{args.seed}__{int(time.time())}" @@ -358,7 +358,6 @@ def evaluate(args: Args, accelerator, tokenizer, model, dataloader): deepspeed_states.deepspeed_config["train_micro_batch_size_per_gpu"] = args.local_micro_batch_size model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader) - # scheduler = accelerator.prepare(scheduler) # breaks with accelerate@0.25.0 validation_dataloader = accelerator.prepare(validation_dataloader) accelerator.print("===training model===") @@ -393,13 +392,13 @@ def evaluate(args: Args, accelerator, tokenizer, model, dataloader): accelerator.backward(loss) optimizer.step() optimizer.zero_grad() - scheduler.step() losses[gradient_accumulation_idx] = loss accuracies[gradient_accumulation_idx] = accuracy reward_preferreds[gradient_accumulation_idx] = reward_preferred.mean() reward_rejecteds[gradient_accumulation_idx] = reward_rejected.mean() gradient_accumulation_idx = (gradient_accumulation_idx + 1) % args.gradient_accumulation_steps if update > 1 and (update - 1) % args.gradient_accumulation_steps == 0: + scheduler.step() train_accuracy = accelerator.gather(accuracies).mean().item() writer.add_scalar("train/rm/loss", accelerator.gather(losses).mean().item(), global_step) writer.add_scalar("train/rm/accuracy", train_accuracy, global_step) diff --git a/lm_human_preference_details/summarize/sft.py b/lm_human_preference_details/summarize/sft.py index f0c49aa..fba21fd 100644 --- a/lm_human_preference_details/summarize/sft.py +++ b/lm_human_preference_details/summarize/sft.py @@ -13,6 +13,8 @@ import torch.optim as optim import tyro from accelerate import Accelerator +from accelerate.state import AcceleratorState +from accelerate.utils import gather_object from datasets import load_dataset from rich.console import Console from rich.pretty import pprint @@ -181,10 +183,10 @@ def evaluate(args: Args, accelerator, tokenizer, model, dataloader, generation_c model.eval() rouge_scores = collections.defaultdict(list) all_decode_queries = [] - all_decode_query_responses = [] all_decode_responses = [] all_decode_reference_responses = [] all_losses = [] + unwrapped = accelerator.unwrap_model(model) for _, data in tqdm(enumerate(dataloader)): with torch.no_grad(): reference_responses = data["reference_response_token"] @@ -206,28 +208,28 @@ def evaluate(args: Args, accelerator, tokenizer, model, dataloader, generation_c all_losses.append(loss) generated_responses = generate( - accelerator.unwrap_model(model), + unwrapped, queries, tokenizer, generation_config, ) - decode_queries = tokenizer.batch_decode(accelerator.gather(queries)) - decode_query_responses = tokenizer.batch_decode(accelerator.gather(generated_responses)) + decode_queries = tokenizer.batch_decode(queries) decode_reference_responses = tokenizer.batch_decode( - accelerator.gather(reference_responses), + reference_responses, skip_special_tokens=True, ) decode_responses = tokenizer.batch_decode( - accelerator.gather(generated_responses[:, -args.task.response_length :]), + generated_responses[:, -args.task.response_length :], skip_special_tokens=True, ) rouge_score = rouge.compute(predictions=decode_responses, references=decode_reference_responses) - rouge_scores["rouge1"].append(rouge_score["rouge1"]) - rouge_scores["rouge2"].append(rouge_score["rouge2"]) - rouge_scores["rougeL"].append(rouge_score["rougeL"]) - + decode_queries = gather_object(decode_queries) + decode_responses = gather_object(decode_responses) + decode_reference_responses = gather_object(decode_reference_responses) + rouge_scores["rouge1"].append(np.mean(gather_object([rouge_score["rouge1"]]))) + rouge_scores["rouge2"].append(np.mean(gather_object([rouge_score["rouge2"]]))) + rouge_scores["rougeL"].append(np.mean(gather_object([rouge_score["rougeL"]]))) all_decode_queries.extend(decode_queries) - all_decode_query_responses.extend(decode_query_responses) all_decode_responses.extend(decode_responses) all_decode_reference_responses.extend(decode_reference_responses) return ( @@ -264,7 +266,7 @@ def evaluate(args: Args, accelerator, tokenizer, model, dataloader, generation_c accelerator.print("The number of samples in dataset", len(dataset)) accelerator.print("The number of samples in validation_dataset", len(validation_dataset)) args.total_episodes = len(dataset) - args.num_updates = args.total_episodes // args.local_batch_size + args.num_updates = args.total_episodes // args.batch_size console = Console(force_terminal=True) run_name = f"{args.exp_name}__{args.seed}__{int(time.time())}" @@ -303,8 +305,6 @@ def evaluate(args: Args, accelerator, tokenizer, model, dataloader, generation_c tokenizer.add_special_tokens({"pad_token": "[PAD]"}) model_config = AutoConfig.from_pretrained(args.base_model) configure_dropout(model_config, args.dropout_layer_keys, 0.0) # disable dropout - if accelerator.is_main_process: - pprint(model_config) model: PreTrainedModel = AutoModelForCausalLM.from_pretrained( args.base_model, config=model_config, @@ -312,6 +312,8 @@ def evaluate(args: Args, accelerator, tokenizer, model, dataloader, generation_c ) model.generation_config.eos_token_id = None # disable `pad_token_id` and `eos_token_id` because we just want to model.generation_config.pad_token_id = None # generate tokens without truncation / padding + if accelerator.is_main_process: + pprint(model_config) if args.optimizer == "adam": optimizer = optim.Adam(model.parameters(), lr=args.lr, eps=args.eps) elif args.optimizer == "adamw": @@ -323,8 +325,11 @@ def evaluate(args: Args, accelerator, tokenizer, model, dataloader, generation_c num_training_steps=args.num_updates * args.num_train_epochs, ) + if args.deepspeed: + deepspeed_states = AcceleratorState().deepspeed_plugin + deepspeed_states.deepspeed_config["train_micro_batch_size_per_gpu"] = args.local_micro_batch_size + model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader) - # scheduler = accelerator.prepare(scheduler) # breaks with accelerate@0.25.0 validation_dataloader = accelerator.prepare(validation_dataloader) # WARNING: even with `max_new_tokens` and `min_new_tokens` set to the same value, the number of tokens generated # may not be the same. TODO: investigate further, we just want to generate a fixed number of tokens @@ -364,17 +369,17 @@ def evaluate(args: Args, accelerator, tokenizer, model, dataloader, generation_c accelerator.backward(loss) optimizer.step() optimizer.zero_grad() - scheduler.step() loss_stats[gradient_accumulation_idx] = loss gradient_accumulation_idx = (gradient_accumulation_idx + 1) % args.gradient_accumulation_steps if update > 1 and (update - 1) % args.gradient_accumulation_steps == 0: + scheduler.step() writer.add_scalar("train/sft/loss", accelerator.gather(loss_stats).mean().item(), update) writer.add_scalar("train/sft/lr", scheduler.get_last_lr()[0], update) accelerator.print(f"{loss.item()=}, {scheduler.get_last_lr()=}, {update=}") if args.run_eval: evaluate_df, rouge_scores, all_validation_losses = evaluate( - args, accelerator, tokenizer, model, dataloader, generation_config + args, accelerator, tokenizer, model, validation_dataloader, generation_config ) if accelerator.is_main_process and args.track: wandb.log({"samples/query_responses": wandb.Table(dataframe=evaluate_df)}, step=update) diff --git a/train_ppo.sbatch b/train_ppo.sbatch new file mode 100644 index 0000000..4df08ad --- /dev/null +++ b/train_ppo.sbatch @@ -0,0 +1,37 @@ +#!/bin/bash +#SBATCH --partition=hopper-prod +#SBATCH --gpus-per-task=8 +#SBATCH --cpus-per-gpu=10 +#SBATCH --ntasks=1 +#SBATCH --output=slurm/logs/%x_%j.out +#SBATCH --exclusive + +module load cuda/12.2 + +if [ -z "$SEED" ]; then + SEED=1 +fi +if [ -z "$MODEL" ]; then + MODEL=EleutherAI/pythia-2.8b-deduped +fi +if [ -z "$LR" ]; then + LR=3e-6 +fi + +REWARD_MODEL_PATH=models/$MODEL/reward_model_$SEED +SFT_MODEL_PATH=models/$MODEL/sft_model_$SEED +POLICY_MODEL_PATH=models/$MODEL/policy_model_$SEED +srun poetry run accelerate launch --config_file deepspeed.yaml \ + lm_human_preference_details/summarize/ppo.py \ + --task.query_dataset=vwxyzjn/summarize_from_feedback_tldr_3_filtered_oai_preprocessing_pythia-160m_53 \ + --base_model=$MODEL \ + --sft_model_path=$SFT_MODEL_PATH \ + --reward_model_path=$REWARD_MODEL_PATH \ + --reward.dataset_mean=0.5226626396179199 \ + --lr=$LR \ + --deepspeed \ + --push_to_hub \ + --track \ + --output_dir=$POLICY_MODEL_PATH \ + --seed=$SEED + \ No newline at end of file diff --git a/train_pythia.sbatch b/train_pythia.sbatch new file mode 100644 index 0000000..86a8f55 --- /dev/null +++ b/train_pythia.sbatch @@ -0,0 +1,62 @@ +#!/bin/bash +#SBATCH --partition=hopper-prod +#SBATCH --gpus-per-task=8 +#SBATCH --cpus-per-gpu=10 +#SBATCH --ntasks=1 +#SBATCH --output=slurm/logs/%x_%j.out +#SBATCH --exclusive + +module load cuda/12.2 + +if [ -z "$SEED" ]; then + SEED=3 +fi +if [ -z "$MODEL" ]; then + MODEL=EleutherAI/pythia-2.8b-deduped + # MODEL=EleutherAI/pythia-1b-deduped + # MODEL=EleutherAI/pythia-160m +fi +if [ -z "$LR" ]; then + LR=3e-6 +fi + +REWARD_MODEL_PATH=models/$MODEL/reward_model_$SEED +SFT_MODEL_PATH=models/$MODEL/sft_model_$SEED +POLICY_MODEL_PATH=models/$MODEL/policy_model_$SEED +srun poetry run accelerate launch --config_file deepspeed.yaml \ + lm_human_preference_details/summarize/sft.py \ + --task.query_dataset=vwxyzjn/summarize_from_feedback_tldr_3_filtered_oai_preprocessing_pythia-160m_53 \ + --base_model=$MODEL \ + --lr=$LR \ + --deepspeed \ + --run_eval \ + --push_to_hub \ + --track \ + --output_dir=$SFT_MODEL_PATH \ + --seed=$SEED + + srun poetry run accelerate launch --config_file deepspeed.yaml \ + lm_human_preference_details/summarize/reward.py \ + --label_dataset=vwxyzjn/summarize_from_feedback_oai_preprocessing_pythia-160m_169 \ + --base_model=$SFT_MODEL_PATH \ + --lr=$LR \ + --deepspeed \ + --run_eval \ + --push_to_hub \ + --track \ + --output_dir=$REWARD_MODEL_PATH \ + --seed=$SEED + + srun poetry run accelerate launch --config_file deepspeed.yaml \ + lm_human_preference_details/summarize/ppo.py \ + --task.query_dataset=vwxyzjn/summarize_from_feedback_tldr_3_filtered_oai_preprocessing_pythia-160m_53 \ + --base_model=$MODEL \ + --sft_model_path=$SFT_MODEL_PATH \ + --reward_model_path=$REWARD_MODEL_PATH \ + --lr=$LR \ + --deepspeed \ + --push_to_hub \ + --track \ + --output_dir=$POLICY_MODEL_PATH \ + --seed=$SEED + \ No newline at end of file diff --git a/train_reward.sbatch b/train_reward.sbatch new file mode 100644 index 0000000..e488f38 --- /dev/null +++ b/train_reward.sbatch @@ -0,0 +1,35 @@ +#!/bin/bash +#SBATCH --partition=hopper-prod +#SBATCH --gpus-per-task=8 +#SBATCH --cpus-per-gpu=10 +#SBATCH --ntasks=1 +#SBATCH --output=slurm/logs/%x_%j.out +#SBATCH --exclusive + +module load cuda/12.2 + +if [ -z "$SEED" ]; then + SEED=1 +fi +if [ -z "$MODEL" ]; then + MODEL=EleutherAI/pythia-2.8b-deduped +fi +if [ -z "$LR" ]; then + LR=3e-6 +fi + +REWARD_MODEL_PATH=models/$MODEL/reward_model_$SEED +SFT_MODEL_PATH=models/$MODEL/sft_model_$SEED +POLICY_MODEL_PATH=models/$MODEL/policy_model_$SEED +srun poetry run accelerate launch --config_file deepspeed.yaml \ + lm_human_preference_details/summarize/reward.py \ + --label_dataset=vwxyzjn/summarize_from_feedback_oai_preprocessing_pythia-160m_169 \ + --base_model=$SFT_MODEL_PATH \ + --lr=$LR \ + --deepspeed \ + --run_eval \ + --push_to_hub \ + --track \ + --output_dir=$REWARD_MODEL_PATH \ + --seed=$SEED + \ No newline at end of file diff --git a/train_sft.sbatch b/train_sft.sbatch new file mode 100644 index 0000000..4b59fdc --- /dev/null +++ b/train_sft.sbatch @@ -0,0 +1,35 @@ +#!/bin/bash +#SBATCH --partition=hopper-prod +#SBATCH --gpus-per-task=8 +#SBATCH --cpus-per-gpu=10 +#SBATCH --ntasks=1 +#SBATCH --output=slurm/logs/%x_%j.out +#SBATCH --exclusive + +module load cuda/12.2 + +if [ -z "$SEED" ]; then + SEED=1 +fi +if [ -z "$MODEL" ]; then + MODEL=EleutherAI/pythia-2.8b-deduped +fi +if [ -z "$LR" ]; then + LR=3e-6 +fi + +REWARD_MODEL_PATH=models/$MODEL/reward_model_$SEED +SFT_MODEL_PATH=models/$MODEL/sft_model_$SEED +POLICY_MODEL_PATH=models/$MODEL/policy_model_$SEED +srun poetry run accelerate launch --config_file deepspeed.yaml \ + lm_human_preference_details/summarize/sft.py \ + --task.query_dataset=vwxyzjn/summarize_from_feedback_tldr_3_filtered_oai_preprocessing_pythia-160m_53 \ + --base_model=$MODEL \ + --lr=$LR \ + --deepspeed \ + --run_eval \ + --push_to_hub \ + --track \ + --output_dir=$SFT_MODEL_PATH \ + --seed=$SEED + \ No newline at end of file From 0d4ddfacec17dc55780218ac8ab718c479f30bff Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Tue, 26 Dec 2023 05:13:49 +0000 Subject: [PATCH 46/62] support offload / 6.9b model --- lm_human_preference_details/summarize/ppo.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/lm_human_preference_details/summarize/ppo.py b/lm_human_preference_details/summarize/ppo.py index 8236c72..69685eb 100644 --- a/lm_human_preference_details/summarize/ppo.py +++ b/lm_human_preference_details/summarize/ppo.py @@ -170,6 +170,8 @@ class Args: # other args base_model: str = "EleutherAI/pythia-160m" """the name of the pretrained model to use""" + offload: bool = False + """Whether to offload ref policy and reward model to CPU""" reward_model_path: str = "" """the name of the pretrained model to use""" sft_model_path: str = "EleutherAI/pythia-160m" @@ -460,17 +462,16 @@ def forward(model, query_responses, tokenizer): import deepspeed deepspeed_states = AcceleratorState().deepspeed_plugin - # deepspeed_states.deepspeed_config["train_micro_batch_size_per_gpu"] = args.local_micro_batch_size - # deepspeed_states.deepspeed_config["checkpoint"] = {"use_node_local_storage": True} + deepspeed_states.deepspeed_config["train_micro_batch_size_per_gpu"] = args.local_micro_batch_size - offload = False eval_ds_config = { "train_micro_batch_size_per_gpu": deepspeed_states.deepspeed_config["train_micro_batch_size_per_gpu"], "bf16": {"enabled": True}, "prescale_gradients": False, "wall_clock_breakdown": False, } - if offload: + if args.offload: + deepspeed_states.deepspeed_config["checkpoint"] = {"use_node_local_storage": True} eval_ds_config["zero_optimization"] = { "stage": 3, "stage3_param_persistence_threshold": 1e4, From 0472f4de941ba795862d95218877cc75d56bcbb1 Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Tue, 26 Dec 2023 05:31:08 +0000 Subject: [PATCH 47/62] sft / reward without padding --- .../summarize/reward.py | 13 ++- lm_human_preference_details/summarize/sft.py | 10 +- lm_human_preference_details/tldr_dataset.py | 108 +++++++++++++++--- 3 files changed, 105 insertions(+), 26 deletions(-) diff --git a/lm_human_preference_details/summarize/reward.py b/lm_human_preference_details/summarize/reward.py index 215dd35..8fcba8f 100644 --- a/lm_human_preference_details/summarize/reward.py +++ b/lm_human_preference_details/summarize/reward.py @@ -272,7 +272,7 @@ def evaluate(args: Args, accelerator, tokenizer, model, dataloader): dataset = dataset.shuffle(seed=local_seed) dataset = dataset.select(range(args.label.num_train)) dataset = dataset.with_format( - "torch", columns=["query_token", "choice", "response0_token", "response1_token", "batch", "split"] + "torch", columns=["query_token", "choice", "response0_token", "query_response0_token", "response1_token", "query_response1_token", "batch", "split"] ) dataloader = DataLoader(dataset, batch_size=args.local_micro_batch_size) validation_dataset = load_dataset(args.label_dataset, "comparisons", split="validation").flatten() @@ -374,11 +374,12 @@ def evaluate(args: Args, accelerator, tokenizer, model, dataloader): for data in dataloader: update += 1 global_step += args.micro_batch_size - mb_query = data["query_token"] - mb_responses = torch.cat([data[f"response0_token"].unsqueeze(1), data[f"response1_token"].unsqueeze(1)], dim=1) + query_responses = torch.cat([data['query_response0_token'].unsqueeze(1), data['query_response1_token'].unsqueeze(1)], dim=1).flatten(0, 1) mb_best = data["choice"] - mb_query_tiled = mb_query.unsqueeze(1).repeat(1, args.label.num_labels, 1) - query_responses = torch.cat([mb_query_tiled, mb_responses], dim=2).flatten(0, 1) + # mb_query = data["query_token"] + # mb_responses = torch.cat([data[f"response0_token"].unsqueeze(1), data[f"response1_token"].unsqueeze(1)], dim=1) + # mb_query_tiled = mb_query.unsqueeze(1).repeat(1, args.label.num_labels, 1) + # query_responses = torch.cat([mb_query_tiled, mb_responses], dim=2).flatten(0, 1) with accelerator.accumulate(model): predicted_reward = get_reward(model, query_responses, tokenizer) predicted_reward = predicted_reward.view(-1, args.label.num_labels) @@ -407,7 +408,7 @@ def evaluate(args: Args, accelerator, tokenizer, model, dataloader): ) writer.add_scalar("train/rm/reward_rejected", accelerator.gather(reward_rejecteds).mean().item(), global_step) writer.add_scalar("train/rm/lr", scheduler.get_last_lr()[0], global_step) - accelerator.print(f"{train_accuracy=}, {scheduler.get_last_lr()=}, {update=}") + accelerator.print(f"{train_accuracy=}, {scheduler.get_last_lr()=}, {optimizer.param_groups[0]['lr']=}, {update=}") if args.run_eval: evaluate_df = evaluate(args, accelerator, tokenizer, model, validation_dataloader) diff --git a/lm_human_preference_details/summarize/sft.py b/lm_human_preference_details/summarize/sft.py index fba21fd..420955a 100644 --- a/lm_human_preference_details/summarize/sft.py +++ b/lm_human_preference_details/summarize/sft.py @@ -258,7 +258,7 @@ def evaluate(args: Args, accelerator, tokenizer, model, dataloader, generation_c # load dataset dataset = load_dataset(args.task.query_dataset, split="train") dataset = dataset.shuffle(seed=local_seed) - dataset = dataset.with_format("torch", columns=["query_token", "reference_response_token"]) + dataset = dataset.with_format("torch", columns=["query_reference_response_token"]) dataloader = DataLoader(dataset, batch_size=args.local_micro_batch_size) validation_dataset = load_dataset(args.task.query_dataset, split="validation") validation_dataset = validation_dataset.with_format("torch", columns=["query_token", "reference_response_token"]) @@ -353,9 +353,9 @@ def evaluate(args: Args, accelerator, tokenizer, model, dataloader, generation_c for data in dataloader: update += 1 global_step += args.micro_batch_size - reference_responses = data["reference_response_token"].to(device, non_blocking=True) - queries = data["query_token"].to(device, non_blocking=True) - query_responses = torch.cat((queries, reference_responses), dim=1) + # reference_responses = data["reference_response_token"].to(device, non_blocking=True) + # queries = data["query_token"].to(device, non_blocking=True) + query_responses = data["query_reference_response_token"] with accelerator.accumulate(model): output = forward(model, query_responses, tokenizer) # mask out gradient effects on response padding tokens @@ -375,7 +375,7 @@ def evaluate(args: Args, accelerator, tokenizer, model, dataloader, generation_c scheduler.step() writer.add_scalar("train/sft/loss", accelerator.gather(loss_stats).mean().item(), update) writer.add_scalar("train/sft/lr", scheduler.get_last_lr()[0], update) - accelerator.print(f"{loss.item()=}, {scheduler.get_last_lr()=}, {update=}") + accelerator.print(f"{loss.item()=}, {scheduler.get_last_lr()=}, {optimizer.param_groups[0]['lr']=}, {update=}") if args.run_eval: evaluate_df, rouge_scores, all_validation_losses = evaluate( diff --git a/lm_human_preference_details/tldr_dataset.py b/lm_human_preference_details/tldr_dataset.py index d945428..6d406da 100644 --- a/lm_human_preference_details/tldr_dataset.py +++ b/lm_human_preference_details/tldr_dataset.py @@ -10,6 +10,9 @@ from huggingface_hub import HfApi from rich.pretty import pprint from transformers import AutoTokenizer +from huggingface_hub.repocard import RepoCard +from pprint import pformat + api = HfApi() @@ -17,9 +20,17 @@ """ poetry run python lm_human_preference_details/tldr_dataset.py poetry run python lm_human_preference_details/tldr_dataset.py \ - --base-model=EleutherAI/pythia-160m \ - --max-sft-response-length=53 \ - --max-rm-response-length=169 + --base_model=EleutherAI/pythia-160m \ + --max_sft_response_length=53 \ + --max_sft_query_response_length=562 \ + --max-rm-response-length=169 \ + --max_rm_query_response_length=638 +poetry run python lm_human_preference_details/tldr_dataset.py \ + --base_model=EleutherAI/pythia-160m \ + --max_sft_response_length=48 \ + --max_sft_query_response_length=560 \ + --max-rm-response-length=48 \ + --max_rm_query_response_length=560 """ @@ -27,7 +38,9 @@ class Args: base_model: str = "gpt2" # EleutherAI/pythia-160m max_sft_response_length: int = 48 # 53 + max_sft_query_response_length: int = 512 + 48 # 565 max_rm_response_length: int = 153 # 169 + max_rm_query_response_length: int = 512 + 153 # 665 hf_entity: str = None @@ -128,7 +141,7 @@ def process_query_data(x): # DOES NOT HAVE a leading space so we are adding the leading space and # `<|endoftext|>` token reference_response = f" {x['summary']}<|endoftext|>" - return { + y = { **process_query(x, encoder=tokenizer, hparams=oai_h), "reference_response": reference_response, "reference_response_token": tokenizer.encode( @@ -139,11 +152,56 @@ def process_query_data(x): ), "reference_response_token_len": len(tokenizer.encode(reference_response)), } + y["query_reference_response"] = y["query"].strip() + y["reference_response"] + y["query_reference_response_token"] = tokenizer.encode( + y["query_reference_response"], + padding="max_length", + max_length=args.max_sft_query_response_length, + truncation=True, + ) + y["query_reference_response_token_len"] = len(tokenizer.encode(y["query_reference_response"])) + return y sft_ds = sft_ds.map(process_query_data, load_from_cache_file=False, num_proc=multiprocessing.cpu_count()) sft_ds.push_to_hub( f"{args.hf_entity}/summarize_from_feedback_tldr_3_filtered_oai_preprocessing_{args.base_model.split('/')[-1]}_{args.max_sft_response_length}" ) + sft_card = RepoCard.load(f"{args.hf_entity}/summarize_from_feedback_tldr_3_filtered_oai_preprocessing_{args.base_model.split('/')[-1]}_{args.max_sft_response_length}", repo_type="dataset") + sft_card.text = f"""\ +# TL;DR SFT Dataset for OpenAI's [Summarize from Feedback](https://openai.com/blog/summarization/) task + +The dataset is directly taken from https://github.com/openai/summarize-from-feedback/tree/700967448d10004279f138666442bf1497d0e705#reddit-tldr-dataset + +These columns are taken directly from the aforementioned dataset: + +* **id**: unique identifier for the post +* **subreddit**: subreddit the post was taken from +* **title**: title of the post +* **post**: body of the post +* **summary**: summary of the post +* **reference_response**: reference response for the post + +These columns are added by this preprocessing script: +* **query**: length-limited query for summarization: OAI pre-processes the main text (title + subreddit + post), ensuring it has only 512 tokens; if the main text is too long, then it tries to truncate at the last `\n`. If it's too short it pads the main text ([summarize_from_feedback/tasks.py#L98-L165](https://github.com/openai/summarize-from-feedback/blob/700967448d10004279f138666442bf1497d0e705/summarize_from_feedback/tasks.py#L98-L165)). Padding is either space or `[PAD]` token (see Args below). +* **query_token**: tokenized version of `query` +* **reference_response_token**: tokenized version of `reference_response` +* **reference_response_token_len**: length of `reference_response_token` +* **query_reference_response**: concatenation of `query.strip()` and `reference_response` +* **query_reference_response_token**: tokenized version of `query_reference_response`, up to `max_sft_query_response_length` tokens +* **query_reference_response_token_len**: length of `query_reference_response_token` + + +# Args + +```python +{pformat(vars(args))} +{pformat(vars(oai_h))} +``` +""" + sft_card.push_to_hub( + f"{args.hf_entity}/summarize_from_feedback_tldr_3_filtered_oai_preprocessing_{args.base_model.split('/')[-1]}_{args.max_sft_response_length}", + repo_type="dataset", + ) label_ds = load_dataset("openai/summarize_from_feedback", "comparisons") @@ -155,7 +213,7 @@ def process_response_data(x): response0_policy = x["summaries"][0]["policy"] response1_policy = x["summaries"][1]["policy"] policies = "--".join(sorted([response0_policy, response1_policy])) - return { + y = { **process_query(x["info"], encoder=tokenizer, hparams=oai_h), "response0": response0, "response0_token": tokenizer.encode( @@ -171,6 +229,17 @@ def process_response_data(x): "response1_policy": response1_policy, "policies": policies, } + y["query_response0"] = y["query"].strip() + y["response0"] + y["query_response0_token"] = tokenizer.encode( + y["query_response0"], padding="max_length", max_length=args.max_rm_query_response_length, truncation=True + ) + y["query_response0_token_len"] = len(tokenizer.encode(y["query_response0"])) + y["query_response1"] = y["query"].strip() + y["response1"] + y["query_response1_token"] = tokenizer.encode( + y["query_response1"], padding="max_length", max_length=args.max_rm_query_response_length, truncation=True + ) + y["query_response1_token_len"] = len(tokenizer.encode(y["query_response1"])) + return y label_ds = label_ds.map(process_response_data, load_from_cache_file=False, num_proc=multiprocessing.cpu_count()) label_ds.push_to_hub( @@ -179,21 +248,30 @@ def process_response_data(x): os.makedirs("dataset_visuals", exist_ok=True) # visualize token length distribution - num_subplots = len(sft_ds) + len(label_ds) * 2 + num_subplots = len(sft_ds) * 2 + len(label_ds) * 4 print(f"{num_subplots=}") - fig, axs = plt.subplots(3, 3, figsize=(16, 16)) + fig, axs = plt.subplots(5, 3, figsize=(16, 16)) axs = axs.flatten() - for i, key in enumerate(sft_ds.keys()): + j = 0 + for _, key in enumerate(sft_ds.keys()): df = sft_ds[key].to_pandas() - axs[i].hist(df["reference_response_token_len"], bins=100) - axs[i].set_title(f"{key} split: reference response token length\nmax_length={max(df['reference_response_token_len'])}") + axs[j].hist(df["reference_response_token_len"], bins=100) + axs[j].set_title(f"{key} split: reference response token length\nmax_length={max(df['reference_response_token_len'])}") + axs[j + 1].hist(df["query_reference_response_token_len"], bins=100) + axs[j + 1].set_title(f"{key} split: query.strip() + reference response token length\nmax_length={max(df['query_reference_response_token_len'])}") + j += 2 offset = len(sft_ds) - for i, key in enumerate(label_ds.keys()): + for _, key in enumerate(label_ds.keys()): df = label_ds[key].to_pandas() - axs[2 * i + offset].hist(df["response0_token_len"], bins=100) - axs[2 * i + offset].set_title(f"{key} split: response0 token length\nmax_length={max(df['response0_token_len'])}") - axs[2 * i + offset + 1].hist(df["response1_token_len"], bins=100) - axs[2 * i + offset + 1].set_title(f"{key} split: response1 token length\nmax_length={max(df['response1_token_len'])}") + axs[j].hist(df["response0_token_len"], bins=100) + axs[j].set_title(f"{key} split: response0 token length\nmax_length={max(df['response0_token_len'])}") + axs[j + 1].hist(df["response1_token_len"], bins=100) + axs[j + 1].set_title(f"{key} split: response1 token length\nmax_length={max(df['response1_token_len'])}") + axs[j + 2].hist(df["query_response0_token_len"], bins=100) + axs[j + 2].set_title(f"{key} split: query.strip() + response0 token length\nmax_length={max(df['query_response0_token_len'])}") + axs[j + 3].hist(df["query_response1_token_len"], bins=100) + axs[j + 3].set_title(f"{key} split: query.strip() + response1 token length\nmax_length={max(df['query_response1_token_len'])}") + j += 4 fig.suptitle(f"{args.base_model} Tokenizer: Token length distribution") fig.tight_layout() fig.savefig("dataset_visuals/token_len.png") From 451ec856a5273257a00ddc78b3edbe9371e78345 Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Tue, 26 Dec 2023 05:31:41 +0000 Subject: [PATCH 48/62] update benchmark.py --- benchmark/benchmark.py | 96 ++++++++++++++++++++++++------------------ 1 file changed, 54 insertions(+), 42 deletions(-) diff --git a/benchmark/benchmark.py b/benchmark/benchmark.py index 5b433d2..fb36c0d 100644 --- a/benchmark/benchmark.py +++ b/benchmark/benchmark.py @@ -1,66 +1,77 @@ -import argparse import math import os import shlex import subprocess import uuid -from distutils.util import strtobool +from dataclasses import dataclass +from typing import Optional import requests - - -def parse_args(): - # fmt: off - parser = argparse.ArgumentParser() - parser.add_argument("--command", type=str, default="", - help="the command to run") - parser.add_argument("--num-seeds", type=int, default=3, - help="the number of random seeds") - parser.add_argument("--start-seed", type=int, default=1, - help="the number of the starting seed") - parser.add_argument("--workers", type=int, default=0, - help="the number of workers to run benchmark experimenets") - parser.add_argument("--auto-tag", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, - help="if toggled, the runs will be tagged with git tags, commit, and pull request number if possible") - parser.add_argument("--slurm-template-path", type=str, default=None, - help="the path to the slurm template file (see docs for more details)") - parser.add_argument("--slurm-gpus-per-task", type=int, default=1, - help="the number of gpus per task to use for slurm jobs") - parser.add_argument("--slurm-total-cpus", type=int, default=50, - help="the number of gpus per task to use for slurm jobs") - parser.add_argument("--slurm-ntasks", type=int, default=1, - help="the number of tasks to use for slurm jobs") - parser.add_argument("--slurm-nodes", type=int, default=None, - help="the number of nodes to use for slurm jobs") - args = parser.parse_args() - # fmt: on - return args +import tyro + + +@dataclass +class Args: + command: str + """the command to run""" + num_seeds: int = 3 + """the number of random seeds""" + start_seed: int = 1 + """the number of the starting seed""" + workers: int = 0 + """the number of workers to run benchmark experimenets""" + auto_tag: bool = True + """if toggled, the runs will be tagged with git tags, commit, and pull request number if possible""" + slurm_template_path: Optional[str] = None + """the path to the slurm template file (see docs for more details)""" + slurm_gpus_per_task: Optional[int] = None + """the number of gpus per task to use for slurm jobs""" + slurm_total_cpus: Optional[int] = None + """the number of gpus per task to use for slurm jobs""" + slurm_ntasks: Optional[int] = None + """the number of tasks to use for slurm jobs""" + slurm_nodes: Optional[int] = None + """the number of nodes to use for slurm jobs""" def run_experiment(command: str): command_list = shlex.split(command) print(f"running {command}") - fd = subprocess.Popen(command_list) - return_code = fd.wait() - assert return_code == 0 + + # Use subprocess.PIPE to capture the output + fd = subprocess.Popen(command_list, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + output, errors = fd.communicate() + + return_code = fd.returncode + assert return_code == 0, f"Command failed with error: {errors.decode('utf-8')}" + + # Convert bytes to string and strip leading/trailing whitespaces + return output.decode("utf-8").strip() def autotag() -> str: wandb_tag = "" print("autotag feature is enabled") + git_tag = "" try: git_tag = subprocess.check_output(["git", "describe", "--tags"]).decode("ascii").strip() - wandb_tag = f"{git_tag}" print(f"identified git tag: {git_tag}") - except subprocess.CalledProcessError: - return wandb_tag + except subprocess.CalledProcessError as e: + print(e) + if len(git_tag) == 0: + try: + count = int(subprocess.check_output(["git", "rev-list", "--count", "HEAD"]).decode("ascii").strip()) + hash = subprocess.check_output(["git", "rev-parse", "--short", "HEAD"]).decode("ascii").strip() + git_tag = f"no-tag-{count}-g{hash}" + print(f"identified git tag: {git_tag}") + except subprocess.CalledProcessError as e: + print(e) + wandb_tag = git_tag git_commit = subprocess.check_output(["git", "rev-parse", "--verify", "HEAD"]).decode("ascii").strip() try: # try finding the pull request number on github - prs = requests.get( - f"https://api.github.com/search/issues?q=repo:vwxyzjn/lm-human-preference-details+is:pr+{git_commit}" - ) + prs = requests.get(f"https://api.github.com/search/issues?q=repo:vwxyzjn/cleanrl+is:pr+{git_commit}") if prs.status_code == 200: prs = prs.json() if len(prs["items"]) > 0: @@ -75,7 +86,7 @@ def autotag() -> str: if __name__ == "__main__": - args = parse_args() + args = tyro.cli(Args) if args.auto_tag: existing_wandb_tag = os.environ.get("WANDB_TAGS", "") wandb_tag = autotag() @@ -84,7 +95,7 @@ def autotag() -> str: os.environ["WANDB_TAGS"] = ",".join([existing_wandb_tag, wandb_tag]) else: os.environ["WANDB_TAGS"] = wandb_tag - + print("WANDB_TAGS: ", os.environ.get("WANDB_TAGS", "")) commands = [] for seed in range(0, args.num_seeds): commands += [" ".join([args.command, "--seed", str(args.start_seed + seed)])] @@ -133,4 +144,5 @@ def autotag() -> str: slurm_path = os.path.join("slurm", f"{filename}.slurm") print(f"saving command in {slurm_path}") if args.workers > 0: - run_experiment(f"sbatch {slurm_path}") + job_id = run_experiment(f"sbatch --parsable {slurm_path}") + print(f"Job ID: {job_id}") From 2cbb1f7dfd17399db00248921bce2953a25362e5 Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Tue, 26 Dec 2023 05:31:55 +0000 Subject: [PATCH 49/62] precommit --- .../summarize/reward.py | 20 +++++++++++--- lm_human_preference_details/tldr_dataset.py | 26 ++++++++++++------- 2 files changed, 34 insertions(+), 12 deletions(-) diff --git a/lm_human_preference_details/summarize/reward.py b/lm_human_preference_details/summarize/reward.py index 8fcba8f..1a3054b 100644 --- a/lm_human_preference_details/summarize/reward.py +++ b/lm_human_preference_details/summarize/reward.py @@ -272,7 +272,17 @@ def evaluate(args: Args, accelerator, tokenizer, model, dataloader): dataset = dataset.shuffle(seed=local_seed) dataset = dataset.select(range(args.label.num_train)) dataset = dataset.with_format( - "torch", columns=["query_token", "choice", "response0_token", "query_response0_token", "response1_token", "query_response1_token", "batch", "split"] + "torch", + columns=[ + "query_token", + "choice", + "response0_token", + "query_response0_token", + "response1_token", + "query_response1_token", + "batch", + "split", + ], ) dataloader = DataLoader(dataset, batch_size=args.local_micro_batch_size) validation_dataset = load_dataset(args.label_dataset, "comparisons", split="validation").flatten() @@ -374,7 +384,9 @@ def evaluate(args: Args, accelerator, tokenizer, model, dataloader): for data in dataloader: update += 1 global_step += args.micro_batch_size - query_responses = torch.cat([data['query_response0_token'].unsqueeze(1), data['query_response1_token'].unsqueeze(1)], dim=1).flatten(0, 1) + query_responses = torch.cat( + [data["query_response0_token"].unsqueeze(1), data["query_response1_token"].unsqueeze(1)], dim=1 + ).flatten(0, 1) mb_best = data["choice"] # mb_query = data["query_token"] # mb_responses = torch.cat([data[f"response0_token"].unsqueeze(1), data[f"response1_token"].unsqueeze(1)], dim=1) @@ -408,7 +420,9 @@ def evaluate(args: Args, accelerator, tokenizer, model, dataloader): ) writer.add_scalar("train/rm/reward_rejected", accelerator.gather(reward_rejecteds).mean().item(), global_step) writer.add_scalar("train/rm/lr", scheduler.get_last_lr()[0], global_step) - accelerator.print(f"{train_accuracy=}, {scheduler.get_last_lr()=}, {optimizer.param_groups[0]['lr']=}, {update=}") + accelerator.print( + f"{train_accuracy=}, {scheduler.get_last_lr()=}, {optimizer.param_groups[0]['lr']=}, {update=}" + ) if args.run_eval: evaluate_df = evaluate(args, accelerator, tokenizer, model, validation_dataloader) diff --git a/lm_human_preference_details/tldr_dataset.py b/lm_human_preference_details/tldr_dataset.py index 6d406da..30f47ea 100644 --- a/lm_human_preference_details/tldr_dataset.py +++ b/lm_human_preference_details/tldr_dataset.py @@ -1,6 +1,7 @@ import multiprocessing import os from dataclasses import dataclass +from pprint import pformat from typing import Dict, Optional import matplotlib.pyplot as plt @@ -8,11 +9,9 @@ import tyro from datasets import load_dataset from huggingface_hub import HfApi +from huggingface_hub.repocard import RepoCard from rich.pretty import pprint from transformers import AutoTokenizer -from huggingface_hub.repocard import RepoCard -from pprint import pformat - api = HfApi() @@ -38,9 +37,9 @@ class Args: base_model: str = "gpt2" # EleutherAI/pythia-160m max_sft_response_length: int = 48 # 53 - max_sft_query_response_length: int = 512 + 48 # 565 + max_sft_query_response_length: int = 512 + 48 # 565 max_rm_response_length: int = 153 # 169 - max_rm_query_response_length: int = 512 + 153 # 665 + max_rm_query_response_length: int = 512 + 153 # 665 hf_entity: str = None @@ -166,7 +165,10 @@ def process_query_data(x): sft_ds.push_to_hub( f"{args.hf_entity}/summarize_from_feedback_tldr_3_filtered_oai_preprocessing_{args.base_model.split('/')[-1]}_{args.max_sft_response_length}" ) - sft_card = RepoCard.load(f"{args.hf_entity}/summarize_from_feedback_tldr_3_filtered_oai_preprocessing_{args.base_model.split('/')[-1]}_{args.max_sft_response_length}", repo_type="dataset") + sft_card = RepoCard.load( + f"{args.hf_entity}/summarize_from_feedback_tldr_3_filtered_oai_preprocessing_{args.base_model.split('/')[-1]}_{args.max_sft_response_length}", + repo_type="dataset", + ) sft_card.text = f"""\ # TL;DR SFT Dataset for OpenAI's [Summarize from Feedback](https://openai.com/blog/summarization/) task @@ -258,7 +260,9 @@ def process_response_data(x): axs[j].hist(df["reference_response_token_len"], bins=100) axs[j].set_title(f"{key} split: reference response token length\nmax_length={max(df['reference_response_token_len'])}") axs[j + 1].hist(df["query_reference_response_token_len"], bins=100) - axs[j + 1].set_title(f"{key} split: query.strip() + reference response token length\nmax_length={max(df['query_reference_response_token_len'])}") + axs[j + 1].set_title( + f"{key} split: query.strip() + reference response token length\nmax_length={max(df['query_reference_response_token_len'])}" + ) j += 2 offset = len(sft_ds) for _, key in enumerate(label_ds.keys()): @@ -268,9 +272,13 @@ def process_response_data(x): axs[j + 1].hist(df["response1_token_len"], bins=100) axs[j + 1].set_title(f"{key} split: response1 token length\nmax_length={max(df['response1_token_len'])}") axs[j + 2].hist(df["query_response0_token_len"], bins=100) - axs[j + 2].set_title(f"{key} split: query.strip() + response0 token length\nmax_length={max(df['query_response0_token_len'])}") + axs[j + 2].set_title( + f"{key} split: query.strip() + response0 token length\nmax_length={max(df['query_response0_token_len'])}" + ) axs[j + 3].hist(df["query_response1_token_len"], bins=100) - axs[j + 3].set_title(f"{key} split: query.strip() + response1 token length\nmax_length={max(df['query_response1_token_len'])}") + axs[j + 3].set_title( + f"{key} split: query.strip() + response1 token length\nmax_length={max(df['query_response1_token_len'])}" + ) j += 4 fig.suptitle(f"{args.base_model} Tokenizer: Token length distribution") fig.tight_layout() From 82ea918660917af192cee204b1409266b5833373 Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Fri, 29 Dec 2023 03:45:47 +0000 Subject: [PATCH 50/62] test --- lm_human_preference_details/summarize/ppo.py | 160 ++++++++++--------- oai.py | 101 ++++++++++++ 2 files changed, 189 insertions(+), 72 deletions(-) create mode 100644 oai.py diff --git a/lm_human_preference_details/summarize/ppo.py b/lm_human_preference_details/summarize/ppo.py index 69685eb..eaeab9a 100644 --- a/lm_human_preference_details/summarize/ppo.py +++ b/lm_human_preference_details/summarize/ppo.py @@ -13,6 +13,7 @@ import torch.optim as optim import tyro from accelerate import Accelerator +from accelerate.utils import gather_object from accelerate.state import AcceleratorState from datasets import load_dataset from rich.console import Console @@ -30,6 +31,8 @@ PretrainedConfig, PreTrainedModel, ) +from tqdm import tqdm + INVALID_LOGPROB = 1.0 @@ -125,8 +128,8 @@ class Args: """Whether to use deepspeed to train the model""" print_sample_output_freq: int = 220 """How often to print sample output""" - # run_eval: bool = False - # """Whether to run evaluation""" + run_eval: bool = False + """Whether to run evaluation""" # optimizer args eps: float = 1e-5 @@ -354,6 +357,65 @@ def forward(model, query_responses, tokenizer): ) +@dataclass +class EvalStorage: + query_token: List[str] = field(default_factory=list) + postprocessed_response_token: List[str] = field(default_factory=list) + reference_response_token: List[str] = field(default_factory=list) + score: List[float] = field(default_factory=list) + reference_score: List[float] = field(default_factory=list) + + query: List[str] = field(default_factory=list) + postprocessed_response: List[str] = field(default_factory=list) + reference_response: List[str] = field(default_factory=list) + + +def evaluate(args: Args, reward_model, policy, tokenizer, dataloader, generation_config, sampling=True): + eval_storage = EvalStorage() + with torch.no_grad(): + for data in tqdm(dataloader): + queries = data["query_token"] + reference_response_token = data["reference_response_token"] + query_reference_responses = torch.cat((data["query_token"], data["reference_response_token"]), dim=1) + _, reference_score, _ = get_reward(reward_model, query_reference_responses, tokenizer) + + context_length = queries.shape[1] + query_responses = generate( + policy, + queries, + tokenizer, + generation_config, + ) + responses = query_responses[:, context_length:] + postprocessed_responses = truncate_response(args, tokenizer, responses) + postprocessed_query_responses = torch.cat((queries, postprocessed_responses), 1) + _, score, _ = get_reward(reward_model, postprocessed_query_responses, tokenizer) + + eval_storage.query_token.extend(queries) + eval_storage.reference_response_token.extend(reference_response_token) + eval_storage.reference_score.append(reference_score) + eval_storage.postprocessed_response_token.extend(postprocessed_responses) + eval_storage.score.append(score) + if sampling: + break + + eval_storage.query = tokenizer.batch_decode(eval_storage.query_token, skip_special_tokens=True) + eval_storage.reference_response = tokenizer.batch_decode(eval_storage.reference_response_token) + eval_storage.postprocessed_response = tokenizer.batch_decode(eval_storage.postprocessed_response_token, skip_special_tokens=True) + eval_score = torch.cat(eval_storage.score).cpu().numpy().tolist() + eval_reference_score = torch.cat(eval_storage.reference_score).cpu().numpy().tolist() + eval_df = pd.DataFrame( + { + "query": gather_object(eval_storage.query), + "postprocessed_response": gather_object(eval_storage.postprocessed_response), + "reference_responses": gather_object(eval_storage.reference_response), + "scores": gather_object(eval_score), + "reference_scores": gather_object(eval_reference_score), + } + ) + return eval_storage, eval_df + + # def train(args: Args): if __name__ == "__main__": args = tyro.cli(Args) @@ -455,7 +517,7 @@ def forward(model, query_responses, tokenizer): dataset = dataset.shuffle(seed=local_seed) dataloader = DataLoader(dataset, batch_size=args.local_batch_size) validation_dataset = validation_dataset.with_format("torch", columns=["query_token", "reference_response_token"]) - validation_dataloader = DataLoader(validation_dataset, batch_size=args.local_batch_size) + validation_dataloader = DataLoader(validation_dataset, batch_size=args.local_eval_batch_size) model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader) validation_dataloader = accelerator.prepare(validation_dataloader) if args.deepspeed: @@ -490,19 +552,6 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well while True: yield from dataloader - sample_validation_inds = np.arange(args.batch_size) - local_sample_validation_inds = sample_validation_inds[accelerator.process_index :: accelerator.num_processes] - sample_validation = validation_dataset[local_sample_validation_inds] - sample_validation_queries = torch.Tensor(sample_validation["query_token"]).to(device) - with torch.no_grad(): - sample_validation_reference_response = torch.Tensor(sample_validation["reference_response_token"]).to(device) - sample_validation_query_reference_responses = torch.cat( - (sample_validation_queries, sample_validation_reference_response), dim=1 - ) - _, sample_validation_reference_scores, _ = get_reward( - reward_model, sample_validation_query_reference_responses, tokenizer - ) - iter_dataloader = iter(repeat_generator()) kl_ctl = AdaptiveKLController(args.reward.kl_coef, hparams=args.reward.adaptive_kl) # WARNING: even with `max_new_tokens` and `min_new_tokens` set to the same value, the number of tokens generated @@ -553,18 +602,21 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well context_length = queries.shape[1] responses = query_responses[:, context_length:] - # validation - sample_validation_query_responses = generate( - accelerator.unwrap_model(model).policy, - sample_validation_queries, - tokenizer, - validation_generation_config, - ) - sample_validation_responses = sample_validation_query_responses[:, context_length:] - postprocessed_sample_validation_responses = truncate_response(args, tokenizer, sample_validation_responses) - postprocessed_sample_validation_query_responses = torch.cat( - (sample_validation_queries, postprocessed_sample_validation_responses), 1 - ) + eval_storage, eval_df = evaluate(args, reward_model, accelerator.unwrap_model(model).policy, tokenizer, validation_dataloader, validation_generation_config) + validation_score = eval_storage.score[0] + if args.print_sample_output_freq > 0 and (update - 1) % args.print_sample_output_freq == 0: + if accelerator.is_main_process: + eval_df.to_csv(f"runs/{run_name}_{global_step}/table.csv") + if args.track: + wandb.log( + {"samples/query_responses": wandb.Table(dataframe=eval_df)}, step=update + ) + else: + try: + print_rich_table(f"Sample Output at Step {update}", eval_df[:1], console) + except Exception as e: + print(e) + del eval_storage, eval_df torch.cuda.empty_cache() # TODO: do I do this with query response or post-processed query response? @@ -594,14 +646,12 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well full_values, _, _ = get_reward(accelerator.unwrap_model(model).critic, query_responses, tokenizer) values = full_values[:, context_length - 1 : -1].squeeze(-1) _, scores, _ = get_reward(reward_model, postprocessed_query_responses, tokenizer) - _, validation_score, _ = get_reward(reward_model, postprocessed_sample_validation_query_responses, tokenizer) # 3. filter response. Ensure that the sample contains truncate_token_id # responses not passing that filter will receive a low (fixed) score # only query humans on responses that pass that filter contain_pad_token = torch.any(postprocessed_responses == tokenizer.pad_token_id, dim=-1) scores = torch.where(contain_pad_token, scores, torch.full_like(scores, args.task.penalty_reward_value)) - accelerator.print(f"{scores=}, {(contain_pad_token.sum() / len(contain_pad_token))=}") # torch.cuda.empty_cache() @@ -617,48 +667,6 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well if args.ppo.whiten_rewards: rewards = whiten(rewards, shift_mean=False) - if args.print_sample_output_freq > 0 and (update - 1) % args.print_sample_output_freq == 0: - try: - all_decode_validation_queries = tokenizer.batch_decode(sample_validation_queries, skip_special_tokens=True) - all_sample_validation_responses = tokenizer.batch_decode(sample_validation_responses) - all_sample_validation_query_responses_postprocessed = tokenizer.batch_decode( - postprocessed_sample_validation_query_responses, skip_special_tokens=True - ) - all_sample_validation_postprocessed_responses = [ - x[len(y) :] - for x, y in zip(all_sample_validation_query_responses_postprocessed, all_decode_validation_queries) - ] - all_sample_validation_reference_responses = tokenizer.batch_decode(sample_validation_reference_response) - all_sample_validation_df = pd.DataFrame( - { - "query": all_decode_validation_queries, - "response": all_sample_validation_responses, - "postprocessed_response": all_sample_validation_postprocessed_responses, - "reference_responses": all_sample_validation_reference_responses, - "scores": validation_score.float().cpu().numpy(), - "reference_scores": sample_validation_reference_scores.float().cpu().numpy(), - } - ) - if accelerator.is_main_process: - all_sample_validation_df.to_json(f"runs/{run_name}/table.json") - if args.track: - wandb.log( - {"samples/query_responses": wandb.Table(dataframe=all_sample_validation_df)}, step=update - ) - else: - print_rich_table(f"Sample Output at Step {update}", all_sample_validation_df[:1], console) - - except Exception as e: - print(e) - del ( - all_decode_validation_queries, - all_sample_validation_responses, - all_sample_validation_reference_responses, - all_sample_validation_df, - ) - del postprocessed_query_responses - torch.cuda.empty_cache() - # 6. compute advantages and returns lastgaelam = 0 advantages_reversed = [] @@ -813,6 +821,14 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well kl_ctl.update(mean_kl.item(), args.batch_size) del kl, mean_kl, mean_entropy, mean_non_score_reward, scores + if args.run_eval: + eval_storage, eval_df = evaluate(args, reward_model, accelerator.unwrap_model(model).policy, tokenizer, validation_dataloader, validation_generation_config, sampling=False) + eval_df.to_csv(f"runs/{run_name}/table.csv") + if accelerator.is_main_process and args.track: + wandb.log( + {"eval/query_responses": wandb.Table(dataframe=eval_df)}, step=update + ) + # save model if args.output_dir and args.num_train_epochs > 0: os.makedirs(os.path.dirname(args.output_dir), exist_ok=True) diff --git a/oai.py b/oai.py new file mode 100644 index 0000000..0a197ec --- /dev/null +++ b/oai.py @@ -0,0 +1,101 @@ +# you can download the CSV from https://wandb.ai/costa-huang/tldr_summarize/runs/gb2dian5 + +from dataclasses import dataclass +import random +from openai import AsyncOpenAI +import pandas as pd +import asyncio +import tyro +from tqdm.asyncio import tqdm_asyncio +from aiolimiter import AsyncLimiter + +limiter = AsyncLimiter(1000, 60) + +@dataclass +class Args: + csv_path: str = "trained_response.csv" + max_samples: int = 64 + +# client = OpenAI() +async_client = AsyncOpenAI() + + +template = r""" +Which of the following summaries does a better job of summarizing the most important points in the given forum post, without including unimportant or irrelevant details? Judge based on accuracy, coverage, and coherence. +### Post: +{{post}} +### Summary A: +{{summarya}} +### Summary B: +{{summaryb}} +### Instructions: +FIRST provide a one-sentence comparison of the two summaries, explaining which \ +you prefer and why. SECOND, on a new line, state only "A" or "B" to indicate your choice. Your response should use the format: +Comparison: +Preferred: <"A" or "B"> +""" + +async def process_text(post, summary_a, summary_b, i): + text = template.replace('{{post}}', post) + text = text.replace('{{summarya}}', summary_a) + text = text.replace('{{summaryb}}', summary_b) # Ensure this split logic is correct for your data + + async with limiter: + response = await async_client.chat.completions.create( + model="gpt-4", + messages=[ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": text}, + ] + ) + r = response.choices[0].message.content + try: + comparison = r.split('Comparison:')[1].split('Preferred:')[0].strip() + preferred = r.split('Preferred:')[1].strip() + return comparison, preferred, i + except: + print(f"error in {i}") + return "", random.choice(["A", "B"]), i + +async def main(args: Args): + num_trails = 2 + for j in range(num_trails): + print(j) + tasks = [] + df = pd.read_csv(args.csv_path) + df["explanation"] = [None for _ in range(len(df))] + df["prefered"] = [None for _ in range(len(df))] + df["shuffled_index"] = [None for _ in range(len(df))] + r = range(min(args.max_samples, len(df))) + if args.max_samples == -1: + r = range(len(df)) + for i in r: + post = df['query'].iloc[i].strip() + # shuffled the index to avoid GPT4's preference bias in the content's order + shuffled_index = random.randint(0, 1) + df.at[i, 'shuffled_index'] = shuffled_index + summaries = [ + df['postprocessed_response'].iloc[i].strip(), + df['reference_responses'].iloc[i].split('<|endoftext|>')[0].strip(), + ] + summary_a = summaries[shuffled_index] + summary_b = summaries[1 - shuffled_index] + task = asyncio.create_task(process_text(post, summary_a, summary_b, i)) + tasks.append(task) + + results = await tqdm_asyncio.gather(*tasks) + + for _, (comparison, preferred, i) in enumerate(results): + df.at[i, 'explanation'] = comparison + preferred_label = "ours" if (df.at[i, 'shuffled_index'] == 0 and preferred == "A") or \ + (df.at[i, 'shuffled_index'] == 1 and preferred == "B") else "reference" + df.at[i, 'prefered'] = preferred_label + + + print(df['prefered'].value_counts()) + df.to_csv(f'{args.csv_path}_judged.csv') + # return df + +if __name__ == "__main__": + args = tyro.cli(Args) + asyncio.run(main(args)) \ No newline at end of file From f97df9f6660781a495a664c0aefadea42f5f5a44 Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Fri, 29 Dec 2023 03:46:58 +0000 Subject: [PATCH 51/62] various fix; ppo repeat shuffle --- lm_human_preference_details/summarize/ppo.py | 68 ++++++++++++------- .../summarize/reward.py | 14 ++-- lm_human_preference_details/summarize/sft.py | 48 ++++++++++--- oai.py | 59 +++++++++------- 4 files changed, 123 insertions(+), 66 deletions(-) diff --git a/lm_human_preference_details/summarize/ppo.py b/lm_human_preference_details/summarize/ppo.py index eaeab9a..2ba9f0b 100644 --- a/lm_human_preference_details/summarize/ppo.py +++ b/lm_human_preference_details/summarize/ppo.py @@ -13,8 +13,8 @@ import torch.optim as optim import tyro from accelerate import Accelerator -from accelerate.utils import gather_object from accelerate.state import AcceleratorState +from accelerate.utils import gather_object from datasets import load_dataset from rich.console import Console from rich.pretty import pprint @@ -22,6 +22,7 @@ from torch import optim from torch.utils.data import DataLoader from torch.utils.tensorboard import SummaryWriter +from tqdm import tqdm from transformers import ( AutoConfig, AutoModel, @@ -31,8 +32,6 @@ PretrainedConfig, PreTrainedModel, ) -from tqdm import tqdm - INVALID_LOGPROB = 1.0 @@ -364,11 +363,11 @@ class EvalStorage: reference_response_token: List[str] = field(default_factory=list) score: List[float] = field(default_factory=list) reference_score: List[float] = field(default_factory=list) - + query: List[str] = field(default_factory=list) postprocessed_response: List[str] = field(default_factory=list) reference_response: List[str] = field(default_factory=list) - + def evaluate(args: Args, reward_model, policy, tokenizer, dataloader, generation_config, sampling=True): eval_storage = EvalStorage() @@ -376,10 +375,10 @@ def evaluate(args: Args, reward_model, policy, tokenizer, dataloader, generation for data in tqdm(dataloader): queries = data["query_token"] reference_response_token = data["reference_response_token"] + context_length = queries.shape[1] query_reference_responses = torch.cat((data["query_token"], data["reference_response_token"]), dim=1) _, reference_score, _ = get_reward(reward_model, query_reference_responses, tokenizer) - context_length = queries.shape[1] query_responses = generate( policy, queries, @@ -401,9 +400,11 @@ def evaluate(args: Args, reward_model, policy, tokenizer, dataloader, generation eval_storage.query = tokenizer.batch_decode(eval_storage.query_token, skip_special_tokens=True) eval_storage.reference_response = tokenizer.batch_decode(eval_storage.reference_response_token) - eval_storage.postprocessed_response = tokenizer.batch_decode(eval_storage.postprocessed_response_token, skip_special_tokens=True) - eval_score = torch.cat(eval_storage.score).cpu().numpy().tolist() - eval_reference_score = torch.cat(eval_storage.reference_score).cpu().numpy().tolist() + eval_storage.postprocessed_response = tokenizer.batch_decode( + eval_storage.postprocessed_response_token, skip_special_tokens=True + ) + eval_score = torch.cat(eval_storage.score).float().cpu().numpy().tolist() + eval_reference_score = torch.cat(eval_storage.reference_score).float().cpu().numpy().tolist() eval_df = pd.DataFrame( { "query": gather_object(eval_storage.query), @@ -514,11 +515,21 @@ def evaluate(args: Args, reward_model, policy, tokenizer, dataloader, generation dataset = load_dataset(args.task.query_dataset, split="train") validation_dataset = load_dataset(args.task.query_dataset, split="validation") dataset = dataset.with_format("torch", columns=["query_token", "reference_response_token"]) - dataset = dataset.shuffle(seed=local_seed) - dataloader = DataLoader(dataset, batch_size=args.local_batch_size) + dataloader = DataLoader(dataset, batch_size=args.local_batch_size, shuffle=True) validation_dataset = validation_dataset.with_format("torch", columns=["query_token", "reference_response_token"]) validation_dataloader = DataLoader(validation_dataset, batch_size=args.local_eval_batch_size) + + # sync random states for DataLoader(shuffle=True) before `accelerator.prepare` + # see https://gist.github.com/vwxyzjn/2581bff1e48e185e0b85b6dfe1def79c + torch.manual_seed(args.seed) model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader) + torch.manual_seed(local_seed) # reset the local seed again + + def repeat_generator(): + while True: + yield from dataloader + + iter_dataloader = iter(repeat_generator()) validation_dataloader = accelerator.prepare(validation_dataloader) if args.deepspeed: import deepspeed @@ -548,11 +559,6 @@ def evaluate(args: Args, reward_model, policy, tokenizer, dataloader, generation ref_policy = ref_policy.to(device) reward_model = reward_model.to(device) - def repeat_generator(): # TODO: ideally we shuffle the dataloader as well - while True: - yield from dataloader - - iter_dataloader = iter(repeat_generator()) kl_ctl = AdaptiveKLController(args.reward.kl_coef, hparams=args.reward.adaptive_kl) # WARNING: even with `max_new_tokens` and `min_new_tokens` set to the same value, the number of tokens generated # may not be the same. TODO: investigate further, we just want to generate a fixed number of tokens @@ -602,15 +608,20 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well context_length = queries.shape[1] responses = query_responses[:, context_length:] - eval_storage, eval_df = evaluate(args, reward_model, accelerator.unwrap_model(model).policy, tokenizer, validation_dataloader, validation_generation_config) + eval_storage, eval_df = evaluate( + args, + reward_model, + accelerator.unwrap_model(model).policy, + tokenizer, + validation_dataloader, + validation_generation_config, + ) validation_score = eval_storage.score[0] if args.print_sample_output_freq > 0 and (update - 1) % args.print_sample_output_freq == 0: if accelerator.is_main_process: - eval_df.to_csv(f"runs/{run_name}_{global_step}/table.csv") + eval_df.to_csv(f"runs/{run_name}/table_{global_step}.csv") if args.track: - wandb.log( - {"samples/query_responses": wandb.Table(dataframe=eval_df)}, step=update - ) + wandb.log({"samples/query_responses": wandb.Table(dataframe=eval_df)}, step=update) else: try: print_rich_table(f"Sample Output at Step {update}", eval_df[:1], console) @@ -619,7 +630,6 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well del eval_storage, eval_df torch.cuda.empty_cache() - # TODO: do I do this with query response or post-processed query response? output = forward(accelerator.unwrap_model(model).policy, query_responses, tokenizer) logits = output.logits[:, context_length - 1 : -1] logits /= args.task.temperature + 1e-7 @@ -822,12 +832,18 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well del kl, mean_kl, mean_entropy, mean_non_score_reward, scores if args.run_eval: - eval_storage, eval_df = evaluate(args, reward_model, accelerator.unwrap_model(model).policy, tokenizer, validation_dataloader, validation_generation_config, sampling=False) + eval_storage, eval_df = evaluate( + args, + reward_model, + accelerator.unwrap_model(model).policy, + tokenizer, + validation_dataloader, + validation_generation_config, + sampling=False, + ) eval_df.to_csv(f"runs/{run_name}/table.csv") if accelerator.is_main_process and args.track: - wandb.log( - {"eval/query_responses": wandb.Table(dataframe=eval_df)}, step=update - ) + wandb.log({"eval/query_responses": wandb.Table(dataframe=eval_df)}, step=update) # save model if args.output_dir and args.num_train_epochs > 0: diff --git a/lm_human_preference_details/summarize/reward.py b/lm_human_preference_details/summarize/reward.py index 1a3054b..9b46189 100644 --- a/lm_human_preference_details/summarize/reward.py +++ b/lm_human_preference_details/summarize/reward.py @@ -229,12 +229,14 @@ def evaluate(args: Args, accelerator, tokenizer, model, dataloader): with torch.no_grad(): items = defaultdict(list) for data in tqdm(dataloader): - mb_query = data["query_token"] - mb_responses = torch.cat([data[f"response0_token"].unsqueeze(1), data[f"response1_token"].unsqueeze(1)], dim=1) + query_responses = torch.cat( + [data["query_response0_token"].unsqueeze(1), data["query_response1_token"].unsqueeze(1)], dim=1 + ).flatten(0, 1) mb_best = data["choice"] - mb_query_tiled = mb_query.unsqueeze(1).repeat(1, args.label.num_labels, 1) - query_responses = torch.cat([mb_query_tiled, mb_responses], dim=2).flatten(0, 1) - # query_responses = left_padding_to_right_padding(query_responses, tokenizer.pad_token_id) + # mb_query = data["query_token"] + # mb_responses = torch.cat([data[f"response0_token"].unsqueeze(1), data[f"response1_token"].unsqueeze(1)], dim=1) + # mb_query_tiled = mb_query.unsqueeze(1).repeat(1, args.label.num_labels, 1) + # query_responses = torch.cat([mb_query_tiled, mb_responses], dim=2).flatten(0, 1) predicted_reward = get_reward(model, query_responses, tokenizer) predicted_reward = predicted_reward.view(-1, args.label.num_labels) accuracy = (predicted_reward.argmax(1) == mb_best).float() @@ -292,7 +294,9 @@ def evaluate(args: Args, accelerator, tokenizer, model, dataloader): "query_token", "choice", "response0_token", + "query_response0_token", "response1_token", + "query_response1_token", "batch", "split", "extra.confidence", diff --git a/lm_human_preference_details/summarize/sft.py b/lm_human_preference_details/summarize/sft.py index 420955a..3c20704 100644 --- a/lm_human_preference_details/summarize/sft.py +++ b/lm_human_preference_details/summarize/sft.py @@ -52,7 +52,8 @@ class TaskHParams: response_length: int = 53 # Truncate response after the first occurrence of this token at or after index after when sampling. - truncate_token: int = 50256 # EOS token + truncate_token: Literal["eos"] = "eos" + truncate_token_id: Optional[int] = None truncate_after: int = 16 penalty_reward_value: int = -1 @@ -169,6 +170,26 @@ def generate(lm_backbone, queries, tokenizer, generation_config): return torch.cat((queries, output.sequences[:, context_length:]), dim=1) +def first_true_indices(bools, dtype=torch.long): + """ + Takes an N-dimensional bool tensor and returns an (N-1)-dimensional tensor of integers giving + the position of the first True in each "row". + + Returns the length of the rows (bools.size(-1)) if no element is True in a given row. + """ + row_len = bools.size(-1) + zero_or_index = row_len * (~bools).type(dtype) + torch.arange(row_len, dtype=dtype, device=bools.device) + return torch.min(zero_or_index, dim=-1).values + + +def truncate_response(args, tokenizer, responses): + trunc_idxs = first_true_indices(responses == args.task.truncate_token_id).unsqueeze(-1) + new_size = [1] * (len(responses.size()) - 1) + [args.task.response_length] + idxs = torch.arange(args.task.response_length, device=responses.device).view(*new_size) + postprocessed_responses = torch.masked_fill(responses, idxs > trunc_idxs, tokenizer.pad_token_id) + return postprocessed_responses + + def forward(model, query_responses, tokenizer): attention_mask = query_responses != tokenizer.pad_token_id input_ids = torch.masked_fill(query_responses, ~attention_mask, 0) @@ -189,8 +210,9 @@ def evaluate(args: Args, accelerator, tokenizer, model, dataloader, generation_c unwrapped = accelerator.unwrap_model(model) for _, data in tqdm(enumerate(dataloader)): with torch.no_grad(): - reference_responses = data["reference_response_token"] queries = data["query_token"] + reference_responses = data["reference_response_token"] + context_length = queries.shape[1] query_reference_responses = torch.cat((queries, reference_responses), dim=1) output = forward(model, query_reference_responses, tokenizer) labels = query_reference_responses.masked_fill(query_reference_responses == tokenizer.pad_token_id, -1) @@ -213,13 +235,15 @@ def evaluate(args: Args, accelerator, tokenizer, model, dataloader, generation_c tokenizer, generation_config, ) + responses = generated_responses[:, context_length:] + postprocessed_responses = truncate_response(args, tokenizer, responses) decode_queries = tokenizer.batch_decode(queries) decode_reference_responses = tokenizer.batch_decode( reference_responses, skip_special_tokens=True, ) decode_responses = tokenizer.batch_decode( - generated_responses[:, -args.task.response_length :], + postprocessed_responses, skip_special_tokens=True, ) rouge_score = rouge.compute(predictions=decode_responses, references=decode_reference_responses) @@ -268,6 +292,16 @@ def evaluate(args: Args, accelerator, tokenizer, model, dataloader, generation_c args.total_episodes = len(dataset) args.num_updates = args.total_episodes // args.batch_size + tokenizer = AutoTokenizer.from_pretrained( + args.base_model, + padding_side="right", + trust_remote_code=True, + ) + # we use the padding token manually but do not resize the token embedding of the model + tokenizer.add_special_tokens({"pad_token": "[PAD]"}) + if args.task.truncate_token == "eos": + args.task.truncate_token_id = tokenizer.eos_token_id + console = Console(force_terminal=True) run_name = f"{args.exp_name}__{args.seed}__{int(time.time())}" writer = SimpleNamespace() # dummy writer @@ -296,13 +330,7 @@ def evaluate(args: Args, accelerator, tokenizer, model, dataloader, generation_c np.random.seed(local_seed) torch.manual_seed(local_seed) torch.backends.cudnn.deterministic = True - tokenizer = AutoTokenizer.from_pretrained( - args.base_model, - padding_side="right", - trust_remote_code=True, - ) - # we use the padding token manually but do not resize the token embedding of the model - tokenizer.add_special_tokens({"pad_token": "[PAD]"}) + model_config = AutoConfig.from_pretrained(args.base_model) configure_dropout(model_config, args.dropout_layer_keys, 0.0) # disable dropout model: PreTrainedModel = AutoModelForCausalLM.from_pretrained( diff --git a/oai.py b/oai.py index 0a197ec..b0c29cd 100644 --- a/oai.py +++ b/oai.py @@ -1,21 +1,24 @@ # you can download the CSV from https://wandb.ai/costa-huang/tldr_summarize/runs/gb2dian5 -from dataclasses import dataclass +import asyncio import random -from openai import AsyncOpenAI +from dataclasses import dataclass + import pandas as pd -import asyncio import tyro -from tqdm.asyncio import tqdm_asyncio from aiolimiter import AsyncLimiter +from openai import AsyncOpenAI +from tqdm.asyncio import tqdm_asyncio limiter = AsyncLimiter(1000, 60) + @dataclass class Args: csv_path: str = "trained_response.csv" max_samples: int = 64 + # client = OpenAI() async_client = AsyncOpenAI() @@ -35,10 +38,11 @@ class Args: Preferred: <"A" or "B"> """ + async def process_text(post, summary_a, summary_b, i): - text = template.replace('{{post}}', post) - text = text.replace('{{summarya}}', summary_a) - text = text.replace('{{summaryb}}', summary_b) # Ensure this split logic is correct for your data + text = template.replace("{{post}}", post) + text = text.replace("{{summarya}}", summary_a) + text = text.replace("{{summaryb}}", summary_b) # Ensure this split logic is correct for your data async with limiter: response = await async_client.chat.completions.create( @@ -46,17 +50,18 @@ async def process_text(post, summary_a, summary_b, i): messages=[ {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": text}, - ] + ], ) r = response.choices[0].message.content try: - comparison = r.split('Comparison:')[1].split('Preferred:')[0].strip() - preferred = r.split('Preferred:')[1].strip() + comparison = r.split("Comparison:")[1].split("Preferred:")[0].strip() + preferred = r.split("Preferred:")[1].strip() return comparison, preferred, i except: print(f"error in {i}") return "", random.choice(["A", "B"]), i + async def main(args: Args): num_trails = 2 for j in range(num_trails): @@ -64,19 +69,19 @@ async def main(args: Args): tasks = [] df = pd.read_csv(args.csv_path) df["explanation"] = [None for _ in range(len(df))] - df["prefered"] = [None for _ in range(len(df))] + df["preferred"] = [None for _ in range(len(df))] df["shuffled_index"] = [None for _ in range(len(df))] r = range(min(args.max_samples, len(df))) if args.max_samples == -1: r = range(len(df)) for i in r: - post = df['query'].iloc[i].strip() + post = df["query"].iloc[i].strip() # shuffled the index to avoid GPT4's preference bias in the content's order shuffled_index = random.randint(0, 1) - df.at[i, 'shuffled_index'] = shuffled_index + df.at[i, "shuffled_index"] = shuffled_index summaries = [ - df['postprocessed_response'].iloc[i].strip(), - df['reference_responses'].iloc[i].split('<|endoftext|>')[0].strip(), + df["postprocessed_response"].iloc[i].strip(), + df["reference_responses"].iloc[i].split("<|endoftext|>")[0].strip(), ] summary_a = summaries[shuffled_index] summary_b = summaries[1 - shuffled_index] @@ -84,18 +89,22 @@ async def main(args: Args): tasks.append(task) results = await tqdm_asyncio.gather(*tasks) - - for _, (comparison, preferred, i) in enumerate(results): - df.at[i, 'explanation'] = comparison - preferred_label = "ours" if (df.at[i, 'shuffled_index'] == 0 and preferred == "A") or \ - (df.at[i, 'shuffled_index'] == 1 and preferred == "B") else "reference" - df.at[i, 'prefered'] = preferred_label + for _, (comparison, preferred, i) in enumerate(results): + df.at[i, "explanation"] = comparison + preferred_label = ( + "ours" + if (df.at[i, "shuffled_index"] == 0 and preferred == "A") + or (df.at[i, "shuffled_index"] == 1 and preferred == "B") + else "reference" + ) + df.at[i, "preferred"] = preferred_label - print(df['prefered'].value_counts()) - df.to_csv(f'{args.csv_path}_judged.csv') + print(df["preferred"].value_counts()) + df.to_csv(f"{args.csv_path}_judged.csv") # return df - + + if __name__ == "__main__": args = tyro.cli(Args) - asyncio.run(main(args)) \ No newline at end of file + asyncio.run(main(args)) From 0efacc40320cb29a45467c8d804c57d2a4c3d830 Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Mon, 1 Jan 2024 15:01:40 +0000 Subject: [PATCH 52/62] push changes --- .../summarize/ppo_left_padding.py | 899 ++++++++++++++++++ .../summarize/reward.py | 50 +- poetry.lock | 534 ++++++++++- pyproject.toml | 8 +- 4 files changed, 1451 insertions(+), 40 deletions(-) create mode 100644 lm_human_preference_details/summarize/ppo_left_padding.py diff --git a/lm_human_preference_details/summarize/ppo_left_padding.py b/lm_human_preference_details/summarize/ppo_left_padding.py new file mode 100644 index 0000000..12ddc2e --- /dev/null +++ b/lm_human_preference_details/summarize/ppo_left_padding.py @@ -0,0 +1,899 @@ +import os +import random +import time +from dataclasses import asdict, dataclass, field +from types import SimpleNamespace +from typing import List, Literal, Optional + +import numpy as np +import pandas as pd +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +import tyro +from accelerate import Accelerator +from accelerate.state import AcceleratorState +from accelerate.utils import gather_object +from datasets import load_dataset +from rich.console import Console +from rich.pretty import pprint +from rich.table import Table +from torch import optim +from torch.utils.data import DataLoader +from torch.utils.tensorboard import SummaryWriter +from tqdm import tqdm +from transformers import ( + AutoConfig, + AutoModel, + AutoModelForCausalLM, + AutoTokenizer, + GenerationConfig, + PretrainedConfig, + PreTrainedModel, +) + +INVALID_LOGPROB = 1.0 + + +@dataclass +class AdaptiveKLParams: + target: float = 6.0 + horizon: int = 10000 # in episodes + + +@dataclass +class RewardHParams: + use_adaptive_kl: bool = False + adaptive_kl: Optional[AdaptiveKLParams] = field(default_factory=AdaptiveKLParams) + dataset_std: float = 1.0 + kl_coef: float = 0.05 + + +@dataclass +class PpoHParams: + num_updates: tyro.conf.Suppress[int] = None + nminibatches: int = 1 + noptepochs: int = 4 + vf_coef: float = 0.1 + cliprange: float = 0.2 + cliprange_value: float = 0.2 + gamma: float = 1 + lam: float = 0.95 + whiten_rewards: bool = True + + +@dataclass +class TaskHParams: + # Query params + query_length: int = 512 + query_dataset: str = "vwxyzjn/summarize_from_feedback_tldr_3_filtered_oai_preprocessing_pythia-160m_53" + + query_format_str: Optional[str] = "SUBREDDIT: r/{subreddit}\n\nTITLE: {title}\n\nPOST: {post}\n\nTL;DR:" + query_truncate_field: Optional[str] = "post" + query_truncate_text: Optional[str] = "\n" + query_padding: Optional[str] = None # defaults to repeated spaces + query_pad_side: Optional[str] = "left" + + # Response params + response_length: int = 53 + + # Truncate response after the first occurrence of this token at or after index after when sampling. + truncate_token: Literal["eos"] = "eos" + truncate_token_id: Optional[int] = None + truncate_after: int = 16 + penalty_reward_value: int = -1 + + # LM params + temperature: float = 0.7 + + +# a patch +@dataclass +class TaskQueryHParams: + length: int = None + dataset: str = None + format_str: Optional[str] = None # if underlying dataset yields dicts, can format arbitrarily + truncate_field: Optional[str] = None + truncate_text: Optional[str] = None + padding: Optional[str] = None # defaults to repeated spaces + pad_side: Optional[str] = None + + +@dataclass +class Args: + # common args + exp_name: str = os.path.basename(__file__)[: -len(".py")] + """the name of this experiment""" + seed: int = 1 + """seed of the experiment""" + track: bool = False + """if toggled, this experiment will be tracked with Weights and Biases""" + wandb_project_name: str = "tldr_summarize" + """the wandb's project name""" + wandb_entity: Optional[str] = None + """the entity (team) of wandb's project""" + cuda: bool = True + """Whether to use cuda if available.""" + run_name: Optional[str] = None + """a unique name of this run""" + load_from_cache_file: bool = False + """Whether to load data from the local cache file in `dataset.map`""" + push_to_hub: bool = False + "whether to upload the saved model to huggingface" + hf_entity: str = "" + "the user or org name of the model repository from the Hugging Face Hub" + deepspeed: bool = False + """Whether to use deepspeed to train the model""" + print_sample_output_freq: int = 220 + """How often to print sample output""" + run_eval: bool = False + """Whether to run evaluation""" + + # optimizer args + eps: float = 1e-5 + """the epsilon value for the optimizer""" + lr: float = 0.00001 + """the learning rate""" + optimizer: Literal["adam", "adamw"] = "adamw" + """Which optimizer to use""" + scheduler: str = "cosine" + """Which scheduler to use""" + warm_up_steps: int = 0 + """Number of warm up steps for the scheduler""" + + world_size: Optional[int] = None + """The number of processes (GPUs) to use""" + num_train_epochs: int = 1 + """Number of epochs to train""" + num_updates: Optional[int] = None + """The number of updates to train""" + gradient_accumulation_steps: int = 64 + """The number of gradient accumulation steps""" + local_micro_batch_size: Optional[int] = 1 + """The micro batch size per GPU (HF's `per_device_train_batch_size`)""" + total_episodes: Optional[int] = 1000000 + """The total number of episodes in the dataset""" + micro_batch_size: Optional[int] = None + """The micro batch size across devices (HF's `per_device_train_batch_size` * `world_size`)""" + local_batch_size: Optional[int] = None + """The batch size per GPU (HF's `per_device_train_batch_size` * `gradient_accumulation_steps`)""" + batch_size: Optional[int] = None + """The batch size across devices (HF's `per_device_train_batch_size` * `world_size` * `gradient_accumulation_steps`)""" + nminibatches: int = 1 + """Number of minibatches to split a batch into""" + local_mini_batch_size: Optional[int] = None + """the mini batch size per GPU""" + mini_batch_size: Optional[int] = None + """the mini batch size across GPUs""" + local_eval_batch_size: int = 8 + """per rank eval batch size""" + + # other args + base_model: str = "EleutherAI/pythia-160m" + """the name of the pretrained model to use""" + offload: bool = False + """Whether to offload ref policy and reward model to CPU""" + reward_model_path: str = "" + """the name of the pretrained model to use""" + sft_model_path: str = "EleutherAI/pythia-160m" + """the name of the pretrained model to use""" + dropout_layer_keys: List[str] = field( + default_factory=lambda: ["attn_pdrop", "embd_pdrop", "resid_pdrop", "summary_first_dropout"] + ) + """Which layers to apply dropout to""" + output_dir: str = "models/ppo_model" + """Where to save the model""" + task: TaskHParams = field(default_factory=TaskHParams) + reward: RewardHParams = field(default_factory=RewardHParams) + ppo: PpoHParams = field(default_factory=PpoHParams) + + +# taken from https://github.com/microsoft/DeepSpeedExamples/blob/737c6740bec38b77a24a59135b6481a53d566b38/applications/DeepSpeed-Chat/training/utils/model/model_utils.py#L20C1-L26C52 +def configure_dropout(model_config, dropout_layer_keys, dropout): + if dropout is not None: + for key in dropout_layer_keys: + if hasattr(model_config, key): + print(f"Setting model_config.{key} to {dropout}") + setattr(model_config, key, dropout) + + +def print_rich_table(title: str, df: pd.DataFrame, console: Console) -> Table: + table = Table(show_lines=True) + for column in df.columns: + table.add_column(column) + for _, row in df.iterrows(): + table.add_row(*row.astype(str).tolist()) + console.rule(f"[bold red]{title}") + console.print(table) + + +def layer_init(layer, std=np.sqrt(2), bias_const=0.0): + torch.nn.init.normal_(layer.weight, std=std) + torch.nn.init.constant_(layer.bias, val=bias_const) + return layer + + +class AdaptiveKLController: + def __init__(self, init_kl_coef: float, hparams: AdaptiveKLParams): + self.value = init_kl_coef + self.hparams = hparams + + def update(self, current, n_steps): + target = self.hparams.target + proportional_error = np.clip(current / target - 1, -0.2, 0.2) + mult = 1 + proportional_error * n_steps / self.hparams.horizon + self.value *= mult + + +def whiten(values, shift_mean=True): + # `unbiased=False` matches TF `tf.nn.moments`'s setting + mean, var = torch.mean(values), torch.var(values, unbiased=False) + whitened = (values - mean) * torch.rsqrt(var + 1e-8) + if not shift_mean: + whitened += mean + return whitened + + +class ScalarModelConfig(PretrainedConfig): + def __init__( + self, + base_model: str = "EleutherAI/pythia-160m", + base_config: PretrainedConfig = AutoConfig.from_pretrained("EleutherAI/pythia-160m"), + hidden_size: int = 768, + bias: float = 0.0, + **kwargs, + ): + super().__init__(**kwargs) + self.base_model = base_model + self.base_config = base_config + self.hidden_size = hidden_size + self.bias = bias + + +class ScalarModel(PreTrainedModel): + config_class = ScalarModelConfig + + def __init__(self, config: ScalarModelConfig): + super().__init__(config) + self.config = config + self.lm_backbone = AutoModel.from_pretrained( + config.base_model, + config=self.config.base_config, + trust_remote_code=True, + ) + self.scalar_head = layer_init( + nn.Linear(self.config.hidden_size, 1), + std=1 / np.sqrt(self.config.hidden_size + 1), + ) + + def forward(self, **kwargs): + output = self.lm_backbone(**kwargs) + reward = self.scalar_head(output.hidden_states[-1]) - self.config.bias + return reward + + +def get_reward(model, query_responses, tokenizer, context_length): + attention_mask = query_responses != tokenizer.pad_token_id + position_ids = attention_mask.cumsum(1) - attention_mask.long() # exclusive cumsum + input_ids = torch.masked_fill(query_responses, ~attention_mask, 0) + reward_logits = model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + return_dict=True, + output_hidden_states=True, + ) + sequence_lengths = first_true_indices(query_responses[:, context_length:] == tokenizer.pad_token_id) - 1 + context_length + # https://github.com/huggingface/transformers/blob/dc68a39c8111217683bf49a4912d0c9018bab33d/src/transformers/models/gpt2/modeling_gpt2.py#L1454 + return ( + reward_logits, + reward_logits[torch.arange(reward_logits.size(0), device=reward_logits.device), sequence_lengths].squeeze(-1), + sequence_lengths, + ) + + +def get_value(model, query_responses, tokenizer): + attention_mask = query_responses != tokenizer.pad_token_id + input_ids = torch.masked_fill(query_responses, ~attention_mask, 0) + reward_logits = model( + input_ids=input_ids, + attention_mask=attention_mask, + return_dict=True, + output_hidden_states=True, + ) + sequence_lengths = first_true_indices(query_responses == tokenizer.pad_token_id) - 1 + # https://github.com/huggingface/transformers/blob/dc68a39c8111217683bf49a4912d0c9018bab33d/src/transformers/models/gpt2/modeling_gpt2.py#L1454 + return ( + reward_logits, + reward_logits[torch.arange(reward_logits.size(0), device=reward_logits.device), sequence_lengths].squeeze(-1), + sequence_lengths, + ) + + +# taken from https://github.com/OpenLMLab/MOSS-RLHF/blob/40b91eb2f2b71b16919addede0341d2bef70825d/ppo/ppo_trainer.py#L29 +# we did this we can do a single `model = accelerator.prepare(model)` +class PolicyAndValueWrapper(nn.Module): + def __init__(self, policy, critic) -> None: + super().__init__() + self.policy = policy + self.critic = critic + + def forward(self, **kwargs): + return self.policy(**kwargs), self.critic(**kwargs) + + +def exact_div(a, b): + q = a // b + if a != q * b: + raise ValueError(f"Inexact division: {a} / {b} = {a / b}") + return q + + +def generate(lm_backbone, queries, tokenizer, generation_config): + """generate in a way that does not affect padding tokens""" + context_length = queries.shape[1] + attention_mask = queries != tokenizer.pad_token_id + input_ids = torch.masked_fill(queries, ~attention_mask, 0) + output = lm_backbone.generate( + input_ids=input_ids, + attention_mask=attention_mask, + # position_ids=attention_mask.cumsum(1) - attention_mask.long(), # generation collapsed if this was turned on. TODO: why does generation collapse with this? + generation_config=generation_config, + return_dict_in_generate=True, + ) + return torch.cat((queries, output.sequences[:, context_length:]), dim=1) + + +def first_true_indices(bools, dtype=torch.long): + """ + Takes an N-dimensional bool tensor and returns an (N-1)-dimensional tensor of integers giving + the position of the first True in each "row". + + Returns the length of the rows (bools.size(-1)) if no element is True in a given row. + """ + row_len = bools.size(-1) + zero_or_index = row_len * (~bools).type(dtype) + torch.arange(row_len, dtype=dtype, device=bools.device) + return torch.min(zero_or_index, dim=-1).values + + +def truncate_response(args, tokenizer, responses): + trunc_idxs = first_true_indices(responses == args.task.truncate_token_id).unsqueeze(-1) + new_size = [1] * (len(responses.size()) - 1) + [args.task.response_length] + idxs = torch.arange(args.task.response_length, device=responses.device).view(*new_size) + postprocessed_responses = torch.masked_fill(responses, idxs > trunc_idxs, tokenizer.pad_token_id) + return postprocessed_responses + + +def forward(model, query_responses, tokenizer): + attention_mask = query_responses != tokenizer.pad_token_id + # position_ids = attention_mask.cumsum(1) - attention_mask.long() + input_ids = torch.masked_fill(query_responses, ~attention_mask, 0) + return model( + input_ids=input_ids, + attention_mask=attention_mask, + # position_ids=position_ids, + return_dict=True, + output_hidden_states=True, + ) + + +@dataclass +class EvalStorage: + query_token: List[str] = field(default_factory=list) + postprocessed_response_token: List[str] = field(default_factory=list) + reference_response_token: List[str] = field(default_factory=list) + score: List[float] = field(default_factory=list) + reference_score: List[float] = field(default_factory=list) + + query: List[str] = field(default_factory=list) + postprocessed_response: List[str] = field(default_factory=list) + reference_response: List[str] = field(default_factory=list) + + +def evaluate(args: Args, reward_model, policy, tokenizer, dataloader, generation_config, sampling=True): + eval_storage = EvalStorage() + with torch.no_grad(): + for data in tqdm(dataloader): + queries = data["query_token"] + reference_response_token = data["reference_response_token"] + context_length = queries.shape[1] + query_reference_responses = torch.cat((data["query_token"], data["reference_response_token"]), dim=1) + _, reference_score, _ = get_reward(reward_model, query_reference_responses, tokenizer, queries.shape[1]) + + query_responses = generate( + policy, + queries, + tokenizer, + generation_config, + ) + responses = query_responses[:, context_length:] + postprocessed_responses = truncate_response(args, tokenizer, responses) + postprocessed_query_responses = torch.cat((queries, postprocessed_responses), 1) + _, score, _ = get_reward(reward_model, postprocessed_query_responses, tokenizer, queries.shape[1]) + + eval_storage.query_token.extend(queries) + eval_storage.reference_response_token.extend(reference_response_token) + eval_storage.reference_score.append(reference_score) + eval_storage.postprocessed_response_token.extend(postprocessed_responses) + eval_storage.score.append(score) + if sampling: + break + + eval_storage.query = tokenizer.batch_decode(eval_storage.query_token, skip_special_tokens=True) + eval_storage.reference_response = tokenizer.batch_decode(eval_storage.reference_response_token) + eval_storage.postprocessed_response = tokenizer.batch_decode( + eval_storage.postprocessed_response_token, skip_special_tokens=True + ) + eval_score = torch.cat(eval_storage.score).float().cpu().numpy().tolist() + eval_reference_score = torch.cat(eval_storage.reference_score).float().cpu().numpy().tolist() + eval_df = pd.DataFrame( + { + "query": gather_object(eval_storage.query), + "postprocessed_response": gather_object(eval_storage.postprocessed_response), + "reference_responses": gather_object(eval_storage.reference_response), + "scores": gather_object(eval_score), + "reference_scores": gather_object(eval_reference_score), + } + ) + return eval_storage, eval_df + + +# def train(args: Args): +if __name__ == "__main__": + args = tyro.cli(Args) + accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps) + local_seed = args.seed + accelerator.process_index * 100003 # Prime + args.world_size = accelerator.num_processes + args.local_batch_size = args.local_micro_batch_size * args.gradient_accumulation_steps + args.micro_batch_size = int(args.local_micro_batch_size * args.world_size) + args.batch_size = int(args.local_batch_size * args.world_size) + args.mini_batch_size = exact_div(args.batch_size, args.nminibatches) + args.local_mini_batch_size = exact_div(args.local_batch_size, args.nminibatches) + if args.ppo.whiten_rewards: + assert ( + args.local_mini_batch_size >= 8 + ), f"Per-rank minibatch size {args.local_mini_batch_size} is insufficient for whitening" + # `per_rank_rollout_batch_size` is our `args.local_batch_size` + # `per_rank_minibatch_size` is our `args.local_mini_batch_size` + args.ppo.num_updates = args.total_episodes // args.batch_size + tokenizer = AutoTokenizer.from_pretrained( + args.base_model, + padding_side="right", + trust_remote_code=True, + ) + # we use the padding token manually but do not resize the token embedding of the model + tokenizer.add_special_tokens({"pad_token": "[PAD]"}) + if args.task.truncate_token == "eos": + args.task.truncate_token_id = tokenizer.eos_token_id + + console = Console(force_terminal=True) + run_name = f"{args.exp_name}__{args.seed}__{int(time.time())}" + writer = SimpleNamespace() # dummy writer + writer.add_scalar = lambda x, y, z: None + writer.add_histogram = lambda x, y, z: None + if accelerator.is_main_process: + if args.track: + import wandb + + wandb.init( + project=args.wandb_project_name, + entity=args.wandb_entity, + sync_tensorboard=True, + config=asdict(args), + name=run_name, + save_code=True, + ) + wandb.run.log_code(".") + writer = SummaryWriter(f"runs/{run_name}") + writer.add_text( + "hyperparameters", + "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])), + ) + pprint(args) + device = accelerator.device + random.seed(local_seed) + np.random.seed(local_seed) + torch.manual_seed(local_seed) + torch.backends.cudnn.deterministic = True + + model_config = AutoConfig.from_pretrained(args.base_model) + configure_dropout(model_config, args.dropout_layer_keys, 0.0) # disable dropout + scalar_model_config = ScalarModelConfig( + base_model=args.base_model, + base_config=model_config, + hidden_size=model_config.hidden_size, + ) + if not args.reward_model_path: + critic: PreTrainedModel = ScalarModel(scalar_model_config) + reward_model: PreTrainedModel = ScalarModel(scalar_model_config) + else: + critic: PreTrainedModel = ScalarModel.from_pretrained( + args.reward_model_path, + trust_remote_code=True, + ) + reward_model: PreTrainedModel = ScalarModel.from_pretrained( + args.reward_model_path, + trust_remote_code=True, + ) + if accelerator.is_main_process: + pprint(model_config) + pprint(reward_model.config) + # each class should have a separate pretrained model that do not share weights + ref_policy = AutoModelForCausalLM.from_pretrained(args.sft_model_path, config=model_config, trust_remote_code=True) + policy = AutoModelForCausalLM.from_pretrained(args.sft_model_path, config=model_config, trust_remote_code=True) + # critic.lm_backbone.gradient_checkpointing_enable() + # policy.gradient_checkpointing_enable() + accelerator.print(policy) + accelerator.print(critic) + policy.generation_config.eos_token_id = None # disable `pad_token_id` and `eos_token_id` because we just want to + policy.generation_config.pad_token_id = None # generate tokens without truncation / padding + model = PolicyAndValueWrapper(policy, critic) + if args.optimizer == "adam": + optimizer = optim.Adam(model.parameters(), lr=args.lr, eps=args.eps) + elif args.optimizer == "adamw": + optimizer = optim.AdamW(model.parameters(), lr=args.lr, eps=args.eps) + + dataset = load_dataset(args.task.query_dataset, split="train") + validation_dataset = load_dataset(args.task.query_dataset, split="validation") + dataset = dataset.with_format("torch", columns=["query_token", "reference_response_token"]) + dataloader = DataLoader(dataset, batch_size=args.local_batch_size, shuffle=True) + validation_dataset = validation_dataset.with_format("torch", columns=["query_token", "reference_response_token"]) + validation_dataloader = DataLoader(validation_dataset, batch_size=args.local_eval_batch_size) + + # sync random states for DataLoader(shuffle=True) before `accelerator.prepare` + # see https://gist.github.com/vwxyzjn/2581bff1e48e185e0b85b6dfe1def79c + torch.manual_seed(args.seed) + model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader) + torch.manual_seed(local_seed) # reset the local seed again + + def repeat_generator(): + while True: + yield from dataloader + + iter_dataloader = iter(repeat_generator()) + validation_dataloader = accelerator.prepare(validation_dataloader) + if args.deepspeed: + import deepspeed + + deepspeed_states = AcceleratorState().deepspeed_plugin + deepspeed_states.deepspeed_config["train_micro_batch_size_per_gpu"] = args.local_micro_batch_size + + eval_ds_config = { + "train_micro_batch_size_per_gpu": deepspeed_states.deepspeed_config["train_micro_batch_size_per_gpu"], + "bf16": {"enabled": True}, + "prescale_gradients": False, + "wall_clock_breakdown": False, + } + if args.offload: + deepspeed_states.deepspeed_config["checkpoint"] = {"use_node_local_storage": True} + eval_ds_config["zero_optimization"] = { + "stage": 3, + "stage3_param_persistence_threshold": 1e4, + "offload_param": {"device": "cpu"}, + } + accelerator.print(f"{eval_ds_config=}") + reward_model, *_ = deepspeed.initialize(model=reward_model, config=eval_ds_config) + reward_model.eval() + ref_policy, *_ = deepspeed.initialize(model=ref_policy, config=eval_ds_config) + ref_policy.eval() + else: + ref_policy = ref_policy.to(device) + reward_model = reward_model.to(device) + + kl_ctl = AdaptiveKLController(args.reward.kl_coef, hparams=args.reward.adaptive_kl) + # WARNING: even with `max_new_tokens` and `min_new_tokens` set to the same value, the number of tokens generated + # may not be the same. TODO: investigate further, we just want to generate a fixed number of tokens + generation_config = GenerationConfig( + max_new_tokens=args.task.response_length, + min_new_tokens=args.task.response_length, + temperature=(args.task.temperature + 1e-7), + top_k=0.0, + top_p=1.0, + do_sample=True, + ) + # use the same `0.01` temperature for validation response generation https://github.com/openai/summarize-from-feedback/blob/700967448d10004279f138666442bf1497d0e705/exps/sample.py#L27 + validation_generation_config = GenerationConfig( + max_new_tokens=args.task.response_length, + min_new_tokens=args.task.response_length, + temperature=(0.01 + 1e-7), + top_k=0.0, + top_p=1.0, + do_sample=True, + ) + + accelerator.print("===training policy===") + global_step = 0 + stats_shape = (args.ppo.noptepochs, args.nminibatches, args.gradient_accumulation_steps) + approxkl_stats = torch.zeros(stats_shape, device=device) + pg_clipfrac_stats = torch.zeros(stats_shape, device=device) + pg_loss_stats = torch.zeros(stats_shape, device=device) + vf_loss_stats = torch.zeros(stats_shape, device=device) + vf_clipfrac_stats = torch.zeros(stats_shape, device=device) + entropy_stats = torch.zeros(stats_shape, device=device) + ratio_stats = torch.zeros(stats_shape, device=device) + model.train() + for update in range(1, args.ppo.num_updates + 1): + global_step += 1 * args.batch_size + frac = 1.0 - (update - 1.0) / args.ppo.num_updates + lrnow = frac * args.lr + optimizer.param_groups[0]["lr"] = lrnow + data = next(iter_dataloader) + with torch.no_grad(): + queries = data["query_token"].to(device) + query_responses = generate( + accelerator.unwrap_model(model).policy, + queries, + tokenizer, + generation_config, + ) + context_length = queries.shape[1] + responses = query_responses[:, context_length:] + + eval_storage, eval_df = evaluate( + args, + reward_model, + accelerator.unwrap_model(model).policy, + tokenizer, + validation_dataloader, + validation_generation_config, + ) + validation_score = eval_storage.score[0] + if args.print_sample_output_freq > 0 and (update - 1) % args.print_sample_output_freq == 0: + if accelerator.is_main_process: + eval_df.to_csv(f"runs/{run_name}/table_{global_step}.csv") + if args.track: + wandb.log({"samples/query_responses": wandb.Table(dataframe=eval_df)}, step=update) + else: + try: + print_rich_table(f"Sample Output at Step {update}", eval_df[:1], console) + except Exception as e: + print(e) + del eval_storage, eval_df + torch.cuda.empty_cache() + + output = forward(accelerator.unwrap_model(model).policy, query_responses, tokenizer) + logits = output.logits[:, context_length - 1 : -1] + logits /= args.task.temperature + 1e-7 + all_logprobs = F.log_softmax(logits, dim=-1) + logprobs = torch.gather(all_logprobs, 2, responses.unsqueeze(-1)).squeeze(-1) + del output, logits, all_logprobs + torch.cuda.empty_cache() + + ref_output = forward(ref_policy, query_responses, tokenizer) + ref_logits = ref_output.logits[:, context_length - 1 : -1] + ref_logits /= args.task.temperature + 1e-7 + ref_all_logprobs = F.log_softmax(ref_logits, dim=-1) + ref_logprobs = torch.gather(ref_all_logprobs, 2, responses.unsqueeze(-1)).squeeze(-1) + del ref_output, ref_logits, ref_all_logprobs + torch.cuda.empty_cache() + + # **Response Processing** + postprocessed_responses = truncate_response(args, tokenizer, responses) + torch.cuda.empty_cache() + + # 2. run reward model on the truncated responses + postprocessed_query_responses = torch.cat((queries, postprocessed_responses), 1) + sequence_lengths = first_true_indices(postprocessed_responses == tokenizer.pad_token_id) - 1 + full_values, _, _ = get_value(accelerator.unwrap_model(model).critic, query_responses, tokenizer) + values = full_values[:, context_length - 1 : -1].squeeze(-1) + scores_logits, scores, _ = get_reward(reward_model, postprocessed_query_responses, tokenizer, queries.shape[1]) + + # 3. filter response. Ensure that the sample contains truncate_token_id + # responses not passing that filter will receive a low (fixed) score + # only query humans on responses that pass that filter + contain_pad_token = torch.any(postprocessed_responses == tokenizer.pad_token_id, dim=-1) + scores = torch.where(contain_pad_token, scores, torch.full_like(scores, args.task.penalty_reward_value)) + accelerator.print(f"{scores=}, {(contain_pad_token.sum() / len(contain_pad_token))=}") + # torch.cuda.empty_cache() + + # 4. compute rewards + kl = logprobs - ref_logprobs + non_score_reward = -kl_ctl.value * kl + rewards = non_score_reward.clone() + actual_start = torch.arange(rewards.size(0), device=rewards.device) + actual_end = sequence_lengths + rewards[[actual_start, actual_end]] += scores + + # 5. whiten rewards + if args.ppo.whiten_rewards: + rewards = whiten(rewards, shift_mean=False) + + # 6. compute advantages and returns + lastgaelam = 0 + advantages_reversed = [] + gen_length = args.task.response_length + for t in reversed(range(gen_length)): + nextvalues = values[:, t + 1] if t < gen_length - 1 else 0.0 + delta = rewards[:, t] + args.ppo.gamma * nextvalues - values[:, t] + lastgaelam = delta + args.ppo.gamma * args.ppo.lam * lastgaelam + advantages_reversed.append(lastgaelam) + advantages = torch.stack(advantages_reversed[::-1], axis=1) + returns = advantages + values + advantages = whiten(advantages) + return_mean, return_var = returns.mean(), returns.var() + value_mean, value_var = values.mean(), values.var() + writer.add_histogram("rewards", rewards[0].float(), global_step) + writer.add_histogram("advantages", advantages[0].float(), global_step) + accelerator.print("rewards====", rewards[0]) + accelerator.print("advantages====", advantages[0]) + + # Do multiple epochs of PPO training, with a fresh random shuffle in each epoch + for ppo_epoch_idx in range(args.ppo.noptepochs): + b_inds = np.random.permutation(args.local_batch_size) + minibatch_idx = 0 + for mini_batch_start in range(0, args.local_batch_size, args.local_mini_batch_size): + mini_batch_end = mini_batch_start + args.local_mini_batch_size + mini_batch_inds = b_inds[mini_batch_start:mini_batch_end] + gradient_accumulation_idx = 0 + for micro_batch_start in range(0, args.local_mini_batch_size, args.local_micro_batch_size): + with accelerator.accumulate(policy): + micro_batch_end = micro_batch_start + args.local_micro_batch_size + micro_batch_inds = mini_batch_inds[micro_batch_start:micro_batch_end] + mb_return = returns[micro_batch_inds] + mb_advantage = advantages[micro_batch_inds] + mb_values = values[micro_batch_inds] + mb_responses = responses[micro_batch_inds] + mb_query_responses = query_responses[micro_batch_inds] + mb_logprobs = logprobs[micro_batch_inds] + + output, vpred_temp = forward(model, mb_query_responses, tokenizer) + logits = output.logits[:, context_length - 1 : -1] + logits /= args.task.temperature + 1e-7 + new_all_logprobs = F.log_softmax(logits, dim=-1) + new_logprobs = torch.gather(new_all_logprobs, 2, mb_responses.unsqueeze(-1)).squeeze(-1) + vpred = vpred_temp[:, context_length - 1 : -1].squeeze(-1) + vpredclipped = torch.clamp( + vpred, + mb_values - args.ppo.cliprange_value, + mb_values + args.ppo.cliprange_value, + ) + vf_losses1 = torch.square(vpred - mb_return) + vf_losses2 = torch.square(vpredclipped - mb_return) + vf_loss = 0.5 * torch.max(vf_losses1, vf_losses2).mean() + vf_clipfrac = (vf_losses2 > vf_losses1).float().mean() + logprobs_diff = new_logprobs - mb_logprobs + ratio = torch.exp(logprobs_diff) + pg_losses = -mb_advantage * ratio + pg_losses2 = -mb_advantage * torch.clamp(ratio, 1.0 - args.ppo.cliprange, 1.0 + args.ppo.cliprange) + pg_loss = torch.max(pg_losses, pg_losses2).mean() + pg_clipfrac = (pg_losses2 > pg_losses).float().mean() + loss = pg_loss + args.ppo.vf_coef * vf_loss + accelerator.backward(loss) + optimizer.step() + optimizer.zero_grad() + prob_dist = torch.nn.functional.softmax(logits, dim=-1) + entropy = torch.logsumexp(logits, dim=-1) - torch.sum(prob_dist * logits, dim=-1) + approxkl = 0.5 * (logprobs_diff**2).mean() + # if ppo_epoch_idx == 0 and micro_batch_start == 0: + # torch.testing.assert_close(ratio, torch.zeros_like(ratio) + 1, atol=1e-4, rtol=1e-4) + # if ppo_epoch_idx == 0: + # pprint({ + # # "responses": responses, + # # "values": values, + # "rewards": rewards, + # # "scores": scores, + # "advantages": advantages, + # # "ratio": ratio, + # # "pg_losses": pg_losses, + # # "approxkl": approxkl, + # # "pg_loss": pg_loss, + # # "pg_clipfrac": pg_clipfrac, + # # "ratio": ratio.mean(), + # # "vf_loss": vf_loss, + # # "vf_clipfrac": vf_clipfrac, + # # "entropy": masked_mean(entropy, ~padding_mask[micro_batch_inds]), + # }) + # breakpoint() + with torch.no_grad(): + approxkl_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = approxkl + pg_clipfrac_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = pg_clipfrac + pg_loss_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = pg_loss + vf_loss_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = vf_loss + vf_clipfrac_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = vf_clipfrac + entropy_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = entropy.mean() + ratio_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = ratio.mean() + gradient_accumulation_idx += 1 + minibatch_idx += 1 + if accelerator.is_main_process: + console.print( + f"ppo_epoch_idx", + ppo_epoch_idx, + "approxkl", + approxkl_stats[: ppo_epoch_idx + 1].mean().item(), + "pg_loss", + pg_loss_stats[: ppo_epoch_idx + 1].mean().item(), + "pg_clipfrac", + pg_clipfrac_stats[: ppo_epoch_idx + 1].mean().item(), + "ratio", + ratio_stats[: ppo_epoch_idx + 1].mean().item(), + ) + with torch.no_grad(): + if not args.deepspeed: # for some reason there is a OOM with the `writer.add_histogram` + writer.add_histogram("ppo/val/ratio_hist", ratio, update) + mean_kl = kl.sum(1).mean() + mean_entropy = (-logprobs).sum(1).mean() + mean_non_score_reward = non_score_reward.sum(1).mean() + writer.add_scalar("objective/kl_coef", kl_ctl.value, update) + writer.add_scalar("objective/kl", accelerator.gather(mean_kl).mean().item(), update) + writer.add_scalar("objective/entropy", accelerator.gather(mean_entropy).mean().item(), update) + writer.add_scalar("objective/non_score_reward", accelerator.gather(mean_non_score_reward).mean().item(), update) + writer.add_scalar( + "objective/score_total", accelerator.gather(mean_non_score_reward + scores.mean()).mean().item(), update + ) + writer.add_scalar("objective/scores", accelerator.gather(scores.mean()).mean().item(), update) + writer.add_scalar("objective/validation_score", accelerator.gather(validation_score.mean()).mean().item(), update) + writer.add_scalar("ppo/loss/policy", accelerator.gather(pg_loss).mean().item(), update) + writer.add_scalar("ppo/loss/value", accelerator.gather(vf_loss).mean().item(), update) + writer.add_scalar("ppo/loss/total", accelerator.gather(loss).mean().item(), update) + writer.add_scalar("ppo/policy/entropy", accelerator.gather(entropy.mean()).mean().item(), update) + writer.add_scalar("ppo/policy/approxkl", accelerator.gather(approxkl).mean().item(), update) + writer.add_scalar("ppo/policy/clipfrac", accelerator.gather(pg_clipfrac).mean().item(), update) + writer.add_scalar("ppo/policy/approxkl_avg", accelerator.gather(approxkl_stats).mean().item(), update) + writer.add_scalar("ppo/policy/clipfrac_avg", accelerator.gather(pg_clipfrac_stats).mean().item(), update) + writer.add_scalar("ppo/loss/policy_avg", accelerator.gather(pg_loss_stats).mean().item(), update) + writer.add_scalar("ppo/loss/value_avg", accelerator.gather(vf_loss_stats).mean().item(), update) + writer.add_scalar("ppo/val/clipfrac_avg", accelerator.gather(vf_clipfrac_stats).mean().item(), update) + writer.add_scalar("ppo/policy/entropy_avg", accelerator.gather(entropy_stats).mean().item(), update) + writer.add_scalar("ppo/returns/mean", accelerator.gather(return_mean).mean().item(), update) + writer.add_scalar("ppo/returns/var", accelerator.gather(return_var).mean().item(), update) + writer.add_scalar("ppo/val/vpred", accelerator.gather(vpred.mean()).mean().item(), update) + writer.add_scalar("ppo/val/error", accelerator.gather(vf_losses1.mean()).mean().item(), update) + writer.add_scalar("ppo/val/clipfrac", accelerator.gather(vf_clipfrac).mean().item(), update) + writer.add_scalar("ppo/val/mean", accelerator.gather(value_mean).mean().item(), update) + writer.add_scalar("ppo/val/var", accelerator.gather(value_var).mean().item(), update) + writer.add_scalar("ppo/val/ratio", accelerator.gather(ratio_stats).mean().item(), update) + writer.add_scalar("ppo/val/ratio_var", accelerator.gather(ratio_stats).var().item(), update) + writer.add_scalar("ppo/val/advantage", accelerator.gather(advantages.mean()).mean().item(), update) + writer.add_scalar("ppo/val/advantage_var", accelerator.gather(advantages.mean()).var().item(), update) + writer.add_scalar("ppo/val/num_eos_tokens", (responses == tokenizer.eos_token_id).sum().item(), update) + writer.add_scalar("ppo/lr", lrnow, update) + writer.add_scalar("ppo/episode", global_step, update) + if args.reward.use_adaptive_kl: + kl_ctl.update(mean_kl.item(), args.batch_size) + del kl, mean_kl, mean_entropy, mean_non_score_reward, scores + + if args.run_eval: + eval_storage, eval_df = evaluate( + args, + reward_model, + accelerator.unwrap_model(model).policy, + tokenizer, + validation_dataloader, + validation_generation_config, + sampling=False, + ) + eval_df.to_csv(f"runs/{run_name}/table.csv") + if accelerator.is_main_process and args.track: + wandb.log({"eval/query_responses": wandb.Table(dataframe=eval_df)}, step=update) + + # save model + if args.output_dir and args.num_train_epochs > 0: + os.makedirs(os.path.dirname(args.output_dir), exist_ok=True) + time_tensor = torch.tensor([int(time.time())], device=device) + time_int = accelerator.gather(time_tensor)[0].item() # avoid different timestamps across processes + repo_name = f"{args.base_model.replace('/', '_')}__{args.exp_name}__tldr" + repo_id = f"{args.hf_entity}/{repo_name}" if args.hf_entity else repo_name + + if accelerator.is_main_process: + tokenizer.save_pretrained(args.output_dir, repo_id=repo_id) + if args.push_to_hub: + tokenizer.push_to_hub(repo_id, revision=f"seed{args.seed}_{str(time_int)}") + + unwrapped: PreTrainedModel = accelerator.unwrap_model(model).policy + accelerator.wait_for_everyone() + if accelerator.is_main_process: + unwrapped.save_pretrained( + args.output_dir, + is_main_process=accelerator.is_main_process, + save_function=accelerator.save, + state_dict=accelerator.get_state_dict(unwrapped), + safe_serialization=False, + repo_id=repo_id, + ) + if args.push_to_hub: + unwrapped.push_to_hub(repo_id, revision=f"seed{args.seed}_{str(time_int)}", safe_serialization=False) + +# if __name__ == "__main__": +# args = tyro.cli(Args) +# train(args) diff --git a/lm_human_preference_details/summarize/reward.py b/lm_human_preference_details/summarize/reward.py index 9b46189..d99717c 100644 --- a/lm_human_preference_details/summarize/reward.py +++ b/lm_human_preference_details/summarize/reward.py @@ -36,10 +36,10 @@ @dataclass class LabelHParams: - type: str = None + type: Optional[str] = None num_train: int = 92832 num_labels: int = 2 - source: str = None + source: Optional[str] = None # a patch @@ -132,6 +132,7 @@ class Args: # other args base_model: str = "EleutherAI/pythia-160m" + reward_model_path: str = "" """the name of the pretrained model to use""" dropout_layer_keys: List[str] = field( default_factory=lambda: ["attn_pdrop", "embd_pdrop", "resid_pdrop", "summary_first_dropout"] @@ -353,7 +354,13 @@ def evaluate(args: Args, accelerator, tokenizer, model, dataloader): base_config=model_config, hidden_size=model_config.hidden_size, ) - model: PreTrainedModel = ScalarModel(scalar_model_config) + if len(args.reward_model_path) == 0: + model: PreTrainedModel = ScalarModel(scalar_model_config) + else: + model: PreTrainedModel = ScalarModel.from_pretrained( + args.reward_model_path, + trust_remote_code=True, + ) if accelerator.is_main_process: pprint(model_config) if args.optimizer == "adam": @@ -428,24 +435,24 @@ def evaluate(args: Args, accelerator, tokenizer, model, dataloader): f"{train_accuracy=}, {scheduler.get_last_lr()=}, {optimizer.param_groups[0]['lr']=}, {update=}" ) - if args.run_eval: - evaluate_df = evaluate(args, accelerator, tokenizer, model, validation_dataloader) - for split, row in evaluate_df[["split", "accuracy"]].groupby(["split"]).mean().iterrows(): - writer.add_scalar(f"eval/rm/accuracy/split/{split}", row["accuracy"], global_step) - accelerator.print(f"eval/rm/accuracy/split/{split}: {row['accuracy']}") - for batch, row in evaluate_df[["batch", "accuracy"]].groupby(["batch"]).mean().iterrows(): - writer.add_scalar(f"eval/rm/accuracy/batch/{batch}", row["accuracy"], global_step) - accelerator.print(f"eval/rm/accuracy/batch/{batch}: {row['accuracy']}") - for confi, row in evaluate_df[["confidence", "accuracy"]].groupby(["confidence"]).mean().iterrows(): - writer.add_scalar(f"eval/rm/accuracy/confidence/{confi}", row["accuracy"], global_step) - accelerator.print(f"eval/rm/accuracy/confidence/{confi}: {row['accuracy']}") - writer.add_scalar("eval/rm/accuracy", evaluate_df["accuracy"].mean(), global_step) - accelerator.print(f"eval/rm/accuracy: {evaluate_df['accuracy'].mean()}") - if accelerator.is_main_process: - os.makedirs(f"eval_tables/{run_name}", exist_ok=True) - evaluate_df.to_csv(f"eval_tables/{run_name}/eval_{update}.csv") - if args.track: - wandb.log({"samples/query_responses": wandb.Table(dataframe=evaluate_df)}, step=update) + if args.run_eval: + evaluate_df = evaluate(args, accelerator, tokenizer, model, validation_dataloader) + for split, row in evaluate_df[["split", "accuracy"]].groupby(["split"]).mean().iterrows(): + writer.add_scalar(f"eval/rm/accuracy/split/{split}", row["accuracy"], global_step) + accelerator.print(f"eval/rm/accuracy/split/{split}: {row['accuracy']}") + for batch, row in evaluate_df[["batch", "accuracy"]].groupby(["batch"]).mean().iterrows(): + writer.add_scalar(f"eval/rm/accuracy/batch/{batch}", row["accuracy"], global_step) + accelerator.print(f"eval/rm/accuracy/batch/{batch}: {row['accuracy']}") + for confi, row in evaluate_df[["confidence", "accuracy"]].groupby(["confidence"]).mean().iterrows(): + writer.add_scalar(f"eval/rm/accuracy/confidence/{confi}", row["accuracy"], global_step) + accelerator.print(f"eval/rm/accuracy/confidence/{confi}: {row['accuracy']}") + writer.add_scalar("eval/rm/accuracy", evaluate_df["accuracy"].mean(), global_step) + accelerator.print(f"eval/rm/accuracy: {evaluate_df['accuracy'].mean()}") + if accelerator.is_main_process: + os.makedirs(f"eval_tables/{run_name}", exist_ok=True) + evaluate_df.to_csv(f"eval_tables/{run_name}/eval_{update}.csv") + if args.track: + wandb.log({"samples/query_responses": wandb.Table(dataframe=evaluate_df)}, step=update) torch.cuda.empty_cache() norm_dataset = load_dataset(args.task.query_dataset, split="train") @@ -463,7 +470,6 @@ def evaluate(args: Args, accelerator, tokenizer, model, dataloader): predicted_reward = accelerator.gather(predicted_reward) queries = accelerator.gather(queries) reference_responses = accelerator.gather(reference_responses) - accelerator.print(predicted_reward.shape) for i in range(len(predicted_reward)): items["query"].append(tokenizer.decode(queries[i], skip_special_tokens=True)) items["reference_response"].append(tokenizer.decode(reference_responses[i])) diff --git a/poetry.lock b/poetry.lock index 503b091..22d1fcc 100644 --- a/poetry.lock +++ b/poetry.lock @@ -174,8 +174,27 @@ files = [ {file = "annotated_types-0.6.0.tar.gz", hash = "sha256:563339e807e53ffd9c267e99fc6d9ea23eb8443c08f112651963e24e22f84a5d"}, ] +[[package]] +name = "anyio" +version = "4.2.0" +description = "High level compatibility layer for multiple asynchronous event loop implementations" +optional = false +python-versions = ">=3.8" +files = [ + {file = "anyio-4.2.0-py3-none-any.whl", hash = "sha256:745843b39e829e108e518c489b31dc757de7d2131d53fac32bd8df268227bfee"}, + {file = "anyio-4.2.0.tar.gz", hash = "sha256:e1875bb4b4e2de1669f4bc7869b6d3f54231cdced71605e6e64c9be77e3be50f"}, +] + [package.dependencies] -typing-extensions = {version = ">=4.0.0", markers = "python_version < \"3.9\""} +exceptiongroup = {version = ">=1.0.2", markers = "python_version < \"3.11\""} +idna = ">=2.8" +sniffio = ">=1.1" +typing-extensions = {version = ">=4.1", markers = "python_version < \"3.11\""} + +[package.extras] +doc = ["Sphinx (>=7)", "packaging", "sphinx-autodoc-typehints (>=1.2.0)", "sphinx-rtd-theme"] +test = ["anyio[trio]", "coverage[toml] (>=7)", "exceptiongroup (>=1.2.0)", "hypothesis (>=4.0)", "psutil (>=5.9)", "pytest (>=7.0)", "pytest-mock (>=3.6.1)", "trustme", "uvloop (>=0.17)"] +trio = ["trio (>=0.23)"] [[package]] name = "appdirs" @@ -607,6 +626,84 @@ files = [ {file = "contextlib2-21.6.0.tar.gz", hash = "sha256:ab1e2bfe1d01d968e1b7e8d9023bc51ef3509bba217bb730cee3827e1ee82869"}, ] +[[package]] +name = "contourpy" +version = "1.2.0" +description = "Python library for calculating contours of 2D quadrilateral grids" +optional = false +python-versions = ">=3.9" +files = [ + {file = "contourpy-1.2.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:0274c1cb63625972c0c007ab14dd9ba9e199c36ae1a231ce45d725cbcbfd10a8"}, + {file = "contourpy-1.2.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:ab459a1cbbf18e8698399c595a01f6dcc5c138220ca3ea9e7e6126232d102bb4"}, + {file = "contourpy-1.2.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6fdd887f17c2f4572ce548461e4f96396681212d858cae7bd52ba3310bc6f00f"}, + {file = "contourpy-1.2.0-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5d16edfc3fc09968e09ddffada434b3bf989bf4911535e04eada58469873e28e"}, + {file = "contourpy-1.2.0-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:1c203f617abc0dde5792beb586f827021069fb6d403d7f4d5c2b543d87edceb9"}, + {file = "contourpy-1.2.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b69303ceb2e4d4f146bf82fda78891ef7bcd80c41bf16bfca3d0d7eb545448aa"}, + {file = "contourpy-1.2.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:884c3f9d42d7218304bc74a8a7693d172685c84bd7ab2bab1ee567b769696df9"}, + {file = "contourpy-1.2.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:4a1b1208102be6e851f20066bf0e7a96b7d48a07c9b0cfe6d0d4545c2f6cadab"}, + {file = "contourpy-1.2.0-cp310-cp310-win32.whl", hash = "sha256:34b9071c040d6fe45d9826cbbe3727d20d83f1b6110d219b83eb0e2a01d79488"}, + {file = "contourpy-1.2.0-cp310-cp310-win_amd64.whl", hash = "sha256:bd2f1ae63998da104f16a8b788f685e55d65760cd1929518fd94cd682bf03e41"}, + {file = "contourpy-1.2.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:dd10c26b4eadae44783c45ad6655220426f971c61d9b239e6f7b16d5cdaaa727"}, + {file = "contourpy-1.2.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:5c6b28956b7b232ae801406e529ad7b350d3f09a4fde958dfdf3c0520cdde0dd"}, + {file = "contourpy-1.2.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ebeac59e9e1eb4b84940d076d9f9a6cec0064e241818bcb6e32124cc5c3e377a"}, + {file = "contourpy-1.2.0-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:139d8d2e1c1dd52d78682f505e980f592ba53c9f73bd6be102233e358b401063"}, + {file = "contourpy-1.2.0-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:1e9dc350fb4c58adc64df3e0703ab076f60aac06e67d48b3848c23647ae4310e"}, + {file = "contourpy-1.2.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:18fc2b4ed8e4a8fe849d18dce4bd3c7ea637758c6343a1f2bae1e9bd4c9f4686"}, + {file = "contourpy-1.2.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:16a7380e943a6d52472096cb7ad5264ecee36ed60888e2a3d3814991a0107286"}, + {file = "contourpy-1.2.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:8d8faf05be5ec8e02a4d86f616fc2a0322ff4a4ce26c0f09d9f7fb5330a35c95"}, + {file = "contourpy-1.2.0-cp311-cp311-win32.whl", hash = "sha256:67b7f17679fa62ec82b7e3e611c43a016b887bd64fb933b3ae8638583006c6d6"}, + {file = "contourpy-1.2.0-cp311-cp311-win_amd64.whl", hash = "sha256:99ad97258985328b4f207a5e777c1b44a83bfe7cf1f87b99f9c11d4ee477c4de"}, + {file = "contourpy-1.2.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:575bcaf957a25d1194903a10bc9f316c136c19f24e0985a2b9b5608bdf5dbfe0"}, + {file = "contourpy-1.2.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:9e6c93b5b2dbcedad20a2f18ec22cae47da0d705d454308063421a3b290d9ea4"}, + {file = "contourpy-1.2.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:464b423bc2a009088f19bdf1f232299e8b6917963e2b7e1d277da5041f33a779"}, + {file = "contourpy-1.2.0-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:68ce4788b7d93e47f84edd3f1f95acdcd142ae60bc0e5493bfd120683d2d4316"}, + {file = "contourpy-1.2.0-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3d7d1f8871998cdff5d2ff6a087e5e1780139abe2838e85b0b46b7ae6cc25399"}, + {file = "contourpy-1.2.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6e739530c662a8d6d42c37c2ed52a6f0932c2d4a3e8c1f90692ad0ce1274abe0"}, + {file = "contourpy-1.2.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:247b9d16535acaa766d03037d8e8fb20866d054d3c7fbf6fd1f993f11fc60ca0"}, + {file = "contourpy-1.2.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:461e3ae84cd90b30f8d533f07d87c00379644205b1d33a5ea03381edc4b69431"}, + {file = "contourpy-1.2.0-cp312-cp312-win32.whl", hash = "sha256:1c2559d6cffc94890b0529ea7eeecc20d6fadc1539273aa27faf503eb4656d8f"}, + {file = "contourpy-1.2.0-cp312-cp312-win_amd64.whl", hash = "sha256:491b1917afdd8638a05b611a56d46587d5a632cabead889a5440f7c638bc6ed9"}, + {file = "contourpy-1.2.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:5fd1810973a375ca0e097dee059c407913ba35723b111df75671a1976efa04bc"}, + {file = "contourpy-1.2.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:999c71939aad2780f003979b25ac5b8f2df651dac7b38fb8ce6c46ba5abe6ae9"}, + {file = "contourpy-1.2.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b7caf9b241464c404613512d5594a6e2ff0cc9cb5615c9475cc1d9b514218ae8"}, + {file = "contourpy-1.2.0-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:266270c6f6608340f6c9836a0fb9b367be61dde0c9a9a18d5ece97774105ff3e"}, + {file = "contourpy-1.2.0-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:dbd50d0a0539ae2e96e537553aff6d02c10ed165ef40c65b0e27e744a0f10af8"}, + {file = "contourpy-1.2.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:11f8d2554e52f459918f7b8e6aa20ec2a3bce35ce95c1f0ef4ba36fbda306df5"}, + {file = "contourpy-1.2.0-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:ce96dd400486e80ac7d195b2d800b03e3e6a787e2a522bfb83755938465a819e"}, + {file = "contourpy-1.2.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:6d3364b999c62f539cd403f8123ae426da946e142312a514162adb2addd8d808"}, + {file = "contourpy-1.2.0-cp39-cp39-win32.whl", hash = "sha256:1c88dfb9e0c77612febebb6ac69d44a8d81e3dc60f993215425b62c1161353f4"}, + {file = "contourpy-1.2.0-cp39-cp39-win_amd64.whl", hash = "sha256:78e6ad33cf2e2e80c5dfaaa0beec3d61face0fb650557100ee36db808bfa6843"}, + {file = "contourpy-1.2.0-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:be16975d94c320432657ad2402f6760990cb640c161ae6da1363051805fa8108"}, + {file = "contourpy-1.2.0-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b95a225d4948b26a28c08307a60ac00fb8671b14f2047fc5476613252a129776"}, + {file = "contourpy-1.2.0-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:0d7e03c0f9a4f90dc18d4e77e9ef4ec7b7bbb437f7f675be8e530d65ae6ef956"}, + {file = "contourpy-1.2.0.tar.gz", hash = "sha256:171f311cb758de7da13fc53af221ae47a5877be5a0843a9fe150818c51ed276a"}, +] + +[package.dependencies] +numpy = ">=1.20,<2.0" + +[package.extras] +bokeh = ["bokeh", "selenium"] +docs = ["furo", "sphinx (>=7.2)", "sphinx-copybutton"] +mypy = ["contourpy[bokeh,docs]", "docutils-stubs", "mypy (==1.6.1)", "types-Pillow"] +test = ["Pillow", "contourpy[test-no-images]", "matplotlib"] +test-no-images = ["pytest", "pytest-cov", "pytest-xdist", "wurlitzer"] + +[[package]] +name = "cycler" +version = "0.12.1" +description = "Composable style cycles" +optional = false +python-versions = ">=3.8" +files = [ + {file = "cycler-0.12.1-py3-none-any.whl", hash = "sha256:85cef7cff222d8644161529808465972e51340599459b8ac3ccbac5a854e0d30"}, + {file = "cycler-0.12.1.tar.gz", hash = "sha256:88bb128f02ba341da8ef447245a9e138fae777f6a23943da4540077d3601eb1c"}, +] + +[package.extras] +docs = ["ipython", "matplotlib", "numpydoc", "sphinx"] +tests = ["pytest", "pytest-cov", "pytest-xdist"] + [[package]] name = "datasets" version = "2.12.0" @@ -690,12 +787,12 @@ files = [ [[package]] name = "deepspeed" -version = "0.12.5" +version = "0.12.6" description = "DeepSpeed library" optional = false python-versions = "*" files = [ - {file = "deepspeed-0.12.5.tar.gz", hash = "sha256:7aca1e761f21792b49cbbb6b6ce6ef1cd5fb17d5738835aee3680b0a1c5a8234"}, + {file = "deepspeed-0.12.6.tar.gz", hash = "sha256:69ea07c65ef6414f9cd67746672f1c23b4b629dc14c9177de103ac0c5b2e0ce4"}, ] [package.dependencies] @@ -748,6 +845,17 @@ files = [ {file = "distlib-0.3.6.tar.gz", hash = "sha256:14bad2d9b04d3a36127ac97f30b12a19268f211063d8f8ee4f47108896e11b46"}, ] +[[package]] +name = "distro" +version = "1.9.0" +description = "Distro - an OS platform information API" +optional = false +python-versions = ">=3.6" +files = [ + {file = "distro-1.9.0-py3-none-any.whl", hash = "sha256:7bffd925d65168f85027d8da9af6bddab658135b840670a223589bc0c8ef02b2"}, + {file = "distro-1.9.0.tar.gz", hash = "sha256:2fa77c6fd8940f116ee1d6b94a2f90b13b5ea8d019b98bc8bafdcabcdd9bdbed"}, +] + [[package]] name = "dm-tree" version = "0.1.8" @@ -970,6 +1078,71 @@ typing-extensions = ">=4.1.1" all = ["matplotlib"] testing = ["atari-py (==0.2.5)", "clu", "einops", "gym (==0.18.3)", "jaxlib", "jraph (>=0.0.6dev0)", "ml-collections", "mypy", "nbstripout", "opencv-python", "pytest", "pytest-cov", "pytest-custom-exit-code", "pytest-xdist (==1.34.0)", "pytype", "sentencepiece", "tensorflow", "tensorflow-datasets", "tensorflow-text (>=2.11.0)", "torch"] +[[package]] +name = "fonttools" +version = "4.47.0" +description = "Tools to manipulate font files" +optional = false +python-versions = ">=3.8" +files = [ + {file = "fonttools-4.47.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:2d2404107626f97a221dc1a65b05396d2bb2ce38e435f64f26ed2369f68675d9"}, + {file = "fonttools-4.47.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:c01f409be619a9a0f5590389e37ccb58b47264939f0e8d58bfa1f3ba07d22671"}, + {file = "fonttools-4.47.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d986b66ff722ef675b7ee22fbe5947a41f60a61a4da15579d5e276d897fbc7fa"}, + {file = "fonttools-4.47.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e8acf6dd0434b211b3bd30d572d9e019831aae17a54016629fa8224783b22df8"}, + {file = "fonttools-4.47.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:495369c660e0c27233e3c572269cbe520f7f4978be675f990f4005937337d391"}, + {file = "fonttools-4.47.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:c59227d7ba5b232281c26ae04fac2c73a79ad0e236bca5c44aae904a18f14faf"}, + {file = "fonttools-4.47.0-cp310-cp310-win32.whl", hash = "sha256:59a6c8b71a245800e923cb684a2dc0eac19c56493e2f896218fcf2571ed28984"}, + {file = "fonttools-4.47.0-cp310-cp310-win_amd64.whl", hash = "sha256:52c82df66201f3a90db438d9d7b337c7c98139de598d0728fb99dab9fd0495ca"}, + {file = "fonttools-4.47.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:854421e328d47d70aa5abceacbe8eef231961b162c71cbe7ff3f47e235e2e5c5"}, + {file = "fonttools-4.47.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:511482df31cfea9f697930f61520f6541185fa5eeba2fa760fe72e8eee5af88b"}, + {file = "fonttools-4.47.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ce0e2c88c8c985b7b9a7efcd06511fb0a1fe3ddd9a6cd2895ef1dbf9059719d7"}, + {file = "fonttools-4.47.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e7a0a8848726956e9d9fb18c977a279013daadf0cbb6725d2015a6dd57527992"}, + {file = "fonttools-4.47.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:e869da810ae35afb3019baa0d0306cdbab4760a54909c89ad8904fa629991812"}, + {file = "fonttools-4.47.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:dd23848f877c3754f53a4903fb7a593ed100924f9b4bff7d5a4e2e8a7001ae11"}, + {file = "fonttools-4.47.0-cp311-cp311-win32.whl", hash = "sha256:bf1810635c00f7c45d93085611c995fc130009cec5abdc35b327156aa191f982"}, + {file = "fonttools-4.47.0-cp311-cp311-win_amd64.whl", hash = "sha256:61df4dee5d38ab65b26da8efd62d859a1eef7a34dcbc331299a28e24d04c59a7"}, + {file = "fonttools-4.47.0-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:e3f4d61f3a8195eac784f1d0c16c0a3105382c1b9a74d99ac4ba421da39a8826"}, + {file = "fonttools-4.47.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:174995f7b057e799355b393e97f4f93ef1f2197cbfa945e988d49b2a09ecbce8"}, + {file = "fonttools-4.47.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ea592e6a09b71cb7a7661dd93ac0b877a6228e2d677ebacbad0a4d118494c86d"}, + {file = "fonttools-4.47.0-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:40bdbe90b33897d9cc4a39f8e415b0fcdeae4c40a99374b8a4982f127ff5c767"}, + {file = "fonttools-4.47.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:843509ae9b93db5aaf1a6302085e30bddc1111d31e11d724584818f5b698f500"}, + {file = "fonttools-4.47.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:9acfa1cdc479e0dde528b61423855913d949a7f7fe09e276228298fef4589540"}, + {file = "fonttools-4.47.0-cp312-cp312-win32.whl", hash = "sha256:66c92ec7f95fd9732550ebedefcd190a8d81beaa97e89d523a0d17198a8bda4d"}, + {file = "fonttools-4.47.0-cp312-cp312-win_amd64.whl", hash = "sha256:e8fa20748de55d0021f83754b371432dca0439e02847962fc4c42a0e444c2d78"}, + {file = "fonttools-4.47.0-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:c75e19971209fbbce891ebfd1b10c37320a5a28e8d438861c21d35305aedb81c"}, + {file = "fonttools-4.47.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:e79f1a3970d25f692bbb8c8c2637e621a66c0d60c109ab48d4a160f50856deff"}, + {file = "fonttools-4.47.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:562681188c62c024fe2c611b32e08b8de2afa00c0c4e72bed47c47c318e16d5c"}, + {file = "fonttools-4.47.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a77a60315c33393b2bd29d538d1ef026060a63d3a49a9233b779261bad9c3f71"}, + {file = "fonttools-4.47.0-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:b4fabb8cc9422efae1a925160083fdcbab8fdc96a8483441eb7457235df625bd"}, + {file = "fonttools-4.47.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:2a78dba8c2a1e9d53a0fb5382979f024200dc86adc46a56cbb668a2249862fda"}, + {file = "fonttools-4.47.0-cp38-cp38-win32.whl", hash = "sha256:e6b968543fde4119231c12c2a953dcf83349590ca631ba8216a8edf9cd4d36a9"}, + {file = "fonttools-4.47.0-cp38-cp38-win_amd64.whl", hash = "sha256:4a9a51745c0439516d947480d4d884fa18bd1458e05b829e482b9269afa655bc"}, + {file = "fonttools-4.47.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:62d8ddb058b8e87018e5dc26f3258e2c30daad4c87262dfeb0e2617dd84750e6"}, + {file = "fonttools-4.47.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:5dde0eab40faaa5476133123f6a622a1cc3ac9b7af45d65690870620323308b4"}, + {file = "fonttools-4.47.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f4da089f6dfdb822293bde576916492cd708c37c2501c3651adde39804630538"}, + {file = "fonttools-4.47.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:253bb46bab970e8aae254cebf2ae3db98a4ef6bd034707aa68a239027d2b198d"}, + {file = "fonttools-4.47.0-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:1193fb090061efa2f9e2d8d743ae9850c77b66746a3b32792324cdce65784154"}, + {file = "fonttools-4.47.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:084511482dd265bce6dca24c509894062f0117e4e6869384d853f46c0e6d43be"}, + {file = "fonttools-4.47.0-cp39-cp39-win32.whl", hash = "sha256:97620c4af36e4c849e52661492e31dc36916df12571cb900d16960ab8e92a980"}, + {file = "fonttools-4.47.0-cp39-cp39-win_amd64.whl", hash = "sha256:e77bdf52185bdaf63d39f3e1ac3212e6cfa3ab07d509b94557a8902ce9c13c82"}, + {file = "fonttools-4.47.0-py3-none-any.whl", hash = "sha256:d6477ba902dd2d7adda7f0fd3bfaeb92885d45993c9e1928c9f28fc3961415f7"}, + {file = "fonttools-4.47.0.tar.gz", hash = "sha256:ec13a10715eef0e031858c1c23bfaee6cba02b97558e4a7bfa089dba4a8c2ebf"}, +] + +[package.extras] +all = ["brotli (>=1.0.1)", "brotlicffi (>=0.8.0)", "fs (>=2.2.0,<3)", "lxml (>=4.0,<5)", "lz4 (>=1.7.4.2)", "matplotlib", "munkres", "pycairo", "scipy", "skia-pathops (>=0.5.0)", "sympy", "uharfbuzz (>=0.23.0)", "unicodedata2 (>=15.1.0)", "xattr", "zopfli (>=0.1.4)"] +graphite = ["lz4 (>=1.7.4.2)"] +interpolatable = ["munkres", "pycairo", "scipy"] +lxml = ["lxml (>=4.0,<5)"] +pathops = ["skia-pathops (>=0.5.0)"] +plot = ["matplotlib"] +repacker = ["uharfbuzz (>=0.23.0)"] +symfont = ["sympy"] +type1 = ["xattr"] +ufo = ["fs (>=2.2.0,<3)"] +unicode = ["unicodedata2 (>=15.1.0)"] +woff = ["brotli (>=1.0.1)", "brotlicffi (>=0.8.0)", "zopfli (>=0.1.4)"] + [[package]] name = "frozenlist" version = "1.3.3" @@ -1233,6 +1406,17 @@ files = [ [package.extras] protobuf = ["grpcio-tools (>=1.54.2)"] +[[package]] +name = "h11" +version = "0.14.0" +description = "A pure-Python, bring-your-own-I/O implementation of HTTP/1.1" +optional = false +python-versions = ">=3.7" +files = [ + {file = "h11-0.14.0-py3-none-any.whl", hash = "sha256:e3fe4ac4b851c468cc8363d500db52c2ead036020723024a109d37346efaa761"}, + {file = "h11-0.14.0.tar.gz", hash = "sha256:8f19fbbe99e72420ff35c00b27a34cb9937e902a8b810e2c88300c6f0a3b699d"}, +] + [[package]] name = "hjson" version = "3.1.0" @@ -1244,6 +1428,51 @@ files = [ {file = "hjson-3.1.0.tar.gz", hash = "sha256:55af475a27cf83a7969c808399d7bccdec8fb836a07ddbd574587593b9cdcf75"}, ] +[[package]] +name = "httpcore" +version = "1.0.2" +description = "A minimal low-level HTTP client." +optional = false +python-versions = ">=3.8" +files = [ + {file = "httpcore-1.0.2-py3-none-any.whl", hash = "sha256:096cc05bca73b8e459a1fc3dcf585148f63e534eae4339559c9b8a8d6399acc7"}, + {file = "httpcore-1.0.2.tar.gz", hash = "sha256:9fc092e4799b26174648e54b74ed5f683132a464e95643b226e00c2ed2fa6535"}, +] + +[package.dependencies] +certifi = "*" +h11 = ">=0.13,<0.15" + +[package.extras] +asyncio = ["anyio (>=4.0,<5.0)"] +http2 = ["h2 (>=3,<5)"] +socks = ["socksio (==1.*)"] +trio = ["trio (>=0.22.0,<0.23.0)"] + +[[package]] +name = "httpx" +version = "0.26.0" +description = "The next generation HTTP client." +optional = false +python-versions = ">=3.8" +files = [ + {file = "httpx-0.26.0-py3-none-any.whl", hash = "sha256:8915f5a3627c4d47b73e8202457cb28f1266982d1159bd5779d86a80c0eab1cd"}, + {file = "httpx-0.26.0.tar.gz", hash = "sha256:451b55c30d5185ea6b23c2c793abf9bb237d2a7dfb901ced6ff69ad37ec1dfaf"}, +] + +[package.dependencies] +anyio = "*" +certifi = "*" +httpcore = "==1.*" +idna = "*" +sniffio = "*" + +[package.extras] +brotli = ["brotli", "brotlicffi"] +cli = ["click (==8.*)", "pygments (==2.*)", "rich (>=10,<14)"] +http2 = ["h2 (>=3,<5)"] +socks = ["socksio (==1.*)"] + [[package]] name = "huggingface-hub" version = "0.19.4" @@ -1568,6 +1797,119 @@ traitlets = ">=5.3" docs = ["myst-parser", "sphinx-autodoc-typehints", "sphinxcontrib-github-alt", "sphinxcontrib-spelling", "traitlets"] test = ["ipykernel", "pre-commit", "pytest", "pytest-cov", "pytest-timeout"] +[[package]] +name = "kiwisolver" +version = "1.4.5" +description = "A fast implementation of the Cassowary constraint solver" +optional = false +python-versions = ">=3.7" +files = [ + {file = "kiwisolver-1.4.5-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:05703cf211d585109fcd72207a31bb170a0f22144d68298dc5e61b3c946518af"}, + {file = "kiwisolver-1.4.5-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:146d14bebb7f1dc4d5fbf74f8a6cb15ac42baadee8912eb84ac0b3b2a3dc6ac3"}, + {file = "kiwisolver-1.4.5-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:6ef7afcd2d281494c0a9101d5c571970708ad911d028137cd558f02b851c08b4"}, + {file = "kiwisolver-1.4.5-cp310-cp310-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:9eaa8b117dc8337728e834b9c6e2611f10c79e38f65157c4c38e9400286f5cb1"}, + {file = "kiwisolver-1.4.5-cp310-cp310-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:ec20916e7b4cbfb1f12380e46486ec4bcbaa91a9c448b97023fde0d5bbf9e4ff"}, + {file = "kiwisolver-1.4.5-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:39b42c68602539407884cf70d6a480a469b93b81b7701378ba5e2328660c847a"}, + {file = "kiwisolver-1.4.5-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:aa12042de0171fad672b6c59df69106d20d5596e4f87b5e8f76df757a7c399aa"}, + {file = "kiwisolver-1.4.5-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2a40773c71d7ccdd3798f6489aaac9eee213d566850a9533f8d26332d626b82c"}, + {file = "kiwisolver-1.4.5-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:19df6e621f6d8b4b9c4d45f40a66839294ff2bb235e64d2178f7522d9170ac5b"}, + {file = "kiwisolver-1.4.5-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:83d78376d0d4fd884e2c114d0621624b73d2aba4e2788182d286309ebdeed770"}, + {file = "kiwisolver-1.4.5-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:e391b1f0a8a5a10ab3b9bb6afcfd74f2175f24f8975fb87ecae700d1503cdee0"}, + {file = "kiwisolver-1.4.5-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:852542f9481f4a62dbb5dd99e8ab7aedfeb8fb6342349a181d4036877410f525"}, + {file = "kiwisolver-1.4.5-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:59edc41b24031bc25108e210c0def6f6c2191210492a972d585a06ff246bb79b"}, + {file = "kiwisolver-1.4.5-cp310-cp310-win32.whl", hash = "sha256:a6aa6315319a052b4ee378aa171959c898a6183f15c1e541821c5c59beaa0238"}, + {file = "kiwisolver-1.4.5-cp310-cp310-win_amd64.whl", hash = "sha256:d0ef46024e6a3d79c01ff13801cb19d0cad7fd859b15037aec74315540acc276"}, + {file = "kiwisolver-1.4.5-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:11863aa14a51fd6ec28688d76f1735f8f69ab1fabf388851a595d0721af042f5"}, + {file = "kiwisolver-1.4.5-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:8ab3919a9997ab7ef2fbbed0cc99bb28d3c13e6d4b1ad36e97e482558a91be90"}, + {file = "kiwisolver-1.4.5-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:fcc700eadbbccbf6bc1bcb9dbe0786b4b1cb91ca0dcda336eef5c2beed37b797"}, + {file = "kiwisolver-1.4.5-cp311-cp311-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:dfdd7c0b105af050eb3d64997809dc21da247cf44e63dc73ff0fd20b96be55a9"}, + {file = "kiwisolver-1.4.5-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:76c6a5964640638cdeaa0c359382e5703e9293030fe730018ca06bc2010c4437"}, + {file = "kiwisolver-1.4.5-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:bbea0db94288e29afcc4c28afbf3a7ccaf2d7e027489c449cf7e8f83c6346eb9"}, + {file = "kiwisolver-1.4.5-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:ceec1a6bc6cab1d6ff5d06592a91a692f90ec7505d6463a88a52cc0eb58545da"}, + {file = "kiwisolver-1.4.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:040c1aebeda72197ef477a906782b5ab0d387642e93bda547336b8957c61022e"}, + {file = "kiwisolver-1.4.5-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:f91de7223d4c7b793867797bacd1ee53bfe7359bd70d27b7b58a04efbb9436c8"}, + {file = "kiwisolver-1.4.5-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:faae4860798c31530dd184046a900e652c95513796ef51a12bc086710c2eec4d"}, + {file = "kiwisolver-1.4.5-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:b0157420efcb803e71d1b28e2c287518b8808b7cf1ab8af36718fd0a2c453eb0"}, + {file = "kiwisolver-1.4.5-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:06f54715b7737c2fecdbf140d1afb11a33d59508a47bf11bb38ecf21dc9ab79f"}, + {file = "kiwisolver-1.4.5-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:fdb7adb641a0d13bdcd4ef48e062363d8a9ad4a182ac7647ec88f695e719ae9f"}, + {file = "kiwisolver-1.4.5-cp311-cp311-win32.whl", hash = "sha256:bb86433b1cfe686da83ce32a9d3a8dd308e85c76b60896d58f082136f10bffac"}, + {file = "kiwisolver-1.4.5-cp311-cp311-win_amd64.whl", hash = "sha256:6c08e1312a9cf1074d17b17728d3dfce2a5125b2d791527f33ffbe805200a355"}, + {file = "kiwisolver-1.4.5-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:32d5cf40c4f7c7b3ca500f8985eb3fb3a7dfc023215e876f207956b5ea26632a"}, + {file = "kiwisolver-1.4.5-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:f846c260f483d1fd217fe5ed7c173fb109efa6b1fc8381c8b7552c5781756192"}, + {file = "kiwisolver-1.4.5-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:5ff5cf3571589b6d13bfbfd6bcd7a3f659e42f96b5fd1c4830c4cf21d4f5ef45"}, + {file = "kiwisolver-1.4.5-cp312-cp312-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7269d9e5f1084a653d575c7ec012ff57f0c042258bf5db0954bf551c158466e7"}, + {file = "kiwisolver-1.4.5-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:da802a19d6e15dffe4b0c24b38b3af68e6c1a68e6e1d8f30148c83864f3881db"}, + {file = "kiwisolver-1.4.5-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3aba7311af82e335dd1e36ffff68aaca609ca6290c2cb6d821a39aa075d8e3ff"}, + {file = "kiwisolver-1.4.5-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:763773d53f07244148ccac5b084da5adb90bfaee39c197554f01b286cf869228"}, + {file = "kiwisolver-1.4.5-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2270953c0d8cdab5d422bee7d2007f043473f9d2999631c86a223c9db56cbd16"}, + {file = "kiwisolver-1.4.5-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:d099e745a512f7e3bbe7249ca835f4d357c586d78d79ae8f1dcd4d8adeb9bda9"}, + {file = "kiwisolver-1.4.5-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:74db36e14a7d1ce0986fa104f7d5637aea5c82ca6326ed0ec5694280942d1162"}, + {file = "kiwisolver-1.4.5-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:7e5bab140c309cb3a6ce373a9e71eb7e4873c70c2dda01df6820474f9889d6d4"}, + {file = "kiwisolver-1.4.5-cp312-cp312-musllinux_1_1_s390x.whl", hash = "sha256:0f114aa76dc1b8f636d077979c0ac22e7cd8f3493abbab152f20eb8d3cda71f3"}, + {file = "kiwisolver-1.4.5-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:88a2df29d4724b9237fc0c6eaf2a1adae0cdc0b3e9f4d8e7dc54b16812d2d81a"}, + {file = "kiwisolver-1.4.5-cp312-cp312-win32.whl", hash = "sha256:72d40b33e834371fd330fb1472ca19d9b8327acb79a5821d4008391db8e29f20"}, + {file = "kiwisolver-1.4.5-cp312-cp312-win_amd64.whl", hash = "sha256:2c5674c4e74d939b9d91dda0fae10597ac7521768fec9e399c70a1f27e2ea2d9"}, + {file = "kiwisolver-1.4.5-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:3a2b053a0ab7a3960c98725cfb0bf5b48ba82f64ec95fe06f1d06c99b552e130"}, + {file = "kiwisolver-1.4.5-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3cd32d6c13807e5c66a7cbb79f90b553642f296ae4518a60d8d76243b0ad2898"}, + {file = "kiwisolver-1.4.5-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:59ec7b7c7e1a61061850d53aaf8e93db63dce0c936db1fda2658b70e4a1be709"}, + {file = "kiwisolver-1.4.5-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:da4cfb373035def307905d05041c1d06d8936452fe89d464743ae7fb8371078b"}, + {file = "kiwisolver-1.4.5-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:2400873bccc260b6ae184b2b8a4fec0e4082d30648eadb7c3d9a13405d861e89"}, + {file = "kiwisolver-1.4.5-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:1b04139c4236a0f3aff534479b58f6f849a8b351e1314826c2d230849ed48985"}, + {file = "kiwisolver-1.4.5-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:4e66e81a5779b65ac21764c295087de82235597a2293d18d943f8e9e32746265"}, + {file = "kiwisolver-1.4.5-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:7931d8f1f67c4be9ba1dd9c451fb0eeca1a25b89e4d3f89e828fe12a519b782a"}, + {file = "kiwisolver-1.4.5-cp37-cp37m-musllinux_1_1_ppc64le.whl", hash = "sha256:b3f7e75f3015df442238cca659f8baa5f42ce2a8582727981cbfa15fee0ee205"}, + {file = "kiwisolver-1.4.5-cp37-cp37m-musllinux_1_1_s390x.whl", hash = "sha256:bbf1d63eef84b2e8c89011b7f2235b1e0bf7dacc11cac9431fc6468e99ac77fb"}, + {file = "kiwisolver-1.4.5-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:4c380469bd3f970ef677bf2bcba2b6b0b4d5c75e7a020fb863ef75084efad66f"}, + {file = "kiwisolver-1.4.5-cp37-cp37m-win32.whl", hash = "sha256:9408acf3270c4b6baad483865191e3e582b638b1654a007c62e3efe96f09a9a3"}, + {file = "kiwisolver-1.4.5-cp37-cp37m-win_amd64.whl", hash = "sha256:5b94529f9b2591b7af5f3e0e730a4e0a41ea174af35a4fd067775f9bdfeee01a"}, + {file = "kiwisolver-1.4.5-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:11c7de8f692fc99816e8ac50d1d1aef4f75126eefc33ac79aac02c099fd3db71"}, + {file = "kiwisolver-1.4.5-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:53abb58632235cd154176ced1ae8f0d29a6657aa1aa9decf50b899b755bc2b93"}, + {file = "kiwisolver-1.4.5-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:88b9f257ca61b838b6f8094a62418421f87ac2a1069f7e896c36a7d86b5d4c29"}, + {file = "kiwisolver-1.4.5-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3195782b26fc03aa9c6913d5bad5aeb864bdc372924c093b0f1cebad603dd712"}, + {file = "kiwisolver-1.4.5-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:fc579bf0f502e54926519451b920e875f433aceb4624a3646b3252b5caa9e0b6"}, + {file = "kiwisolver-1.4.5-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5a580c91d686376f0f7c295357595c5a026e6cbc3d77b7c36e290201e7c11ecb"}, + {file = "kiwisolver-1.4.5-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:cfe6ab8da05c01ba6fbea630377b5da2cd9bcbc6338510116b01c1bc939a2c18"}, + {file = "kiwisolver-1.4.5-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:d2e5a98f0ec99beb3c10e13b387f8db39106d53993f498b295f0c914328b1333"}, + {file = "kiwisolver-1.4.5-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:a51a263952b1429e429ff236d2f5a21c5125437861baeed77f5e1cc2d2c7c6da"}, + {file = "kiwisolver-1.4.5-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:3edd2fa14e68c9be82c5b16689e8d63d89fe927e56debd6e1dbce7a26a17f81b"}, + {file = "kiwisolver-1.4.5-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:74d1b44c6cfc897df648cc9fdaa09bc3e7679926e6f96df05775d4fb3946571c"}, + {file = "kiwisolver-1.4.5-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:76d9289ed3f7501012e05abb8358bbb129149dbd173f1f57a1bf1c22d19ab7cc"}, + {file = "kiwisolver-1.4.5-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:92dea1ffe3714fa8eb6a314d2b3c773208d865a0e0d35e713ec54eea08a66250"}, + {file = "kiwisolver-1.4.5-cp38-cp38-win32.whl", hash = "sha256:5c90ae8c8d32e472be041e76f9d2f2dbff4d0b0be8bd4041770eddb18cf49a4e"}, + {file = "kiwisolver-1.4.5-cp38-cp38-win_amd64.whl", hash = "sha256:c7940c1dc63eb37a67721b10d703247552416f719c4188c54e04334321351ced"}, + {file = "kiwisolver-1.4.5-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:9407b6a5f0d675e8a827ad8742e1d6b49d9c1a1da5d952a67d50ef5f4170b18d"}, + {file = "kiwisolver-1.4.5-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:15568384086b6df3c65353820a4473575dbad192e35010f622c6ce3eebd57af9"}, + {file = "kiwisolver-1.4.5-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:0dc9db8e79f0036e8173c466d21ef18e1befc02de8bf8aa8dc0813a6dc8a7046"}, + {file = "kiwisolver-1.4.5-cp39-cp39-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:cdc8a402aaee9a798b50d8b827d7ecf75edc5fb35ea0f91f213ff927c15f4ff0"}, + {file = "kiwisolver-1.4.5-cp39-cp39-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:6c3bd3cde54cafb87d74d8db50b909705c62b17c2099b8f2e25b461882e544ff"}, + {file = "kiwisolver-1.4.5-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:955e8513d07a283056b1396e9a57ceddbd272d9252c14f154d450d227606eb54"}, + {file = "kiwisolver-1.4.5-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:346f5343b9e3f00b8db8ba359350eb124b98c99efd0b408728ac6ebf38173958"}, + {file = "kiwisolver-1.4.5-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b9098e0049e88c6a24ff64545cdfc50807818ba6c1b739cae221bbbcbc58aad3"}, + {file = "kiwisolver-1.4.5-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:00bd361b903dc4bbf4eb165f24d1acbee754fce22ded24c3d56eec268658a5cf"}, + {file = "kiwisolver-1.4.5-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:7b8b454bac16428b22560d0a1cf0a09875339cab69df61d7805bf48919415901"}, + {file = "kiwisolver-1.4.5-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:f1d072c2eb0ad60d4c183f3fb44ac6f73fb7a8f16a2694a91f988275cbf352f9"}, + {file = "kiwisolver-1.4.5-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:31a82d498054cac9f6d0b53d02bb85811185bcb477d4b60144f915f3b3126342"}, + {file = "kiwisolver-1.4.5-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:6512cb89e334e4700febbffaaa52761b65b4f5a3cf33f960213d5656cea36a77"}, + {file = "kiwisolver-1.4.5-cp39-cp39-win32.whl", hash = "sha256:9db8ea4c388fdb0f780fe91346fd438657ea602d58348753d9fb265ce1bca67f"}, + {file = "kiwisolver-1.4.5-cp39-cp39-win_amd64.whl", hash = "sha256:59415f46a37f7f2efeec758353dd2eae1b07640d8ca0f0c42548ec4125492635"}, + {file = "kiwisolver-1.4.5-pp37-pypy37_pp73-macosx_10_9_x86_64.whl", hash = "sha256:5c7b3b3a728dc6faf3fc372ef24f21d1e3cee2ac3e9596691d746e5a536de920"}, + {file = "kiwisolver-1.4.5-pp37-pypy37_pp73-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:620ced262a86244e2be10a676b646f29c34537d0d9cc8eb26c08f53d98013390"}, + {file = "kiwisolver-1.4.5-pp37-pypy37_pp73-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:378a214a1e3bbf5ac4a8708304318b4f890da88c9e6a07699c4ae7174c09a68d"}, + {file = "kiwisolver-1.4.5-pp37-pypy37_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:aaf7be1207676ac608a50cd08f102f6742dbfc70e8d60c4db1c6897f62f71523"}, + {file = "kiwisolver-1.4.5-pp37-pypy37_pp73-win_amd64.whl", hash = "sha256:ba55dce0a9b8ff59495ddd050a0225d58bd0983d09f87cfe2b6aec4f2c1234e4"}, + {file = "kiwisolver-1.4.5-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:fd32ea360bcbb92d28933fc05ed09bffcb1704ba3fc7942e81db0fd4f81a7892"}, + {file = "kiwisolver-1.4.5-pp38-pypy38_pp73-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:5e7139af55d1688f8b960ee9ad5adafc4ac17c1c473fe07133ac092310d76544"}, + {file = "kiwisolver-1.4.5-pp38-pypy38_pp73-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:dced8146011d2bc2e883f9bd68618b8247387f4bbec46d7392b3c3b032640126"}, + {file = "kiwisolver-1.4.5-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c9bf3325c47b11b2e51bca0824ea217c7cd84491d8ac4eefd1e409705ef092bd"}, + {file = "kiwisolver-1.4.5-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:5794cf59533bc3f1b1c821f7206a3617999db9fbefc345360aafe2e067514929"}, + {file = "kiwisolver-1.4.5-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:e368f200bbc2e4f905b8e71eb38b3c04333bddaa6a2464a6355487b02bb7fb09"}, + {file = "kiwisolver-1.4.5-pp39-pypy39_pp73-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e5d706eba36b4c4d5bc6c6377bb6568098765e990cfc21ee16d13963fab7b3e7"}, + {file = "kiwisolver-1.4.5-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:85267bd1aa8880a9c88a8cb71e18d3d64d2751a790e6ca6c27b8ccc724bcd5ad"}, + {file = "kiwisolver-1.4.5-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:210ef2c3a1f03272649aff1ef992df2e724748918c4bc2d5a90352849eb40bea"}, + {file = "kiwisolver-1.4.5-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:11d011a7574eb3b82bcc9c1a1d35c1d7075677fdd15de527d91b46bd35e935ee"}, + {file = "kiwisolver-1.4.5.tar.gz", hash = "sha256:e57e563a57fb22a142da34f38acc2fc1a5c864bc29ca1517a88abc963e60d6ec"}, +] + [[package]] name = "markdown" version = "3.4.3" @@ -1678,6 +2020,55 @@ files = [ {file = "MarkupSafe-2.1.3.tar.gz", hash = "sha256:af598ed32d6ae86f1b747b82783958b1a4ab8f617b06fe68795c7f026abbdcad"}, ] +[[package]] +name = "matplotlib" +version = "3.8.2" +description = "Python plotting package" +optional = false +python-versions = ">=3.9" +files = [ + {file = "matplotlib-3.8.2-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:09796f89fb71a0c0e1e2f4bdaf63fb2cefc84446bb963ecdeb40dfee7dfa98c7"}, + {file = "matplotlib-3.8.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:6f9c6976748a25e8b9be51ea028df49b8e561eed7809146da7a47dbecebab367"}, + {file = "matplotlib-3.8.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b78e4f2cedf303869b782071b55fdde5987fda3038e9d09e58c91cc261b5ad18"}, + {file = "matplotlib-3.8.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4e208f46cf6576a7624195aa047cb344a7f802e113bb1a06cfd4bee431de5e31"}, + {file = "matplotlib-3.8.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:46a569130ff53798ea5f50afce7406e91fdc471ca1e0e26ba976a8c734c9427a"}, + {file = "matplotlib-3.8.2-cp310-cp310-win_amd64.whl", hash = "sha256:830f00640c965c5b7f6bc32f0d4ce0c36dfe0379f7dd65b07a00c801713ec40a"}, + {file = "matplotlib-3.8.2-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:d86593ccf546223eb75a39b44c32788e6f6440d13cfc4750c1c15d0fcb850b63"}, + {file = "matplotlib-3.8.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:9a5430836811b7652991939012f43d2808a2db9b64ee240387e8c43e2e5578c8"}, + {file = "matplotlib-3.8.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b9576723858a78751d5aacd2497b8aef29ffea6d1c95981505877f7ac28215c6"}, + {file = "matplotlib-3.8.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5ba9cbd8ac6cf422f3102622b20f8552d601bf8837e49a3afed188d560152788"}, + {file = "matplotlib-3.8.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:03f9d160a29e0b65c0790bb07f4f45d6a181b1ac33eb1bb0dd225986450148f0"}, + {file = "matplotlib-3.8.2-cp311-cp311-win_amd64.whl", hash = "sha256:3773002da767f0a9323ba1a9b9b5d00d6257dbd2a93107233167cfb581f64717"}, + {file = "matplotlib-3.8.2-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:4c318c1e95e2f5926fba326f68177dee364aa791d6df022ceb91b8221bd0a627"}, + {file = "matplotlib-3.8.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:091275d18d942cf1ee9609c830a1bc36610607d8223b1b981c37d5c9fc3e46a4"}, + {file = "matplotlib-3.8.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1b0f3b8ea0e99e233a4bcc44590f01604840d833c280ebb8fe5554fd3e6cfe8d"}, + {file = "matplotlib-3.8.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d7b1704a530395aaf73912be741c04d181f82ca78084fbd80bc737be04848331"}, + {file = "matplotlib-3.8.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:533b0e3b0c6768eef8cbe4b583731ce25a91ab54a22f830db2b031e83cca9213"}, + {file = "matplotlib-3.8.2-cp312-cp312-win_amd64.whl", hash = "sha256:0f4fc5d72b75e2c18e55eb32292659cf731d9d5b312a6eb036506304f4675630"}, + {file = "matplotlib-3.8.2-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:deaed9ad4da0b1aea77fe0aa0cebb9ef611c70b3177be936a95e5d01fa05094f"}, + {file = "matplotlib-3.8.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:172f4d0fbac3383d39164c6caafd3255ce6fa58f08fc392513a0b1d3b89c4f89"}, + {file = "matplotlib-3.8.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c7d36c2209d9136cd8e02fab1c0ddc185ce79bc914c45054a9f514e44c787917"}, + {file = "matplotlib-3.8.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5864bdd7da445e4e5e011b199bb67168cdad10b501750367c496420f2ad00843"}, + {file = "matplotlib-3.8.2-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:ef8345b48e95cee45ff25192ed1f4857273117917a4dcd48e3905619bcd9c9b8"}, + {file = "matplotlib-3.8.2-cp39-cp39-win_amd64.whl", hash = "sha256:7c48d9e221b637c017232e3760ed30b4e8d5dfd081daf327e829bf2a72c731b4"}, + {file = "matplotlib-3.8.2-pp39-pypy39_pp73-macosx_10_12_x86_64.whl", hash = "sha256:aa11b3c6928a1e496c1a79917d51d4cd5d04f8a2e75f21df4949eeefdf697f4b"}, + {file = "matplotlib-3.8.2-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d1095fecf99eeb7384dabad4bf44b965f929a5f6079654b681193edf7169ec20"}, + {file = "matplotlib-3.8.2-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:bddfb1db89bfaa855912261c805bd0e10218923cc262b9159a49c29a7a1c1afa"}, + {file = "matplotlib-3.8.2.tar.gz", hash = "sha256:01a978b871b881ee76017152f1f1a0cbf6bd5f7b8ff8c96df0df1bd57d8755a1"}, +] + +[package.dependencies] +contourpy = ">=1.0.1" +cycler = ">=0.10" +fonttools = ">=4.22.0" +importlib-resources = {version = ">=3.2.0", markers = "python_version < \"3.10\""} +kiwisolver = ">=1.3.1" +numpy = ">=1.21,<2" +packaging = ">=20.0" +pillow = ">=8" +pyparsing = ">=2.3.1" +python-dateutil = ">=2.7" + [[package]] name = "matplotlib-inline" version = "0.1.6" @@ -2297,6 +2688,29 @@ rsa = ["cryptography (>=3.0.0)"] signals = ["blinker (>=1.4.0)"] signedtoken = ["cryptography (>=3.0.0)", "pyjwt (>=2.0.0,<3)"] +[[package]] +name = "openai" +version = "1.6.1" +description = "The official Python library for the openai API" +optional = false +python-versions = ">=3.7.1" +files = [ + {file = "openai-1.6.1-py3-none-any.whl", hash = "sha256:bc9f774838d67ac29fb24cdeb2d58faf57de8b311085dcd1348f7aa02a96c7ee"}, + {file = "openai-1.6.1.tar.gz", hash = "sha256:d553ca9dbf9486b08e75b09e8671e4f638462aaadccfced632bf490fc3d75fa2"}, +] + +[package.dependencies] +anyio = ">=3.5.0,<5" +distro = ">=1.7.0,<2" +httpx = ">=0.23.0,<1" +pydantic = ">=1.9.0,<3" +sniffio = "*" +tqdm = ">4" +typing-extensions = ">=4.7,<5" + +[package.extras] +datalib = ["numpy (>=1)", "pandas (>=1.2.3)", "pandas-stubs (>=1.1.0.11)"] + [[package]] name = "opt-einsum" version = "3.3.0" @@ -2500,6 +2914,73 @@ files = [ {file = "pickleshare-0.7.5.tar.gz", hash = "sha256:87683d47965c1da65cdacaf31c8441d12b8044cdec9aca500cd78fc2c683afca"}, ] +[[package]] +name = "pillow" +version = "10.1.0" +description = "Python Imaging Library (Fork)" +optional = false +python-versions = ">=3.8" +files = [ + {file = "Pillow-10.1.0-cp310-cp310-macosx_10_10_x86_64.whl", hash = "sha256:1ab05f3db77e98f93964697c8efc49c7954b08dd61cff526b7f2531a22410106"}, + {file = "Pillow-10.1.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:6932a7652464746fcb484f7fc3618e6503d2066d853f68a4bd97193a3996e273"}, + {file = "Pillow-10.1.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a5f63b5a68daedc54c7c3464508d8c12075e56dcfbd42f8c1bf40169061ae666"}, + {file = "Pillow-10.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c0949b55eb607898e28eaccb525ab104b2d86542a85c74baf3a6dc24002edec2"}, + {file = "Pillow-10.1.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:ae88931f93214777c7a3aa0a8f92a683f83ecde27f65a45f95f22d289a69e593"}, + {file = "Pillow-10.1.0-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:b0eb01ca85b2361b09480784a7931fc648ed8b7836f01fb9241141b968feb1db"}, + {file = "Pillow-10.1.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:d27b5997bdd2eb9fb199982bb7eb6164db0426904020dc38c10203187ae2ff2f"}, + {file = "Pillow-10.1.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:7df5608bc38bd37ef585ae9c38c9cd46d7c81498f086915b0f97255ea60c2818"}, + {file = "Pillow-10.1.0-cp310-cp310-win_amd64.whl", hash = "sha256:41f67248d92a5e0a2076d3517d8d4b1e41a97e2df10eb8f93106c89107f38b57"}, + {file = "Pillow-10.1.0-cp311-cp311-macosx_10_10_x86_64.whl", hash = "sha256:1fb29c07478e6c06a46b867e43b0bcdb241b44cc52be9bc25ce5944eed4648e7"}, + {file = "Pillow-10.1.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:2cdc65a46e74514ce742c2013cd4a2d12e8553e3a2563c64879f7c7e4d28bce7"}, + {file = "Pillow-10.1.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:50d08cd0a2ecd2a8657bd3d82c71efd5a58edb04d9308185d66c3a5a5bed9610"}, + {file = "Pillow-10.1.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:062a1610e3bc258bff2328ec43f34244fcec972ee0717200cb1425214fe5b839"}, + {file = "Pillow-10.1.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:61f1a9d247317fa08a308daaa8ee7b3f760ab1809ca2da14ecc88ae4257d6172"}, + {file = "Pillow-10.1.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:a646e48de237d860c36e0db37ecaecaa3619e6f3e9d5319e527ccbc8151df061"}, + {file = "Pillow-10.1.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:47e5bf85b80abc03be7455c95b6d6e4896a62f6541c1f2ce77a7d2bb832af262"}, + {file = "Pillow-10.1.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:a92386125e9ee90381c3369f57a2a50fa9e6aa8b1cf1d9c4b200d41a7dd8e992"}, + {file = "Pillow-10.1.0-cp311-cp311-win_amd64.whl", hash = "sha256:0f7c276c05a9767e877a0b4c5050c8bee6a6d960d7f0c11ebda6b99746068c2a"}, + {file = "Pillow-10.1.0-cp312-cp312-macosx_10_10_x86_64.whl", hash = "sha256:a89b8312d51715b510a4fe9fc13686283f376cfd5abca8cd1c65e4c76e21081b"}, + {file = "Pillow-10.1.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:00f438bb841382b15d7deb9a05cc946ee0f2c352653c7aa659e75e592f6fa17d"}, + {file = "Pillow-10.1.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3d929a19f5469b3f4df33a3df2983db070ebb2088a1e145e18facbc28cae5b27"}, + {file = "Pillow-10.1.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9a92109192b360634a4489c0c756364c0c3a2992906752165ecb50544c251312"}, + {file = "Pillow-10.1.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:0248f86b3ea061e67817c47ecbe82c23f9dd5d5226200eb9090b3873d3ca32de"}, + {file = "Pillow-10.1.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:9882a7451c680c12f232a422730f986a1fcd808da0fd428f08b671237237d651"}, + {file = "Pillow-10.1.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:1c3ac5423c8c1da5928aa12c6e258921956757d976405e9467c5f39d1d577a4b"}, + {file = "Pillow-10.1.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:806abdd8249ba3953c33742506fe414880bad78ac25cc9a9b1c6ae97bedd573f"}, + {file = "Pillow-10.1.0-cp312-cp312-win_amd64.whl", hash = "sha256:eaed6977fa73408b7b8a24e8b14e59e1668cfc0f4c40193ea7ced8e210adf996"}, + {file = "Pillow-10.1.0-cp38-cp38-macosx_10_10_x86_64.whl", hash = "sha256:fe1e26e1ffc38be097f0ba1d0d07fcade2bcfd1d023cda5b29935ae8052bd793"}, + {file = "Pillow-10.1.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:7a7e3daa202beb61821c06d2517428e8e7c1aab08943e92ec9e5755c2fc9ba5e"}, + {file = "Pillow-10.1.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:24fadc71218ad2b8ffe437b54876c9382b4a29e030a05a9879f615091f42ffc2"}, + {file = "Pillow-10.1.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fa1d323703cfdac2036af05191b969b910d8f115cf53093125e4058f62012c9a"}, + {file = "Pillow-10.1.0-cp38-cp38-manylinux_2_28_aarch64.whl", hash = "sha256:912e3812a1dbbc834da2b32299b124b5ddcb664ed354916fd1ed6f193f0e2d01"}, + {file = "Pillow-10.1.0-cp38-cp38-manylinux_2_28_x86_64.whl", hash = "sha256:7dbaa3c7de82ef37e7708521be41db5565004258ca76945ad74a8e998c30af8d"}, + {file = "Pillow-10.1.0-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:9d7bc666bd8c5a4225e7ac71f2f9d12466ec555e89092728ea0f5c0c2422ea80"}, + {file = "Pillow-10.1.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:baada14941c83079bf84c037e2d8b7506ce201e92e3d2fa0d1303507a8538212"}, + {file = "Pillow-10.1.0-cp38-cp38-win_amd64.whl", hash = "sha256:2ef6721c97894a7aa77723740a09547197533146fba8355e86d6d9a4a1056b14"}, + {file = "Pillow-10.1.0-cp39-cp39-macosx_10_10_x86_64.whl", hash = "sha256:0a026c188be3b443916179f5d04548092e253beb0c3e2ee0a4e2cdad72f66099"}, + {file = "Pillow-10.1.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:04f6f6149f266a100374ca3cc368b67fb27c4af9f1cc8cb6306d849dcdf12616"}, + {file = "Pillow-10.1.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bb40c011447712d2e19cc261c82655f75f32cb724788df315ed992a4d65696bb"}, + {file = "Pillow-10.1.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1a8413794b4ad9719346cd9306118450b7b00d9a15846451549314a58ac42219"}, + {file = "Pillow-10.1.0-cp39-cp39-manylinux_2_28_aarch64.whl", hash = "sha256:c9aeea7b63edb7884b031a35305629a7593272b54f429a9869a4f63a1bf04c34"}, + {file = "Pillow-10.1.0-cp39-cp39-manylinux_2_28_x86_64.whl", hash = "sha256:b4005fee46ed9be0b8fb42be0c20e79411533d1fd58edabebc0dd24626882cfd"}, + {file = "Pillow-10.1.0-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:4d0152565c6aa6ebbfb1e5d8624140a440f2b99bf7afaafbdbf6430426497f28"}, + {file = "Pillow-10.1.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:d921bc90b1defa55c9917ca6b6b71430e4286fc9e44c55ead78ca1a9f9eba5f2"}, + {file = "Pillow-10.1.0-cp39-cp39-win_amd64.whl", hash = "sha256:cfe96560c6ce2f4c07d6647af2d0f3c54cc33289894ebd88cfbb3bcd5391e256"}, + {file = "Pillow-10.1.0-pp310-pypy310_pp73-macosx_10_10_x86_64.whl", hash = "sha256:937bdc5a7f5343d1c97dc98149a0be7eb9704e937fe3dc7140e229ae4fc572a7"}, + {file = "Pillow-10.1.0-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b1c25762197144e211efb5f4e8ad656f36c8d214d390585d1d21281f46d556ba"}, + {file = "Pillow-10.1.0-pp310-pypy310_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:afc8eef765d948543a4775f00b7b8c079b3321d6b675dde0d02afa2ee23000b4"}, + {file = "Pillow-10.1.0-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:883f216eac8712b83a63f41b76ddfb7b2afab1b74abbb413c5df6680f071a6b9"}, + {file = "Pillow-10.1.0-pp39-pypy39_pp73-macosx_10_10_x86_64.whl", hash = "sha256:b920e4d028f6442bea9a75b7491c063f0b9a3972520731ed26c83e254302eb1e"}, + {file = "Pillow-10.1.0-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1c41d960babf951e01a49c9746f92c5a7e0d939d1652d7ba30f6b3090f27e412"}, + {file = "Pillow-10.1.0-pp39-pypy39_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:1fafabe50a6977ac70dfe829b2d5735fd54e190ab55259ec8aea4aaea412fa0b"}, + {file = "Pillow-10.1.0-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:3b834f4b16173e5b92ab6566f0473bfb09f939ba14b23b8da1f54fa63e4b623f"}, + {file = "Pillow-10.1.0.tar.gz", hash = "sha256:e6bf8de6c36ed96c86ea3b6e1d5273c53f46ef518a062464cd7ef5dd2cf92e38"}, +] + +[package.extras] +docs = ["furo", "olefile", "sphinx (>=2.4)", "sphinx-copybutton", "sphinx-inline-tabs", "sphinx-removed-in", "sphinxext-opengraph"] +tests = ["check-manifest", "coverage", "defusedxml", "markdown2", "olefile", "packaging", "pyroma", "pytest", "pytest-cov", "pytest-timeout"] + [[package]] name = "platformdirs" version = "3.5.3" @@ -2880,6 +3361,20 @@ files = [ {file = "pynvml-11.5.0.tar.gz", hash = "sha256:d027b21b95b1088b9fc278117f9f61b7c67f8e33a787e9f83f735f0f71ac32d0"}, ] +[[package]] +name = "pyparsing" +version = "3.1.1" +description = "pyparsing module - Classes and methods to define and execute parsing grammars" +optional = false +python-versions = ">=3.6.8" +files = [ + {file = "pyparsing-3.1.1-py3-none-any.whl", hash = "sha256:32c7c0b711493c72ff18a981d24f28aaf9c1fb7ed5e9667c9e84e3db623bdbfb"}, + {file = "pyparsing-3.1.1.tar.gz", hash = "sha256:ede28a1a32462f5a9705e07aea48001a08f7cf81a021585011deba701581a0db"}, +] + +[package.extras] +diagrams = ["jinja2", "railroad-diagrams"] + [[package]] name = "pytest" version = "7.4.0" @@ -3272,7 +3767,6 @@ files = [ [package.dependencies] markdown-it-py = ">=2.2.0" pygments = ">=2.13.0,<3.0.0" -typing-extensions = {version = ">=4.0.0,<5.0", markers = "python_version < \"3.9\""} [package.extras] jupyter = ["ipywidgets (>=7.5.1,<9)"] @@ -3580,6 +4074,17 @@ files = [ {file = "smmap-5.0.0.tar.gz", hash = "sha256:c840e62059cd3be204b0c9c9f74be2c09d5648eddd4580d9314c3ecde0b30936"}, ] +[[package]] +name = "sniffio" +version = "1.3.0" +description = "Sniff out which async library your code is running under" +optional = false +python-versions = ">=3.7" +files = [ + {file = "sniffio-1.3.0-py3-none-any.whl", hash = "sha256:eecefdce1e5bbfb7ad2eeaabf7c1eeb404d7757c379bd1f7e5cce9d8bf425384"}, + {file = "sniffio-1.3.0.tar.gz", hash = "sha256:e60305c5e5d314f5389259b7f22aaa33d8f7dee49763119234af3755c55b9101"}, +] + [[package]] name = "stack-data" version = "0.6.2" @@ -3953,13 +4458,13 @@ test = ["argcomplete (>=2.0)", "pre-commit", "pytest", "pytest-mock"] [[package]] name = "transformers" -version = "4.36.1" +version = "4.36.2" description = "State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow" optional = false python-versions = ">=3.8.0" files = [ - {file = "transformers-4.36.1-py3-none-any.whl", hash = "sha256:0e309d03634885f02d46801ec4f2c3fc1d614a5b9ebde608181f3e842bac53b8"}, - {file = "transformers-4.36.1.tar.gz", hash = "sha256:28e55952d9bed68f06cf45a3d29cc480679b528afe944e68f8cf6c799e428759"}, + {file = "transformers-4.36.2-py3-none-any.whl", hash = "sha256:462066c4f74ee52516f12890dcc9ec71d1a5e97998db621668455117a54330f6"}, + {file = "transformers-4.36.2.tar.gz", hash = "sha256:d8068e897e47793281501e547d2bbdfc5b8556409c2cb6c3d9e2ca77d4c0b4ec"}, ] [package.dependencies] @@ -4046,13 +4551,13 @@ tutorials = ["matplotlib", "pandas", "tabulate"] [[package]] name = "typing-extensions" -version = "4.6.3" -description = "Backported and Experimental Type Hints for Python 3.7+" +version = "4.9.0" +description = "Backported and Experimental Type Hints for Python 3.8+" optional = false -python-versions = ">=3.7" +python-versions = ">=3.8" files = [ - {file = "typing_extensions-4.6.3-py3-none-any.whl", hash = "sha256:88a4153d8505aabbb4e13aacb7c486c2b4a33ca3b3f807914a9b4c844c471c26"}, - {file = "typing_extensions-4.6.3.tar.gz", hash = "sha256:d91d5919357fe7f681a9f2b5b4cb2a5f1ef0a1e9f59c4d8ff0d3491e05c0ffd5"}, + {file = "typing_extensions-4.9.0-py3-none-any.whl", hash = "sha256:af72aea155e91adfc61c3ae9e0e342dbc0cba726d6cba4b6c72c1f34e47291cd"}, + {file = "typing_extensions-4.9.0.tar.gz", hash = "sha256:23478f88c37f27d76ac8aee6c905017a143b0b1b886c3c9f66bc2fd94f9f5783"}, ] [[package]] @@ -4142,7 +4647,6 @@ docker-pycreds = ">=0.4.0" GitPython = ">=1.0.0,<3.1.29 || >3.1.29" pathtools = "*" protobuf = [ - {version = ">=3.12.0,<4.21.0 || >4.21.0,<5", markers = "python_version < \"3.9\" and sys_platform == \"linux\""}, {version = ">=3.15.0,<4.21.0 || >4.21.0,<5", markers = "python_version == \"3.9\" and sys_platform == \"linux\""}, {version = ">=3.19.0,<4.21.0 || >4.21.0,<5", markers = "python_version > \"3.9\" or sys_platform != \"linux\""}, ] @@ -4524,5 +5028,5 @@ testing = ["big-O", "flake8 (<5)", "jaraco.functools", "jaraco.itertools", "more [metadata] lock-version = "2.0" -python-versions = "^3.8" -content-hash = "cc56f3606c4e024c6a1ceafe6d1544c50f6793978cfe1ccd5723fcc3fe82b10c" +python-versions = ">=3.9,<4.0" +content-hash = "2c47564e41c9c56b5fb140556847ffdb983508d8620f6ee4bcd7fd0b6904e599" diff --git a/pyproject.toml b/pyproject.toml index 1b23ff3..dfaa824 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,7 +7,7 @@ readme = "README.md" packages = [{include = "lm_human_preference_details"}] [tool.poetry.dependencies] -python = "^3.8" +python = ">=3.9,<4.0" torch = "^2.1.2" tyro = "^0.6.0" datasets = "^2.12.0" @@ -15,7 +15,7 @@ wandb = "^0.15.4" nvitop = "^1.1.2" ftfy = "^6.1.1" rich = "^13.4.2" -transformers = "^4.36.1" +transformers = "^4.36.2" tensorboard = "^2.13.0" accelerate = "^0.25.0" jax = "0.4.8" @@ -27,11 +27,13 @@ einops = "^0.6.1" black = "^23.7.0" clu = "^0.0.9" tabulate = "^0.9.0" -deepspeed = "^0.12.5" +deepspeed = "^0.12.6" evaluate = "^0.4.1" nltk = "^3.8.1" rouge-score = "^0.1.2" huggingface-hub = "^0.19.4" +matplotlib = "^3.8.2" +openai = "^1.6.1" [tool.poetry.group.dev.dependencies] From 537abaf95f177f233df5dd13cde297feb1cef8ad Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Wed, 3 Jan 2024 00:24:26 +0000 Subject: [PATCH 53/62] a regression in transformers https://github.com/huggingface/transformers/issues/28316 --- poetry.lock | 212 ++++++++++++++++++------------------------------- pyproject.toml | 2 +- 2 files changed, 77 insertions(+), 137 deletions(-) diff --git a/poetry.lock b/poetry.lock index 22d1fcc..ac490fb 100644 --- a/poetry.lock +++ b/poetry.lock @@ -4213,117 +4213,56 @@ tests = ["pytest", "pytest-cov"] [[package]] name = "tokenizers" -version = "0.15.0" -description = "" +version = "0.13.3" +description = "Fast and Customizable Tokenizers" optional = false -python-versions = ">=3.7" +python-versions = "*" files = [ - {file = "tokenizers-0.15.0-cp310-cp310-macosx_10_7_x86_64.whl", hash = "sha256:cd3cd0299aaa312cd2988957598f80becd04d5a07338741eca076057a2b37d6e"}, - {file = "tokenizers-0.15.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:8a922c492c721744ee175f15b91704be2d305569d25f0547c77cd6c9f210f9dc"}, - {file = "tokenizers-0.15.0-cp310-cp310-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:331dd786d02fc38698f835fff61c99480f98b73ce75a4c65bd110c9af5e4609a"}, - {file = "tokenizers-0.15.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:88dd0961c437d413ab027f8b115350c121d49902cfbadf08bb8f634b15fa1814"}, - {file = "tokenizers-0.15.0-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:6fdcc55339df7761cd52e1fbe8185d3b3963bc9e3f3545faa6c84f9e8818259a"}, - {file = "tokenizers-0.15.0-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f1480b0051d8ab5408e8e4db2dc832f7082ea24aa0722c427bde2418c6f3bd07"}, - {file = "tokenizers-0.15.0-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:9855e6c258918f9cf62792d4f6ddfa6c56dccd8c8118640f867f6393ecaf8bd7"}, - {file = "tokenizers-0.15.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:de9529fe75efcd54ba8d516aa725e1851df9199f0669b665c55e90df08f5af86"}, - {file = "tokenizers-0.15.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:8edcc90a36eab0705fe9121d6c77c6e42eeef25c7399864fd57dfb27173060bf"}, - {file = "tokenizers-0.15.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:ae17884aafb3e94f34fb7cfedc29054f5f54e142475ebf8a265a4e388fee3f8b"}, - {file = "tokenizers-0.15.0-cp310-none-win32.whl", hash = "sha256:9a3241acdc9b44cff6e95c4a55b9be943ef3658f8edb3686034d353734adba05"}, - {file = "tokenizers-0.15.0-cp310-none-win_amd64.whl", hash = "sha256:4b31807cb393d6ea31926b307911c89a1209d5e27629aa79553d1599c8ffdefe"}, - {file = "tokenizers-0.15.0-cp311-cp311-macosx_10_7_x86_64.whl", hash = "sha256:af7e9be8c05d30bb137b9fd20f9d99354816599e5fd3d58a4b1e28ba3b36171f"}, - {file = "tokenizers-0.15.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:c3d7343fa562ea29661783344a2d83662db0d3d17a6fa6a403cac8e512d2d9fd"}, - {file = "tokenizers-0.15.0-cp311-cp311-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:32371008788aeeb0309a9244809a23e4c0259625e6b74a103700f6421373f395"}, - {file = "tokenizers-0.15.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ca9db64c7c9954fbae698884c5bb089764edc549731e5f9b7fa1dd4e4d78d77f"}, - {file = "tokenizers-0.15.0-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:dbed5944c31195514669cf6381a0d8d47f164943000d10f93d6d02f0d45c25e0"}, - {file = "tokenizers-0.15.0-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:aab16c4a26d351d63e965b0c792f5da7227a37b69a6dc6d922ff70aa595b1b0c"}, - {file = "tokenizers-0.15.0-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3c2b60b12fdd310bf85ce5d7d3f823456b9b65eed30f5438dd7761879c495983"}, - {file = "tokenizers-0.15.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0344d6602740e44054a9e5bbe9775a5e149c4dddaff15959bb07dcce95a5a859"}, - {file = "tokenizers-0.15.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:4525f6997d81d9b6d9140088f4f5131f6627e4c960c2c87d0695ae7304233fc3"}, - {file = "tokenizers-0.15.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:65975094fef8cc68919644936764efd2ce98cf1bacbe8db2687155d2b0625bee"}, - {file = "tokenizers-0.15.0-cp311-none-win32.whl", hash = "sha256:ff5d2159c5d93015f5a4542aac6c315506df31853123aa39042672031768c301"}, - {file = "tokenizers-0.15.0-cp311-none-win_amd64.whl", hash = "sha256:2dd681b53cf615e60a31a115a3fda3980e543d25ca183797f797a6c3600788a3"}, - {file = "tokenizers-0.15.0-cp312-cp312-macosx_10_7_x86_64.whl", hash = "sha256:c9cce6ee149a3d703f86877bc2a6d997e34874b2d5a2d7839e36b2273f31d3d9"}, - {file = "tokenizers-0.15.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:4a0a94bc3370e6f1cc8a07a8ae867ce13b7c1b4291432a773931a61f256d44ea"}, - {file = "tokenizers-0.15.0-cp312-cp312-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:309cfcccfc7e502cb1f1de2c9c1c94680082a65bfd3a912d5a5b2c90c677eb60"}, - {file = "tokenizers-0.15.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8413e994dd7d875ab13009127fc85633916c71213917daf64962bafd488f15dc"}, - {file = "tokenizers-0.15.0-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:d0ebf9430f901dbdc3dcb06b493ff24a3644c9f88c08e6a1d6d0ae2228b9b818"}, - {file = "tokenizers-0.15.0-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:10361e9c7864b22dd791ec5126327f6c9292fb1d23481d4895780688d5e298ac"}, - {file = "tokenizers-0.15.0-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:babe42635b8a604c594bdc56d205755f73414fce17ba8479d142a963a6c25cbc"}, - {file = "tokenizers-0.15.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3768829861e964c7a4556f5f23307fce6a23872c2ebf030eb9822dbbbf7e9b2a"}, - {file = "tokenizers-0.15.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:9c91588a630adc88065e1c03ac6831e3e2112558869b9ebcb2b8afd8a14c944d"}, - {file = "tokenizers-0.15.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:77606994e793ca54ecf3a3619adc8a906a28ca223d9354b38df41cb8766a0ed6"}, - {file = "tokenizers-0.15.0-cp37-cp37m-macosx_10_7_x86_64.whl", hash = "sha256:6fe143939f3b596681922b2df12a591a5b010e7dcfbee2202482cd0c1c2f2459"}, - {file = "tokenizers-0.15.0-cp37-cp37m-macosx_11_0_arm64.whl", hash = "sha256:b7bee0f1795e3e3561e9a557061b1539e5255b8221e3f928f58100282407e090"}, - {file = "tokenizers-0.15.0-cp37-cp37m-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:5d37e7f4439b4c46192ab4f2ff38ab815e4420f153caa13dec9272ef14403d34"}, - {file = "tokenizers-0.15.0-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:caadf255cf7f951b38d10097836d1f3bcff4aeaaffadfdf748bab780bf5bff95"}, - {file = "tokenizers-0.15.0-cp37-cp37m-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:05accb9162bf711a941b1460b743d62fec61c160daf25e53c5eea52c74d77814"}, - {file = "tokenizers-0.15.0-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:26a2ef890740127cb115ee5260878f4a677e36a12831795fd7e85887c53b430b"}, - {file = "tokenizers-0.15.0-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e54c5f26df14913620046b33e822cb3bcd091a332a55230c0e63cc77135e2169"}, - {file = "tokenizers-0.15.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:669b8ed653a578bcff919566631156f5da3aab84c66f3c0b11a6281e8b4731c7"}, - {file = "tokenizers-0.15.0-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:0ea480d943297df26f06f508dab6e012b07f42bf3dffdd36e70799368a5f5229"}, - {file = "tokenizers-0.15.0-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:bc80a0a565ebfc7cd89de7dd581da8c2b3238addfca6280572d27d763f135f2f"}, - {file = "tokenizers-0.15.0-cp37-none-win32.whl", hash = "sha256:cdd945e678bbdf4517d5d8de66578a5030aeefecdb46f5320b034de9cad8d4dd"}, - {file = "tokenizers-0.15.0-cp37-none-win_amd64.whl", hash = "sha256:1ab96ab7dc706e002c32b2ea211a94c1c04b4f4de48354728c3a6e22401af322"}, - {file = "tokenizers-0.15.0-cp38-cp38-macosx_10_7_x86_64.whl", hash = "sha256:f21c9eb71c9a671e2a42f18b456a3d118e50c7f0fc4dd9fa8f4eb727fea529bf"}, - {file = "tokenizers-0.15.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:2a5f4543a35889679fc3052086e69e81880b2a5a28ff2a52c5a604be94b77a3f"}, - {file = "tokenizers-0.15.0-cp38-cp38-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:f8aa81afec893e952bd39692b2d9ef60575ed8c86fce1fd876a06d2e73e82dca"}, - {file = "tokenizers-0.15.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1574a5a4af22c3def93fe8fe4adcc90a39bf5797ed01686a4c46d1c3bc677d2f"}, - {file = "tokenizers-0.15.0-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:7c7982fd0ec9e9122d03b209dac48cebfea3de0479335100ef379a9a959b9a5a"}, - {file = "tokenizers-0.15.0-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f8d16b647032df2ce2c1f9097236e046ea9fedd969b25637b9d5d734d78aa53b"}, - {file = "tokenizers-0.15.0-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b3cdf29e6f9653da330515dc8fa414be5a93aae79e57f8acc50d4028dd843edf"}, - {file = "tokenizers-0.15.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7286f3df10de840867372e3e64b99ef58c677210e3ceb653cd0e740a5c53fe78"}, - {file = "tokenizers-0.15.0-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:aabc83028baa5a36ce7a94e7659250f0309c47fa4a639e5c2c38e6d5ea0de564"}, - {file = "tokenizers-0.15.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:72f78b0e0e276b1fc14a672fa73f3acca034ba8db4e782124a2996734a9ba9cf"}, - {file = "tokenizers-0.15.0-cp38-none-win32.whl", hash = "sha256:9680b0ecc26e7e42f16680c1aa62e924d58d1c2dd992707081cc10a374896ea2"}, - {file = "tokenizers-0.15.0-cp38-none-win_amd64.whl", hash = "sha256:f17cbd88dab695911cbdd385a5a7e3709cc61dff982351f5d1b5939f074a2466"}, - {file = "tokenizers-0.15.0-cp39-cp39-macosx_10_7_x86_64.whl", hash = "sha256:3661862df7382c5eb23ac4fbf7c75e69b02dc4f5784e4c5a734db406b5b24596"}, - {file = "tokenizers-0.15.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:c3045d191dad49647f5a5039738ecf1c77087945c7a295f7bcf051c37067e883"}, - {file = "tokenizers-0.15.0-cp39-cp39-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:a9fcaad9ab0801f14457d7c820d9f246b5ab590c407fc6b073819b1573097aa7"}, - {file = "tokenizers-0.15.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a79f17027f24fe9485701c8dbb269b9c713954ec3bdc1e7075a66086c0c0cd3c"}, - {file = "tokenizers-0.15.0-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:01a3aa332abc4bee7640563949fcfedca4de8f52691b3b70f2fc6ca71bfc0f4e"}, - {file = "tokenizers-0.15.0-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:05b83896a893cdfedad8785250daa3ba9f0504848323471524d4783d7291661e"}, - {file = "tokenizers-0.15.0-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:cbbf2489fcf25d809731ba2744ff278dd07d9eb3f8b7482726bd6cae607073a4"}, - {file = "tokenizers-0.15.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ab806ad521a5e9de38078b7add97589c313915f6f5fec6b2f9f289d14d607bd6"}, - {file = "tokenizers-0.15.0-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:4a522612d5c88a41563e3463226af64e2fa00629f65cdcc501d1995dd25d23f5"}, - {file = "tokenizers-0.15.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:e58a38c4e6075810bdfb861d9c005236a72a152ebc7005941cc90d1bbf16aca9"}, - {file = "tokenizers-0.15.0-cp39-none-win32.whl", hash = "sha256:b8034f1041fd2bd2b84ff9f4dc4ae2e1c3b71606820a9cd5c562ebd291a396d1"}, - {file = "tokenizers-0.15.0-cp39-none-win_amd64.whl", hash = "sha256:edde9aa964145d528d0e0dbf14f244b8a85ebf276fb76869bc02e2530fa37a96"}, - {file = "tokenizers-0.15.0-pp310-pypy310_pp73-macosx_10_7_x86_64.whl", hash = "sha256:309445d10d442b7521b98083dc9f0b5df14eca69dbbfebeb98d781ee2cef5d30"}, - {file = "tokenizers-0.15.0-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:d3125a6499226d4d48efc54f7498886b94c418e93a205b673bc59364eecf0804"}, - {file = "tokenizers-0.15.0-pp310-pypy310_pp73-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:ed56ddf0d54877bb9c6d885177db79b41576e61b5ef6defeb579dcb803c04ad5"}, - {file = "tokenizers-0.15.0-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3b22cd714706cc5b18992a232b023f736e539495f5cc61d2d28d176e55046f6c"}, - {file = "tokenizers-0.15.0-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fac2719b1e9bc8e8e7f6599b99d0a8e24f33d023eb8ef644c0366a596f0aa926"}, - {file = "tokenizers-0.15.0-pp310-pypy310_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:85ddae17570ec7e5bfaf51ffa78d044f444a8693e1316e1087ee6150596897ee"}, - {file = "tokenizers-0.15.0-pp310-pypy310_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:76f1bed992e396bf6f83e3df97b64ff47885e45e8365f8983afed8556a0bc51f"}, - {file = "tokenizers-0.15.0-pp37-pypy37_pp73-macosx_10_7_x86_64.whl", hash = "sha256:3bb0f4df6dce41a1c7482087b60d18c372ef4463cb99aa8195100fcd41e0fd64"}, - {file = "tokenizers-0.15.0-pp37-pypy37_pp73-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:22c27672c27a059a5f39ff4e49feed8c7f2e1525577c8a7e3978bd428eb5869d"}, - {file = "tokenizers-0.15.0-pp37-pypy37_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:78104f5d035c9991f92831fc0efe9e64a05d4032194f2a69f67aaa05a4d75bbb"}, - {file = "tokenizers-0.15.0-pp37-pypy37_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a40b73dc19d82c3e3ffb40abdaacca8fbc95eeb26c66b7f9f860aebc07a73998"}, - {file = "tokenizers-0.15.0-pp37-pypy37_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:d801d1368188c74552cd779b1286e67cb9fd96f4c57a9f9a2a09b6def9e1ab37"}, - {file = "tokenizers-0.15.0-pp37-pypy37_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:82641ffb13a4da1293fcc9f437d457647e60ed0385a9216cd135953778b3f0a1"}, - {file = "tokenizers-0.15.0-pp38-pypy38_pp73-macosx_10_7_x86_64.whl", hash = "sha256:160f9d1810f2c18fffa94aa98bf17632f6bd2dabc67fcb01a698ca80c37d52ee"}, - {file = "tokenizers-0.15.0-pp38-pypy38_pp73-macosx_11_0_arm64.whl", hash = "sha256:8d7d6eea831ed435fdeeb9bcd26476226401d7309d115a710c65da4088841948"}, - {file = "tokenizers-0.15.0-pp38-pypy38_pp73-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:f6456bec6c557d63d8ec0023758c32f589e1889ed03c055702e84ce275488bed"}, - {file = "tokenizers-0.15.0-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1eef39a502fad3bf104b9e1906b4fb0cee20e44e755e51df9a98f8922c3bf6d4"}, - {file = "tokenizers-0.15.0-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c1e4664c5b797e093c19b794bbecc19d2367e782b4a577d8b7c1821db5dc150d"}, - {file = "tokenizers-0.15.0-pp38-pypy38_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:ca003fb5f3995ff5cf676db6681b8ea5d54d3b30bea36af1120e78ee1a4a4cdf"}, - {file = "tokenizers-0.15.0-pp38-pypy38_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:7f17363141eb0c53752c89e10650b85ef059a52765d0802ba9613dbd2d21d425"}, - {file = "tokenizers-0.15.0-pp39-pypy39_pp73-macosx_10_7_x86_64.whl", hash = "sha256:8a765db05581c7d7e1280170f2888cda351760d196cc059c37ea96f121125799"}, - {file = "tokenizers-0.15.0-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:2a0dd641a72604486cd7302dd8f87a12c8a9b45e1755e47d2682733f097c1af5"}, - {file = "tokenizers-0.15.0-pp39-pypy39_pp73-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:0a1a3c973e4dc97797fc19e9f11546c95278ffc55c4492acb742f69e035490bc"}, - {file = "tokenizers-0.15.0-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d4fab75642aae4e604e729d6f78e0addb9d7e7d49e28c8f4d16b24da278e5263"}, - {file = "tokenizers-0.15.0-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:65f80be77f6327a86d8fd35a4467adcfe6174c159b4ab52a1a8dd4c6f2d7d9e1"}, - {file = "tokenizers-0.15.0-pp39-pypy39_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:a8da7533dbe66b88afd430c56a2f2ce1fd82e2681868f857da38eeb3191d7498"}, - {file = "tokenizers-0.15.0-pp39-pypy39_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:fa8eb4584fc6cbe6a84d7a7864be3ed28e23e9fd2146aa8ef1814d579df91958"}, - {file = "tokenizers-0.15.0.tar.gz", hash = "sha256:10c7e6e7b4cabd757da59e93f5f8d1126291d16f8b54f28510825ef56a3e5d0e"}, + {file = "tokenizers-0.13.3-cp310-cp310-macosx_10_11_x86_64.whl", hash = "sha256:f3835c5be51de8c0a092058a4d4380cb9244fb34681fd0a295fbf0a52a5fdf33"}, + {file = "tokenizers-0.13.3-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:4ef4c3e821730f2692489e926b184321e887f34fb8a6b80b8096b966ba663d07"}, + {file = "tokenizers-0.13.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c5fd1a6a25353e9aa762e2aae5a1e63883cad9f4e997c447ec39d071020459bc"}, + {file = "tokenizers-0.13.3-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ee0b1b311d65beab83d7a41c56a1e46ab732a9eed4460648e8eb0bd69fc2d059"}, + {file = "tokenizers-0.13.3-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5ef4215284df1277dadbcc5e17d4882bda19f770d02348e73523f7e7d8b8d396"}, + {file = "tokenizers-0.13.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a4d53976079cff8a033f778fb9adca2d9d69d009c02fa2d71a878b5f3963ed30"}, + {file = "tokenizers-0.13.3-cp310-cp310-win32.whl", hash = "sha256:1f0e3b4c2ea2cd13238ce43548959c118069db7579e5d40ec270ad77da5833ce"}, + {file = "tokenizers-0.13.3-cp310-cp310-win_amd64.whl", hash = "sha256:89649c00d0d7211e8186f7a75dfa1db6996f65edce4b84821817eadcc2d3c79e"}, + {file = "tokenizers-0.13.3-cp311-cp311-macosx_10_11_universal2.whl", hash = "sha256:56b726e0d2bbc9243872b0144515ba684af5b8d8cd112fb83ee1365e26ec74c8"}, + {file = "tokenizers-0.13.3-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:cc5c022ce692e1f499d745af293ab9ee6f5d92538ed2faf73f9708c89ee59ce6"}, + {file = "tokenizers-0.13.3-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f55c981ac44ba87c93e847c333e58c12abcbb377a0c2f2ef96e1a266e4184ff2"}, + {file = "tokenizers-0.13.3-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f247eae99800ef821a91f47c5280e9e9afaeed9980fc444208d5aa6ba69ff148"}, + {file = "tokenizers-0.13.3-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:4b3e3215d048e94f40f1c95802e45dcc37c5b05eb46280fc2ccc8cd351bff839"}, + {file = "tokenizers-0.13.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9ba2b0bf01777c9b9bc94b53764d6684554ce98551fec496f71bc5be3a03e98b"}, + {file = "tokenizers-0.13.3-cp311-cp311-win32.whl", hash = "sha256:cc78d77f597d1c458bf0ea7c2a64b6aa06941c7a99cb135b5969b0278824d808"}, + {file = "tokenizers-0.13.3-cp311-cp311-win_amd64.whl", hash = "sha256:ecf182bf59bd541a8876deccf0360f5ae60496fd50b58510048020751cf1724c"}, + {file = "tokenizers-0.13.3-cp37-cp37m-macosx_10_11_x86_64.whl", hash = "sha256:0527dc5436a1f6bf2c0327da3145687d3bcfbeab91fed8458920093de3901b44"}, + {file = "tokenizers-0.13.3-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:07cbb2c307627dc99b44b22ef05ff4473aa7c7cc1fec8f0a8b37d8a64b1a16d2"}, + {file = "tokenizers-0.13.3-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:4560dbdeaae5b7ee0d4e493027e3de6d53c991b5002d7ff95083c99e11dd5ac0"}, + {file = "tokenizers-0.13.3-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:64064bd0322405c9374305ab9b4c07152a1474370327499911937fd4a76d004b"}, + {file = "tokenizers-0.13.3-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b8c6e2ab0f2e3d939ca66aa1d596602105fe33b505cd2854a4c1717f704c51de"}, + {file = "tokenizers-0.13.3-cp37-cp37m-win32.whl", hash = "sha256:6cc29d410768f960db8677221e497226e545eaaea01aa3613fa0fdf2cc96cff4"}, + {file = "tokenizers-0.13.3-cp37-cp37m-win_amd64.whl", hash = "sha256:fc2a7fdf864554a0dacf09d32e17c0caa9afe72baf9dd7ddedc61973bae352d8"}, + {file = "tokenizers-0.13.3-cp38-cp38-macosx_10_11_x86_64.whl", hash = "sha256:8791dedba834c1fc55e5f1521be325ea3dafb381964be20684b92fdac95d79b7"}, + {file = "tokenizers-0.13.3-cp38-cp38-macosx_12_0_arm64.whl", hash = "sha256:d607a6a13718aeb20507bdf2b96162ead5145bbbfa26788d6b833f98b31b26e1"}, + {file = "tokenizers-0.13.3-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3791338f809cd1bf8e4fee6b540b36822434d0c6c6bc47162448deee3f77d425"}, + {file = "tokenizers-0.13.3-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c2f35f30e39e6aab8716f07790f646bdc6e4a853816cc49a95ef2a9016bf9ce6"}, + {file = "tokenizers-0.13.3-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:310204dfed5aa797128b65d63538a9837cbdd15da2a29a77d67eefa489edda26"}, + {file = "tokenizers-0.13.3-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a0f9b92ea052305166559f38498b3b0cae159caea712646648aaa272f7160963"}, + {file = "tokenizers-0.13.3-cp38-cp38-win32.whl", hash = "sha256:9a3fa134896c3c1f0da6e762d15141fbff30d094067c8f1157b9fdca593b5806"}, + {file = "tokenizers-0.13.3-cp38-cp38-win_amd64.whl", hash = "sha256:8e7b0cdeace87fa9e760e6a605e0ae8fc14b7d72e9fc19c578116f7287bb873d"}, + {file = "tokenizers-0.13.3-cp39-cp39-macosx_10_11_x86_64.whl", hash = "sha256:00cee1e0859d55507e693a48fa4aef07060c4bb6bd93d80120e18fea9371c66d"}, + {file = "tokenizers-0.13.3-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:a23ff602d0797cea1d0506ce69b27523b07e70f6dda982ab8cf82402de839088"}, + {file = "tokenizers-0.13.3-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:70ce07445050b537d2696022dafb115307abdffd2a5c106f029490f84501ef97"}, + {file = "tokenizers-0.13.3-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:280ffe95f50eaaf655b3a1dc7ff1d9cf4777029dbbc3e63a74e65a056594abc3"}, + {file = "tokenizers-0.13.3-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:97acfcec592f7e9de8cadcdcda50a7134423ac8455c0166b28c9ff04d227b371"}, + {file = "tokenizers-0.13.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dd7730c98a3010cd4f523465867ff95cd9d6430db46676ce79358f65ae39797b"}, + {file = "tokenizers-0.13.3-cp39-cp39-win32.whl", hash = "sha256:48625a108029cb1ddf42e17a81b5a3230ba6888a70c9dc14e81bc319e812652d"}, + {file = "tokenizers-0.13.3-cp39-cp39-win_amd64.whl", hash = "sha256:bc0a6f1ba036e482db6453571c9e3e60ecd5489980ffd95d11dc9f960483d783"}, + {file = "tokenizers-0.13.3.tar.gz", hash = "sha256:2e546dbb68b623008a5442353137fbb0123d311a6d7ba52f2667c8862a75af2e"}, ] -[package.dependencies] -huggingface_hub = ">=0.16.4,<1.0" - [package.extras] -dev = ["tokenizers[testing]"] -docs = ["setuptools_rust", "sphinx", "sphinx_rtd_theme"] +dev = ["black (==22.3)", "datasets", "numpy", "pytest", "requests"] +docs = ["setuptools-rust", "sphinx", "sphinx-rtd-theme"] testing = ["black (==22.3)", "datasets", "numpy", "pytest", "requests"] [[package]] @@ -4458,71 +4397,72 @@ test = ["argcomplete (>=2.0)", "pre-commit", "pytest", "pytest-mock"] [[package]] name = "transformers" -version = "4.36.2" +version = "4.30.1" description = "State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow" optional = false -python-versions = ">=3.8.0" +python-versions = ">=3.7.0" files = [ - {file = "transformers-4.36.2-py3-none-any.whl", hash = "sha256:462066c4f74ee52516f12890dcc9ec71d1a5e97998db621668455117a54330f6"}, - {file = "transformers-4.36.2.tar.gz", hash = "sha256:d8068e897e47793281501e547d2bbdfc5b8556409c2cb6c3d9e2ca77d4c0b4ec"}, + {file = "transformers-4.30.1-py3-none-any.whl", hash = "sha256:9b12bd9d69f21b7c56cd512117fd52856b3def1c9bfc1da97ab0ee4e8bcbd797"}, + {file = "transformers-4.30.1.tar.gz", hash = "sha256:fa74fc271d0692f385d571ce83ec898e3350455f6076d21631f4eed4916e6ffd"}, ] [package.dependencies] filelock = "*" -huggingface-hub = ">=0.19.3,<1.0" +huggingface-hub = ">=0.14.1,<1.0" numpy = ">=1.17" packaging = ">=20.0" pyyaml = ">=5.1" regex = "!=2019.12.17" requests = "*" safetensors = ">=0.3.1" -tokenizers = ">=0.14,<0.19" +tokenizers = ">=0.11.1,<0.11.3 || >0.11.3,<0.14" tqdm = ">=4.27" [package.extras] -accelerate = ["accelerate (>=0.21.0)"] -agents = ["Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.21.0)", "datasets (!=2.5.0)", "diffusers", "opencv-python", "sentencepiece (>=0.1.91,!=0.1.92)", "torch (>=1.10,!=1.12.0)"] -all = ["Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.21.0)", "av (==9.2.0)", "codecarbon (==1.2.0)", "decord (==0.6.0)", "flax (>=0.4.1,<=0.7.0)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "kenlm", "keras-nlp (>=0.3.1)", "librosa", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "phonemizer", "protobuf", "pyctcdecode (>=0.4.0)", "ray[tune] (>=2.7.0)", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "tensorflow (>=2.6,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timm", "tokenizers (>=0.14,<0.19)", "torch (>=1.10,!=1.12.0)", "torchaudio", "torchvision"] +accelerate = ["accelerate (>=0.20.2)"] +agents = ["Pillow", "accelerate (>=0.20.2)", "datasets (!=2.5.0)", "diffusers", "opencv-python", "sentencepiece (>=0.1.91,!=0.1.92)", "torch (>=1.9,!=1.12.0)"] +all = ["Pillow", "accelerate (>=0.20.2)", "av (==9.2.0)", "codecarbon (==1.2.0)", "decord (==0.6.0)", "flax (>=0.4.1,<=0.6.9)", "jax (>=0.2.8,!=0.3.2,<=0.3.6)", "jaxlib (>=0.1.65,<=0.3.6)", "kenlm", "keras-nlp (>=0.3.1)", "librosa", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "phonemizer", "protobuf (<=3.20.3)", "pyctcdecode (>=0.4.0)", "ray[tune]", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "tensorflow (>=2.4,<2.13)", "tensorflow-text (<2.13)", "tf2onnx", "timm", "tokenizers (>=0.11.1,!=0.11.3,<0.14)", "torch (>=1.9,!=1.12.0)", "torchaudio", "torchvision"] audio = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)"] codecarbon = ["codecarbon (==1.2.0)"] -deepspeed = ["accelerate (>=0.21.0)", "deepspeed (>=0.9.3)"] -deepspeed-testing = ["GitPython (<3.1.19)", "accelerate (>=0.21.0)", "beautifulsoup4", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "deepspeed (>=0.9.3)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "hf-doc-builder (>=0.3.0)", "nltk", "optuna", "parameterized", "protobuf", "psutil", "pydantic (<2)", "pytest (>=7.2.0)", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.1.5)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorboard", "timeout-decorator"] -dev = ["GitPython (<3.1.19)", "Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.21.0)", "av (==9.2.0)", "beautifulsoup4", "codecarbon (==1.2.0)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "decord (==0.6.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "flax (>=0.4.1,<=0.7.0)", "fugashi (>=1.0)", "hf-doc-builder", "hf-doc-builder (>=0.3.0)", "ipadic (>=1.0.0,<2.0)", "isort (>=5.5.4)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "kenlm", "keras-nlp (>=0.3.1)", "librosa", "nltk", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pydantic (<2)", "pytest (>=7.2.0)", "pytest-timeout", "pytest-xdist", "ray[tune] (>=2.7.0)", "rhoknp (>=1.1.0,<1.3.1)", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.1.5)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "tensorboard", "tensorflow (>=2.6,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timeout-decorator", "timm", "tokenizers (>=0.14,<0.19)", "torch (>=1.10,!=1.12.0)", "torchaudio", "torchvision", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)", "urllib3 (<2.0.0)"] -dev-tensorflow = ["GitPython (<3.1.19)", "Pillow (>=10.0.1,<=15.0)", "beautifulsoup4", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "hf-doc-builder", "hf-doc-builder (>=0.3.0)", "isort (>=5.5.4)", "kenlm", "keras-nlp (>=0.3.1)", "librosa", "nltk", "onnxconverter-common", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pydantic (<2)", "pytest (>=7.2.0)", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.1.5)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorboard", "tensorflow (>=2.6,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timeout-decorator", "tokenizers (>=0.14,<0.19)", "urllib3 (<2.0.0)"] -dev-torch = ["GitPython (<3.1.19)", "Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.21.0)", "beautifulsoup4", "codecarbon (==1.2.0)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "fugashi (>=1.0)", "hf-doc-builder", "hf-doc-builder (>=0.3.0)", "ipadic (>=1.0.0,<2.0)", "isort (>=5.5.4)", "kenlm", "librosa", "nltk", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "optuna", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pydantic (<2)", "pytest (>=7.2.0)", "pytest-timeout", "pytest-xdist", "ray[tune] (>=2.7.0)", "rhoknp (>=1.1.0,<1.3.1)", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.1.5)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "tensorboard", "timeout-decorator", "timm", "tokenizers (>=0.14,<0.19)", "torch (>=1.10,!=1.12.0)", "torchaudio", "torchvision", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)", "urllib3 (<2.0.0)"] -docs = ["Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.21.0)", "av (==9.2.0)", "codecarbon (==1.2.0)", "decord (==0.6.0)", "flax (>=0.4.1,<=0.7.0)", "hf-doc-builder", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "kenlm", "keras-nlp (>=0.3.1)", "librosa", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "phonemizer", "protobuf", "pyctcdecode (>=0.4.0)", "ray[tune] (>=2.7.0)", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "tensorflow (>=2.6,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timm", "tokenizers (>=0.14,<0.19)", "torch (>=1.10,!=1.12.0)", "torchaudio", "torchvision"] +deepspeed = ["accelerate (>=0.20.2)", "deepspeed (>=0.8.3)"] +deepspeed-testing = ["GitPython (<3.1.19)", "accelerate (>=0.20.2)", "beautifulsoup4", "black (>=23.1,<24.0)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "deepspeed (>=0.8.3)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "hf-doc-builder (>=0.3.0)", "nltk", "optuna", "parameterized", "protobuf (<=3.20.3)", "psutil", "pytest (>=7.2.0)", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "sentencepiece (>=0.1.91,!=0.1.92)", "timeout-decorator"] +dev = ["GitPython (<3.1.19)", "Pillow", "accelerate (>=0.20.2)", "av (==9.2.0)", "beautifulsoup4", "black (>=23.1,<24.0)", "codecarbon (==1.2.0)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "decord (==0.6.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "flax (>=0.4.1,<=0.6.9)", "fugashi (>=1.0)", "hf-doc-builder", "hf-doc-builder (>=0.3.0)", "ipadic (>=1.0.0,<2.0)", "isort (>=5.5.4)", "jax (>=0.2.8,!=0.3.2,<=0.3.6)", "jaxlib (>=0.1.65,<=0.3.6)", "kenlm", "keras-nlp (>=0.3.1)", "librosa", "nltk", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "parameterized", "phonemizer", "protobuf (<=3.20.3)", "psutil", "pyctcdecode (>=0.4.0)", "pytest (>=7.2.0)", "pytest-timeout", "pytest-xdist", "ray[tune]", "rhoknp (>=1.1.0,<1.3.1)", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (>=0.0.241,<=0.0.259)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "tensorflow (>=2.4,<2.13)", "tensorflow-text (<2.13)", "tf2onnx", "timeout-decorator", "timm", "tokenizers (>=0.11.1,!=0.11.3,<0.14)", "torch (>=1.9,!=1.12.0)", "torchaudio", "torchvision", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)", "urllib3 (<2.0.0)"] +dev-tensorflow = ["GitPython (<3.1.19)", "Pillow", "beautifulsoup4", "black (>=23.1,<24.0)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "hf-doc-builder", "hf-doc-builder (>=0.3.0)", "isort (>=5.5.4)", "kenlm", "keras-nlp (>=0.3.1)", "librosa", "nltk", "onnxconverter-common", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "parameterized", "phonemizer", "protobuf (<=3.20.3)", "psutil", "pyctcdecode (>=0.4.0)", "pytest (>=7.2.0)", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (>=0.0.241,<=0.0.259)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorflow (>=2.4,<2.13)", "tensorflow-text (<2.13)", "tf2onnx", "timeout-decorator", "tokenizers (>=0.11.1,!=0.11.3,<0.14)", "urllib3 (<2.0.0)"] +dev-torch = ["GitPython (<3.1.19)", "Pillow", "accelerate (>=0.20.2)", "beautifulsoup4", "black (>=23.1,<24.0)", "codecarbon (==1.2.0)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "fugashi (>=1.0)", "hf-doc-builder", "hf-doc-builder (>=0.3.0)", "ipadic (>=1.0.0,<2.0)", "isort (>=5.5.4)", "kenlm", "librosa", "nltk", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "optuna", "parameterized", "phonemizer", "protobuf (<=3.20.3)", "psutil", "pyctcdecode (>=0.4.0)", "pytest (>=7.2.0)", "pytest-timeout", "pytest-xdist", "ray[tune]", "rhoknp (>=1.1.0,<1.3.1)", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (>=0.0.241,<=0.0.259)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "timeout-decorator", "timm", "tokenizers (>=0.11.1,!=0.11.3,<0.14)", "torch (>=1.9,!=1.12.0)", "torchaudio", "torchvision", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)", "urllib3 (<2.0.0)"] +docs = ["Pillow", "accelerate (>=0.20.2)", "av (==9.2.0)", "codecarbon (==1.2.0)", "decord (==0.6.0)", "flax (>=0.4.1,<=0.6.9)", "hf-doc-builder", "jax (>=0.2.8,!=0.3.2,<=0.3.6)", "jaxlib (>=0.1.65,<=0.3.6)", "kenlm", "keras-nlp (>=0.3.1)", "librosa", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "phonemizer", "protobuf (<=3.20.3)", "pyctcdecode (>=0.4.0)", "ray[tune]", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "tensorflow (>=2.4,<2.13)", "tensorflow-text (<2.13)", "tf2onnx", "timm", "tokenizers (>=0.11.1,!=0.11.3,<0.14)", "torch (>=1.9,!=1.12.0)", "torchaudio", "torchvision"] docs-specific = ["hf-doc-builder"] -flax = ["flax (>=0.4.1,<=0.7.0)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "optax (>=0.0.8,<=0.1.4)"] +fairscale = ["fairscale (>0.3)"] +flax = ["flax (>=0.4.1,<=0.6.9)", "jax (>=0.2.8,!=0.3.2,<=0.3.6)", "jaxlib (>=0.1.65,<=0.3.6)", "optax (>=0.0.8,<=0.1.4)"] flax-speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)"] ftfy = ["ftfy"] -integrations = ["optuna", "ray[tune] (>=2.7.0)", "sigopt"] +integrations = ["optuna", "ray[tune]", "sigopt"] ja = ["fugashi (>=1.0)", "ipadic (>=1.0.0,<2.0)", "rhoknp (>=1.1.0,<1.3.1)", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)"] modelcreation = ["cookiecutter (==1.7.3)"] natten = ["natten (>=0.14.6)"] onnx = ["onnxconverter-common", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "tf2onnx"] onnxruntime = ["onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)"] optuna = ["optuna"] -quality = ["GitPython (<3.1.19)", "datasets (!=2.5.0)", "hf-doc-builder (>=0.3.0)", "isort (>=5.5.4)", "ruff (==0.1.5)", "urllib3 (<2.0.0)"] -ray = ["ray[tune] (>=2.7.0)"] +quality = ["GitPython (<3.1.19)", "black (>=23.1,<24.0)", "datasets (!=2.5.0)", "hf-doc-builder (>=0.3.0)", "isort (>=5.5.4)", "ruff (>=0.0.241,<=0.0.259)", "urllib3 (<2.0.0)"] +ray = ["ray[tune]"] retrieval = ["datasets (!=2.5.0)", "faiss-cpu"] sagemaker = ["sagemaker (>=2.31.0)"] -sentencepiece = ["protobuf", "sentencepiece (>=0.1.91,!=0.1.92)"] -serving = ["fastapi", "pydantic (<2)", "starlette", "uvicorn"] +sentencepiece = ["protobuf (<=3.20.3)", "sentencepiece (>=0.1.91,!=0.1.92)"] +serving = ["fastapi", "pydantic", "starlette", "uvicorn"] sigopt = ["sigopt"] sklearn = ["scikit-learn"] speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)", "torchaudio"] -testing = ["GitPython (<3.1.19)", "beautifulsoup4", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "hf-doc-builder (>=0.3.0)", "nltk", "parameterized", "protobuf", "psutil", "pydantic (<2)", "pytest (>=7.2.0)", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.1.5)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "tensorboard", "timeout-decorator"] -tf = ["keras-nlp (>=0.3.1)", "onnxconverter-common", "tensorflow (>=2.6,<2.16)", "tensorflow-text (<2.16)", "tf2onnx"] -tf-cpu = ["keras-nlp (>=0.3.1)", "onnxconverter-common", "tensorflow-cpu (>=2.6,<2.16)", "tensorflow-text (<2.16)", "tf2onnx"] +testing = ["GitPython (<3.1.19)", "beautifulsoup4", "black (>=23.1,<24.0)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "hf-doc-builder (>=0.3.0)", "nltk", "parameterized", "protobuf (<=3.20.3)", "psutil", "pytest (>=7.2.0)", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "timeout-decorator"] +tf = ["keras-nlp (>=0.3.1)", "onnxconverter-common", "tensorflow (>=2.4,<2.13)", "tensorflow-text (<2.13)", "tf2onnx"] +tf-cpu = ["keras-nlp (>=0.3.1)", "onnxconverter-common", "tensorflow-cpu (>=2.4,<2.13)", "tensorflow-text (<2.13)", "tf2onnx"] tf-speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)"] timm = ["timm"] -tokenizers = ["tokenizers (>=0.14,<0.19)"] -torch = ["accelerate (>=0.21.0)", "torch (>=1.10,!=1.12.0)"] +tokenizers = ["tokenizers (>=0.11.1,!=0.11.3,<0.14)"] +torch = ["accelerate (>=0.20.2)", "torch (>=1.9,!=1.12.0)"] torch-speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)", "torchaudio"] -torch-vision = ["Pillow (>=10.0.1,<=15.0)", "torchvision"] -torchhub = ["filelock", "huggingface-hub (>=0.19.3,<1.0)", "importlib-metadata", "numpy (>=1.17)", "packaging (>=20.0)", "protobuf", "regex (!=2019.12.17)", "requests", "sentencepiece (>=0.1.91,!=0.1.92)", "tokenizers (>=0.14,<0.19)", "torch (>=1.10,!=1.12.0)", "tqdm (>=4.27)"] +torch-vision = ["Pillow", "torchvision"] +torchhub = ["filelock", "huggingface-hub (>=0.14.1,<1.0)", "importlib-metadata", "numpy (>=1.17)", "packaging (>=20.0)", "protobuf (<=3.20.3)", "regex (!=2019.12.17)", "requests", "sentencepiece (>=0.1.91,!=0.1.92)", "tokenizers (>=0.11.1,!=0.11.3,<0.14)", "torch (>=1.9,!=1.12.0)", "tqdm (>=4.27)"] video = ["av (==9.2.0)", "decord (==0.6.0)"] -vision = ["Pillow (>=10.0.1,<=15.0)"] +vision = ["Pillow"] [[package]] name = "triton" @@ -5029,4 +4969,4 @@ testing = ["big-O", "flake8 (<5)", "jaraco.functools", "jaraco.itertools", "more [metadata] lock-version = "2.0" python-versions = ">=3.9,<4.0" -content-hash = "2c47564e41c9c56b5fb140556847ffdb983508d8620f6ee4bcd7fd0b6904e599" +content-hash = "037542b57b2d4791c9a38bcf12fd846ba270265d38c4e12d4ac0696deee44e00" diff --git a/pyproject.toml b/pyproject.toml index dfaa824..42a9471 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,7 +15,7 @@ wandb = "^0.15.4" nvitop = "^1.1.2" ftfy = "^6.1.1" rich = "^13.4.2" -transformers = "^4.36.2" +transformers = "4.30.1" tensorboard = "^2.13.0" accelerate = "^0.25.0" jax = "0.4.8" From 55e3a50db020233b24d1a49b8bf7fa673c1591f5 Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Wed, 3 Jan 2024 15:28:45 +0000 Subject: [PATCH 54/62] various adjustments --- lm_human_preference_details/summarize/ppo.py | 102 ++++++++------ .../summarize/ppo_left_padding.py | 126 +++++++++--------- lm_human_preference_details/tldr_dataset.py | 121 ++++++++++------- 3 files changed, 198 insertions(+), 151 deletions(-) diff --git a/lm_human_preference_details/summarize/ppo.py b/lm_human_preference_details/summarize/ppo.py index 2ba9f0b..a7bf159 100644 --- a/lm_human_preference_details/summarize/ppo.py +++ b/lm_human_preference_details/summarize/ppo.py @@ -166,7 +166,7 @@ class Args: """the mini batch size per GPU""" mini_batch_size: Optional[int] = None """the mini batch size across GPUs""" - local_eval_batch_size: int = 8 + local_eval_batch_size: int = 2 """per rank eval batch size""" # other args @@ -598,16 +598,6 @@ def repeat_generator(): optimizer.param_groups[0]["lr"] = lrnow data = next(iter_dataloader) with torch.no_grad(): - queries = data["query_token"].to(device) - query_responses = generate( - accelerator.unwrap_model(model).policy, - queries, - tokenizer, - generation_config, - ) - context_length = queries.shape[1] - responses = query_responses[:, context_length:] - eval_storage, eval_df = evaluate( args, reward_model, @@ -630,40 +620,71 @@ def repeat_generator(): del eval_storage, eval_df torch.cuda.empty_cache() - output = forward(accelerator.unwrap_model(model).policy, query_responses, tokenizer) - logits = output.logits[:, context_length - 1 : -1] - logits /= args.task.temperature + 1e-7 - all_logprobs = F.log_softmax(logits, dim=-1) - logprobs = torch.gather(all_logprobs, 2, responses.unsqueeze(-1)).squeeze(-1) - del output, logits, all_logprobs - torch.cuda.empty_cache() - - ref_output = forward(ref_policy, query_responses, tokenizer) - ref_logits = ref_output.logits[:, context_length - 1 : -1] - ref_logits /= args.task.temperature + 1e-7 - ref_all_logprobs = F.log_softmax(ref_logits, dim=-1) - ref_logprobs = torch.gather(ref_all_logprobs, 2, responses.unsqueeze(-1)).squeeze(-1) - del ref_output, ref_logits, ref_all_logprobs - torch.cuda.empty_cache() - - # **Response Processing** - postprocessed_responses = truncate_response(args, tokenizer, responses) - torch.cuda.empty_cache() - - # 2. run reward model on the truncated responses - postprocessed_query_responses = torch.cat((queries, postprocessed_responses), 1) - sequence_lengths = first_true_indices(postprocessed_responses == tokenizer.pad_token_id) - 1 - full_values, _, _ = get_reward(accelerator.unwrap_model(model).critic, query_responses, tokenizer) - values = full_values[:, context_length - 1 : -1].squeeze(-1) - _, scores, _ = get_reward(reward_model, postprocessed_query_responses, tokenizer) - - # 3. filter response. Ensure that the sample contains truncate_token_id + queries = data["query_token"].to(device) + context_length = queries.shape[1] + query_responses = [] + responses = [] + postprocessed_responses = [] + logprobs = [] + ref_logprobs = [] + values = [] + scores = [] + sequence_lengths = [] + for i in range(0, queries.shape[0], args.local_eval_batch_size): + query = queries[i : i + args.local_eval_batch_size] + query_response = generate( + accelerator.unwrap_model(model).policy, + query, + tokenizer, + generation_config, + ) + response = query_response[:, context_length:] + + output = forward(accelerator.unwrap_model(model).policy, query_response, tokenizer) + logits = output.logits[:, context_length - 1 : -1] + logits /= args.task.temperature + 1e-7 + all_logprob = F.log_softmax(logits, dim=-1) + logprob = torch.gather(all_logprob, 2, response.unsqueeze(-1)).squeeze(-1) + + ref_output = forward(ref_policy, query_response, tokenizer) + ref_logits = ref_output.logits[:, context_length - 1 : -1] + ref_logits /= args.task.temperature + 1e-7 + ref_all_logprob = F.log_softmax(ref_logits, dim=-1) + ref_logprob = torch.gather(ref_all_logprob, 2, response.unsqueeze(-1)).squeeze(-1) + + # Response Processing 1. truncate response after the first occurrence of `truncate_token_id` + postprocessed_response = truncate_response(args, tokenizer, response) + + # Response Processing 2. run reward model on the truncated responses + postprocessed_query_response = torch.cat((query, postprocessed_response), 1) + sequence_length = first_true_indices(postprocessed_response == tokenizer.pad_token_id) - 1 + full_value, _, _ = get_reward(accelerator.unwrap_model(model).critic, query_response, tokenizer) + value = full_value[:, context_length - 1 : -1].squeeze(-1) + _, score, _ = get_reward(reward_model, postprocessed_query_response, tokenizer) + + query_responses.append(query_response) + responses.append(response) + postprocessed_responses.append(postprocessed_response) + logprobs.append(logprob) + ref_logprobs.append(ref_logprob) + values.append(value) + sequence_lengths.append(sequence_length) + scores.append(score) + query_responses = torch.cat(query_responses, 0) + responses = torch.cat(responses, 0) + postprocessed_responses = torch.cat(postprocessed_responses, 0) + logprobs = torch.cat(logprobs, 0) + ref_logprobs = torch.cat(ref_logprobs, 0) + values = torch.cat(values, 0) + sequence_lengths = torch.cat(sequence_lengths, 0) + scores = torch.cat(scores, 0) + + # Response Processing 3. filter response. Ensure that the sample contains truncate_token_id # responses not passing that filter will receive a low (fixed) score # only query humans on responses that pass that filter contain_pad_token = torch.any(postprocessed_responses == tokenizer.pad_token_id, dim=-1) scores = torch.where(contain_pad_token, scores, torch.full_like(scores, args.task.penalty_reward_value)) accelerator.print(f"{scores=}, {(contain_pad_token.sum() / len(contain_pad_token))=}") - # torch.cuda.empty_cache() # 4. compute rewards kl = logprobs - ref_logprobs @@ -695,6 +716,7 @@ def repeat_generator(): writer.add_histogram("advantages", advantages[0].float(), global_step) accelerator.print("rewards====", rewards[0]) accelerator.print("advantages====", advantages[0]) + torch.cuda.empty_cache() # Do multiple epochs of PPO training, with a fresh random shuffle in each epoch for ppo_epoch_idx in range(args.ppo.noptepochs): diff --git a/lm_human_preference_details/summarize/ppo_left_padding.py b/lm_human_preference_details/summarize/ppo_left_padding.py index 12ddc2e..3575190 100644 --- a/lm_human_preference_details/summarize/ppo_left_padding.py +++ b/lm_human_preference_details/summarize/ppo_left_padding.py @@ -166,7 +166,7 @@ class Args: """the mini batch size per GPU""" mini_batch_size: Optional[int] = None """the mini batch size across GPUs""" - local_eval_batch_size: int = 8 + local_eval_batch_size: int = 2 """per rank eval batch size""" # other args @@ -275,12 +275,12 @@ def forward(self, **kwargs): def get_reward(model, query_responses, tokenizer, context_length): attention_mask = query_responses != tokenizer.pad_token_id - position_ids = attention_mask.cumsum(1) - attention_mask.long() # exclusive cumsum + # position_ids = attention_mask.cumsum(1) - attention_mask.long() # exclusive cumsum input_ids = torch.masked_fill(query_responses, ~attention_mask, 0) reward_logits = model( input_ids=input_ids, attention_mask=attention_mask, - position_ids=position_ids, + # position_ids=position_ids, return_dict=True, output_hidden_states=True, ) @@ -293,24 +293,6 @@ def get_reward(model, query_responses, tokenizer, context_length): ) -def get_value(model, query_responses, tokenizer): - attention_mask = query_responses != tokenizer.pad_token_id - input_ids = torch.masked_fill(query_responses, ~attention_mask, 0) - reward_logits = model( - input_ids=input_ids, - attention_mask=attention_mask, - return_dict=True, - output_hidden_states=True, - ) - sequence_lengths = first_true_indices(query_responses == tokenizer.pad_token_id) - 1 - # https://github.com/huggingface/transformers/blob/dc68a39c8111217683bf49a4912d0c9018bab33d/src/transformers/models/gpt2/modeling_gpt2.py#L1454 - return ( - reward_logits, - reward_logits[torch.arange(reward_logits.size(0), device=reward_logits.device), sequence_lengths].squeeze(-1), - sequence_lengths, - ) - - # taken from https://github.com/OpenLMLab/MOSS-RLHF/blob/40b91eb2f2b71b16919addede0341d2bef70825d/ppo/ppo_trainer.py#L29 # we did this we can do a single `model = accelerator.prepare(model)` class PolicyAndValueWrapper(nn.Module): @@ -620,16 +602,6 @@ def repeat_generator(): optimizer.param_groups[0]["lr"] = lrnow data = next(iter_dataloader) with torch.no_grad(): - queries = data["query_token"].to(device) - query_responses = generate( - accelerator.unwrap_model(model).policy, - queries, - tokenizer, - generation_config, - ) - context_length = queries.shape[1] - responses = query_responses[:, context_length:] - eval_storage, eval_df = evaluate( args, reward_model, @@ -652,40 +624,73 @@ def repeat_generator(): del eval_storage, eval_df torch.cuda.empty_cache() - output = forward(accelerator.unwrap_model(model).policy, query_responses, tokenizer) - logits = output.logits[:, context_length - 1 : -1] - logits /= args.task.temperature + 1e-7 - all_logprobs = F.log_softmax(logits, dim=-1) - logprobs = torch.gather(all_logprobs, 2, responses.unsqueeze(-1)).squeeze(-1) - del output, logits, all_logprobs - torch.cuda.empty_cache() - - ref_output = forward(ref_policy, query_responses, tokenizer) - ref_logits = ref_output.logits[:, context_length - 1 : -1] - ref_logits /= args.task.temperature + 1e-7 - ref_all_logprobs = F.log_softmax(ref_logits, dim=-1) - ref_logprobs = torch.gather(ref_all_logprobs, 2, responses.unsqueeze(-1)).squeeze(-1) - del ref_output, ref_logits, ref_all_logprobs - torch.cuda.empty_cache() - - # **Response Processing** - postprocessed_responses = truncate_response(args, tokenizer, responses) - torch.cuda.empty_cache() - - # 2. run reward model on the truncated responses - postprocessed_query_responses = torch.cat((queries, postprocessed_responses), 1) - sequence_lengths = first_true_indices(postprocessed_responses == tokenizer.pad_token_id) - 1 - full_values, _, _ = get_value(accelerator.unwrap_model(model).critic, query_responses, tokenizer) - values = full_values[:, context_length - 1 : -1].squeeze(-1) - scores_logits, scores, _ = get_reward(reward_model, postprocessed_query_responses, tokenizer, queries.shape[1]) - - # 3. filter response. Ensure that the sample contains truncate_token_id + queries = data["query_token"].to(device) + context_length = queries.shape[1] + query_responses = [] + responses = [] + postprocessed_responses = [] + logprobs = [] + ref_logprobs = [] + values = [] + scores = [] + sequence_lengths = [] + for i in range(0, queries.shape[0], args.local_eval_batch_size): + query = queries[i : i + args.local_eval_batch_size] + query_response = generate( + accelerator.unwrap_model(model).policy, + query, + tokenizer, + generation_config, + ) + response = query_response[:, context_length:] + + output = forward(accelerator.unwrap_model(model).policy, query_response, tokenizer) + logits = output.logits[:, context_length - 1 : -1] + logits /= args.task.temperature + 1e-7 + all_logprob = F.log_softmax(logits, dim=-1) + logprob = torch.gather(all_logprob, 2, response.unsqueeze(-1)).squeeze(-1) + + ref_output = forward(ref_policy, query_response, tokenizer) + ref_logits = ref_output.logits[:, context_length - 1 : -1] + ref_logits /= args.task.temperature + 1e-7 + ref_all_logprob = F.log_softmax(ref_logits, dim=-1) + ref_logprob = torch.gather(ref_all_logprob, 2, response.unsqueeze(-1)).squeeze(-1) + + # Response Processing 1. truncate response after the first occurrence of `truncate_token_id` + postprocessed_response = truncate_response(args, tokenizer, response) + + # Response Processing 2. run reward model on the truncated responses + postprocessed_query_response = torch.cat((query, postprocessed_response), 1) + sequence_length = first_true_indices(postprocessed_response == tokenizer.pad_token_id) - 1 + full_value, _, _ = get_reward( + accelerator.unwrap_model(model).critic, query_response, tokenizer, context_length + ) + value = full_value[:, context_length - 1 : -1].squeeze(-1) + _, score, _ = get_reward(reward_model, postprocessed_query_response, tokenizer, context_length) + + query_responses.append(query_response) + responses.append(response) + postprocessed_responses.append(postprocessed_response) + logprobs.append(logprob) + ref_logprobs.append(ref_logprob) + values.append(value) + sequence_lengths.append(sequence_length) + scores.append(score) + query_responses = torch.cat(query_responses, 0) + responses = torch.cat(responses, 0) + postprocessed_responses = torch.cat(postprocessed_responses, 0) + logprobs = torch.cat(logprobs, 0) + ref_logprobs = torch.cat(ref_logprobs, 0) + values = torch.cat(values, 0) + sequence_lengths = torch.cat(sequence_lengths, 0) + scores = torch.cat(scores, 0) + + # Response Processing 3. filter response. Ensure that the sample contains truncate_token_id # responses not passing that filter will receive a low (fixed) score # only query humans on responses that pass that filter contain_pad_token = torch.any(postprocessed_responses == tokenizer.pad_token_id, dim=-1) scores = torch.where(contain_pad_token, scores, torch.full_like(scores, args.task.penalty_reward_value)) accelerator.print(f"{scores=}, {(contain_pad_token.sum() / len(contain_pad_token))=}") - # torch.cuda.empty_cache() # 4. compute rewards kl = logprobs - ref_logprobs @@ -717,6 +722,7 @@ def repeat_generator(): writer.add_histogram("advantages", advantages[0].float(), global_step) accelerator.print("rewards====", rewards[0]) accelerator.print("advantages====", advantages[0]) + torch.cuda.empty_cache() # Do multiple epochs of PPO training, with a fresh random shuffle in each epoch for ppo_epoch_idx in range(args.ppo.noptepochs): diff --git a/lm_human_preference_details/tldr_dataset.py b/lm_human_preference_details/tldr_dataset.py index 30f47ea..08ba37e 100644 --- a/lm_human_preference_details/tldr_dataset.py +++ b/lm_human_preference_details/tldr_dataset.py @@ -1,6 +1,7 @@ import multiprocessing import os -from dataclasses import dataclass +import time +from dataclasses import dataclass, field from pprint import pformat from typing import Dict, Optional @@ -30,17 +31,25 @@ --max_sft_query_response_length=560 \ --max-rm-response-length=48 \ --max_rm_query_response_length=560 -""" - -@dataclass -class Args: - base_model: str = "gpt2" # EleutherAI/pythia-160m - max_sft_response_length: int = 48 # 53 - max_sft_query_response_length: int = 512 + 48 # 565 - max_rm_response_length: int = 153 # 169 - max_rm_query_response_length: int = 512 + 153 # 665 - hf_entity: str = None +poetry run python lm_human_preference_details/tldr_dataset.py \ + --base_model=EleutherAI/pythia-160m \ + --max_sft_response_length=53 \ + --max_sft_query_response_length=562 \ + --max-rm-response-length=169 \ + --max_rm_query_response_length=638 \ + --hf_entity=cleanrl \ + --push_to_hub \ + --oai_params.padding="" +poetry run python lm_human_preference_details/tldr_dataset.py \ + --base_model=EleutherAI/pythia-160m \ + --max_sft_response_length=48 \ + --max_sft_query_response_length=560 \ + --max-rm-response-length=48 \ + --max_rm_query_response_length=560 \ + --push_to_hub \ + --oai_params.padding="" +""" @dataclass @@ -55,6 +64,18 @@ class TaskQueryHParams: pad_side: Optional[str] = "left" +@dataclass +class Args: + base_model: str = "gpt2" # EleutherAI/pythia-160m + max_sft_response_length: int = 48 # 53 + max_sft_query_response_length: int = 512 + 48 # 565 + max_rm_response_length: int = 153 # 169 + max_rm_query_response_length: int = 512 + 153 # 665 + hf_entity: str = None + push_to_hub: bool = False + oai_params: TaskQueryHParams = field(default_factory=TaskQueryHParams) + + def _ensure_length(toks, l, pad_sequence=None, pad_side=None, truncate_side=None): assert pad_side in (None, "left", "right") assert truncate_side in (None, "left", "right") @@ -113,7 +134,7 @@ def process_query(query_info: Dict[str, str], *, encoder, hparams: TaskQueryHPar query_tokens = encoder.encode(format_str.format(**query_info)) query_token = _ensure_length(query_tokens, hparams.length, pad_side=hparams.pad_side, pad_sequence=pad_sequence) - query = encoder.decode(query_token).lstrip() + query = encoder.decode(query_token, skip_special_tokens=True).lstrip() return dict( query_token=query_token, query=query, @@ -127,12 +148,12 @@ def process_query(query_info: Dict[str, str], *, encoder, hparams: TaskQueryHPar assert isinstance(args.hf_entity, str) tokenizer = AutoTokenizer.from_pretrained(args.base_model) tokenizer.add_special_tokens({"pad_token": "[PAD]"}) - oai_h = TaskQueryHParams() - if isinstance(oai_h.padding, str): - oai_h.padding = tokenizer.encode(oai_h.padding) + if len(args.oai_params.padding) > 0: + args.oai_params.padding = tokenizer.encode(args.oai_params.padding) else: - oai_h.padding = tokenizer.pad_token_id - pprint(oai_h) + args.oai_params.padding = [tokenizer.pad_token_id] + pprint(args.oai_params) + timestamp = int(time.time()) sft_ds = load_dataset("vwxyzjn/summarize_from_feedback_tldr_3_filtered") def process_query_data(x): @@ -141,7 +162,7 @@ def process_query_data(x): # `<|endoftext|>` token reference_response = f" {x['summary']}<|endoftext|>" y = { - **process_query(x, encoder=tokenizer, hparams=oai_h), + **process_query(x, encoder=tokenizer, hparams=args.oai_params), "reference_response": reference_response, "reference_response_token": tokenizer.encode( reference_response, @@ -162,14 +183,13 @@ def process_query_data(x): return y sft_ds = sft_ds.map(process_query_data, load_from_cache_file=False, num_proc=multiprocessing.cpu_count()) - sft_ds.push_to_hub( - f"{args.hf_entity}/summarize_from_feedback_tldr_3_filtered_oai_preprocessing_{args.base_model.split('/')[-1]}_{args.max_sft_response_length}" - ) - sft_card = RepoCard.load( - f"{args.hf_entity}/summarize_from_feedback_tldr_3_filtered_oai_preprocessing_{args.base_model.split('/')[-1]}_{args.max_sft_response_length}", - repo_type="dataset", - ) - sft_card.text = f"""\ + if args.push_to_hub: + sft_ds.push_to_hub(f"{args.hf_entity}/summarize_from_feedback_tldr_3_filtered_oai_preprocessing_{timestamp}") + sft_card = RepoCard.load( + f"{args.hf_entity}/summarize_from_feedback_tldr_3_filtered_oai_preprocessing_{timestamp}", + repo_type="dataset", + ) + sft_card.text = f"""\ # TL;DR SFT Dataset for OpenAI's [Summarize from Feedback](https://openai.com/blog/summarization/) task The dataset is directly taken from https://github.com/openai/summarize-from-feedback/tree/700967448d10004279f138666442bf1497d0e705#reddit-tldr-dataset @@ -197,13 +217,13 @@ def process_query_data(x): ```python {pformat(vars(args))} -{pformat(vars(oai_h))} +{pformat(vars(args.oai_params))} ``` """ - sft_card.push_to_hub( - f"{args.hf_entity}/summarize_from_feedback_tldr_3_filtered_oai_preprocessing_{args.base_model.split('/')[-1]}_{args.max_sft_response_length}", - repo_type="dataset", - ) + sft_card.push_to_hub( + f"{args.hf_entity}/summarize_from_feedback_tldr_3_filtered_oai_preprocessing_{timestamp}", + repo_type="dataset", + ) label_ds = load_dataset("openai/summarize_from_feedback", "comparisons") @@ -216,7 +236,7 @@ def process_response_data(x): response1_policy = x["summaries"][1]["policy"] policies = "--".join(sorted([response0_policy, response1_policy])) y = { - **process_query(x["info"], encoder=tokenizer, hparams=oai_h), + **process_query(x["info"], encoder=tokenizer, hparams=args.oai_params), "response0": response0, "response0_token": tokenizer.encode( response0, padding="max_length", max_length=args.max_rm_response_length, truncation=True @@ -244,9 +264,8 @@ def process_response_data(x): return y label_ds = label_ds.map(process_response_data, load_from_cache_file=False, num_proc=multiprocessing.cpu_count()) - label_ds.push_to_hub( - f"{args.hf_entity}/summarize_from_feedback_oai_preprocessing_{args.base_model.split('/')[-1]}_{args.max_rm_response_length}" - ) + if args.push_to_hub: + label_ds.push_to_hub(f"{args.hf_entity}/summarize_from_feedback_oai_preprocessing_{timestamp}") os.makedirs("dataset_visuals", exist_ok=True) # visualize token length distribution @@ -321,19 +340,19 @@ def process_response_data(x): fig.tight_layout() fig.savefig("dataset_visuals/policy_comparisons.png") - # upload the `dataset_visuals` - - api.upload_folder( - folder_path="dataset_visuals", - path_in_repo="dataset_visuals", - repo_id=f"{args.hf_entity}/summarize_from_feedback_oai_preprocessing_{args.base_model.split('/')[-1]}_{args.max_rm_response_length}", - repo_type="dataset", - ) - # upload current file - print(f"{__file__=}") - api.upload_file( - path_or_fileobj=__file__, - path_in_repo="create_dataset.py", - repo_id=f"{args.hf_entity}/summarize_from_feedback_oai_preprocessing_{args.base_model.split('/')[-1]}_{args.max_rm_response_length}", - repo_type="dataset", - ) + if args.push_to_hub: + # upload the `dataset_visuals` + api.upload_folder( + folder_path="dataset_visuals", + path_in_repo="dataset_visuals", + repo_id=f"{args.hf_entity}/summarize_from_feedback_oai_preprocessing_{timestamp}", + repo_type="dataset", + ) + # upload current file + print(f"{__file__=}") + api.upload_file( + path_or_fileobj=__file__, + path_in_repo="create_dataset.py", + repo_id=f"{args.hf_entity}/summarize_from_feedback_oai_preprocessing_{timestamp}", + repo_type="dataset", + ) From e8105dd34dfdfa1ae05e29a83dc98b0617bdab50 Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Wed, 3 Jan 2024 21:23:32 +0000 Subject: [PATCH 55/62] work with 6.9B --- lm_human_preference_details/summarize/ppo.py | 16 +++++++++++++++- .../summarize/ppo_left_padding.py | 16 +++++++++++++++- 2 files changed, 30 insertions(+), 2 deletions(-) diff --git a/lm_human_preference_details/summarize/ppo.py b/lm_human_preference_details/summarize/ppo.py index a7bf159..0f0e0af 100644 --- a/lm_human_preference_details/summarize/ppo.py +++ b/lm_human_preference_details/summarize/ppo.py @@ -543,7 +543,7 @@ def repeat_generator(): "prescale_gradients": False, "wall_clock_breakdown": False, } - if args.offload: + if args.offload or args.base_model == "EleutherAI/pythia-6.9b-deduped": deepspeed_states.deepspeed_config["checkpoint"] = {"use_node_local_storage": True} eval_ds_config["zero_optimization"] = { "stage": 3, @@ -678,6 +678,20 @@ def repeat_generator(): values = torch.cat(values, 0) sequence_lengths = torch.cat(sequence_lengths, 0) scores = torch.cat(scores, 0) + del ( + output, + logits, + all_logprob, + logprob, + ref_output, + ref_logits, + ref_all_logprob, + ref_logprob, + full_value, + value, + score, + ) + torch.cuda.empty_cache() # Response Processing 3. filter response. Ensure that the sample contains truncate_token_id # responses not passing that filter will receive a low (fixed) score diff --git a/lm_human_preference_details/summarize/ppo_left_padding.py b/lm_human_preference_details/summarize/ppo_left_padding.py index 3575190..a129719 100644 --- a/lm_human_preference_details/summarize/ppo_left_padding.py +++ b/lm_human_preference_details/summarize/ppo_left_padding.py @@ -547,7 +547,7 @@ def repeat_generator(): "prescale_gradients": False, "wall_clock_breakdown": False, } - if args.offload: + if args.offload or args.base_model == "EleutherAI/pythia-6.9b-deduped": deepspeed_states.deepspeed_config["checkpoint"] = {"use_node_local_storage": True} eval_ds_config["zero_optimization"] = { "stage": 3, @@ -684,6 +684,20 @@ def repeat_generator(): values = torch.cat(values, 0) sequence_lengths = torch.cat(sequence_lengths, 0) scores = torch.cat(scores, 0) + del ( + output, + logits, + all_logprob, + logprob, + ref_output, + ref_logits, + ref_all_logprob, + ref_logprob, + full_value, + value, + score, + ) + torch.cuda.empty_cache() # Response Processing 3. filter response. Ensure that the sample contains truncate_token_id # responses not passing that filter will receive a low (fixed) score From 3f6d0454f24d033e10d6ba34f1aa16c90f575823 Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Wed, 3 Jan 2024 21:53:13 +0000 Subject: [PATCH 56/62] prettier --- lm_human_preference_details/summarize/ppo.py | 15 ++------------- .../summarize/ppo_left_padding.py | 15 ++------------- 2 files changed, 4 insertions(+), 26 deletions(-) diff --git a/lm_human_preference_details/summarize/ppo.py b/lm_human_preference_details/summarize/ppo.py index 0f0e0af..925e2de 100644 --- a/lm_human_preference_details/summarize/ppo.py +++ b/lm_human_preference_details/summarize/ppo.py @@ -678,19 +678,8 @@ def repeat_generator(): values = torch.cat(values, 0) sequence_lengths = torch.cat(sequence_lengths, 0) scores = torch.cat(scores, 0) - del ( - output, - logits, - all_logprob, - logprob, - ref_output, - ref_logits, - ref_all_logprob, - ref_logprob, - full_value, - value, - score, - ) + del (output, logits, all_logprob, logprob, ref_output) + del (ref_logits, ref_all_logprob, ref_logprob, full_value, value, score) torch.cuda.empty_cache() # Response Processing 3. filter response. Ensure that the sample contains truncate_token_id diff --git a/lm_human_preference_details/summarize/ppo_left_padding.py b/lm_human_preference_details/summarize/ppo_left_padding.py index a129719..c621821 100644 --- a/lm_human_preference_details/summarize/ppo_left_padding.py +++ b/lm_human_preference_details/summarize/ppo_left_padding.py @@ -684,19 +684,8 @@ def repeat_generator(): values = torch.cat(values, 0) sequence_lengths = torch.cat(sequence_lengths, 0) scores = torch.cat(scores, 0) - del ( - output, - logits, - all_logprob, - logprob, - ref_output, - ref_logits, - ref_all_logprob, - ref_logprob, - full_value, - value, - score, - ) + del (output, logits, all_logprob, logprob, ref_output) + del (ref_logits, ref_all_logprob, ref_logprob, full_value, value, score) torch.cuda.empty_cache() # Response Processing 3. filter response. Ensure that the sample contains truncate_token_id From d12e44312542f8207d850f683cd42c01f8ebe41f Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Sat, 6 Jan 2024 21:37:55 +0000 Subject: [PATCH 57/62] push changes --- .../summarize/ppo_left_padding.py | 41 +++++---- .../summarize/reward.py | 90 ++++++++++--------- lm_human_preference_details/summarize/sft.py | 3 +- 3 files changed, 74 insertions(+), 60 deletions(-) diff --git a/lm_human_preference_details/summarize/ppo_left_padding.py b/lm_human_preference_details/summarize/ppo_left_padding.py index c621821..9067f2e 100644 --- a/lm_human_preference_details/summarize/ppo_left_padding.py +++ b/lm_human_preference_details/summarize/ppo_left_padding.py @@ -168,6 +168,8 @@ class Args: """the mini batch size across GPUs""" local_eval_batch_size: int = 2 """per rank eval batch size""" + local_rollout_forward_batch_size: int = 64 + """per rank no grad forward pass in the rollout phase""" # other args base_model: str = "EleutherAI/pythia-160m" @@ -466,7 +468,8 @@ def evaluate(args: Args, reward_model, policy, tokenizer, dataloader, generation name=run_name, save_code=True, ) - wandb.run.log_code(".") + file_extensions = [".toml", ".lock", ".py", ".sh", ".yaml"] + wandb.run.log_code(".", include_fn=lambda path: any([path.endswith(ext) for ext in file_extensions])) writer = SummaryWriter(f"runs/{run_name}") writer.add_text( "hyperparameters", @@ -634,8 +637,8 @@ def repeat_generator(): values = [] scores = [] sequence_lengths = [] - for i in range(0, queries.shape[0], args.local_eval_batch_size): - query = queries[i : i + args.local_eval_batch_size] + for i in range(0, queries.shape[0], args.local_rollout_forward_batch_size): + query = queries[i : i + args.local_rollout_forward_batch_size] query_response = generate( accelerator.unwrap_model(model).policy, query, @@ -649,12 +652,16 @@ def repeat_generator(): logits /= args.task.temperature + 1e-7 all_logprob = F.log_softmax(logits, dim=-1) logprob = torch.gather(all_logprob, 2, response.unsqueeze(-1)).squeeze(-1) + del output, logits, all_logprob + torch.cuda.empty_cache() ref_output = forward(ref_policy, query_response, tokenizer) ref_logits = ref_output.logits[:, context_length - 1 : -1] ref_logits /= args.task.temperature + 1e-7 ref_all_logprob = F.log_softmax(ref_logits, dim=-1) ref_logprob = torch.gather(ref_all_logprob, 2, response.unsqueeze(-1)).squeeze(-1) + del ref_output, ref_logits, ref_all_logprob + torch.cuda.empty_cache() # Response Processing 1. truncate response after the first occurrence of `truncate_token_id` postprocessed_response = truncate_response(args, tokenizer, response) @@ -684,8 +691,7 @@ def repeat_generator(): values = torch.cat(values, 0) sequence_lengths = torch.cat(sequence_lengths, 0) scores = torch.cat(scores, 0) - del (output, logits, all_logprob, logprob, ref_output) - del (ref_logits, ref_all_logprob, ref_logprob, full_value, value, score) + del (logprob, ref_logprob, full_value, value, score) torch.cuda.empty_cache() # Response Processing 3. filter response. Ensure that the sample contains truncate_token_id @@ -766,14 +772,22 @@ def repeat_generator(): pg_losses = -mb_advantage * ratio pg_losses2 = -mb_advantage * torch.clamp(ratio, 1.0 - args.ppo.cliprange, 1.0 + args.ppo.cliprange) pg_loss = torch.max(pg_losses, pg_losses2).mean() - pg_clipfrac = (pg_losses2 > pg_losses).float().mean() loss = pg_loss + args.ppo.vf_coef * vf_loss accelerator.backward(loss) optimizer.step() optimizer.zero_grad() - prob_dist = torch.nn.functional.softmax(logits, dim=-1) - entropy = torch.logsumexp(logits, dim=-1) - torch.sum(prob_dist * logits, dim=-1) - approxkl = 0.5 * (logprobs_diff**2).mean() + with torch.no_grad(): + pg_clipfrac = (pg_losses2 > pg_losses).float().mean() + prob_dist = torch.nn.functional.softmax(logits, dim=-1) + entropy = torch.logsumexp(logits, dim=-1) - torch.sum(prob_dist * logits, dim=-1) + approxkl = 0.5 * (logprobs_diff**2).mean() + approxkl_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = approxkl + pg_clipfrac_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = pg_clipfrac + pg_loss_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = pg_loss + vf_loss_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = vf_loss + vf_clipfrac_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = vf_clipfrac + entropy_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = entropy.mean() + ratio_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = ratio.mean() # if ppo_epoch_idx == 0 and micro_batch_start == 0: # torch.testing.assert_close(ratio, torch.zeros_like(ratio) + 1, atol=1e-4, rtol=1e-4) # if ppo_epoch_idx == 0: @@ -794,14 +808,6 @@ def repeat_generator(): # # "entropy": masked_mean(entropy, ~padding_mask[micro_batch_inds]), # }) # breakpoint() - with torch.no_grad(): - approxkl_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = approxkl - pg_clipfrac_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = pg_clipfrac - pg_loss_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = pg_loss - vf_loss_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = vf_loss - vf_clipfrac_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = vf_clipfrac - entropy_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = entropy.mean() - ratio_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = ratio.mean() gradient_accumulation_idx += 1 minibatch_idx += 1 if accelerator.is_main_process: @@ -861,6 +867,7 @@ def repeat_generator(): if args.reward.use_adaptive_kl: kl_ctl.update(mean_kl.item(), args.batch_size) del kl, mean_kl, mean_entropy, mean_non_score_reward, scores + torch.cuda.empty_cache() if args.run_eval: eval_storage, eval_df = evaluate( diff --git a/lm_human_preference_details/summarize/reward.py b/lm_human_preference_details/summarize/reward.py index d99717c..a41bea4 100644 --- a/lm_human_preference_details/summarize/reward.py +++ b/lm_human_preference_details/summarize/reward.py @@ -140,7 +140,7 @@ class Args: """Which layers to apply dropout to""" output_dir: str = "models/reward_model" """Where to save the model""" - label_dataset: str = "vwxyzjn/summarize_from_feedback_oai_preprocessing_pythia-160m_169" + label_dataset: str = "cleanrl/summarize_from_feedback_oai_preprocessing_1704563162" """the name of the dataset to use for labels in `https://huggingface.co/datasets/vwxyzjn/lm-human-preferences`""" logsigmoid: bool = True """Whether to use log-sigmoid loss instead of cross-entropy loss""" @@ -271,7 +271,7 @@ def evaluate(args: Args, accelerator, tokenizer, model, dataloader): args.batch_size = int(args.local_batch_size * args.world_size) # load dataset - dataset = load_dataset(args.label_dataset, "comparisons", split="train") + dataset = load_dataset(args.label_dataset, split="train") dataset = dataset.shuffle(seed=local_seed) dataset = dataset.select(range(args.label.num_train)) dataset = dataset.with_format( @@ -288,27 +288,31 @@ def evaluate(args: Args, accelerator, tokenizer, model, dataloader): ], ) dataloader = DataLoader(dataset, batch_size=args.local_micro_batch_size) - validation_dataset = load_dataset(args.label_dataset, "comparisons", split="validation").flatten() - validation_dataset = validation_dataset.with_format( - "torch", - columns=[ - "query_token", - "choice", - "response0_token", - "query_response0_token", - "response1_token", - "query_response1_token", - "batch", - "split", - "extra.confidence", - "response0_policy", - "response1_policy", - "policies", - ], - ) - validation_dataloader = DataLoader(validation_dataset, batch_size=args.local_eval_batch_size) + eval_datasets = [] + eval_dataloaders = {} + for split in ["validation", "validation_cnndm"]: + validation_dataset = load_dataset(args.label_dataset, split=split).flatten() + validation_dataset = validation_dataset.with_format( + "torch", + columns=[ + "query_token", + "choice", + "response0_token", + "query_response0_token", + "response1_token", + "query_response1_token", + "batch", + "split", + "extra.confidence", + "response0_policy", + "response1_policy", + "policies", + ], + ) + eval_datasets.append(validation_dataset) + eval_dataloaders[split] = DataLoader(validation_dataset, batch_size=args.local_eval_batch_size) + accelerator.print("The number of samples in validation_dataset", len(validation_dataset)) accelerator.print("The number of samples in dataset", len(dataset)) - accelerator.print("The number of samples in validation_dataset", len(validation_dataset)) args.total_episodes = len(dataset) args.num_updates = args.total_episodes // args.batch_size @@ -328,7 +332,8 @@ def evaluate(args: Args, accelerator, tokenizer, model, dataloader): name=run_name, save_code=True, ) - wandb.run.log_code(".") + file_extensions = [".toml", ".lock", ".py", ".sh", ".yaml"] + wandb.run.log_code(".", include_fn=lambda path: any([path.endswith(ext) for ext in file_extensions])) writer = SummaryWriter(f"runs/{run_name}") writer.add_text( "hyperparameters", @@ -379,7 +384,7 @@ def evaluate(args: Args, accelerator, tokenizer, model, dataloader): deepspeed_states.deepspeed_config["train_micro_batch_size_per_gpu"] = args.local_micro_batch_size model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader) - validation_dataloader = accelerator.prepare(validation_dataloader) + eval_dataloaders = {split: accelerator.prepare(eval_dataloader) for split, eval_dataloader in eval_dataloaders.items()} accelerator.print("===training model===") losses = torch.zeros((args.gradient_accumulation_steps,), device=device) @@ -436,24 +441,25 @@ def evaluate(args: Args, accelerator, tokenizer, model, dataloader): ) if args.run_eval: - evaluate_df = evaluate(args, accelerator, tokenizer, model, validation_dataloader) - for split, row in evaluate_df[["split", "accuracy"]].groupby(["split"]).mean().iterrows(): - writer.add_scalar(f"eval/rm/accuracy/split/{split}", row["accuracy"], global_step) - accelerator.print(f"eval/rm/accuracy/split/{split}: {row['accuracy']}") - for batch, row in evaluate_df[["batch", "accuracy"]].groupby(["batch"]).mean().iterrows(): - writer.add_scalar(f"eval/rm/accuracy/batch/{batch}", row["accuracy"], global_step) - accelerator.print(f"eval/rm/accuracy/batch/{batch}: {row['accuracy']}") - for confi, row in evaluate_df[["confidence", "accuracy"]].groupby(["confidence"]).mean().iterrows(): - writer.add_scalar(f"eval/rm/accuracy/confidence/{confi}", row["accuracy"], global_step) - accelerator.print(f"eval/rm/accuracy/confidence/{confi}: {row['accuracy']}") - writer.add_scalar("eval/rm/accuracy", evaluate_df["accuracy"].mean(), global_step) - accelerator.print(f"eval/rm/accuracy: {evaluate_df['accuracy'].mean()}") - if accelerator.is_main_process: - os.makedirs(f"eval_tables/{run_name}", exist_ok=True) - evaluate_df.to_csv(f"eval_tables/{run_name}/eval_{update}.csv") - if args.track: - wandb.log({"samples/query_responses": wandb.Table(dataframe=evaluate_df)}, step=update) - torch.cuda.empty_cache() + for eval_split in eval_dataloaders: + evaluate_df = evaluate(args, accelerator, tokenizer, model, eval_dataloaders[eval_split]) + for split, row in evaluate_df[["split", "accuracy"]].groupby(["split"]).mean().iterrows(): + writer.add_scalar(f"eval/rm/{eval_split}/accuracy/split/{split}", row["accuracy"], global_step) + accelerator.print(f"eval/rm/{eval_split}/accuracy/split/{split}: {row['accuracy']}") + for batch, row in evaluate_df[["batch", "accuracy"]].groupby(["batch"]).mean().iterrows(): + writer.add_scalar(f"eval/rm/{eval_split}/accuracy/batch/{batch}", row["accuracy"], global_step) + accelerator.print(f"eval/rm/{eval_split}/accuracy/batch/{batch}: {row['accuracy']}") + for confi, row in evaluate_df[["confidence", "accuracy"]].groupby(["confidence"]).mean().iterrows(): + writer.add_scalar(f"eval/rm/{eval_split}/accuracy/confidence/{confi}", row["accuracy"], global_step) + accelerator.print(f"eval/rm/{eval_split}/accuracy/confidence/{confi}: {row['accuracy']}") + writer.add_scalar(f"eval/rm/{eval_split}/accuracy", evaluate_df["accuracy"].mean(), global_step) + accelerator.print(f"eval/rm/{eval_split}/accuracy: {evaluate_df['accuracy'].mean()}") + if accelerator.is_main_process: + os.makedirs(f"eval_tables/{run_name}", exist_ok=True) + evaluate_df.to_csv(f"eval_tables/{run_name}/eval_{eval_split}_{update}.csv") + if args.track: + wandb.log({f"samples/{eval_split}/query_responses": wandb.Table(dataframe=evaluate_df)}, step=update) + torch.cuda.empty_cache() norm_dataset = load_dataset(args.task.query_dataset, split="train") norm_dataset = norm_dataset.with_format("torch", columns=["query_token", "reference_response_token"]) diff --git a/lm_human_preference_details/summarize/sft.py b/lm_human_preference_details/summarize/sft.py index 3c20704..a37f387 100644 --- a/lm_human_preference_details/summarize/sft.py +++ b/lm_human_preference_details/summarize/sft.py @@ -318,7 +318,8 @@ def evaluate(args: Args, accelerator, tokenizer, model, dataloader, generation_c name=run_name, save_code=True, ) - wandb.run.log_code(".") + file_extensions = [".toml", ".lock", ".py", ".sh", ".yaml"] + wandb.run.log_code(".", include_fn=lambda path: any([path.endswith(ext) for ext in file_extensions])) writer = SummaryWriter(f"runs/{run_name}") writer.add_text( "hyperparameters", From 6f6490f074b94854faeafd59b702d22dfa5e3277 Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Thu, 11 Jan 2024 21:39:17 +0000 Subject: [PATCH 58/62] handles cnndm correctly --- lm_human_preference_details/tldr_dataset.py | 266 ++++++++++++++------ 1 file changed, 185 insertions(+), 81 deletions(-) diff --git a/lm_human_preference_details/tldr_dataset.py b/lm_human_preference_details/tldr_dataset.py index 08ba37e..5cea73c 100644 --- a/lm_human_preference_details/tldr_dataset.py +++ b/lm_human_preference_details/tldr_dataset.py @@ -3,7 +3,7 @@ import time from dataclasses import dataclass, field from pprint import pformat -from typing import Dict, Optional +from typing import Dict, Literal, Optional import matplotlib.pyplot as plt import pandas as pd @@ -20,60 +20,74 @@ """ poetry run python lm_human_preference_details/tldr_dataset.py poetry run python lm_human_preference_details/tldr_dataset.py \ - --base_model=EleutherAI/pythia-160m \ + --base_model=EleutherAI/pythia-1b-deduped \ --max_sft_response_length=53 \ --max_sft_query_response_length=562 \ - --max-rm-response-length=169 \ + --max_rm_response_length=169 \ --max_rm_query_response_length=638 -poetry run python lm_human_preference_details/tldr_dataset.py \ - --base_model=EleutherAI/pythia-160m \ - --max_sft_response_length=48 \ - --max_sft_query_response_length=560 \ - --max-rm-response-length=48 \ - --max_rm_query_response_length=560 -poetry run python lm_human_preference_details/tldr_dataset.py \ - --base_model=EleutherAI/pythia-160m \ - --max_sft_response_length=53 \ - --max_sft_query_response_length=562 \ - --max-rm-response-length=169 \ - --max_rm_query_response_length=638 \ +poetry run python -i lm_human_preference_details/tldr_dataset.py \ + --base_model=EleutherAI/pythia-1b-deduped \ + --tldr_params.max_sft_response_length=53 \ + --tldr_params.max_sft_query_response_length=562 \ + --tldr_params.max_rm_response_length=169 \ + --tldr_params.max_rm_query_response_length=638 \ + --cnndm_params.max_rm_response_length=155 \ + --cnndm_params.max_rm_query_response_length=2021 \ --hf_entity=cleanrl \ --push_to_hub \ - --oai_params.padding="" -poetry run python lm_human_preference_details/tldr_dataset.py \ - --base_model=EleutherAI/pythia-160m \ - --max_sft_response_length=48 \ - --max_sft_query_response_length=560 \ - --max-rm-response-length=48 \ - --max_rm_query_response_length=560 \ - --push_to_hub \ - --oai_params.padding="" + --tldr_params.padding="pad_token" \ + --cnndm_params.padding="pad_token" \ """ @dataclass class TaskQueryHParams: - length: int = 512 - format_str: Optional[ - str - ] = "SUBREDDIT: r/{subreddit}\n\nTITLE: {title}\n\nPOST: {post}\n\nTL;DR:" # if underlying dataset yields dicts, can format arbitrarily - truncate_field: Optional[str] = "post" - truncate_text: Optional[str] = "\n" - padding: Optional[str] = " " # empty spaces - pad_side: Optional[str] = "left" + length: Optional[int] = None + format_str: Optional[str] = None + truncate_field: Optional[str] = None + truncate_text: Optional[str] = None + padding: Optional[Literal["empty_space", "pad_token"]] = None + pad_token: Optional[str] = None + pad_side: Optional[str] = None + max_sft_response_length: Optional[int] = None + max_sft_query_response_length: Optional[int] = None + max_rm_response_length: Optional[int] = None + max_rm_query_response_length: Optional[int] = None @dataclass class Args: - base_model: str = "gpt2" # EleutherAI/pythia-160m - max_sft_response_length: int = 48 # 53 - max_sft_query_response_length: int = 512 + 48 # 565 - max_rm_response_length: int = 153 # 169 - max_rm_query_response_length: int = 512 + 153 # 665 + base_model: str = "EleutherAI/pythia-1b-deduped" # "gpt2" hf_entity: str = None push_to_hub: bool = False - oai_params: TaskQueryHParams = field(default_factory=TaskQueryHParams) + check_length_correctness: bool = True + tldr_params: TaskQueryHParams = field( + default_factory=lambda: TaskQueryHParams( + length=512, + format_str="SUBREDDIT: r/{subreddit}\n\nTITLE: {title}\n\nPOST: {post}\n\nTL;DR:", + truncate_field="post", + truncate_text="\n", + padding="empty_space", + pad_side="left", + max_sft_response_length=53, # 48 + max_sft_query_response_length=562, # 512 + 48 + max_rm_response_length=169, # 153 + max_rm_query_response_length=638, # 512 + 153 + ) + ) + cnndm_params: TaskQueryHParams = field( + default_factory=lambda: TaskQueryHParams( + length=2047 - 128, + format_str="Article:\n{article}\n\nTL;DR:\n", + truncate_field="article", + truncate_text="\n", + padding="empty_space", + pad_side="left", + max_rm_response_length=155, # 153 + max_rm_query_response_length=2021, # 512 + 153 + ) + ) def _ensure_length(toks, l, pad_sequence=None, pad_side=None, truncate_side=None): @@ -102,7 +116,7 @@ def _ensure_length(toks, l, pad_sequence=None, pad_side=None, truncate_side=None def _get_query_padding_for_task(encoder, hparams: TaskQueryHParams): - return hparams.padding * hparams.length + return hparams.pad_token * hparams.length def process_query(query_info: Dict[str, str], *, encoder, hparams: TaskQueryHParams, pad_sequence=None): @@ -141,6 +155,10 @@ def process_query(query_info: Dict[str, str], *, encoder, hparams: TaskQueryHPar ) +def ceil_div(a, b): + return (a - 1) // b + 1 + + if __name__ == "__main__": args = tyro.cli(Args) if args.hf_entity is None: @@ -148,11 +166,17 @@ def process_query(query_info: Dict[str, str], *, encoder, hparams: TaskQueryHPar assert isinstance(args.hf_entity, str) tokenizer = AutoTokenizer.from_pretrained(args.base_model) tokenizer.add_special_tokens({"pad_token": "[PAD]"}) - if len(args.oai_params.padding) > 0: - args.oai_params.padding = tokenizer.encode(args.oai_params.padding) + + # post init + if args.tldr_params.padding == "empty_space": + args.tldr_params.pad_token = tokenizer.encode(" ") else: - args.oai_params.padding = [tokenizer.pad_token_id] - pprint(args.oai_params) + args.tldr_params.pad_token = [tokenizer.pad_token_id] + if args.cnndm_params.padding == "empty_space": + args.cnndm_params.pad_token = tokenizer.encode(" ") + else: + args.cnndm_params.pad_token = [tokenizer.pad_token_id] + pprint(args) timestamp = int(time.time()) sft_ds = load_dataset("vwxyzjn/summarize_from_feedback_tldr_3_filtered") @@ -162,23 +186,27 @@ def process_query_data(x): # `<|endoftext|>` token reference_response = f" {x['summary']}<|endoftext|>" y = { - **process_query(x, encoder=tokenizer, hparams=args.oai_params), + **process_query(x, encoder=tokenizer, hparams=args.tldr_params), "reference_response": reference_response, "reference_response_token": tokenizer.encode( reference_response, padding="max_length", - max_length=args.max_sft_response_length, + max_length=args.tldr_params.max_sft_response_length, truncation=True, ), "reference_response_token_len": len(tokenizer.encode(reference_response)), } y["query_reference_response"] = y["query"].strip() + y["reference_response"] - y["query_reference_response_token"] = tokenizer.encode( - y["query_reference_response"], - padding="max_length", - max_length=args.max_sft_query_response_length, - truncation=True, - ) + # if padding is space, then we can just concatenate the tokens + if args.tldr_params.padding == "empty_space": + y["query_reference_response_token"] = y["query_token"] + y["reference_response_token"] + else: + y["query_reference_response_token"] = tokenizer.encode( + y["query_reference_response"], + padding="max_length", + max_length=args.max_sft_query_response_length, + truncation=True, + ) y["query_reference_response_token_len"] = len(tokenizer.encode(y["query_reference_response"])) return y @@ -217,7 +245,6 @@ def process_query_data(x): ```python {pformat(vars(args))} -{pformat(vars(args.oai_params))} ``` """ sft_card.push_to_hub( @@ -225,7 +252,10 @@ def process_query_data(x): repo_type="dataset", ) + cnndm_batches = ["batch0_cnndm", "cnndm0", "cnndm2"] label_ds = load_dataset("openai/summarize_from_feedback", "comparisons") + label_ds["validation_cnndm"] = label_ds["validation"].filter(lambda x: x["batch"] in cnndm_batches) + label_ds["validation"] = label_ds["validation"].filter(lambda x: x["batch"] not in cnndm_batches) def process_response_data(x): # the `x['summaries'][0]['text']` in `openai/summarize_from_feedback` `comaprisons` @@ -235,16 +265,27 @@ def process_response_data(x): response0_policy = x["summaries"][0]["policy"] response1_policy = x["summaries"][1]["policy"] policies = "--".join(sorted([response0_policy, response1_policy])) + format_params = args.cnndm_params if x["batch"] in cnndm_batches else args.tldr_params + max_rm_response_length = ( + args.cnndm_params.max_rm_response_length + if x["batch"] in cnndm_batches + else args.tldr_params.max_rm_response_length + ) + max_rm_query_response_length = ( + args.cnndm_params.max_rm_query_response_length + if x["batch"] in cnndm_batches + else args.tldr_params.max_rm_query_response_length + ) y = { - **process_query(x["info"], encoder=tokenizer, hparams=args.oai_params), + **process_query(x["info"], encoder=tokenizer, hparams=format_params), "response0": response0, "response0_token": tokenizer.encode( - response0, padding="max_length", max_length=args.max_rm_response_length, truncation=True + response0, padding="max_length", max_length=max_rm_response_length, truncation=True ), "response0_token_len": len(tokenizer.encode(response0)), "response1": response1, "response1_token": tokenizer.encode( - response1, padding="max_length", max_length=args.max_rm_response_length, truncation=True + response1, padding="max_length", max_length=max_rm_response_length, truncation=True ), "response1_token_len": len(tokenizer.encode(response1)), "response0_policy": response0_policy, @@ -252,26 +293,50 @@ def process_response_data(x): "policies": policies, } y["query_response0"] = y["query"].strip() + y["response0"] - y["query_response0_token"] = tokenizer.encode( - y["query_response0"], padding="max_length", max_length=args.max_rm_query_response_length, truncation=True - ) + # if padding is space, then we can just concatenate the tokens + if args.tldr_params.padding == "empty_space": + y["query_response0_token"] = y["query_token"] + y["response0_token"] + else: + y["query_response0_token"] = tokenizer.encode( + y["query_response0"], padding="max_length", max_length=max_rm_query_response_length, truncation=True + ) y["query_response0_token_len"] = len(tokenizer.encode(y["query_response0"])) y["query_response1"] = y["query"].strip() + y["response1"] - y["query_response1_token"] = tokenizer.encode( - y["query_response1"], padding="max_length", max_length=args.max_rm_query_response_length, truncation=True - ) + if args.tldr_params.padding == "empty_space": + y["query_response1_token"] = y["query_token"] + y["response1_token"] + else: + y["query_response1_token"] = tokenizer.encode( + y["query_response1"], padding="max_length", max_length=max_rm_query_response_length, truncation=True + ) y["query_response1_token_len"] = len(tokenizer.encode(y["query_response1"])) + y["query_token_len"] = len(tokenizer.encode(y["query"])) return y label_ds = label_ds.map(process_response_data, load_from_cache_file=False, num_proc=multiprocessing.cpu_count()) if args.push_to_hub: label_ds.push_to_hub(f"{args.hf_entity}/summarize_from_feedback_oai_preprocessing_{timestamp}") - os.makedirs("dataset_visuals", exist_ok=True) + #################################### # visualize token length distribution - num_subplots = len(sft_ds) * 2 + len(label_ds) * 4 + #################################### + calculated_tldr_params = TaskQueryHParams( + max_sft_query_response_length=0, + max_sft_response_length=0, + max_rm_response_length=0, + max_rm_query_response_length=0, + ) + calculated_cnndm_params = TaskQueryHParams( + max_rm_query_response_length=0, + max_rm_response_length=0, + ) + + os.makedirs("dataset_visuals", exist_ok=True) + num_sft_visuals = 2 + num_label_visuals = 5 + num_subplots = len(sft_ds) * num_sft_visuals + len(label_ds) * num_label_visuals + num_cols = 3 print(f"{num_subplots=}") - fig, axs = plt.subplots(5, 3, figsize=(16, 16)) + fig, axs = plt.subplots(ceil_div(num_subplots, num_cols), num_cols, figsize=(16, 16)) axs = axs.flatten() j = 0 for _, key in enumerate(sft_ds.keys()): @@ -282,35 +347,74 @@ def process_response_data(x): axs[j + 1].set_title( f"{key} split: query.strip() + reference response token length\nmax_length={max(df['query_reference_response_token_len'])}" ) - j += 2 + calculated_tldr_params.max_sft_response_length = max( + calculated_tldr_params.max_sft_response_length, max(df["reference_response_token_len"]) + ) + calculated_tldr_params.max_sft_query_response_length = max( + calculated_tldr_params.max_sft_query_response_length, max(df["query_reference_response_token_len"]) + ) + j += num_sft_visuals offset = len(sft_ds) - for _, key in enumerate(label_ds.keys()): - df = label_ds[key].to_pandas() + for _, split in enumerate(label_ds.keys()): + df = label_ds[split].to_pandas() axs[j].hist(df["response0_token_len"], bins=100) - axs[j].set_title(f"{key} split: response0 token length\nmax_length={max(df['response0_token_len'])}") + axs[j].set_title(f"{split} split: response0 token length\nmax_length={max(df['response0_token_len'])}") axs[j + 1].hist(df["response1_token_len"], bins=100) - axs[j + 1].set_title(f"{key} split: response1 token length\nmax_length={max(df['response1_token_len'])}") + axs[j + 1].set_title(f"{split} split: response1 token length\nmax_length={max(df['response1_token_len'])}") axs[j + 2].hist(df["query_response0_token_len"], bins=100) axs[j + 2].set_title( - f"{key} split: query.strip() + response0 token length\nmax_length={max(df['query_response0_token_len'])}" + f"{split} split: query.strip() + response0 token length\nmax_length={max(df['query_response0_token_len'])}" ) axs[j + 3].hist(df["query_response1_token_len"], bins=100) axs[j + 3].set_title( - f"{key} split: query.strip() + response1 token length\nmax_length={max(df['query_response1_token_len'])}" + f"{split} split: query.strip() + response1 token length\nmax_length={max(df['query_response1_token_len'])}" ) - j += 4 + axs[j + 4].hist(df["query_token_len"], bins=100) + axs[j + 4].set_title(f"{split} split: query token length\nmax_length={max(df['query_token_len'])}") + if split in ["train", "validation"]: + calculated_tldr_params.max_rm_response_length = max( + calculated_tldr_params.max_rm_response_length, max(df["response0_token_len"]), max(df["response1_token_len"]) + ) + calculated_tldr_params.max_rm_query_response_length = max( + calculated_tldr_params.max_rm_query_response_length, + max(df["query_response0_token_len"]), + max(df["query_response1_token_len"]), + ) + elif split == "validation_cnndm": + calculated_cnndm_params.max_rm_response_length = max( + calculated_cnndm_params.max_rm_response_length, max(df["response0_token_len"]), max(df["response1_token_len"]) + ) + calculated_cnndm_params.max_rm_query_response_length = max( + calculated_cnndm_params.max_rm_query_response_length, + max(df["query_response0_token_len"]), + max(df["query_response1_token_len"]), + ) + else: + raise ValueError(f"Unknown dataset split: {split}") + j += num_label_visuals fig.suptitle(f"{args.base_model} Tokenizer: Token length distribution") fig.tight_layout() fig.savefig("dataset_visuals/token_len.png") + pprint({"calculated_tldr_params": calculated_tldr_params}) + pprint({"calculated_cnndm_params": calculated_cnndm_params}) + if args.check_length_correctness: + assert calculated_tldr_params.max_sft_response_length == args.tldr_params.max_sft_response_length + assert calculated_tldr_params.max_sft_query_response_length == args.tldr_params.max_sft_query_response_length + assert calculated_tldr_params.max_rm_response_length == args.tldr_params.max_rm_response_length + assert calculated_tldr_params.max_rm_query_response_length == args.tldr_params.max_rm_query_response_length + assert calculated_cnndm_params.max_rm_response_length == args.cnndm_params.max_rm_response_length + assert calculated_cnndm_params.max_rm_query_response_length == args.cnndm_params.max_rm_query_response_length + print("✨ calculated lenghts are ok!") + # visualize confidence distribution fig, axs = plt.subplots(len(label_ds), 1, figsize=(8, 8)) axs = axs.flatten() label_ds = label_ds.flatten() - for i, key in enumerate(label_ds.keys()): - df = label_ds[key].to_pandas() + for i, split in enumerate(label_ds.keys()): + df = label_ds[split].to_pandas() axs[i].hist(df["extra.confidence"]) - axs[i].set_title(f"{key} split: confidence distribution") + axs[i].set_title(f"{split} split: confidence distribution") fig.suptitle("Confidence distribution") fig.tight_layout() fig.savefig("dataset_visuals/confidence.png") @@ -319,11 +423,11 @@ def process_response_data(x): fig, axs = plt.subplots(1, len(label_ds), figsize=(8, 12)) axs = axs.flatten() label_ds = label_ds.flatten() - for i, key in enumerate(label_ds.keys()): - df = label_ds[key].to_pandas() + for i, split in enumerate(label_ds.keys()): + df = label_ds[split].to_pandas() cat = pd.concat([df["response0_policy"], df["response1_policy"]], axis=0) cat.hist(ax=axs[i], xrot=90, orientation="horizontal") - axs[i].set_title(f"{key} split: policy distribution") + axs[i].set_title(f"{split} split: policy distribution") fig.suptitle("Policy distribution") fig.tight_layout() fig.savefig("dataset_visuals/policies.png") @@ -332,10 +436,10 @@ def process_response_data(x): fig, axs = plt.subplots(1, len(label_ds), figsize=(24, 30)) axs = axs.flatten() label_ds = label_ds.flatten() - for i, key in enumerate(label_ds.keys()): - df = label_ds[key].to_pandas() + for i, split in enumerate(label_ds.keys()): + df = label_ds[split].to_pandas() df["policies"].hist(ax=axs[i], xrot=90, orientation="horizontal") - axs[i].set_title(f"{key} split: policy comparison distribution") + axs[i].set_title(f"{split} split: policy comparison distribution") fig.suptitle("Policy comparison distribution") fig.tight_layout() fig.savefig("dataset_visuals/policy_comparisons.png") From 2166b4f88b25950600c5fd5dd15d90b8d88f7a6e Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Thu, 11 Jan 2024 23:05:23 +0000 Subject: [PATCH 59/62] quick change --- lm_human_preference_details/summarize/ppo.py | 52 ++++++++++++------- .../summarize/ppo_left_padding.py | 11 ++-- .../summarize/reward.py | 3 +- 3 files changed, 42 insertions(+), 24 deletions(-) diff --git a/lm_human_preference_details/summarize/ppo.py b/lm_human_preference_details/summarize/ppo.py index 925e2de..f891bbb 100644 --- a/lm_human_preference_details/summarize/ppo.py +++ b/lm_human_preference_details/summarize/ppo.py @@ -168,6 +168,8 @@ class Args: """the mini batch size across GPUs""" local_eval_batch_size: int = 2 """per rank eval batch size""" + local_rollout_forward_batch_size: int = 64 + """per rank no grad forward pass in the rollout phase""" # other args base_model: str = "EleutherAI/pythia-160m" @@ -462,7 +464,8 @@ def evaluate(args: Args, reward_model, policy, tokenizer, dataloader, generation name=run_name, save_code=True, ) - wandb.run.log_code(".") + file_extensions = [".toml", ".lock", ".py", ".sh", ".yaml"] + wandb.run.log_code(".", include_fn=lambda path: any([path.endswith(ext) for ext in file_extensions])) writer = SummaryWriter(f"runs/{run_name}") writer.add_text( "hyperparameters", @@ -582,6 +585,7 @@ def repeat_generator(): accelerator.print("===training policy===") global_step = 0 + start_time = time.time() stats_shape = (args.ppo.noptepochs, args.nminibatches, args.gradient_accumulation_steps) approxkl_stats = torch.zeros(stats_shape, device=device) pg_clipfrac_stats = torch.zeros(stats_shape, device=device) @@ -630,8 +634,8 @@ def repeat_generator(): values = [] scores = [] sequence_lengths = [] - for i in range(0, queries.shape[0], args.local_eval_batch_size): - query = queries[i : i + args.local_eval_batch_size] + for i in range(0, queries.shape[0], args.local_rollout_forward_batch_size): + query = queries[i : i + args.local_rollout_forward_batch_size] query_response = generate( accelerator.unwrap_model(model).policy, query, @@ -645,12 +649,16 @@ def repeat_generator(): logits /= args.task.temperature + 1e-7 all_logprob = F.log_softmax(logits, dim=-1) logprob = torch.gather(all_logprob, 2, response.unsqueeze(-1)).squeeze(-1) + del output, logits, all_logprob + torch.cuda.empty_cache() ref_output = forward(ref_policy, query_response, tokenizer) ref_logits = ref_output.logits[:, context_length - 1 : -1] ref_logits /= args.task.temperature + 1e-7 ref_all_logprob = F.log_softmax(ref_logits, dim=-1) ref_logprob = torch.gather(ref_all_logprob, 2, response.unsqueeze(-1)).squeeze(-1) + del ref_output, ref_logits, ref_all_logprob + torch.cuda.empty_cache() # Response Processing 1. truncate response after the first occurrence of `truncate_token_id` postprocessed_response = truncate_response(args, tokenizer, response) @@ -678,8 +686,7 @@ def repeat_generator(): values = torch.cat(values, 0) sequence_lengths = torch.cat(sequence_lengths, 0) scores = torch.cat(scores, 0) - del (output, logits, all_logprob, logprob, ref_output) - del (ref_logits, ref_all_logprob, ref_logprob, full_value, value, score) + del (logprob, ref_logprob, full_value, value, score) torch.cuda.empty_cache() # Response Processing 3. filter response. Ensure that the sample contains truncate_token_id @@ -760,14 +767,22 @@ def repeat_generator(): pg_losses = -mb_advantage * ratio pg_losses2 = -mb_advantage * torch.clamp(ratio, 1.0 - args.ppo.cliprange, 1.0 + args.ppo.cliprange) pg_loss = torch.max(pg_losses, pg_losses2).mean() - pg_clipfrac = (pg_losses2 > pg_losses).float().mean() loss = pg_loss + args.ppo.vf_coef * vf_loss accelerator.backward(loss) optimizer.step() optimizer.zero_grad() - prob_dist = torch.nn.functional.softmax(logits, dim=-1) - entropy = torch.logsumexp(logits, dim=-1) - torch.sum(prob_dist * logits, dim=-1) - approxkl = 0.5 * (logprobs_diff**2).mean() + with torch.no_grad(): + pg_clipfrac = (pg_losses2 > pg_losses).float().mean() + prob_dist = torch.nn.functional.softmax(logits, dim=-1) + entropy = torch.logsumexp(logits, dim=-1) - torch.sum(prob_dist * logits, dim=-1) + approxkl = 0.5 * (logprobs_diff**2).mean() + approxkl_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = approxkl + pg_clipfrac_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = pg_clipfrac + pg_loss_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = pg_loss + vf_loss_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = vf_loss + vf_clipfrac_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = vf_clipfrac + entropy_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = entropy.mean() + ratio_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = ratio.mean() # if ppo_epoch_idx == 0 and micro_batch_start == 0: # torch.testing.assert_close(ratio, torch.zeros_like(ratio) + 1, atol=1e-4, rtol=1e-4) # if ppo_epoch_idx == 0: @@ -788,14 +803,6 @@ def repeat_generator(): # # "entropy": masked_mean(entropy, ~padding_mask[micro_batch_inds]), # }) # breakpoint() - with torch.no_grad(): - approxkl_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = approxkl - pg_clipfrac_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = pg_clipfrac - pg_loss_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = pg_loss - vf_loss_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = vf_loss - vf_clipfrac_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = vf_clipfrac - entropy_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = entropy.mean() - ratio_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = ratio.mean() gradient_accumulation_idx += 1 minibatch_idx += 1 if accelerator.is_main_process: @@ -852,9 +859,13 @@ def repeat_generator(): writer.add_scalar("ppo/val/num_eos_tokens", (responses == tokenizer.eos_token_id).sum().item(), update) writer.add_scalar("ppo/lr", lrnow, update) writer.add_scalar("ppo/episode", global_step, update) + eps = int(global_step / (time.time() - start_time)) + writer.add_scalar("ppo/eps", eps, update) + accelerator.print("ppo/eps", eps, update) if args.reward.use_adaptive_kl: kl_ctl.update(mean_kl.item(), args.batch_size) del kl, mean_kl, mean_entropy, mean_non_score_reward, scores + torch.cuda.empty_cache() if args.run_eval: eval_storage, eval_df = evaluate( @@ -866,9 +877,10 @@ def repeat_generator(): validation_generation_config, sampling=False, ) - eval_df.to_csv(f"runs/{run_name}/table.csv") - if accelerator.is_main_process and args.track: - wandb.log({"eval/query_responses": wandb.Table(dataframe=eval_df)}, step=update) + if accelerator.is_main_process: + eval_df.to_csv(f"runs/{run_name}/table.csv") + if args.track: + wandb.log({"eval/query_responses": wandb.Table(dataframe=eval_df)}, step=update) # save model if args.output_dir and args.num_train_epochs > 0: diff --git a/lm_human_preference_details/summarize/ppo_left_padding.py b/lm_human_preference_details/summarize/ppo_left_padding.py index 9067f2e..da776b9 100644 --- a/lm_human_preference_details/summarize/ppo_left_padding.py +++ b/lm_human_preference_details/summarize/ppo_left_padding.py @@ -589,6 +589,7 @@ def repeat_generator(): accelerator.print("===training policy===") global_step = 0 + start_time = time.time() stats_shape = (args.ppo.noptepochs, args.nminibatches, args.gradient_accumulation_steps) approxkl_stats = torch.zeros(stats_shape, device=device) pg_clipfrac_stats = torch.zeros(stats_shape, device=device) @@ -864,6 +865,9 @@ def repeat_generator(): writer.add_scalar("ppo/val/num_eos_tokens", (responses == tokenizer.eos_token_id).sum().item(), update) writer.add_scalar("ppo/lr", lrnow, update) writer.add_scalar("ppo/episode", global_step, update) + eps = int(global_step / (time.time() - start_time)) + writer.add_scalar("ppo/eps", eps, update) + accelerator.print("ppo/eps", eps, update) if args.reward.use_adaptive_kl: kl_ctl.update(mean_kl.item(), args.batch_size) del kl, mean_kl, mean_entropy, mean_non_score_reward, scores @@ -879,9 +883,10 @@ def repeat_generator(): validation_generation_config, sampling=False, ) - eval_df.to_csv(f"runs/{run_name}/table.csv") - if accelerator.is_main_process and args.track: - wandb.log({"eval/query_responses": wandb.Table(dataframe=eval_df)}, step=update) + if accelerator.is_main_process: + eval_df.to_csv(f"runs/{run_name}/table.csv") + if args.track: + wandb.log({"eval/query_responses": wandb.Table(dataframe=eval_df)}, step=update) # save model if args.output_dir and args.num_train_epochs > 0: diff --git a/lm_human_preference_details/summarize/reward.py b/lm_human_preference_details/summarize/reward.py index a41bea4..5ed3524 100644 --- a/lm_human_preference_details/summarize/reward.py +++ b/lm_human_preference_details/summarize/reward.py @@ -127,7 +127,7 @@ class Args: """The batch size per GPU (HF's `per_device_train_batch_size` * `gradient_accumulation_steps`)""" batch_size: Optional[int] = None """The batch size across devices (HF's `per_device_train_batch_size` * `world_size` * `gradient_accumulation_steps`)""" - local_eval_batch_size: int = 8 + local_eval_batch_size: int = 1 """per rank eval batch size""" # other args @@ -459,6 +459,7 @@ def evaluate(args: Args, accelerator, tokenizer, model, dataloader): evaluate_df.to_csv(f"eval_tables/{run_name}/eval_{eval_split}_{update}.csv") if args.track: wandb.log({f"samples/{eval_split}/query_responses": wandb.Table(dataframe=evaluate_df)}, step=update) + del evaluate_df torch.cuda.empty_cache() norm_dataset = load_dataset(args.task.query_dataset, split="train") From 90c75f025c88fa5ddb5a21fccbc43cb0364cd3f2 Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Sun, 14 Jan 2024 16:39:31 +0000 Subject: [PATCH 60/62] small refactor; remove unused var --- lm_human_preference_details/summarize/reward.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/lm_human_preference_details/summarize/reward.py b/lm_human_preference_details/summarize/reward.py index 5ed3524..a5de811 100644 --- a/lm_human_preference_details/summarize/reward.py +++ b/lm_human_preference_details/summarize/reward.py @@ -269,6 +269,13 @@ def evaluate(args: Args, accelerator, tokenizer, model, dataloader): args.local_batch_size = args.local_micro_batch_size * args.gradient_accumulation_steps args.micro_batch_size = int(args.local_micro_batch_size * args.world_size) args.batch_size = int(args.local_batch_size * args.world_size) + tokenizer = AutoTokenizer.from_pretrained( + args.base_model, + padding_side="right", + trust_remote_code=True, + ) + # we use the padding token manually but do not resize the token embedding of the model + tokenizer.add_special_tokens({"pad_token": "[PAD]"}) # load dataset dataset = load_dataset(args.label_dataset, split="train") @@ -288,7 +295,6 @@ def evaluate(args: Args, accelerator, tokenizer, model, dataloader): ], ) dataloader = DataLoader(dataset, batch_size=args.local_micro_batch_size) - eval_datasets = [] eval_dataloaders = {} for split in ["validation", "validation_cnndm"]: validation_dataset = load_dataset(args.label_dataset, split=split).flatten() @@ -309,7 +315,6 @@ def evaluate(args: Args, accelerator, tokenizer, model, dataloader): "policies", ], ) - eval_datasets.append(validation_dataset) eval_dataloaders[split] = DataLoader(validation_dataset, batch_size=args.local_eval_batch_size) accelerator.print("The number of samples in validation_dataset", len(validation_dataset)) accelerator.print("The number of samples in dataset", len(dataset)) @@ -345,13 +350,7 @@ def evaluate(args: Args, accelerator, tokenizer, model, dataloader): np.random.seed(local_seed) torch.manual_seed(local_seed) torch.backends.cudnn.deterministic = True - tokenizer = AutoTokenizer.from_pretrained( - args.base_model, - padding_side="right", - trust_remote_code=True, - ) - # we use the padding token manually but do not resize the token embedding of the model - tokenizer.add_special_tokens({"pad_token": "[PAD]"}) + model_config = AutoConfig.from_pretrained(args.base_model) configure_dropout(model_config, args.dropout_layer_keys, 0.0) # disable dropout scalar_model_config = ScalarModelConfig( From 1d5b7dbc351f0ec27aa18d39e3e9acbdf3e41fa3 Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Sun, 14 Jan 2024 16:40:05 +0000 Subject: [PATCH 61/62] evaluate test splits as well --- lm_human_preference_details/summarize/ppo.py | 45 ++++++++++--------- .../summarize/ppo_left_padding.py | 45 ++++++++++--------- 2 files changed, 50 insertions(+), 40 deletions(-) diff --git a/lm_human_preference_details/summarize/ppo.py b/lm_human_preference_details/summarize/ppo.py index f891bbb..c34c5c5 100644 --- a/lm_human_preference_details/summarize/ppo.py +++ b/lm_human_preference_details/summarize/ppo.py @@ -485,7 +485,7 @@ def evaluate(args: Args, reward_model, policy, tokenizer, dataloader, generation base_config=model_config, hidden_size=model_config.hidden_size, ) - if not args.reward_model_path: + if len(args.reward_model_path) == 0: critic: PreTrainedModel = ScalarModel(scalar_model_config) reward_model: PreTrainedModel = ScalarModel(scalar_model_config) else: @@ -519,21 +519,24 @@ def evaluate(args: Args, reward_model, policy, tokenizer, dataloader, generation validation_dataset = load_dataset(args.task.query_dataset, split="validation") dataset = dataset.with_format("torch", columns=["query_token", "reference_response_token"]) dataloader = DataLoader(dataset, batch_size=args.local_batch_size, shuffle=True) - validation_dataset = validation_dataset.with_format("torch", columns=["query_token", "reference_response_token"]) - validation_dataloader = DataLoader(validation_dataset, batch_size=args.local_eval_batch_size) + eval_dataloaders = {} + for split in ["validation", "test"]: + eval_dataset = load_dataset(args.task.query_dataset, split=split) + eval_dataset = eval_dataset.with_format("torch", columns=["query_token", "reference_response_token"]) + eval_dataloaders[split] = DataLoader(eval_dataset, batch_size=args.local_eval_batch_size) # sync random states for DataLoader(shuffle=True) before `accelerator.prepare` # see https://gist.github.com/vwxyzjn/2581bff1e48e185e0b85b6dfe1def79c torch.manual_seed(args.seed) model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader) torch.manual_seed(local_seed) # reset the local seed again + eval_dataloaders = {split: accelerator.prepare(eval_dataloader) for split, eval_dataloader in eval_dataloaders.items()} def repeat_generator(): while True: yield from dataloader iter_dataloader = iter(repeat_generator()) - validation_dataloader = accelerator.prepare(validation_dataloader) if args.deepspeed: import deepspeed @@ -586,6 +589,7 @@ def repeat_generator(): accelerator.print("===training policy===") global_step = 0 start_time = time.time() + eval_split = "validation" stats_shape = (args.ppo.noptepochs, args.nminibatches, args.gradient_accumulation_steps) approxkl_stats = torch.zeros(stats_shape, device=device) pg_clipfrac_stats = torch.zeros(stats_shape, device=device) @@ -607,15 +611,15 @@ def repeat_generator(): reward_model, accelerator.unwrap_model(model).policy, tokenizer, - validation_dataloader, + eval_dataloaders[eval_split], validation_generation_config, ) validation_score = eval_storage.score[0] if args.print_sample_output_freq > 0 and (update - 1) % args.print_sample_output_freq == 0: if accelerator.is_main_process: - eval_df.to_csv(f"runs/{run_name}/table_{global_step}.csv") + eval_df.to_csv(f"runs/{run_name}/{eval_split}_table_{global_step}.csv") if args.track: - wandb.log({"samples/query_responses": wandb.Table(dataframe=eval_df)}, step=update) + wandb.log({f"samples/{eval_split}_query_responses": wandb.Table(dataframe=eval_df)}, step=update) else: try: print_rich_table(f"Sample Output at Step {update}", eval_df[:1], console) @@ -868,19 +872,20 @@ def repeat_generator(): torch.cuda.empty_cache() if args.run_eval: - eval_storage, eval_df = evaluate( - args, - reward_model, - accelerator.unwrap_model(model).policy, - tokenizer, - validation_dataloader, - validation_generation_config, - sampling=False, - ) - if accelerator.is_main_process: - eval_df.to_csv(f"runs/{run_name}/table.csv") - if args.track: - wandb.log({"eval/query_responses": wandb.Table(dataframe=eval_df)}, step=update) + for eval_split in eval_dataloaders: + eval_storage, eval_df = evaluate( + args, + reward_model, + accelerator.unwrap_model(model).policy, + tokenizer, + eval_dataloaders[eval_split], + validation_generation_config, + sampling=False, + ) + if accelerator.is_main_process: + eval_df.to_csv(f"runs/{run_name}/{eval_split}_table.csv") + if args.track: + wandb.log({f"eval/{eval_split}_query_responses": wandb.Table(dataframe=eval_df)}, step=update) # save model if args.output_dir and args.num_train_epochs > 0: diff --git a/lm_human_preference_details/summarize/ppo_left_padding.py b/lm_human_preference_details/summarize/ppo_left_padding.py index da776b9..389ab22 100644 --- a/lm_human_preference_details/summarize/ppo_left_padding.py +++ b/lm_human_preference_details/summarize/ppo_left_padding.py @@ -489,7 +489,7 @@ def evaluate(args: Args, reward_model, policy, tokenizer, dataloader, generation base_config=model_config, hidden_size=model_config.hidden_size, ) - if not args.reward_model_path: + if len(args.reward_model_path) == 0: critic: PreTrainedModel = ScalarModel(scalar_model_config) reward_model: PreTrainedModel = ScalarModel(scalar_model_config) else: @@ -523,21 +523,24 @@ def evaluate(args: Args, reward_model, policy, tokenizer, dataloader, generation validation_dataset = load_dataset(args.task.query_dataset, split="validation") dataset = dataset.with_format("torch", columns=["query_token", "reference_response_token"]) dataloader = DataLoader(dataset, batch_size=args.local_batch_size, shuffle=True) - validation_dataset = validation_dataset.with_format("torch", columns=["query_token", "reference_response_token"]) - validation_dataloader = DataLoader(validation_dataset, batch_size=args.local_eval_batch_size) + eval_dataloaders = {} + for split in ["validation", "test"]: + eval_dataset = load_dataset(args.task.query_dataset, split=split) + eval_dataset = eval_dataset.with_format("torch", columns=["query_token", "reference_response_token"]) + eval_dataloaders[split] = DataLoader(eval_dataset, batch_size=args.local_eval_batch_size) # sync random states for DataLoader(shuffle=True) before `accelerator.prepare` # see https://gist.github.com/vwxyzjn/2581bff1e48e185e0b85b6dfe1def79c torch.manual_seed(args.seed) model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader) torch.manual_seed(local_seed) # reset the local seed again + eval_dataloaders = {split: accelerator.prepare(eval_dataloader) for split, eval_dataloader in eval_dataloaders.items()} def repeat_generator(): while True: yield from dataloader iter_dataloader = iter(repeat_generator()) - validation_dataloader = accelerator.prepare(validation_dataloader) if args.deepspeed: import deepspeed @@ -590,6 +593,7 @@ def repeat_generator(): accelerator.print("===training policy===") global_step = 0 start_time = time.time() + eval_split = "validation" stats_shape = (args.ppo.noptepochs, args.nminibatches, args.gradient_accumulation_steps) approxkl_stats = torch.zeros(stats_shape, device=device) pg_clipfrac_stats = torch.zeros(stats_shape, device=device) @@ -611,15 +615,15 @@ def repeat_generator(): reward_model, accelerator.unwrap_model(model).policy, tokenizer, - validation_dataloader, + eval_dataloaders[eval_split], validation_generation_config, ) validation_score = eval_storage.score[0] if args.print_sample_output_freq > 0 and (update - 1) % args.print_sample_output_freq == 0: if accelerator.is_main_process: - eval_df.to_csv(f"runs/{run_name}/table_{global_step}.csv") + eval_df.to_csv(f"runs/{run_name}/{eval_split}_table_{global_step}.csv") if args.track: - wandb.log({"samples/query_responses": wandb.Table(dataframe=eval_df)}, step=update) + wandb.log({f"samples/{eval_split}_query_responses": wandb.Table(dataframe=eval_df)}, step=update) else: try: print_rich_table(f"Sample Output at Step {update}", eval_df[:1], console) @@ -874,19 +878,20 @@ def repeat_generator(): torch.cuda.empty_cache() if args.run_eval: - eval_storage, eval_df = evaluate( - args, - reward_model, - accelerator.unwrap_model(model).policy, - tokenizer, - validation_dataloader, - validation_generation_config, - sampling=False, - ) - if accelerator.is_main_process: - eval_df.to_csv(f"runs/{run_name}/table.csv") - if args.track: - wandb.log({"eval/query_responses": wandb.Table(dataframe=eval_df)}, step=update) + for eval_split in eval_dataloaders: + eval_storage, eval_df = evaluate( + args, + reward_model, + accelerator.unwrap_model(model).policy, + tokenizer, + eval_dataloaders[eval_split], + validation_generation_config, + sampling=False, + ) + if accelerator.is_main_process: + eval_df.to_csv(f"runs/{run_name}/{eval_split}_table.csv") + if args.track: + wandb.log({f"eval/{eval_split}_query_responses": wandb.Table(dataframe=eval_df)}, step=update) # save model if args.output_dir and args.num_train_epochs > 0: From 828db07e5774db1152a97035c24e637e0f11470d Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Sun, 14 Jan 2024 16:40:18 +0000 Subject: [PATCH 62/62] add DPO --- lm_human_preference_details/summarize/dpo.py | 580 +++++++++++++++++++ 1 file changed, 580 insertions(+) create mode 100644 lm_human_preference_details/summarize/dpo.py diff --git a/lm_human_preference_details/summarize/dpo.py b/lm_human_preference_details/summarize/dpo.py new file mode 100644 index 0000000..dcc4c7f --- /dev/null +++ b/lm_human_preference_details/summarize/dpo.py @@ -0,0 +1,580 @@ +import os +import random +import time +from collections import defaultdict +from dataclasses import asdict, dataclass, field +from types import SimpleNamespace +from typing import List, Literal, Optional + +import numpy as np +import pandas as pd +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +import tyro +from accelerate import Accelerator +from accelerate.state import AcceleratorState +from accelerate.utils import gather_object +from datasets import load_dataset +from rich.console import Console +from rich.pretty import pprint +from rich.table import Table +from torch import optim +from torch.utils.data import DataLoader +from torch.utils.tensorboard import SummaryWriter +from tqdm import tqdm +from transformers import ( + AutoConfig, + AutoModelForCausalLM, + AutoTokenizer, + GenerationConfig, + PreTrainedModel, + get_scheduler, +) + + +@dataclass +class LabelHParams: + type: Optional[str] = None + num_train: int = 92832 + num_labels: int = 2 + source: Optional[str] = None + + +# a patch +@dataclass +class TaskHParams: + # Query params + query_length: int = 512 + query_dataset: str = "cleanrl/summarize_from_feedback_tldr_3_filtered_oai_preprocessing_1705009345" + + query_format_str: Optional[str] = "SUBREDDIT: r/{subreddit}\n\nTITLE: {title}\n\nPOST: {post}\n\nTL;DR:" + query_truncate_field: Optional[str] = "post" + query_truncate_text: Optional[str] = "\n" + query_padding: Optional[str] = None # defaults to repeated spaces + query_pad_side: Optional[str] = "left" + + # Response params + response_length: int = 53 + + # Truncate response after the first occurrence of this token at or after index after when sampling. + truncate_token: Literal["eos"] = "eos" + truncate_token_id: Optional[int] = None + truncate_after: int = 16 + penalty_reward_value: int = -1 + + # LM params + temperature: float = 0.01 + + +@dataclass +class Args: + # common args + exp_name: str = os.path.basename(__file__)[: -len(".py")] + """the name of this experiment""" + seed: int = 1 + """seed of the experiment""" + track: bool = False + """if toggled, this experiment will be tracked with Weights and Biases""" + wandb_project_name: str = "tldr_summarize" + """the wandb's project name""" + wandb_entity: Optional[str] = None + """the entity (team) of wandb's project""" + cuda: bool = True + """Whether to use cuda if available.""" + run_name: Optional[str] = None + """a unique name of this run""" + load_from_cache_file: bool = False + """Whether to load data from the local cache file in `dataset.map`""" + push_to_hub: bool = False + "whether to upload the saved model to huggingface" + hf_entity: str = "" + "the user or org name of the model repository from the Hugging Face Hub" + deepspeed: bool = False + """Whether to use deepspeed to train the model""" + print_sample_output_freq: int = 220 + """How often to print sample output""" + run_eval: bool = False + """Whether to run evaluation""" + + # optimizer args + eps: float = 1e-5 + """the epsilon value for the optimizer""" + lr: float = 5e-6 + """the learning rate""" + optimizer: Literal["adam", "adamw"] = "adamw" + """Which optimizer to use""" + scheduler: str = "cosine" + """Which scheduler to use""" + warm_up_steps: int = 0 + """Number of warm up steps for the scheduler""" + + world_size: Optional[int] = None + """The number of processes (GPUs) to use""" + num_train_epochs: int = 1 + """Number of epochs to train""" + num_updates: Optional[int] = None + """The number of updates to train""" + gradient_accumulation_steps: int = 8 + """The number of gradient accumulation steps""" + local_micro_batch_size: Optional[int] = 1 + """The micro batch size per GPU (HF's `per_device_train_batch_size`)""" + total_episodes: Optional[int] = None + """The total number of episodes in the dataset""" + micro_batch_size: Optional[int] = None + """The micro batch size across devices (HF's `per_device_train_batch_size` * `world_size`)""" + local_batch_size: Optional[int] = None + """The batch size per GPU (HF's `per_device_train_batch_size` * `gradient_accumulation_steps`)""" + batch_size: Optional[int] = None + """The batch size across devices (HF's `per_device_train_batch_size` * `world_size` * `gradient_accumulation_steps`)""" + local_eval_batch_size: int = 1 + """per rank eval batch size""" + + # other args + base_model: str = "EleutherAI/pythia-160m" + """the name of the pretrained model to use""" + dropout_layer_keys: List[str] = field( + default_factory=lambda: ["attn_pdrop", "embd_pdrop", "resid_pdrop", "summary_first_dropout"] + ) + """Which layers to apply dropout to""" + output_dir: str = "models/dpo_policy_model" + """Where to save the model""" + label_dataset: str = "cleanrl/summarize_from_feedback_oai_preprocessing_1705009345" + """the name of the dataset to use for labels in `https://huggingface.co/datasets/vwxyzjn/lm-human-preferences`""" + ipo: bool = False + """Whether to use IPO loss https://arxiv.org/abs/2310.12036""" + label_smoothing: float = 0.0 + """Label smoothing for DPO (Eq. 3 https://ericmitchell.ai/cdpo.pdf; label_smoothing=0 gives original DPO (Eq. 7 of https://arxiv.org/pdf/2305.18290.pdf))""" + beta: float = 0.05 + """The beta value for DPO""" + task: TaskHParams = field(default_factory=TaskHParams) + label: LabelHParams = field(default_factory=LabelHParams) + + +# taken from https://github.com/microsoft/DeepSpeedExamples/blob/737c6740bec38b77a24a59135b6481a53d566b38/applications/DeepSpeed-Chat/training/utils/model/model_utils.py#L20C1-L26C52 +def configure_dropout(model_config, dropout_layer_keys, dropout): + if dropout is not None: + for key in dropout_layer_keys: + if hasattr(model_config, key): + print(f"Setting model_config.{key} to {dropout}") + setattr(model_config, key, dropout) + + +def print_rich_table(title: str, df: pd.DataFrame, console: Console) -> Table: + table = Table(show_lines=True) + for column in df.columns: + table.add_column(column) + for _, row in df.iterrows(): + table.add_row(*row.astype(str).tolist()) + console.rule(f"[bold red]{title}") + console.print(table) + + +def forward(model, query_responses, labels, mb_best, tokenizer): + attention_mask = query_responses != tokenizer.pad_token_id + input_ids = torch.masked_fill(query_responses, ~attention_mask, 0) + output = model( + input_ids=input_ids, + attention_mask=attention_mask, + return_dict=True, + ) + labels = labels[:, 1:].clone() + logits = output.logits[:, :-1, :] + loss_mask = (labels != tokenizer.pad_token_id) + per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2) + all_logps = (per_token_logps * loss_mask).sum(-1) + chosen_logps = all_logps.view(-1, args.label.num_labels).gather(1, mb_best.view(-1, 1)).view(-1) + rejected_logps = all_logps.view(-1, args.label.num_labels).gather(1, (1 - mb_best).view(-1, 1)).view(-1) + return chosen_logps, rejected_logps + + +def generate(lm_backbone, queries, tokenizer, generation_config): + """generate in a way that does not affect padding tokens""" + context_length = queries.shape[1] + attention_mask = queries != tokenizer.pad_token_id + input_ids = torch.masked_fill(queries, ~attention_mask, 0) + output = lm_backbone.generate( + input_ids=input_ids, + attention_mask=attention_mask, + # position_ids=attention_mask.cumsum(1) - attention_mask.long(), # generation collapsed if this was turned on. TODO: why does generation collapse with this? + generation_config=generation_config, + return_dict_in_generate=True, + ) + return torch.cat((queries, output.sequences[:, context_length:]), dim=1) + + +def first_true_indices(bools, dtype=torch.long): + """ + Takes an N-dimensional bool tensor and returns an (N-1)-dimensional tensor of integers giving + the position of the first True in each "row". + + Returns the length of the rows (bools.size(-1)) if no element is True in a given row. + """ + row_len = bools.size(-1) + zero_or_index = row_len * (~bools).type(dtype) + torch.arange(row_len, dtype=dtype, device=bools.device) + return torch.min(zero_or_index, dim=-1).values + + +def truncate_response(args, tokenizer, responses): + trunc_idxs = first_true_indices(responses == args.task.truncate_token_id).unsqueeze(-1) + new_size = [1] * (len(responses.size()) - 1) + [args.task.response_length] + idxs = torch.arange(args.task.response_length, device=responses.device).view(*new_size) + postprocessed_responses = torch.masked_fill(responses, idxs > trunc_idxs, tokenizer.pad_token_id) + return postprocessed_responses + + +def evaluate_rm(args: Args, accelerator, tokenizer, model, ref_model, dataloader): + model.eval() + with torch.no_grad(): + items = defaultdict(list) + for data in tqdm(dataloader): + query_responses = torch.cat( + [data["query_response0_token"].unsqueeze(1), data["query_response1_token"].unsqueeze(1)], dim=1 + ).flatten(0, 1) + labels = torch.cat( + [data["query_response0_token_response_label"].unsqueeze(1), data["query_response1_token_response_label"].unsqueeze(1)], + dim=1, + ).flatten(0, 1) + mb_best = data["choice"] + chosen_logps, rejected_logps = forward(model, query_responses, labels, mb_best, tokenizer) + ref_chosen_logps, ref_rejected_logps = forward(ref_model, query_responses, labels, mb_best, tokenizer) + reward_preferred = args.beta * (chosen_logps - ref_chosen_logps).detach() + reward_rejected = args.beta * (rejected_logps - ref_rejected_logps).detach() + accuracy = reward_preferred > reward_rejected + print(accuracy.float().mean()) + for k in data: + data[k] = gather_object(data[k]) + for i in range(len(accuracy)): + items["query"].append(tokenizer.decode(data["query_token"][i], skip_special_tokens=True)) + items["response0"].append(tokenizer.decode(data["response0_token"][i])) + items["response1"].append(tokenizer.decode(data["response1_token"][i])) + items["batch"].append(data["batch"][i]) + items["split"].append(data["split"][i]) + items["confidence"].append(data["extra.confidence"][i].item()) + items["choice"].append(data["choice"][i].item()) + items["policies"].append(data["policies"][i]) + items["response0_policy"].append(data["response0_policy"][i]) + items["response1_policy"].append(data["response1_policy"][i]) + items["accuracy"].append(accuracy[i].item()) + model.train() + return pd.DataFrame(items) + + + +@dataclass +class EvalStorage: + query_token: List[str] = field(default_factory=list) + postprocessed_response_token: List[str] = field(default_factory=list) + reference_response_token: List[str] = field(default_factory=list) + score: List[float] = field(default_factory=list) + reference_score: List[float] = field(default_factory=list) + + query: List[str] = field(default_factory=list) + postprocessed_response: List[str] = field(default_factory=list) + reference_response: List[str] = field(default_factory=list) + + +def evaluate_policy(args: Args, model, tokenizer, dataloader, generation_config, sampling=True): + eval_storage = EvalStorage() + with torch.no_grad(): + for data in tqdm(dataloader): + queries = data["query_token"] + reference_response_token = data["reference_response_token"] + context_length = queries.shape[1] + query_responses = generate( + model, + queries, + tokenizer, + generation_config, + ) + responses = query_responses[:, context_length:] + postprocessed_responses = truncate_response(args, tokenizer, responses) + eval_storage.query_token.extend(queries) + eval_storage.reference_response_token.extend(reference_response_token) + eval_storage.postprocessed_response_token.extend(postprocessed_responses) + if sampling: + break + + eval_storage.query = tokenizer.batch_decode(eval_storage.query_token, skip_special_tokens=True) + eval_storage.reference_response = tokenizer.batch_decode(eval_storage.reference_response_token) + eval_storage.postprocessed_response = tokenizer.batch_decode( + eval_storage.postprocessed_response_token, skip_special_tokens=True + ) + # eval_score = torch.cat(eval_storage.score).float().cpu().numpy().tolist() + # eval_reference_score = torch.cat(eval_storage.reference_score).float().cpu().numpy().tolist() + eval_df = pd.DataFrame( + { + "query": gather_object(eval_storage.query), + "postprocessed_response": gather_object(eval_storage.postprocessed_response), + "reference_responses": gather_object(eval_storage.reference_response), + # "scores": gather_object(eval_score), + # "reference_scores": gather_object(eval_reference_score), + } + ) + return eval_storage, eval_df + +# def train(args: Args): +if __name__ == "__main__": + args = tyro.cli(Args) + accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps) + local_seed = args.seed + accelerator.process_index * 100003 # Prime + args.world_size = accelerator.num_processes + args.local_batch_size = args.local_micro_batch_size * args.gradient_accumulation_steps + args.micro_batch_size = int(args.local_micro_batch_size * args.world_size) + args.batch_size = int(args.local_batch_size * args.world_size) + + tokenizer = AutoTokenizer.from_pretrained( + args.base_model, + padding_side="right", + trust_remote_code=True, + ) + # we use the padding token manually but do not resize the token embedding of the model + tokenizer.add_special_tokens({"pad_token": "[PAD]"}) + if args.task.truncate_token == "eos": + args.task.truncate_token_id = tokenizer.eos_token_id + + # load dataset + dataset = load_dataset(args.label_dataset, split="train") + dataset = dataset.shuffle(seed=local_seed) + dataset = dataset.select(range(args.label.num_train)) + dataset = dataset.with_format( + "torch", + columns=[ + "query_token", + "choice", + "response0_token", + "query_response0_token", + "query_response0_token_response_label", + "response1_token", + "query_response1_token", + "query_response1_token_response_label", + "batch", + "split", + ], + ) + dataloader = DataLoader(dataset, batch_size=args.local_micro_batch_size) + eval_datasets = [] + eval_dataloaders = {} + for split in ["validation", "validation_cnndm"]: + validation_dataset = load_dataset(args.label_dataset, split=split).flatten() + validation_dataset = validation_dataset.with_format( + "torch", + columns=[ + "query_token", + "choice", + "response0_token", + "query_response0_token", + "query_response0_token_response_label", + "response1_token", + "query_response1_token", + "query_response1_token_response_label", + "batch", + "split", + "extra.confidence", + "response0_policy", + "response1_policy", + "policies", + ], + ) + eval_datasets.append(validation_dataset) + eval_dataloaders[split] = DataLoader(validation_dataset, batch_size=args.local_eval_batch_size) + accelerator.print("The number of samples in validation_dataset", len(validation_dataset)) + accelerator.print("The number of samples in dataset", len(dataset)) + + sft_validation_dataset = load_dataset(args.task.query_dataset, split="validation") + sft_validation_dataset = sft_validation_dataset.with_format("torch", columns=["query_token", "reference_response_token", "query_reference_response_token_response_label"]) + sft_validation_dataloader = DataLoader(sft_validation_dataset, batch_size=args.local_eval_batch_size) + + args.total_episodes = len(dataset) + args.num_updates = args.total_episodes // args.batch_size + + console = Console(force_terminal=True) + run_name = f"{args.exp_name}__{args.seed}__{int(time.time())}" + writer = SimpleNamespace() # dummy writer + writer.add_scalar = lambda x, y, z: None + if accelerator.is_main_process: + if args.track: + import wandb + + wandb.init( + project=args.wandb_project_name, + entity=args.wandb_entity, + sync_tensorboard=True, + config=asdict(args), + name=run_name, + save_code=True, + ) + # file_extensions = [".toml", ".lock", ".py", ".sh", ".yaml"] + # wandb.run.log_code(".", include_fn=lambda path: any([path.endswith(ext) for ext in file_extensions])) + writer = SummaryWriter(f"runs/{run_name}") + writer.add_text( + "hyperparameters", + "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])), + ) + pprint(args) + device = accelerator.device + random.seed(local_seed) + np.random.seed(local_seed) + torch.manual_seed(local_seed) + torch.backends.cudnn.deterministic = True + # model_config = AutoConfig.from_pretrained(args.base_model) + # configure_dropout(model_config, args.dropout_layer_keys, 0.0) # disable dropout + model: PreTrainedModel = AutoModelForCausalLM.from_pretrained( + args.base_model, + # config=model_config, + trust_remote_code=True, + ) + model.generation_config.eos_token_id = None # disable `pad_token_id` and `eos_token_id` because we just want to + model.generation_config.pad_token_id = None # generate tokens without truncation / padding + ref_model = AutoModelForCausalLM.from_pretrained(args.base_model) + # if accelerator.is_main_process: + # pprint(model_config) + if args.optimizer == "adam": + optimizer = optim.Adam(model.parameters(), lr=args.lr, eps=args.eps) + elif args.optimizer == "adamw": + optimizer = optim.AdamW(model.parameters(), lr=args.lr, eps=args.eps) + scheduler = get_scheduler( + args.scheduler, + optimizer=optimizer, + num_warmup_steps=args.warm_up_steps, + num_training_steps=args.num_updates * args.num_train_epochs, + ) + + if args.deepspeed: + deepspeed_states = AcceleratorState().deepspeed_plugin + deepspeed_states.deepspeed_config["train_micro_batch_size_per_gpu"] = args.local_micro_batch_size + + ref_model = ref_model.to(device) + model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader) + eval_dataloaders = {split: accelerator.prepare(eval_dataloader) for split, eval_dataloader in eval_dataloaders.items()} + sft_validation_dataloader = accelerator.prepare(sft_validation_dataloader) + # use the same `0.01` temperature for validation response generation https://github.com/openai/summarize-from-feedback/blob/700967448d10004279f138666442bf1497d0e705/exps/sample.py#L27 + validation_generation_config = GenerationConfig( + max_new_tokens=args.task.response_length, + min_new_tokens=args.task.response_length, + temperature=(0.01 + 1e-7), + top_k=0.0, + top_p=1.0, + do_sample=True, + ) + + accelerator.print("===training model===") + losses = torch.zeros((args.gradient_accumulation_steps,), device=device) + accuracies = torch.zeros((args.gradient_accumulation_steps,), device=device) + reward_preferreds = torch.zeros((args.gradient_accumulation_steps,), device=device) + reward_rejecteds = torch.zeros((args.gradient_accumulation_steps,), device=device) + reward_margins = torch.zeros((args.gradient_accumulation_steps,), device=device) + model.train() + gradient_accumulation_idx = 0 + global_step = 0 + update = 0 + for epoch in range(args.num_train_epochs): + accelerator.print(f"epoch: {epoch}") + for data in dataloader: + update += 1 + global_step += args.micro_batch_size + query_responses = torch.cat( + [data["query_response0_token"].unsqueeze(1), data["query_response1_token"].unsqueeze(1)], dim=1 + ).flatten(0, 1) + labels = torch.cat( + [data["query_response0_token_response_label"].unsqueeze(1), data["query_response1_token_response_label"].unsqueeze(1)], + dim=1, + ).flatten(0, 1) + mb_best = data["choice"] + with torch.no_grad(): + ref_chosen_logps, ref_rejected_logps = forward(ref_model, query_responses, labels, mb_best, tokenizer) + with accelerator.accumulate(model): + chosen_logps, rejected_logps = forward(model, query_responses, labels, mb_best, tokenizer) + + pi_logratios = chosen_logps - rejected_logps + ref_logratios = ref_chosen_logps - ref_rejected_logps + logits = pi_logratios - ref_logratios # also known as h_{\pi_\theta}^{y_w,y_l} + if args.ipo: + loss = (logits - 1/(2 * args.beta)) ** 2 # Eq. 17 of https://arxiv.org/pdf/2310.12036v2.pdf + else: + loss = -F.logsigmoid(args.beta * logits) * (1 - args.label_smoothing) - F.logsigmoid(-args.beta * logits) * args.label_smoothing + reward_preferred = args.beta * (chosen_logps - ref_chosen_logps).detach() + reward_rejected = args.beta * (rejected_logps - ref_rejected_logps).detach() + + accelerator.backward(loss) + optimizer.step() + optimizer.zero_grad() + with torch.no_grad(): + losses[gradient_accumulation_idx] = loss + accuracies[gradient_accumulation_idx] = (reward_preferred > reward_rejected).float().mean() + reward_preferreds[gradient_accumulation_idx] = reward_preferred.mean() + reward_rejecteds[gradient_accumulation_idx] = reward_rejected.mean() + reward_margins[gradient_accumulation_idx] = (reward_preferred - reward_rejected).mean() + gradient_accumulation_idx = (gradient_accumulation_idx + 1) % args.gradient_accumulation_steps + if update > 1 and (update - 1) % args.gradient_accumulation_steps == 0: + scheduler.step() + train_accuracy = accelerator.gather(accuracies).mean().item() + writer.add_scalar("train/rm/loss", accelerator.gather(losses).mean().item(), global_step) + writer.add_scalar("train/rm/accuracy", train_accuracy, global_step) + writer.add_scalar( + "train/rm/reward_preferred", accelerator.gather(reward_preferreds).mean().item(), global_step + ) + writer.add_scalar("train/rm/reward_rejected", accelerator.gather(reward_rejecteds).mean().item(), global_step) + writer.add_scalar("train/rm/lr", scheduler.get_last_lr()[0], global_step) + accelerator.print( + f"{train_accuracy=}, {scheduler.get_last_lr()=}, {optimizer.param_groups[0]['lr']=}, {update=}" + ) + + if args.run_eval: + _, evaluate_df = evaluate_policy(args, model, tokenizer, sft_validation_dataloader, validation_generation_config, sampling=False) + if accelerator.is_main_process: + evaluate_df.to_csv(f"runs/{run_name}/table.csv") + if args.track: + wandb.log({"eval/query_responses": wandb.Table(dataframe=evaluate_df)}, step=update) + for eval_split in eval_dataloaders: + evaluate_df = evaluate_rm(args, accelerator, tokenizer, model, ref_model, eval_dataloaders[eval_split]) + for split, row in evaluate_df[["split", "accuracy"]].groupby(["split"]).mean().iterrows(): + writer.add_scalar(f"eval/rm/{eval_split}/accuracy/split/{split}", row["accuracy"], global_step) + accelerator.print(f"eval/rm/{eval_split}/accuracy/split/{split}: {row['accuracy']}") + for batch, row in evaluate_df[["batch", "accuracy"]].groupby(["batch"]).mean().iterrows(): + writer.add_scalar(f"eval/rm/{eval_split}/accuracy/batch/{batch}", row["accuracy"], global_step) + accelerator.print(f"eval/rm/{eval_split}/accuracy/batch/{batch}: {row['accuracy']}") + for confi, row in evaluate_df[["confidence", "accuracy"]].groupby(["confidence"]).mean().iterrows(): + writer.add_scalar(f"eval/rm/{eval_split}/accuracy/confidence/{confi}", row["accuracy"], global_step) + accelerator.print(f"eval/rm/{eval_split}/accuracy/confidence/{confi}: {row['accuracy']}") + writer.add_scalar(f"eval/rm/{eval_split}/accuracy", evaluate_df["accuracy"].mean(), global_step) + accelerator.print(f"eval/rm/{eval_split}/accuracy: {evaluate_df['accuracy'].mean()}") + if accelerator.is_main_process: + os.makedirs(f"eval_tables/{run_name}", exist_ok=True) + evaluate_df.to_csv(f"eval_tables/{run_name}/eval_{eval_split}_{update}.csv") + if args.track: + wandb.log({f"samples/{eval_split}/query_responses": wandb.Table(dataframe=evaluate_df)}, step=update) + del evaluate_df + torch.cuda.empty_cache() + + # save model + if args.output_dir and args.num_train_epochs > 0: + os.makedirs(os.path.dirname(args.output_dir), exist_ok=True) + time_tensor = torch.tensor([int(time.time())], device=device) + time_int = accelerator.gather(time_tensor)[0].item() # avoid different timestamps across processes + repo_name = f"{args.base_model.replace('/', '_')}__{args.exp_name}__tldr" + repo_id = f"{args.hf_entity}/{repo_name}" if args.hf_entity else repo_name + + if accelerator.is_main_process: + tokenizer.save_pretrained(args.output_dir, repo_id=repo_id) + if args.push_to_hub: + tokenizer.push_to_hub(repo_id, revision=f"seed{args.seed}_{str(time_int)}") + + unwrapped: PreTrainedModel = accelerator.unwrap_model(model) + accelerator.wait_for_everyone() + if accelerator.is_main_process: + unwrapped.save_pretrained( + args.output_dir, + is_main_process=accelerator.is_main_process, + save_function=accelerator.save, + state_dict=accelerator.get_state_dict(model), + safe_serialization=False, + repo_id=repo_id, + ) + if args.push_to_hub: + unwrapped.push_to_hub(repo_id, revision=f"seed{args.seed}_{str(time_int)}", safe_serialization=False) + +# if __name__ == "__main__": +# args = tyro.cli(Args) +# train(args)