From 3f937fbc475bd8600cb62ca885fcf0e4a80effba Mon Sep 17 00:00:00 2001 From: David Chanin Date: Wed, 15 Mar 2023 00:08:43 +0000 Subject: [PATCH] feat: new models trained on Framenet exemplars (#18) * include exemplars in framenet training * skipping invalid trigger exemplars * skip exemplars by default during training * fixing tests * improving data augmentations * ensure wordnet download for inference * updating snapshots * adding more info when augmentations fail validation * adding more augmentations from nlpaug * fixing linting * fixing keyboard augmentation * more checks on keyboard augmentation * tweaking augmentations * fixing tests * adding safety check to uppercase augmentation * lower augmentation rate * adding more augmentations * tweaking augs * removing debugging output * reduce augmentation * tweaking augmentation probs * tweaking augmentation probs * fixing type import * adding option to delete non-optimal models as training progresses * tweaking augmentations * updating models * updating README with new model stats --- README.md | 6 +- frame_semantic_transformer/constants.py | 2 +- .../data/LoaderDataCache.py | 42 +++-- .../data/TaskSampleDataset.py | 9 +- .../data/augmentations/DataAugmentation.py | 26 ++- .../augmentations/DoubleQuotesAugmentation.py | 59 +++++++ .../augmentations/KeyboardAugmentation.py | 48 ++++++ .../augmentations/LowercaseAugmentation.py | 16 +- .../RemoveContractionsAugmentation.py | 25 --- .../RemoveEndPunctuationAugmentation.py | 29 +++- .../SimpleMisspellingAugmentation.py | 49 ++++++ .../StripPunctuationAugmentation.py | 63 ++++++++ .../data/augmentations/SynonymAugmentation.py | 42 +++++ .../augmentations/UppercaseAugmentation.py | 20 +++ .../data/augmentations/__init__.py | 14 +- .../data/augmentations/chain_augmentations.py | 11 +- .../modification_helpers/__init__.py | 9 ++ .../modification_helpers/get_sample_text.py | 18 +++ .../modify_text_without_changing_length.py | 53 ++++++ .../modification_helpers/splice_text.py | 103 ++++++++++++ .../data/frame_types.py | 3 + .../framenet17/Framenet17InferenceLoader.py | 35 ++-- .../framenet17/Framenet17TrainingLoader.py | 126 +++++++++++---- .../framenet17/ensure_wordnet_downloaded.py | 8 + .../data/loaders/loader.py | 2 +- .../propbank34/Propbank34TrainingLoader.py | 14 +- frame_semantic_transformer/data/tasks/Task.py | 5 +- .../data/tasks/TaskSample.py | 6 +- .../data/tasks_from_annotated_sentences.py | 12 +- .../training/ModelRecorder.py | 102 ++++++++++++ .../training/TrainingModelWrapper.py | 37 +++-- .../training/find_best_val_model_paths.py | 2 +- frame_semantic_transformer/training/train.py | 2 + pyproject.toml | 1 + setup.cfg | 3 + .../test_frame_semantic_transformer.ambr | 2 +- tests/conftest.py | 7 +- .../test_tasks_from_annotated_sentences.ambr | 140 ++++++++-------- ...est_modify_text_without_changing_length.py | 115 +++++++++++++ .../modification_helpers/test_splice_text.py | 151 ++++++++++++++++++ .../test_DoubleQuotesAugmentation.py | 46 ++++++ .../test_KeyboardAugmentation.py | 48 ++++++ .../test_LowercaseAugmentation.py | 38 +++-- .../test_RemoveContractionsAugmentation.py | 34 ---- .../test_RemoveEndPunctuationAugmentation.py | 35 ++-- .../test_SimpleMisspellingAugmentation.py | 29 ++++ .../test_StripPunctuationAugmentation.py | 29 ++++ .../augmentations/test_SynonymAugmenter.py | 31 ++++ .../test_UppercaseAugmentation.py | 38 +++++ .../augmentations/test_chain_augmentations.py | 24 +-- .../test_Framenet17TrainingLoader.py | 29 ++++ .../test_Propbank34TrainingLoader.ambr | 10 +- tests/data/test_LoaderDataCache.py | 98 ++++++++++-- tests/training/test_ModelRecorder.py | 98 ++++++++++++ .../test_find_best_val_model_paths.py | 2 +- 55 files changed, 1693 insertions(+), 313 deletions(-) create mode 100644 frame_semantic_transformer/data/augmentations/DoubleQuotesAugmentation.py create mode 100644 frame_semantic_transformer/data/augmentations/KeyboardAugmentation.py delete mode 100644 frame_semantic_transformer/data/augmentations/RemoveContractionsAugmentation.py create mode 100644 frame_semantic_transformer/data/augmentations/SimpleMisspellingAugmentation.py create mode 100644 frame_semantic_transformer/data/augmentations/StripPunctuationAugmentation.py create mode 100644 frame_semantic_transformer/data/augmentations/SynonymAugmentation.py create mode 100644 frame_semantic_transformer/data/augmentations/UppercaseAugmentation.py create mode 100644 frame_semantic_transformer/data/augmentations/modification_helpers/__init__.py create mode 100644 frame_semantic_transformer/data/augmentations/modification_helpers/get_sample_text.py create mode 100644 frame_semantic_transformer/data/augmentations/modification_helpers/modify_text_without_changing_length.py create mode 100644 frame_semantic_transformer/data/augmentations/modification_helpers/splice_text.py create mode 100644 frame_semantic_transformer/data/loaders/framenet17/ensure_wordnet_downloaded.py create mode 100644 frame_semantic_transformer/training/ModelRecorder.py create mode 100644 tests/data/augmentations/modification_helpers/test_modify_text_without_changing_length.py create mode 100644 tests/data/augmentations/modification_helpers/test_splice_text.py create mode 100644 tests/data/augmentations/test_DoubleQuotesAugmentation.py create mode 100644 tests/data/augmentations/test_KeyboardAugmentation.py delete mode 100644 tests/data/augmentations/test_RemoveContractionsAugmentation.py create mode 100644 tests/data/augmentations/test_SimpleMisspellingAugmentation.py create mode 100644 tests/data/augmentations/test_StripPunctuationAugmentation.py create mode 100644 tests/data/augmentations/test_SynonymAugmenter.py create mode 100644 tests/data/augmentations/test_UppercaseAugmentation.py create mode 100644 tests/training/test_ModelRecorder.py rename tests/{data => }/training/test_find_best_val_model_paths.py (74%) diff --git a/README.md b/README.md index da80ae1..3230ef9 100644 --- a/README.md +++ b/README.md @@ -21,9 +21,9 @@ This library uses the same train/dev/test documents and evaluation methodology a | Task | Sesame F1 (dev/test) | Small Model F1 (dev/test) | Base Model F1 (dev/test) | | ---------------------- | -------------------- | ------------------------- | ------------------------ | -| Trigger identification | 0.80 / 0.73 | 0.74 / 0.70 | 0.78 / 0.71 | -| Frame classification | 0.90 / 0.87 | 0.83 / 0.81 | 0.89 / 0.87 | -| Argument extraction | 0.61 / 0.61 | 0.68 / 0.70 | 0.74 / 0.72 | +| Trigger identification | 0.80 / 0.73 | 0.75 / 0.71 | 0.78 / 0.74 | +| Frame classification | 0.90 / 0.87 | 0.87 / 0.86 | 0.91 / 0.89 | +| Argument extraction | 0.61 / 0.61 | 0.76 / 0.73 | 0.78 / 0.75 | The base model performs similarly to Open-Sesame on trigger identification and frame classification tasks, but outperforms it by a significant margin on argument extraction. The small pretrained model has lower F1 than base across the board, but is 1/4 the size and still outperforms Open-Sesame at argument extraction. diff --git a/frame_semantic_transformer/constants.py b/frame_semantic_transformer/constants.py index 9cce23d..9b83932 100644 --- a/frame_semantic_transformer/constants.py +++ b/frame_semantic_transformer/constants.py @@ -5,6 +5,6 @@ MODEL_MAX_LENGTH = 512 OFFICIAL_RELEASES = ["base", "small"] # TODO: small, large -MODEL_REVISION = "v0.1.0" +MODEL_REVISION = "v0.2.0" PADDING_LABEL_ID = -100 DEFAULT_NUM_WORKERS = os.cpu_count() or 2 diff --git a/frame_semantic_transformer/data/LoaderDataCache.py b/frame_semantic_transformer/data/LoaderDataCache.py index be1f4bb..ce8bfeb 100644 --- a/frame_semantic_transformer/data/LoaderDataCache.py +++ b/frame_semantic_transformer/data/LoaderDataCache.py @@ -1,9 +1,12 @@ from __future__ import annotations from collections import defaultdict from functools import lru_cache +from itertools import product +from typing import TYPE_CHECKING -from .loaders.loader import InferenceLoader -from .frame_types import Frame +if TYPE_CHECKING: + from .frame_types import Frame + from .loaders.loader import InferenceLoader class LoaderDataCache: @@ -77,11 +80,13 @@ def get_lexical_unit_bigram_to_frame_lookup_map(self) -> dict[str, list[str]]: for part in parts: # also key this as a mongram if there's only 1 element or the word is rare enough if len(parts) == 1 or self.loader.prioritize_lexical_unit(part): - lu_bigrams.append(self._normalize_lexical_unit_ngram([part])) + for norm_part in self._normalize_lexical_unit_ngram([part]): + lu_bigrams.append(norm_part) if prev_part is not None: - lu_bigrams.append( - self._normalize_lexical_unit_ngram([prev_part, part]) - ) + for norm_parts in self._normalize_lexical_unit_ngram( + [prev_part, part] + ): + lu_bigrams.append(norm_parts) prev_part = part for bigram in lu_bigrams: @@ -97,16 +102,20 @@ def get_possible_frames_for_trigger_bigrams( possible_frames = [] lookup_map = self.get_lexical_unit_bigram_to_frame_lookup_map() for bigram in bigrams: - normalized_bigram = self._normalize_lexical_unit_ngram(bigram) - if normalized_bigram in lookup_map: - bigram_frames = lookup_map[normalized_bigram] - possible_frames += bigram_frames + # sorted here just to get a consistent ordering + for normalized_bigram in sorted(self._normalize_lexical_unit_ngram(bigram)): + if normalized_bigram in lookup_map: + bigram_frames = lookup_map[normalized_bigram] + possible_frames += bigram_frames # remove duplicates, while preserving order # https://stackoverflow.com/questions/1653970/does-python-have-an-ordered-set/53657523#53657523 return list(dict.fromkeys(possible_frames)) - def _normalize_lexical_unit_ngram(self, ngram: list[str]) -> str: - return "_".join([self.loader.normalize_lexical_unit_text(tok) for tok in ngram]) + def _normalize_lexical_unit_ngram(self, ngram: list[str]) -> set[str]: + norm_toks = [ + setify(self.loader.normalize_lexical_unit_text(tok)) for tok in ngram + ] + return {"_".join(combo) for combo in product(*norm_toks)} def normalize_name(name: str) -> str: @@ -114,3 +123,12 @@ def normalize_name(name: str) -> str: Normalize a frame or element name to be lowercase and without underscores """ return name.lower().replace("_", "") + + +def setify(input: str | set[str]) -> set[str]: + """ + Convert a string or set to a set + """ + if isinstance(input, str): + return {input} + return input diff --git a/frame_semantic_transformer/data/TaskSampleDataset.py b/frame_semantic_transformer/data/TaskSampleDataset.py index fd827b5..0c9fb23 100644 --- a/frame_semantic_transformer/data/TaskSampleDataset.py +++ b/frame_semantic_transformer/data/TaskSampleDataset.py @@ -19,7 +19,7 @@ class TaskSampleDataset(Dataset[Any]): samples: Sequence[TaskSample] - augmentation: Optional[Callable[[str, str], tuple[str, str]]] = None + augmentation: Optional[Callable[[TaskSample], TaskSample]] = None tokenizer: T5TokenizerFast def __init__( @@ -59,10 +59,11 @@ def parse_sample( self, sample: TaskSample ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - input = sample.get_input() - target = sample.get_target() + aug_sample = sample if self.augmentation: - input, target = self.augmentation(input, target) + aug_sample = self.augmentation(sample) + input = aug_sample.get_input() + target = aug_sample.get_target() input_encoding = self.tokenizer( input, diff --git a/frame_semantic_transformer/data/augmentations/DataAugmentation.py b/frame_semantic_transformer/data/augmentations/DataAugmentation.py index bcf5d0b..11304d0 100644 --- a/frame_semantic_transformer/data/augmentations/DataAugmentation.py +++ b/frame_semantic_transformer/data/augmentations/DataAugmentation.py @@ -1,29 +1,41 @@ from __future__ import annotations + +from typing import Callable, Union from abc import ABC, abstractmethod from random import uniform +from frame_semantic_transformer.data.tasks.TaskSample import TaskSample + + +ProbabilityType = Union[float, Callable[[TaskSample], float]] + class DataAugmentation(ABC): """ Base class for data augmentations on training data """ - probability: float + probability: ProbabilityType - def __init__(self, probability: float): + def __init__(self, probability: ProbabilityType): self.probability = probability - def __call__(self, input: str, output: str) -> tuple[str, str]: + def __call__(self, task_sample: TaskSample) -> TaskSample: """ randomly apply this augmentation in proportion to self.probability """ rand_val = uniform(0, 1.0) - if rand_val > self.probability: - return (input, output) - return self.apply_augmentation(input, output) + if rand_val > self.get_probability(task_sample): + return task_sample + return self.apply_augmentation(task_sample) + + def get_probability(self, task_sample: TaskSample) -> float: + if callable(self.probability): + return self.probability(task_sample) + return self.probability @abstractmethod - def apply_augmentation(self, input: str, output: str) -> tuple[str, str]: + def apply_augmentation(self, task_sample: TaskSample) -> TaskSample: """ Main logic for subclasses to implement """ diff --git a/frame_semantic_transformer/data/augmentations/DoubleQuotesAugmentation.py b/frame_semantic_transformer/data/augmentations/DoubleQuotesAugmentation.py new file mode 100644 index 0000000..0af502d --- /dev/null +++ b/frame_semantic_transformer/data/augmentations/DoubleQuotesAugmentation.py @@ -0,0 +1,59 @@ +from __future__ import annotations +import random + +from frame_semantic_transformer.data.augmentations.modification_helpers import ( + splice_text, +) +from frame_semantic_transformer.data.tasks import TaskSample +from .DataAugmentation import DataAugmentation +from .modification_helpers.get_sample_text import get_sample_text + + +LATEX_QUOTES = ["``", "''"] +STANDARD_QUOTE = '"' +ALL_QUOTES = LATEX_QUOTES + [STANDARD_QUOTE] + + +class DoubleQuotesAugmentation(DataAugmentation): + def apply_augmentation(self, task_sample: TaskSample) -> TaskSample: + sample_text = get_sample_text(task_sample) + + # if standard quotes are used, convert to latex quotes, and vice versa + to_latex = STANDARD_QUOTE in sample_text + from_quotes = [STANDARD_QUOTE] if to_latex else LATEX_QUOTES + to_quotes = LATEX_QUOTES if to_latex else [STANDARD_QUOTE] + + updated_sample = task_sample + while count_instances(sample_text, from_quotes) > 0: + quote, start_loc = find_first_instance(sample_text, from_quotes) + try: + updated_sample = splice_text( + updated_sample, + lambda _text, _critical_indices: ( + start_loc, + len(quote), + random.choice(to_quotes), + ), + ) + sample_text = get_sample_text(updated_sample) + except ValueError: + # The splice failed, so just return the sample + return updated_sample + + return updated_sample + + +def count_instances(text: str, substrings: list[str]) -> int: + return sum(text.count(substring) for substring in substrings) + + +def find_first_instance(text: str, substrings: list[str]) -> tuple[str, int]: + """ + Find the first instance of any of the substrings in the text. Returns the substring and the + start location of the substring. + """ + for substring in substrings: + start_loc = text.find(substring) + if start_loc >= 0: + return substring, start_loc + raise ValueError(f"Could not find any of {substrings} in {text}") diff --git a/frame_semantic_transformer/data/augmentations/KeyboardAugmentation.py b/frame_semantic_transformer/data/augmentations/KeyboardAugmentation.py new file mode 100644 index 0000000..ee471ad --- /dev/null +++ b/frame_semantic_transformer/data/augmentations/KeyboardAugmentation.py @@ -0,0 +1,48 @@ +from __future__ import annotations + +from nlpaug.augmenter.char import KeyboardAug + +from frame_semantic_transformer.data.augmentations.modification_helpers import ( + modify_text_without_changing_length, +) +from frame_semantic_transformer.data.tasks import TaskSample +from .DataAugmentation import DataAugmentation, ProbabilityType + + +class KeyboardAugmentation(DataAugmentation): + """ + Wrapper about nlpaug's KeyboardAugmenter + Attempts to make spelling mistakes similar to what a user might make + """ + + augmenter: KeyboardAug + + def __init__(self, probability: ProbabilityType): + super().__init__(probability) + self.augmenter = KeyboardAug( + include_special_char=False, aug_char_p=0.1, aug_word_p=0.1 + ) + self.augmenter.include_detail = True + + def apply_augmentation(self, task_sample: TaskSample) -> TaskSample: + def augment_sent(sentence: str) -> str: + # this augmentation removes spaces around punctuation, so just manually do the changes + _, changes = self.augmenter.augment(sentence)[0] + new_sentence = sentence + for change in changes: + # sometimes this augmenter changes token lengths, which we don't want + # just skip the changes if that happens + if len(change["orig_token"]) != len(change["new_token"]): + return new_sentence + if change["orig_start_pos"] != change["new_start_pos"]: + return new_sentence + start_pos = change["orig_start_pos"] + end_pos = change["orig_start_pos"] + len(change["orig_token"]) + new_sentence = ( + new_sentence[:start_pos] + + change["new_token"] + + new_sentence[end_pos:] + ) + return new_sentence + + return modify_text_without_changing_length(task_sample, augment_sent) diff --git a/frame_semantic_transformer/data/augmentations/LowercaseAugmentation.py b/frame_semantic_transformer/data/augmentations/LowercaseAugmentation.py index 19c67a6..ecc76ea 100644 --- a/frame_semantic_transformer/data/augmentations/LowercaseAugmentation.py +++ b/frame_semantic_transformer/data/augmentations/LowercaseAugmentation.py @@ -1,14 +1,12 @@ from __future__ import annotations +from frame_semantic_transformer.data.augmentations.modification_helpers import ( + modify_text_without_changing_length, +) + +from frame_semantic_transformer.data.tasks import TaskSample from .DataAugmentation import DataAugmentation class LowercaseAugmentation(DataAugmentation): - def apply_augmentation(self, input: str, output: str) -> tuple[str, str]: - task_def_index = input.find(":") - task_def = input[:task_def_index] - input_contents = input[task_def_index:] - # only lowercase the content, not the task definition - return ( - task_def + input_contents.lower(), - output.lower(), - ) + def apply_augmentation(self, task_sample: TaskSample) -> TaskSample: + return modify_text_without_changing_length(task_sample, str.lower) diff --git a/frame_semantic_transformer/data/augmentations/RemoveContractionsAugmentation.py b/frame_semantic_transformer/data/augmentations/RemoveContractionsAugmentation.py deleted file mode 100644 index 66dd530..0000000 --- a/frame_semantic_transformer/data/augmentations/RemoveContractionsAugmentation.py +++ /dev/null @@ -1,25 +0,0 @@ -from __future__ import annotations -from .DataAugmentation import DataAugmentation -import re - - -def remove_contractions(text: str) -> str: - new_text = text.replace("won't", "will not") - new_text = new_text.replace("can't", "cannot") - new_text = re.sub(r"n't(\b)", r" not\1", new_text) - new_text = re.sub(r"'ll(\b)", r" will\1", new_text) - new_text = re.sub(r"'m(\b)", r" am\1", new_text) - new_text = re.sub(r"'re(\b)", r" are\1", new_text) - new_text = re.sub(r"'ve(\b)", r" have\1", new_text) - return new_text - - -class RemoveContractionsAugmentation(DataAugmentation): - def apply_augmentation(self, input: str, output: str) -> tuple[str, str]: - if "*'" in input or "*'" in output or "*n'" in input or "*n'" in output: - return (input, output) - - return ( - remove_contractions(input), - remove_contractions(output), - ) diff --git a/frame_semantic_transformer/data/augmentations/RemoveEndPunctuationAugmentation.py b/frame_semantic_transformer/data/augmentations/RemoveEndPunctuationAugmentation.py index 5156d88..bf350e9 100644 --- a/frame_semantic_transformer/data/augmentations/RemoveEndPunctuationAugmentation.py +++ b/frame_semantic_transformer/data/augmentations/RemoveEndPunctuationAugmentation.py @@ -1,13 +1,30 @@ from __future__ import annotations -from .DataAugmentation import DataAugmentation import re +from frame_semantic_transformer.data.augmentations.modification_helpers import ( + splice_text, +) +from frame_semantic_transformer.data.augmentations.modification_helpers.splice_text import ( + is_valid_splice, +) + +from frame_semantic_transformer.data.tasks import TaskSample +from .DataAugmentation import DataAugmentation REMOVE_END_PUNCT_RE = r"\s*[.?!]\s*$" class RemoveEndPunctuationAugmentation(DataAugmentation): - def apply_augmentation(self, input: str, output: str) -> tuple[str, str]: - return ( - re.sub(REMOVE_END_PUNCT_RE, "", input), - re.sub(REMOVE_END_PUNCT_RE, "", output), - ) + def apply_augmentation(self, task_sample: TaskSample) -> TaskSample: + def splice_end_punct_cb( + sentence: str, critical_indices: list[int] + ) -> tuple[int, int, str] | None: + match = re.search(REMOVE_END_PUNCT_RE, sentence) + if match is None: + return None + start, end = match.span() + del_len = end - start + if not is_valid_splice(start, del_len, critical_indices): + return None + return start, del_len, "" + + return splice_text(task_sample, splice_end_punct_cb) diff --git a/frame_semantic_transformer/data/augmentations/SimpleMisspellingAugmentation.py b/frame_semantic_transformer/data/augmentations/SimpleMisspellingAugmentation.py new file mode 100644 index 0000000..f498c9e --- /dev/null +++ b/frame_semantic_transformer/data/augmentations/SimpleMisspellingAugmentation.py @@ -0,0 +1,49 @@ +from __future__ import annotations + +import random +import string +from frame_semantic_transformer.data.augmentations.modification_helpers import ( + modify_text_without_changing_length, +) + +from frame_semantic_transformer.data.tasks import TaskSample +from .DataAugmentation import DataAugmentation, ProbabilityType + + +class SimpleMisspellingAugmentation(DataAugmentation): + + max_misspellings_per_sentence: int + min_misspellings_per_sentence: int + + def __init__( + self, + probability: ProbabilityType, + max_misspellings_per_sentence: int = 10, + min_misspellings_per_sentence: int = 1, + ): + super().__init__(probability) + self.max_misspellings_per_sentence = max_misspellings_per_sentence + self.min_misspellings_per_sentence = min_misspellings_per_sentence + + def apply_augmentation(self, task_sample: TaskSample) -> TaskSample: + def misspell_cb(sentence: str) -> str: + num_mispellings = random.randint( + self.min_misspellings_per_sentence, self.max_misspellings_per_sentence + ) + new_sentence = sentence + for _ in range(num_mispellings): + index = random.randint(0, len(sentence) - 1) + char = sentence[index] + new_char = char + if char.isupper(): + new_char = random.choice(string.ascii_uppercase) + elif char.islower(): + new_char = random.choice(string.ascii_lowercase) + elif char.isdigit(): + new_char = random.choice(string.digits) + new_sentence = ( + new_sentence[:index] + new_char + new_sentence[index + 1 :] + ) + return new_sentence + + return modify_text_without_changing_length(task_sample, misspell_cb) diff --git a/frame_semantic_transformer/data/augmentations/StripPunctuationAugmentation.py b/frame_semantic_transformer/data/augmentations/StripPunctuationAugmentation.py new file mode 100644 index 0000000..17f839a --- /dev/null +++ b/frame_semantic_transformer/data/augmentations/StripPunctuationAugmentation.py @@ -0,0 +1,63 @@ +from __future__ import annotations +import random +import string + +from frame_semantic_transformer.data.augmentations.modification_helpers import ( + splice_text, +) +from frame_semantic_transformer.data.tasks import TaskSample +from .DataAugmentation import DataAugmentation, ProbabilityType +from .modification_helpers.get_sample_text import get_sample_text + + +class StripPunctuationAugmentation(DataAugmentation): + + max_to_remove: int + min_to_remove: int + + def __init__( + self, + probability: ProbabilityType, + max_to_remove: int = 5, + min_to_remove: int = 1, + ): + self.max_to_remove = max_to_remove + self.min_to_remove = min_to_remove + super().__init__(probability) + + def apply_augmentation(self, task_sample: TaskSample) -> TaskSample: + sample_text = get_sample_text(task_sample) + punctuation_indices = find_punctuation_indices(sample_text) + if len(punctuation_indices) == 0: + return task_sample + + updated_sample = task_sample + for _ in range(random.randint(self.min_to_remove, self.max_to_remove)): + if len(punctuation_indices) == 0: + break + punctuation_index = random.choice(punctuation_indices) + punctuation_indices.remove(punctuation_index) + punctuation_indices = [ + i - 1 if i > punctuation_index else i for i in punctuation_indices + ] + try: + updated_sample = splice_text( + updated_sample, + lambda _text, _critical_indices: ( + punctuation_index, + 1, + "", + ), + ) + except ValueError: + # The splice failed, so just return the sample + return updated_sample + return updated_sample + + +def find_punctuation_indices(text: str) -> list[int]: + """ + Find the indices of all punctuation in the text. + """ + # TODO: This would be more efficient with a regex + return [i for i, char in enumerate(text) if char in string.punctuation] diff --git a/frame_semantic_transformer/data/augmentations/SynonymAugmentation.py b/frame_semantic_transformer/data/augmentations/SynonymAugmentation.py new file mode 100644 index 0000000..f1d309b --- /dev/null +++ b/frame_semantic_transformer/data/augmentations/SynonymAugmentation.py @@ -0,0 +1,42 @@ +from __future__ import annotations + +from nlpaug.augmenter.word import SynonymAug + +from frame_semantic_transformer.data.augmentations.modification_helpers import ( + splice_text, +) +from frame_semantic_transformer.data.augmentations.modification_helpers.splice_text import ( + is_valid_splice, +) + +from frame_semantic_transformer.data.tasks import TaskSample +from .DataAugmentation import DataAugmentation, ProbabilityType + + +class SynonymAugmentation(DataAugmentation): + """ + Wrapper about nlpaug's SynonymAugmenter + """ + + augmenter: SynonymAug + + def __init__(self, probability: ProbabilityType): + super().__init__(probability) + self.augmenter = SynonymAug(aug_max=1, aug_min=1) + self.augmenter.include_detail = True + + def apply_augmentation(self, task_sample: TaskSample) -> TaskSample: + def splice_end_punct_cb( + sentence: str, critical_indices: list[int] + ) -> tuple[int, int, str] | None: + _, changes = self.augmenter.augment(sentence)[0] + if len(changes) == 0: + return None + start = changes[0]["orig_start_pos"] + new_text = changes[0]["new_token"] + del_len = len(changes[0]["orig_token"]) + if not is_valid_splice(start, del_len, critical_indices): + return None + return start, del_len, new_text + + return splice_text(task_sample, splice_end_punct_cb) diff --git a/frame_semantic_transformer/data/augmentations/UppercaseAugmentation.py b/frame_semantic_transformer/data/augmentations/UppercaseAugmentation.py new file mode 100644 index 0000000..4a1ef63 --- /dev/null +++ b/frame_semantic_transformer/data/augmentations/UppercaseAugmentation.py @@ -0,0 +1,20 @@ +from __future__ import annotations +from frame_semantic_transformer.data.augmentations.modification_helpers import ( + modify_text_without_changing_length, +) + +from frame_semantic_transformer.data.tasks import TaskSample +from .DataAugmentation import DataAugmentation + + +class UppercaseAugmentation(DataAugmentation): + def apply_augmentation(self, task_sample: TaskSample) -> TaskSample: + def safe_uppercase(text: str) -> str: + new_text = text.upper() + # it turns out the some characters, like "fi", become 2 chars when uppercased + # just check to make sure we're not in that case here + if len(new_text) != len(text): + return text + return new_text + + return modify_text_without_changing_length(task_sample, safe_uppercase) diff --git a/frame_semantic_transformer/data/augmentations/__init__.py b/frame_semantic_transformer/data/augmentations/__init__.py index a7c9ab4..3ed8ee6 100644 --- a/frame_semantic_transformer/data/augmentations/__init__.py +++ b/frame_semantic_transformer/data/augmentations/__init__.py @@ -1,13 +1,23 @@ from .chain_augmentations import chain_augmentations from .DataAugmentation import DataAugmentation from .LowercaseAugmentation import LowercaseAugmentation -from .RemoveContractionsAugmentation import RemoveContractionsAugmentation +from .UppercaseAugmentation import UppercaseAugmentation +from .SimpleMisspellingAugmentation import SimpleMisspellingAugmentation +from .KeyboardAugmentation import KeyboardAugmentation +from .SynonymAugmentation import SynonymAugmentation +from .DoubleQuotesAugmentation import DoubleQuotesAugmentation from .RemoveEndPunctuationAugmentation import RemoveEndPunctuationAugmentation +from .StripPunctuationAugmentation import StripPunctuationAugmentation __all__ = ( "chain_augmentations", "DataAugmentation", + "DoubleQuotesAugmentation", "LowercaseAugmentation", - "RemoveContractionsAugmentation", + "UppercaseAugmentation", + "KeyboardAugmentation", + "SimpleMisspellingAugmentation", + "StripPunctuationAugmentation", + "SynonymAugmentation", "RemoveEndPunctuationAugmentation", ) diff --git a/frame_semantic_transformer/data/augmentations/chain_augmentations.py b/frame_semantic_transformer/data/augmentations/chain_augmentations.py index cb46ede..c694e55 100644 --- a/frame_semantic_transformer/data/augmentations/chain_augmentations.py +++ b/frame_semantic_transformer/data/augmentations/chain_augmentations.py @@ -1,17 +1,18 @@ from __future__ import annotations from typing import Callable, Sequence +from frame_semantic_transformer.data.tasks import TaskSample + from .DataAugmentation import DataAugmentation def chain_augmentations( augmentations: Sequence[DataAugmentation], -) -> Callable[[str, str], tuple[str, str]]: - def chained_augmentation(input: str, output: str) -> tuple[str, str]: +) -> Callable[[TaskSample], TaskSample]: + def chained_augmentation(input: TaskSample) -> TaskSample: chained_input = input - chained_output = output for augmentation in augmentations: - chained_input, chained_output = augmentation(chained_input, chained_output) - return chained_input, chained_output + chained_input = augmentation(chained_input) + return chained_input return chained_augmentation diff --git a/frame_semantic_transformer/data/augmentations/modification_helpers/__init__.py b/frame_semantic_transformer/data/augmentations/modification_helpers/__init__.py new file mode 100644 index 0000000..35ed69d --- /dev/null +++ b/frame_semantic_transformer/data/augmentations/modification_helpers/__init__.py @@ -0,0 +1,9 @@ +from .modify_text_without_changing_length import modify_text_without_changing_length +from .splice_text import splice_text +from .get_sample_text import get_sample_text + +__all__ = [ + "modify_text_without_changing_length", + "splice_text", + "get_sample_text", +] diff --git a/frame_semantic_transformer/data/augmentations/modification_helpers/get_sample_text.py b/frame_semantic_transformer/data/augmentations/modification_helpers/get_sample_text.py new file mode 100644 index 0000000..9b335e0 --- /dev/null +++ b/frame_semantic_transformer/data/augmentations/modification_helpers/get_sample_text.py @@ -0,0 +1,18 @@ +from __future__ import annotations + +from frame_semantic_transformer.data.tasks import ( + ArgumentsExtractionSample, + FrameClassificationSample, + TaskSample, + TriggerIdentificationSample, +) + + +def get_sample_text(sample: TaskSample) -> str: + if isinstance(sample, ArgumentsExtractionSample): + return sample.task.text + if isinstance(sample, FrameClassificationSample): + return sample.task.text + if isinstance(sample, TriggerIdentificationSample): + return sample.task.text + raise ValueError(f"Unknown sample type: {type(sample)}") diff --git a/frame_semantic_transformer/data/augmentations/modification_helpers/modify_text_without_changing_length.py b/frame_semantic_transformer/data/augmentations/modification_helpers/modify_text_without_changing_length.py new file mode 100644 index 0000000..4076468 --- /dev/null +++ b/frame_semantic_transformer/data/augmentations/modification_helpers/modify_text_without_changing_length.py @@ -0,0 +1,53 @@ +from __future__ import annotations +from dataclasses import replace +from typing import Callable + +from frame_semantic_transformer.data.tasks import ( + ArgumentsExtractionSample, + FrameClassificationSample, + TaskSample, + TriggerIdentificationSample, +) + + +def modify_text_without_changing_length( + task_sample: TaskSample, modify_text_cb: Callable[[str], str] +) -> TaskSample: + """ + Helper to modify the text of a TaskSample without changing the length of the text + This is a simple augmentation since it doesn't require rewriting indices + + This takes the task sample and a lambda function that takes the text of the task sample + and returns the modified text. It then modifies the text of the task sample and returns + """ + + def modify_text(text: str) -> str: + new_text = modify_text_cb(text) + if len(new_text) != len(text): + raise ValueError( + f"Text length changed during augmentation: {text} -> {new_text}" + ) + return new_text + + if isinstance(task_sample, ArgumentsExtractionSample): + new_text = modify_text(task_sample.task.text) + return replace( + task_sample, + task=replace(task_sample.task, text=new_text), + ) + + if isinstance(task_sample, FrameClassificationSample): + new_text = modify_text(task_sample.task.text) + return replace( + task_sample, + task=replace(task_sample.task, text=new_text), + ) + + if isinstance(task_sample, TriggerIdentificationSample): + new_text = modify_text(task_sample.task.text) + return replace( + task_sample, + task=replace(task_sample.task, text=new_text), + ) + + raise ValueError(f"Unknown task sample type: {type(task_sample)}") diff --git a/frame_semantic_transformer/data/augmentations/modification_helpers/splice_text.py b/frame_semantic_transformer/data/augmentations/modification_helpers/splice_text.py new file mode 100644 index 0000000..7434b1b --- /dev/null +++ b/frame_semantic_transformer/data/augmentations/modification_helpers/splice_text.py @@ -0,0 +1,103 @@ +from __future__ import annotations +from dataclasses import replace +from typing import Callable + +from frame_semantic_transformer.data.tasks import ( + ArgumentsExtractionSample, + FrameClassificationSample, + TaskSample, + TriggerIdentificationSample, +) + + +def is_valid_splice( + start_loc: int, + delete_num: int, + critical_indices: list[int], +) -> bool: + """ + Helper to check if a splice is valid. A splice is valid if it does not delete any of the critical indices. + """ + for index in critical_indices: + if index >= start_loc and index < start_loc + delete_num: + return False + return True + + +def splice_text( + task_sample: TaskSample, + modify_text_cb: Callable[[str, list[int]], tuple[int, int, str] | None], +) -> TaskSample: + """ + Helper to modify the text of a TaskSample that may change the length of the text. This is + a more complex augmentation since it requires rewriting indices, and can potentially break + the sample. This is loosely modified on the `splice()` function from javascript, where + a start position is given, followed by the number of chars to remove, and then the new + string to insert. + + This takes the task sample and a lambda function. The lambda funtion that takes the text of the + task sample and a list of critical indices to the task, which must not be deleted during the splice. + The lambda returns a tuple of the start position, the number of chars to remove, and the new string + to insert. + """ + + def modify_text( + text: str, critical_indices: list[int] + ) -> tuple[str, Callable[[int], int]]: + modify_results = modify_text_cb(text, critical_indices) + if modify_results is None: + return text, lambda i: i + start_loc, delete_num, insert_text = modify_results + if not is_valid_splice(start_loc, delete_num, critical_indices): + raise ValueError( + f"Critical index was deleted during splice. This is not allowed: {text}, {start_loc}, {delete_num}" + ) + index_modifier = ( + lambda i: i if i <= start_loc else i + len(insert_text) - delete_num + ) + new_text = text[:start_loc] + insert_text + text[start_loc + delete_num :] + return new_text, index_modifier + + if isinstance(task_sample, ArgumentsExtractionSample): + critical_indices = [task_sample.task.trigger_loc] + for frame_element in task_sample.frame_elements: + critical_indices.append(frame_element.start_loc) + critical_indices.append(frame_element.end_loc) + new_text, index_modifier = modify_text(task_sample.task.text, critical_indices) + return ArgumentsExtractionSample( + frame_elements=[ + replace( + elm, + start_loc=index_modifier(elm.start_loc), + end_loc=index_modifier(elm.end_loc), + ) + for elm in task_sample.frame_elements + ], + task=replace( + task_sample.task, + text=new_text, + trigger_loc=index_modifier(task_sample.task.trigger_loc), + ), + ) + + if isinstance(task_sample, FrameClassificationSample): + critical_indices = [task_sample.task.trigger_loc] + new_text, index_modifier = modify_text(task_sample.task.text, critical_indices) + return FrameClassificationSample( + frame=task_sample.frame, + task=replace( + task_sample.task, + text=new_text, + trigger_loc=index_modifier(task_sample.task.trigger_loc), + ), + ) + + if isinstance(task_sample, TriggerIdentificationSample): + critical_indices = task_sample.trigger_locs + new_text, index_modifier = modify_text(task_sample.task.text, critical_indices) + return TriggerIdentificationSample( + trigger_locs=[index_modifier(loc) for loc in task_sample.trigger_locs], + task=replace(task_sample.task, text=new_text), + ) + + raise ValueError(f"Unknown task sample type: {type(task_sample)}") diff --git a/frame_semantic_transformer/data/frame_types.py b/frame_semantic_transformer/data/frame_types.py index a73b3bb..baf8c88 100644 --- a/frame_semantic_transformer/data/frame_types.py +++ b/frame_semantic_transformer/data/frame_types.py @@ -25,6 +25,9 @@ class FrameAnnotatedSentence: text: str annotations: list[FrameAnnotation] + # if this text isn't annotated with every trigger loc, we shouldn't generate a trigger id task from it + # but it's still useful for frame classification and argument extraction tasks + skip_trigger_identification_task: bool = False @dataclass diff --git a/frame_semantic_transformer/data/loaders/framenet17/Framenet17InferenceLoader.py b/frame_semantic_transformer/data/loaders/framenet17/Framenet17InferenceLoader.py index 8c998a7..0764a3a 100644 --- a/frame_semantic_transformer/data/loaders/framenet17/Framenet17InferenceLoader.py +++ b/frame_semantic_transformer/data/loaders/framenet17/Framenet17InferenceLoader.py @@ -1,19 +1,26 @@ from __future__ import annotations import re -from nltk.stem import PorterStemmer -from nltk.corpus import framenet as fn - -from frame_semantic_transformer.data.loaders.framenet17.ensure_framenet_downloaded import ( - ensure_framenet_downloaded, +from nltk.stem import ( + PorterStemmer, + LancasterStemmer, + SnowballStemmer, + WordNetLemmatizer, ) +from nltk.corpus import framenet as fn - +from .ensure_framenet_downloaded import ensure_framenet_downloaded +from .ensure_wordnet_downloaded import ensure_wordnet_downloaded from frame_semantic_transformer.data.frame_types import Frame from ..loader import InferenceLoader -base_stemmer = PorterStemmer() +porter_stemmer = PorterStemmer() +lancaster_stemmer = LancasterStemmer() +snowball_stemmer = SnowballStemmer("english") +wordnet_lemmatizer = WordNetLemmatizer() + +WORDNET_LEMMATIZER_POS = ["a", "r", "n", "v", "s"] LOW_PRIORITY_LONGER_LUS = {"back", "down", "make", "take", "have", "into", "come"} @@ -26,6 +33,7 @@ class Framenet17InferenceLoader(InferenceLoader): def setup(self) -> None: ensure_framenet_downloaded() + ensure_wordnet_downloaded() def load_frames(self) -> list[Frame]: """ @@ -46,14 +54,23 @@ def load_frames(self) -> list[Frame]: frames.append(frame) return frames - def normalize_lexical_unit_text(self, lu: str) -> str: + def normalize_lexical_unit_text(self, lu: str) -> str | set[str]: """ Normalize a lexical unit like "takes.v" to "take". """ normalized_lu = lu.lower() normalized_lu = re.sub(r"\.[a-zA-Z]+$", "", normalized_lu) normalized_lu = re.sub(r"[^a-z0-9 ]", "", normalized_lu) - return base_stemmer.stem(normalized_lu.strip()) + normalized_lu = normalized_lu.strip() + norm_lus = { + porter_stemmer.stem(normalized_lu), + lancaster_stemmer.stem(normalized_lu), + snowball_stemmer.stem(normalized_lu), + } + # try every possible part of speech for the wordnet lemmatizer + for pos in WORDNET_LEMMATIZER_POS: + norm_lus.add(wordnet_lemmatizer.lemmatize(normalized_lu, pos=pos)) + return norm_lus def prioritize_lexical_unit(self, lu: str) -> bool: """ diff --git a/frame_semantic_transformer/data/loaders/framenet17/Framenet17TrainingLoader.py b/frame_semantic_transformer/data/loaders/framenet17/Framenet17TrainingLoader.py index 45095d2..710b4c0 100644 --- a/frame_semantic_transformer/data/loaders/framenet17/Framenet17TrainingLoader.py +++ b/frame_semantic_transformer/data/loaders/framenet17/Framenet17TrainingLoader.py @@ -4,13 +4,19 @@ from nltk.corpus import framenet as fn from frame_semantic_transformer.data.augmentations import ( + DoubleQuotesAugmentation, + KeyboardAugmentation, LowercaseAugmentation, - RemoveContractionsAugmentation, + SimpleMisspellingAugmentation, RemoveEndPunctuationAugmentation, + StripPunctuationAugmentation, + SynonymAugmentation, + UppercaseAugmentation, ) from frame_semantic_transformer.data.augmentations.DataAugmentation import ( DataAugmentation, ) +from frame_semantic_transformer.data.tasks import TriggerIdentificationSample from .ensure_framenet_downloaded import ensure_framenet_downloaded from .sesame_data_splits import SESAME_DEV_FILES, SESAME_TEST_FILES @@ -37,55 +43,111 @@ def load_framenet_samples( return samples +def load_framenet_samples_from_exemplars() -> list[FrameAnnotatedSentence]: + samples: list[FrameAnnotatedSentence] = [] + # make sure we don't include exemplars if we've already included them in the training data + all_doc_samples_text = {sample.text for sample in load_framenet_samples()} + for sent in fn.exemplars(): + annotated_sent = parse_annotated_sentence_from_framenet_sentence( + sent, skip_trigger_identification_task=True + ) + if annotated_sent and annotated_sent.text not in all_doc_samples_text: + samples.append(annotated_sent) + return samples + + def parse_annotated_sentences_from_framenet_doc( fn_doc: dict[str, Any] ) -> list[FrameAnnotatedSentence]: annotated_sentences = [] for sentence in fn_doc["sentence"]: - sentence_text = sentence["text"] - frame_annotations: list[FrameAnnotation] = [] - for fn_annotation in sentence["annotationSet"]: - if ( - "FE" in fn_annotation - and "Target" in fn_annotation - and "frame" in fn_annotation - ): - frame_annotations.append( - FrameAnnotation( - frame=fn_annotation["frame"]["name"], - trigger_locs=[loc[0] for loc in fn_annotation["Target"]], - frame_elements=[ - FrameElementAnnotation( - start_loc=fn_element[0], - end_loc=fn_element[1], - name=fn_element[2], - ) - for fn_element in fn_annotation["FE"][0] - ], - ) - ) - if len(frame_annotations) > 0: - annotated_sentences.append( - FrameAnnotatedSentence( - text=sentence_text, annotations=frame_annotations + annotated_sentence = parse_annotated_sentence_from_framenet_sentence(sentence) + if annotated_sentence: + annotated_sentences.append(annotated_sentence) + return annotated_sentences + + +def parse_annotated_sentence_from_framenet_sentence( + fn_sentence: dict[str, Any], + skip_trigger_identification_task: bool = False, +) -> FrameAnnotatedSentence | None: + sentence_text = fn_sentence["text"] + frame_annotations: list[FrameAnnotation] = [] + for fn_annotation in fn_sentence["annotationSet"]: + if ( + "FE" in fn_annotation + and "Target" in fn_annotation + and "frame" in fn_annotation + ): + trigger_locs = [loc[0] for loc in fn_annotation["Target"]] + # filter out broken annotations + for trigger_loc in trigger_locs: + if trigger_loc >= len(sentence_text): + return None + frame_annotations.append( + FrameAnnotation( + frame=fn_annotation["frame"]["name"], + trigger_locs=trigger_locs, + frame_elements=[ + FrameElementAnnotation( + start_loc=fn_element[0], + end_loc=fn_element[1], + name=fn_element[2], + ) + for fn_element in fn_annotation["FE"][0] + ], ) ) - return annotated_sentences + if len(frame_annotations) > 0: + return FrameAnnotatedSentence( + text=sentence_text, + annotations=frame_annotations, + skip_trigger_identification_task=skip_trigger_identification_task, + ) + return None class Framenet17TrainingLoader(TrainingLoader): + include_exemplars: bool + + def __init__(self, include_exemplars: bool = False) -> None: + super().__init__() + self.include_exemplars = include_exemplars + def setup(self) -> None: ensure_framenet_downloaded() def get_augmentations(self) -> list[DataAugmentation]: return [ - RemoveEndPunctuationAugmentation(0.3), - LowercaseAugmentation(0.2), - RemoveContractionsAugmentation(0.2), + RemoveEndPunctuationAugmentation(0.5), + DoubleQuotesAugmentation(0.2), + StripPunctuationAugmentation(0.2), + SynonymAugmentation( + lambda sample: 0.2 + if isinstance(sample, TriggerIdentificationSample) + else 0.05 + ), + KeyboardAugmentation( + lambda sample: 0.3 + if isinstance(sample, TriggerIdentificationSample) + else 0.05 + ), + SimpleMisspellingAugmentation( + lambda sample: 0.3 + if isinstance(sample, TriggerIdentificationSample) + else 0.05 + ), + LowercaseAugmentation(0.1), + UppercaseAugmentation(0.1), ] def load_training_data(self) -> list[FrameAnnotatedSentence]: - return load_framenet_samples(exclude_docs=SESAME_DEV_FILES + SESAME_TEST_FILES) + training_samples = load_framenet_samples( + exclude_docs=SESAME_DEV_FILES + SESAME_TEST_FILES + ) + if self.include_exemplars: + training_samples += load_framenet_samples_from_exemplars() + return training_samples def load_test_data(self) -> list[FrameAnnotatedSentence]: return load_framenet_samples(include_docs=SESAME_TEST_FILES) diff --git a/frame_semantic_transformer/data/loaders/framenet17/ensure_wordnet_downloaded.py b/frame_semantic_transformer/data/loaders/framenet17/ensure_wordnet_downloaded.py new file mode 100644 index 0000000..98ad527 --- /dev/null +++ b/frame_semantic_transformer/data/loaders/framenet17/ensure_wordnet_downloaded.py @@ -0,0 +1,8 @@ +import nltk + + +def ensure_wordnet_downloaded() -> None: + try: + nltk.data.find("corpora/wordnet.zip") + except LookupError: + nltk.download("wordnet") diff --git a/frame_semantic_transformer/data/loaders/loader.py b/frame_semantic_transformer/data/loaders/loader.py index 52e9081..5a2e762 100644 --- a/frame_semantic_transformer/data/loaders/loader.py +++ b/frame_semantic_transformer/data/loaders/loader.py @@ -39,7 +39,7 @@ def load_frames(self) -> list[Frame]: pass @abstractmethod - def normalize_lexical_unit_text(self, lu: str) -> str: + def normalize_lexical_unit_text(self, lu: str) -> str | set[str]: """ Normalize a lexical unit like "takes.v" to "take". """ diff --git a/frame_semantic_transformer/data/loaders/propbank34/Propbank34TrainingLoader.py b/frame_semantic_transformer/data/loaders/propbank34/Propbank34TrainingLoader.py index aae0ea2..a1a0f50 100644 --- a/frame_semantic_transformer/data/loaders/propbank34/Propbank34TrainingLoader.py +++ b/frame_semantic_transformer/data/loaders/propbank34/Propbank34TrainingLoader.py @@ -8,9 +8,12 @@ from nltk.corpus.reader.conll import ConllCorpusReader from frame_semantic_transformer.data.augmentations import ( + KeyboardAugmentation, LowercaseAugmentation, - RemoveContractionsAugmentation, RemoveEndPunctuationAugmentation, + SimpleMisspellingAugmentation, + SynonymAugmentation, + UppercaseAugmentation, ) from frame_semantic_transformer.data.augmentations.DataAugmentation import ( DataAugmentation, @@ -146,9 +149,12 @@ def setup(self) -> None: def get_augmentations(self) -> list[DataAugmentation]: return [ - RemoveEndPunctuationAugmentation(0.3), - LowercaseAugmentation(0.2), - RemoveContractionsAugmentation(0.2), + RemoveEndPunctuationAugmentation(0.5), + SynonymAugmentation(0.3), + KeyboardAugmentation(0.3), + SimpleMisspellingAugmentation(0.1), + LowercaseAugmentation(0.1), + UppercaseAugmentation(0.1), ] def load_training_data(self) -> list[FrameAnnotatedSentence]: diff --git a/frame_semantic_transformer/data/tasks/Task.py b/frame_semantic_transformer/data/tasks/Task.py index 1c6150f..e32a5d5 100644 --- a/frame_semantic_transformer/data/tasks/Task.py +++ b/frame_semantic_transformer/data/tasks/Task.py @@ -1,8 +1,9 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import Any, Sequence +from typing import Any, Sequence, TYPE_CHECKING -from frame_semantic_transformer.data.LoaderDataCache import LoaderDataCache +if TYPE_CHECKING: + from frame_semantic_transformer.data.LoaderDataCache import LoaderDataCache class Task(ABC): diff --git a/frame_semantic_transformer/data/tasks/TaskSample.py b/frame_semantic_transformer/data/tasks/TaskSample.py index bfcac4e..b20c5e8 100644 --- a/frame_semantic_transformer/data/tasks/TaskSample.py +++ b/frame_semantic_transformer/data/tasks/TaskSample.py @@ -1,7 +1,9 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import Sequence -from frame_semantic_transformer.data.LoaderDataCache import LoaderDataCache +from typing import Sequence, TYPE_CHECKING + +if TYPE_CHECKING: + from frame_semantic_transformer.data.LoaderDataCache import LoaderDataCache from frame_semantic_transformer.data.tasks.Task import Task diff --git a/frame_semantic_transformer/data/tasks_from_annotated_sentences.py b/frame_semantic_transformer/data/tasks_from_annotated_sentences.py index 1d942d9..ab196de 100644 --- a/frame_semantic_transformer/data/tasks_from_annotated_sentences.py +++ b/frame_semantic_transformer/data/tasks_from_annotated_sentences.py @@ -44,11 +44,11 @@ def tasks_from_annotated_sentences( frame_elements=annotation.frame_elements, ) ) - - task_samples.append( - TriggerIdentificationSample( - task=TriggerIdentificationTask(text=annotated_sentence.text), - trigger_locs=trigger_locs, + if not annotated_sentence.skip_trigger_identification_task: + task_samples.append( + TriggerIdentificationSample( + task=TriggerIdentificationTask(text=annotated_sentence.text), + trigger_locs=trigger_locs, + ) ) - ) return task_samples diff --git a/frame_semantic_transformer/training/ModelRecorder.py b/frame_semantic_transformer/training/ModelRecorder.py new file mode 100644 index 0000000..24dfce0 --- /dev/null +++ b/frame_semantic_transformer/training/ModelRecorder.py @@ -0,0 +1,102 @@ +from __future__ import annotations +from dataclasses import dataclass +import shutil +from typing import Optional + +from transformers import T5ForConditionalGeneration, T5TokenizerFast + + +class ModelRecorder: + + output_dir: str + records: list[ModelSaveRecord] + + def __init__(self, output_dir: str): + self.output_dir = output_dir + self.records = [] + + def save_model( + self, + model: T5ForConditionalGeneration, + tokenizer: T5TokenizerFast, + epoch: int, + val_loss: float, + task_val_metrics: Optional[dict[str, float]] = None, + ) -> None: + save_path = self.get_save_path(epoch, val_loss, task_val_metrics) + model.save_pretrained(save_path) + tokenizer.save_pretrained(save_path) + self.records.append( + ModelSaveRecord( + epoch=epoch, + val_loss=val_loss, + task_val_metrics=task_val_metrics, + save_path=save_path, + ) + ) + + def remove_non_optimal_models(self) -> None: + """ + Delete any saved model that doesn't have the best validation loss or + best validation metric for any task. + """ + best_val_loss_model = _find_best_val_loss_model(self.records) + best_val_metric_models = _find_best_val_metric_models(self.records) + optimal_models = [best_val_loss_model, *best_val_metric_models.values()] + # clone the list while iterating so we can remove items in the loop + for record in [*self.records]: + if record not in optimal_models: + shutil.rmtree(record.save_path) + self.records.remove(record) + + def get_save_path( + self, + epoch: int, + val_loss: float, + task_val_metrics: Optional[dict[str, float]] = None, + ) -> str: + filename_parts = [ + f"epoch={epoch}", + f"val_loss={val_loss}", + ] + if task_val_metrics: + filename_parts.extend( + [f"{k.replace('-', '_')}={v}" for k, v in task_val_metrics.items()] + ) + + return f"{self.output_dir}/{'-'.join(filename_parts)}" + + +# moved outside of class for easier testing +def _find_best_val_metric_models( + records: list[ModelSaveRecord], +) -> dict[str, ModelSaveRecord]: + best_val_metric_models: dict[str, ModelSaveRecord] = {} + for record in records: + if record.task_val_metrics: + for task, val_metric in record.task_val_metrics.items(): + best_val_model = best_val_metric_models.get(task) + best_val_metrics = ( + best_val_model.task_val_metrics if best_val_model else {} + ) or {} + best_task_metric = best_val_metrics.get(task) + if best_task_metric is None or val_metric > best_task_metric: + best_val_metric_models[task] = record + return best_val_metric_models + + +# moved outside of class for easier testing +def _find_best_val_loss_model(records: list[ModelSaveRecord]) -> ModelSaveRecord: + best_val_loss_model = records[0] + for record in records[1:]: + if record.val_loss < best_val_loss_model.val_loss: + best_val_loss_model = record + return best_val_loss_model + + +@dataclass +class ModelSaveRecord: + epoch: int + val_loss: float + task_val_metrics: Optional[dict[str, float]] + save_path: str diff --git a/frame_semantic_transformer/training/TrainingModelWrapper.py b/frame_semantic_transformer/training/TrainingModelWrapper.py index e672760..af0db6f 100644 --- a/frame_semantic_transformer/training/TrainingModelWrapper.py +++ b/frame_semantic_transformer/training/TrainingModelWrapper.py @@ -12,6 +12,7 @@ from frame_semantic_transformer.data.LoaderDataCache import LoaderDataCache from frame_semantic_transformer.data.data_utils import trim_batch +from frame_semantic_transformer.training.ModelRecorder import ModelRecorder from .evaluate_batch import TaskEvalResults, calc_eval_metrics, evaluate_batch @@ -31,6 +32,7 @@ class TrainingModelWrapper(pl.LightningModule): val_metrics: dict[str, float] | None lr_gamma: float log_eval_failures: bool + model_recorder: ModelRecorder def __init__( self, @@ -43,6 +45,7 @@ def __init__( skip_initial_epochs_validation: int = 0, lr_gamma: float = 1.0, log_eval_failures: bool = False, + remove_non_optimal_models: bool = True, ): super().__init__() self.lr = lr @@ -55,6 +58,8 @@ def __init__( self.val_metrics = None self.lr_gamma = lr_gamma self.log_eval_failures = log_eval_failures + self.model_recorder = ModelRecorder(output_dir) + self.remove_non_optimal_models = remove_non_optimal_models def forward( self, @@ -130,21 +135,19 @@ def training_epoch_end(self, training_step_outputs: list[Any]) -> None: 4, ) self.log("train_loss", self.average_training_loss) - filename_parts = [ - f"epoch={self.current_epoch}", - f"train_loss={self.average_training_loss}", - f"val_loss={self.average_validation_loss}", - ] - if self.val_metrics: - filename_parts.extend([f"{k}={v}" for k, v in self.val_metrics.items()]) - - path = f"{self.output_dir}/{'--'.join(filename_parts)}" if ( not self.save_only_last_epoch or self.current_epoch == (self.trainer.max_epochs or 0) - 1 ): - self.tokenizer.save_pretrained(path) - self.model.save_pretrained(path) + self.model_recorder.save_model( + self.model, + self.tokenizer, + epoch=self.current_epoch, + val_loss=self.average_validation_loss, + task_val_metrics=self.val_metrics, + ) + if self.remove_non_optimal_models: + self.model_recorder.remove_non_optimal_models() def validation_epoch_end(self, validation_step_outputs: list[Any]) -> None: losses = [out["loss"].cpu() for out in validation_step_outputs] @@ -160,10 +163,13 @@ def validation_epoch_end(self, validation_step_outputs: list[Any]) -> None: metrics = merge_metrics([out["metrics"] for out in validation_step_outputs]) self.val_metrics = {} for task_name, results in metrics.items(): + scores = calc_eval_metrics(results.scores) name = f"val_{task_name}_f1" - f_score = calc_eval_metrics(results.scores)["f_score"] + f_score = scores["f_score"] self.val_metrics[name] = f_score self.log(name, f_score) + self.log(f"val_{task_name}_recall", scores["recall"]) + self.log(f"val_{task_name}_precision", scores["precision"]) if self.log_eval_failures: log_eval_failures( @@ -180,9 +186,10 @@ def test_epoch_end(self, test_step_outputs: list[Any]) -> None: self.log("test_loss", average_test_loss) metrics = merge_metrics([out["metrics"] for out in test_step_outputs]) for task_name, results in metrics.items(): - self.log( - f"test_{task_name}_f1", calc_eval_metrics(results.scores)["f_score"] - ) + scores = calc_eval_metrics(results.scores) + self.log(f"test_{task_name}_f1", scores["f_score"]) + self.log(f"test_{task_name}_recall", scores["recall"]) + self.log(f"test_{task_name}_precision", scores["precision"]) if self.log_eval_failures: log_eval_failures( self.output_dir, diff --git a/frame_semantic_transformer/training/find_best_val_model_paths.py b/frame_semantic_transformer/training/find_best_val_model_paths.py index 2ce10d3..32a260b 100644 --- a/frame_semantic_transformer/training/find_best_val_model_paths.py +++ b/frame_semantic_transformer/training/find_best_val_model_paths.py @@ -41,7 +41,7 @@ def get_model_scores(output_name: str) -> dict[str, float]: Helper function to get the scores for a given model """ scores = {} - for name_part in output_name.split("--"): + for name_part in output_name.split("-"): if "=" in name_part: key, value = name_part.split("=") if key in KEYS_TO_CHECK: diff --git a/frame_semantic_transformer/training/train.py b/frame_semantic_transformer/training/train.py index 8b5d557..eddbbb3 100644 --- a/frame_semantic_transformer/training/train.py +++ b/frame_semantic_transformer/training/train.py @@ -54,6 +54,7 @@ def train( pl_callbacks: Optional[list[Callback]] = None, pl_loggers: Optional[list[Logger]] = None, resume_from_checkpoint: Optional[str] = None, + remove_non_optimal_models: bool = True, ) -> tuple[T5ForConditionalGeneration, T5TokenizerFast]: device = torch.device("cuda" if use_gpu else "cpu") logger.info("loading base T5 model") @@ -113,6 +114,7 @@ def train( save_only_last_epoch=save_only_last_epoch, skip_initial_epochs_validation=skip_initial_epochs_validation, loader_cache=loader_cache, + remove_non_optimal_models=remove_non_optimal_models, ) # add callbacks diff --git a/pyproject.toml b/pyproject.toml index 4bebc29..32b744a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,6 +15,7 @@ pytorch-lightning = "^1.6.2" tqdm = "^4.64.0" sentencepiece = "^0.1.97" protobuf = "^3.20.1" +nlpaug = "^1.1.11" [tool.poetry.dev-dependencies] pytest = "^5.2" diff --git a/setup.cfg b/setup.cfg index d4a9947..e09ef4c 100644 --- a/setup.cfg +++ b/setup.cfg @@ -23,6 +23,9 @@ ignore_missing_imports = True [mypy-nltk.*] ignore_missing_imports = True +[mypy-nlpaug.*] +ignore_missing_imports = True + [mypy-flask_cors.*] ignore_missing_imports = True diff --git a/tests/__snapshots__/test_frame_semantic_transformer.ambr b/tests/__snapshots__/test_frame_semantic_transformer.ambr index 1017cbe..2dd44e3 100644 --- a/tests/__snapshots__/test_frame_semantic_transformer.ambr +++ b/tests/__snapshots__/test_frame_semantic_transformer.ambr @@ -1,3 +1,3 @@ # name: test_basic_detect_frames_functionality - DetectFramesResult(sentence="I'm getting quite hungry, but I can wait a bit longer.", trigger_locations=[4, 18, 36], frames=[FrameResult(name='Getting', trigger_location=4, frame_elements=[FrameElementResult(name='Recipient', text='I'), FrameElementResult(name='Theme', text='quite hungry')]), FrameResult(name='Biological_urge', trigger_location=18, frame_elements=[FrameElementResult(name='Experiencer', text='I'), FrameElementResult(name='Degree', text='quite')]), FrameResult(name='Waiting', trigger_location=36, frame_elements=[FrameElementResult(name='Protagonist', text='I'), FrameElementResult(name='Time', text='a bit longer')])]) + DetectFramesResult(sentence="I'm getting quite hungry, but I can wait a bit longer.", trigger_locations=[4, 18, 36], frames=[FrameResult(name='Transition_to_state', trigger_location=4, frame_elements=[FrameElementResult(name='Entity', text='I'), FrameElementResult(name='Final_quality', text='quite hungry')]), FrameResult(name='Biological_urge', trigger_location=18, frame_elements=[FrameElementResult(name='Experiencer', text='I'), FrameElementResult(name='Degree', text='quite')]), FrameResult(name='Waiting', trigger_location=36, frame_elements=[FrameElementResult(name='Protagonist', text='I'), FrameElementResult(name='Duration', text='a bit longer')])]) # --- diff --git a/tests/conftest.py b/tests/conftest.py index 2266edd..b77fba4 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -9,15 +9,20 @@ from frame_semantic_transformer.data.loaders.framenet17.ensure_framenet_downloaded import ( ensure_framenet_downloaded, ) +from frame_semantic_transformer.data.loaders.framenet17.ensure_wordnet_downloaded import ( + ensure_wordnet_downloaded, +) from frame_semantic_transformer.data.loaders.propbank34.ensure_propbank_downloaded import ( ensure_propbank_downloaded, ) +ensure_wordnet_downloaded() ensure_framenet_downloaded() ensure_propbank_downloaded() _loader_cache = LoaderDataCache(Framenet17InferenceLoader()) -_training_loader = Framenet17TrainingLoader() +# exemplars are really slow to load, so skip those for this fixture +_training_loader = Framenet17TrainingLoader(include_exemplars=False) @pytest.fixture diff --git a/tests/data/__snapshots__/test_tasks_from_annotated_sentences.ambr b/tests/data/__snapshots__/test_tasks_from_annotated_sentences.ambr index 91e6c44..af17e08 100644 --- a/tests/data/__snapshots__/test_tasks_from_annotated_sentences.ambr +++ b/tests/data/__snapshots__/test_tasks_from_annotated_sentences.ambr @@ -12,7 +12,7 @@ ), tuple( 'frame_classification', - 'FRAME Becoming Biological_urge Cause_change Cause_to_move_in_place Change_direction Change_operational_state Contingency Moving_in_place Respond_to_proposal Submitting_documents Surrendering Temporal_subregion Turning_out Undergo_change : Because * turning welfare recipients into tax payers just makes sense.', + 'FRAME Becoming Cause_change Cause_to_move_in_place Change_direction Contingency Moving_in_place Temporal_subregion Undergo_change : Because * turning welfare recipients into tax payers just makes sense.', 'Cause_change', ), tuple( @@ -37,7 +37,7 @@ ), tuple( 'frame_classification', - 'FRAME Being_named First_experience Ordinal_numbers : When I * first came to Goodwill I was a single parent with little or no self-esteem.', + 'FRAME Desirability First_experience Ordinal_numbers : When I * first came to Goodwill I was a single parent with little or no self-esteem.', 'First_experience', ), tuple( @@ -47,7 +47,7 @@ ), tuple( 'frame_classification', - 'FRAME : When I first * came to Goodwill I was a single parent with little or no self-esteem.', + 'FRAME Coming_to_be Process_end Waking_up Arriving Motion Relative_time Transition_to_a_situation : When I first * came to Goodwill I was a single parent with little or no self-esteem.', 'Arriving', ), tuple( @@ -57,7 +57,7 @@ ), tuple( 'frame_classification', - 'FRAME Personal_relationship Sole_instance : When I first came to Goodwill I was a * single parent with little or no self-esteem.', + 'FRAME Accompaniment Personal_relationship Sole_instance : When I first came to Goodwill I was a * single parent with little or no self-esteem.', 'Personal_relationship', ), tuple( @@ -67,7 +67,7 @@ ), tuple( 'frame_classification', - 'FRAME Kinship : When I first came to Goodwill I was a single * parent with little or no self-esteem.', + 'FRAME Cutting Kinship People_by_jurisdiction Political_locales : When I first came to Goodwill I was a single * parent with little or no self-esteem.', 'Kinship', ), tuple( @@ -87,7 +87,7 @@ ), tuple( 'frame_classification', - 'FRAME Goal Locative_relation : When I first came * to Goodwill I was a single parent with little or no self-esteem.', + 'FRAME Coming_to_be Process_end Waking_up Goal Locative_relation : When I first came * to Goodwill I was a single parent with little or no self-esteem.', 'Goal', ), tuple( @@ -102,7 +102,7 @@ ), tuple( 'frame_classification', - 'FRAME Documents : I was on welfare and without my * diploma.', + 'FRAME Documents Leadership Social_interaction_evaluation : I was on welfare and without my * diploma.', 'Documents', ), tuple( @@ -117,7 +117,7 @@ ), tuple( 'frame_classification', - 'FRAME Coming_to_be Process_end Waking_up Arriving Motion Relative_time Transition_to_a_situation : * Coming to Goodwill was the first step toward my becoming totally independent.', + 'FRAME Coming_to_be Process_end Waking_up Activity_start Arriving Motion Process_start Relative_time Statement Stimulus_focus Transition_to_a_situation : * Coming to Goodwill was the first step toward my becoming totally independent.', 'Arriving', ), tuple( @@ -137,7 +137,7 @@ ), tuple( 'frame_classification', - 'FRAME Being_named First_experience Ordinal_numbers : Coming to Goodwill was the * first step toward my becoming totally independent.', + 'FRAME Desirability First_experience Ordinal_numbers : Coming to Goodwill was the * first step toward my becoming totally independent.', 'Ordinal_numbers', ), tuple( @@ -147,7 +147,7 @@ ), tuple( 'frame_classification', - 'FRAME Cause_change_of_position_on_a_scale Connecting_architecture Cotheme Intentionally_act Quitting Self_motion : Coming to Goodwill was the first * step toward my becoming totally independent.', + 'FRAME Connecting_architecture Intentionally_act Self_motion : Coming to Goodwill was the first * step toward my becoming totally independent.', 'Intentionally_act', ), tuple( @@ -157,7 +157,7 @@ ), tuple( 'frame_classification', - 'FRAME Becoming Eventive_affecting Suitability : Coming to Goodwill was the first step toward my * becoming totally independent.', + 'FRAME Becoming Suitability : Coming to Goodwill was the first step toward my * becoming totally independent.', 'Becoming', ), tuple( @@ -167,7 +167,7 @@ ), tuple( 'frame_classification', - 'FRAME Adding_up Amounting_to Completeness Degree Render_nonfunctional : Coming to Goodwill was the first step toward my becoming * totally independent.', + 'FRAME Adding_up Amounting_to Bringing Completeness Degree Render_nonfunctional Self_motion : Coming to Goodwill was the first step toward my becoming * totally independent.', 'Degree', ), tuple( @@ -192,7 +192,7 @@ ), tuple( 'frame_classification', - 'FRAME Adding_up Amounting_to Completeness Degree Render_nonfunctional : I am now... * totally off of welfare.', + 'FRAME Adding_up Amounting_to Bringing Completeness Degree Render_nonfunctional Self_motion : I am now... * totally off of welfare.', 'Degree', ), tuple( @@ -217,7 +217,7 @@ ), tuple( 'frame_classification', - "FRAME Desiring Difficulty Experiencer_focus Hedging Likelihood Similarity Sleep Statement : I really * like my job. '' -- Sherry", + "FRAME Experiencer_focus Hedging Likelihood Probability Similarity : I really * like my job. '' -- Sherry", 'Experiencer_focus', ), tuple( @@ -262,7 +262,7 @@ ), tuple( 'frame_classification', - 'FRAME Being_employed Being_operational Collaboration Coming_to_believe Dimension Exercising Experiencer_focus Labor_product Locale_by_use Resolve_problem Usefulness Version_sequence Work Working_a_post : Because people want to * work.', + 'FRAME Being_employed Being_operational Dimension Employing Labor_product Locale_by_use Usefulness Version_sequence Work Working_a_post : Because people want to * work.', 'Being_employed', ), tuple( @@ -287,7 +287,7 @@ ), tuple( 'frame_classification', - "FRAME Activity_done_state Activity_finish Process_completed_state : I'd never * finished high school.", + "FRAME Activity_done_state Activity_finish Body_parts Desirability Domain Emotion_directed Fields Fining Funding Ordinal_numbers Origin Process_completed_state Process_end Time_vector Usefulness Version_sequence : I'd never * finished high school.", 'Activity_finish', ), tuple( @@ -302,7 +302,7 @@ ), tuple( 'frame_classification', - "FRAME : I * had no experience or skills... The only thing I did know for sure was here's a chance to change things for me and my children... I rode a bike to Goodwill in the rain and snow.", + "FRAME Giving_birth Have_associated Inclusion Ingest_substance Ingestion Possession Sex : I * had no experience or skills... The only thing I did know for sure was here's a chance to change things for me and my children... I rode a bike to Goodwill in the rain and snow.", 'Possession', ), tuple( @@ -342,7 +342,7 @@ ), tuple( 'frame_classification', - "FRAME Sole_instance : I had no experience or skills... The * only thing I did know for sure was here's a chance to change things for me and my children... I rode a bike to Goodwill in the rain and snow.", + "FRAME Accuracy Being_in_operation Cardinal_numbers Means Non-gradable_proximity Process_continue Sole_instance Spatial_contact Temporal_collocation Topic : I had no experience or skills... The * only thing I did know for sure was here's a chance to change things for me and my children... I rode a bike to Goodwill in the rain and snow.", 'Sole_instance', ), tuple( @@ -352,7 +352,7 @@ ), tuple( 'frame_classification', - "FRAME Awareness Certainty Differentiation Familiarity Telling : I had no experience or skills... The only thing I did * know for sure was here's a chance to change things for me and my children... I rode a bike to Goodwill in the rain and snow.", + "FRAME Awareness Awareness_status Being_named Certainty Differentiation Fame Familiarity : I had no experience or skills... The only thing I did * know for sure was here's a chance to change things for me and my children... I rode a bike to Goodwill in the rain and snow.", 'Awareness', ), tuple( @@ -362,7 +362,7 @@ ), tuple( 'frame_classification', - "FRAME Alternatives Becoming_aware Coincidence Daring Likelihood Opportunity Probability : I had no experience or skills... The only thing I did know for sure was here's a * chance to change things for me and my children... I rode a bike to Goodwill in the rain and snow.", + "FRAME Alternatives Coincidence Daring Likelihood Opportunity Probability Communication_manner : I had no experience or skills... The only thing I did know for sure was here's a * chance to change things for me and my children... I rode a bike to Goodwill in the rain and snow.", 'Likelihood', ), tuple( @@ -382,7 +382,7 @@ ), tuple( 'frame_classification', - "FRAME : I had no experience or skills... The only thing I did know for sure was here's a chance to change things for me and my * children... I rode a bike to Goodwill in the rain and snow.", + "FRAME Appellations Kinship People_by_age : I had no experience or skills... The only thing I did know for sure was here's a chance to change things for me and my * children... I rode a bike to Goodwill in the rain and snow.", 'Kinship', ), tuple( @@ -392,7 +392,7 @@ ), tuple( 'frame_classification', - "FRAME : I had no experience or skills... The only thing I did know for sure was here's a chance to change things for me and my children... I * rode a bike to Goodwill in the rain and snow.", + "FRAME Operate_vehicle Ride_vehicle : I had no experience or skills... The only thing I did know for sure was here's a chance to change things for me and my children... I * rode a bike to Goodwill in the rain and snow.", 'Operate_vehicle', ), tuple( @@ -457,7 +457,7 @@ ), tuple( 'frame_classification', - "FRAME Attention_getting Existence Locative_relation Spatial_co-location Success_or_failure : I wanted to be * there... I had my second chance to change my life. '' -- Donna", + "FRAME Locative_relation Spatial_co-location : I wanted to be * there... I had my second chance to change my life. '' -- Donna", 'Locative_relation', ), tuple( @@ -467,7 +467,7 @@ ), tuple( 'frame_classification', - "FRAME : I wanted to be there... I * had my second chance to change my life. '' -- Donna", + "FRAME Giving_birth Have_associated Inclusion Ingest_substance Ingestion Possession Sex : I wanted to be there... I * had my second chance to change my life. '' -- Donna", 'Possession', ), tuple( @@ -497,7 +497,7 @@ ), tuple( 'frame_classification', - "FRAME Alternatives Becoming_aware Coincidence Daring Likelihood Opportunity Probability : I wanted to be there... I had my second * chance to change my life. '' -- Donna", + "FRAME Alternatives Coincidence Daring Likelihood Opportunity Probability Communication_manner : I wanted to be there... I had my second * chance to change my life. '' -- Donna", 'Opportunity', ), tuple( @@ -512,7 +512,7 @@ ), tuple( 'frame_classification', - 'FRAME Education_teaching : Because * teaching a man to fish will keep him fed for his entire life.', + 'FRAME Education_teaching People_by_vocation : Because * teaching a man to fish will keep him fed for his entire life.', 'Education_teaching', ), tuple( @@ -522,7 +522,7 @@ ), tuple( 'frame_classification', - 'FRAME People Working_a_post : Because teaching a * man to fish will keep him fed for his entire life.', + 'FRAME Being_in_control Body_parts Buildings Conduct Grooming Manner Operating_a_system People People_by_vocation Success_or_failure Successful_action Text Working_a_post : Because teaching a * man to fish will keep him fed for his entire life.', 'People', ), tuple( @@ -542,7 +542,7 @@ ), tuple( 'frame_classification', - 'FRAME : Because teaching a man to fish will keep him * fed for his entire life.', + 'FRAME Political_locales Capacity Ingestion : Because teaching a man to fish will keep him * fed for his entire life.', 'Ingestion', ), tuple( @@ -562,7 +562,7 @@ ), tuple( 'frame_classification', - 'FRAME Activity_ongoing Attention Avoiding Cause_to_continue Compliance Preventing_or_letting Retaining Storing : Because teaching a man to fish will * keep him fed for his entire life.', + 'FRAME Activity_ongoing Cause_to_continue Compliance Retaining Storing : Because teaching a man to fish will * keep him fed for his entire life.', 'Cause_to_continue', ), tuple( @@ -587,7 +587,7 @@ ), tuple( 'frame_classification', - 'FRAME Required_event Possession : Before I * got to Goodwill, I was on a mission.', + 'FRAME Required_event Arriving Board_vehicle Bringing Come_down_with Disembarking Getting Giving_birth Grasp Intentional_deception Possession Transition_to_state : Before I * got to Goodwill, I was on a mission.', 'Arriving', ), tuple( @@ -607,7 +607,7 @@ ), tuple( 'frame_classification', - 'FRAME Being_obligated : Before I got to Goodwill, I was on a * mission.', + 'FRAME Appellations Attention_getting Being_obligated Hit_or_miss Success_or_failure Text : Before I got to Goodwill, I was on a * mission.', 'Being_obligated', ), tuple( @@ -707,7 +707,7 @@ ), tuple( 'frame_classification', - "FRAME Possession : I've * got more than a job ; I've got a career.", + "FRAME Arriving Board_vehicle Bringing Come_down_with Disembarking Getting Giving_birth Grasp Intentional_deception Possession Transition_to_state : I've * got more than a job ; I've got a career.", 'Possession', ), tuple( @@ -737,7 +737,7 @@ ), tuple( 'frame_classification', - "FRAME Possession : I've got more than a job ; I've * got a career.", + "FRAME Resolve_problem Arriving Board_vehicle Bringing Come_down_with Disembarking Getting Giving_birth Grasp Intentional_deception Possession Transition_to_state : I've got more than a job ; I've * got a career.", 'Possession', ), tuple( @@ -762,7 +762,7 @@ ), tuple( 'frame_classification', - "FRAME Desiring Difficulty Experiencer_focus Hedging Likelihood Similarity Sleep Statement : My instructor played a role * like no other instructor I've ever had I appreciate everything that Goodwill has done for me. '' -- Cornell", + "FRAME Experiencer_focus Hedging Likelihood Probability Similarity : My instructor played a role * like no other instructor I've ever had I appreciate everything that Goodwill has done for me. '' -- Cornell", 'Similarity', ), tuple( @@ -772,7 +772,7 @@ ), tuple( 'frame_classification', - "FRAME Increment Personal_relationship : My instructor played a role like no * other instructor I've ever had I appreciate everything that Goodwill has done for me. '' -- Cornell", + "FRAME Increment : My instructor played a role like no * other instructor I've ever had I appreciate everything that Goodwill has done for me. '' -- Cornell", 'Increment', ), tuple( @@ -782,7 +782,7 @@ ), tuple( 'frame_classification', - "FRAME : My instructor played a role like no other instructor I've ever * had I appreciate everything that Goodwill has done for me. '' -- Cornell", + "FRAME Giving_birth Have_associated Inclusion Ingest_substance Ingestion Possession Sex : My instructor played a role like no other instructor I've ever * had I appreciate everything that Goodwill has done for me. '' -- Cornell", 'Possession', ), tuple( @@ -802,7 +802,7 @@ ), tuple( 'frame_classification', - "FRAME Activity_done_state Process_completed_state : My instructor played a role like no other instructor I've ever had I appreciate everything that Goodwill has * done for me. '' -- Cornell", + "FRAME Activity_done_state Ingest_substance Intentionally_act Intentionally_affect Process_completed_state Sex Thriving Touring Dressing Giving : My instructor played a role like no other instructor I've ever had I appreciate everything that Goodwill has * done for me. '' -- Cornell", 'Intentionally_affect', ), tuple( @@ -812,7 +812,7 @@ ), tuple( 'frame_classification', - "FRAME Performers_and_roles : My instructor played a * role like no other instructor I've ever had I appreciate everything that Goodwill has done for me. '' -- Cornell", + "FRAME Body_movement Cause_motion Cause_to_move_in_place Food Mass_motion Motion Moving_in_place Performers_and_roles Reshaping Sound_movement : My instructor played a * role like no other instructor I've ever had I appreciate everything that Goodwill has done for me. '' -- Cornell", 'Performers_and_roles', ), tuple( @@ -822,7 +822,7 @@ ), tuple( 'frame_classification', - "FRAME Being_relevant Cause_to_make_noise Competition Compliance Difficulty Make_noise Performers_and_roles Performing_arts : My instructor * played a role like no other instructor I've ever had I appreciate everything that Goodwill has done for me. '' -- Cornell", + "FRAME Cause_to_make_noise Competition Gizmo Make_noise Participation Performers Performers_and_roles Performing_arts : My instructor * played a role like no other instructor I've ever had I appreciate everything that Goodwill has done for me. '' -- Cornell", 'Performers_and_roles', ), tuple( @@ -837,7 +837,7 @@ ), tuple( 'frame_classification', - 'FRAME Calendric_unit Measure_duration : Each * year, we help thousands of people who face tremendous obstacles.', + 'FRAME Calendric_unit Frequency Measure_duration : Each * year, we help thousands of people who face tremendous obstacles.', 'Calendric_unit', ), tuple( @@ -847,7 +847,7 @@ ), tuple( 'frame_classification', - 'FRAME Assistance Self_control : Each year, we * help thousands of people who face tremendous obstacles.', + 'FRAME Assistance : Each year, we * help thousands of people who face tremendous obstacles.', 'Assistance', ), tuple( @@ -887,7 +887,7 @@ ), tuple( 'frame_classification', - 'FRAME Body_parts Compatibility Confronting_problem Contrary_circumstances Facial_expression Part_orientational : Each year, we help thousands of people who * face tremendous obstacles.', + 'FRAME Body_parts Confronting_problem Facial_expression Grooming Part_orientational : Each year, we help thousands of people who * face tremendous obstacles.', 'Confronting_problem', ), tuple( @@ -902,7 +902,7 @@ ), tuple( 'frame_classification', - 'FRAME Cardinal_numbers : Their * one common goal: they all want to work.', + 'FRAME Accuracy Being_in_operation Cardinal_numbers Means Non-gradable_proximity Process_continue Sole_instance Spatial_contact Temporal_collocation Topic : Their * one common goal: they all want to work.', 'Cardinal_numbers', ), tuple( @@ -932,7 +932,7 @@ ), tuple( 'frame_classification', - 'FRAME Being_employed Being_operational Collaboration Coming_to_believe Dimension Exercising Experiencer_focus Labor_product Locale_by_use Resolve_problem Usefulness Version_sequence Work Working_a_post : Their one common goal: they all want to * work.', + 'FRAME Being_employed Being_operational Dimension Employing Labor_product Locale_by_use Usefulness Version_sequence Work Working_a_post : Their one common goal: they all want to * work.', 'Being_employed', ), tuple( @@ -947,7 +947,7 @@ ), tuple( 'frame_classification', - "FRAME Economy Frugality : A robust * economy helps by providing job opportunities, but to be honest, most of the people who aren't working today are quite simply the ones who face the greatest obstacles.", + "FRAME Domain Economy Frugality : A robust * economy helps by providing job opportunities, but to be honest, most of the people who aren't working today are quite simply the ones who face the greatest obstacles.", 'Economy', ), tuple( @@ -957,7 +957,7 @@ ), tuple( 'frame_classification', - "FRAME Assistance Self_control : A robust economy * helps by providing job opportunities, but to be honest, most of the people who aren't working today are quite simply the ones who face the greatest obstacles.", + "FRAME Assistance : A robust economy * helps by providing job opportunities, but to be honest, most of the people who aren't working today are quite simply the ones who face the greatest obstacles.", 'Assistance', ), tuple( @@ -967,7 +967,7 @@ ), tuple( 'frame_classification', - "FRAME Conditional_occurrence Supply : A robust economy helps by * providing job opportunities, but to be honest, most of the people who aren't working today are quite simply the ones who face the greatest obstacles.", + "FRAME Conditional_occurrence Supply Terms_of_agreement : A robust economy helps by * providing job opportunities, but to be honest, most of the people who aren't working today are quite simply the ones who face the greatest obstacles.", 'Supply', ), tuple( @@ -1017,7 +1017,7 @@ ), tuple( 'frame_classification', - "FRAME Being_employed Being_operational Collaboration Coming_to_believe Dimension Exercising Experiencer_focus Labor_product Locale_by_use Resolve_problem Usefulness Version_sequence Work Working_a_post : A robust economy helps by providing job opportunities, but to be honest, most of the people who aren't * working today are quite simply the ones who face the greatest obstacles.", + "FRAME Being_employed Being_operational Dimension Employing Labor_product Locale_by_use Usefulness Version_sequence Work Working_a_post : A robust economy helps by providing job opportunities, but to be honest, most of the people who aren't * working today are quite simply the ones who face the greatest obstacles.", 'Being_employed', ), tuple( @@ -1027,7 +1027,7 @@ ), tuple( 'frame_classification', - "FRAME Body_parts Compatibility Confronting_problem Contrary_circumstances Facial_expression Part_orientational : A robust economy helps by providing job opportunities, but to be honest, most of the people who aren't working today are quite simply the ones who * face the greatest obstacles.", + "FRAME Body_parts Confronting_problem Facial_expression Grooming Part_orientational : A robust economy helps by providing job opportunities, but to be honest, most of the people who aren't working today are quite simply the ones who * face the greatest obstacles.", 'Confronting_problem', ), tuple( @@ -1052,7 +1052,7 @@ ), tuple( 'frame_classification', - 'FRAME Member_of_military Public_services Rite : The kinds of * services we provide help people deal with obstacles like health care, transportation and child care - problems that are big enough on their own without being compounded by factors like physical and mental disabilities, illiteracy and lack of job skills.', + 'FRAME Assistance Being_incarcerated Capacity Function Offering People_by_vocation Public_services Rite Serving_in_capacity Sports_jargon Subordinates_and_superiors Successful_action Sufficiency : The kinds of * services we provide help people deal with obstacles like health care, transportation and child care - problems that are big enough on their own without being compounded by factors like physical and mental disabilities, illiteracy and lack of job skills.', 'Public_services', ), tuple( @@ -1062,7 +1062,7 @@ ), tuple( 'frame_classification', - 'FRAME Conditional_occurrence Supply : The kinds of services we * provide help people deal with obstacles like health care, transportation and child care - problems that are big enough on their own without being compounded by factors like physical and mental disabilities, illiteracy and lack of job skills.', + 'FRAME Conditional_occurrence Supply Terms_of_agreement : The kinds of services we * provide help people deal with obstacles like health care, transportation and child care - problems that are big enough on their own without being compounded by factors like physical and mental disabilities, illiteracy and lack of job skills.', 'Supply', ), tuple( @@ -1072,7 +1072,7 @@ ), tuple( 'frame_classification', - 'FRAME Assistance Self_control : The kinds of services we provide * help people deal with obstacles like health care, transportation and child care - problems that are big enough on their own without being compounded by factors like physical and mental disabilities, illiteracy and lack of job skills.', + 'FRAME Assistance : The kinds of services we provide * help people deal with obstacles like health care, transportation and child care - problems that are big enough on their own without being compounded by factors like physical and mental disabilities, illiteracy and lack of job skills.', 'Assistance', ), tuple( @@ -1102,7 +1102,7 @@ ), tuple( 'frame_classification', - 'FRAME Appellations Difficulty Kinship Offenses People_by_age : The kinds of services we provide help people deal with obstacles like health care, transportation and * child care - problems that are big enough on their own without being compounded by factors like physical and mental disabilities, illiteracy and lack of job skills.', + 'FRAME Appellations Kinship People_by_age : The kinds of services we provide help people deal with obstacles like health care, transportation and * child care - problems that are big enough on their own without being compounded by factors like physical and mental disabilities, illiteracy and lack of job skills.', 'People_by_age', ), tuple( @@ -1142,7 +1142,7 @@ ), tuple( 'frame_classification', - 'FRAME Contingency : The kinds of services we provide help people deal with obstacles like health care, transportation and child care - problems that are big enough on their own without being compounded by * factors like physical and mental disabilities, illiteracy and lack of job skills.', + 'FRAME Contingency Locale_by_use Offshoot : The kinds of services we provide help people deal with obstacles like health care, transportation and child care - problems that are big enough on their own without being compounded by * factors like physical and mental disabilities, illiteracy and lack of job skills.', 'Contingency', ), tuple( @@ -1187,7 +1187,7 @@ ), tuple( 'frame_classification', - 'FRAME Being_named Duration_description Duration_relation Ordinal_numbers Relative_time : * Last year, Goodwill helped 3,300 people find jobs that increased their self - sufficiency.', + 'FRAME Duration_description Duration_relation Ordinal_numbers Relative_time : * Last year, Goodwill helped 3,300 people find jobs that increased their self - sufficiency.', 'Relative_time', ), tuple( @@ -1197,7 +1197,7 @@ ), tuple( 'frame_classification', - 'FRAME Calendric_unit Measure_duration : Last * year, Goodwill helped 3,300 people find jobs that increased their self - sufficiency.', + 'FRAME Calendric_unit Frequency Measure_duration : Last * year, Goodwill helped 3,300 people find jobs that increased their self - sufficiency.', 'Calendric_unit', ), tuple( @@ -1207,7 +1207,7 @@ ), tuple( 'frame_classification', - 'FRAME Assistance Self_control : Last year, Goodwill * helped 3,300 people find jobs that increased their self - sufficiency.', + 'FRAME Assistance : Last year, Goodwill * helped 3,300 people find jobs that increased their self - sufficiency.', 'Assistance', ), tuple( @@ -1247,7 +1247,7 @@ ), tuple( 'frame_classification', - 'FRAME Achieving_first Arriving Becoming_aware Being_located Coming_to_believe Coming_up_with Documents Locating Regard Verdict : Last year, Goodwill helped 3,300 people * find jobs that increased their self - sufficiency.', + 'FRAME Achieving_first Arriving Becoming_aware Being_located Circumscribed_existence Coming_to_believe Coming_up_with Documents Intentionally_create Locating Regard Verdict : Last year, Goodwill helped 3,300 people * find jobs that increased their self - sufficiency.', 'Locating', ), tuple( @@ -1282,7 +1282,7 @@ ), tuple( 'frame_classification', - 'FRAME Member_of_military Public_services Rite : Your gift to Goodwill will be used directly to support * services that will help even more find jobs.', + 'FRAME Assistance Being_incarcerated Capacity Function Offering People_by_vocation Public_services Rite Serving_in_capacity Sports_jargon Subordinates_and_superiors Successful_action Sufficiency : Your gift to Goodwill will be used directly to support * services that will help even more find jobs.', 'Public_services', ), tuple( @@ -1292,7 +1292,7 @@ ), tuple( 'frame_classification', - 'FRAME Assistance Self_control : Your gift to Goodwill will be used directly to support services that will * help even more find jobs.', + 'FRAME Assistance : Your gift to Goodwill will be used directly to support services that will * help even more find jobs.', 'Assistance', ), tuple( @@ -1327,7 +1327,7 @@ ), tuple( 'frame_classification', - 'FRAME Attention_getting Existence Locative_relation Spatial_co-location Success_or_failure : Right now, * there are thousands of people who do not know what it feels like to support themselves.', + 'FRAME Existence Locative_relation Spatial_co-location : Right now, * there are thousands of people who do not know what it feels like to support themselves.', 'Existence', ), tuple( @@ -1337,7 +1337,7 @@ ), tuple( 'frame_classification', - 'FRAME : Right now, there * are thousands of people who do not know what it feels like to support themselves.', + 'FRAME Existence Arriving Biological_urge Cause_to_start Coming_to_be Experiencer_obj Personal_success Dimension Fields Locale Performers_and_roles : Right now, there * are thousands of people who do not know what it feels like to support themselves.', 'Existence', ), tuple( @@ -1377,7 +1377,7 @@ ), tuple( 'frame_classification', - 'FRAME Awareness Certainty Differentiation Familiarity Telling : Right now, there are thousands of people who do not * know what it feels like to support themselves.', + 'FRAME Awareness Awareness_status Being_named Certainty Differentiation Fame Familiarity : Right now, there are thousands of people who do not * know what it feels like to support themselves.', 'Awareness', ), tuple( @@ -1387,7 +1387,7 @@ ), tuple( 'frame_classification', - 'FRAME Desiring Feeling Give_impression Opinion Others_situation_as_stimulus Perception_active Perception_experience Seeking Sensation : Right now, there are thousands of people who do not know what it * feels like to support themselves.', + 'FRAME Desiring Body_parts Feeling Give_impression Opinion Perception_active Perception_experience Seeking Sensation : Right now, there are thousands of people who do not know what it * feels like to support themselves.', 'Feeling', ), tuple( @@ -1412,7 +1412,7 @@ ), tuple( 'frame_classification', - 'FRAME Capability Containers Firing Likelihood Measure_volume Possibility Preserving : You * can help them to know that feeling.', + 'FRAME Capability Cause_harm Containers Corporal_punishment Firing Likelihood Locale_by_use Measure_volume Possibility Preserving : You * can help them to know that feeling.', 'Capability', ), tuple( @@ -1422,7 +1422,7 @@ ), tuple( 'frame_classification', - 'FRAME Assistance Self_control : You can * help them to know that feeling.', + 'FRAME Assistance : You can * help them to know that feeling.', 'Assistance', ), tuple( @@ -1432,7 +1432,7 @@ ), tuple( 'frame_classification', - 'FRAME Awareness Certainty Differentiation Familiarity Telling : You can help them to * know that feeling.', + 'FRAME Awareness Awareness_status Being_named Certainty Differentiation Fame Familiarity : You can help them to * know that feeling.', 'Awareness', ), tuple( @@ -1447,7 +1447,7 @@ ), tuple( 'frame_classification', - 'FRAME Assistance Self_control : Please * help Goodwill and help people find jobs.', + 'FRAME Assistance : Please * help Goodwill and help people find jobs.', 'Assistance', ), tuple( @@ -1467,7 +1467,7 @@ ), tuple( 'frame_classification', - 'FRAME Achieving_first Arriving Becoming_aware Being_located Coming_to_believe Coming_up_with Documents Locating Regard Verdict : Please help Goodwill and help people * find jobs.', + 'FRAME Achieving_first Arriving Becoming_aware Being_located Circumscribed_existence Coming_to_believe Coming_up_with Documents Intentionally_create Locating Regard Verdict : Please help Goodwill and help people * find jobs.', 'Locating', ), tuple( @@ -1487,7 +1487,7 @@ ), tuple( 'frame_classification', - 'FRAME Assistance Self_control : Please help Goodwill and * help people find jobs.', + 'FRAME Assistance : Please help Goodwill and * help people find jobs.', 'Assistance', ), tuple( diff --git a/tests/data/augmentations/modification_helpers/test_modify_text_without_changing_length.py b/tests/data/augmentations/modification_helpers/test_modify_text_without_changing_length.py new file mode 100644 index 0000000..15fbadc --- /dev/null +++ b/tests/data/augmentations/modification_helpers/test_modify_text_without_changing_length.py @@ -0,0 +1,115 @@ +from __future__ import annotations +from typing import cast +from unittest.mock import MagicMock + +import pytest + +from frame_semantic_transformer.data.LoaderDataCache import LoaderDataCache +from frame_semantic_transformer.data.augmentations.modification_helpers import ( + modify_text_without_changing_length, +) +from frame_semantic_transformer.data.frame_types import FrameElementAnnotation +from frame_semantic_transformer.data.tasks import ( + ArgumentsExtractionSample, + ArgumentsExtractionTask, + FrameClassificationSample, + FrameClassificationTask, + TriggerIdentificationSample, + TriggerIdentificationTask, +) + + +def create_arg_extraction_sample( + sentence: str, loader_cache: LoaderDataCache +) -> ArgumentsExtractionSample: + return ArgumentsExtractionSample( + task=ArgumentsExtractionTask( + text=sentence, + trigger_loc=15, + frame="blah", + loader_cache=loader_cache, + ), + frame_elements=[ + FrameElementAnnotation( + name="The_elm", + start_loc=4, + end_loc=9, + ) + ], + ) + + +def create_frame_classification_sample( + sentence: str, loader_cache: LoaderDataCache +) -> FrameClassificationSample: + return FrameClassificationSample( + task=FrameClassificationTask( + text=sentence, + trigger_loc=15, + loader_cache=loader_cache, + ), + frame="blah", + ) + + +def create_trigger_identification_sample(sentence: str) -> TriggerIdentificationSample: + return TriggerIdentificationSample( + task=TriggerIdentificationTask(text=sentence), + trigger_locs=[15], + ) + + +def test_modify_text_without_changing_length_for_arg_extraction_sample( + loader_cache: LoaderDataCache, +) -> None: + sentence = "The quick brown fox jumps over the lazy dog" + new_sentence = "The quick BROWN fox jumps over the LAZY MAN" + sample = create_arg_extraction_sample(sentence, loader_cache) + callback = MagicMock(return_value=new_sentence) + new_sample = cast( + ArgumentsExtractionSample, modify_text_without_changing_length(sample, callback) + ) + assert new_sample.task.text == new_sentence + assert new_sample.task.frame == sample.task.frame + assert new_sample.task.trigger_loc == sample.task.trigger_loc + assert new_sample.frame_elements == sample.frame_elements + callback.assert_called_with(sentence) + + +def test_throw_error_if_sentence_length_changes(loader_cache: LoaderDataCache) -> None: + sentence = "The quick brown fox jumps over the lazy dog" + new_sentence = "The LOOOOOOOOOONG quick BROWN fox jumps over the LAZY MAN" + sample = create_arg_extraction_sample(sentence, loader_cache) + callback = MagicMock(return_value=new_sentence) + with pytest.raises(ValueError): + modify_text_without_changing_length(sample, callback) + + +def test_modify_text_without_changing_length_for_frame_classification_sample( + loader_cache: LoaderDataCache, +) -> None: + sentence = "The quick brown fox jumps over the lazy dog" + new_sentence = "The quick BROWN fox jumps over the LAZY MAN" + sample = create_frame_classification_sample(sentence, loader_cache) + callback = MagicMock(return_value=new_sentence) + new_sample = cast( + FrameClassificationSample, modify_text_without_changing_length(sample, callback) + ) + assert new_sample.task.text == new_sentence + assert new_sample.frame == sample.frame + assert new_sample.task.trigger_loc == sample.task.trigger_loc + callback.assert_called_with(sentence) + + +def test_modify_text_without_changing_length_for_trigger_identification_sample() -> None: + sentence = "The quick brown fox jumps over the lazy dog" + new_sentence = "The quick BROWN fox jumps over the LAZY MAN" + sample = create_trigger_identification_sample(sentence) + callback = MagicMock(return_value=new_sentence) + new_sample = cast( + TriggerIdentificationSample, + modify_text_without_changing_length(sample, callback), + ) + assert new_sample.task.text == new_sentence + assert new_sample.trigger_locs == sample.trigger_locs + callback.assert_called_with(sentence) diff --git a/tests/data/augmentations/modification_helpers/test_splice_text.py b/tests/data/augmentations/modification_helpers/test_splice_text.py new file mode 100644 index 0000000..b1fbfb6 --- /dev/null +++ b/tests/data/augmentations/modification_helpers/test_splice_text.py @@ -0,0 +1,151 @@ +from __future__ import annotations +from typing import cast +from unittest.mock import MagicMock + +import pytest + +from frame_semantic_transformer.data.LoaderDataCache import LoaderDataCache +from frame_semantic_transformer.data.augmentations.modification_helpers import ( + splice_text, +) +from frame_semantic_transformer.data.frame_types import FrameElementAnnotation +from frame_semantic_transformer.data.tasks import ( + ArgumentsExtractionSample, + ArgumentsExtractionTask, + FrameClassificationSample, + FrameClassificationTask, + TriggerIdentificationSample, + TriggerIdentificationTask, +) + + +def create_arg_extraction_sample( + sentence: str, loader_cache: LoaderDataCache +) -> ArgumentsExtractionSample: + return ArgumentsExtractionSample( + task=ArgumentsExtractionTask( + text=sentence, + trigger_loc=16, + frame="blah", + loader_cache=loader_cache, + ), + frame_elements=[ + FrameElementAnnotation( + name="The_elm", + start_loc=4, + end_loc=9, + ) + ], + ) + + +def create_frame_classification_sample( + sentence: str, loader_cache: LoaderDataCache +) -> FrameClassificationSample: + return FrameClassificationSample( + task=FrameClassificationTask( + text=sentence, + trigger_loc=16, + loader_cache=loader_cache, + ), + frame="blah", + ) + + +def create_trigger_identification_sample(sentence: str) -> TriggerIdentificationSample: + return TriggerIdentificationSample( + task=TriggerIdentificationTask(text=sentence), + trigger_locs=[16], + ) + + +def test_splice_text_splices_the_text_into_the_sentence( + loader_cache: LoaderDataCache, +) -> None: + sentence = "The quick brown fox jumps over the lazy dog" + sample = create_arg_extraction_sample(sentence, loader_cache) + callback = MagicMock(return_value=(20, 10, "EATS")) + new_sample = cast(ArgumentsExtractionSample, splice_text(sample, callback)) + assert new_sample.task.text == "The quick brown fox EATS the lazy dog" + assert new_sample.task.frame == sample.task.frame + assert new_sample.task.trigger_loc == sample.task.trigger_loc + assert new_sample.frame_elements == sample.frame_elements + callback.assert_called_with(sentence, [16, 4, 9]) + + +def test_splice_text_leaves_sentence_unchanged_if_callback_returns_none( + loader_cache: LoaderDataCache, +) -> None: + sentence = "The quick brown fox jumps over the lazy dog" + sample = create_arg_extraction_sample(sentence, loader_cache) + callback = MagicMock(return_value=None) + new_sample = cast(ArgumentsExtractionSample, splice_text(sample, callback)) + assert new_sample.task.text == sentence + assert new_sample.task.frame == sample.task.frame + assert new_sample.task.trigger_loc == sample.task.trigger_loc + assert new_sample.frame_elements == sample.frame_elements + callback.assert_called_with(sentence, [16, 4, 9]) + + +def test_splice_text_modified_indices_after_the_changes( + loader_cache: LoaderDataCache, +) -> None: + sentence = "The quick brown fox jumps over the lazy dog" + sample = create_arg_extraction_sample(sentence, loader_cache) + # replace 'The' with 'The nonexistant' + callback = MagicMock(return_value=(0, 3, "The nonexistant")) + new_sample = cast(ArgumentsExtractionSample, splice_text(sample, callback)) + assert ( + new_sample.task.text + == "The nonexistant quick brown fox jumps over the lazy dog" + ) + assert new_sample.task.frame == sample.task.frame + assert new_sample.task.trigger_loc == sample.task.trigger_loc + 12 + assert len(new_sample.frame_elements) == 1 + assert new_sample.frame_elements[0].start_loc == 16 + assert new_sample.frame_elements[0].end_loc == 21 + callback.assert_called_with(sentence, [16, 4, 9]) + + +def test_throw_error_if_splice_text_delete_a_critical_index( + loader_cache: LoaderDataCache, +) -> None: + sentence = "The quick brown fox jumps over the lazy dog" + sample = create_arg_extraction_sample(sentence, loader_cache) + callback = MagicMock(return_value=(4, 2, "")) + with pytest.raises(ValueError): + splice_text(sample, callback) + + +def test_splice_text_for_frame_classification_sample( + loader_cache: LoaderDataCache, +) -> None: + sentence = "The quick brown fox jumps over the lazy dog" + sample = create_frame_classification_sample(sentence, loader_cache) + # replace 'The' with 'The nonexistant' + callback = MagicMock(return_value=(0, 3, "The nonexistant")) + new_sample = cast(FrameClassificationSample, splice_text(sample, callback)) + assert ( + new_sample.task.text + == "The nonexistant quick brown fox jumps over the lazy dog" + ) + assert new_sample.frame == sample.frame + assert new_sample.task.trigger_loc == sample.task.trigger_loc + 12 + callback.assert_called_with(sentence, [16]) + + +def test_splice_text_for_trigger_identification_sample() -> None: + sentence = "The quick brown fox jumps over the lazy dog" + sample = create_trigger_identification_sample(sentence) + # replace 'The' with 'The nonexistant' + callback = MagicMock(return_value=(0, 3, "The nonexistant")) + new_sample = cast( + TriggerIdentificationSample, + splice_text(sample, callback), + ) + assert ( + new_sample.task.text + == "The nonexistant quick brown fox jumps over the lazy dog" + ) + assert new_sample.trigger_locs == [28] + callback.assert_called_with(sentence, [16]) diff --git a/tests/data/augmentations/test_DoubleQuotesAugmentation.py b/tests/data/augmentations/test_DoubleQuotesAugmentation.py new file mode 100644 index 0000000..62bbe42 --- /dev/null +++ b/tests/data/augmentations/test_DoubleQuotesAugmentation.py @@ -0,0 +1,46 @@ +from __future__ import annotations +from typing import cast +import pytest + +from frame_semantic_transformer.data.augmentations import DoubleQuotesAugmentation +from frame_semantic_transformer.data.tasks import ( + TriggerIdentificationSample, + TriggerIdentificationTask, +) + + +def create_trigger_identification_sample(sentence: str) -> TriggerIdentificationSample: + return TriggerIdentificationSample( + task=TriggerIdentificationTask(text=sentence), + trigger_locs=[0], + ) + + +@pytest.mark.parametrize( + "input,expected", + [ + ("I am a ``banana'' .", 'I am a "banana" .'), + ("she ``says'' ``hi''", 'she "says" "hi"'), + ], +) +def test_DoubleQuotesAugmentation_changes_latex_quotes_to_standard_quotes( + input: str, expected: str +) -> None: + augmentation = DoubleQuotesAugmentation(1.0) + sample = create_trigger_identification_sample(input) + new_sample = cast(TriggerIdentificationSample, augmentation(sample)) + assert new_sample.task.text == expected + + +def test_DoubleQuotesAugmentation_changes_standard_quotes_to_latex_quotes() -> None: + augmentation = DoubleQuotesAugmentation(1.0) + sample = create_trigger_identification_sample('This is a quote: " .') + new_sample = cast(TriggerIdentificationSample, augmentation(sample)) + assert new_sample.task.text in {"This is a quote: '' .", "This is a quote: `` ."} + + +def test_DoubleQuotesAugmentation_leaves_samples_unchanged_if_no_quotes_are_present() -> None: + augmentation = DoubleQuotesAugmentation(1.0) + sample = create_trigger_identification_sample("Nothing to see here !") + new_sample = cast(TriggerIdentificationSample, augmentation(sample)) + assert new_sample.task.text == sample.task.text diff --git a/tests/data/augmentations/test_KeyboardAugmentation.py b/tests/data/augmentations/test_KeyboardAugmentation.py new file mode 100644 index 0000000..4e1cc83 --- /dev/null +++ b/tests/data/augmentations/test_KeyboardAugmentation.py @@ -0,0 +1,48 @@ +from __future__ import annotations +from typing import cast + +from frame_semantic_transformer.data.augmentations import KeyboardAugmentation +from frame_semantic_transformer.data.tasks import ( + TriggerIdentificationSample, + TriggerIdentificationTask, +) + + +def create_trigger_identification_sample(sentence: str) -> TriggerIdentificationSample: + return TriggerIdentificationSample( + task=TriggerIdentificationTask(text=sentence), + trigger_locs=[16], + ) + + +def test_KeyboardAugmentation() -> None: + sentence = "I like , to , eat food 1234 and I like in a boat ." + augmentation = KeyboardAugmentation(1.0) + sample = create_trigger_identification_sample(sentence) + + is_same = True + # do this 20 times since it's not guaranteed to change anything every time + for _ in range(20): + new_sample = cast(TriggerIdentificationSample, augmentation(sample)) + new_sentence = new_sample.task.text + if new_sentence != sentence: + is_same = False + assert len(new_sentence) == len(sentence) + assert len(new_sentence.split()) == len(sentence.split()) + assert not is_same + + +def test_KeyboardAugmentation_with_complex_sentence() -> None: + sentence = "Totally absorbed , the ringers stare straight ahead , using peripheral vision ( they call it `` rope-sight '' ) to watch the other ropes and thus time their pulls ." + augmentation = KeyboardAugmentation(1.0) + sample = create_trigger_identification_sample(sentence) + + is_same = True + # do this 20 times since it's not guaranteed to change anything every time + for _ in range(20): + new_sample = cast(TriggerIdentificationSample, augmentation(sample)) + new_sentence = new_sample.task.text + if new_sentence != sentence: + is_same = False + assert len(new_sentence) == len(sentence) + assert not is_same diff --git a/tests/data/augmentations/test_LowercaseAugmentation.py b/tests/data/augmentations/test_LowercaseAugmentation.py index eb13782..dc3848d 100644 --- a/tests/data/augmentations/test_LowercaseAugmentation.py +++ b/tests/data/augmentations/test_LowercaseAugmentation.py @@ -1,32 +1,30 @@ from __future__ import annotations +from typing import cast import pytest from frame_semantic_transformer.data.augmentations import LowercaseAugmentation +from frame_semantic_transformer.data.tasks import ( + TriggerIdentificationSample, + TriggerIdentificationTask, +) + + +def create_trigger_identification_sample(sentence: str) -> TriggerIdentificationSample: + return TriggerIdentificationSample( + task=TriggerIdentificationTask(text=sentence), + trigger_locs=[16], + ) @pytest.mark.parametrize( "input,expected", [ - ( - ("TASK: I am a banana.", "I am a banana."), - ("TASK: i am a banana.", "i am a banana."), - ), - ( - ("TASK: I AM A BANANA !", "I AM A BANANA !"), - ("TASK: i am a banana !", "i am a banana !"), - ), - ( - ("TASK | Param1 | Param 2 : I AM A BANANA !", "I AM A BANANA !"), - ("TASK | Param1 | Param 2 : i am a banana !", "i am a banana !"), - ), - ( - ("TASK: Ch 1: I AM A BANANA !", "Ch 1: I AM A BANANA !"), - ("TASK: ch 1: i am a banana !", "ch 1: i am a banana !"), - ), + ("I am a banana.", "i am a banana."), + ("I AM A BANANA !", "i am a banana !"), ], ) -def test_LowercaseAugmentation( - input: tuple[str, str], expected: tuple[str, str] -) -> None: +def test_LowercaseAugmentation(input: str, expected: str) -> None: augmentation = LowercaseAugmentation(1.0) - assert augmentation(*input) == expected + sample = create_trigger_identification_sample(input) + new_sample = cast(TriggerIdentificationSample, augmentation(sample)) + assert new_sample.task.text == expected diff --git a/tests/data/augmentations/test_RemoveContractionsAugmentation.py b/tests/data/augmentations/test_RemoveContractionsAugmentation.py deleted file mode 100644 index a8419cd..0000000 --- a/tests/data/augmentations/test_RemoveContractionsAugmentation.py +++ /dev/null @@ -1,34 +0,0 @@ -from __future__ import annotations -import pytest - -from frame_semantic_transformer.data.augmentations import RemoveContractionsAugmentation - - -@pytest.mark.parametrize( - "input,expected", - [ - ( - ("TASK: I can't go I won't go", "I can't go I won't go"), - ("TASK: I cannot go I will not go", "I cannot go I will not go"), - ), - ( - ( - "TASK: shouldn't couldn't they're we'll they've", - "shouldn't couldn't they're we'll they've", - ), - ( - "TASK: should not could not they are we will they have", - "should not could not they are we will they have", - ), - ), - ( - ("TASK | Param1 | Param 2 : We're didn*'t", "We're didn't"), - ("TASK | Param1 | Param 2 : We're didn*'t", "We're didn't"), - ), - ], -) -def test_RemoveContractionsAugmentation( - input: tuple[str, str], expected: tuple[str, str] -) -> None: - augmentation = RemoveContractionsAugmentation(1.0) - assert augmentation(*input) == expected diff --git a/tests/data/augmentations/test_RemoveEndPunctuationAugmentation.py b/tests/data/augmentations/test_RemoveEndPunctuationAugmentation.py index 4b03286..950687b 100644 --- a/tests/data/augmentations/test_RemoveEndPunctuationAugmentation.py +++ b/tests/data/augmentations/test_RemoveEndPunctuationAugmentation.py @@ -1,30 +1,33 @@ from __future__ import annotations +from typing import cast import pytest from frame_semantic_transformer.data.augmentations import ( RemoveEndPunctuationAugmentation, ) +from frame_semantic_transformer.data.tasks import ( + TriggerIdentificationSample, + TriggerIdentificationTask, +) + + +def create_trigger_identification_sample(sentence: str) -> TriggerIdentificationSample: + return TriggerIdentificationSample( + task=TriggerIdentificationTask(text=sentence), + trigger_locs=[2], + ) @pytest.mark.parametrize( "input,expected", [ - ( - ("TASK: I am a banana.", "I am a banana."), - ("TASK: I am a banana", "I am a banana"), - ), - ( - ("TASK: I am a banana!", "I am a banana!"), - ("TASK: I am a banana", "I am a banana"), - ), - ( - ("TASK: I am a banana .", "I am a banana ."), - ("TASK: I am a banana", "I am a banana"), - ), + ("I am a banana.", "I am a banana"), + ("I am a banana!", "I am a banana"), + ("I ! am a banana .", "I ! am a banana"), ], ) -def test_RemoveEndPunctuationAugmentation( - input: tuple[str, str], expected: tuple[str, str] -) -> None: +def test_RemoveEndPunctuationAugmentation(input: str, expected: str) -> None: augmentation = RemoveEndPunctuationAugmentation(1.0) - assert augmentation(*input) == expected + sample = create_trigger_identification_sample(input) + new_sample = cast(TriggerIdentificationSample, augmentation(sample)) + assert new_sample.task.text == expected diff --git a/tests/data/augmentations/test_SimpleMisspellingAugmentation.py b/tests/data/augmentations/test_SimpleMisspellingAugmentation.py new file mode 100644 index 0000000..dac71d3 --- /dev/null +++ b/tests/data/augmentations/test_SimpleMisspellingAugmentation.py @@ -0,0 +1,29 @@ +from __future__ import annotations +from typing import cast + +from frame_semantic_transformer.data.augmentations import SimpleMisspellingAugmentation +from frame_semantic_transformer.data.tasks import ( + TriggerIdentificationSample, + TriggerIdentificationTask, +) + + +def create_trigger_identification_sample(sentence: str) -> TriggerIdentificationSample: + return TriggerIdentificationSample( + task=TriggerIdentificationTask(text=sentence), + trigger_locs=[16], + ) + + +def test_SimpleMisspellingAugmentation() -> None: + sentence = "I like to eat food 1234" + # just to make it almost certain something will be changed + augmentation = SimpleMisspellingAugmentation( + 1.0, min_misspellings_per_sentence=20, max_misspellings_per_sentence=20 + ) + sample = create_trigger_identification_sample(sentence) + new_sample = cast(TriggerIdentificationSample, augmentation(sample)) + new_sentence = new_sample.task.text + assert new_sentence != sentence + assert len(new_sentence) == len(sentence) + assert len(new_sentence.split()) == len(sentence.split()) diff --git a/tests/data/augmentations/test_StripPunctuationAugmentation.py b/tests/data/augmentations/test_StripPunctuationAugmentation.py new file mode 100644 index 0000000..19b3bc0 --- /dev/null +++ b/tests/data/augmentations/test_StripPunctuationAugmentation.py @@ -0,0 +1,29 @@ +from __future__ import annotations +from typing import cast + +from frame_semantic_transformer.data.augmentations import StripPunctuationAugmentation +from frame_semantic_transformer.data.tasks import ( + TriggerIdentificationSample, + TriggerIdentificationTask, +) + + +def create_trigger_identification_sample(sentence: str) -> TriggerIdentificationSample: + return TriggerIdentificationSample( + task=TriggerIdentificationTask(text=sentence), + trigger_locs=[0], + ) + + +def test_StripPunctuationAugmentation_removes_punctuation() -> None: + augmentation = StripPunctuationAugmentation(1.0, min_to_remove=5, max_to_remove=5) + sample = create_trigger_identification_sample("This! is? A! sentence.") + new_sample = cast(TriggerIdentificationSample, augmentation(sample)) + assert new_sample.task.text == "This is A sentence" + + +def test_StripPunctuationAugmentation_removes_up_to_max_to_remove() -> None: + augmentation = StripPunctuationAugmentation(1.0, min_to_remove=5, max_to_remove=5) + sample = create_trigger_identification_sample("This! is! A! sentence !!! !") + new_sample = cast(TriggerIdentificationSample, augmentation(sample)) + assert new_sample.task.text.count("!") == 2 diff --git a/tests/data/augmentations/test_SynonymAugmenter.py b/tests/data/augmentations/test_SynonymAugmenter.py new file mode 100644 index 0000000..2bfa30f --- /dev/null +++ b/tests/data/augmentations/test_SynonymAugmenter.py @@ -0,0 +1,31 @@ +from __future__ import annotations +from typing import cast + +from frame_semantic_transformer.data.augmentations import SynonymAugmentation +from frame_semantic_transformer.data.tasks import ( + TriggerIdentificationSample, + TriggerIdentificationTask, +) + + +def create_trigger_identification_sample(sentence: str) -> TriggerIdentificationSample: + return TriggerIdentificationSample( + task=TriggerIdentificationTask(text=sentence), + trigger_locs=[16], + ) + + +def test_SynonymAugmentation() -> None: + sentence = "I like to eat food 1234 and I like in a boat ." + # just to make it almost certain something will be changed + augmentation = SynonymAugmentation(1.0) + sample = create_trigger_identification_sample(sentence) + + is_same = True + # do this 20 times since it's not guaranteed to change anything every time + for _ in range(20): + new_sample = cast(TriggerIdentificationSample, augmentation(sample)) + new_sentence = new_sample.task.text + if new_sentence != sentence: + is_same = False + assert not is_same diff --git a/tests/data/augmentations/test_UppercaseAugmentation.py b/tests/data/augmentations/test_UppercaseAugmentation.py new file mode 100644 index 0000000..5e0f81a --- /dev/null +++ b/tests/data/augmentations/test_UppercaseAugmentation.py @@ -0,0 +1,38 @@ +from __future__ import annotations +from typing import cast +import pytest + +from frame_semantic_transformer.data.augmentations import UppercaseAugmentation +from frame_semantic_transformer.data.tasks import ( + TriggerIdentificationSample, + TriggerIdentificationTask, +) + + +def create_trigger_identification_sample(sentence: str) -> TriggerIdentificationSample: + return TriggerIdentificationSample( + task=TriggerIdentificationTask(text=sentence), + trigger_locs=[16], + ) + + +@pytest.mark.parametrize( + "input,expected", + [ + ("I am a banana.", "I AM A BANANA."), + ("I AM A banana !", "I AM A BANANA !"), + ], +) +def test_UppercaseAugmentation(input: str, expected: str) -> None: + augmentation = UppercaseAugmentation(1.0) + sample = create_trigger_identification_sample(input) + new_sample = cast(TriggerIdentificationSample, augmentation(sample)) + assert new_sample.task.text == expected + + +def test_UppercaseAugmentation_returns_original_sentence_if_contains_ligature() -> None: + sentence = "this is one character: fi ." + augmentation = UppercaseAugmentation(1.0) + sample = create_trigger_identification_sample(sentence) + new_sample = cast(TriggerIdentificationSample, augmentation(sample)) + assert new_sample.task.text == sentence diff --git a/tests/data/augmentations/test_chain_augmentations.py b/tests/data/augmentations/test_chain_augmentations.py index f88fedf..0ba8df2 100644 --- a/tests/data/augmentations/test_chain_augmentations.py +++ b/tests/data/augmentations/test_chain_augmentations.py @@ -1,25 +1,31 @@ from __future__ import annotations +from typing import cast from frame_semantic_transformer.data.augmentations import ( chain_augmentations, LowercaseAugmentation, - RemoveContractionsAugmentation, RemoveEndPunctuationAugmentation, ) +from frame_semantic_transformer.data.tasks import ( + TriggerIdentificationSample, + TriggerIdentificationTask, +) + + +def create_trigger_identification_sample(sentence: str) -> TriggerIdentificationSample: + return TriggerIdentificationSample( + task=TriggerIdentificationTask(text=sentence), + trigger_locs=[16], + ) def test_chain_augmentations_applys_all_augmentations() -> None: augmentation = chain_augmentations( [ LowercaseAugmentation(1.0), - RemoveContractionsAugmentation(1.0), RemoveEndPunctuationAugmentation(1.0), ] ) - - input = "TASK: I don't like BANANAS!" - target = "I don't like BANANAS!" - assert augmentation(input, target) == ( - "TASK: i do not like bananas", - "i do not like bananas", - ) + sample = create_trigger_identification_sample("I don't like BANANAS!") + new_sample = cast(TriggerIdentificationSample, augmentation(sample)) + assert new_sample.task.text == "i don't like bananas" diff --git a/tests/data/loaders/framenet17/test_Framenet17TrainingLoader.py b/tests/data/loaders/framenet17/test_Framenet17TrainingLoader.py index e4ede68..9c9d2c8 100644 --- a/tests/data/loaders/framenet17/test_Framenet17TrainingLoader.py +++ b/tests/data/loaders/framenet17/test_Framenet17TrainingLoader.py @@ -16,13 +16,42 @@ def test_load_sesame_test_samples() -> None: sentences = training_loader.load_test_data() samples = tasks_from_annotated_sentences(sentences, loader_cache) + trigger_id_samples = [ + sample + for sample in samples + if sample.get_task_name() == "trigger_identification" + ] assert len(samples) == 15126 + assert len(trigger_id_samples) == 1354 def test_load_sesame_dev_samples() -> None: sentences = training_loader.load_validation_data() samples = tasks_from_annotated_sentences(sentences, loader_cache) + trigger_id_samples = [ + sample + for sample in samples + if sample.get_task_name() == "trigger_identification" + ] assert len(samples) == 5166 + assert len(trigger_id_samples) == 328 + + +def test_load_sesame_train_samples_with_exemplars() -> None: + training_loader_with_exemplars = Framenet17TrainingLoader(include_exemplars=True) + sentences = training_loader_with_exemplars.load_training_data() + samples = tasks_from_annotated_sentences(sentences, loader_cache) + trigger_id_samples = [ + sample + for sample in samples + if sample.get_task_name() == "trigger_identification" + ] + frame_id_samples = [ + sample for sample in samples if sample.get_task_name() == "frame_classification" + ] + assert len(trigger_id_samples) == 3425 + assert len(frame_id_samples) == 198482 + assert len(samples) == 400389 def test_load_sesame_train_samples() -> None: diff --git a/tests/data/loaders/propbank34/__snapshots__/test_Propbank34TrainingLoader.ambr b/tests/data/loaders/propbank34/__snapshots__/test_Propbank34TrainingLoader.ambr index c3c459d..81ad10c 100644 --- a/tests/data/loaders/propbank34/__snapshots__/test_Propbank34TrainingLoader.ambr +++ b/tests/data/loaders/propbank34/__snapshots__/test_Propbank34TrainingLoader.ambr @@ -1,9 +1,9 @@ # name: test_load_propbank_samples list([ - FrameAnnotatedSentence(text='Paris was once the sex capital of the world .', annotations=[FrameAnnotation(frame='be.01', trigger_locs=[6], frame_elements=[FrameElementAnnotation(name='ARG1', start_loc=0, end_loc=5), FrameElementAnnotation(name='ARGM-TMP', start_loc=10, end_loc=14), FrameElementAnnotation(name='ARG2', start_loc=15, end_loc=43)])]), - FrameAnnotatedSentence(text='But a crackdown on prostitution and the rise of porn megastores are destroying a unique , secret heritage .', annotations=[FrameAnnotation(frame='be.03', trigger_locs=[64], frame_elements=[]), FrameAnnotation(frame='destroy.01', trigger_locs=[68], frame_elements=[FrameElementAnnotation(name='ARGM-DIS', start_loc=0, end_loc=3), FrameElementAnnotation(name='ARG0', start_loc=4, end_loc=63), FrameElementAnnotation(name='ARG1', start_loc=79, end_loc=105)])]), - FrameAnnotatedSentence(text='Andrew Hussey reports', annotations=[FrameAnnotation(frame='report.01', trigger_locs=[14], frame_elements=[FrameElementAnnotation(name='ARG0', start_loc=0, end_loc=13)])]), - FrameAnnotatedSentence(text='From the outside at least , there is little to distinguish it from the dozen or so rival sex shops which line the streets in this part of Paris , in between the stalls flogging cheap hip - hop gear and drugs paraphernalia .', annotations=[FrameAnnotation(frame='be.02', trigger_locs=[34], frame_elements=[FrameElementAnnotation(name='ARGM-DIR', start_loc=0, end_loc=25), FrameElementAnnotation(name='ARG1', start_loc=37, end_loc=221)]), FrameAnnotation(frame='distinguish.01', trigger_locs=[47], frame_elements=[FrameElementAnnotation(name='ARG0', start_loc=37, end_loc=43), FrameElementAnnotation(name='ARG1', start_loc=59, end_loc=61), FrameElementAnnotation(name='ARG2', start_loc=62, end_loc=221)]), FrameAnnotation(frame='line.01', trigger_locs=[105], frame_elements=[FrameElementAnnotation(name='ARG2', start_loc=67, end_loc=98), FrameElementAnnotation(name='R-ARG2', start_loc=99, end_loc=104), FrameElementAnnotation(name='ARG1', start_loc=110, end_loc=121), FrameElementAnnotation(name='ARGM-LOC', start_loc=122, end_loc=143), FrameElementAnnotation(name='ARGM-LOC', start_loc=146, end_loc=221)]), FrameAnnotation(frame='flog.01', trigger_locs=[168], frame_elements=[FrameElementAnnotation(name='ARG0', start_loc=157, end_loc=167), FrameElementAnnotation(name='ARG1', start_loc=177, end_loc=221)])]), - FrameAnnotatedSentence(text='This part of the city has always had a seedy reputation .', annotations=[FrameAnnotation(frame='have.01', trigger_locs=[22], frame_elements=[])]), + FrameAnnotatedSentence(text='Paris was once the sex capital of the world .', annotations=[FrameAnnotation(frame='be.01', trigger_locs=[6], frame_elements=[FrameElementAnnotation(name='ARG1', start_loc=0, end_loc=5), FrameElementAnnotation(name='ARGM-TMP', start_loc=10, end_loc=14), FrameElementAnnotation(name='ARG2', start_loc=15, end_loc=43)])], skip_trigger_identification_task=False), + FrameAnnotatedSentence(text='But a crackdown on prostitution and the rise of porn megastores are destroying a unique , secret heritage .', annotations=[FrameAnnotation(frame='be.03', trigger_locs=[64], frame_elements=[]), FrameAnnotation(frame='destroy.01', trigger_locs=[68], frame_elements=[FrameElementAnnotation(name='ARGM-DIS', start_loc=0, end_loc=3), FrameElementAnnotation(name='ARG0', start_loc=4, end_loc=63), FrameElementAnnotation(name='ARG1', start_loc=79, end_loc=105)])], skip_trigger_identification_task=False), + FrameAnnotatedSentence(text='Andrew Hussey reports', annotations=[FrameAnnotation(frame='report.01', trigger_locs=[14], frame_elements=[FrameElementAnnotation(name='ARG0', start_loc=0, end_loc=13)])], skip_trigger_identification_task=False), + FrameAnnotatedSentence(text='From the outside at least , there is little to distinguish it from the dozen or so rival sex shops which line the streets in this part of Paris , in between the stalls flogging cheap hip - hop gear and drugs paraphernalia .', annotations=[FrameAnnotation(frame='be.02', trigger_locs=[34], frame_elements=[FrameElementAnnotation(name='ARGM-DIR', start_loc=0, end_loc=25), FrameElementAnnotation(name='ARG1', start_loc=37, end_loc=221)]), FrameAnnotation(frame='distinguish.01', trigger_locs=[47], frame_elements=[FrameElementAnnotation(name='ARG0', start_loc=37, end_loc=43), FrameElementAnnotation(name='ARG1', start_loc=59, end_loc=61), FrameElementAnnotation(name='ARG2', start_loc=62, end_loc=221)]), FrameAnnotation(frame='line.01', trigger_locs=[105], frame_elements=[FrameElementAnnotation(name='ARG2', start_loc=67, end_loc=98), FrameElementAnnotation(name='R-ARG2', start_loc=99, end_loc=104), FrameElementAnnotation(name='ARG1', start_loc=110, end_loc=121), FrameElementAnnotation(name='ARGM-LOC', start_loc=122, end_loc=143), FrameElementAnnotation(name='ARGM-LOC', start_loc=146, end_loc=221)]), FrameAnnotation(frame='flog.01', trigger_locs=[168], frame_elements=[FrameElementAnnotation(name='ARG0', start_loc=157, end_loc=167), FrameElementAnnotation(name='ARG1', start_loc=177, end_loc=221)])], skip_trigger_identification_task=False), + FrameAnnotatedSentence(text='This part of the city has always had a seedy reputation .', annotations=[FrameAnnotation(frame='have.01', trigger_locs=[22], frame_elements=[])], skip_trigger_identification_task=False), ]) # --- diff --git a/tests/data/test_LoaderDataCache.py b/tests/data/test_LoaderDataCache.py index 94cc4ae..5424ed5 100644 --- a/tests/data/test_LoaderDataCache.py +++ b/tests/data/test_LoaderDataCache.py @@ -1,5 +1,7 @@ from __future__ import annotations +import pytest + from frame_semantic_transformer.data.LoaderDataCache import LoaderDataCache @@ -13,18 +15,57 @@ def test_get_lexical_unit_bigram_to_frame_lookup_map( def test_normalize_lexical_unit_ngram(loader_cache: LoaderDataCache) -> None: - assert loader_cache._normalize_lexical_unit_ngram(["can't", "stop"]) == "cant_stop" - assert loader_cache._normalize_lexical_unit_ngram(["he", "eats"]) == "he_eat" - assert loader_cache._normalize_lexical_unit_ngram(["eats"]) == "eat" + assert loader_cache._normalize_lexical_unit_ngram(["can't", "stop"]) == { + "cant_stop" + } + assert loader_cache._normalize_lexical_unit_ngram(["he", "eats"]) == { + "he_eat", + "he_eats", + } + assert loader_cache._normalize_lexical_unit_ngram(["eats"]) == {"eat", "eats"} -def test_get_possible_frames_for_trigger_bigrams(loader_cache: LoaderDataCache) -> None: - assert loader_cache.get_possible_frames_for_trigger_bigrams( - [["can't", "help"], ["help", "it"], ["help"]] - ) == ["Self_control", "Assistance"] - assert loader_cache.get_possible_frames_for_trigger_bigrams( - [["can't", "help"]] - ) == ["Self_control"] +@pytest.mark.parametrize( + "ngrams,expected", + [ + ([["can't", "help"], ["help", "it"], ["help"]], ["Self_control", "Assistance"]), + ([["can't", "help"]], ["Self_control"]), + ( + [["and", "staffed"], ["staffed", "by"], ["staffed"]], + ["Employing", "Working_a_post"], + ), + ( + [["strongest"]], + [ + "Chemical_potency", + "Expertise", + "Judgment_of_intensity", + "Level_of_force_exertion", + "Level_of_force_resistance", + "Usefulness", + ], + ), + ( + [["done"]], + [ + "Activity_done_state", + "Ingest_substance", + "Intentionally_act", + "Intentionally_affect", + "Process_completed_state", + "Sex", + "Thriving", + "Touring", + "Dressing", + "Giving", + ], + ), + ], +) +def test_get_possible_frames_for_trigger_bigrams( + ngrams: list[list[str]], expected: list[str], loader_cache: LoaderDataCache +) -> None: + assert loader_cache.get_possible_frames_for_trigger_bigrams(ngrams) == expected def test_get_possible_frames_for_trigger_bigrams_stems_bigrams( @@ -33,3 +74,40 @@ def test_get_possible_frames_for_trigger_bigrams_stems_bigrams( assert loader_cache.get_possible_frames_for_trigger_bigrams( [["can't", "helps"]] ) == ["Self_control"] + + +@pytest.mark.parametrize( + "ngrams,expected", + [ + ( + [["use", "trying"], ["trying"], ["trying", "to"]], + [ + "Attempt", + "Attempt_means", + "Operational_testing", + "Tasting", + "Trial", + "Try_defendant", + "Trying_out", + ], + ), + ( + [["the lift"], ["lift"]], + [ + "Body_movement", + "Building_subparts", + "Cause_change_of_position_on_a_scale", + "Cause_motion", + "Cause_to_end", + "Connecting_architecture", + "Theft", + ], + ), + ], +) +def test_get_possible_frames_for_trigger_bigrams_paper_examples( + ngrams: list[list[str]], + expected: list[str], + loader_cache: LoaderDataCache, +) -> None: + assert loader_cache.get_possible_frames_for_trigger_bigrams(ngrams) == expected diff --git a/tests/training/test_ModelRecorder.py b/tests/training/test_ModelRecorder.py new file mode 100644 index 0000000..322ecef --- /dev/null +++ b/tests/training/test_ModelRecorder.py @@ -0,0 +1,98 @@ +from __future__ import annotations +from unittest.mock import MagicMock + +from frame_semantic_transformer.training.ModelRecorder import ( + ModelRecorder, + ModelSaveRecord, + _find_best_val_loss_model, + _find_best_val_metric_models, +) + + +def test_ModelRecorder_get_save_path() -> None: + recorder = ModelRecorder("output_dir") + save_path = recorder.get_save_path(1, 0.5, {"task1": 0.5, "task2": 0.6}) + assert save_path == "output_dir/epoch=1-val_loss=0.5-task1=0.5-task2=0.6" + + +def test_ModelRecorder_get_save_path_no_task_val_metrics() -> None: + recorder = ModelRecorder("output_dir") + save_path = recorder.get_save_path(1, 0.5) + assert save_path == "output_dir/epoch=1-val_loss=0.5" + + +def test_ModelRecorder_get_save_path_replaces_dashes_with_underscores_in_tasks() -> None: + recorder = ModelRecorder("output_dir") + save_path = recorder.get_save_path(1, 0.5, {"task1-f1": 0.5, "task2-f1": 0.6}) + assert save_path == "output_dir/epoch=1-val_loss=0.5-task1_f1=0.5-task2_f1=0.6" + + +def test_ModelRecorder_save_model() -> None: + model = MagicMock() + tokenizer = MagicMock() + recorder = ModelRecorder("output_dir") + recorder.save_model(model, tokenizer, 1, 0.5, {"task1": 0.5, "task2": 0.6}) + model.save_pretrained.assert_called_once_with( + "output_dir/epoch=1-val_loss=0.5-task1=0.5-task2=0.6" + ) + tokenizer.save_pretrained.assert_called_once_with( + "output_dir/epoch=1-val_loss=0.5-task1=0.5-task2=0.6" + ) + assert recorder.records == [ + ModelSaveRecord( + epoch=1, + val_loss=0.5, + task_val_metrics={"task1": 0.5, "task2": 0.6}, + save_path="output_dir/epoch=1-val_loss=0.5-task1=0.5-task2=0.6", + ) + ] + + +def test_find_best_val_metric_models() -> None: + epoch1 = ModelSaveRecord( + epoch=1, + val_loss=0.5, + task_val_metrics=None, + save_path="...", + ) + epoch2 = ModelSaveRecord( + epoch=2, + val_loss=0.4, + task_val_metrics={"task1": 0.6, "task2": 0.7, "task3": 0.8}, + save_path="...", + ) + epoch3 = ModelSaveRecord( + epoch=3, + val_loss=0.3, + task_val_metrics={"task1": 0.7, "task2": 0.8, "task3": 0.7}, + save_path="...", + ) + model_records = [epoch1, epoch2, epoch3] + assert _find_best_val_metric_models(model_records) == { + "task1": epoch3, + "task2": epoch3, + "task3": epoch2, + } + + +def test_find_best_val_loss_model() -> None: + epoch1 = ModelSaveRecord( + epoch=1, + val_loss=0.43, + task_val_metrics=None, + save_path="...", + ) + epoch2 = ModelSaveRecord( + epoch=2, + val_loss=0.4, + task_val_metrics=None, + save_path="...", + ) + epoch3 = ModelSaveRecord( + epoch=3, + val_loss=0.45, + task_val_metrics={"task1": 0.7, "task2": 0.8, "task3": 0.7}, + save_path="...", + ) + model_records = [epoch1, epoch2, epoch3] + assert _find_best_val_loss_model(model_records) == epoch2 diff --git a/tests/data/training/test_find_best_val_model_paths.py b/tests/training/test_find_best_val_model_paths.py similarity index 74% rename from tests/data/training/test_find_best_val_model_paths.py rename to tests/training/test_find_best_val_model_paths.py index 449a094..d3b73b7 100644 --- a/tests/data/training/test_find_best_val_model_paths.py +++ b/tests/training/test_find_best_val_model_paths.py @@ -6,7 +6,7 @@ def test_get_model_scores() -> None: - model_path = "epoch=1--val_loss=1.0000--val_args_extraction_f1=0.5000--val_trigger_identification_f1=0.4000--val_frame_classification_f1=0.6000" + model_path = "epoch=1-val_loss=1.0000-val_args_extraction_f1=0.5000-val_trigger_identification_f1=0.4000-val_frame_classification_f1=0.6000" expected_scores = { "val_loss": 1.0, "val_args_extraction_f1": 0.5,