-
Notifications
You must be signed in to change notification settings - Fork 0
/
SearchEngine.py
129 lines (109 loc) · 4.88 KB
/
SearchEngine.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
from BM25 import BM25
from EngineStatus import Status
from Ranker import Ranker
from Retriever import Retriever
from utils import show_scores
from utils import tokenize
import gensim
import gensim.downloader as api
from gensim.models.keyedvectors import KeyedVectors
MODEL = 'SO_vectors_200.bin'
# logging.basicConfig(level=logging.DEBUG)
class SearchEngine:
def __init__(self, logger=print):
self.__logger = logger
self.__status = Status.DOWN
self.__docs = None
self.__tokenized_docs = None
self.__query_embedding = self.load_query_embedding()
def load_query_embedding(self):
self.__setStatus(Status.PREPARING)
self.__logger(f"Loading {MODEL}...")
# https://github.com/vefstathiou/SO_word2vec
corpus = KeyedVectors.load_word2vec_format(
f"./models/{MODEL}", binary=True)
self.__setStatus(Status.READY)
return corpus
def __prepare_docs(self, documents):
self.__setStatus(Status.PREPARING)
self.__logger(f"Tokenizing documents...")
clean_documents = [document['text'] for document in documents]
corpus = [tokenize(doc) for doc in clean_documents]
self.__setStatus(Status.READY)
return (documents, corpus)
def __prepare_query(self, query):
self.__logger(f"Tokenizing query...")
return tokenize(query)
def get_status(self):
return self.__status
def __setStatus(self, status: Status):
self.__logger(f"MODEL STATUS: {status}")
self.__status = status
def train(self, docs):
self.__docs, self.__tokenized_docs = self.__prepare_docs(docs)
def search(self, raw_query, dual=False):
documents = self.__docs
tokenized_documents = self.__tokenized_docs
tokenized_query = self.__prepare_query(raw_query)
# if dual Initial screening to fastly retrieve the n most relevant documents using BM25
if dual:
retrieval_scores, retrieved_documents, tokenzed_retrieved_documents = self.__retrieve(
tokenized_query, dual)
documents = retrieved_documents
tokenized_documents = tokenzed_retrieved_documents
results = self.__rank(tokenized_query, documents, tokenized_documents)
return {
'data': results,
'model': Retriever.model + (f' + {Ranker.model}' if dual else ''),
'object': MODEL
}
def __retrieve(self, tokenized_query, dual):
# Set the BM25 model
retriever = Retriever(self.__tokenized_docs)
# Return list with sorted positions and its scores
retrieval_indexes, retrieval_scores = retriever.query(tokenized_query)
# Get rid of non-positive-ranked entries
positive_indexes = [retrieval_indexes[index] for index in range(
len(retrieval_indexes)) if retrieval_scores[index] > 0]
# If dual, make use of this list to reduce the list of valid entries.
if dual:
# Get entries that match with the filtered indexes
retrieved_entries = [self.__docs[idx]
for idx in positive_indexes]
# Get tokenized entries that match with the filtered indexes
tokenized_retrieved_entries = [
self.__tokenized_docs[idx] for idx in positive_indexes]
print("======== BM25 ========")
show_scores(retrieved_entries, retrieval_scores,
len(retrieved_entries))
return (
retrieval_scores,
retrieved_entries,
tokenized_retrieved_entries,
)
else:
show_scores([], retrieval_scores, 0)
return (retrieval_scores, None, None)
def __rank(self, tokenized_query, retrieved_documents, tokenized_retrieved_documents):
self.__setStatus(Status.RANKING)
if len(retrieved_documents) == 0:
return []
self.__logger(f"Ranking documents...")
ranker = Ranker(query_embedding=self.__query_embedding,
document_embedding=self.__query_embedding)
ranker_indexes, ranker_scores = ranker.rank(
tokenized_query, tokenized_retrieved_documents)
reranked_documents = [retrieved_documents[idx]
for idx in ranker_indexes]
print(" [DONE]")
print("======== Embedding ========")
show_scores(reranked_documents, ranker_scores, len(reranked_documents))
self.__setStatus(Status.READY)
results = [{'object': MODEL,
'document': float(i),
# + float(retrieval_scores[ranker_indexes[i]]),
'score': float(ranker_scores[i]),
'text': reranked_documents[i]['text'],
'metadata': reranked_documents[i]['metadata'],
} for i in range(len(ranker_indexes))]
return results