From 322382b74f8583ce5b13086634c74f249ad2febf Mon Sep 17 00:00:00 2001 From: Stephane Clinchant Date: Thu, 16 Jan 2025 16:19:53 +0100 Subject: [PATCH] fix eval scores + multigpu indexing , reranking --- models/rerankers/crossencoder.py | 2 ++ models/retrievers/dense.py | 1 - models/retrievers/splade.py | 1 - modules/rerank.py | 10 +++++++--- utils.py | 5 ++--- 5 files changed, 11 insertions(+), 8 deletions(-) diff --git a/models/rerankers/crossencoder.py b/models/rerankers/crossencoder.py index 5689f16..bb5476c 100644 --- a/models/rerankers/crossencoder.py +++ b/models/rerankers/crossencoder.py @@ -17,6 +17,8 @@ def __init__(self, model_name=None,max_len=512): self.model = AutoModelForSequenceClassification.from_pretrained(self.model_name, low_cpu_mem_usage=True, torch_dtype=torch.float16) self.tokenizer = AutoTokenizer.from_pretrained(self.model_name, max_length=self.max_len) self.model.eval() + if torch.cuda.device_count() > 1 and torch.cuda.is_available(): + self.model = torch.nn.DataParallel(self.model) def collate_fn(self, examples): question = [e['query'] for e in examples] diff --git a/models/retrievers/dense.py b/models/retrievers/dense.py index 64482ca..59e40d9 100644 --- a/models/retrievers/dense.py +++ b/models/retrievers/dense.py @@ -20,7 +20,6 @@ def __init__(self, model_name, max_len, pooler, similarity, prompt_q=None, promp self.query_encoder = self.model # otherwise symmetric self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - self.model = self.model.to(self.device) self.model.eval() if query_encoder_name: self.query_encoder = self.query_encoder.to(self.device) diff --git a/models/retrievers/splade.py b/models/retrievers/splade.py index c4b91b9..2827b13 100644 --- a/models/retrievers/splade.py +++ b/models/retrievers/splade.py @@ -22,7 +22,6 @@ def __init__(self, model_name, max_len=512, query_encoder_name=None): self.tokenizer = AutoTokenizer.from_pretrained(self.model_name, max_length=self.max_len) self.reverse_vocab = {v: k for k, v in self.tokenizer.vocab.items()} self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - self.model = self.model.to(self.device) self.model.eval() if query_encoder_name: self.query_encoder = self.query_encoder.to(self.device) diff --git a/modules/rerank.py b/modules/rerank.py index 312f94a..de010ad 100644 --- a/modules/rerank.py +++ b/modules/rerank.py @@ -19,10 +19,13 @@ def __init__(self, init_args=None, batch_size=1): self.batch_size = batch_size self.init_args = init_args self.model = instantiate(self.init_args) + self.model_name=self.model.model_name.replace('/', '_') + @torch.no_grad() def eval(self, dataset): # get dataloader - self.model.model.to('cuda') + #self.model.model.to('cuda') + self.model.model = self.model.model.to('cuda') dataloader = DataLoader(dataset, batch_size=self.batch_size, collate_fn=self.model.collate_fn) q_ids, d_ids, scores, embs_list = list(), list(), list(), list() # run inference on the dataset @@ -30,7 +33,8 @@ def eval(self, dataset): q_ids += batch.pop('q_id') d_ids += batch.pop('d_id') outputs = self.model(batch) - score = outputs['score'] + score = outputs['score'].detach().cpu() + scores.append(score) # get flat tensor of scores @@ -64,4 +68,4 @@ def sort_by_score_indexes(self, scores, q_ids, d_ids): return q_ids_sorted, doc_ids_sorted, scores_sorted def get_clean_model_name(self): - return self.model.model_name.replace('/', '_') \ No newline at end of file + return self.model_name \ No newline at end of file diff --git a/utils.py b/utils.py index f81ef79..f08028c 100644 --- a/utils.py +++ b/utils.py @@ -22,7 +22,6 @@ from omegaconf import OmegaConf from tqdm import tqdm - def left_pad(sequence: torch.LongTensor, max_length: int, pad_value: int) -> torch.LongTensor: """ Helper function to perform left padding @@ -272,10 +271,10 @@ def eval_retrieval_kilt(experiment_folder, qrels_folder, query_dataset_name, doc for i, (doc_id, score) in enumerate(zip(doc_ids[i], scores[i])): # if we have duplicate doc ids (because different passage can map to same wiki page) only write the max scoring passage if doc_id not in run[q_id]: - run[q_id].update({doc_id: score}) + run[q_id].update({doc_id: float(score)}) # if there is a higher scoring passage from the same wiki_doc, update the score (maxP) elif score >= run[q_id][doc_id]: - run[q_id].update({doc_id: score}) + run[q_id].update({doc_id: float(score)}) if write_trec: with open(f'{experiment_folder}/eval_{split}_{reranking_str}ranking_run.trec', 'w') as trec_out: