Skip to content

Commit

Permalink
0.1.5 rc2
Browse files Browse the repository at this point in the history
Signed-off-by: ssbuild <[email protected]>
  • Loading branch information
ssbuild committed May 12, 2023
1 parent 2aa6114 commit 3843829
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 10 deletions.
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@
## update
- <strong>2023-05-10</strong>
- fix lora v2 modules_to_save 自定义额外训练模块
- 0.1.5 rc0 增加reward ppo llm,完整训练 [llm_rlhf_training](https://github.com/ssbuild/llm_rlhf_training)
- 0.1.5 rc1 增加reward ppo chatglm,完整训练 [chatglm_rlhf_training](https://github.com/ssbuild/chatglm_rlhf_training)
- 0.1.5 rc0 增加reward ppo llm 完整训练 [llm_rlhf_training](https://github.com/ssbuild/llm_rlhf_training)
- 0.1.5 rc1 增加reward ppo chatglm 完整训练 [chatglm_rlhf_training](https://github.com/ssbuild/chatglm_rlhf_training)
- 0.1.5 rc2 增加reward ppo chatglm 完整训练 [chatyuan_rlhf_training](https://github.com/ssbuild/chatyuan_rlhf_training)
- <strong>2023-05-02</strong>
- 0.1.4 增加 prompt_tuning,p_tuning,prefix_tuning,adaption_prompt

Expand Down
2 changes: 1 addition & 1 deletion nlp/models/rl/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def forward(self, *args, **inputs) -> Seq2SeqLMOutputWithValue:
if not return_dict:
inputs.update({"return_dict": True})
inputs["output_hidden_states"] = True
outputs = self.forward(**inputs)
outputs = self.model(**inputs)
last_hidden_state = outputs.decoder_hidden_states[-1]
value = self.score(last_hidden_state).squeeze(-1)
if not return_dict:
Expand Down
15 changes: 9 additions & 6 deletions nlp/rl/ppo/ppo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -705,8 +705,13 @@ def decode(
str_prompts.append(str_prompt)
str_outputs.append(str_output)



if self.ppo_config.model_arch_type == "seq2seq":
sample = str_prompt + self.tokenizer.sep_token + str_output
if hasattr(self.tokenizer,'_sep_token') and self.tokenizer._sep_token is not None:
sample = str_prompt + self.tokenizer.sep_token + str_output
else:
sample = str_prompt + str_output
elif self.ppo_config.model_arch_type == "prefixlm":
sample = str_prompt + self.tokenizer.gmask_token + self.tokenizer.bos_token + str_output
else:
Expand Down Expand Up @@ -766,7 +771,6 @@ def make_experience(self, model, ref_model,**kwargs): # noqa:
device = samples.device

prompt_sizes = torch.tensor([prompt_tensors.shape[1]] * len(prompt_tensors), device=device)

padded_samples = pad_across_processes(
samples,world_size, dim=1, pad_index=self.tokenizer.eos_token_id, pad_first=False
)
Expand All @@ -785,13 +789,11 @@ def make_experience(self, model, ref_model,**kwargs): # noqa:
)

rollout_score_time = time()

all_scores = self.reward_fn(
samples=all_str_samples, prompts=all_str_prompts, outputs=all_str_outputs, **metadata
)
all_scores = all_scores.clone().detach().float().to(device)
stats["rollout/time/score"] = time() - rollout_score_time

all_scores = list(all_scores.reshape(world_size, -1).unbind())
else:
all_scores = None
Expand Down Expand Up @@ -847,11 +849,12 @@ def make_experience(self, model, ref_model,**kwargs): # noqa:
decoder_attention_mask = sample_outputs.not_equal(self.tokenizer.pad_token_id)
decoder_attention_mask[:, 0] = 1
with torch.no_grad():
outputs = model(
outputs = model.forward_logits_values(
input_ids=prompt_tensors,
attention_mask=attention_mask,
decoder_input_ids=sample_outputs,
decoder_attention_mask=decoder_attention_mask,
return_dict=True,
)
logits = outputs.logits
values = outputs.value
Expand All @@ -864,7 +867,7 @@ def make_experience(self, model, ref_model,**kwargs): # noqa:
return_dict=True,
).logits
else:
ref_logits = ref_model(
ref_logits = ref_model.forward_logits_values(
input_ids=prompt_tensors,
attention_mask=attention_mask,
decoder_input_ids=sample_outputs,
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
ignore = ['test','tests']
setup(
name='deep_training',
version='0.1.5rc1',
version='0.1.5rc2',
description='an easy training architecture',
long_description='torch_training: https://github.com/ssbuild/deep_training.git',
license='Apache License 2.0',
Expand Down

0 comments on commit 3843829

Please sign in to comment.