Skip to content

Commit

Permalink
refine compute_accuracy and proper test
Browse files Browse the repository at this point in the history
  • Loading branch information
qgallouedec committed Dec 10, 2024
1 parent f164711 commit 4572a21
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 18 deletions.
8 changes: 1 addition & 7 deletions tests/test_reward_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,11 @@

import torch
from datasets import Dataset, load_dataset
from transformers import AutoModelForSequenceClassification, AutoTokenizer, EvalPrediction
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from transformers.testing_utils import require_peft
from transformers.utils import is_peft_available

from trl import RewardConfig, RewardTrainer, maybe_apply_chat_template
from trl.trainer import compute_accuracy
from trl.trainer.reward_trainer import _tokenize


Expand All @@ -37,11 +36,6 @@ def setUp(self):
self.model = AutoModelForSequenceClassification.from_pretrained(self.model_id)
self.model.config.pad_token_id = self.tokenizer.pad_token_id

def test_accuracy_metrics(self):
dummy_eval_predictions = EvalPrediction(torch.FloatTensor([[0.1, 0.9], [0.9, 0.1]]), torch.LongTensor([0, 0]))
accuracy = compute_accuracy(dummy_eval_predictions)
self.assertEqual(accuracy["accuracy"], 0.5)

def test_preprocessing_conversational(self):
with tempfile.TemporaryDirectory() as tmp_dir:
dummy_dataset = load_dataset("trl-internal-testing/zen", "conversational_preference", split="train")
Expand Down
71 changes: 70 additions & 1 deletion tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,15 @@

import unittest

import numpy as np
import torch
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
from transformers.testing_utils import require_peft
from transformers.utils import is_peft_available

from trl.trainer.model_config import ModelConfig
from trl import ModelConfig
from trl.trainer import compute_accuracy
from trl.trainer.utils import (
DataCollatorForChatML,
batch_generation,
Expand Down Expand Up @@ -312,3 +314,70 @@ def test_single_batch_generation(self):
self.assertGreater(max_length_query, context_length)
self.assertEqual(query_responses.shape, (bs, max_length_query))
self.assertEqual(logits.shape, (bs, max_length_logits, self.model.config.vocab_size))


class TestComputeAccuracy(unittest.TestCase):
def test_token_classification_task(self):
eval_pred = (
np.array(
[
[[0.1, 0.9], [0.8, 0.2]], # Batch 1
[[0.3, 0.7], [0.6, 0.4]], # Batch 2
]
),
np.array([[0, 1], [1, 0]]),
)
expected_accuracy = 0.5 # 2 matches, 2 mismatches
result = compute_accuracy(eval_pred)
self.assertAlmostEqual(result["accuracy"], expected_accuracy)

def test_token_classification_task_with_ignored_tokens_0(self):
eval_pred = (
np.array(
[
[[0.1, 0.9], [0.8, 0.2]], # Batch 1
[[0.3, 0.7], [0.6, 0.4]], # Batch 2
]
),
np.array([[1, 0], [1, -100]]),
)
expected_accuracy = 1.0 # All non-ignored tokens match
result = compute_accuracy(eval_pred)
self.assertAlmostEqual(result["accuracy"], expected_accuracy)

def test_token_classification_task_with_ignored_tokens_1(self):
eval_pred = (
np.array(
[
[[0.1, 0.9], [0.8, 0.2]], # Batch 1
[[0.3, 0.7], [0.6, 0.4]], # Batch 2
]
),
np.array([[1, 1], [0, -100]]),
)
expected_accuracy = 1 / 3 # 1 match, 2 mismatch, 1 ignored
result = compute_accuracy(eval_pred)
self.assertAlmostEqual(result["accuracy"], expected_accuracy)

def test_rewards_comparison_task(self):
eval_pred = (
np.array(
[
[0.9, 0.1], # Batch 1
[0.6, 0.4], # Batch 2
[0.5, 0.5], # Batch 3 (equal)
]
),
np.array([0, 1, 1]),
)
expected_accuracy = 0.5 # 1 match, 1 mismatch, 1 equal (ignored)

with self.assertWarns(UserWarning) as cm:
result = compute_accuracy(eval_pred)

self.assertAlmostEqual(result["accuracy"], expected_accuracy)
expected_warning = (
"There are 1 out of 3 instances where the predictions for both options are equal. "
"These instances are ignored in the accuracy computation."
)
self.assertEqual(str(cm.warning), expected_warning)
26 changes: 16 additions & 10 deletions trl/trainer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from transformers import (
BitsAndBytesConfig,
DataCollatorForLanguageModeling,
EvalPrediction,
GenerationConfig,
PreTrainedTokenizerBase,
TrainerState,
Expand Down Expand Up @@ -756,32 +757,37 @@ def get_global_statistics(
return global_mean.to(device), global_var.to(device), count.item()


def compute_accuracy(eval_pred) -> dict[str, float]:
def compute_accuracy(eval_pred: EvalPrediction) -> dict[str, float]:
predictions, labels = eval_pred
if predictions.ndim == 3:
# Token classification task.
# Token classification task. Shapes are (batch_size, seq_len, num_labels) and (batch_size, seq_len)
# Used to compute the accuracy in the stepwise_reward_trainer.
predictions = np.argmax(predictions, axis=2)

# Flatten the predictions and labels to remove the ignored tokens.
predictions = np.array(
[p for prediction, label in zip(predictions, labels) for (p, lbl) in zip(prediction, label) if lbl != -100]
)
labels = np.array([lbl for label in labels for lbl in label if lbl != -100])

else:
# Here, predictions is rewards_chosen and rewards_rejected.
# Here, predictions is rewards_chosen and rewards_rejected. Shapes are (batch_size, 2) and (batch_size,)
# We want to see how much of the time rewards_chosen > rewards_rejected.
equal_predictions_count = np.array(predictions[:, 0] == predictions[:, 1], dtype=float).sum()
equal_mask = predictions[:, 0] == predictions[:, 1]
equal_predictions_count = int(equal_mask.sum())

if equal_predictions_count > 0:
warnings.warn(
f"There are {equal_predictions_count} out of {len(predictions[:, 0])} instances where the predictions for "
"both options are equal. As a consequence the accuracy can be misleading.",
f"There are {equal_predictions_count} out of {len(predictions[:, 0])} instances where the predictions "
"for both options are equal. These instances are ignored in the accuracy computation.",
UserWarning,
)
if np.array(predictions[:, 0] == predictions[:, 1], dtype=float).sum() > 0:
warnings.warn(
f"There are {np.array(predictions[:, 0] == predictions[:, 1]).sum()} out of {len(predictions[:, 0])} instances where the predictions for both options are equal. As a consequence the accuracy can be misleading."
)

# Filter out equal predictions
predictions = predictions[~equal_mask]
labels = labels[~equal_mask]

# Use the remaining predictions for accuracy calculation
predictions = np.argmax(predictions, axis=1)

accuracy = np.array(predictions == labels, dtype=float).mean().item()
Expand Down

0 comments on commit 4572a21

Please sign in to comment.