Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix eval scores + multigpu indexing , reranking #39

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions models/rerankers/crossencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ def __init__(self, model_name=None,max_len=512):
self.model = AutoModelForSequenceClassification.from_pretrained(self.model_name, low_cpu_mem_usage=True, torch_dtype=torch.float16)
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name, max_length=self.max_len)
self.model.eval()
if torch.cuda.device_count() > 1 and torch.cuda.is_available():
self.model = torch.nn.DataParallel(self.model)

def collate_fn(self, examples):
question = [e['query'] for e in examples]
Expand Down
1 change: 0 additions & 1 deletion models/retrievers/dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ def __init__(self, model_name, max_len, pooler, similarity, prompt_q=None, promp
self.query_encoder = self.model # otherwise symmetric
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model = self.model.to(self.device)
self.model.eval()
if query_encoder_name:
self.query_encoder = self.query_encoder.to(self.device)
Expand Down
1 change: 0 additions & 1 deletion models/retrievers/splade.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ def __init__(self, model_name, max_len=512, query_encoder_name=None):
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name, max_length=self.max_len)
self.reverse_vocab = {v: k for k, v in self.tokenizer.vocab.items()}
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model = self.model.to(self.device)
self.model.eval()
if query_encoder_name:
self.query_encoder = self.query_encoder.to(self.device)
Expand Down
10 changes: 7 additions & 3 deletions modules/rerank.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,22 @@ def __init__(self, init_args=None, batch_size=1):
self.batch_size = batch_size
self.init_args = init_args
self.model = instantiate(self.init_args)
self.model_name=self.model.model_name.replace('/', '_')

@torch.no_grad()
def eval(self, dataset):
# get dataloader
self.model.model.to('cuda')
#self.model.model.to('cuda')
self.model.model = self.model.model.to('cuda')
dataloader = DataLoader(dataset, batch_size=self.batch_size, collate_fn=self.model.collate_fn)
q_ids, d_ids, scores, embs_list = list(), list(), list(), list()
# run inference on the dataset
for batch in tqdm(dataloader, desc=f'Reranking: {self.model.model_name}'):
q_ids += batch.pop('q_id')
d_ids += batch.pop('d_id')
outputs = self.model(batch)
score = outputs['score']
score = outputs['score'].detach().cpu()

scores.append(score)

# get flat tensor of scores
Expand Down Expand Up @@ -64,4 +68,4 @@ def sort_by_score_indexes(self, scores, q_ids, d_ids):
return q_ids_sorted, doc_ids_sorted, scores_sorted

def get_clean_model_name(self):
return self.model.model_name.replace('/', '_')
return self.model_name
5 changes: 2 additions & 3 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from omegaconf import OmegaConf
from tqdm import tqdm


def left_pad(sequence: torch.LongTensor, max_length: int, pad_value: int) -> torch.LongTensor:
"""
Helper function to perform left padding
Expand Down Expand Up @@ -272,10 +271,10 @@ def eval_retrieval_kilt(experiment_folder, qrels_folder, query_dataset_name, doc
for i, (doc_id, score) in enumerate(zip(doc_ids[i], scores[i])):
# if we have duplicate doc ids (because different passage can map to same wiki page) only write the max scoring passage
if doc_id not in run[q_id]:
run[q_id].update({doc_id: score})
run[q_id].update({doc_id: float(score)})
# if there is a higher scoring passage from the same wiki_doc, update the score (maxP)
elif score >= run[q_id][doc_id]:
run[q_id].update({doc_id: score})
run[q_id].update({doc_id: float(score)})

if write_trec:
with open(f'{experiment_folder}/eval_{split}_{reranking_str}ranking_run.trec', 'w') as trec_out:
Expand Down