Skip to content

Commit

Permalink
feat: new models trained on Framenet exemplars (#18)
Browse files Browse the repository at this point in the history
* include exemplars in framenet training

* skipping invalid trigger exemplars

* skip exemplars by default during training

* fixing tests

* improving data augmentations

* ensure wordnet download for inference

* updating snapshots

* adding more info when augmentations fail validation

* adding more augmentations from nlpaug

* fixing linting

* fixing keyboard augmentation

* more checks on keyboard augmentation

* tweaking augmentations

* fixing tests

* adding safety check to uppercase augmentation

* lower augmentation rate

* adding more augmentations

* tweaking augs

* removing debugging output

* reduce augmentation

* tweaking augmentation probs

* tweaking augmentation probs

* fixing type import

* adding option to delete non-optimal models as training progresses

* tweaking augmentations

* updating models

* updating README with new model stats
  • Loading branch information
chanind authored Mar 15, 2023
1 parent a061507 commit 3f937fb
Show file tree
Hide file tree
Showing 55 changed files with 1,693 additions and 313 deletions.
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ This library uses the same train/dev/test documents and evaluation methodology a

| Task | Sesame F1 (dev/test) | Small Model F1 (dev/test) | Base Model F1 (dev/test) |
| ---------------------- | -------------------- | ------------------------- | ------------------------ |
| Trigger identification | 0.80 / 0.73 | 0.74 / 0.70 | 0.78 / 0.71 |
| Frame classification | 0.90 / 0.87 | 0.83 / 0.81 | 0.89 / 0.87 |
| Argument extraction | 0.61 / 0.61 | 0.68 / 0.70 | 0.74 / 0.72 |
| Trigger identification | 0.80 / 0.73 | 0.75 / 0.71 | 0.78 / 0.74 |
| Frame classification | 0.90 / 0.87 | 0.87 / 0.86 | 0.91 / 0.89 |
| Argument extraction | 0.61 / 0.61 | 0.76 / 0.73 | 0.78 / 0.75 |

The base model performs similarly to Open-Sesame on trigger identification and frame classification tasks, but outperforms it by a significant margin on argument extraction. The small pretrained model has lower F1 than base across the board, but is 1/4 the size and still outperforms Open-Sesame at argument extraction.

Expand Down
2 changes: 1 addition & 1 deletion frame_semantic_transformer/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,6 @@

MODEL_MAX_LENGTH = 512
OFFICIAL_RELEASES = ["base", "small"] # TODO: small, large
MODEL_REVISION = "v0.1.0"
MODEL_REVISION = "v0.2.0"
PADDING_LABEL_ID = -100
DEFAULT_NUM_WORKERS = os.cpu_count() or 2
42 changes: 30 additions & 12 deletions frame_semantic_transformer/data/LoaderDataCache.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from __future__ import annotations
from collections import defaultdict
from functools import lru_cache
from itertools import product
from typing import TYPE_CHECKING

from .loaders.loader import InferenceLoader
from .frame_types import Frame
if TYPE_CHECKING:
from .frame_types import Frame
from .loaders.loader import InferenceLoader


class LoaderDataCache:
Expand Down Expand Up @@ -77,11 +80,13 @@ def get_lexical_unit_bigram_to_frame_lookup_map(self) -> dict[str, list[str]]:
for part in parts:
# also key this as a mongram if there's only 1 element or the word is rare enough
if len(parts) == 1 or self.loader.prioritize_lexical_unit(part):
lu_bigrams.append(self._normalize_lexical_unit_ngram([part]))
for norm_part in self._normalize_lexical_unit_ngram([part]):
lu_bigrams.append(norm_part)
if prev_part is not None:
lu_bigrams.append(
self._normalize_lexical_unit_ngram([prev_part, part])
)
for norm_parts in self._normalize_lexical_unit_ngram(
[prev_part, part]
):
lu_bigrams.append(norm_parts)
prev_part = part

for bigram in lu_bigrams:
Expand All @@ -97,20 +102,33 @@ def get_possible_frames_for_trigger_bigrams(
possible_frames = []
lookup_map = self.get_lexical_unit_bigram_to_frame_lookup_map()
for bigram in bigrams:
normalized_bigram = self._normalize_lexical_unit_ngram(bigram)
if normalized_bigram in lookup_map:
bigram_frames = lookup_map[normalized_bigram]
possible_frames += bigram_frames
# sorted here just to get a consistent ordering
for normalized_bigram in sorted(self._normalize_lexical_unit_ngram(bigram)):
if normalized_bigram in lookup_map:
bigram_frames = lookup_map[normalized_bigram]
possible_frames += bigram_frames
# remove duplicates, while preserving order
# https://stackoverflow.com/questions/1653970/does-python-have-an-ordered-set/53657523#53657523
return list(dict.fromkeys(possible_frames))

def _normalize_lexical_unit_ngram(self, ngram: list[str]) -> str:
return "_".join([self.loader.normalize_lexical_unit_text(tok) for tok in ngram])
def _normalize_lexical_unit_ngram(self, ngram: list[str]) -> set[str]:
norm_toks = [
setify(self.loader.normalize_lexical_unit_text(tok)) for tok in ngram
]
return {"_".join(combo) for combo in product(*norm_toks)}


def normalize_name(name: str) -> str:
"""
Normalize a frame or element name to be lowercase and without underscores
"""
return name.lower().replace("_", "")


def setify(input: str | set[str]) -> set[str]:
"""
Convert a string or set to a set
"""
if isinstance(input, str):
return {input}
return input
9 changes: 5 additions & 4 deletions frame_semantic_transformer/data/TaskSampleDataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

class TaskSampleDataset(Dataset[Any]):
samples: Sequence[TaskSample]
augmentation: Optional[Callable[[str, str], tuple[str, str]]] = None
augmentation: Optional[Callable[[TaskSample], TaskSample]] = None
tokenizer: T5TokenizerFast

def __init__(
Expand Down Expand Up @@ -59,10 +59,11 @@ def parse_sample(
self, sample: TaskSample
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:

input = sample.get_input()
target = sample.get_target()
aug_sample = sample
if self.augmentation:
input, target = self.augmentation(input, target)
aug_sample = self.augmentation(sample)
input = aug_sample.get_input()
target = aug_sample.get_target()

input_encoding = self.tokenizer(
input,
Expand Down
26 changes: 19 additions & 7 deletions frame_semantic_transformer/data/augmentations/DataAugmentation.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,41 @@
from __future__ import annotations

from typing import Callable, Union
from abc import ABC, abstractmethod
from random import uniform

from frame_semantic_transformer.data.tasks.TaskSample import TaskSample


ProbabilityType = Union[float, Callable[[TaskSample], float]]


class DataAugmentation(ABC):
"""
Base class for data augmentations on training data
"""

probability: float
probability: ProbabilityType

def __init__(self, probability: float):
def __init__(self, probability: ProbabilityType):
self.probability = probability

def __call__(self, input: str, output: str) -> tuple[str, str]:
def __call__(self, task_sample: TaskSample) -> TaskSample:
"""
randomly apply this augmentation in proportion to self.probability
"""
rand_val = uniform(0, 1.0)
if rand_val > self.probability:
return (input, output)
return self.apply_augmentation(input, output)
if rand_val > self.get_probability(task_sample):
return task_sample
return self.apply_augmentation(task_sample)

def get_probability(self, task_sample: TaskSample) -> float:
if callable(self.probability):
return self.probability(task_sample)
return self.probability

@abstractmethod
def apply_augmentation(self, input: str, output: str) -> tuple[str, str]:
def apply_augmentation(self, task_sample: TaskSample) -> TaskSample:
"""
Main logic for subclasses to implement
"""
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
from __future__ import annotations
import random

from frame_semantic_transformer.data.augmentations.modification_helpers import (
splice_text,
)
from frame_semantic_transformer.data.tasks import TaskSample
from .DataAugmentation import DataAugmentation
from .modification_helpers.get_sample_text import get_sample_text


LATEX_QUOTES = ["``", "''"]
STANDARD_QUOTE = '"'
ALL_QUOTES = LATEX_QUOTES + [STANDARD_QUOTE]


class DoubleQuotesAugmentation(DataAugmentation):
def apply_augmentation(self, task_sample: TaskSample) -> TaskSample:
sample_text = get_sample_text(task_sample)

# if standard quotes are used, convert to latex quotes, and vice versa
to_latex = STANDARD_QUOTE in sample_text
from_quotes = [STANDARD_QUOTE] if to_latex else LATEX_QUOTES
to_quotes = LATEX_QUOTES if to_latex else [STANDARD_QUOTE]

updated_sample = task_sample
while count_instances(sample_text, from_quotes) > 0:
quote, start_loc = find_first_instance(sample_text, from_quotes)
try:
updated_sample = splice_text(
updated_sample,
lambda _text, _critical_indices: (
start_loc,
len(quote),
random.choice(to_quotes),
),
)
sample_text = get_sample_text(updated_sample)
except ValueError:
# The splice failed, so just return the sample
return updated_sample

return updated_sample


def count_instances(text: str, substrings: list[str]) -> int:
return sum(text.count(substring) for substring in substrings)


def find_first_instance(text: str, substrings: list[str]) -> tuple[str, int]:
"""
Find the first instance of any of the substrings in the text. Returns the substring and the
start location of the substring.
"""
for substring in substrings:
start_loc = text.find(substring)
if start_loc >= 0:
return substring, start_loc
raise ValueError(f"Could not find any of {substrings} in {text}")
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
from __future__ import annotations

from nlpaug.augmenter.char import KeyboardAug

from frame_semantic_transformer.data.augmentations.modification_helpers import (
modify_text_without_changing_length,
)
from frame_semantic_transformer.data.tasks import TaskSample
from .DataAugmentation import DataAugmentation, ProbabilityType


class KeyboardAugmentation(DataAugmentation):
"""
Wrapper about nlpaug's KeyboardAugmenter
Attempts to make spelling mistakes similar to what a user might make
"""

augmenter: KeyboardAug

def __init__(self, probability: ProbabilityType):
super().__init__(probability)
self.augmenter = KeyboardAug(
include_special_char=False, aug_char_p=0.1, aug_word_p=0.1
)
self.augmenter.include_detail = True

def apply_augmentation(self, task_sample: TaskSample) -> TaskSample:
def augment_sent(sentence: str) -> str:
# this augmentation removes spaces around punctuation, so just manually do the changes
_, changes = self.augmenter.augment(sentence)[0]
new_sentence = sentence
for change in changes:
# sometimes this augmenter changes token lengths, which we don't want
# just skip the changes if that happens
if len(change["orig_token"]) != len(change["new_token"]):
return new_sentence
if change["orig_start_pos"] != change["new_start_pos"]:
return new_sentence
start_pos = change["orig_start_pos"]
end_pos = change["orig_start_pos"] + len(change["orig_token"])
new_sentence = (
new_sentence[:start_pos]
+ change["new_token"]
+ new_sentence[end_pos:]
)
return new_sentence

return modify_text_without_changing_length(task_sample, augment_sent)
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
from __future__ import annotations
from frame_semantic_transformer.data.augmentations.modification_helpers import (
modify_text_without_changing_length,
)

from frame_semantic_transformer.data.tasks import TaskSample
from .DataAugmentation import DataAugmentation


class LowercaseAugmentation(DataAugmentation):
def apply_augmentation(self, input: str, output: str) -> tuple[str, str]:
task_def_index = input.find(":")
task_def = input[:task_def_index]
input_contents = input[task_def_index:]
# only lowercase the content, not the task definition
return (
task_def + input_contents.lower(),
output.lower(),
)
def apply_augmentation(self, task_sample: TaskSample) -> TaskSample:
return modify_text_without_changing_length(task_sample, str.lower)

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,13 +1,30 @@
from __future__ import annotations
from .DataAugmentation import DataAugmentation
import re
from frame_semantic_transformer.data.augmentations.modification_helpers import (
splice_text,
)
from frame_semantic_transformer.data.augmentations.modification_helpers.splice_text import (
is_valid_splice,
)

from frame_semantic_transformer.data.tasks import TaskSample
from .DataAugmentation import DataAugmentation

REMOVE_END_PUNCT_RE = r"\s*[.?!]\s*$"


class RemoveEndPunctuationAugmentation(DataAugmentation):
def apply_augmentation(self, input: str, output: str) -> tuple[str, str]:
return (
re.sub(REMOVE_END_PUNCT_RE, "", input),
re.sub(REMOVE_END_PUNCT_RE, "", output),
)
def apply_augmentation(self, task_sample: TaskSample) -> TaskSample:
def splice_end_punct_cb(
sentence: str, critical_indices: list[int]
) -> tuple[int, int, str] | None:
match = re.search(REMOVE_END_PUNCT_RE, sentence)
if match is None:
return None
start, end = match.span()
del_len = end - start
if not is_valid_splice(start, del_len, critical_indices):
return None
return start, del_len, ""

return splice_text(task_sample, splice_end_punct_cb)
Loading

0 comments on commit 3f937fb

Please sign in to comment.