-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathnaive_bayes_db.py
189 lines (177 loc) · 9.09 KB
/
naive_bayes_db.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
import sqlite3 as sl3
from os.path import exists
class NaiveBayesDB(object):
"""Creates and maintains a database that will
hold values for the NaiveBayesClassifier.
"""
def __init__(self,
database_path,
global_description='',
positive_description='',
negative_description=''):
"""Creates the database schema if the file does not exist"""
if not exists(database_path):
self.db_connection = sl3.connect(database_path)
self.db_connection.text_factory = str
cursor = self.db_connection.cursor()
cursor.execute("create table counters (counter INTEGER, name TEXT, description TEXT)")
cursor.execute("insert into counters VALUES (0, 'global_counter', ?)", (global_description,))
cursor.execute("insert into counters VALUES (0, 'positive_counter', ?)", (positive_description,))
cursor.execute("insert into counters VALUES (0, 'negative_counter', ?)", (negative_description,))
cursor.execute("create table negative_classification (token BLOB UNIQUE, count INTEGER)")
cursor.execute("create table positive_classification (token BLOB UNIQUE, count INTEGER)")
self.db_connection.commit()
cursor.close()
return None
self.db_connection = sl3.connect(database_path)
self.db_connection.text_factory = str
self.cursor = None
return None
def update_counter(self, counter='', value=0):
"""Increment each counter according to train methods."""
possible_counters = ['global_counter', 'positive_counter', 'negative_counter']
if (not counter) or (counter not in possible_counters):
return False
current = self.cursor.execute("SELECT counter from counters WHERE name=?", (counter,))
current_value = current.fetchone()[0]
if (current_value == 0) and (value < 1):
return False
current_value += value
self.cursor.execute("UPDATE counters SET counter=? WHERE name=?", (current_value, counter))
return True
def _increment_or_insert(self, token, polarity=None):
"""for each token, if token not in database, add token to the database and set count to 1; if the
token exists in the database, increment the counter by 1."""
if not polarity:
return False
try:
if polarity == 'positive':
self.cursor.execute("insert into positive_classification VALUES (?, ?)", (token, 1))
elif polarity == 'negative':
self.cursor.execute("insert into negative_classification VALUES (?, ?)", (token, 1))
except sl3.IntegrityError: # token exists in database, so increment token count
if polarity == 'positive':
self.cursor.execute("SELECT count from positive_classification WHERE token=?", (token,))
value = self.cursor.fetchone()[0]
value += 1
self.cursor.execute("UPDATE positive_classification SET count=? WHERE token=?", (value, token))
elif polarity == 'negative':
self.cursor.execute("SELECT count from negative_classification WHERE token=?", (token,))
value = self.cursor.fetchone()[0]
value += 1
self.cursor.execute("UPDATE negative_classification SET count=? WHERE token=?", (value, token))
finally:
pass
return None
def _decrement_or_remove(self, token, polarity):
"""for each token, if token not in database, pass, else if token count >
1, decrement, else if token count == 1, remove element from database"""
if (not polarity) or (polarity not in ['positive', 'negative']):
return False
try:
if polarity == 'positive':
self.cursor.execute("SELECT count from positive_classification WHERE token=?", (token,))
# current_cursor
current_value = self.cursor.fetchone()
if not current_value: # not in database; do nothing
return True
value = current_value[0]
if value == 1: # remove the token from the database instead of setting to 0
self.cursor.execute("DELETE FROM positive_classification WHERE token=?", (token,))
else: # decrement
value -= 1
self.cursor.execute("UPDATE positive_classification SET count=? WHERE token=?", (value, token))
else:
self.cursor.execute("SELECT count from negative_classification WHERE token=?", (token,))
current_value = self.cursor.fetchone()
if not current_value: # not in database; do nothing
return True
value = current_value[0]
if value == 1: # remove the token from the database instead of setting to 0
self.cursor.execute("DELETE FROM negative_classification WHERE token=?", (token,))
else: # decrement
value -= 1
self.cursor.execute("UPDATE negative_classification SET count=? WHERE token=?", (value, token))
finally:
pass
return None
def train_positive(self, tokens):
"""batch update/insert tokens and increment global and positive counters"""
self.cursor = self.db_connection.cursor()
self.cursor.execute('BEGIN TRANSACTION')
for token in tokens:
self._increment_or_insert(token.token_string, polarity='positive')
self.update_counter('global_counter', value=1)
self.update_counter('positive_counter', value=1)
self.db_connection.commit()
self.cursor.close()
return None
def train_negative(self, tokens):
"""For each token in tokens, add token/counter and/or increment negative_counter in database.
Increment the global counter"""
self.cursor = self.db_connection.cursor()
self.cursor.execute('BEGIN TRANSACTION')
for token in tokens:
self._increment_or_insert(token.token_string, polarity='negative')
self.update_counter('global_counter', value=1)
self.update_counter('negative_counter', value=1)
self.db_connection.commit()
self.cursor.close()
return None
def untrain_positive(self, tokens):
"""for each token, if token in database, decrement the token's counter by 1.
if token does not exist in the database, pass; if token count == 1,
remove from database"""
self.cursor = self.db_connection.cursor()
self.cursor.execute('BEGIN')
for token in tokens:
self._decrement_or_remove(token.token_string, polarity='positive')
self.update_counter('global_counter', value=-1)
self.update_counter('positive_counter', value=-1)
self.db_connection.commit()
self.cursor.close()
return None
def untrain_negative(self, tokens):
"""for each token, if token in database, decrement the token's counter by 1.
if token does not exist in the database, pass; if token count == 1,
remove from database"""
self.cursor = self.db_connection.cursor()
self.cursor.execute('BEGIN')
for token in tokens:
self._decrement_or_remove(token.token_string, polarity='negative')
self.update_counter('global_counter', value=-1)
self.update_counter('negative_counter', value=-1)
self.db_connection.commit()
self.cursor.close()
return None
def counter_for_token(self, token, polarity=''):
if (not polarity) or (polarity not in ['positive', 'negative']):
return False
cursor = self.db_connection.cursor()
try:
if polarity == 'positive':
cursor.execute("SELECT count from positive_classification WHERE token=?", (token,))
current_value = cursor.fetchone()
if not current_value: # not in database
return 1 # using this value as a default; TODO: find optimal value
return current_value[0]
else:
cursor.execute("SELECT count from negative_classification WHERE token=?", (token,))
current_value = cursor.fetchone()
if not current_value: # not in database; use 1
return 1 # using this value as a default; TODO: find optimal value
return current_value[0]
finally:
cursor.close()
return True
def total_for_polarity(self, polarity=''):
"""Returns the counter for the given polarity"""
if (not polarity) or (polarity not in ['positive', 'negative']):
return False
cursor = self.db_connection.cursor()
current_counter = cursor.execute("SELECT counter from counters WHERE name=?", ("%s_counter" % polarity,))
counter_value = current_counter.fetchone()[0]
if counter_value == 0:
counter_value = 1
cursor.close()
return counter_value