-
Notifications
You must be signed in to change notification settings - Fork 2
/
LanguageModel.py
353 lines (308 loc) · 13.8 KB
/
LanguageModel.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
from Helper_Functions import pick_index_from_distribution, steal_attr
import numpy as np
from time import time
import os
import random
class LanguageModel:
def __init__(self,model):
for attr in ["initial_state","next_state","state_probs_dist","state_char_prob","state_hash"]:
if not hasattr(model,attr) and callable(getattr(model,attr)):
raise Exception("model"+str(model)[:30]+"is missing attribute:"+str(attr))
self.model = model
[steal_attr(self,model,attr) for attr in ["input_alphabet","end_token","internal_alphabet","name"]]
self.int2char = self.internal_alphabet # its a list from def of internal_alphabet
self.char2int = {c:i for i,c in enumerate(self.int2char)}
def initial_state(self):
return self.model.initial_state()
def next_state(self,state,char):
return self.model.next_state(state,char)
def state_hash(self,state):
return self.model.state_hash(state)
##########
# theoretically, of _state_probs_dist, _state_char_prob, and weight, only really need one.
# but efficient implementation for them may vary greatly:
# eg, for WFA, next-token probability is computed by 2 weight calculations:
# one for pref and one for pref+char, with weight being computed directly without considering
# conditional probs. conversely, in DPWFA and RNN, weight is product of conditional probs.
# to force this behaviour on a WFA would be woefully inefficient.
# forcing the opposite behaviour on DPWFA or RNN is also not great: why compute 2 whole sequence
# weights just to get a conditional probability when that is already available?
# tl:dr - let the model make its own decisions on how to implement these 3
##########
def state_probs_dist(self,state): # next token weights for each of internal alphabet, in model's internal alphabet order
return self.model.state_probs_dist(state)
def state_char_prob(self,state,char):
return self.model.state_char_prob(state,char)
def weight(self,sequence,as_prefix=False,state=None):
if hasattr(self.model,"weight"): # rnns, wfas may prefer their own implementation
return self.model.weight(sequence,as_prefix=as_prefix,state=state)
s = self.model.initial_state() if None is state else state
res = self.probability_of_sequence_after_state(s,sequence)
if not as_prefix:
res *= self.state_char_prob(s,self.end_token)
return res
def probability_of_sequence_after_state(self,state,seq):
p = 1
for c in seq:
p *= self.state_char_prob(state,c)
state = self.next_state(state,c)
return p
def _append_token(self,pref,c):
if isinstance(pref,list):
return pref + [c]
if isinstance(pref,tuple):
return pref + (c,)
return pref + c
def sample(self,from_seq=None,cutoff=np.inf,avg_cutoff=np.inf,empty_sequence = None):
def token_from_state(state):
return self.int2char[pick_index_from_distribution(self.state_probs_dist(state))]
def early_stop():
return (early_stopping_prob>0) and random.random() < early_stopping_prob
res = empty_sequence if not None is empty_sequence else [] # default to making list, but could also make tuple or string if requested
res = from_seq if not None is from_seq else res
early_stopping_prob = 1/avg_cutoff
s = self._state_from_sequence(res)
while True:
c = token_from_state(s)
if (c == self.end_token and avg_cutoff==np.inf) or len(res)>=cutoff or early_stop(): # reached end of sample and weren't trying to get a certain avg length OR got too long OR random cutoff for random length samples
return res
if c == self.end_token: # got end token but have to keep going b/c were looking for some kind of sample
for _ in range(100): # not trying more than that to get another one really
c = token_from_state(s)
if not c == self.end_token:
break
if c == self.end_token: # give up trying to keep going
return res
res = self._append_token(res,c)
s = self.next_state(s,c)
def _state_from_sequence(self,sequence,s=None):
if None is s:
s = self.initial_state()
for c in sequence:
s = self.next_state(s,c)
return s
def distribution_from_sequence(self,sequence):
return {c:p for c,p in zip(self.int2char,self.state_probs_dist(self._state_from_sequence(sequence)))}
def _most_likely_token_from_state(self,state,k=1):
if 1==k:
return self.int2char[np.argmax(self.state_probs_dist(state))]
a = np.array(self.state_probs_dist(state)) # some models give np arrays, some dont..
relevant = np.argpartition(a,-k)[-k:]
decreasing_order = relevant[np.argsort(-a[relevant])] #a is positive, and this sorts in increasing order
return [self.int2char[i] for i in decreasing_order]
def errors_in_group(self,group,s=None): # definitely check against results of above function before using
if len(group)==0:
return 0
if None is s:
s = self.initial_state()
by_tokens = split_by_first(group,self.end_token,self.input_alphabet)
errors = len(group) - len(by_tokens[self._most_likely_token_from_state(s)])
for t in self.input_alphabet: # check more errors for sequences that are still going
if len(by_tokens[t])==0:
continue #don't waste time computing next state or upset dynet with even more nodes than need be
errors += self.errors_in_group(by_tokens[t],self.next_state(s,t))
return errors
def num_states_visited(self,sequences):
states_visited = set()
def add_states(base_state,sequences):
if not sequences:
return
states_visited.add(self.state_hash(base_state))
by_tokens = split_by_first(sequences,self.end_token,self.input_alphabet)
for a in self.input_alphabet:
add_states(self.next_state(base_state,a),by_tokens[a])
# EOS token is never processed and brings another state, doesn't need to be passed forward
add_states(self.initial_state(),sequences)
return len(states_visited)
def next_token_preds(self,sequences,s=None,pref=None,res=None):
sequences = quickfix(sequences) # its gonna be tuples!
if None is res:
res = {}
if not sequences:
return res
if None is s:
s = self.initial_state()
if None is pref:
pref = sequences[0][:0] # doing dumb hack with tuples for this one so don't
if 0 in [len(seq) for seq in sequences]: # use original empty sequence, use whatever input you're getting...
res[pref] = self._most_likely_token_from_state(s)
by_tokens = split_by_first(sequences,self.end_token,self.input_alphabet)
for t in self.input_alphabet:
if not by_tokens[t]: #empty
continue
res = self.next_token_preds(by_tokens[t],s=self.next_state(s,t),\
pref=self._append_token(pref,t),res=res)
return res
def weights_for_sequences_with_same_prefix(self,prefix,suffixes):
if hasattr(self.model,"weights_for_sequences_with_same_prefix"):
return self.model.weights_for_sequences_with_same_prefix(prefix,suffixes)
# rnns lose accuracy when made to convert to val @ every step instead
# of using expression and only evaluating it at the end. allow
# model to use its own implementation of this if it has one
def _compute_weight(base_state,lookaheads,seq):
if len(seq)<1:
return (initial_weight*self.weight(seq,state=base_state))
base_state,base_weight = lookaheads[seq[0]]
return (base_weight*self.weight(seq[1:],state=base_state))
s = self.initial_state()
res = 1
for c in prefix:
res *= self.state_char_prob(s,c)
s = self.next_state(s,c)
initial_weight = res
lookaheads = {t:(self.next_state(s,t),initial_weight*self.state_char_prob(s,t)) \
for t in self.input_alphabet}
return [_compute_weight(s,lookaheads,seq) for seq in suffixes]
# return [(initial_weight*self.weight(seq,base_state=s)).value() for seq in suffixes]
def all_sequence_transitions(self,sequence,including_stop=False,prefix=None):
prefix = prefix if not None is prefix else []
s = self._state_from_sequence(prefix)
res = []
for c in sequence:
res.append(self.state_char_prob(s,c))
s = self.next_state(s,c)
if including_stop:
res.append(self.state_char_prob(s,self.end_token))
return res
def last_token_probabilities_after_pref(self,pref,suffixes):
base_state = self._state_from_sequence(pref)
res = []
for sequence in suffixes:
if len(sequence)==0:
res.append(1) # default value for empty string continuation
continue
s = self._state_from_sequence(sequence[:-1],s=base_state)
res.append(self.state_char_prob(s,sequence[-1]))
return res
def last_token_ranks_after_pref(self,pref,suffixes,last_token_probs_too=False):
base_state = self._state_from_sequence(pref)
res = []
probs = []
for sequence in suffixes:
if len(sequence)==0:
res.append(0) # not interesting, make it like the best because whatever of course you're gonna at least do the empty continuation
probs.append(1)
continue
s = self._state_from_sequence(sequence[:-1],s=base_state)
state_probs = self.state_probs_dist(s)
last_token_prob = state_probs[self.char2int[sequence[-1]]]
res.append(len([p for p in state_probs if p>last_token_prob])) # gives the rank, from 0 (best) to |alphabet| (worst) (there are |alphabet|+1 possible predictions bc of the stopping token)
probs.append(last_token_prob)
if last_token_probs_too:
return res,probs
return res
def probability_of_char_after_prefix(self,prefix,char):
return self.state_char_prob(self._state_from_sequence(prefix),char)
def probability_of_ending_after_prefix(self,prefix):
return self.probability_of_char_after_prefix(prefix,self.end_token)
def loss_on_word(self,w): # just for the sake of comparing to RNNs
def update_loss(loss,p):
if p == 0:
loss = np.inf
else:
loss -= np.log(p) # beware! this is log-e loss!
return loss
s = self.initial_state()
loss = 0
for c in w:
loss = update_loss(loss,self.state_char_prob(s,c))
s = self.next_state(s,c)
loss = update_loss(loss,self.state_char_prob(s,self.end_token))
return loss
def average_loss_on_group(self,group):
# inefficient implementation for now:
return np.sum([self.loss_on_word(w) for w in group])/sum([(len(s)+1) for s in group])
def WER(self,sequences,gold=None,gold_dict=None):
if None is gold and None is gold_dict: # WER against the sequences themselves
sum_errors = self.errors_in_group(sequences)
sum_lengths = sum(len(s)+1 for s in sequences) #has to predict every token and then also the end
return sum_errors/sum_lengths
pref_counts = make_prefs_count_dict(sequences)
all_prefs = list(pref_counts.keys())
if None is gold_dict:
gold = LanguageModel(gold)
assert self.end_token == gold.end_token
gold_dict = gold.next_token_preds(all_prefs)
# else: # just hope for the best about the end token thing frankly idk. check it outside with your original model or something
self_dict = self.next_token_preds(all_prefs)
errors = sum(pref_counts[p] for p in all_prefs if not self_dict[p]==gold_dict[p])
num_preds = sum((len(p)+1) for p in sequences)
return errors/num_preds
def least_squares(self,sequences,gold=None,gold_dict=None):
if None is gold and None is gold_dict:
print("need something to compare to")
return None
if None is gold_dict:
lm_gold = LanguageModel(gold)
gold_dict = {tuple(s):lm_gold.weight(s) for s in sequences}
self_dict = {tuple(s):self.weight(s) for s in sequences}
return np.sum([pow(gold_dict[tuple(s)]-self_dict[tuple(s)],2) for s in sequences])/len(sequences)
def perplexity(self,sequences,cheating=False): # could probably do all this a lot faster
# with a couple of one liners but its infinity so often that its better
# to go this way and let it cut off early when it can
num_skipped = 0
num_observations = np.sum(len(s)+1 for s in sequences)
total_logs = 0
for s in sequences:
w = one_sequence_logs(self,s)
if -np.inf == w:
if not cheating:
return np.inf # early stop when nothing to do, which just happens a lot
num_skipped += 1
num_observations -= (len(s)+1)
else:
total_logs += w
if total_logs == -np.inf:
return np.inf
res = np.inf if num_observations == 0 else np.power(2,-total_logs/num_observations)
if cheating:
return res, num_skipped
return res
def cheating_perplexity(self,sequences):
return self.perplexity(sequences,cheating=True)
# def make_spice_preds(self,prefixes,filename=None):
# if None is filename:
# filename = "temporary_preds_"+str(time())+".txt"
# with open(filename,"w") as f:
# for p in prefixes:
# state = self._state_from_sequence(p)
# preds = self._most_likely_token_from_state(state,k=len(self.internal_alphabet)) # just list them all in decreasing order
# preds = [str(t) if not t == self.model.end_token else "-1" for t in preds]
# f.write(" ".join(preds)+"\n")
# return filename
# def get_spice_and_wer_score(self,spice_number,set_type,wgy_name=None):
# assert set_type in ["validate","test"]
# prefixes,alpha = load_spice_prefs(spice_number,set_type,wgy_name)
# model_preds_file = self.make_spice_preds(prefixes)
# targets_file = make_spice_filename(spice_number,set_type,wgy_name,target=True)
# spice,wer = score_rankings(model_preds_file,targets_file)
# os.remove(model_preds_file)
# return spice, wer
def one_sequence_logs(model,seq):
transitions = model.all_sequence_transitions(seq,including_stop=True)
if True in [t<=0 for t in transitions]:
return -np.inf
return np.sum(np.log2(v) for v in transitions)
def quickfix(sequences):
if isinstance(sequences[0],list):
sequences = [tuple(l) for l in sequences]
return sequences
def make_prefs_count_dict(sequences):
res = {}
for s in sequences:
for i in range(len(s)+1):
pref = s[:i]
if pref in res:
res[pref] += 1
else:
res[pref] = 1
return res
# this is very slow:
# all_prefixes = [p[:i] for p in sequences for i in range(len(p)+1)]
# pref_counts = {p:all_prefixes.count(p) for p in set(all_prefixes)}
def split_by_first(group,end_token,alphabet):
by_tokens = {end_token:[None for s in group if len(s)==0]} # seems silly but helps deal with EOS with less ifs
group = [s for s in group if len(s)>0]
for t in alphabet:
by_tokens[t] = [s[1:] for s in group if (s[0]==t)]
return by_tokens