Skip to content

Commit

Permalink
Merge pull request #62 from ell-hol/fix/ppo-accelerate
Browse files Browse the repository at this point in the history
fix ppo accelerate
  • Loading branch information
lucidrains authored Dec 24, 2024
2 parents 6b02ee3 + 1226210 commit 183b3a5
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 12 deletions.
66 changes: 66 additions & 0 deletions examples.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import torch
from palm_rlhf_pytorch import PaLM, RewardModel, RLHFTrainer
from accelerate import Accelerator

accelerator = Accelerator()
device = accelerator.device

# load your pretrained palm

palm = PaLM(
num_tokens = 20000,
dim = 512,
depth = 12
).to(device)


# load your pretrained reward model

reward_model = RewardModel(
palm,
num_binned_output = 5
).to(device)

# Train you reward model on mock data :
# mock data

seq = torch.randint(0, 20000, (1, 1024)).to(device)
prompt_mask = torch.zeros(1, 1024).bool().to(device) # which part of the sequence is prompt, which part is response
labels = torch.randint(0, 5, (1,)).to(device)

# train
loss = reward_model(seq, prompt_mask = prompt_mask, labels = labels)
accelerator.backward(loss)

# after much training
reward = reward_model(seq, prompt_mask = prompt_mask)


# ready your list of prompts for reinforcement learning

prompts = torch.randint(0, 256, (1, 512)).to(device) # 1 prompt

# pass it all to the trainer and train

trainer = RLHFTrainer(
palm = palm,
reward_model = reward_model,
prompt_token_ids = prompts
)

accelerator.print("Training")
trainer.train(
num_episodes = 1,
max_timesteps = 1,
update_timesteps = 1,
max_batch_size = 256,
max_seq_len = 2048,
eos_token = None,
temperature = 1.
)

# then, if it succeeded...
# generate say 10 samples and use the reward model to return the best one
accelerator.print("Generating answer")
answer = trainer.generate(2048, prompt = prompts[0], num_samples = 10) # (<= 2048,)
accelerator.print(f"answer: {answer}")
4 changes: 2 additions & 2 deletions palm_rlhf_pytorch/ppo.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import math
from pathlib import Path
import copy
from tqdm import tqdm
from accelerate.utils.tqdm import tqdm
from functools import partial
from collections import deque, namedtuple
from random import randrange
Expand Down Expand Up @@ -611,7 +611,7 @@ def train(
prompt_mask,
action_logits,
value
) = self.actor_critic.generate(
) = self.actor_critic.module.generate(
rearrange(state, 'n -> 1 n'),
max_seq_len = max_seq_len,
eos_token = eos_token,
Expand Down
20 changes: 10 additions & 10 deletions train.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import gzip
import random
import tqdm
from accelerate.utils.tqdm import tqdm
import numpy as np

import torch
Expand Down Expand Up @@ -39,7 +39,7 @@ def decode_tokens(tokens):

# accelerator

accelerator = Accelerator()
accelerator = Accelerator(gradient_accumulation_steps=GRADIENT_ACCUMULATE_EVERY)
device = accelerator.device

# instantiate palm
Expand Down Expand Up @@ -87,18 +87,18 @@ def __len__(self):

# training

for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10.0, desc="training"):
for i in tqdm(range(NUM_BATCHES), mininterval=10.0, desc="training"):
model.train()

for _ in range(GRADIENT_ACCUMULATE_EVERY):
with accelerator.accumulate(model):
loss = model(next(train_loader), return_loss = True)
accelerator.backward(loss / GRADIENT_ACCUMULATE_EVERY)

accelerator.print(f"training loss: {loss.item()}")
accelerator.clip_grad_norm_(model.parameters(), 0.5)

optim.step()
optim.zero_grad()
accelerator.print(f"training loss: {loss.item()}")
accelerator.clip_grad_norm_(model.parameters(), 0.5)
optim.step()
optim.zero_grad()

if i % VALIDATE_EVERY == 0:
model.eval()
Expand All @@ -112,6 +112,6 @@ def __len__(self):
prime = decode_tokens(inp)
accelerator.print(f"%s \n\n %s", (prime, "*" * 100))

sample = model.generate(GENERATE_LENGTH, inp[None, ...])
sample = model.module.generate(GENERATE_LENGTH, inp[None, ...])
output_str = decode_tokens(sample[0])
accelerator.print(output_str, "\n")

0 comments on commit 183b3a5

Please sign in to comment.