diff --git a/README.md b/README.md index dfe0119..e681fff 100644 --- a/README.md +++ b/README.md @@ -19,7 +19,7 @@ It is essential to understand that identifying whether a candidate is a spelling > > -- [Monojit Choudhury et. al. (2007)][1] -This package currently focuses on Out of Vocabulary (OOV) word or non-word error (NWE) correction using BERT model. The idea of using BERT was to use the context when correcting OOV. To improve this package, I would like to extend the functionality to identify RWE, optimising the package, and improving the documentation. +This package currently focuses on Out of Vocabulary (OOV) word or non-word error (NWE) correction using BERT model. The idea of using BERT was to use the context when correcting OOV. Supports Damerau distance, Bayesian reasoning, GPU inference. To improve this package, I would like to extend the functionality to identify RWE, optimising the package, and improving the documentation. ## Install diff --git a/contextualSpellCheck/contextualSpellCheck.py b/contextualSpellCheck/contextualSpellCheck.py index 62e2a91..f71edbc 100755 --- a/contextualSpellCheck/contextualSpellCheck.py +++ b/contextualSpellCheck/contextualSpellCheck.py @@ -6,6 +6,11 @@ import unicodedata import editdistance +from fastDamerauLevenshtein import damerauLevenshtein +from collections import defaultdict +from functools import partial +import numpy as np + import spacy import torch from spacy.tokens import Doc, Token, Span @@ -31,6 +36,11 @@ def __init__( max_edit_dist: int = 10, debug: bool = False, performance: bool = False, + top_n: int = 10, + lowercased_distance: bool = True, + damerau_distance: bool = True, + bayes_selection: bool = True, + ranked_bert_probs: bool = True, ): """To create an object for this class. It does not require any special @@ -46,6 +56,18 @@ def __init__( performance (bool, optional): This is used to print the time taken by individual steps in spell check. Defaults to False. + top_n (int, optional): suggestions from underlying ANN model to be + considered. Defaults to 10. + lowercased_distance (bool, optional): lowercase candidates before + computing edit distance. Defaults to True. + damerau_distance (bool, optional): additionally account for symbol swaps + when calculating a distance. Defaults to True. + bayes_selection (bool, optional): use bayes reasoning when selecting the + best candidate. Bert probabilities are the prior, textual similarities + of candidates to the input are treated as the probabilities B/A that + the corect candiadte is A, while the input was B. Defaults to True. + ranked_bert_probs (bool, optional): use ranked probs as opposed to the + absolute probs values coming from Bert. Defaults to True. """ if vocab_path != "": @@ -89,9 +111,7 @@ def __init__( elif len(extra_token) == 1: words.append(extra_token) if debug: - debug_file_path = os.path.join( - current_path, "tests", "debugFile.txt" - ) + debug_file_path = os.path.join(current_path, "tests", "debugFile.txt") with open(debug_file_path, "w+") as new_file: new_file.write("\n".join(words)) print("Final vocab at " + debug_file_path) @@ -114,6 +134,13 @@ def __init__( self.mask = self.BertTokenizer.mask_token self.debug = debug self.performance = performance + + self.top_n = int(float(top_n)) + self.lowercased_distance = lowercased_distance + self.damerau_distance = damerau_distance + self.bayes_selection = bayes_selection + self.ranked_bert_probs = ranked_bert_probs + if not Doc.has_extension("contextual_spellCheck"): Doc.set_extension("contextual_spellCheck", default=True) Doc.set_extension("performed_spellCheck", default=False) @@ -121,24 +148,18 @@ def __init__( Doc.set_extension("suggestions_spellCheck", default={}) Doc.set_extension("outcome_spellCheck", default="") Doc.set_extension("score_spellCheck", default=None) + Doc.set_extension("bayes_probs", default=None) + Doc.set_extension("bayes_details", default=None) - Span.set_extension( - "get_has_spellCheck", getter=self.span_require_spell_check - ) - Span.set_extension( - "score_spellCheck", getter=self.span_score_spell_check - ) + Span.set_extension("get_has_spellCheck", getter=self.span_require_spell_check) + Span.set_extension("score_spellCheck", getter=self.span_score_spell_check) - Token.set_extension( - "get_require_spellCheck", getter=self.token_require_spell_check - ) + Token.set_extension("get_require_spellCheck", getter=self.token_require_spell_check) Token.set_extension( "get_suggestion_spellCheck", getter=self.token_suggestion_spell_check, ) - Token.set_extension( - "score_spellCheck", getter=self.token_score_spell_check - ) + Token.set_extension("score_spellCheck", getter=self.token_score_spell_check) def __call__(self, doc): """ @@ -156,25 +177,21 @@ def __call__(self, doc): self.time_log("Misspell Identification took: ", model_loaded) if len(misspell_tokens) > 0: model_loaded = datetime.now() - doc, candidate = self.candidate_generator(doc, misspell_tokens) + doc, candidate, scores = self.candidate_generator(doc, misspell_tokens) self.time_log("Candidate Generator took: ", model_loaded) model_loaded = datetime.now() - self.candidate_ranking(doc, candidate) + self.candidate_ranking(doc, candidate, scores) self.time_log("Candidate Ranking took: ", model_loaded) raw_sentence = doc._.outcome_spellCheck.split(" ") - cleaned_sentence = self.BertTokenizer.convert_tokens_to_string( - raw_sentence - ) + cleaned_sentence = self.BertTokenizer.convert_tokens_to_string(raw_sentence) doc._.set("outcome_spellCheck", cleaned_sentence) else: misspell_tokens, doc = self.misspell_identify(doc) if len(misspell_tokens) > 0: - doc, candidate = self.candidate_generator(doc, misspell_tokens) - self.candidate_ranking(doc, candidate) + doc, candidate, scores = self.candidate_generator(doc, misspell_tokens) + self.candidate_ranking(doc, candidate, scores) raw_sentence = doc._.outcome_spellCheck.split(" ") - cleaned_sentence = self.BertTokenizer.convert_tokens_to_string( - raw_sentence - ) + cleaned_sentence = self.BertTokenizer.convert_tokens_to_string(raw_sentence) doc._.set("outcome_spellCheck", cleaned_sentence) return doc @@ -201,8 +218,8 @@ def check(self, query="", spacy_model="en_core_web_sm"): self.time_log("Misspell identification: ", model_loaded) update_query = "" if len(misspell_tokens) > 0: - candidate = self.candidate_generator(doc, misspell_tokens) - answer = self.candidate_ranking(candidate) + candidate, scores = self.candidate_generator(doc, misspell_tokens) + answer = self.candidate_ranking(candidate, scores) for i in doc: if i in misspell_tokens: update_query += answer[i] + i.whitespace_ @@ -243,7 +260,7 @@ def misspell_identify(self, doc, query=""): # deep copy is required to preserve individual token info # from objects in pipeline which can modify token info # like merge_entities - docCopy = copy.deepcopy(doc) + docCopy = copy.deepcopy(doc) # TODO: find ways to only do deepcopy when absolutely neccessary, as it is a hotspot in the profiling misspell = [] for token in docCopy: @@ -265,7 +282,7 @@ def misspell_identify(self, doc, query=""): print("misspell identified: ", misspell) return misspell, doc - def candidate_generator(self, doc, misspellings, top_n=10): + def candidate_generator(self, doc, misspellings): """Returns Candidates for misspell words This function is responsible for generating candidate list for misspell @@ -281,15 +298,13 @@ def candidate_generator(self, doc, misspellings, top_n=10): spacy to preserve meta information of the token - Keyword Args: - top_n {int}: # suggestions to be considered (default: {10}) - Returns: Dict{`Token`:List[{str}]}: Eg of return type {misspell-1: ['candidate-1','candidate-2', ...], misspell-2:['candidate-1','candidate-2' . ...]} """ + top_n = self.top_n response = {} score = {} @@ -312,20 +327,13 @@ def candidate_generator(self, doc, misspellings, top_n=10): update_query, ) - model_input = self.BertTokenizer.encode( - update_query, return_tensors="pt" - ) - mask_token_index = torch.where( - model_input == self.BertTokenizer.mask_token_id - )[1] - token_logits = self.BertModel(model_input)[0] + torch_device = "cuda" if torch.cuda.is_available() else "cpu" + model_input = self.BertTokenizer.encode(update_query, return_tensors="pt").to(torch_device) + mask_token_index = torch.where(model_input == self.BertTokenizer.mask_token_id)[1] + token_logits = self.BertModel.to(torch_device)(model_input)[0] mask_token_logits = token_logits[0, mask_token_index, :] - token_probability = torch.nn.functional.softmax( - mask_token_logits, dim=1 - ) - top_n_score, top_n_tokens = torch.topk( - token_probability, top_n, dim=1 - ) + token_probability = torch.nn.functional.softmax(mask_token_logits, dim=1) + top_n_score, top_n_tokens = torch.topk(token_probability, top_n, dim=1) top_n_tokens = top_n_tokens[0].tolist() top_n_score = top_n_score[0].tolist() if self.debug: @@ -333,10 +341,7 @@ def candidate_generator(self, doc, misspellings, top_n=10): print("token_score: ", top_n_score) if token not in response: - response[token] = [ - self.BertTokenizer.decode([candidateWord]) - for candidateWord in top_n_tokens - ] + response[token] = [self.BertTokenizer.decode([candidateWord]) for candidateWord in top_n_tokens] score[token] = [ ( self.BertTokenizer.decode([top_n_tokens[i]]), @@ -357,9 +362,9 @@ def candidate_generator(self, doc, misspellings, top_n=10): doc._.set("performed_spellCheck", True) doc._.set("score_spellCheck", score) - return doc, response + return doc, response, score - def candidate_ranking(self, doc, misspellings_dict): + def candidate_ranking(self, doc, misspellings_dict, candidate_probs): """Ranking the candidates based on edit Distance At present using a library to calculate edit distance @@ -378,18 +383,50 @@ def candidate_ranking(self, doc, misspellings_dict): response = {} # doc = self.nlp(query) + + if self.damerau_distance: + distance_func = partial(damerauLevenshtein, similarity=False) + else: + distance_func = editdistance.eval + + bayes_probs = defaultdict(dict) + bayes_details = defaultdict(dict) + for misspell in misspellings_dict: # Init least_edit distance least_edit_dist = self.max_edit_dist + if self.bayes_selection: + candidate_probs_dct = {} + for cand_pair in candidate_probs[misspell]: + candidate_probs_dct[cand_pair[0]] = cand_pair[1] + + if self.ranked_bert_probs: + keys = candidate_probs_dct.keys() + values = np.array(list(candidate_probs_dct.values())) + order = values.argsort() + ranks = order.argsort() / len(values) + candidate_probs_dct = dict(zip(keys, ranks.tolist())) + candidate_probs[misspell] = candidate_probs_dct if self.debug: - print( - "misspellings_dict[misspell]", misspellings_dict[misspell] - ) + print("misspellings_dict[misspell]", misspellings_dict[misspell]) for candidate in misspellings_dict[misspell]: - edit_dist = editdistance.eval(misspell.text, candidate) - if edit_dist < least_edit_dist: - least_edit_dist = edit_dist + if self.lowercased_distance: + edit_dist = distance_func(misspell.text.lower(), candidate.lower()) + else: + edit_dist = distance_func(misspell.text, candidate) + + if self.bayes_selection: + similarity = 1 - edit_dist / max(len(misspell.text), len(candidate)) + bert_prob = candidate_probs[misspell][candidate] + amount = -similarity * bert_prob + bayes_probs[misspell.text][candidate] = -amount + bayes_details[misspell.text][candidate] = dict(bayes_prob=-amount, similarity=similarity, bert_prob=bert_prob) + else: + amount = edit_dist + + if amount < least_edit_dist: + least_edit_dist = amount response[misspell] = candidate if self.debug: @@ -399,10 +436,7 @@ def candidate_ranking(self, doc, misspellings_dict): response[misspell], ) else: - print( - "No candidate selected for max_edit_dist=" - + str(self.max_edit_dist) - ) + print("No candidate selected for max_edit_dist=" + str(self.max_edit_dist)) if len(response) > 0: doc._.set("suggestions_spellCheck", response) @@ -421,6 +455,9 @@ def candidate_ranking(self, doc, misspellings_dict): if self.debug: print("Final suggestions", doc._.suggestions_spellCheck) + doc._.set("bayes_probs", bayes_probs) + doc._.set("bayes_details", bayes_details) + return response @staticmethod @@ -447,12 +484,7 @@ def token_require_spell_check(token): Returns: List: If no suggestions: False else: True """ - return any( - [ - token.i == suggestion.i and token.text == suggestion.text - for suggestion in token.doc._.suggestions_spellCheck.keys() - ] - ) + return any([token.i == suggestion.i and token.text == suggestion.text for suggestion in token.doc._.suggestions_spellCheck.keys()]) @staticmethod def token_suggestion_spell_check(token): @@ -471,10 +503,7 @@ def token_suggestion_spell_check(token): if token.text_with_ws == suggestion.text_with_ws: return token.doc._.suggestions_spellCheck[suggestion] else: - warnings.warn( - "Position of tokens modified by downstream element " - "in pipeline eg. merge_entities" - ) + warnings.warn("Position of tokens modified by downstream element " "in pipeline eg. merge_entities") return "" @staticmethod @@ -496,10 +525,7 @@ def token_score_spell_check(token): if token.text == suggestion.text: return token.doc._.score_spellCheck[suggestion] else: - warnings.warn( - "Position of tokens modified by downstream element" - " in pipeline eg. merge_entities" - ) + warnings.warn("Position of tokens modified by downstream element" " in pipeline eg. merge_entities") return [] def span_score_spell_check(self, span): @@ -605,22 +631,13 @@ def deep_tokenize_in_vocab(self, text): sub_tokens.append(text[char_position]) # print("pre_pos is {}, cur is {} , pre to current is {}" # .format(pre_puct_position,char_position,text[pre_puct_position+1:char_position])) - if ( - pre_puct_position >= 0 - and text[pre_puct_position + 1 : char_position] != "" - ): + if pre_puct_position >= 0 and text[pre_puct_position + 1 : char_position] != "": # print("pre_pos is {}, cur is {} , pre to current is {}" # .format(pre_puct_position,char_position,text[pre_puct_position+1:char_position])) - sub_tokens.append( - text[pre_puct_position + 1 : char_position] - ) + sub_tokens.append(text[pre_puct_position + 1 : char_position]) pre_puct_position = char_position - if ( - (len(sub_tokens) > 0) - and (char_position + 1 == text_len) - and (text[pre_puct_position + 1 :] != "") - ): + if (len(sub_tokens) > 0) and (char_position + 1 == text_len) and (text[pre_puct_position + 1 :] != ""): # print("inside last token append {}" # .format(text[pre_puct_position+1:])) sub_tokens.append(text[pre_puct_position + 1 :]) @@ -641,19 +658,13 @@ def deep_tokenize_in_vocab(self, text): # for issue #1 # merge_ents = nlp.create_pipe("merge_entities") if "parser" not in nlp.pipe_names: - raise AttributeError( - "parser is required please enable it in nlp pipeline" - ) + raise AttributeError("parser is required please enable it in nlp pipeline") # checker = ContextualSpellCheck(debug=True, max_edit_dist=3) - nlp.add_pipe( - "contextual spellchecker", config={"debug": True, "max_edit_dist": 3} - ) + nlp.add_pipe("contextual spellchecker", config={"debug": True, "max_edit_dist": 3}) # nlp.add_pipe(merge_ents) - doc = nlp( - "Income was $9.4 milion compared to the prior year of $2.7 milion." - ) + doc = nlp("Income was $9.4 milion compared to the prior year of $2.7 milion.") print("=" * 20, "Doc Extension Test", "=" * 20) print(doc._.outcome_spellCheck) diff --git a/requirements.txt b/requirements.txt index 6d9742d..fdb4fb5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -14,3 +14,5 @@ transformers>=4.0.0 flake8>=3.8.3 black==22.6 +fastDamerauLevenshtein +numpy \ No newline at end of file diff --git a/setup.py b/setup.py index 43fc1c8..c3ad1c6 100644 --- a/setup.py +++ b/setup.py @@ -34,5 +34,7 @@ "editdistance==0.6.2", "transformers>=4.0.0", "spacy>=3.0.0", + "fastDamerauLevenshtein", + "numpy" ], )