Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/elastic retriever #8

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ notebooks/
local.py

.vscode/
.idea/

# Created by https://www.gitignore.io/api/python,virtualenv,visualstudiocode
# Edit at https://www.gitignore.io/?templates=python,virtualenv,visualstudiocode
Expand Down
11 changes: 11 additions & 0 deletions graphqa/core/retriever/elastic_retriever/bert_server/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
FROM tensorflow/tensorflow:1.12.0-py3

RUN pip install --no-cache-dir bert-serving-server

COPY . /bert

WORKDIR /bert

RUN chmod +x bert-start.sh

ENTRYPOINT ["./bert-start.sh"]
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
#!/bin/bash
bert-serving-start -num_worker=1 -model_dir model/cased_L-12_H-768_A-12/
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from bert_serving.server import BertServer
from bert_serving.server.helper import get_run_args


def main(args):
server = BertServer(args=args)
server.start()
server.join()


if __name__ == '__main__':
arguments = get_run_args()
main(arguments)
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
import json
import os
import logging

from typing import Dict, List
from argparse import ArgumentParser
from bert_serving.client import BertClient

logging.basicConfig(level=logging.INFO)


def create_document(paragraph: Dict,
embedding: any,
index_name: str):
return {
"_op_type": "index",
"_index": index_name,
"id": paragraph["id"],
"topic_name": paragraph["topic_name"],
"topic_text": paragraph["topic_text"],
"topic_text_vector": embedding
}


def bulk_predict(paragraphs: List[Dict],
bert_client: BertClient,
batch_size=256):
for i in range(0, len(paragraphs), batch_size):
batch = paragraphs[i: i+batch_size]
embeddings = bert_client.encode(
[paragraph["topic_text"] for paragraph in batch]
)
for embedding in embeddings:
yield embedding


def get_parsed_paragraphs(load_path: str) -> Dict:
with open(load_path, 'r') as fp:
data = json.load(fp)
return data


def create_list_of_paragraphs(topic: Dict) -> List[Dict]:
list_of_paragraphs = []
# for topic_id in data:
# topic = data[topic_id]
topic_name = topic["topic_name"]
paragraphs = topic["paragraphs"]
for paragraph_id in paragraphs:
paragraph = paragraphs[paragraph_id]
paragraph_text = paragraph["text"]
item = dict()
item["id"] = paragraph_id
item["topic_name"] = topic_name
item["topic_text"] = paragraph_text
list_of_paragraphs.append(item)
return list_of_paragraphs


def create_list_of_sentences(topic: Dict) -> List[Dict]:
list_of_sentence = []
topic_name = topic["topic_name"]
paragraphs = topic["paragraphs"]
for paragraph_id in paragraphs:
paragraph = paragraphs[paragraph_id]
# paragraph_text = paragraph["text"]
paragraph_sentences = paragraph["sentences"]
sentence_counter = 0
for sentence in paragraph_sentences:
sentence_counter += 1
sentence_id = paragraph_id + "-" + str(sentence_counter)
item = dict()
item["id"] = sentence_id
item["topic_name"] = topic_name
item["topic_text"] = sentence
list_of_sentence.append(item)
return list_of_sentence


def main(arguments):
bc = BertClient(output_fmt='list', check_length=False)
logging.info("start")
index_name = arguments.index
json_path = arguments.json
save_path = arguments.output
for topic_name in os.listdir(json_path):
load_path = os.path.join(json_path, topic_name)
data = get_parsed_paragraphs(load_path)
logging.info("done parsing paragraphs. [1/2]")
list_of_paragraphs = create_list_of_sentences(data)
logging.info("done creating list of paragraphs. [2/2]")
with open(save_path, 'a') as f:
counter = 0
for paragraph, embedding in \
zip(list_of_paragraphs,
bulk_predict(list_of_paragraphs,
bc)):
counter += 1
logging.info("counter value is: ", counter)
logging.info("paragraph id: ", paragraph["id"])
d = create_document(paragraph, embedding, index_name)
f.write(json.dumps(d) + '\n')


if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument('--index', required=True, help='name of the ES index (es_sentences)')
parser.add_argument('--json', required=True, help='path to the directory with input json files')
parser.add_argument('--output', required=True, help='name of the output file (output_sentences.json1)')
args = parser.parse_args()

main(args)
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import logging

from json import load
from argparse import ArgumentParser
from elasticsearch import Elasticsearch

logging.basicConfig(level=logging.INFO)


