Skip to content

Commit

Permalink
Refactor/remove access mutate (#111)
Browse files Browse the repository at this point in the history
* add: fn to make unix timestamp (ms).

* add: export utils.

* add: fn to check unix timestamp.

* update: validations, json.

Set type validations for relevant attributes.

A method to convert the object to dict.

* update: use list for easy json.

* add: prevent plugin run on condition.

* refactor: types reference circular dependency.

* update: remove `access` and `mutate`.

* refactor: use immutable `Input` and `Output`.

* refactor: remove getattr and decrease opaque code.

* style: lint.

* update: export from plugin.base.

* style: support annotations of self.

* update: export is_utterance from utils.

* refactor: support `Input` `Output`

we will pass `Input`, `Output` to plugins.

* refactor: remove access, mutate.

* test: support removal of access mutate.

* fix: deecopy -> deepcopy.

* fix: remove self on chain.

* fix: argument collisions.

* update: default output is Output.

* update: reference time is optional.

* update: method signature.

(*args) to (Input, Output)

* update: Intent is an attrs class.

* update: use score for ordering.

* fix: prevent argument collision.

* update: clf_feature is a `List[str]`.

* refactor: use method with simpler signature.

* add: method to create new instances from a dict.

* add: normalize utterances.

* add: method to create instances from dict.

* add: semi-serializer.

* style: annotations support

* fix: indentation.

* update: reference for serializer can be optional.

* add: constant entity_type.

* style: embed args in method call.

* refactor: support new plugin form.

* test: updated for new plugin format.

* refactor: supports new plugin io.

* fix: lower() may have unknown side-effects.

* update: make attrs class.

* update: serializer.

* update: entity type is one of "value" or "interval"

* update: coverage requirement shifted to 90%.

* refactor: new plugin format.

* refactor: new plugin format.

* add: export Slot.

* refactor: new plugin format.

* refactor: break into smaller methods.

* fix: entity_type and type have same value.

entity-type tells the type of the entity.
type tells if the entity is one of "value" or "interval".

* test: check entity_type vs type change.

* fix: instantiation needs utterances.

* fix: guards and signature.

* refactor: fn to get operators.

* fix: lb plugin must call ducklingplugin.

* refactor: move fn related to dt in same mod.

* update: default entity types.

* test: coverage 100%.

* coverage: 100%

* fix: typecheck.

* update: ignore type check/

* test: cov 100%.

* style: lint black isort.

* update: warn_unused_ignores=False

* update: coverage requirement to 100.

* update: merge master.

* add: fn to check unix timestamp.

* update: rebase master.

* update: merge master.

* update: rebase master.

* update: merge master.

* update: merge master.

* fix: duplicate impl.

* style: lint.

* style: lint.
  • Loading branch information
ltbringer authored Feb 7, 2022
1 parent 33ae14c commit cef7272
Show file tree
Hide file tree
Showing 57 changed files with 1,099 additions and 1,163 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/tag_publish.yml
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,11 @@ jobs:
fail_ci_if_error: true
verbose: true
- name: Check if coverage less than 100%
if: env.covpercentage != '100'
if: env.covpercentage < '100'
uses: actions/github-script@v3
with:
script: |
core.setFailed('Code Coverage is not 100%')
core.setFailed('Code Coverage is less than 100%')
- name: Build and Publish to PYPI
env:
TWINE_USERNAME: __token__
Expand Down
4 changes: 4 additions & 0 deletions dialogy/base/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from dialogy.base.entity_extractor import EntityScoringMixin
from dialogy.base.input import Input
from dialogy.base.output import Output
from dialogy.base.plugin import Guard, Plugin
51 changes: 35 additions & 16 deletions dialogy/base/input.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,39 @@
from typing import Optional, List
from __future__ import annotations

from typing import Any, Dict, List, Optional

import attr

from dialogy.types import Utterance
from dialogy.utils import is_unix_ts, normalize, is_utterance
from dialogy.utils import is_unix_ts, normalize


@attr.frozen
class Input:
utterances: List[Utterance] = attr.ib(kw_only=True)
reference_time: Optional[int] = attr.ib(kw_only=True)

reference_time: Optional[int] = attr.ib(default=None, kw_only=True)
latent_entities: bool = attr.ib(default=False, kw_only=True, converter=bool)
transcripts: List[str] = attr.ib(default=None)
clf_feature: Optional[str] = attr.ib(
clf_feature: Optional[List[str]] = attr.ib( # type: ignore
kw_only=True,
default=None,
validator=attr.validators.optional(attr.validators.instance_of(str)),
factory=list,
validator=attr.validators.optional(attr.validators.instance_of(list)),
)
lang: str = attr.ib(
default="en", kw_only=True, validator=attr.validators.instance_of(str)
)
lang: str = attr.ib(default="en", kw_only=True, validator=attr.validators.instance_of(str))
locale: str = attr.ib(
default="en_IN",
kw_only=True,
validator=attr.validators.optional(attr.validators.instance_of(str)),
validator=attr.validators.optional(attr.validators.instance_of(str)), # type: ignore
)
timezone: str = attr.ib(
default="UTC",
kw_only=True,
validator=attr.validators.optional(attr.validators.instance_of(str)),
validator=attr.validators.optional(attr.validators.instance_of(str)), # type: ignore
)
slot_tracker: Optional[list] = attr.ib(
slot_tracker: Optional[List[Dict[str, Any]]] = attr.ib(
default=None,
kw_only=True,
validator=attr.validators.optional(attr.validators.instance_of(list)),
Expand All @@ -45,15 +49,30 @@ class Input:
validator=attr.validators.optional(attr.validators.instance_of(str)),
)

def __attrs_post_init__(self):
object.__setattr__(self, "transcript", normalize(self.utterances))
def __attrs_post_init__(self) -> None:
try:
object.__setattr__(self, "transcripts", normalize(self.utterances))
except TypeError:
...

@reference_time.validator
def _check_reference_time(self, attribute: attr.Attribute, reference_time: int):
@reference_time.validator # type: ignore
def _check_reference_time(
self, attribute: attr.Attribute, reference_time: Optional[int] # type: ignore
) -> None:
if reference_time is None:
return
if not isinstance(reference_time, int):
raise TypeError(f"{attribute.name} must be an integer.")
if not is_unix_ts(reference_time):
raise ValueError(f"{attribute.name} must be a unix timestamp but got {reference_time}.")
raise ValueError(
f"{attribute.name} must be a unix timestamp but got {reference_time}."
)

def json(self):
def json(self) -> Dict[str, Any]:
return attr.asdict(self)

@classmethod
def from_dict(cls, d: Dict[str, Any], reference: Optional[Input] = None) -> Input:
if reference:
return attr.evolve(reference, **d)
return attr.evolve(cls(utterances=d["utterances"]), **d) # type: ignore
15 changes: 12 additions & 3 deletions dialogy/base/output.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import List
from typing import Any, Dict, List, Optional

import attr

Expand All @@ -20,5 +20,14 @@ class Output:
kw_only=True,
)

def json(self: Output) -> dict:
return attr.asdict(self)
def json(self: Output) -> Dict[str, List[Dict[str, Any]]]:
return {
"intents": [intent.json() for intent in self.intents],
"entities": [entity.json() for entity in self.entities],
}

@classmethod
def from_dict(cls, d: Dict[str, Any], reference: Optional[Output] = None) -> Output:
if reference:
return attr.evolve(reference, **d)
return attr.evolve(cls(), **d)
68 changes: 44 additions & 24 deletions dialogy/base/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,13 +268,22 @@ def train(self, training_data: pd.DataFrame)
- If your plugin needs a model but it need not be trained frequently or uses an off the shelf pre-trained model, then you must build a Solitary plugin.
- A trainable plugin can also have transform methods if it needs to modify a dataframe for other plugins in place.
"""
from __future__ import annotations

from abc import ABC, abstractmethod
from typing import Any, Optional
from typing import TYPE_CHECKING, Any, Callable, List, Optional

import dialogy.constants as const
from dialogy.types import PluginFn
from dialogy.base.input import Input
from dialogy.base.output import Output
from dialogy.utils.logger import logger

if TYPE_CHECKING: # pragma: no cover
from dialogy.workflow import Workflow


Guard = Callable[[Input, Output], bool]


class Plugin(ABC):
"""
Expand Down Expand Up @@ -366,22 +375,22 @@ class Plugin(ABC):

def __init__(
self,
access: Optional[PluginFn] = None,
mutate: Optional[PluginFn] = None,
input_column: str = const.ALTERNATIVES,
output_column: Optional[str] = None,
use_transform: bool = False,
dest: Optional[str] = None,
guards: Optional[List[Guard]] = None,
debug: bool = False,
) -> None:
self.access = access
self.mutate = mutate
self.debug = debug
self.guards = guards
self.dest = dest
self.input_column = input_column
self.output_column = output_column or input_column
self.use_transform = use_transform

@abstractmethod
def utility(self, *args: Any) -> Any:
def utility(self, input: Input, output: Output) -> Any:
"""
Transform X -> y.
Expand All @@ -391,7 +400,7 @@ def utility(self, *args: Any) -> Any:
:rtype: Any
"""

def __call__(self, workflow: Any) -> None:
def __call__(self, workflow: Workflow) -> None:
"""
Abstraction for plugin io.
Expand All @@ -400,27 +409,38 @@ def __call__(self, workflow: Any) -> None:
:raises TypeError: If access method is missing, we can't get inputs for transformation.
"""
logger.enable(str(self)) if self.debug else logger.disable(str(self))
if self.access:
args = self.access(workflow)
value = self.utility(*args) # pylint: disable=assignment-from-none
if value is not None and self.mutate:
self.mutate(workflow, value)
else:
raise TypeError(
"Expected access to be functions" f" but {type(self.access)} was found."
)

def train(
self, _: Any
) -> Any: # pylint: disable=unused-argument disable=no-self-use

if workflow.input is None:
return

if workflow.output is None:
return

if self.prevent(workflow.input, workflow.output):
return

value = self.utility(workflow.input, workflow.output)
if value is not None and isinstance(self.dest, str):
workflow.set(self.dest, value)

def prevent(self, input_: Input, output: Output) -> bool:
"""
Decide if the plugin should execute.
:return: prevent plugin execution if True.
:rtype: bool
"""
if not self.guards:
return False
return any(guard(input_, output) for guard in self.guards)

def train(self, _: Any) -> Any:
"""
Train a plugin.
"""
return None

def transform(
self, training_data: Any
) -> Any: # pylint: disable=unused-argument disable=no-self-use
def transform(self, training_data: Any) -> Any:
"""
Transform data for a plugin in the workflow.
"""
Expand Down
1 change: 1 addition & 0 deletions dialogy/constants/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ class EntityKeys:
START = "start"
TO = "to"
TYPE = "type"
ENTITY_TYPE = "entity_type"
UNIT = "unit"
VALUE = "value"
VALUES = "values"
Expand Down
25 changes: 14 additions & 11 deletions dialogy/plugins/text/calibration/xgb.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
from xgboost import XGBRegressor

from dialogy import constants as const
from dialogy.base.plugin import Plugin, PluginFn
from dialogy.types import Transcript, Utterance
from dialogy.base import Guard, Input, Output, Plugin
from dialogy.types import Transcript, Utterance, utterances
from dialogy.utils import normalize


Expand Down Expand Up @@ -92,18 +92,18 @@ class CalibrationModel(Plugin):

def __init__(
self,
access: Optional[PluginFn],
mutate: Optional[PluginFn],
threshold: float,
dest: Optional[str] = None,
guards: Optional[List[Guard]] = None,
debug: bool = False,
input_column: str = const.ALTERNATIVES,
output_column: Optional[str] = const.ALTERNATIVES,
use_transform: bool = False,
model_name: str = "calibration.pkl",
) -> None:
super().__init__(
access,
mutate,
dest=dest,
guards=guards,
debug=debug,
input_column=input_column,
output_column=output_column,
Expand Down Expand Up @@ -180,8 +180,9 @@ def transform(self, training_data: pd.DataFrame) -> pd.DataFrame:
)
return training_data_

def inference(self, utterances: List[Utterance]) -> List[str]:
transcripts: List[Transcript] = normalize(utterances)
def inference(
self, transcripts: List[str], utterances: List[Utterance]
) -> List[str]:
transcript_lengths: List[int] = [
len(transcript.split()) for transcript in transcripts
]
Expand All @@ -194,15 +195,17 @@ def inference(self, utterances: List[Utterance]) -> List[str]:
# a classifier's prediction to a fallback label.
# If the transcripts have less than WORD_THRESHOLD words, we will always predict the fallback label.
if average_word_count <= const.WORD_THRESHOLD:
return normalize(utterances)
return transcripts

return normalize(self.filter_asr_output(utterances))

def save(self, fname: str) -> None:
pickle.dump(self, open(fname, "wb"))

def utility(self, *args: Any) -> Any:
return self.inference(*args) # pylint: disable=no-value-for-parameter
def utility(self, input: Input, _: Output) -> Any:
return self.inference(
input.transcripts, input.utterances
) # pylint: disable=no-value-for-parameter

def validate(self, df: pd.DataFrame) -> bool:
"""
Expand Down
21 changes: 11 additions & 10 deletions dialogy/plugins/text/canonicalization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,13 @@
from tqdm import tqdm

import dialogy.constants as const
from dialogy.base.plugin import Plugin
from dialogy.types import BaseEntity, plugin
from dialogy.base import Guard, Input, Output, Plugin
from dialogy.types import BaseEntity
from dialogy.utils import normalize


def get_entity_type(entity: BaseEntity) -> str:
return f"<{entity.type}>"
return f"<{entity.entity_type}>"


class CanonicalizationPlugin(Plugin):
Expand All @@ -28,8 +28,8 @@ def __init__(
serializer: Callable[[BaseEntity], str] = get_entity_type,
mask: str = "MASK",
mask_tokens: Optional[List[str]] = None,
access: Optional[plugin.PluginFn] = None,
mutate: Optional[plugin.PluginFn] = None,
dest: Optional[str] = None,
guards: Optional[List[Guard]] = None,
entity_column: str = const.ENTITY_COLUMN,
input_column: str = const.ALTERNATIVES,
output_column: Optional[str] = None,
Expand All @@ -38,8 +38,8 @@ def __init__(
debug: bool = False,
) -> None:
super().__init__(
access,
mutate,
dest=dest,
guards=guards,
debug=debug,
use_transform=use_transform,
input_column=input_column,
Expand Down Expand Up @@ -77,8 +77,9 @@ def mask_transcript(

return canonicalized_transcripts

def utility(self, *args: Any) -> Any:
entities, transcripts = args
def utility(self, input: Input, output: Output) -> Any:
entities = output.entities
transcripts = input.transcripts
return self.mask_transcript(entities, transcripts)

def transform(self, training_data: pd.DataFrame) -> pd.DataFrame:
Expand All @@ -89,7 +90,7 @@ def transform(self, training_data: pd.DataFrame) -> pd.DataFrame:

for i, row in tqdm(training_data.iterrows(), total=len(training_data)):
try:
canonicalized_transcripts = self.utility(
canonicalized_transcripts = self.mask_transcript(
row[self.entity_column],
normalize(json.loads(row[self.input_column])),
)
Expand Down
Loading

0 comments on commit cef7272

Please sign in to comment.