forked from andrewabeles/drug-labels
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathapp.py
71 lines (66 loc) · 2.55 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
from dash import Dash, dcc, html, Input, Output, State
import pickle
import pandas as pd
import plotly.express as px
from cleaning import prepare, get_document_features
app = Dash(__name__, suppress_callback_exceptions=True)
server = app.server
# Load trained classifier
with open('models/classifier.pkl', 'rb') as f:
classifier = pickle.load(f)
# Load classifier featureset
with open('models/classifier_features.pkl', 'rb') as f:
classifier_features = pickle.load(f)
app.layout = html.Div([
html.H1('Text Mining Drug Labels'),
dcc.Tabs([
dcc.Tab(label='Classifier', children=[
html.Div(style={'display': 'flex'}, children=[
html.Div(style={'width': '49%', 'display': 'inline-block'}, children=[
html.H3('Dosage and Administration Text'),
dcc.Textarea(
id='input-text',
style={'width': '99%', 'height': 370, 'resize': 'none'},
value='DIRECTIONS Chew tablets and let dissolve in mouth. Do not use more than directed. Do not take with food.'
),
html.Button('Classify', id='classify-button', n_clicks=0)
]),
html.Div(style={'width': '49%', 'display': 'inline-block'}, children=[
html.H3('Predicted Route of Administration'),
html.Div(id='pdist')
])
])
]),
dcc.Tab(label='Topic Model', children=[
html.Div([
html.Iframe(
src=app.get_asset_url('lda_display.html'),
style={
'position': 'absolute',
'width': '100%',
'height': '100%'
}
)
])
])
])
])
@app.callback(
Output('pdist', 'children'),
Input('classify-button', 'n_clicks'),
State('input-text', 'value')
)
def classify_text(n_clicks, text):
if n_clicks > 0:
tokens = prepare(text)
features = get_document_features(tokens, classifier_features)
prediction = classifier.classify(features)
pdist = classifier.prob_classify(features)
pdist_df = pd.DataFrame({
'route': classifier.labels(),
'probability': [pdist.prob(route) for route in classifier.labels()]
}).sort_values(by='probability')
fig = px.bar(pdist_df, y='route', x='probability')
return [dcc.Graph(figure=fig)]
if __name__ == '__main__':
app.run_server(debug=True)