INDEX_NAME = "es_sentences"
CONFIG_PATH = "index_config.json"


def create_index(es: Elasticsearch,
index_name: str,
config_path: str) -> None:
try:
with open(config_path) as file:
config = load(file)

es.indices.create(index=index_name, body=config)
logging.info("index " + index_name + " has been created!")
except:
logging.warning("some exception has occurred!")


def main(arguments):
es = Elasticsearch('localhost:9200')

index_name = arguments.index
config_path = arguments.config

create_index(es,
index_name,
config_path)


if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument('--index', required=True, help='name of the ES index (es_sentences)')
parser.add_argument('--config', required=True, help='path to the config file (index_config.json)')
args = parser.parse_args()

main(args)
98 changes: 98 additions & 0 deletions graphqa/core/retriever/elastic_retriever/elastic_search/elastic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
import logging
from typing import Dict, List

from bert_serving.client import BertClient
from elasticsearch import Elasticsearch, NotFoundError

from graphqa.core import AbstractRetriever

logging.basicConfig(level=logging.INFO)


class ElasticRetriever(AbstractRetriever):

def __init__(self,
total_number=14,
index_name='es_sentences',
ip_address='localhost:9200'):
super().__init__()
self.total_number = total_number
self.index_name = index_name
self.paragraph_ids = []
self.ip_address = ip_address

def load(self, path):
self.ip_address = path

def retrieve(self, question) -> List:
# establishing connections
bc = BertClient(ip='localhost', output_fmt='list', check_length=False)
client = Elasticsearch(self.ip_address)

query_vector = bc.encode([question])[0]

script_query = {
"script_score": {
"query": {"match_all": {}},
"script": {
"source": "cosineSimilarity(params.query_vector, doc['topic_text_vector']) + 1.0",
"params": {"query_vector": query_vector}
}
}
}
results = []
try:
response = client.search(
index=self.index_name, # name of the index
body={
"size": self.total_number,
"query": script_query,
"_source": {"includes": ["id", "topic_name", "topic_text"]}
}
)
logging.info(response)
results = self.post_process_response(response)
except ConnectionError:
logging.warning("docker isn't up and running!")
except NotFoundError:
logging.warning("no such index!")
return results

def post_process_response(self,
response: Dict) -> List:
scored_responses = response["hits"]["hits"]
processed_response = dict()
target_sentences = []
for score_object in scored_responses:
score = score_object["_score"]
source = score_object["_source"]
sentence_id = source["id"]
tokenized_sentence_id = sentence_id.split("-")
topic_id = tokenized_sentence_id[0]
topic_name = source["topic_name"]
sentence = source["topic_text"]
target_sentences.append(sentence)
if topic_id not in processed_response:
processed_response[topic_id] = dict()
processed_response[topic_id]["count"] = 0
processed_response[topic_id]["topic_name"] = topic_name
processed_response[topic_id]["sum_score"] = 0
processed_response[topic_id]["sentence_ids"] = []
processed_response[topic_id]["count"] += 1
processed_response[topic_id]["sum_score"] += score
processed_response[topic_id]["sentence_ids"].append(sentence_id)
logging.info(processed_response)
ranking_dictionary = dict()
for topic_id in processed_response:
topic = processed_response[topic_id]
count = topic["count"]
sum_score = topic["sum_score"]
topic_name = topic["topic_name"]
if count not in ranking_dictionary:
ranking_dictionary[count] = dict()
average_score = sum_score / count
ranking_dictionary[count][topic_name] = average_score

logging.info(ranking_dictionary)

return target_sentences
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
{
"settings" : {
"number_of_shards": 1,
"number_of_replicas": 1
},

"mappings": {
"properties": {
"id": {"type": "text"},
"topic_name": {"type": "text"},
"topic_text": {"type": "text"},
"topic_text_vector": {"type": "dense_vector", "dims": 768}
}}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import json
from typing import List
from argparse import ArgumentParser

from elasticsearch import Elasticsearch
from elasticsearch.helpers import bulk


def load_data_set(path: str) -> List:
with open(path) as f:
return [json.loads(line) for line in f]


def main(arguments):
client = Elasticsearch('localhost:9200')
docs = load_data_set(arguments.data)
bulk(client, docs)


if __name__ == '__main__':
parser = ArgumentParser(description='indexing ES documents.')
parser.add_argument('--data', help='ES documents (output_sentences.json1)')
args = parser.parse_args()
main(args)
Loading