-
Notifications
You must be signed in to change notification settings - Fork 2
/
splade_stopwords.py
36 lines (26 loc) · 1.18 KB
/
splade_stopwords.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
from typing import Dict, List
from nltk.corpus import stopwords
from collections import defaultdict
from sprint.inference.methods import SpladeQueryEncoder, SpladeDocumentEncoder
stopwords = set(stopwords.words('english'))
def remove_stopwords(term_weights: Dict[str, float]):
for term in list(term_weights.keys()):
if term in stopwords:
term_weights.pop(term)
class SpladeStopWordsQueryEncoder(SpladeQueryEncoder):
def encode(self, text, **kwargs) -> Dict[str, float]:
term_weights: Dict[str, float] = super().encode(text, **kwargs)
remove_stopwords(term_weights)
return term_weights
class SpladeStopWordsDocumentEncoder(SpladeDocumentEncoder):
def encode(self, texts, **kwargs) -> List[Dict[str, float]]:
term_weights_batch = super().encode(texts, **kwargs)
map(remove_stopwords, term_weights_batch)
return term_weights_batch
def splade_stopwords(ckpt_name, etype, device='cpu'):
if etype == 'query':
return SpladeStopWordsQueryEncoder(ckpt_name, device=device)
elif etype == 'document':
return SpladeStopWordsDocumentEncoder(ckpt_name, device=device)
else:
raise ValueError