diff --git a/README.md b/README.md index a6d9985..a527444 100644 --- a/README.md +++ b/README.md @@ -61,13 +61,13 @@ To fully configure BERGEN, please read our [configuration guide](documentation/c Run the evaluation script to calculate LLMEval metrics and print the results: ```bash -python3 eval.py --experiments_folder experiments/ --llm_batch_size 16 --split 'dev' --llm vllm_SOLAR-107B +python3 evaluate.py --experiments_folder experiments/ --llm_batch_size 16 --split 'dev' --llm vllm_SOLAR-107B #parse all the experiments files into a panda dataframe python print_results.py --folder experiments/ --format=tiny ``` -For more evaluation options and details, refer to the [Evaluation section](documentation/evaluations.md) in the complete documentation. +Bergen also offers the possiblity to run pairwise comparisons using an LLM as judge. For more evaluation options and details, refer to the [Evaluation section](documentation/evaluations.md) in the complete documentation. ## RAG Baselines Bergen provides results for several models and many datasets aiming to **provide strong baselines**. On the important datasets for RAG, the match metric is given by this table (see more in our paper): diff --git a/config/evaluator/default_multi_qa.yaml b/config/evaluator/default_multi_qa.yaml index 0a9cc80..12ea3ea 100644 --- a/config/evaluator/default_multi_qa.yaml +++ b/config/evaluator/default_multi_qa.yaml @@ -7,4 +7,11 @@ output_options: prompt: system: f"You are an evaluation tool. Answer with one of \n {self.rubrik_section}." user: f"Here is a question, a golden answer and an AI-generated answer. Can you judge whether the AI-generated answer is correct according to the question and golden answer, simply answer with one of {self.rubrik_section}.\n Question:\ {question}. \nGolden answer:\ {answer} \n Generated answer:\ {prediction}" - user_without_system: f"You are an evaluation tool. Just answer as following {self.rubrik_section}. Here is a question, a golden answer and an AI-generated answer. Judge whether the AI-generated answer is correct according to the question and golden answer, answer with {self.rubrik_section}.\nQuestion:\ {question}.\nGolden answer:\ {answer}\nGenerated answer:\ {prediction}" \ No newline at end of file + user_without_system: f"You are an evaluation tool. Just answer as following {self.rubrik_section}. Here is a question, a golden answer and an AI-generated answer. Judge whether the AI-generated answer is correct according to the question and golden answer, answer with {self.rubrik_section}.\nQuestion:\ {question}.\nGolden answer:\ {answer}\nGenerated answer:\ {prediction}" +output_options_pairwise: + '1': 1. + '2': 0. + '3': 0.5 +prompt_pairwise: + system: f"You are a helpful assistant, that ranks models by the quality of their answers. Please act as an impartial judge. Do not allow the length of the responses to influence your evaluation. Be as objective as possible." + user: f"Here is a question, a ground truth answer, an AI-generated answer 1 and an AI-generated answer 2. Which answer is the most correct one ? Simply answer {{1}} if the first is better, {{2}} if the second is better and {{3}} if it's a tie. \n Question:\ {question}.\n Ground truth answer:\ {ref_answer}.\n Answer 1:\ {answer_1}.\n Answer 2:\ {answer_2}." diff --git a/config/evaluator/default_qa.yaml b/config/evaluator/default_qa.yaml index 9349526..a9acc74 100644 --- a/config/evaluator/default_qa.yaml +++ b/config/evaluator/default_qa.yaml @@ -6,5 +6,11 @@ output_options: prompt: system: f"You are an evaluation tool. Answer with one of {self.rubrik_section}." user: f"Here is a question, a golden answer and an AI-generated answer. Can you judge whether the AI-generated answer is correct according to the question and golden answer, simply answer with one of {self.rubrik_section}.\n Question:\ {question}. \nGolden answer:\ {answer} \n Generated answer:\ {prediction}" - assistant: f"Response:\ {{" user_without_system: f"You are an evaluation tool. Just answer by {self.rubrik_section}. Here is a question, a golden answer and an AI-generated answer. Judge whether the AI-generated answer is correct according to the question and golden answer, answer with {self.rubrik_section}.\nQuestion:\ {question}.\nGolden answer:\ {answer}\nGenerated answer:\ {prediction}" +output_options_pairwise: + '1': 1. + '2': 0. + '3': 0.5 +prompt_pairwise: + system: f"You are a helpful assistant, that ranks models by the quality of their answers. Please act as an impartial judge. Do not allow the length of the responses to influence your evaluation. Be as objective as possible." + user: f"Here is a question, a ground truth answer, an AI-generated answer 1 and an AI-generated answer 2. Which answer is the most correct one ? Simply answer 1 if the first is better, 2 if the second is better and 3 if it's a tie. \n Question:\ {question}.\n Ground truth answer:\ {answer}.\n Answer 1:\ {prediction_1}.\n Answer 2:\ {prediction_2}." diff --git a/documentation/evaluations.md b/documentation/evaluations.md index b863bf0..d228811 100644 --- a/documentation/evaluations.md +++ b/documentation/evaluations.md @@ -14,7 +14,7 @@ Example files generated for split `dev` using `naver_splade-cocondenser-selfdist Non-neural metrics will be calculated automatically. Neural metrics such as `BEM` and `LLM` need to be evoked seperately. -By default `eval.py` will scan all folders in `experiments/` and evaluate them sequentially. To evaluate a single folder pass the folder using `--folder`. To avoid running out of memory either run `BEM` using `--bem` or run `LLM` using `--llm` . A csv file will automatically be saved to `results/` containing the table in `csv` format. +By default `evaluate.py` will scan all folders in `experiments/` and evaluate them sequentially. To evaluate a single folder pass the folder using `--folder`. To avoid running out of memory either run `BEM` using `--bem` or run `LLM` using `--llm` . A csv file will automatically be saved to `results/` containing the table in `csv` format. When using `--llm` you have a choice on how you transform LLM predictions in the final score: - directly check in the generated answer for the expepected label occurence (default Yes/No), and assign corresponding score (default 1/0), when no expected label is found, or more than one expected label is matched, we assign score -100 to the corresponding sample, such samples are excluded from the mean score computation @@ -23,17 +23,17 @@ The choice of score interpretation is done via `use_logits` parameter specified ```bash -python3 eval.py --experiments_folder experiments/ --llm_batch_size 16 --split 'dev' --llm +python3 evaluate.py --experiments_folder experiments/ --llm_batch_size 16 --split 'dev' --llm ``` Similarly to `--generator` you can specify which LLM you are willing as first options of `--llm`, as well as short name at metrics naming (use the name of the configuration file as the name of the llm). ```bash # use llama2-7b-chat to run evaluation, output metric will be named VLLMeval_l2_7b -python3 eval.py --experiments_folder experiments/ --llm_batch_size 16 --split 'dev' --llm "vllm_llama-2-7b-chat" "l2_7b" +python3 evaluate.py --experiments_folder experiments/ --llm_batch_size 16 --split 'dev' --llm "vllm_llama-2-7b-chat" "l2_7b" # use tinyllama to run evaluation, output metric will be named LLMeval_tinyllama -python3 eval.py --experiments_folder experiments/ --llm_batch_size 16 --split 'dev' --llm "tinyllama-chat" "tinyllama" +python3 evaluate.py --experiments_folder experiments/ --llm_batch_size 16 --split 'dev' --llm "tinyllama-chat" "tinyllama" # in default settings (with no arguments specified) we use SOLAR-107B for evaluation and output metric is named LLMeval python3 eval.py --experiments_folder experiments/ --llm_batch_size 16 --split 'dev' --llm @@ -53,3 +53,17 @@ If you have local ollama server running, you can call models installed on this s python3 eval.py --experiments_folder experiments/ --llm_ollama "phi3:latest" --ollama_url "http://localhost:11434" --llm_prompt default_multi_qa ``` +### Pairwise comparisons + +Instead of computing an LLM eval score for a given run, you can compare two outputs using the same script and some additional arguments e.g. +```` +python3 evaluate.py --llm --folder mistral_preds --opponent_folder llama_preds --opponent_name llama +``` +where both `mistral_preds` and `llama_preds` are output folders of bergen inferences. +This scripts uses an LLM (can be any LLM supported in bergen or gpt-4o) to compare the two sets of predictions and compute win/tie/lose rates against the opponent. Results are stored in the metrics file of the folder. The prompt used is the pairwise prompt in `config/default_qa.yaml`. + +This approach does not use logits but rather the raw prediction of the LLMs (win, tie or lose). + +In this setup note that: + - A single experiment folder must be specified for `--folder` and `--opponent_folder` + - the `opponent_name` is required \ No newline at end of file diff --git a/eval.py b/eval.py deleted file mode 100644 index c07e5a4..0000000 --- a/eval.py +++ /dev/null @@ -1,204 +0,0 @@ -import json -import shutil -import torch -import time -import os -from hydra.utils import instantiate -import omegaconf -import yaml -import gc -import pandas as pd -pd.set_option("display.precision", 4) - -class Evaluate: - @staticmethod - def eval(experiment_folder="experiments/", split="dev", bem: bool=False, llm: list[str]=None, llm_ollama: list[str]=None, vllm: list[str]=None, gpt: bool=None, bem_batch_size: int=1, lid: bool=None, lid_advanced: bool=None, llm_batch_size: int=None, llm_prompt: str = "default_qa", ollama_url: str=None, folder: str=None, force: bool=False, samples: int=-1): - def eval_single(experiment_folder, folder, split: str, model, metric_name: str, nb_samples: int =-1): - if folder != None: - folders = [folder] - else: - folders = [ f.path for f in os.scandir(experiment_folder) if f.is_dir() and 'tmp_' not in f.path] - for experiment_folder in folders: - - print('evaluating', experiment_folder) - def load_data(input_file): - result_dict = json.load(open(input_file)) - return pd.DataFrame(result_dict) - - input_file = f'{experiment_folder}/eval_{split}_out.json' - if os.path.exists(input_file): - data = load_data(input_file) - if nb_samples >0 and nb_samples < len(data): - data = data[:nb_samples] - - metrics_file = f'{experiment_folder}/eval_{split}_metrics.json' - try: - metrics_dict = json.load(open(metrics_file)) - except: continue - - if metric_name in metrics_dict and not force: - print (f"{experiment_folder}\t{metric_name}\talready done") - continue - - predictions = data['response'].values - references = data['label'].values - questions = data['question'].values - - if gpt is not None: - # openai costs - model_score, scores, cost = model(predictions, references, questions) - costs_out_file = f'{experiment_folder}/eval_{split}_cost_{metric_name}_out.json' - with open(costs_out_file, 'w') as fout: fout.write(json.dumps(cost)) - else: - model_score, scores = model(predictions, references, questions) - data[metric_name] = scores - metrics_out_file = f'{experiment_folder}/eval_{split}_out.json' - if nb_samples >0: - metrics_out_file = f'{experiment_folder}/eval_{split}_out_{nb_samples}.json' - - # temporary print eval_out results with updated metric (to avoid loosing eval_dev_out.json if smth goes wrong) - data.to_json(metrics_out_file+"_", orient='records') - #move temprorary result into final name - shutil.move(metrics_out_file + '_', metrics_out_file) - if nb_samples >0: - metric_name = f"{metric_name}_{nb_samples}" - metrics_dict.update({metric_name: model_score}) - print(metric_name,model_score) - # save to _ tmp file - with open(metrics_file + '_', 'w') as fp: - json.dump(metrics_dict, fp, indent=2) - # when writing successful remove tmp file - shutil.move(metrics_file + '_', metrics_file) - - if bem: - from models.evaluators.bem import BEM - model = BEM(batch_size=bem_batch_size) - eval_single(experiment_folder, folder, split, model, 'BEM', nb_samples = samples) - if gpt is not None: - from models.evaluators.openai import OpenAI - model = OpenAI(gpt) - eval_single(experiment_folder, folder, split, model, gpt, nb_samples = samples) - - if llm is not None: - - if len(llm) == 0: - model_config, short_name = "SOLAR-107B", "LLMeval" - elif len(llm)==1: - model_config = llm[0] - short_name = model_config - short_name = f"LLMeval_{short_name}" - elif len(llm)==2: - model_config = llm[0] - short_name = llm[1] - short_name = f"LLMeval_{short_name}" - - model_config = omegaconf.OmegaConf.load(f"config/generator/{model_config}.yaml") - if model_config['init_args']['_target_']=='models.generators.vllm.VLLM': - from models.evaluators.vllm import VLLMeval - model = VLLMeval(model_config, batch_size=llm_batch_size, config=llm_prompt) - - else: - from models.evaluators.llm import LLMeval - model = LLMeval(model_config, batch_size=llm_batch_size, config=llm_prompt) - if model.use_logits : - short_name = f"{short_name}_logits" - - eval_single(experiment_folder, folder, split, model, short_name, nb_samples = samples) - del model - torch.cuda.empty_cache() - gc.collect() - if llm_ollama is not None: - from models.evaluators.llm_ollama import OllamaEval - - if len(llm_ollama)==1: - model_config = llm_ollama[0] - short_name = model_config - short_name = f"LLMeval_{short_name}" - elif len(llm_ollama)==2: - model_config = llm_ollama[0] - short_name = llm_ollama[1] - short_name = f"LLMeval_{short_name}" - if llm_batch_size == None: - llm_batch_size = 1 - model = OllamaEval(model_config, batch_size=llm_batch_size, config=llm_prompt, basic_url=ollama_url) - eval_single(experiment_folder, folder, split, model, short_name, nb_samples = samples) - - if lid is not None or lid_advanced is not None: - from models.evaluators.lid import LID - from models.evaluators.lid_advanced import LID_advanced - if folder == None: - folders = [ f.path for f in os.scandir(experiment_folder) if f.is_dir() and 'tmp_' not in f.path] - else: - folders = [folder] - for folder in folders: - # we need to get language from each folder config separately - config = yaml.safe_load(open(f"{folder}/config.yaml")) - if 'lng' in config['dataset'][split]['query']['init_args']: - tgt_lng = config['dataset'][split]['query']['init_args']['lng'] - elif 'lang' in config['dataset'][split]['query']['init_args']: - tgt_lng = config['dataset'][split]['query']['init_args']['lang'] - else: - #if language is not specified we set it to English by default - tgt_lng = 'en' - print(f"{folder}: didn't find lng in the config.yaml, set it to English by default") - if lid is not None: - model=LID(tgt_lng) - eval_single(experiment_folder, folder, split, model, "lid", nb_samples = samples) - if lid_advanced is not None: - model = LID_advanced(tgt_lng) - eval_single(experiment_folder, folder, split, model, "lid_advanced", nb_samples = samples) - - -if __name__ == "__main__": - import argparse - - parser = argparse.ArgumentParser() - parser.add_argument('--experiments_folder', type=str, default="experiments/") - parser.add_argument('--folder', type=str, default=None) - parser.add_argument('--split', type=str, default='dev') - parser.add_argument('--sample', type=int, default=-1, help="Use only subsample of the experiment folder for evaluation, useful for debug purposes (default -1: use full dataset)") - parser.add_argument('--bem', action='store_true') - parser.add_argument('--lid', action='store_true', default=None) - parser.add_argument('--lid_advanced', action='store_true', default=None) - - parser.add_argument('--llm', type=str, nargs='*', default=None, - help=""" - - full model name (corresponding to generator config name) and short name (used for naming output files and metrics): - eg. -llm SOLAR-107B solar - - if short name is missing: use full name in naming, - - if no arguments specified: falls back to default arguments: uses default values (SOLAR-107B LLMeval). - """) - - parser.add_argument('--llm_ollama', type=str, nargs='*', default=None, - help=""" - Calls ollama server to run evaluation. Requires 1 or 2 arguments: - - full model name and short name (used for naming output files and metrics): eg. -llm_ollama llama3:default llama3 - - if short name is missing: use full name in naming - """ ) - parser.add_argument('--gpt', type=str,default=None) - parser.add_argument('--bem_batch_size', type=int, default=1024) - parser.add_argument('--llm_batch_size', type=int, default=None) - parser.add_argument('--force', action='store_true') - parser.add_argument('--llm_prompt', type=str, default="default_qa", help="Provide yaml config file with updated prompt. Default prompt: config/evaluator/default_prompt.yaml") - parser.add_argument('--ollama_url', type=str, default="http://localhost:11434", help="") - - - args = parser.parse_args() - e = Evaluate.eval( - folder=args.folder, - experiment_folder=args.experiments_folder, - split=args.split, - bem=args.bem, - llm=args.llm, - llm_ollama=args.llm_ollama, - gpt=args.gpt, - lid=args.lid, - lid_advanced=args.lid_advanced, - bem_batch_size=args.bem_batch_size, - llm_batch_size=args.llm_batch_size, - llm_prompt=args.llm_prompt, - ollama_url=args.ollama_url, - force=args.force, - samples=args.sample - ) - diff --git a/evaluate.py b/evaluate.py new file mode 100644 index 0000000..9891944 --- /dev/null +++ b/evaluate.py @@ -0,0 +1,316 @@ +import json +import shutil +import torch +import os +import omegaconf +import yaml +import gc +import pandas as pd +pd.set_option("display.precision", 4) + + +def load_data(input_file: str, nb_samples: int) -> pd.DataFrame: + result_dict = json.load(open(input_file)) + data = pd.DataFrame(result_dict) + if nb_samples > 0 and nb_samples < len(data): + data = data[:nb_samples] + return data + + +def load_opponent_predictions(opponent_folder: str, split: str, data: dict) -> list: + """ + Loads predictions from the opponent folder + Orders them as in 'data' and checks all elements are present + """ + # We filter the other data to keep the q_ids in data + other_data = load_data(f'{opponent_folder}/eval_{split}_out.json', nb_samples=-1) + other_data = other_data[other_data.q_id.isin(data.q_id.unique())] + + assert len(other_data) == len(data), f'{len(other_data)} VS {len(data)}' + + # Reordering along data order: + other_data = other_data.set_index('q_id').reindex(data['q_id']).reset_index() + + # Sanity checks: proper joint sorting + for elt, other_elt in zip(data['q_id'].values, other_data['q_id'].values): + assert elt == other_elt, f'Unmatching q_id {elt} vs {other_elt} in json files: cannot compare' + + return other_data['response'].values + + +def eval_single(experiment_folder, + folder, + split: str, + model, + metric_name: str, + nb_samples: int = -1, + gpt: str = None, + opponent_folder: str = None, + force: bool = False, + ): + if nb_samples > 0: + metric_name = f"{metric_name}_{nb_samples}" + if folder is not None: + folders = [folder] + else: + folders = [ f.path for f in os.scandir(experiment_folder) if f.is_dir() and 'tmp_' not in f.path] + for experiment_folder in folders: + print('evaluating', experiment_folder) + + input_file = f'{experiment_folder}/eval_{split}_out.json' + if os.path.exists(input_file): + data = load_data(input_file, nb_samples=nb_samples) + + # Check whether this metric is already calculated: + metrics_file = f'{experiment_folder}/eval_{split}_metrics.json' + if os.path.exists(metrics_file): + metrics_dict = json.load(open(metrics_file)) + else: + metrics_dict = {} + + # Was the metric already calculated ? (tie tests for pairwise metrics) + if (metric_name in metrics_dict or metric_name + '_tie' in metrics_dict) and not force: + print(f"{experiment_folder}\t{metric_name}\talready done") + continue + + predictions = data['response'].values + references = data['label'].values + questions = data['question'].values + + if gpt is not None: + if opponent_folder is None: + model_score, scores, cost = model(predictions, references, questions) + else: + opponent_predictions = load_opponent_predictions(opponent_folder, split=split, data=data) + model_score, scores, cost = model.pairwise_win_rate(predictions, opponent_predictions, references, questions) + + # openai costs + costs_out_file = f'{experiment_folder}/eval_{split}_cost_{metric_name}_out.json' + with open(costs_out_file, 'w') as fout: + fout.write(json.dumps(cost)) + else: + if opponent_folder is None: + model_score, scores = model(predictions, references, questions) + else: + opponent_predictions = load_opponent_predictions(opponent_folder, split=split, data=data) + model_score, scores = model(predictions=predictions, references=references, questions=questions, opponent_predictions=opponent_predictions) + + data[metric_name] = scores + metrics_out_file = f'{experiment_folder}/eval_{split}_out.json' + if nb_samples > 0: + metrics_out_file = f'{experiment_folder}/eval_{split}_out_{nb_samples}.json' + + # temporary print eval_out results with updated metric (to avoid loosing eval_dev_out.json if smth goes wrong) + data.to_json(metrics_out_file + "_", orient='records') + shutil.move(metrics_out_file + '_', metrics_out_file) + + if isinstance(model_score, dict): # win tie lose for pairwise ! + metrics_dict.update({metric_name + '_' + k: v for k, v in model_score.items()}) + else: + metrics_dict.update({metric_name: model_score}) + + print(metric_name, model_score) + # save to _ tmp file + with open(metrics_file + '_', 'w') as fp: + json.dump(metrics_dict, fp, indent=2) + # when writing successful remove tmp file + shutil.move(metrics_file + '_', metrics_file) + + +def llm_eval(llm: list[str], experiment_folder, folder, split, batch_size, llm_prompt, opponent_folder, opponent_name, nb_samples, force): + if len(llm) == 0: + model_config, metric_name = "SOLAR-107B", "LLMeval_SOLAR-107B" + else: + model_config = llm[0] + metric_name = llm[1] if len(llm) > 1 else model_config + metric_name = f"LLMeval_{metric_name}" + + if opponent_folder is not None: + metric_name += '_VS_' + opponent_name + + model_config = omegaconf.OmegaConf.load(f"config/generator/{model_config}.yaml") + if model_config['init_args']['_target_']=='models.generators.vllm.VLLM': + from models.evaluators.vllm import VLLMeval + model = VLLMeval(model_config, batch_size=batch_size, config=llm_prompt) + + else: + from models.evaluators.llm import LLMeval + model = LLMeval(model_config, batch_size=batch_size, config=llm_prompt) + if model.use_logits: + if opponent_folder is not None: + print('WARNING: cannot use logits for pairwise comparison eval: defaulting to just text parsing.') + model.use_logits = False + else: + metric_name = f"{metric_name}_logits" + + eval_single(experiment_folder, folder, split, model, metric_name=metric_name, nb_samples=nb_samples, opponent_folder=opponent_folder, force=force) + del model + torch.cuda.empty_cache() + gc.collect() + + +def llm_ollama_eval(llm_ollama: list[str], experiment_folder, folder, split, batch_size, llm_prompt, ollama_url, nb_samples, force): + from models.evaluators.llm_ollama import OllamaEval + + if len(llm_ollama) > 0: + model_config = llm_ollama[0] + short_name = llm_ollama[1] if len(llm_ollama) > 1 else model_config + short_name = f"LLMeval_{short_name}" + + batch_size = batch_size or 1 + + model = OllamaEval(model_config, batch_size=batch_size, config=llm_prompt, basic_url=ollama_url) + eval_single(experiment_folder, folder, split, model, metric_name=short_name, nb_samples = nb_samples, force=force) + + +def lid_eval(lid, lid_advanced, experiment_folder, folder, split, nb_samples, force): + from models.evaluators.lid import LID + from models.evaluators.lid_advanced import LID_advanced + if folder is None: + folders = [ f.path for f in os.scandir(experiment_folder) if f.is_dir() and 'tmp_' not in f.path] + else: + folders = [folder] + + for folder in folders: + # we need to get language from each folder config separately + config = yaml.safe_load(open(f"{folder}/config.yaml")) + if 'lng' in config['dataset'][split]['query']['init_args']: + tgt_lng = config['dataset'][split]['query']['init_args']['lng'] + elif 'lang' in config['dataset'][split]['query']['init_args']: + tgt_lng = config['dataset'][split]['query']['init_args']['lang'] + else: + #if language is not specified we set it to English by default + tgt_lng = 'en' + print(f"{folder}: didn't find lng in the config.yaml, set it to English by default") + if lid is not None: + model = LID(tgt_lng) + eval_single(experiment_folder, folder, split, model, metric_name="lid", nb_samples = nb_samples, force=force) + if lid_advanced is not None: + model = LID_advanced(tgt_lng) + eval_single(experiment_folder, folder, split, model, metric_name="lid_advanced", nb_samples = nb_samples, force=force) + + +def gpt_eval(gpt, experiment_folder, folder, split, opponent_folder, opponent_name, nb_samples, force): + from models.evaluators.openai import OpenAI + model = OpenAI(gpt) + metric_name = gpt + if opponent_folder is not None: + metric_name += '_VS_' + opponent_name + eval_single(experiment_folder, folder, split, model, gpt=gpt, metric_name=metric_name, nb_samples=nb_samples, opponent_folder=opponent_folder, force=force) + + +def run_eval(experiment_folder=None, + split="dev", + llm: list[str]=None, + llm_ollama: list[str]=None, + gpt: bool=None, + lid: bool=None, + lid_advanced: bool=None, + llm_batch_size: int=None, + llm_prompt: str = "default_qa", + ollama_url: str=None, + folder: str=None, + force: bool=False, + nb_samples: int=-1, + opponent_folder: str = None, + opponent_name: str = None): + """ + Entry point for all LLM evaluations. + """ + if gpt is not None: + gpt_eval(gpt, + experiment_folder, + folder, + split, + opponent_folder=opponent_folder, + opponent_name=opponent_name, + nb_samples=nb_samples, + force=force) + + if llm is not None: + llm_eval(llm, + experiment_folder, + folder, + split, + llm_batch_size, + llm_prompt, + opponent_folder=opponent_folder, + opponent_name=opponent_name, + nb_samples=nb_samples, + force=force) + + if llm_ollama is not None: + llm_ollama_eval(llm_ollama, experiment_folder, folder, split, llm_batch_size, llm_prompt, ollama_url, nb_samples=nb_samples, force=force) + + if lid is not None or lid_advanced is not None: + lid_eval(lid, lid_advanced, experiment_folder, folder, split, nb_samples=nb_samples, force=force) + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument('--experiments_folder', type=str, default="experiments/") + parser.add_argument('--folder', type=str, default=None) + + parser.add_argument('--split', type=str, default='dev') + parser.add_argument('--sample', type=int, default=-1, help="Use only subsample of the experiment folder for evaluation, useful for debug\ + purposes (default -1: use full dataset)") + parser.add_argument('--lid', action='store_true', default=None) + parser.add_argument('--lid_advanced', action='store_true', default=None) + + parser.add_argument('--llm', type=str, nargs='*', default=None, + help=""" + - full model name (corresponding to generator config name) and short name (used for naming output files and metrics): + eg. -llm SOLAR-107B solar + - if short name is missing: use full name in naming, + - if no arguments specified: falls back to default arguments: uses default values (SOLAR-107B LLMeval). + """) + + parser.add_argument('--llm_ollama', type=str, nargs='*', default=None, + help=""" + Calls ollama server to run evaluation. Requires 1 or 2 arguments: + - full model name and short name (used for naming output files and metrics): eg. -llm_ollama llama3:default llama3 + - if short name is missing: use full name in naming + """ ) + + parser.add_argument('--gpt', type=str, default=None) + + # Use these arguments to do pairwise evaluations: + parser.add_argument('--opponent_folder', type=str, default=None, help='Provide a second folder via this to run pairwise comparisons\ + (only available with gpt and when specifying a folder)') + parser.add_argument('--opponent_name', type=str, default=None, help='Provide a second folder via this to run pairwise comparisons\ + (only available with gpt and when specifying a folder)') + + parser.add_argument('--llm_batch_size', type=int, default=None) + parser.add_argument('--force', action='store_true') + parser.add_argument('--llm_prompt', type=str, default="default_qa", help="Provide yaml config file with updated prompt.\ + Default prompt: config/evaluator/default_prompt.yaml") + parser.add_argument('--ollama_url', type=str, default="http://localhost:11434", help="") + + args = parser.parse_args() + + if args.opponent_folder is not None: + assert args.gpt or args.llm is not None, f"{args.gpt} {args.llm}" + assert args.folder is not None, 'Pairwise only supported if you specify a folder' + assert os.path.isdir(args.opponent_folder), 'Pairwise_on argument should point to a directory to which compare the folder arg outputs.' + assert args.opponent_name is not None, 'Specify a name for the opponent (to name the metrics)' + print('Pairwise comparison detected, the opponent is found at:', args.opponent_folder, ' with name ', args.opponent_name) + + e = run_eval( + folder=args.folder, + experiment_folder=args.experiments_folder, + split=args.split, + llm=args.llm, + llm_ollama=args.llm_ollama, + gpt=args.gpt, + lid=args.lid, + lid_advanced=args.lid_advanced, + llm_batch_size=args.llm_batch_size, + llm_prompt=args.llm_prompt, + ollama_url=args.ollama_url, + force=args.force, + nb_samples=args.sample, + opponent_folder=args.opponent_folder, + opponent_name=args.opponent_name + ) diff --git a/models/evaluators/bem.py b/models/evaluators/bem.py deleted file mode 100644 index 2d33a2b..0000000 --- a/models/evaluators/bem.py +++ /dev/null @@ -1,90 +0,0 @@ -''' -BERGEN -Copyright (c) 2024-present NAVER Corp. -CC BY-NC-SA 4.0 license -''' - -import torch -from torch.nn import functional as F -import tensorflow_hub as hub -from transformers import BertTokenizer -import tensorflow as tf -from tqdm import tqdm - -import os - -# Suppress TensorFlow warnings -os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1' - -# Suppress TensorFlow warnings -tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) -tf.get_logger().setLevel(tf.compat.v1.logging.ERROR) - -class BEM: - def __init__(self, batch_size=2048): - self.batch_size = batch_size - self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') - - - def bertify_example(self, question, reference, candidate, max_length=512): - question = self.tokenizer.tokenize(question)[:max_length] - reference = self.tokenizer.tokenize(reference)[:max_length] - candidate = self.tokenizer.tokenize(candidate)[:max_length] - - tokens = ['[CLS]'] + candidate + ['[SEP]'] + reference + ['[SEP]'] + question + ['[SEP]'] - - input_ids = torch.tensor(self.tokenizer.convert_tokens_to_ids(tokens)) - segment_ids = torch.tensor([0] * (len(candidate) + 2) + [1] * (len(reference) + 1) + [2] * (len(question) + 1)) - - input_ids = F.pad(torch.tensor(input_ids), (0, max_length - len(input_ids)), value=0) - segment_ids = F.pad(torch.tensor(segment_ids), (0, max_length - len(segment_ids)), value=0) - - return {'input_ids': input_ids, 'segment_ids': segment_ids} - - - def bertify_examples(self, examples, max_length=512): - input_ids = [] - segment_ids = [] - for example in examples: - question = example['question'] - candidate = example['candidate'] - reference = example['reference'] - - if isinstance(reference, str): - reference = [reference] - - for ref in reference: - example_inputs = self.bertify_example(question, ref, candidate, max_length=max_length) - - input_ids.append(example_inputs['input_ids']) - segment_ids.append(example_inputs['segment_ids']) - - return {'input_ids': torch.stack(input_ids), 'segment_ids': torch.stack(segment_ids)} - - def __call__(self, predictions, references, questions): - # Loading the TensorFlow Hub model - self.model = hub.load('https://tfhub.dev/google/answer_equivalence/bem/1') - assert len(predictions) == len(references) == len(questions) - examples = [{'question': questions[i], 'reference': references[i], 'candidate': predictions[i]} for i in range(len(predictions))] - - inputs = self.bertify_examples(examples, max_length=self.tokenizer.model_max_length) - # The outputs are raw logits. - scores = list() - # Perform batch inference - for i in tqdm(range(0, len(inputs['input_ids']), self.batch_size), desc='BEM evaluation...'): - # Extract batch - batch_input_ids = inputs['input_ids'][i:i+self.batch_size] - batch_segment_ids = inputs['segment_ids'][i:i+self.batch_size] - inp = {"input_ids": tf.stop_gradient(batch_input_ids), "segment_ids": tf.stop_gradient(batch_segment_ids)} - raw_outputs = self.model(inp) - raw_outputs_torch = torch.from_numpy(raw_outputs.numpy()) - scores.append(raw_outputs_torch) - # They can be transformed into a classification 'probability' like so: - del self.model - scores = torch.cat(scores) - tf.keras.backend.clear_session() - torch.cuda.empty_cache() - scores = F.softmax(scores, dim=1)[:, 1] - return scores.mean().item(), scores - - diff --git a/models/evaluators/llm.py b/models/evaluators/llm.py index f91378e..0765ce3 100644 --- a/models/evaluators/llm.py +++ b/models/evaluators/llm.py @@ -9,144 +9,202 @@ from tqdm import tqdm import torch from hydra.utils import instantiate -from models.evaluators.utils import * +from models.evaluators.utils import process_llm_outputs_assess_scores, get_mean_without_unknown, unswitch_switched_scores, set_tq_description, get_pairwise_scores_without_unknown import gc +import random -class LLMeval(): +class BaseEval: """ - - relies on default HF inference - - if use_logits is set to True (in evaluator config) - - output score is computed as interpolation between prob of label and it's associated value - (defined by options map in config): eg. p(x=yes)*1 + p(x=no)*0 - - otherwise: we just check if label is present in the answer (yes/no) and return associated value (1/0) - + Base class for evaluation logic shared by LLMeval and VLLMeval. """ - def __init__(self, model_config: dict, batch_size: int = None, config: str = "default_qa" ): + def __init__(self, model_config: dict, batch_size: int = None, config: str = "default_qa"): """ - model_config: generator config specified as yaml file in cofig/generator directory - batch_size: if none, it keeps default llm batch size from config - confg: name of evaluator config specified as yaml file at config/evaluators + Base initializer for evaluation classes. """ - eval_config = omegaconf.OmegaConf.load(f"config/evaluator/{config}.yaml") - model_config['init_args']['max_new_tokens']= eval_config['max_new_tokens'] + model_config['init_args']['max_new_tokens'] = eval_config['max_new_tokens'] - self.use_logits = eval_config.use_logits - self.llm = instantiate(model_config['init_args'], prompt=eval_config['prompt']) + self.llm = self.initialize_llm(model_config, eval_config) + self.options = eval_config.output_options - self.rubrik_section = ", ".join(["{"+opt+"}" for opt in self.options]) + self.rubrik_section = ", ".join(self.options) + + self.options_pairwise = eval_config.output_options_pairwise + + # Set up prompts self.prompt = eval_config['prompt'] - self.llm.max_new_tokens = eval_config['max_new_tokens'] - if not batch_size == None: - self.llm.batch_size = batch_size + self.prompt_pairwise = eval_config['prompt_pairwise'] self.system_prompt = eval(self.prompt.system).replace(':\ ', ': ') - #FIXME: what shall we do if label corrsponds to multiple tokens? + self.system_prompt_pairwise = eval(self.prompt_pairwise.system).replace(':\ ', ': ') + + # Set up LLM parameters + self.batch_size = batch_size or self.llm.batch_size + self.llm.max_new_tokens = eval_config['max_new_tokens'] + + # output_ids contains the token ids for the possible answers self.output_ids = [self.llm.tokenizer.encode(opt, add_special_tokens=False) for opt in sorted(self.options)] + # output_values contain the associated 'score' for each option self.output_values = torch.tensor([self.options[opt] for opt in sorted(self.options)]).float() - self.generation_config = GenerationConfig.from_model_config(self.llm.model.config) - self.generation_config.do_sample=False, - # according to documentation from https://huggingface.co/docs/transformers/v4.43.2/main_classes/text_generation this is supposed to force model to generate tokens from the list, but it doesn't seem to work in practice - # --> rollback to simple solution: just check first token logit of each predefined label - self.generation_config.force_word_ids=self.output_ids, - self.generation_config.max_new_tokens=self.llm.max_new_tokens - - - - + self.output_ids_pairwise = [self.llm.tokenizer.encode(opt, add_special_tokens=False) for opt in sorted(self.options_pairwise)] + self.output_values_pairwise = torch.tensor([self.options_pairwise[opt] for opt in sorted(self.options_pairwise)]).float() + + def initialize_llm(self, model_config, eval_config): + """ + Placeholder for LLM initialization, to be overridden by subclasses if needed. + """ + return instantiate(model_config['init_args'], prompt=eval_config['prompt']) + def __del__(self): - # print(f"Delete evaluator {self.llm.model_name}") torch.cuda.empty_cache() gc.collect() - - def create_instruction(self,sample): - answer = sample['reference'] - question=sample['question'] - prediction=sample['candidate'] - if 'response' in sample: - response = sample['response'] - else: - response = None + + def create_instruction(self, answer: str, question: str, prediction: str) -> str: prefix = [] + rubrik_section = self.rubrik_section # for the 'eval' if getattr(self.llm.tokenizer, "chat_template") is not None and 'system' in self.llm.tokenizer.chat_template: - prefix = [{'role': 'system', - 'content': self.system_prompt}] - prefix.extend([{'role': 'user', - 'content': eval(self.prompt.user).replace(':\ ', ': ')}] - ) - + prefix = [ + {'role': 'system', 'content': self.system_prompt}, + {'role': 'user', 'content': eval(self.prompt.user).replace(':\ ', ': ')} + ] else: - prefix = ([{'role': 'user', - 'content': eval(self.prompt.user_without_system).replace(':\ ', ': ')}] - ) - if 'assistant' in self.prompt: - prefix.extend([{'role': 'assistant', - 'content': eval(self.prompt.assistant).replace(':\ ', ': ')}] - ) - if not response is None: - prefix.extend([{'role': 'assistant', - 'content': response}] - ) + prefix = ([ + {'role': 'user','content': eval(self.prompt.user_without_system).replace(':\ ', ': ')} + ]) return self.llm.tokenizer.apply_chat_template(prefix, add_generation_prompt=True, tokenize=False) + + def create_pairwise_instruction(self, question: str, answer: str, prediction_1: str, prediction_2: str) -> (str, bool): + """ + To prevent positional bias, orders of answers is randomly switched + We switch the scores appropriately later on in '__call__' + so this method returns the prompt + the 'switch' boolean + Unused arguments are used in the "eval" + """ + switch = random.choice([True, False]) + if switch: + prediction_1, prediction_2 = prediction_2, prediction_1 + + assert hasattr(self.llm.tokenizer, 'chat_template'), 'Please use an LLM with a chat template' + prefix = [ + {'role': 'system', 'content': self.system_prompt_pairwise}, + {'role': 'user', 'content': eval(self.prompt_pairwise.user).replace(':\ ', ': ')} + ] + return self.llm.tokenizer.apply_chat_template(prefix, add_generation_prompt=True, tokenize=False), switch + + def create_inputs(self, predictions, references, questions, opponent_predictions=None) -> dict: + """ + Create all the prompts + For pairwise case, it also creates the 'switches' which correspond to inversions in answer order to prevent bias. + """ + assert len(predictions) == len(references) == len(questions) + pairwise = (opponent_predictions is not None) + if pairwise: + assert len(opponent_predictions) == len(predictions) + + inputs = [] + + for i in range(len(predictions)): + if pairwise: + sample_instr, sample_switch = self.create_pairwise_instruction(question=questions[i], + answer=references[i], + prediction_1=predictions[i], + prediction_2=opponent_predictions[i]) + inputs.append({'instr': sample_instr, 'switch': sample_switch}) + else: + sample_instr = self.create_instruction(question=questions[i], answer=references[i], prediction=predictions[i]) + inputs.append({'instr': sample_instr}) + + return inputs - - - def collate_fn(self, examples, max_length=512): - instr = [self.create_instruction(sample) for sample in examples] # Add prompt to each text - instr_tokenized = self.llm.tokenizer(instr, padding=True, truncation=True, return_tensors="pt") - return instr_tokenized, instr + +class LLMeval(BaseEval): + """ + Evaluation class for HF inference. + """ + def __init__(self, model_config: dict, batch_size: int = None, config: str = "default_qa"): + super().__init__(model_config, batch_size, config) + + eval_config = omegaconf.OmegaConf.load(f"config/evaluator/{config}.yaml") + self.use_logits = eval_config.use_logits + + # Set up generation config for HF + self.generation_config = GenerationConfig.from_model_config(self.llm.model.config) + self.generation_config.do_sample = False + self.generation_config.max_new_tokens = self.llm.max_new_tokens @torch.no_grad() - def __call__(self, predictions, references, questions): + def __call__(self, predictions, references, questions, opponent_predictions=None): + """ + other_preditions: opponent model prediction in pairwise comparison + """ assert len(predictions) == len(references) == len(questions) - examples = [{'question': questions[i], 'reference': references[i], 'candidate': predictions[i]} for i in range(len(predictions))] + + pairwise = (opponent_predictions is not None) + + output_ids = self.output_ids_pairwise if pairwise else self.output_ids + output_values = self.output_values_pairwise if pairwise else self.output_values + options = self.options_pairwise if pairwise else self.options + + # list of dictionaries containing each sample formatted instruction, and switch (if pairwise) + inputs = self.create_inputs(predictions=predictions, references=references, questions=questions, opponent_predictions=opponent_predictions) + # The outputs are raw logits. - scores = list() - weird = list() + scores, weirds = [], [] # Perform batch inference - full_inputs, full_instrs = self.collate_fn(examples) - for i in (tq:=tqdm(range(0, len(examples), self.llm.batch_size), desc=f'LLM evaluation with {self.llm.model_name}...')): + for i in (tq:=tqdm(range(0, len(inputs), self.batch_size), desc=f'LLM evaluation with {self.llm.model_name}...')): # Extract batch - batch_examples = examples[i:i+self.llm.batch_size] - inputs, instrs = self.collate_fn(batch_examples) - input_ids = inputs['input_ids'].to(self.llm.model.device) - attention_mask = inputs['attention_mask'].to(self.llm.model.device) + batch_examples = inputs[i:i+self.batch_size] + instrs = [elt['instr'] for elt in batch_examples] + + llm_inputs = self.llm.tokenizer(instrs, padding=True, truncation=True, return_tensors="pt") + + input_ids = llm_inputs['input_ids'].to(self.llm.model.device) + attention_mask = llm_inputs['attention_mask'].to(self.llm.model.device) - if self.use_logits: - self.generation_config.output_logits=True + if self.use_logits and not pairwise: + self.generation_config.output_logits = True self.generation_config.return_dict_in_generate=True - model_outputs = self.llm.model.generate( - input_ids, - attention_mask=attention_mask, - generation_config=self.generation_config - ) + model_outputs = self.llm.model.generate(input_ids, attention_mask=attention_mask, generation_config=self.generation_config) + #get processed logits from model outputs: expected shape (n_tokens, 1, vocab_size) model_scores = torch.stack(model_outputs.logits) #get scores corresponding to first token of predefined labels from the first generated tokens - model_scores = model_scores[0, :, [tok[0] for tok in self.output_ids]].float() + model_scores = model_scores[0, :, [tok[0] for tok in output_ids]].float() #normalizing scores - getting probablity of each of predefined labesl pos_prob = torch.softmax(model_scores, 1).detach().cpu() - #final score is computed as interpolation between prob of label and it's associated value (defined by options map in config): eg. p(x=yes)*1 + p(x=no)*0 + #final score is computed as interpolation between prob of label + # and its associated value (defined by options map in config): eg. p(x=yes)*1 + p(x=no)*0 + for i, score in enumerate(pos_prob): - scores.append(torch.dot(score,self.output_values).item()) - else: - # discrete model output - # get real answer generation - decoded = self.llm.generate(inputs) - # #model_generations = self.llm.model.generate(input_ids, - # attention_mask=attention_mask, - # generation_config=self.generation_config - # ) - # decoded = self.llm.tokenizer.batch_decode(model_generations) - # breakpoint() - batch_scores, batch_weird = process_llm_outputs_assess_scores(decoded, self.options) - weird.extend(batch_weird) - # if string value specified in options is present in the generated output: assign corresponding score, - # if multiple values are present: take maximum value - scores.extend(batch_scores) - tq.set_description(f" score: {get_mean_without_unknown(scores)* 100:4.1f}%, weird :{float(len(weird))/len(scores)*100:4.1f}%") + scores.append(torch.dot(score, output_values).item()) + + else: # case: pairwise or pointwise, non-logits. + output = self.llm.model.generate( + input_ids, + attention_mask=attention_mask, + generation_config=self.generation_config).detach().cpu().numpy() + decoded = self.llm.tokenizer.batch_decode(output[:, input_ids.shape[1]:], skip_special_tokens=True) + + batch_scores, batch_weirds = process_llm_outputs_assess_scores(decoded, options) + + if pairwise: + # We post-process the scores to take into account the switches (which deter positional bias) + switches = [elt['switch'] for elt in batch_examples] + batch_scores = unswitch_switched_scores(switched_scores=batch_scores, switches=switches) + + weirds.extend(batch_weirds) + scores.extend(batch_scores) + + set_tq_description(tq, scores, weirds, pairwise) torch.cuda.empty_cache() gc.collect() - return get_mean_without_unknown(scores), scores + + if pairwise: + avg_scores = get_pairwise_scores_without_unknown(scores) + else: + avg_scores = get_mean_without_unknown(scores) + + return avg_scores, scores + \ No newline at end of file diff --git a/models/evaluators/openai.py b/models/evaluators/openai.py index 92901b5..5911556 100644 --- a/models/evaluators/openai.py +++ b/models/evaluators/openai.py @@ -8,6 +8,8 @@ from tqdm import tqdm import numpy as np import os +import random + def openai_api_calculate_cost(usage,model="gpt-4-1106-preview"): pricing = { @@ -62,7 +64,7 @@ def run_llm(client, model_name,messages): -def create_instruction(question,answer,prediction): +def create_instruction(question: str, answer: str, prediction: str): prefix = [{'role': 'system', 'content': "You are an evaluation tool. Just answer by {Yes} or {No}."}] prefix.extend([{'role': 'user', @@ -72,20 +74,33 @@ def create_instruction(question,answer,prediction): return prefix +def create_pairwise_instruction(question, ref_answer, answer_1, answer_2): + prefix = [{ + 'role': 'system', + 'content': "You are a helpful assistant, that ranks models by the quality of their answers. Please act as an impartial judge. Do not allow the length of the responses to influence your evaluation. Be as objective as possible." + }] + prefix.extend([{ + 'role': 'user', + 'content' : f"Here is a question, a ground truth answer, an AI-generated answer 1 and an AI-generated answer 2. Which answer is the most correct one ? Simply answer {{1}} if the first is better, {{2}} if the second is better and {{3}} if it's a tie. \n Question: {question}.\n Ground truth answer: {ref_answer}.\n Answer 1: {answer_1}.\n Answer 2: {answer_2}." + }]) + return prefix + # for evaluation class OpenAI(): - def __init__(self,model): + + def __init__(self, model): self.client = openai.OpenAI(api_key = os.environ.get("OPENAI_API_KEY"),) self.model_name=model + def __call__(self, predictions, references, questions): - scores=list() - weird=list() - total_cost=0 - prompt_cost=0 - completion_cost=0 - for q,r,p in (tq:= tqdm(zip(questions,references,predictions),total=len(questions),desc=f"score: 0.0%")): + scores = list() + weird = list() + total_cost = 0 + prompt_cost = 0 + completion_cost = 0 + for q,r,p in (tq:= tqdm(zip(questions,references,predictions),total=len(questions),desc="score: 0.0%")): prompt = create_instruction(q,r[0],p) - response,costs = run_llm(self.client,self.model_name,prompt) + response, costs = run_llm(self.client,self.model_name,prompt) total_cost += costs[0] prompt_cost += costs[1] completion_cost += costs[2] @@ -95,3 +110,52 @@ def __call__(self, predictions, references, questions): tq.set_description(f"cost:{total_cost:4.1f} score: {np.mean(scores)* 100:4.1f}% weird {np.mean(weird)* 100:4.1f}%") print(total_cost,prompt_cost,completion_cost) return np.mean(scores), scores, {"total_cost":total_cost,"prompt_cost":prompt_cost,"completion_cost":completion_cost} + + def pairwise_win_rate(self, predictions, opponent_predictions, references, questions): + assert len(predictions) == len(opponent_predictions) + scores = [] + weird = [] + total_cost = 0 + prompt_cost = 0 + completion_cost = 0 + for pred_1, pred_2, ref_answer, question in (tq:= tqdm(zip(predictions, opponent_predictions, references, questions), total=len(questions),desc="score: 0.0%")): + + # Randomly switch order to prevent position bias in judge + switch_order = (random.randint(0, 1) == 1) + if switch_order: + pred_1, pred_2 = pred_2, pred_1 + + prompt = create_pairwise_instruction(question, ref_answer[0], answer_1=pred_1, answer_2=pred_2) + response, costs = run_llm(self.client,self.model_name,prompt) + total_cost += costs[0] + prompt_cost += costs[1] + completion_cost += costs[2] + score = None + if '1' in response.lower(): + score = 1 + w = 0 + elif '2' in response.lower(): + score = 0 + w = 0 + elif '3' in response.lower(): + score = 0.5 + w = 0 + else: + score = 0.5 # tie by default + w = 1 + + if switch_order: + score = 1 - score + + scores.append(score) + weird.append(w) + tq.set_description(f"cost:{total_cost:4.1f} win: {scores.count(1)*100./len(scores):4.1f}% tie {scores.count(0.5)*100./len(scores):4.1f}% lose {scores.count(0)*100./len(scores):4.1f}% weird {np.mean(weird)* 100:4.1f}%") + print(total_cost, prompt_cost, completion_cost) + avg_scores = { + 'win': scores.count(1)*100./len(scores), + 'tie': scores.count(0.5)*100./len(scores), + 'lose': scores.count(0)*100./len(scores) + } + return avg_scores, scores, {"total_cost":total_cost,"prompt_cost":prompt_cost,"completion_cost":completion_cost} + + \ No newline at end of file diff --git a/models/evaluators/utils.py b/models/evaluators/utils.py index 985c727..3df9b6c 100644 --- a/models/evaluators/utils.py +++ b/models/evaluators/utils.py @@ -2,16 +2,57 @@ def process_llm_outputs_assess_scores(outputs, options, unknown_value=-100): - possible_scores = [[options[opt] for opt in options if opt in rep ] for rep in outputs] scores = [sc[0] if len(sc)==1 else unknown_value for sc in possible_scores] weird = [rep for i,rep in enumerate(outputs) if (len(possible_scores[i])==0 or len(possible_scores[i])>1)] return scores, weird + def get_mean_without_unknown(scores, unknown_value=-100): scores_to_consider = [s for s in scores if s!=unknown_value] if len(scores_to_consider)>0: return np.mean(scores_to_consider) else: return 0 + + +def unswitch_switched_scores(switched_scores: list, switches: list): + """ + When we do pairwise comparison, we randomly switch the answer orders to prevent bias + Here we de-switch the obtained scores + """ + assert len(switched_scores) == len(switches), f"{len(switched_scores)} vs {len(switches)}" + unswitched_scores = [] + for switched_score, switch in zip(switched_scores, switches): + if not (0. <= switched_score <= 1.): # nothing we can do for weird scores + unswitched_scores.append(switched_score) + else: + if switch: + unswitched_scores.append(1 - switched_score) + else: + unswitched_scores.append(switched_score) + return unswitched_scores + +def get_pairwise_scores_without_unknown(scores, unknown_value=-100) -> dict: + """ + Computes win/tie/lose scores for pairwise evaluation + """ + valid_scores = [elt for elt in scores if 0. <= elt <= 1.] + n_valid = max(1e-6, len(valid_scores)) # to avoid zero division + return { + 'win': valid_scores.count(1)*100./n_valid, + 'tie': valid_scores.count(0.5)*100./n_valid, + 'lose': valid_scores.count(0)*100./n_valid + } + + +def set_tq_description(tq, scores, weird, pairwise): + """ + Utility to set tqdm description during evaluation, depending on pairwise vs pointwise. + """ + if pairwise: + tq.set_description(f"Win: {scores.count(1)*100./len(scores):4.1f}% tie {scores.count(0.5)*100./len(scores):4.1f}%\ + lose {scores.count(0)*100./len(scores):4.1f}% weird {float(len(weird))/len(scores)*100:4.1f}%") + else: + tq.set_description(f" score: {get_mean_without_unknown(scores)* 100:4.1f}%, weird :{float(len(weird))/len(scores)*100:4.1f}%") diff --git a/models/evaluators/vllm.py b/models/evaluators/vllm.py index 45b1a3f..18e80e3 100644 --- a/models/evaluators/vllm.py +++ b/models/evaluators/vllm.py @@ -5,97 +5,75 @@ ''' from tqdm import tqdm +from vllm import SamplingParams import torch -import numpy as np -from vllm import LLM as vllm -from vllm import SamplingParams +from models.evaluators.llm import BaseEval import omegaconf -from hydra.utils import instantiate -import random -from models.evaluators.utils import * +from models.evaluators.utils import process_llm_outputs_assess_scores, get_mean_without_unknown, unswitch_switched_scores, get_pairwise_scores_without_unknown, set_tq_description import logging logger = logging.getLogger(__name__) -import gc -class VLLMeval: +class VLLMeval(BaseEval): """ - - relies on vllm for inference, directly loads the model and runs inference (no need to initiate vllm server in advance) - - output score for each sample is 1 (when positive word is present in llm output) or 0 (otherwise) + Evaluation class for vllm inference. """ - def __init__(self, model_config: dict, batch_size: int = None, config: str = "default_qa" ): - """ - model_config: generator config specified as yaml file in cofig/generator directory - batch_size: if none, it keeps default llm batch size from config - confg: name of evaluator config specified as yaml file at config/evaluators - """ + def __init__(self, model_config: dict, batch_size: int = None, config: str = "default_qa"): + super().__init__(model_config, batch_size, config) + eval_config = omegaconf.OmegaConf.load(f"config/evaluator/{config}.yaml") - model_config['init_args']['max_new_tokens']= eval_config['max_new_tokens'] - self.llm = instantiate(model_config['init_args'], prompt=eval_config['prompt']) - self.options = eval_config.output_options - self.rubrik_section = ", ".join(["{"+opt+"}" for opt in self.options]) - self.prompt = eval_config['prompt'] - self.llm.sampling_params.max_new_token = eval_config['max_new_tokens'] - if not batch_size == None: - self.llm.batch_size = batch_size - self.llm.max_new_tokens = eval_config['max_new_tokens'] - self.system_prompt = eval(self.prompt.system).replace(':\ ', ': ') - self.output_ids = [self.llm.tokenizer.encode(opt, add_special_tokens=False)[-1] for opt in sorted(self.options)] - self.output_values = torch.tensor([self.options[opt] for opt in sorted(self.options)]).float() - - def create_instruction(self,sample): - answer = sample['reference'] - question=sample['question'] - prediction=sample['candidate'] - if 'response' in sample: - response = sample['response'] - else: - response = None - prefix = [] - if 'system' in self.llm.tokenizer.chat_template: - prefix = [{'role': 'system', - 'content': self.system_prompt}] - prefix.extend([{'role': 'user', - 'content': eval(self.prompt.user).replace(':\ ', ': ')}] + # VLLM-specific settings + self.sampling_params = SamplingParams( + best_of=1, + temperature=0.0, + top_p=1, + top_k=-1, + use_beam_search=False, + max_tokens=eval_config['max_new_tokens'], + presence_penalty=0, + frequency_penalty=0, ) - else: - prefix = ([{'role': 'user', - 'content': eval(self.prompt.user_without_system).replace(':\ ', ': ')}] - ) - if 'assistant' in self.prompt: - prefix.extend([{'role': 'assistant', - 'content': eval(self.prompt.assistant).replace(':\ ', ': ')}] - ) - if not response is None: - prefix.extend([{'role': 'assistant', - 'content': response}] - ) - return self.llm.tokenizer.apply_chat_template(prefix, add_generation_prompt=True, tokenize=False) + self.llm.sampling_params.max_new_token = eval_config['max_new_tokens'] + self.batch_size = batch_size or self.llm.batch_size + self.llm.max_new_tokens = eval_config['max_new_tokens'] - def __del__(self): - # logger.info("Deleting object") - torch.cuda.empty_cache() - gc.collect() - @torch.no_grad() - def __call__(self, predictions, references, questions): - # Loading the TensorFlow Hub model + def __call__(self, predictions, references, questions, opponent_predictions=None): assert len(predictions) == len(references) == len(questions) - examples = [{'question': questions[i], 'reference': references[i], 'candidate': predictions[i]} for i in range(len(predictions))] - instrs = [self.create_instruction(sample) for sample in examples] - scores = list() - weird = list() + + pairwise = (opponent_predictions is not None) + options = self.options_pairwise if pairwise else self.options + + inputs = self.create_inputs(predictions=predictions, references=references, questions=questions, opponent_predictions=opponent_predictions) + + scores, weirds = [], [] + # Perform batch inference - for i in (tq:=tqdm(range(0, len(instrs), self.llm.batch_size), desc=f'LLM evaluation with {self.llm.model_name}...')): - decoded = self.llm.generate(instrs[i:i+self.llm.batch_size]) - batch_scores, batch_weird = process_llm_outputs_assess_scores(decoded, self.options) + for i in (tq:=tqdm(range(0, len(inputs), self.batch_size), desc=f'LLM evaluation with {self.llm.model_name}...')): + batch_examples = inputs[i:i+self.batch_size] + + instrs = [elt['instr'] for elt in batch_examples] + + decoded = self.llm.generate(instrs) + + batch_scores, batch_weird = process_llm_outputs_assess_scores(decoded, options) + + if pairwise: # samples were randomly switched to avoid position bias: we unswitch! + switches = [elt['switch'] for elt in batch_examples] + batch_scores = unswitch_switched_scores(switched_scores=batch_scores, switches=switches) + scores.extend(batch_scores) - weird.extend(batch_weird) - tq.set_description(f" score: {get_mean_without_unknown(scores)* 100:4.1f}%, weird :{float(len(weird))/len(scores)*100:4.1f}%") - logger.info(weird) - print("Weird", len(weird)) + weirds.extend(batch_weird) + + set_tq_description(tq, scores, weirds, pairwise) + + logger.info(weirds) + + if pairwise: + avg_scores = get_pairwise_scores_without_unknown(scores) + else: + avg_scores = get_mean_without_unknown(scores) - return get_mean_without_unknown(scores), scores - + return avg_scores, scores diff --git a/tests/zeroshot_test.py b/tests/zeroshot_test.py index 06ef7ff..60c82ed 100644 --- a/tests/zeroshot_test.py +++ b/tests/zeroshot_test.py @@ -7,7 +7,7 @@ import shutil from hydra import initialize, compose from bergen import main -from eval import Evaluate +from evaluate import run_eval from omegaconf import OmegaConf import pytest import gc @@ -33,12 +33,20 @@ def init(): def rmdir(folder): if os.path.exists(folder): shutil.rmtree(folder) + + def rmfile(file): + if os.path.exists(file): + os.remove(file) def clean_dirs(): rmdir('tests/exp/') rmdir('tests/index/') rmdir('tests/run/') rmdir('tests/dataset/') + + # some eval tests generate metrics: we remove them + rmfile('tests/utdata/utexp_neg/eval_dev_metrics.json') + rmfile('tests/utdata/utexp_pos/eval_dev_metrics.json') if not torch.cuda.is_available(): @@ -209,35 +217,46 @@ def test_lid(self): with initialize(config_path="../config",version_base="1.2"): test_name = inspect.currentframe().f_code.co_name exp_folder = "tests/utdata/" - Evaluate.eval(experiment_folder=exp_folder, lid=True, force=True) - + run_eval(experiment_folder=exp_folder, lid=True, force=True) def test_llmeval_default(self): with initialize(config_path="../config",version_base="1.2"): test_name = inspect.currentframe().f_code.co_name exp_folder = "tests/utdata/" - Evaluate.eval(experiment_folder=exp_folder, llm=["tinyllama-chat", "test-llm-1"], llm_batch_size= 4, llm_prompt="default_qa", force=True, samples=4) + run_eval(experiment_folder=exp_folder, llm=["tinyllama-chat", "test-llm-1"], llm_batch_size= 4, llm_prompt="default_qa", force=True, nb_samples=4) - def test_llmeval_multi(self): with initialize(config_path="../config",version_base="1.2"): test_name = inspect.currentframe().f_code.co_name exp_folder = "tests/utdata/" - Evaluate.eval(experiment_folder=exp_folder, llm=["tinyllama-chat", "test-llm-2"], llm_batch_size= 4, llm_prompt="default_multi_qa", force=True) + run_eval(experiment_folder=exp_folder, llm=["tinyllama-chat", "test-llm-2"], llm_batch_size= 4, llm_prompt="default_multi_qa", force=True) def test_vllmeval(self): with initialize(config_path="../config",version_base="1.2"): test_name = inspect.currentframe().f_code.co_name exp_folder = "tests/utdata/" - Evaluate.eval(experiment_folder=exp_folder, vllm=["tinyllama-chat", "test-vllm-1"], llm_batch_size=4, llm_prompt="default_qa", force=True) + run_eval(experiment_folder=exp_folder, llm=["vllm_tinyllama-chat", "test-vllm-1"], llm_batch_size=4, llm_prompt="default_qa", force=True) def test_vllmeval_multi(self): with initialize(config_path="../config",version_base="1.2"): test_name = inspect.currentframe().f_code.co_name exp_folder = "tests/utdata/" - Evaluate.eval(experiment_folder=exp_folder, vllm=["tinyllama-chat", "test-vllm-2"], llm_batch_size=4, llm_prompt="default_multi_qa", force=True) + run_eval(experiment_folder=exp_folder, llm=["vllm_tinyllama-chat", "test-vllm-2"], llm_batch_size=4, llm_prompt="default_multi_qa", force=True) - - + def test_llmeval_pairwise(self): + with initialize(config_path="../config",version_base="1.2"): + test_name = inspect.currentframe().f_code.co_name + folder = "tests/utdata/utexp_neg" + opponent_folder = "tests/utdata/utexp_neg" + opponent_name = "utexp_neg" + run_eval(folder=folder, llm=["tinyllama-chat", "test-llm-pairwise"], llm_batch_size=4, llm_prompt="default_qa", force=True, + opponent_folder=opponent_folder, opponent_name=opponent_name) - \ No newline at end of file + def test_vllmeval_pairwise(self): + with initialize(config_path="../config",version_base="1.2"): + test_name = inspect.currentframe().f_code.co_name + folder = "tests/utdata/utexp_neg" + opponent_folder = "tests/utdata/utexp_neg" + opponent_name = "utexp_neg" + run_eval(folder=folder, llm=["vllm_tinyllama-chat", "test-vllm-pairwise"], llm_batch_size=4, llm_prompt="default_qa", force=True, + opponent_folder=opponent_folder, opponent_name=opponent_name) \ No newline at end of file