-
Notifications
You must be signed in to change notification settings - Fork 3
/
support_model.py
84 lines (66 loc) · 2.16 KB
/
support_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
# -*- coding: utf-8 -*-
from gensim.summarization import keywords
from gensim.parsing.preprocessing import remove_stopwords
from config import ModelNames, MODEL_NAME, used_models
if ModelNames.ELASTIC in used_models:
from ml_models import elastic_search_baseline
else:
elastic_search_baseline = 1
if MODEL_NAME.BERT in used_models:
from ml_models import bert_model
else:
bert_model = 1
if MODEL_NAME.BPE in used_models:
from ml_models import bpe_model
else:
bpe_model = 1
if MODEL_NAME.USE in used_models:
from ml_models import use_model
else:
use_model = 1
def get_keywords(query):
return keywords(query)
def remove_stop_words_func(query):
return remove_stopwords(query)
def remove_slack_commands(query):
slack_commands = {
"<!everyone>": "everyone",
"<!channel>": "channel",
"<!here>": "here"
}
for k, v in slack_commands.items():
query = query.replace(k, v)
return query
def get_answer(query, use_lower=True, use_keywords=False, use_remove_stopwords=False, model_name=MODEL_NAME,
use_remove_slack_commands=True
):
if use_lower:
query = query.lower()
if use_keywords:
query = get_keywords(query)
if use_remove_stopwords:
query = remove_stop_words_func(query)
if use_remove_slack_commands:
query = remove_slack_commands(query)
try:
answer_list = []
if model_name == ModelNames.ELASTIC:
answer_list = elastic_search_baseline.get_answer(query)
if model_name == ModelNames.BERT:
answer_list = bert_model.get_answer(query)
if model_name == ModelNames.BPE:
answer_list = bpe_model.get_answer(query)
if model_name == ModelNames.USE:
answer_list = use_model.get_answer(query)
return answer_list
except Exception as ex:
print('exception:', ex)
return [{'text': "not found :(\nPlease paraphrase your query",
'channel_id': '0',
'timestamp': '0'}]
def main():
query = 'How plot gantt Chart?'
answer = get_answer(query)
print(answer)
if __name__ == '__main__':
main()