diff --git a/.github/workflows/tag_publish.yml b/.github/workflows/tag_publish.yml index 94214a04..e9329d6b 100644 --- a/.github/workflows/tag_publish.yml +++ b/.github/workflows/tag_publish.yml @@ -47,11 +47,11 @@ jobs: fail_ci_if_error: true verbose: true - name: Check if coverage less than 100% - if: env.covpercentage != '100' + if: env.covpercentage < '100' uses: actions/github-script@v3 with: script: | - core.setFailed('Code Coverage is not 100%') + core.setFailed('Code Coverage is less than 100%') - name: Build and Publish to PYPI env: TWINE_USERNAME: __token__ diff --git a/dialogy/base/__init__.py b/dialogy/base/__init__.py index e69de29b..01b54b2b 100644 --- a/dialogy/base/__init__.py +++ b/dialogy/base/__init__.py @@ -0,0 +1,4 @@ +from dialogy.base.entity_extractor import EntityScoringMixin +from dialogy.base.input import Input +from dialogy.base.output import Output +from dialogy.base.plugin import Guard, Plugin diff --git a/dialogy/base/input.py b/dialogy/base/input.py index 0d0c39bd..564ca103 100644 --- a/dialogy/base/input.py +++ b/dialogy/base/input.py @@ -1,35 +1,39 @@ -from typing import Optional, List +from __future__ import annotations + +from typing import Any, Dict, List, Optional import attr from dialogy.types import Utterance -from dialogy.utils import is_unix_ts, normalize, is_utterance +from dialogy.utils import is_unix_ts, normalize @attr.frozen class Input: utterances: List[Utterance] = attr.ib(kw_only=True) - reference_time: Optional[int] = attr.ib(kw_only=True) + reference_time: Optional[int] = attr.ib(default=None, kw_only=True) latent_entities: bool = attr.ib(default=False, kw_only=True, converter=bool) transcripts: List[str] = attr.ib(default=None) - clf_feature: Optional[str] = attr.ib( + clf_feature: Optional[List[str]] = attr.ib( # type: ignore kw_only=True, - default=None, - validator=attr.validators.optional(attr.validators.instance_of(str)), + factory=list, + validator=attr.validators.optional(attr.validators.instance_of(list)), + ) + lang: str = attr.ib( + default="en", kw_only=True, validator=attr.validators.instance_of(str) ) - lang: str = attr.ib(default="en", kw_only=True, validator=attr.validators.instance_of(str)) locale: str = attr.ib( default="en_IN", kw_only=True, - validator=attr.validators.optional(attr.validators.instance_of(str)), + validator=attr.validators.optional(attr.validators.instance_of(str)), # type: ignore ) timezone: str = attr.ib( default="UTC", kw_only=True, - validator=attr.validators.optional(attr.validators.instance_of(str)), + validator=attr.validators.optional(attr.validators.instance_of(str)), # type: ignore ) - slot_tracker: Optional[list] = attr.ib( + slot_tracker: Optional[List[Dict[str, Any]]] = attr.ib( default=None, kw_only=True, validator=attr.validators.optional(attr.validators.instance_of(list)), @@ -45,15 +49,30 @@ class Input: validator=attr.validators.optional(attr.validators.instance_of(str)), ) - def __attrs_post_init__(self): - object.__setattr__(self, "transcript", normalize(self.utterances)) + def __attrs_post_init__(self) -> None: + try: + object.__setattr__(self, "transcripts", normalize(self.utterances)) + except TypeError: + ... - @reference_time.validator - def _check_reference_time(self, attribute: attr.Attribute, reference_time: int): + @reference_time.validator # type: ignore + def _check_reference_time( + self, attribute: attr.Attribute, reference_time: Optional[int] # type: ignore + ) -> None: + if reference_time is None: + return if not isinstance(reference_time, int): raise TypeError(f"{attribute.name} must be an integer.") if not is_unix_ts(reference_time): - raise ValueError(f"{attribute.name} must be a unix timestamp but got {reference_time}.") + raise ValueError( + f"{attribute.name} must be a unix timestamp but got {reference_time}." + ) - def json(self): + def json(self) -> Dict[str, Any]: return attr.asdict(self) + + @classmethod + def from_dict(cls, d: Dict[str, Any], reference: Optional[Input] = None) -> Input: + if reference: + return attr.evolve(reference, **d) + return attr.evolve(cls(utterances=d["utterances"]), **d) # type: ignore diff --git a/dialogy/base/output.py b/dialogy/base/output.py index 72933f4a..e45118c1 100644 --- a/dialogy/base/output.py +++ b/dialogy/base/output.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import List +from typing import Any, Dict, List, Optional import attr @@ -20,5 +20,14 @@ class Output: kw_only=True, ) - def json(self: Output) -> dict: - return attr.asdict(self) + def json(self: Output) -> Dict[str, List[Dict[str, Any]]]: + return { + "intents": [intent.json() for intent in self.intents], + "entities": [entity.json() for entity in self.entities], + } + + @classmethod + def from_dict(cls, d: Dict[str, Any], reference: Optional[Output] = None) -> Output: + if reference: + return attr.evolve(reference, **d) + return attr.evolve(cls(), **d) diff --git a/dialogy/base/plugin.py b/dialogy/base/plugin.py index d8bc56f6..f004e30b 100644 --- a/dialogy/base/plugin.py +++ b/dialogy/base/plugin.py @@ -268,13 +268,22 @@ def train(self, training_data: pd.DataFrame) - If your plugin needs a model but it need not be trained frequently or uses an off the shelf pre-trained model, then you must build a Solitary plugin. - A trainable plugin can also have transform methods if it needs to modify a dataframe for other plugins in place. """ +from __future__ import annotations + from abc import ABC, abstractmethod -from typing import Any, Optional +from typing import TYPE_CHECKING, Any, Callable, List, Optional import dialogy.constants as const -from dialogy.types import PluginFn +from dialogy.base.input import Input +from dialogy.base.output import Output from dialogy.utils.logger import logger +if TYPE_CHECKING: # pragma: no cover + from dialogy.workflow import Workflow + + +Guard = Callable[[Input, Output], bool] + class Plugin(ABC): """ @@ -366,22 +375,22 @@ class Plugin(ABC): def __init__( self, - access: Optional[PluginFn] = None, - mutate: Optional[PluginFn] = None, input_column: str = const.ALTERNATIVES, output_column: Optional[str] = None, use_transform: bool = False, + dest: Optional[str] = None, + guards: Optional[List[Guard]] = None, debug: bool = False, ) -> None: - self.access = access - self.mutate = mutate self.debug = debug + self.guards = guards + self.dest = dest self.input_column = input_column self.output_column = output_column or input_column self.use_transform = use_transform @abstractmethod - def utility(self, *args: Any) -> Any: + def utility(self, input: Input, output: Output) -> Any: """ Transform X -> y. @@ -391,7 +400,7 @@ def utility(self, *args: Any) -> Any: :rtype: Any """ - def __call__(self, workflow: Any) -> None: + def __call__(self, workflow: Workflow) -> None: """ Abstraction for plugin io. @@ -400,27 +409,38 @@ def __call__(self, workflow: Any) -> None: :raises TypeError: If access method is missing, we can't get inputs for transformation. """ logger.enable(str(self)) if self.debug else logger.disable(str(self)) - if self.access: - args = self.access(workflow) - value = self.utility(*args) # pylint: disable=assignment-from-none - if value is not None and self.mutate: - self.mutate(workflow, value) - else: - raise TypeError( - "Expected access to be functions" f" but {type(self.access)} was found." - ) - - def train( - self, _: Any - ) -> Any: # pylint: disable=unused-argument disable=no-self-use + + if workflow.input is None: + return + + if workflow.output is None: + return + + if self.prevent(workflow.input, workflow.output): + return + + value = self.utility(workflow.input, workflow.output) + if value is not None and isinstance(self.dest, str): + workflow.set(self.dest, value) + + def prevent(self, input_: Input, output: Output) -> bool: + """ + Decide if the plugin should execute. + + :return: prevent plugin execution if True. + :rtype: bool + """ + if not self.guards: + return False + return any(guard(input_, output) for guard in self.guards) + + def train(self, _: Any) -> Any: """ Train a plugin. """ return None - def transform( - self, training_data: Any - ) -> Any: # pylint: disable=unused-argument disable=no-self-use + def transform(self, training_data: Any) -> Any: """ Transform data for a plugin in the workflow. """ diff --git a/dialogy/constants/__init__.py b/dialogy/constants/__init__.py index 1b961fe6..66dfd804 100644 --- a/dialogy/constants/__init__.py +++ b/dialogy/constants/__init__.py @@ -33,6 +33,7 @@ class EntityKeys: START = "start" TO = "to" TYPE = "type" + ENTITY_TYPE = "entity_type" UNIT = "unit" VALUE = "value" VALUES = "values" diff --git a/dialogy/plugins/text/calibration/xgb.py b/dialogy/plugins/text/calibration/xgb.py index a335fcf1..b2275507 100644 --- a/dialogy/plugins/text/calibration/xgb.py +++ b/dialogy/plugins/text/calibration/xgb.py @@ -20,8 +20,8 @@ from xgboost import XGBRegressor from dialogy import constants as const -from dialogy.base.plugin import Plugin, PluginFn -from dialogy.types import Transcript, Utterance +from dialogy.base import Guard, Input, Output, Plugin +from dialogy.types import Transcript, Utterance, utterances from dialogy.utils import normalize @@ -92,9 +92,9 @@ class CalibrationModel(Plugin): def __init__( self, - access: Optional[PluginFn], - mutate: Optional[PluginFn], threshold: float, + dest: Optional[str] = None, + guards: Optional[List[Guard]] = None, debug: bool = False, input_column: str = const.ALTERNATIVES, output_column: Optional[str] = const.ALTERNATIVES, @@ -102,8 +102,8 @@ def __init__( model_name: str = "calibration.pkl", ) -> None: super().__init__( - access, - mutate, + dest=dest, + guards=guards, debug=debug, input_column=input_column, output_column=output_column, @@ -180,8 +180,9 @@ def transform(self, training_data: pd.DataFrame) -> pd.DataFrame: ) return training_data_ - def inference(self, utterances: List[Utterance]) -> List[str]: - transcripts: List[Transcript] = normalize(utterances) + def inference( + self, transcripts: List[str], utterances: List[Utterance] + ) -> List[str]: transcript_lengths: List[int] = [ len(transcript.split()) for transcript in transcripts ] @@ -194,15 +195,17 @@ def inference(self, utterances: List[Utterance]) -> List[str]: # a classifier's prediction to a fallback label. # If the transcripts have less than WORD_THRESHOLD words, we will always predict the fallback label. if average_word_count <= const.WORD_THRESHOLD: - return normalize(utterances) + return transcripts return normalize(self.filter_asr_output(utterances)) def save(self, fname: str) -> None: pickle.dump(self, open(fname, "wb")) - def utility(self, *args: Any) -> Any: - return self.inference(*args) # pylint: disable=no-value-for-parameter + def utility(self, input: Input, _: Output) -> Any: + return self.inference( + input.transcripts, input.utterances + ) # pylint: disable=no-value-for-parameter def validate(self, df: pd.DataFrame) -> bool: """ diff --git a/dialogy/plugins/text/canonicalization/__init__.py b/dialogy/plugins/text/canonicalization/__init__.py index 573b4e1b..edc0d9bb 100644 --- a/dialogy/plugins/text/canonicalization/__init__.py +++ b/dialogy/plugins/text/canonicalization/__init__.py @@ -9,13 +9,13 @@ from tqdm import tqdm import dialogy.constants as const -from dialogy.base.plugin import Plugin -from dialogy.types import BaseEntity, plugin +from dialogy.base import Guard, Input, Output, Plugin +from dialogy.types import BaseEntity from dialogy.utils import normalize def get_entity_type(entity: BaseEntity) -> str: - return f"<{entity.type}>" + return f"<{entity.entity_type}>" class CanonicalizationPlugin(Plugin): @@ -28,8 +28,8 @@ def __init__( serializer: Callable[[BaseEntity], str] = get_entity_type, mask: str = "MASK", mask_tokens: Optional[List[str]] = None, - access: Optional[plugin.PluginFn] = None, - mutate: Optional[plugin.PluginFn] = None, + dest: Optional[str] = None, + guards: Optional[List[Guard]] = None, entity_column: str = const.ENTITY_COLUMN, input_column: str = const.ALTERNATIVES, output_column: Optional[str] = None, @@ -38,8 +38,8 @@ def __init__( debug: bool = False, ) -> None: super().__init__( - access, - mutate, + dest=dest, + guards=guards, debug=debug, use_transform=use_transform, input_column=input_column, @@ -77,8 +77,9 @@ def mask_transcript( return canonicalized_transcripts - def utility(self, *args: Any) -> Any: - entities, transcripts = args + def utility(self, input: Input, output: Output) -> Any: + entities = output.entities + transcripts = input.transcripts return self.mask_transcript(entities, transcripts) def transform(self, training_data: pd.DataFrame) -> pd.DataFrame: @@ -89,7 +90,7 @@ def transform(self, training_data: pd.DataFrame) -> pd.DataFrame: for i, row in tqdm(training_data.iterrows(), total=len(training_data)): try: - canonicalized_transcripts = self.utility( + canonicalized_transcripts = self.mask_transcript( row[self.entity_column], normalize(json.loads(row[self.input_column])), ) diff --git a/dialogy/plugins/text/classification/mlp.py b/dialogy/plugins/text/classification/mlp.py index a405d5d4..6b38feb3 100644 --- a/dialogy/plugins/text/classification/mlp.py +++ b/dialogy/plugins/text/classification/mlp.py @@ -21,7 +21,7 @@ tqdm.pandas() import dialogy.constants as const -from dialogy.base.plugin import Plugin, PluginFn +from dialogy.base import Guard, Input, Output, Plugin from dialogy.plugins.text.classification.tokenizers import identity_tokenizer from dialogy.types import Intent from dialogy.utils import load_file, logger, save_file @@ -35,8 +35,8 @@ class MLPMultiClass(Plugin): def __init__( self, model_dir: str, - access: Optional[PluginFn] = None, - mutate: Optional[PluginFn] = None, + dest: Optional[str] = None, + guards: Optional[List[Guard]] = None, debug: bool = False, threshold: float = 0.1, score_round_off: int = 5, @@ -49,7 +49,7 @@ def __init__( kwargs: Optional[Dict[str, Any]] = None, ) -> None: - super().__init__(access, mutate, debug=debug) + super().__init__(dest=dest, guards=guards, debug=debug) self.model_pipeline: Any = None self.fallback_label = fallback_label self.data_column = data_column @@ -178,7 +178,7 @@ def get_formatted_gridparams(params: List[Any]) -> List[Any]: def valid_mlpmodel(self) -> bool: return hasattr(self.model_pipeline, "classes_") - def inference(self, texts: List[str]) -> List[Intent]: + def inference(self, texts: Optional[List[str]]) -> List[Intent]: """ Predict the intent of a list of texts. @@ -312,5 +312,7 @@ def load(self) -> None: self.mlp_model_path, mode="rb", loader=joblib.load ) - def utility(self, *args: Any) -> Any: - return self.inference(*args) # pylint: disable=no-value-for-parameter + def utility(self, input: Input, _: Output) -> Any: + return self.inference( + input.clf_feature + ) # pylint: disable=no-value-for-parameter diff --git a/dialogy/plugins/text/classification/xlmr.py b/dialogy/plugins/text/classification/xlmr.py index 2c897518..5b85dc47 100644 --- a/dialogy/plugins/text/classification/xlmr.py +++ b/dialogy/plugins/text/classification/xlmr.py @@ -14,7 +14,7 @@ from sklearn import preprocessing import dialogy.constants as const -from dialogy.base.plugin import Plugin, PluginFn +from dialogy.base import Guard, Input, Output, Plugin from dialogy.types import Intent from dialogy.utils import load_file, logger, save_file @@ -27,8 +27,8 @@ class XLMRMultiClass(Plugin): def __init__( self, model_dir: str, - access: Optional[PluginFn] = None, - mutate: Optional[PluginFn] = None, + dest: Optional[str] = None, + guards: Optional[List[Guard]] = None, debug: bool = False, threshold: float = 0.1, use_cuda: bool = False, @@ -52,7 +52,7 @@ def __init__( "Plugin requires simpletransformers -- https://simpletransformers.ai/docs/installation/" ) from error - super().__init__(access, mutate, debug=debug) + super().__init__(dest=dest, guards=guards, debug=debug) self.labelencoder = preprocessing.LabelEncoder() self.classifier = classifer self.model: Any = None @@ -135,7 +135,7 @@ def init_model(self, label_count: Optional[int] = None) -> None: def valid_labelencoder(self) -> bool: return hasattr(self.labelencoder, "classes_") - def inference(self, texts: List[str]) -> List[Intent]: + def inference(self, texts: Optional[List[str]]) -> List[Intent]: """ Predict the intent of a list of texts. @@ -147,6 +147,9 @@ def inference(self, texts: List[str]) -> List[Intent]: """ logger.debug(f"Classifier input:\n{texts}") fallback_output = Intent(name=self.fallback_label, score=1.0).add_parser(self) + if not texts: + logger.error(f"texts passed to model {texts}!") + return [fallback_output] if self.model is None: logger.error(f"No model found for plugin {self.__class__.__name__}!") @@ -165,17 +168,18 @@ def inference(self, texts: List[str]) -> List[Intent]: f"save the {self.__class__.__name__} plugin." ) - if not texts: - return [fallback_output] - predictions, logits = self.model.predict(texts) if not predictions: return [fallback_output] - + confidence_scores = [np.exp(logit) / sum(np.exp(logit)) for logit in logits] intents_confidence_order = np.argsort(confidence_scores)[0][::-1] - predicted_intents = self.labelencoder.inverse_transform(intents_confidence_order) - ordered_confidence_scores = [confidence_scores[0][idx] for idx in intents_confidence_order] + predicted_intents = self.labelencoder.inverse_transform( + intents_confidence_order + ) + ordered_confidence_scores = [ + confidence_scores[0][idx] for idx in intents_confidence_order + ] return [ Intent(name=intent, score=round(score, self.round)).add_parser( @@ -268,5 +272,5 @@ def load(self) -> None: self.labelencoder_file_path, mode="rb", loader=pickle.load ) - def utility(self, *args: Any) -> Any: - return self.inference(*args) # pylint: disable=no-value-for-parameter + def utility(self, input: Input, _: Output) -> List[Intent]: + return self.inference(input.clf_feature) diff --git a/dialogy/plugins/text/combine_date_time/__init__.py b/dialogy/plugins/text/combine_date_time/__init__.py index bf21b2c9..6352ac78 100644 --- a/dialogy/plugins/text/combine_date_time/__init__.py +++ b/dialogy/plugins/text/combine_date_time/__init__.py @@ -1,20 +1,22 @@ -from datetime import datetime from typing import Any, Dict, List, Optional +import attr +from pydash import py_ + from dialogy import constants as const -from dialogy.base.plugin import Plugin, PluginFn +from dialogy.base import Guard, Input, Output, Plugin from dialogy.types import BaseEntity, TimeEntity def has_time_component(entity: BaseEntity) -> bool: - return entity.type in [ + return entity.entity_type in [ CombineDateTimeOverSlots.TIME, CombineDateTimeOverSlots.DATETIME, ] def is_date(entity: BaseEntity) -> bool: - return entity.type == CombineDateTimeOverSlots.DATE + return entity.entity_type == CombineDateTimeOverSlots.DATE class CombineDateTimeOverSlots(Plugin): @@ -74,8 +76,8 @@ class CombineDateTimeOverSlots(Plugin): def __init__( self, - access: Optional[PluginFn] = None, - mutate: Optional[PluginFn] = None, + dest: Optional[str] = None, + guards: Optional[List[Guard]] = None, input_column: str = const.ALTERNATIVES, output_column: Optional[str] = None, use_transform: bool = False, @@ -83,8 +85,8 @@ def __init__( debug: bool = False, ) -> None: super().__init__( - access=access, - mutate=mutate, + dest=dest, + guards=guards, input_column=input_column, output_column=output_column, use_transform=use_transform, @@ -92,8 +94,9 @@ def __init__( ) self.trigger_intents = trigger_intents - def join(self, current_entity: TimeEntity, previous_entity: TimeEntity) -> None: - combined_value = None + def join( + self, current_entity: TimeEntity, previous_entity: TimeEntity + ) -> TimeEntity: current_turn_datetime = current_entity.get_value() previous_turn_datetime = previous_entity.get_value() @@ -109,61 +112,73 @@ def join(self, current_entity: TimeEntity, previous_entity: TimeEntity) -> None: month=previous_turn_datetime.month, day=previous_turn_datetime.day, ) + else: + return current_entity - if not combined_value: - return - - current_entity.value = combined_value.isoformat() - current_entity.set_value(current_entity.value) - - def utility(self, *args: Any) -> Any: - """ - Combine the date and time entities collected across turns into a single entity. - """ - tracker: List[Dict[str, Any]] - entities: List[BaseEntity] + return attr.evolve( + current_entity, **{const.EntityKeys.VALUE: combined_value.isoformat()} + ) - tracked_entity = None - tracked_intent = None + def get_tracked_slots( + self, slot_tracker: Optional[List[Dict[str, Any]]] + ) -> List[Dict[str, Any]]: + if not self.trigger_intents or not slot_tracker: + return [] - tracker, entities = args + tracked_intents = [ + intent + for intent in slot_tracker + if intent[const.NAME] in self.trigger_intents + ] - if not self.trigger_intents: - return + if not tracked_intents: + return [] - if not tracker: - return + tracked_intent, *_ = tracked_intents + return tracked_intent.get(const.SLOTS, []) - for entity in entities: - if entity.type not in CombineDateTimeOverSlots.SUPPORTED_ENTITIES: - continue + def pick_previously_filled_time_entity( + self, tracked_slots: Optional[List[Dict[str, Any]]] + ) -> Optional[TimeEntity]: + if not tracked_slots: + return None - tracked_intents = [ - intent - for intent in tracker - if intent[const.NAME] in self.trigger_intents - ] + filled_entities_json = tracked_slots[0][const.EntityKeys.VALUES] - if not tracked_intents: - continue + if not filled_entities_json or not isinstance(filled_entities_json, list): + return None - tracked_intent = tracked_intents[0] - tracked_slots = tracked_intent.get(const.SLOTS, []) + filled_entity_json, *_ = filled_entities_json + filled_entity_json[const.EntityKeys.VALUES] = [ + {const.VALUE: filled_entity_json[const.VALUE]} + ] - if not tracked_slots: - continue + return TimeEntity(**filled_entity_json) - tracked_entities_metadata = tracked_slots[0][const.EntityKeys.VALUES] + def combine_time_entities_from_slots( + self, slot_tracker: Optional[List[Dict[str, Any]]], entities: List[BaseEntity] + ) -> List[BaseEntity]: + previously_filled_time_entity = self.pick_previously_filled_time_entity( + self.get_tracked_slots(slot_tracker) + ) - if not tracked_entities_metadata: - continue + if not previously_filled_time_entity: + return entities - if not isinstance(tracked_entities_metadata, list): - continue + time_entities, other_entities = py_.partition( + entities, + lambda entity: entity.entity_type + in CombineDateTimeOverSlots.SUPPORTED_ENTITIES, + ) + combined_time_entities = [ + self.join(entity, previously_filled_time_entity) for entity in time_entities + ] + return combined_time_entities + other_entities - tracked_entity_metadata = tracked_entities_metadata[0] - tracked_entity_metadata[const.EntityKeys.VALUES] = [ - {const.VALUE: tracked_entity_metadata[const.VALUE]} - ] - tracked_entity = TimeEntity(**tracked_entity_metadata) - self.join(entity, tracked_entity) # type: ignore + def utility(self, input: Input, output: Output) -> List[BaseEntity]: + """ + Combine the date and time entities collected across turns into a single entity. + """ + return self.combine_time_entities_from_slots( + input.slot_tracker, output.entities + ) diff --git a/dialogy/plugins/text/duckling_plugin/__init__.py b/dialogy/plugins/text/duckling_plugin/__init__.py index 1cf65107..d7843288 100644 --- a/dialogy/plugins/text/duckling_plugin/__init__.py +++ b/dialogy/plugins/text/duckling_plugin/__init__.py @@ -63,9 +63,10 @@ def update_workflow(workflow, entities): from tqdm import tqdm from dialogy import constants as const +from dialogy.base import Guard, Input, Output, Plugin from dialogy.base.entity_extractor import EntityScoringMixin -from dialogy.base.plugin import Plugin, PluginFn from dialogy.constants import EntityKeys +from dialogy.types import PluginFn from dialogy.types.entity import BaseEntity, dimension_entity_map from dialogy.utils import dt2timestamp, lang_detect_from_text, logger @@ -121,10 +122,10 @@ def __init__( timeout: float = 0.5, url: str = "http://0.0.0.0:8000/parse", locale: str = "en_IN", + dest: Optional[str] = None, + guards: Optional[List[Guard]] = None, datetime_filters: Optional[str] = None, threshold: Optional[float] = None, - access: Optional[PluginFn] = None, - mutate: Optional[PluginFn] = None, entity_map: Optional[Dict[str, Any]] = None, activate_latent_entities: Union[Callable[..., bool], bool] = False, reference_time_column: str = const.REFERENCE_TIME, @@ -137,8 +138,8 @@ def __init__( constructor """ super().__init__( - access=access, - mutate=mutate, + dest=dest, + guards=guards, debug=debug, input_column=input_column, output_column=output_column, @@ -196,7 +197,7 @@ def __create_req_body( text: str, reference_time: Optional[int] = None, locale: str = "en_IN", - use_latent: bool = False, + use_latent: Union[Callable[..., bool], bool] = False, ) -> Dict[str, Any]: """ create request body for entity parsing @@ -236,6 +237,16 @@ def __create_req_body( return payload + def get_operator(self, filter_type: Any) -> Any: + try: + return getattr(operator, filter_type) + except (AttributeError, TypeError) as exception: + logger.debug(traceback.format_exc()) + raise ValueError( + f"Expected datetime_filters to be one of {self.FUTURE}, {self.PAST} " + "or a valid comparison operator here: https://docs.python.org/3/library/operator.html" + ) from exception + def select_datetime( self, entities: List[BaseEntity], filter_type: Any ) -> List[BaseEntity]: @@ -258,14 +269,7 @@ def select_datetime( if filter_type in self.DATETIME_OPERATION_ALIAS: operation = self.DATETIME_OPERATION_ALIAS[filter_type] else: - try: - operation = getattr(operator, filter_type) - except (AttributeError, TypeError) as exception: - logger.debug(traceback.format_exc()) - raise ValueError( - f"Expected datetime_filters to be one of {self.FUTURE}, {self.PAST} " - "or a valid comparison operator here: https://docs.python.org/3/library/operator.html" - ) from exception + operation = self.get_operator(filter_type) time_entities, other_entities = py_.partition( entities, lambda entity: entity.dim == const.TIME @@ -375,7 +379,7 @@ def _get_entities( text: str, locale: str = "en_IN", reference_time: Optional[int] = None, - use_latent: bool = False, + use_latent: Union[Callable[..., bool], bool] = False, sort_idx: int = 0, ) -> Dict[str, Any]: """ @@ -427,7 +431,7 @@ def _get_entities_concurrent( texts: List[str], locale: str = "en_IN", reference_time: Optional[int] = None, - use_latent: bool = False, + use_latent: Union[Callable[..., bool], bool] = False, ) -> List[List[Dict[str, Any]]]: """ Make multiple-parallel API calls to duckling-server . @@ -494,31 +498,34 @@ def validate( def extract( self, - input_: Union[str, List[str]], + transcripts: Union[str, List[str]], locale: str, reference_time: Optional[int] = None, - use_latent: bool = False, + use_latent: Union[Callable[..., bool], bool] = False, ) -> List[BaseEntity]: list_of_entities: List[List[Dict[str, Any]]] = [] entities: List[BaseEntity] = [] - self.validate(input_, reference_time) + self.validate(transcripts, reference_time) self.reference_time = reference_time - if isinstance(input_, str): - input_ = [input_] + if isinstance(transcripts, str): + transcripts = [transcripts] # pragma: no cover try: list_of_entities = self._get_entities_concurrent( - input_, locale, reference_time=reference_time, use_latent=use_latent + transcripts, + locale, + reference_time=reference_time, + use_latent=use_latent, ) entities = self.apply_entity_classes(list_of_entities) - entities = self.entity_consensus(entities, len(input_)) + entities = self.entity_consensus(entities, len(transcripts)) return self.apply_filters(entities) except ValueError as value_error: raise ValueError(str(value_error)) from value_error - def utility(self, *args: Any) -> List[BaseEntity]: + def utility(self, input: Input, _: Output) -> List[BaseEntity]: """ Produces Duckling entities, runs with a :ref:`Workflow's run` method. @@ -527,9 +534,12 @@ def utility(self, *args: Any) -> List[BaseEntity]: :return: A list of duckling entities. :rtype: List[BaseEntity] """ - input_, reference_time, locale, use_latent = args + transcripts = input.transcripts + reference_time = input.reference_time + locale = input.locale + use_latent = input.latent_entities return self.extract( - input_, locale, reference_time=reference_time, use_latent=use_latent + transcripts, locale, reference_time=reference_time, use_latent=use_latent ) def transform(self, training_data: pd.DataFrame) -> pd.DataFrame: @@ -563,11 +573,11 @@ def transform(self, training_data: pd.DataFrame) -> pd.DataFrame: f"{reference_time=} should be isoformat date or unix timestamp integer." ) transcripts = self.make_transform_values(row[self.input_column]) - entities = self.utility( + entities = self.extract( transcripts, - reference_time, lang_detect_from_text(self.input_column), - self.activate_latent_entities, + reference_time=reference_time, + use_latent=self.activate_latent_entities, ) if row[self.output_column] is None or pd.isnull(row[self.output_column]): training_data.at[i, self.output_column] = entities diff --git a/dialogy/plugins/text/lb_plugin/__init__.py b/dialogy/plugins/text/lb_plugin/__init__.py index a14473a7..4a0c62e6 100644 --- a/dialogy/plugins/text/lb_plugin/__init__.py +++ b/dialogy/plugins/text/lb_plugin/__init__.py @@ -3,13 +3,12 @@ from pydash import partition from dialogy import constants as const -from dialogy.base.plugin import PluginFn +from dialogy.base import Guard, Input, Output from dialogy.plugins import DucklingPlugin -from dialogy.types.entity import BaseEntity +from dialogy.types import BaseEntity class DucklingPluginLB(DucklingPlugin): - # Constructor def __init__( self, @@ -19,8 +18,8 @@ def __init__( url: str = "http://0.0.0.0:8000/parse", locale: str = "en_IN", datetime_filters: Optional[str] = None, - access: Optional[PluginFn] = None, - mutate: Optional[PluginFn] = None, + dest: Optional[str] = None, + guards: Optional[List[Guard]] = None, entity_map: Optional[Dict[str, Any]] = None, reference_time_column: str = const.REFERENCE_TIME, input_column: str = const.ALTERNATIVES, @@ -36,8 +35,8 @@ def __init__( locale=locale, datetime_filters=datetime_filters, threshold=0, - access=access, - mutate=mutate, + dest=dest, + guards=guards, entity_map=entity_map, reference_time_column=reference_time_column, input_column=input_column, @@ -46,10 +45,10 @@ def __init__( debug=debug, ) - def utility(self, *args: Any) -> List[BaseEntity]: - entity_list = super().utility(*args) + def utility(self, input_: Input, output: Output) -> List[BaseEntity]: + entity_list = super().utility(input_, output) datetime_list, other_list = partition( - entity_list, lambda x: x.type in ["datetime", "date", "time"] + entity_list, lambda x: x.entity_type in ["datetime", "date", "time"] ) if datetime_list: other_list.append(min(datetime_list, key=lambda x: x.alternative_index)) diff --git a/dialogy/plugins/text/list_entity_plugin/__init__.py b/dialogy/plugins/text/list_entity_plugin/__init__.py index b1fb0510..5b51cad2 100644 --- a/dialogy/plugins/text/list_entity_plugin/__init__.py +++ b/dialogy/plugins/text/list_entity_plugin/__init__.py @@ -17,8 +17,8 @@ from tqdm import tqdm from dialogy import constants as const +from dialogy.base import Guard, Input, Output, Plugin from dialogy.base.entity_extractor import EntityScoringMixin -from dialogy.base.plugin import Plugin, PluginFn from dialogy.types import BaseEntity, KeywordEntity Text = str @@ -61,10 +61,10 @@ def __init__( style: Optional[str] = None, candidates: Optional[Dict[str, Dict[str, List[Any]]]] = None, spacy_nlp: Any = None, + dest: Optional[str] = None, + guards: Optional[List[Guard]] = None, labels: Optional[List[str]] = None, threshold: Optional[float] = None, - access: Optional[PluginFn] = None, - mutate: Optional[PluginFn] = None, input_column: str = const.ALTERNATIVES, output_column: Optional[str] = None, use_transform: bool = True, @@ -72,8 +72,8 @@ def __init__( debug: bool = False, ): super().__init__( - access=access, - mutate=mutate, + dest=dest, + guards=guards, debug=debug, input_column=input_column, output_column=output_column, @@ -214,8 +214,9 @@ def get_entities(self, transcripts: List[str]) -> List[BaseEntity]: aggregated_entities = self.entity_consensus(entities, len(transcripts)) return self.apply_filters(aggregated_entities) - def utility(self, *args: Any) -> Any: - return self.get_entities(*args) # pylint: disable=no-value-for-parameter + def utility(self, input: Input, _: Output) -> Any: + transcripts = input.transcripts + return self.get_entities(transcripts) # pylint: disable=no-value-for-parameter def ner_search(self, transcript: str) -> MatchType: """ @@ -308,7 +309,7 @@ def transform(self, training_data: pd.DataFrame) -> pd.DataFrame: logger.disable("dialogy") for i, row in tqdm(training_data.iterrows(), total=len(training_data)): transcripts = self.make_transform_values(row[self.input_column]) - entities = self.utility(transcripts) + entities = self.get_entities(transcripts) is_empty_series = isinstance(row[self.output_column], pd.Series) and ( row[self.output_column].isnull() ) diff --git a/dialogy/plugins/text/list_search_plugin/__init__.py b/dialogy/plugins/text/list_search_plugin/__init__.py index e4912d29..b4df47bc 100644 --- a/dialogy/plugins/text/list_search_plugin/__init__.py +++ b/dialogy/plugins/text/list_search_plugin/__init__.py @@ -9,18 +9,15 @@ all other entities. So that their :code:`from_dict(...)` methods are pristine and involve no shape hacking. """ import re -from pprint import pformat from typing import Any, Dict, List, Optional, Tuple -import pandas as pd import stanza from loguru import logger from thefuzz import fuzz -from tqdm import tqdm from dialogy import constants as const +from dialogy.base import Guard, Input, Output, Plugin from dialogy.base.entity_extractor import EntityScoringMixin -from dialogy.base.plugin import Plugin, PluginFn from dialogy.types import BaseEntity, KeywordEntity Text = str @@ -63,8 +60,8 @@ def __init__( self, fuzzy_dp_config: Dict[Any, Any], # parsed yaml file threshold: Optional[float] = None, - access: Optional[PluginFn] = None, - mutate: Optional[PluginFn] = None, + dest: Optional[str] = None, + guards: Optional[List[Guard]] = None, input_column: str = const.ALTERNATIVES, output_column: Optional[str] = None, use_transform: bool = True, @@ -73,8 +70,8 @@ def __init__( fuzzy_threshold: Optional[float] = 0.1, ): super().__init__( - access=access, - mutate=mutate, + dest=dest, + guards=guards, debug=debug, input_column=input_column, output_column=output_column, @@ -124,48 +121,6 @@ def fuzzy_init(self) -> None: lang=lang_code, tokenize_pretokenized=True ) - ''' - def _parse(self, candidates: Optional[Dict[str, Dict[str, List[Any]]]]) -> None: - """ - Pre compile regex patterns to speed up runtime evaluation. - - This method's search will still be slow depending on the list of patterns. - - :param candidates: A map for entity types and their pattern list. - :type candidates: Optional[Dict[str, List[str]]] - :return: None - :rtype: NoneType - """ - logger.debug( - pformat( - { - "style": self.style, - "candidates": candidates, - } - ) - ) - if not isinstance(candidates, dict): - raise TypeError( - 'Expected "candidates" to be a Dict[str, List[str]]' - f" but {type(candidates)} was found." - ) - - if not candidates: - raise ValueError( - 'Expected "candidates" to be a Dict[str, List[str]]' - f" but {candidates} was found." - ) - - if self.style not in self.__style_search_map: - raise ValueError( - f"Expected style to be one of {list(self.__style_search_map.keys())}" - f' but "{self.style}" was found.' - ) - - logger.debug("compiled patterns") - logger.debug(self.compiled_patterns) - ''' - def _search(self, transcripts: List[str], lang: str) -> List[MatchType]: """ Search for tokens in a list of strings. @@ -263,7 +218,6 @@ def get_fuzzy_dp_search(self, transcript: str, lang: str = "") -> MatchType: """ match = [] query = transcript - # regex variables entity_patterns = {} entity_match_dict = {} @@ -291,8 +245,6 @@ def get_fuzzy_dp_search(self, transcript: str, lang: str = "") -> MatchType: return match - # return [(value, self.entity_type, "", (0, 0), float(0))] - def get_entities(self, transcripts: List[str], lang: str) -> List[BaseEntity]: """ Parse entities using regex and spacy ner. @@ -338,5 +290,7 @@ def get_entities(self, transcripts: List[str], lang: str) -> List[BaseEntity]: aggregated_entities = self.entity_consensus(entities, len(transcripts)) return self.apply_filters(aggregated_entities) - def utility(self, *args: Any) -> Any: - return self.get_entities(*args) # pylint: disable=no-value-for-parameter + def utility(self, input_: Input, _: Output) -> Any: + return self.get_entities( + input_.transcripts, input_.lang + ) # pylint: disable=no-value-for-parameter diff --git a/dialogy/plugins/text/merge_asr_output/__init__.py b/dialogy/plugins/text/merge_asr_output/__init__.py index d59e204b..fd0e3389 100644 --- a/dialogy/plugins/text/merge_asr_output/__init__.py +++ b/dialogy/plugins/text/merge_asr_output/__init__.py @@ -26,7 +26,7 @@ from tqdm import tqdm import dialogy.constants as const -from dialogy.base.plugin import Plugin, PluginFn +from dialogy.base import Guard, Input, Output, Plugin from dialogy.utils import normalize @@ -93,24 +93,24 @@ class MergeASROutputPlugin(Plugin): def __init__( self, - access: Optional[PluginFn], - mutate: Optional[PluginFn], input_column: str = const.ALTERNATIVES, output_column: Optional[str] = None, use_transform: bool = False, + dest: Optional[str] = None, + guards: Optional[List[Guard]] = None, debug: bool = False, ) -> None: super().__init__( - access=access, - mutate=mutate, - debug=debug, + dest=dest, + guards=guards, input_column=input_column, output_column=output_column, use_transform=use_transform, + debug=debug, ) - def utility(self, *args: Any) -> Any: - return merge_asr_output(*args) + def utility(self, input: Input, _: Output) -> Any: + return merge_asr_output(input.utterances) def transform(self, training_data: pd.DataFrame) -> pd.DataFrame: if not self.use_transform: diff --git a/dialogy/plugins/text/slot_filler/rule_slot_filler.py b/dialogy/plugins/text/slot_filler/rule_slot_filler.py index 19874a05..70c93039 100644 --- a/dialogy/plugins/text/slot_filler/rule_slot_filler.py +++ b/dialogy/plugins/text/slot_filler/rule_slot_filler.py @@ -5,10 +5,9 @@ """ from typing import Any, List, Optional -from dialogy.base.plugin import Plugin +from dialogy.base import Guard, Input, Output, Plugin from dialogy.types.entity import BaseEntity from dialogy.types.intent import Intent -from dialogy.types.plugin import PluginFn from dialogy.types.slots import Rule from dialogy.utils.logger import logger @@ -108,9 +107,9 @@ class RuleBasedSlotFillerPlugin(Plugin): def __init__( self, rules: Rule, - fill_multiple: bool = False, - access: Optional[PluginFn] = None, - mutate: Optional[PluginFn] = None, + dest: Optional[str] = None, + guards: Optional[List[Guard]] = None, + fill_multiple: bool = True, debug: bool = False, ) -> None: """ @@ -125,7 +124,7 @@ def __init__( # ``` # rules = {"intent": {"slot_name": "entity_type"}} # ``` - super().__init__(access=access, mutate=mutate, debug=debug) + super().__init__(dest=dest, guards=guards, debug=debug) self.rules: Rule = rules or {} # fill_multiple @@ -133,21 +132,19 @@ def __init__( # same entity type within a slot. self.fill_multiple = fill_multiple - def fill(self, intents: List[Intent], entities: List[BaseEntity]) -> None: - if not intents: - return + def fill(self, intents: List[Intent], entities: List[BaseEntity]) -> List[Intent]: + if not isinstance(intents, list) or not intents: + return intents - if not (isinstance(intents, list) and isinstance(intents[0], Intent)): - return + intent, *rest = intents - intent = intents[0] intent.apply(self.rules) for entity in entities: intent.fill_slot(entity, fill_multiple=self.fill_multiple) intent.cleanup() - logger.debug(f"intent after slot-filling: {intent}") + return [intent, *rest] - def utility(self, *args: Any) -> Any: - return self.fill(*args) # pylint: disable=no-value-for-parameter + def utility(self, _: Input, output: Output) -> List[Intent]: + return self.fill(output.intents, output.entities) diff --git a/dialogy/plugins/text/voting/intent_voting.py b/dialogy/plugins/text/voting/intent_voting.py index b7c7b5b2..88a2c56a 100644 --- a/dialogy/plugins/text/voting/intent_voting.py +++ b/dialogy/plugins/text/voting/intent_voting.py @@ -8,10 +8,9 @@ import pydash as py_ from dialogy import constants as const -from dialogy.base.plugin import Plugin +from dialogy.base import Guard, Input, Output, Plugin from dialogy.types import Signal from dialogy.types.intent import Intent -from dialogy.types.plugin import PluginFn from dialogy.utils.logger import logger @@ -169,10 +168,10 @@ def update_intent(w, intent): def __init__( self, - access: Optional[PluginFn] = None, - mutate: Optional[PluginFn] = None, threshold: float = 0.6, consensus: float = 0.2, + dest: Optional[str] = None, + guards: Optional[List[Guard]] = None, representation: float = 0.3, fallback_intent: str = const.S_INTENT_OOS, aggregate_fn: Any = np.mean, @@ -181,14 +180,14 @@ def __init__( """ constructor """ - super().__init__(access=access, mutate=mutate, debug=debug) + super().__init__(dest=dest, guards=guards, debug=debug) self.threshold: float = threshold self.consensus: float = consensus self.representation: float = representation self.fallback_intent: str = fallback_intent self.aggregate_fn: Any = aggregate_fn - def vote_signal(self, intents: List[Intent], trials: int) -> Intent: + def vote_signal(self, intents: List[Intent], trials: int) -> List[Intent]: """ Reduce a list of intents. @@ -204,7 +203,7 @@ def vote_signal(self, intents: List[Intent], trials: int) -> Intent: :return: Voted signal or fallback in case of no consensus. :rtype: Intent """ - fallback = Intent(name=self.fallback_intent, score=1) + fallback = [Intent(name=self.fallback_intent, score=1)] if not intents: return fallback @@ -235,11 +234,13 @@ def vote_signal(self, intents: List[Intent], trials: int) -> Intent: logger.debug(f"strong signal: {strong_signal}") if (consensus_achieved or representative_signal) and strong_signal: - return Intent( - name=main_intent[const.SIGNAL.NAME], # type: ignore - score=main_intent[const.SIGNAL.STRENGTH], # type: ignore - ) - return Intent(name=self.fallback_intent, score=1) - - def utility(self, *args: Any) -> Any: - return self.vote_signal(*args) # pylint: disable=no-value-for-parameter + return [ + Intent( + name=main_intent[const.SIGNAL.NAME], # type: ignore + score=main_intent[const.SIGNAL.STRENGTH], # type: ignore + ) + ] + return [Intent(name=self.fallback_intent, score=1)] + + def utility(self, input_: Input, output: Output) -> Any: + return self.vote_signal(output.intents, len(input_.transcripts)) diff --git a/dialogy/types/__init__.py b/dialogy/types/__init__.py index 97c20117..9dbcf421 100644 --- a/dialogy/types/__init__.py +++ b/dialogy/types/__init__.py @@ -38,4 +38,5 @@ from dialogy.types.intent import Intent from dialogy.types.plugin import PluginFn from dialogy.types.signal.signal import Signal +from dialogy.types.slots import Slot from dialogy.types.utterances import Alternative, Transcript, Utterance diff --git a/dialogy/types/entity/base_entity.py b/dialogy/types/entity/base_entity.py index c9da0d48..61537206 100644 --- a/dialogy/types/entity/base_entity.py +++ b/dialogy/types/entity/base_entity.py @@ -10,6 +10,8 @@ - BaseEntity """ +from __future__ import annotations + import copy from typing import Any, Dict, List, Optional, Union @@ -32,22 +34,22 @@ class BaseEntity: # **range** # # is the character range in the alternative where the entity is parsed. - range = attr.ib(type=Dict[str, int], repr=False) - - # **type** - # - # is same as dimension or `dim` for now. We may deprecate `dim` and retain only `type`. - type = attr.ib(type=str, validator=attr.validators.instance_of(str)) + range = attr.ib(type=Dict[str, int], repr=False, order=False, kw_only=True) # **body** # # is the string from which the entity is extracted. - body = attr.ib(type=str, validator=attr.validators.instance_of(str)) + body = attr.ib(type=str, validator=attr.validators.instance_of(str), order=False) # **dim** # # is influenced from Duckling's convention of categorization. - dim = attr.ib(type=Optional[str], default=None, repr=False) + dim = attr.ib(type=Optional[str], default=None, repr=False, order=False) + + # **type** + # + # is same as dimension or `dim` for now. We may deprecate `dim` and retain only `type`. + type = attr.ib(type=str, default="value", order=False, kw_only=True) # **parsers** # @@ -56,6 +58,7 @@ class BaseEntity: type=List[str], default=attr.Factory(list), validator=attr.validators.instance_of(list), + order=False, ) # **score** @@ -66,19 +69,21 @@ class BaseEntity: # **slot_names** # # Entities have awareness of the slots they should fill. - slot_names = attr.ib(type=List[str], default=attr.Factory(list), repr=False) + slot_names = attr.ib( + type=List[str], default=attr.Factory(list), repr=False, order=False + ) # **alternative_index** # # is the index of transcript within the ASR output: `List[Utterances]` # from which this entity was picked up. This may be None. - alternative_index = attr.ib(type=Optional[int], default=None) - alternative_indices = attr.ib(type=Optional[List[int]], default=None) + alternative_index = attr.ib(type=Optional[int], default=None, order=False) + alternative_indices = attr.ib(type=Optional[List[int]], default=None, order=False) # **latent** # # Duckling influenced attribute, tells if there is less evidence for an entity if latent is True. - latent = attr.ib(type=bool, default=False) + latent = attr.ib(type=bool, default=False, order=False) # **values** # @@ -88,23 +93,21 @@ class BaseEntity: default=attr.Factory(list), validator=attr.validators.instance_of(List), repr=False, + order=False, ) # **values** # # A single value interpretation from values. - value: Any = attr.ib(default=None) + value: Any = attr.ib(default=None, order=False) # **entity_type** # # Mirrors type, to be deprecated. - entity_type: Optional[str] = attr.ib(default=None, repr=False) + entity_type: Optional[str] = attr.ib(default=None, repr=False, order=False) __properties_map = const.BASE_ENTITY_PROPS - def __attrs_post_init__(self) -> None: - self.entity_type = self.type - @classmethod def validate(cls, dict_: Dict[str, Any]) -> None: """ @@ -133,7 +136,7 @@ def reshape(cls, dict_: Dict[str, Any]) -> Dict[str, Any]: # ['body', 'start', 'value', 'end', 'dim', 'latent'] # **type** of an entity is same as its **dimension**. - dict_[const.EntityKeys.TYPE] = dict_[const.EntityKeys.DIM] + dict_[const.EntityKeys.ENTITY_TYPE] = dict_[const.EntityKeys.DIM] # This piece is a preparation for multiple entity values. # So, even though we are confident of the value found, we are still keeping the @@ -151,7 +154,7 @@ def reshape(cls, dict_: Dict[str, Any]) -> Dict[str, Any]: return dict_ @classmethod - def from_dict(cls, dict_: Dict[str, Any]) -> "BaseEntity": + def from_dict(cls, dict_: Dict[str, Any]) -> BaseEntity: """ Create an instance of a given class `cls` from a `dict` that complies with attributes of `cls` through its keys and values. @@ -208,7 +211,7 @@ def get_value(self, reference: Any = None) -> Any: else: return reference.get(const.VALUE) - def copy(self) -> "BaseEntity": + def copy(self) -> BaseEntity: """ Create a deep copy of the instance and return. diff --git a/dialogy/types/entity/duration_entity.py b/dialogy/types/entity/duration_entity.py index 031f96f7..f0693613 100644 --- a/dialogy/types/entity/duration_entity.py +++ b/dialogy/types/entity/duration_entity.py @@ -44,7 +44,7 @@ def reshape(cls, dict_: Dict[str, Any]) -> Dict[str, Any]: } # ['body', 'start', 'value', 'end', 'dim', 'latent'] - dict_[const.EntityKeys.TYPE] = dict_[const.EntityKeys.DIM] + dict_[const.EntityKeys.ENTITY_TYPE] = dict_[const.EntityKeys.DIM] # This piece is a preparation for multiple entity values. # So, even though we are confident of the value found, we are still keeping the diff --git a/dialogy/types/entity/location_entity.py b/dialogy/types/entity/location_entity.py index 339c428c..69e74dfc 100644 --- a/dialogy/types/entity/location_entity.py +++ b/dialogy/types/entity/location_entity.py @@ -5,7 +5,7 @@ Import classes: - LocationEntity """ -from typing import Dict +from typing import Dict, Optional import attr @@ -20,3 +20,4 @@ class LocationEntity(BaseEntity): """ _meta = attr.ib(type=Dict[str, str], default=attr.Factory(Dict)) + entity_type: Optional[str] = attr.ib(default="location", repr=False, order=False) diff --git a/dialogy/types/entity/numerical_entity.py b/dialogy/types/entity/numerical_entity.py index d703d2dc..3a03017a 100644 --- a/dialogy/types/entity/numerical_entity.py +++ b/dialogy/types/entity/numerical_entity.py @@ -5,6 +5,8 @@ Import classes: - NumericalEntity """ +from typing import Optional + import attr from dialogy import constants as const @@ -30,4 +32,5 @@ class NumericalEntity(BaseEntity): origin = attr.ib( type=str, default="value", validator=attr.validators.instance_of(str) ) + entity_type: Optional[str] = attr.ib(default="number", repr=False, order=False) __properties_map = const.BASE_ENTITY_PROPS diff --git a/dialogy/types/entity/people_entity.py b/dialogy/types/entity/people_entity.py index 181625cd..4e5d2e3a 100644 --- a/dialogy/types/entity/people_entity.py +++ b/dialogy/types/entity/people_entity.py @@ -5,6 +5,8 @@ Import classes: - PeopleEntity """ +from typing import Optional + import attr from dialogy import constants @@ -23,4 +25,5 @@ class PeopleEntity(NumericalEntity): """ unit = attr.ib(type=str, default="", validator=attr.validators.instance_of(str)) + entity_type: Optional[str] = attr.ib(default="people", repr=False, order=False) __properties_map = constants.PEOPLE_ENTITY_PROPS diff --git a/dialogy/types/entity/time_entity.py b/dialogy/types/entity/time_entity.py index 539efeea..bea25c72 100644 --- a/dialogy/types/entity/time_entity.py +++ b/dialogy/types/entity/time_entity.py @@ -198,8 +198,6 @@ def __attrs_post_init__(self) -> None: self.post_init() def post_init(self) -> None: - grain_: Optional[str] = None if isinstance(self.values, list) and self.values: self.grain = self.values[0].get("grain") or self.grain self.entity_type = self.set_entity_type() - self.type = self.entity_type diff --git a/dialogy/types/entity/time_interval_entity.py b/dialogy/types/entity/time_interval_entity.py index e1d98db5..ca2aaba8 100644 --- a/dialogy/types/entity/time_interval_entity.py +++ b/dialogy/types/entity/time_interval_entity.py @@ -29,6 +29,7 @@ class TimeIntervalEntity(TimeEntity): origin = "interval" dim = "time" + type = attr.ib(type=str, default="value", order=False) __properties_map = const.TIME_ENTITY_PROPS @classmethod diff --git a/dialogy/types/intent/__init__.py b/dialogy/types/intent/__init__.py index e4f13cff..70a14793 100644 --- a/dialogy/types/intent/__init__.py +++ b/dialogy/types/intent/__init__.py @@ -6,57 +6,42 @@ - Intent """ +from __future__ import annotations from typing import Any, Dict, List, Optional +import attr + from dialogy.types.entity import BaseEntity from dialogy.types.slots import Rule, Slot from dialogy.utils.logger import logger +@attr.s class Intent: """ An instance of this class contains the name of the action associated with a body of text. """ - def __init__( - self, - name: str, - score: float, - parsers: Optional[List[str]] = None, - alternative_index: Optional[int] = 0, - slots: Optional[Dict[str, Slot]] = None, - ) -> None: - """ - - :param name: The name of the intent. A class label - :type name: str - :param score: The confidence score from a model. - :type score: float - :param parsers: Items that alter the attributes. - :type parsers: Optional[List[str]] - :param alternative_index: Out of a list of transcripts, this points at the index which lead to prediction. - :type alternative_index: Optional[int] - :param slots: A map of slot names and :ref:`Slot` - :type slots: Optional[Dict[str, Slot]] - """ - # The name of the intent to be used. - self.name = name + # The name of the intent to be used. + name: str = attr.ib(kw_only=True, order=False) - # The confidence of this intent being present in the utterance. - self.score = score + # The confidence of this intent being present in the utterance. + score: float = attr.ib(kw_only=True, default=0.0, order=True) - # Trail of functions that modify the attributes of an instance. - self.parsers = parsers or [] + # In case of an ASR, `alternative_index` points at one of the nth + # alternatives that help in predictions. + alternative_index: Optional[int] = attr.ib(kw_only=True, default=None, order=False) - # In case of an ASR, `alternative_index` points at one of the nth - # alternatives that help in predictions. - self.alternative_index = alternative_index + # Trail of functions that modify the attributes of an instance. + parsers: List[str] = attr.ib(kw_only=True, factory=list, order=False, repr=False) - # Container for holding `List[BaseEntity]`. - self.slots: Dict[str, Slot] = slots or {} + # Container for holding `List[BaseEntity]`. + slots: Dict[str, Slot] = attr.ib( + kw_only=True, factory=dict, order=False, repr=False + ) - def apply(self, rules: Rule) -> "Intent": + def apply(self, rules: Rule) -> Intent: """ Create slots using :ref:`rules`. @@ -91,7 +76,7 @@ def apply(self, rules: Rule) -> "Intent": return self - def add_parser(self, plugin: Any) -> "Intent": + def add_parser(self, plugin: Any) -> Intent: """ Update parsers with the plugin name @@ -107,7 +92,7 @@ def add_parser(self, plugin: Any) -> "Intent": self.parsers.append(plugin_name) return self - def fill_slot(self, entity: BaseEntity, fill_multiple: bool = False) -> "Intent": + def fill_slot(self, entity: BaseEntity, fill_multiple: bool = False) -> Intent: """ Update :code:`slots[slot_type].values` with a single entity. @@ -134,10 +119,11 @@ def fill_slot(self, entity: BaseEntity, fill_multiple: bool = False) -> "Intent" f"slot type: {slot.types}", ) logger.debug( - f"entity type: {entity.type}", + f"entity type: {entity.entity_type}", ) - if entity.type in slot.types: + if entity.entity_type in slot.types: if fill_multiple: + logger.debug(f"filling {entity} into {self.name}.") self.slots[slot_name].add(entity) return self @@ -176,11 +162,8 @@ def json(self) -> Dict[str, Any]: """ return { "name": self.name, - "score": self.score, "alternative_index": self.alternative_index, - "slots": [slot.json() for slot in self.slots.values()], + "score": self.score, "parsers": self.parsers, + "slots": {slot_name: slot.json() for slot_name, slot in self.slots.items()}, } - - def __repr__(self) -> str: - return f"Intent(name={self.name}, score={self.score}, slots={self.slots})" diff --git a/dialogy/types/slots/__init__.py b/dialogy/types/slots/__init__.py index 15d43790..74828a8e 100644 --- a/dialogy/types/slots/__init__.py +++ b/dialogy/types/slots/__init__.py @@ -7,13 +7,17 @@ - Slot """ +from __future__ import annotations from typing import Any, Dict, List +import attr + import dialogy.constants as const from dialogy.types.entity import BaseEntity +@attr.s class Slot: """Slot Type @@ -23,19 +27,18 @@ class Slot: - `values` list of entities extracted """ - def __init__(self, name: str, types: List[str], values: List[BaseEntity]) -> None: - self.name = name - self.types = types - self.values = values + name: str = attr.ib(kw_only=True, order=True) + types: List[str] = attr.ib(kw_only=True, factory=list, order=False) + values: List[BaseEntity] = attr.ib(kw_only=True, factory=list, order=False) - def add(self, entity: BaseEntity) -> "Slot": + def add(self, entity: BaseEntity) -> Slot: """ Insert the `BaseEntity` within the current `Slot` instance. """ self.values.append(entity) return self - def clear(self) -> "Slot": + def clear(self) -> Slot: """ Remove all `BaseEntity` within the current `Slot` instance. """ @@ -49,16 +52,11 @@ def json(self) -> Dict[str, Any]: Returns: Dict[str, Any] """ - entities_json = [entity.json() for entity in self.values] - slot_json = { + return { "name": self.name, - "type": self.types, - const.EntityKeys.VALUES: entities_json, + "types": self.types, + "values": [entity.json() for entity in self.values], } - return slot_json - - def __repr__(self) -> str: - return f"Slot(name={self.name}, types={self.types}, values={self.values})" Rule = Dict[str, Dict[str, Any]] diff --git a/dialogy/utils/__init__.py b/dialogy/utils/__init__.py index 8df83432..b15fbe08 100644 --- a/dialogy/utils/__init__.py +++ b/dialogy/utils/__init__.py @@ -1,5 +1,6 @@ +from dialogy.utils.datetime import dt2timestamp, is_unix_ts, make_unix_ts from dialogy.utils.file_handler import create_timestamps_path, load_file, save_file from dialogy.utils.logger import logger -from dialogy.utils.misc import dt2timestamp, traverse_dict, validate_type +from dialogy.utils.misc import traverse_dict, validate_type from dialogy.utils.naive_lang_detect import lang_detect_from_text -from dialogy.utils.normalize_utterance import normalize +from dialogy.utils.normalize_utterance import is_utterance, normalize diff --git a/dialogy/utils/datetime.py b/dialogy/utils/datetime.py index f0e2630b..7e17f092 100644 --- a/dialogy/utils/datetime.py +++ b/dialogy/utils/datetime.py @@ -1,10 +1,37 @@ -import math from datetime import datetime -from typing import Callable +from typing import Callable, Union import pytz +def is_unix_ts(ts: int) -> bool: + """ + Check if the input is a unix timestamp. + + :param ts: A unix timestamp (13-digit). + :type ts: int + :return: True if :code:`ts` is a unix timestamp, else False. + :rtype: bool + """ + try: + datetime.fromtimestamp(ts / 1000) + return True + except ValueError: + return False + + +def dt2timestamp(date_time: datetime) -> int: + """ + Converts a python datetime object to unix-timestamp. + + :param date_time: An instance of datetime. + :type date_time: datetime + :return: Unix timestamp integer. + :rtype: int + """ + return int(date_time.timestamp() * 1000) + + def make_unix_ts(tz: str = "UTC") -> Callable[[str], int]: """ Convert date in ISO 8601 format to unix ms timestamp. @@ -15,7 +42,7 @@ def make_unix_ts(tz: str = "UTC") -> Callable[[str], int]: :rtype: Callable[[str], int] """ - def make_tz_aware(date_string: str) -> int: + def make_tz_aware(date_string: Union[str, int]) -> int: """ Convert date in ISO 8601 format to unix ms timestamp. @@ -32,22 +59,6 @@ def make_tz_aware(date_string: str) -> int: if dt.tzinfo is None: dt = dt.replace(tzinfo=pytz.timezone(tz)) - return int(dt.timestamp() * 1000) + return dt2timestamp(dt) return make_tz_aware - - -def is_unix_ts(ts: int) -> bool: - """ - Check if the input is a unix timestamp. - - :param ts: A unix timestamp (13-digit). - :type ts: int - :return: True if :code:`ts` is a unix timestamp, else False. - :rtype: bool - """ - try: - datetime.fromtimestamp(ts / 1000) - return True - except ValueError: - return False diff --git a/dialogy/utils/misc.py b/dialogy/utils/misc.py index f3e3aae4..8c3370ef 100644 --- a/dialogy/utils/misc.py +++ b/dialogy/utils/misc.py @@ -79,15 +79,3 @@ def validate_type(obj: Any, obj_type: Union[type, Tuple[type]]) -> None: """ if not isinstance(obj, obj_type): raise TypeError(f"{obj} should be a {obj_type}") - - -def dt2timestamp(date_time: datetime) -> int: - """ - Converts a python datetime object to unix-timestamp. - - :param date_time: An instance of datetime. - :type date_time: datetime - :return: Unix timestamp integer. - :rtype: int - """ - return int(date_time.timestamp() * 1000) diff --git a/dialogy/utils/normalize_utterance.py b/dialogy/utils/normalize_utterance.py index 54f298e4..214357c2 100644 --- a/dialogy/utils/normalize_utterance.py +++ b/dialogy/utils/normalize_utterance.py @@ -175,7 +175,7 @@ def utterance2alternatives( Convert a list of utterances to a list of alternatives. """ return [ - " ".join([alternative[key] for alternative in alternatives]).lower() + " ".join([alternative[key] for alternative in alternatives]) for alternatives in itertools.product(*utterances) ] @@ -214,13 +214,13 @@ def normalize(maybe_utterance: Any, key: str = const.TRANSCRIPT) -> List[str]: return utterance2alternatives(maybe_utterance) if is_unsqueezed_utterance(maybe_utterance): - return [alternative[key].lower() for alternative in maybe_utterance] + return [alternative[key] for alternative in maybe_utterance] if is_list(maybe_utterance) and is_list_of_string(maybe_utterance): - return [utterance.lower() for utterance in maybe_utterance] + return [utterance for utterance in maybe_utterance] if is_string(maybe_utterance): - return [maybe_utterance.lower()] + return [maybe_utterance] else: raise TypeError( diff --git a/dialogy/workflow/workflow.py b/dialogy/workflow/workflow.py index 40c5ba6b..f378f30f 100644 --- a/dialogy/workflow/workflow.py +++ b/dialogy/workflow/workflow.py @@ -48,15 +48,19 @@ abstract class. There are some design considerations which make that a bad choice. We want methods to be overridden to offer flexibility of use. """ +from __future__ import annotations + import copy import time from threading import Lock -from typing import Any, Dict, List +from typing import Any, Dict, List, Optional, Tuple import attr import pandas as pd from dialogy import constants as const +from dialogy.base.input import Input +from dialogy.base.output import Output from dialogy.base.plugin import Plugin from dialogy.utils.logger import logger @@ -85,8 +89,16 @@ class Workflow: type=List[Plugin], validator=attr.validators.instance_of(list), ) - input: Dict[str, Any] = attr.ib(factory=dict, kw_only=True) - output: Dict[str, Any] = attr.ib(factory=dict, kw_only=True) + input: Optional[Input] = attr.ib( + default=None, + kw_only=True, + validator=attr.validators.optional(attr.validators.instance_of(Input)), + ) + output: Optional[Output] = attr.ib( + default=None, + kw_only=True, + validator=attr.validators.optional(attr.validators.instance_of(Output)), + ) debug = attr.ib( type=bool, default=False, validator=attr.validators.instance_of(bool) ) @@ -97,29 +109,46 @@ def __attrs_post_init__(self) -> None: """ Post init hook. """ - self.set_io() + self.__reset() self.lock = Lock() - def set_io(self) -> None: + def __reset(self) -> None: """ Use this method to keep workflow-io in the same format as expected. """ - self.input: Dict[str, Any] = {} - self.output: Dict[str, Any] = {const.INTENTS: [], const.ENTITIES: []} + self.input = None + self.output = Output() - def execute(self) -> None: + def set(self, path: str, value: Any) -> Workflow: + """ + Set attribute path with value. + + :param path: A '.' separated attribute path. + :type path: str + :param value: A value to set. + :type value: Any + :return: This instance + :rtype: Workflow + """ + dest, attribute = path.split(".") + + if dest == "input": + self.input = Input.from_dict({attribute: value}, reference=self.input) + elif dest == "output" and isinstance(value, list): + self.output = Output.from_dict({attribute: value}, reference=self.output) + elif dest == "output": + raise ValueError(f"{value=} should be a List[Intent] or List[BaseEntity].") + else: + raise ValueError(f"{path} is not a valid path.") + return self + + def execute(self) -> Workflow: """ Update input, output attributes. We iterate through pre/post processing functions and update the input and output attributes of the class. It is expected that pre-processing functions would modify the input, and post-processing functions would modify the output. - - Args: - processors (`List`): The list of preprocess or postprocess functions. - - Raises: - `TypeError`: If any element in processors list is not a Callable. """ history = {} for plugin in self.plugins: @@ -147,8 +176,9 @@ def execute(self) -> None: history["perf"] = round(end - start, 4) if history: logger.debug(history) + return self - def run(self, input_: Any) -> Any: + def run(self, input_: Input) -> Tuple[Dict[str, Any], Dict[str, Any]]: """ .. _workflow_run: @@ -156,26 +186,21 @@ def run(self, input_: Any) -> Any: The current workflow exhibits the following simple procedure: pre-processing -> inference -> post-processing. - - Args: - input_ (`Any`): This function receives any arbitrary input. Subclasses may enforce - a stronger check. - - Returns: - (`Any`): This function can return any arbitrary value. Subclasses may enforce a stronger check. """ with self.lock: self.input = input_ - self.execute() - output = copy.copy(self.output) - self.flush() - return output + return self.execute().flush() - def flush(self) -> None: + def flush(self) -> Tuple[Dict[str, Any], Dict[str, Any]]: """ Reset :code:`workflow.input` and :code:`workflow.output`. """ - self.set_io() + if self.input is None or self.output is None: + return {}, {} + input_ = copy.deepcopy(self.input.json()) + output = copy.deepcopy(self.output.json()) + self.__reset() + return input_, output def json(self) -> Dict[str, Any]: """ @@ -189,7 +214,7 @@ def json(self) -> Dict[str, Any]: not in self.NON_SERIALIZABLE_FIELDS, ) - def train(self, training_data: pd.DataFrame) -> None: + def train(self, training_data: pd.DataFrame) -> Workflow: """ Train all the plugins in the workflow. @@ -205,3 +230,7 @@ def train(self, training_data: pd.DataFrame) -> None: transformed_data = plugin.transform(training_data) if transformed_data is not None: training_data = transformed_data + return self + + +47 diff --git a/mypy.ini b/mypy.ini index 046183dd..94fe7c33 100644 --- a/mypy.ini +++ b/mypy.ini @@ -15,7 +15,7 @@ warn_no_return = True strict_optional = True no_implicit_optional = True warn_redundant_casts = True -warn_unused_ignores = True +warn_unused_ignores = False # Display the codes needed for # type: ignore[code] annotations. show_error_codes = True diff --git a/tests/base/test_entity_extractor.py b/tests/base/test_entity_extractor.py index 0b23c7f4..93092ed0 100644 --- a/tests/base/test_entity_extractor.py +++ b/tests/base/test_entity_extractor.py @@ -3,6 +3,7 @@ import pandas as pd import pydash as py_ import pytest +from torch import threshold from dialogy.base.entity_extractor import EntityScoringMixin from dialogy.plugins import DucklingPlugin @@ -11,17 +12,8 @@ def make_entity_object(entity_items): - reference_time = 1622640071000 - - def access(workflow): - return workflow.input, reference_time, "en_IN" - - def mutate(workflow, entities): - workflow.output = {"entities": entities} - duckling_plugin = DucklingPlugin( - access=access, - mutate=mutate, + dest="output.entities", dimensions=["date", "time", "duration", "number", "people"], timezone="Asia/Kolkata", debug=False, @@ -35,6 +27,24 @@ def mutate(workflow, entities): ) +def test_remove_low_scoring_entities(): + entity_extractor = EntityScoringMixin() + entity_extractor.threshold = 0.5 + body = "test" + entities = [ + KeywordEntity( + body=body, + value=body, + range={ + "start": 0, + "end": len(body), + }, + entity_type=body, + ) + ] + assert entity_extractor.remove_low_scoring_entities(entities) == entities + + @pytest.mark.parametrize("payload", load_tests("entity_extractor", __file__)) def test_entity_extractor_for_thresholding(payload) -> None: """ diff --git a/tests/base/test_entity_extractor.yaml b/tests/base/test_entity_extractor.yaml index 75701bb9..e8202c7d 100644 --- a/tests/base/test_entity_extractor.yaml +++ b/tests/base/test_entity_extractor.yaml @@ -1,4 +1,4 @@ -- description: "" +- description: "Check addition on thresholding" input_size: 10 threshold: 0.1 mock_entities: [[{'body': '3rd of July', @@ -167,7 +167,7 @@ 'dim': 'time', 'latent': False}]] expected: [{'range': {'start': 0, 'end': 11}, - 'type': 'date', + 'type': 'value', 'body': '3rd of July', 'parsers': ["DucklingPlugin"], 'score': 0.2, @@ -176,7 +176,7 @@ 'entity_type': 'date', 'grain': 'day'}, {'range': {'start': 0, 'end': 4}, - 'type': 'date', + 'type': 'value', 'body': 'July', 'parsers': ["DucklingPlugin"], 'score': 0.6, @@ -185,7 +185,7 @@ 'entity_type': 'date', 'grain': 'month'}, {'range': {'start': 0, 'end': 1}, - 'type': 'number', + 'type': 'value', 'body': '3', 'parsers': ["DucklingPlugin"], 'score': 0.2, diff --git a/tests/base/test_io.py b/tests/base/test_io.py new file mode 100644 index 00000000..a2863255 --- /dev/null +++ b/tests/base/test_io.py @@ -0,0 +1,22 @@ +import pytest + +from dialogy.base import Input, Output +from dialogy.types import Intent + + +def test_invalid_reftime(): + with pytest.raises(ValueError): + Input(utterances="test", reference_time=18 ** 15) + + +def test_input_extension(): + instance = Input(utterances="test", reference_time=1644238676772) + extended = Input.from_dict({"utterances": "test", "reference_time": 1644238676772}) + assert instance == extended + + +def test_output_extension(): + intent = Intent(name="test", score=0.5) + instance = Output(intents=[intent]) + extended = Output.from_dict({"intents": [intent]}) + assert instance == extended diff --git a/tests/plugin/test_plugins.py b/tests/plugin/test_plugins.py index f011cb36..e791a5d3 100644 --- a/tests/plugin/test_plugins.py +++ b/tests/plugin/test_plugins.py @@ -2,29 +2,14 @@ This is a tutorial on creating and using Plugins with Workflows. """ import re -from typing import Any, Optional +from typing import Any, List, Optional -from dialogy.base.plugin import Plugin -from dialogy.types.plugin import PluginFn +from dialogy import workflow +from dialogy.base import Guard, Input, Output, Plugin +from dialogy.types import Intent from dialogy.workflow import Workflow -def access(workflow: Workflow) -> Any: - """ - This function would be provided by the - workflow implementer. - """ - return workflow.input - - -def mutate(workflow: Workflow, value: Any) -> Any: - """ - This function would be provided by the - workflow implementer. - """ - workflow.output = value - - # == ArbitraryPlugin == class ArbitraryPlugin(Plugin): """ @@ -44,43 +29,17 @@ class ArbitraryPlugin(Plugin): def __init__( self, - access: Optional[PluginFn] = None, - mutate: Optional[PluginFn] = None, + dest: Optional[str] = None, + guards: Optional[List[Guard]] = None, use_transform: bool = False, debug=False, ): super().__init__( - access=access, mutate=mutate, debug=debug, use_transform=use_transform + dest=dest, guards=guards, debug=debug, use_transform=use_transform ) - def utility(self, numbers, words) -> Any: - """ - Expects a tuple from `access(workflow)` that contains numbers and words. - - Where - - - numbers: A list of numbers. - - words: A list of strings. - - This plugin will: - - - Increase each number in `numbers` by 2. - - Concatenate " world" after each word in `words`. - - The plugin method is the place for implementing - plugin logic. - - - If values need to be persisted? store them in class attributes. - - If A set of validation, API calls etc are dependencies to run the implementation? - create separate methods and call them here. - - Args: - - - workflow (Workflow): The workflow we possibly want to modify. - """ - numbers = [number + 2 for number in numbers] - words = [word + " world" for word in words] - return numbers, words + def utility(self, _: Input, __: Output) -> Any: + return [Intent(name="_greeting_", score=0.9)] # == Plugin as a class with workflow == @@ -89,59 +48,89 @@ def test_arbitrary_plugin() -> None: We will test how an arbitrary-class-based plugin works with a workflow. """ # create an instance of `ArbitraryPlugin`. - arbitrary_plugin = ArbitraryPlugin(access=access, mutate=mutate) + arbitrary_plugin = ArbitraryPlugin(dest="output.intents") # create an instance of a `Workflow`. # we are calling the `arbitrary_plugin` to get the `plugin` de method. workflow = Workflow([arbitrary_plugin]) + input_ = Input(utterances=[[{"transcript": "hello"}]]) # This runs all the `preprocessors` and `postprocessors` provided previously. # we can expect our `arbitrary_plugin` will also be used. - output = workflow.run(([2, 5], ["hello", "hi"])) - - numbers, words = output # pylint: disable=unpacking-non-sequence + _, output = workflow.run(input_) + first_intent, *rest = output["intents"] # This test would pass only if our plugin works correctly! - assert numbers == [4, 7] - assert words == ["hello world", "hi world"] + assert first_intent["name"] == "_greeting_" + assert rest == [] -def test_arbitrary_plugin_with_debug_mode() -> None: - """ - We will test how an arbitrary-class-based plugin works with a workflow. - """ - # create an instance of `ArbitraryPlugin`. - arbitrary_plugin = ArbitraryPlugin(access=access, mutate=mutate, debug=False) +# def test_arbitrary_plugin_with_debug_mode() -> None: +# """ +# We will test how an arbitrary-class-based plugin works with a workflow. +# """ +# # create an instance of `ArbitraryPlugin`. +# arbitrary_plugin = ArbitraryPlugin(dest="output.intents", debug=False) - # create an instance of a `Workflow`. - # we are calling the `arbitrary_plugin` to get the `plugin` de method. - workflow = Workflow([arbitrary_plugin]) +# # create an instance of a `Workflow`. +# # we are calling the `arbitrary_plugin` to get the `plugin` de method. +# workflow = Workflow([arbitrary_plugin]) - # This runs all the `preprocessors` and `postprocessors` provided previously. - # we can expect our `arbitrary_plugin` will also be used. - output = workflow.run(([2, 5], ["hello", "hi"])) +# # This runs all the `preprocessors` and `postprocessors` provided previously. +# # we can expect our `arbitrary_plugin` will also be used. +# output = workflow.run(([2, 5], ["hello", "hi"])) - numbers, words = output # pylint: disable=unpacking-non-sequence +# numbers, words = output # pylint: disable=unpacking-non-sequence - # This test would pass only if our plugin works correctly! - assert numbers == [4, 7] - assert words == ["hello world", "hi world"] +# # This test would pass only if our plugin works correctly! +# assert numbers == [4, 7] +# assert words == ["hello world", "hi world"] def test_plugin_train() -> None: - arbitrary_plugin = ArbitraryPlugin(access=access, mutate=mutate, debug=False) + arbitrary_plugin = ArbitraryPlugin(dest="output.intents") assert arbitrary_plugin.train([]) is None def test_plugin_transform_not_use_transform() -> None: - arbitrary_plugin = ArbitraryPlugin( - access=access, mutate=mutate, debug=False, use_transform=False - ) + arbitrary_plugin = ArbitraryPlugin(dest="output.intents", use_transform=False) assert arbitrary_plugin.transform([]) == [] def test_plugin_transform() -> None: arbitrary_plugin = ArbitraryPlugin( - access=access, mutate=mutate, debug=False, use_transform=True + dest="output.intents", debug=False, use_transform=True ) assert arbitrary_plugin.transform([{"a": 1}]) == [{"a": 1}] + + +def test_plugin_guards() -> None: + arbitrary_plugin = ArbitraryPlugin( + dest="output.intents", + guards=[lambda i, _: i.current_state == "COF"], + ) + workflow = ( + Workflow().set("input.utterances", ["hello"]).set("input.current_state", "COF") + ) + assert arbitrary_plugin.prevent(workflow.input, workflow.output) is True + assert arbitrary_plugin(workflow) is None + + +def test_plugin_no_set_on_invalid_input(): + arbitrary_plugin = ArbitraryPlugin( + dest="output.intents", + guards=[lambda i, _: i.current_state == "COF"], + ) + workflow = Workflow() + assert arbitrary_plugin(workflow) is None + + +def test_plugin_no_set_on_invalid_output(): + arbitrary_plugin = ArbitraryPlugin( + dest="output.intents", + guards=[lambda i, _: i.current_state == "COF"], + ) + workflow = Workflow() + workflow.input = Input(utterances="hello") + workflow.output = None + assert arbitrary_plugin(workflow) is None diff --git a/tests/plugin/text/canonicalization/test_canonicalization.py b/tests/plugin/text/canonicalization/test_canonicalization.py index cee6948d..bc78675e 100644 --- a/tests/plugin/text/canonicalization/test_canonicalization.py +++ b/tests/plugin/text/canonicalization/test_canonicalization.py @@ -2,27 +2,19 @@ import pandas as pd +from dialogy.base import Input from dialogy.plugins.text.canonicalization import CanonicalizationPlugin from dialogy.plugins.text.list_entity_plugin import ListEntityPlugin from dialogy.types import KeywordEntity from dialogy.workflow import Workflow - -def canon_access(w): - return w.output["entities"], w.input["classification_input"] - - -def canon_mutate(w, v): - w.output["classification_input"] = v - - canonicalization = CanonicalizationPlugin( mask_tokens=["hello"], input_column="data", use_transform=True, threshold=0.1, - access=canon_access, - mutate=canon_mutate, + debug=False, + dest="input.clf_feature", ) @@ -30,25 +22,17 @@ def canon_mutate(w, v): mask_tokens=["hello"], input_column="data", threshold=0.1, - access=canon_access, - mutate=canon_mutate, + dest="input.clf_feature", + debug=False, use_transform=False, ) -def entity_access(w): - return (w.input["ner_input"],) - - -def entity_mutate(w, v): - w.output["entities"] = v - - list_entity_plugin = ListEntityPlugin( candidates={"fruits": {"apple": ["apple", "apples"]}}, style="regex", - access=entity_access, - mutate=entity_mutate, + dest="output.entities", + debug=False, ) workflow = Workflow([list_entity_plugin, canonicalization]) @@ -59,7 +43,7 @@ def entity_mutate(w, v): "data": json.dumps(["hello apple", "hello orange"]), "entities": [ KeywordEntity( - type="fruits", + entity_type="fruits", body="apple", parsers=["ListEntityPlugin"], score=1.0, @@ -69,7 +53,7 @@ def entity_mutate(w, v): range={"start": 6, "end": 11}, ), KeywordEntity( - type="colour", + entity_type="colour", body="orange", parsers=["ListEntityPlugin"], score=0.0, @@ -79,7 +63,7 @@ def entity_mutate(w, v): range={"start": 6, "end": 12}, ), KeywordEntity( - type="colour", + entity_type="colour", body="orange", parsers=["ListEntityPlugin"], score=0.5, @@ -102,7 +86,7 @@ def entity_mutate(w, v): "data": json.dumps(["hello apple", "hello orange"]), "entities": [ KeywordEntity( - type="colour", + entity_type="colour", body="orange", parsers=["ListEntityPlugin"], score=0.5, @@ -116,10 +100,9 @@ def entity_mutate(w, v): def test_canonicalization_utility(): - output = workflow.run( - input_={"classification_input": ["hello apple"], "ner_input": ["hello apple"]} - ) - assert output["classification_input"] == ["MASK "] + input_ = Input(utterances=[[{"transcript": "hello apple"}]]) + input_, _ = workflow.run(input_) + assert input_["clf_feature"] == ["MASK "] def test_canonicalization_transform(): diff --git a/tests/plugin/text/classification/test_mlp.py b/tests/plugin/text/classification/test_mlp.py index 7846d4e8..da109a8f 100644 --- a/tests/plugin/text/classification/test_mlp.py +++ b/tests/plugin/text/classification/test_mlp.py @@ -9,6 +9,7 @@ import sklearn import dialogy.constants as const +from dialogy.base import Guard, Input, Plugin from dialogy.plugins import MergeASROutputPlugin, MLPMultiClass from dialogy.utils import load_file from dialogy.workflow import Workflow @@ -17,6 +18,7 @@ class MockMLPClassifier: def __init__(self, model_dir, args=None, **kwargs): + super().__init__(**kwargs) self.model_dir = model_dir self.args = args or {} self.kwargs = kwargs @@ -30,20 +32,8 @@ def train(self, training_data: pd.DataFrame): return -def write_intent_to_workflow(w, v): - w.output[const.INTENTS] = v - - -def update_input(w, v): - w.input[const.CLASSIFICATION_INPUT] = v - - def test_mlp_plugin_when_no_mlpmodel_saved(): - mlp_clf = MLPMultiClass( - model_dir=".", - access=lambda w: w.input[const.CLASSIFICATION_INPUT], - mutate=write_intent_to_workflow, - ) + mlp_clf = MLPMultiClass(model_dir=".", dest="output.intents") assert isinstance(mlp_clf, MLPMultiClass) assert mlp_clf.model_pipeline is None @@ -56,8 +46,7 @@ def test_mlp_plugin_when_mlpmodel_EOFError(capsys): with capsys.disabled(): mlp_plugin = MLPMultiClass( model_dir=directory, - access=lambda w: w.input[const.CLASSIFICATION_INPUT], - mutate=write_intent_to_workflow, + dest="output.intents", debug=False, ) assert mlp_plugin.model_pipeline is None @@ -68,8 +57,8 @@ def test_mlp_plugin_when_mlpmodel_EOFError(capsys): def test_mlp_init_mock(): mlp_clf = MLPMultiClass( model_dir=".", - access=lambda w: w.input[const.CLASSIFICATION_INPUT], - mutate=write_intent_to_workflow, + dest="output.intents", + debug=False, ) mlp_clf.init_model() assert isinstance(mlp_clf.model_pipeline, sklearn.pipeline.Pipeline) @@ -79,8 +68,8 @@ def test_mlp_invalid_argsmap(): with pytest.raises(ValueError): MLPMultiClass( model_dir=".", - access=lambda w: w.input[const.CLASSIFICATION_INPUT], - mutate=write_intent_to_workflow, + dest="output.intents", + debug=False, args_map={"invalid": "value"}, ) @@ -108,8 +97,8 @@ def test_mlp_gridsearch_argsmap(): xlmr_clf = MLPMultiClass( model_dir=".", - access=lambda w: w.input[const.CLASSIFICATION_INPUT], - mutate=write_intent_to_workflow, + dest="output.intents", + debug=False, args_map=fake_args, ) xlmr_clf.init_model() @@ -137,8 +126,8 @@ def test_mlp_gridsearch_argsmap(): with pytest.raises(ValueError): xlmr_clf_invalid = MLPMultiClass( model_dir=".", - access=lambda w: w.input[const.CLASSIFICATION_INPUT], - mutate=write_intent_to_workflow, + dest="output.intents", + debug=False, args_map=fake_invalid_args, ) xlmr_clf_invalid.init_model() @@ -150,8 +139,8 @@ def test_train_mlp_mock(): mlp_clf = MLPMultiClass( model_dir=directory, - access=lambda w: w.input[const.CLASSIFICATION_INPUT], - mutate=write_intent_to_workflow, + dest="output.intents", + debug=False, ) train_df = pd.DataFrame( @@ -170,8 +159,8 @@ def test_train_mlp_mock(): # So this instance would have read the saved mlp model. mlp_clf_copy = MLPMultiClass( model_dir=directory, - access=lambda w: w.input[const.CLASSIFICATION_INPUT], - mutate=write_intent_to_workflow, + dest="output.intents", + debug=False, ) mlp_clf_copy.load() assert isinstance(mlp_clf_copy.model_pipeline, sklearn.pipeline.Pipeline) @@ -204,8 +193,8 @@ def test_train_mlp_gridsearch_mock(): mlp_clf = MLPMultiClass( model_dir=directory, - access=lambda w: w.input[const.CLASSIFICATION_INPUT], - mutate=write_intent_to_workflow, + dest="output.intents", + debug=False, args_map=fake_args, ) @@ -225,8 +214,8 @@ def test_train_mlp_gridsearch_mock(): # So this instance would have read the saved mlp model. mlp_clf_copy = MLPMultiClass( model_dir=directory, - access=lambda w: w.input[const.CLASSIFICATION_INPUT], - mutate=write_intent_to_workflow, + dest="output.intents", + debug=False, ) mlp_clf_copy.load() assert mlp_clf_copy.valid_mlpmodel is True @@ -242,8 +231,8 @@ def test_invalid_operations(): mlp_clf = MLPMultiClass( model_dir=directory, - access=lambda w: w.input[const.CLASSIFICATION_INPUT], - mutate=write_intent_to_workflow, + dest="output.intents", + debug=False, ) mlp_clf.init_model() @@ -309,20 +298,19 @@ def test_inference(payload): const.PRODUCTION: {}, } - text = payload.get("input") + transcripts = payload.get("input") intent = payload["expected"]["label"] mlp_clf = MLPMultiClass( model_dir=directory, - access=lambda w: (w.input[const.CLASSIFICATION_INPUT],), - mutate=write_intent_to_workflow, + dest="output.intents", args_map=fake_args, debug=False, ) merge_asr_output_plugin = MergeASROutputPlugin( - access=lambda w: (w.input[const.CLASSIFICATION_INPUT],), - mutate=update_input, + dest="input.clf_feature", + debug=False, ) workflow = Workflow([merge_asr_output_plugin, mlp_clf]) @@ -357,8 +345,10 @@ def test_inference(payload): ) workflow.train(train_df) - output = workflow.run(input_={const.CLASSIFICATION_INPUT: text}) - assert output[const.INTENTS][0].name == intent - assert output[const.INTENTS][0].score > 0.5 + _, output = workflow.run( + Input(utterances=[[{"transcript": transcript} for transcript in transcripts]]) + ) + assert output[const.INTENTS][0]["name"] == intent + assert output[const.INTENTS][0]["score"] > 0.5 if os.path.exists(file_path): os.remove(file_path) diff --git a/tests/plugin/text/classification/test_xlmr.py b/tests/plugin/text/classification/test_xlmr.py index 2ca5809f..f3629a69 100644 --- a/tests/plugin/text/classification/test_xlmr.py +++ b/tests/plugin/text/classification/test_xlmr.py @@ -9,6 +9,7 @@ import pytest import dialogy.constants as const +from dialogy.base import Input from dialogy.plugins import MergeASROutputPlugin, XLMRMultiClass from dialogy.utils import load_file from dialogy.workflow import Workflow @@ -61,11 +62,7 @@ def test_xlmr_plugin_no_module_error(): const.XLMR_MODULE = "this-module-doesn't-exist" with pytest.raises(ModuleNotFoundError): - XLMRMultiClass( - model_dir=".", - access=lambda w: w.input[const.CLASSIFICATION_INPUT], - mutate=write_intent_to_workflow, - ) + XLMRMultiClass(model_dir=".", dest="output.intents", debug=False) const.XLMR_MODULE = save_val @@ -75,11 +72,7 @@ def test_xlmr_plugin_when_no_labelencoder_saved(): const.XLMR_MODULE = "tests.plugin.text.classification.test_xlmr" const.XLMR_MULTI_CLASS_MODEL = "MockClassifier" - xlmr_clf = XLMRMultiClass( - model_dir=".", - access=lambda w: w.input[const.CLASSIFICATION_INPUT], - mutate=write_intent_to_workflow, - ) + xlmr_clf = XLMRMultiClass(model_dir=".", dest="output.intents", debug=False) assert isinstance(xlmr_clf, XLMRMultiClass) assert xlmr_clf.model is None const.XLMR_MODULE = save_module_name @@ -98,8 +91,7 @@ def test_xlmr_plugin_when_labelencoder_EOFError(capsys): with capsys.disabled(): xlmr_plugin = XLMRMultiClass( model_dir=directory, - access=lambda w: w.input[const.CLASSIFICATION_INPUT], - mutate=write_intent_to_workflow, + dest="output.intents", debug=False, ) assert xlmr_plugin.model is None @@ -115,11 +107,7 @@ def test_xlmr_init_mock(): const.XLMR_MODULE = "tests.plugin.text.classification.test_xlmr" const.XLMR_MULTI_CLASS_MODEL = "MockClassifier" - xlmr_clf = XLMRMultiClass( - model_dir=".", - access=lambda w: w.input[const.CLASSIFICATION_INPUT], - mutate=write_intent_to_workflow, - ) + xlmr_clf = XLMRMultiClass(model_dir=".", dest="output.intents", debug=False) xlmr_clf.init_model(5) assert xlmr_clf.model is not None @@ -136,8 +124,8 @@ def test_xlmr_init_mock(): with pytest.raises(ValueError): XLMRMultiClass( model_dir=".", - access=lambda w: w.input[const.CLASSIFICATION_INPUT], - mutate=write_intent_to_workflow, + dest="output.intents", + debug=False, args_map={"invalid": "value"}, ) const.XLMR_MODULE = save_module_name @@ -152,16 +140,12 @@ def test_train_xlmr_mock(): directory = "/tmp" file_path = os.path.join(directory, const.LABELENCODER_FILE) - xlmr_clf = XLMRMultiClass( - model_dir=directory, - access=lambda w: w.input[const.CLASSIFICATION_INPUT], - mutate=write_intent_to_workflow, - ) + xlmr_clf = XLMRMultiClass(model_dir=directory, dest="output.intents", debug=False) xlmr_clf_state = XLMRMultiClass( model_dir=directory, - access=lambda w: w.input[const.CLASSIFICATION_INPUT], - mutate=write_intent_to_workflow, + dest="output.intents", + debug=False, use_state=True, ) @@ -189,9 +173,7 @@ def test_train_xlmr_mock(): # This copy loads from the same directory that was trained previously. # So this instance would have read the labelencoder saved. xlmr_clf_copy = XLMRMultiClass( - model_dir=directory, - access=lambda w: w.input[const.CLASSIFICATION_INPUT], - mutate=write_intent_to_workflow, + model_dir=directory, dest="output.intents", debug=False ) assert len(xlmr_clf_copy.labelencoder.classes_) == 2 @@ -212,15 +194,11 @@ def test_invalid_operations(): if os.path.exists(file_path): os.remove(file_path) - xlmr_clf = XLMRMultiClass( - model_dir=directory, - access=lambda w: w.input[const.CLASSIFICATION_INPUT], - mutate=write_intent_to_workflow, - ) + xlmr_clf = XLMRMultiClass(model_dir=directory, dest="output.intents", debug=False) xlmr_clf_state = XLMRMultiClass( model_dir=directory, - access=lambda w: w.input[const.CLASSIFICATION_INPUT], - mutate=write_intent_to_workflow, + dest="output.intents", + debug=False, use_state=True, ) @@ -274,19 +252,17 @@ def test_inference(payload): if os.path.exists(file_path): os.remove(file_path) - text = payload.get("input") + transcripts = payload.get("input") intent = payload["expected"]["label"] xlmr_clf = XLMRMultiClass( model_dir=directory, - access=lambda w: (w.input[const.CLASSIFICATION_INPUT],), - mutate=write_intent_to_workflow, + dest="output.intents", debug=False, ) merge_asr_output_plugin = MergeASROutputPlugin( - access=lambda w: (w.input[const.CLASSIFICATION_INPUT],), - mutate=update_input, + dest="input.clf_feature", debug=False ) workflow = Workflow([merge_asr_output_plugin, xlmr_clf]) @@ -317,9 +293,13 @@ def test_inference(payload): xlmr_clf.model, MockClassifier ), "model should be a MockClassifier after training." - output = workflow.run(input_={const.CLASSIFICATION_INPUT: text}) - assert output[const.INTENTS][0].name == intent - assert output[const.INTENTS][0].score > 0.9 + _, output = workflow.run( + input_=Input( + utterances=[[{"transcript": transcript} for transcript in transcripts]] + ) + ) + assert output[const.INTENTS][0]["name"] == intent + assert output[const.INTENTS][0]["score"] > 0.9 if os.path.exists(file_path): os.remove(file_path) diff --git a/tests/plugin/text/combine_date_time/test_combine_date_time.py b/tests/plugin/text/combine_date_time/test_combine_date_time.py index fba93c0b..0c417483 100644 --- a/tests/plugin/text/combine_date_time/test_combine_date_time.py +++ b/tests/plugin/text/combine_date_time/test_combine_date_time.py @@ -3,6 +3,7 @@ import pytest from dialogy import constants as const +from dialogy.base import Input, Output from dialogy.plugins import CombineDateTimeOverSlots, DucklingPlugin from dialogy.workflow import Workflow from tests import EXCEPTIONS, load_tests @@ -17,7 +18,7 @@ def test_plugin_cases(payload) -> None: tracker = payload.get("inputs", {}).get("tracker", []) expected = payload.get("expected", {}) duckling_plugin = DucklingPlugin( - dimensions=["date", "time"], timezone="Asia/Kolkata", access=lambda x: x + dimensions=["date", "time"], timezone="Asia/Kolkata", dest="output.entities" ) for i, entity in enumerate(entities): @@ -25,13 +26,13 @@ def test_plugin_cases(payload) -> None: combine_date_time_plugin = CombineDateTimeOverSlots( trigger_intents=["_callback_"], - access=lambda w: (tracker, w.output[const.ENTITIES]), + dest="output.entities", ) workflow = Workflow(plugins=[combine_date_time_plugin]) - workflow.output = {const.ENTITIES: current_turn_entities} - output = workflow.run(input_=[]) - entity_values = [entity.get_value() for entity in output[const.ENTITIES]] + workflow.output = Output(entities=current_turn_entities) + _, output = workflow.run(Input(utterances=[""], slot_tracker=tracker)) + entity_values = [entity["value"] for entity in output[const.ENTITIES]] if len(entity_values) != len(expected): pytest.fail( @@ -40,28 +41,28 @@ def test_plugin_cases(payload) -> None: for entity_value, expected_value in zip(entity_values, expected): try: - assert entity_value == datetime.fromisoformat( - expected_value - ), f"Expected {datetime.fromisoformat(expected_value)} but got {entity_value}" + expected = datetime.fromisoformat(expected_value) + generated = datetime.fromisoformat(entity_value) + assert generated == expected, f"Expected {expected} but got {generated}" except (ValueError, TypeError): assert entity_value == expected_value def test_plugin_exit_at_missing_trigger_intents(): combine_date_time_plugin = CombineDateTimeOverSlots( - trigger_intents=[], access=lambda w: ([], w.output[const.ENTITIES]) + trigger_intents=[], dest="output.entities" ) workflow = Workflow(plugins=[combine_date_time_plugin]) - output = workflow.run(input_=[]) + _, output = workflow.run(Input(utterances=[""])) assert output[const.ENTITIES] == [] def test_plugin_exit_at_missing_tracker(): combine_date_time_plugin = CombineDateTimeOverSlots( - trigger_intents=["_callback_"], access=lambda w: ([], w.output[const.ENTITIES]) + trigger_intents=["_callback_"], dest="output.entities" ) workflow = Workflow(plugins=[combine_date_time_plugin]) - output = workflow.run(input_=[]) + _, output = workflow.run(Input(utterances=[""])) assert output[const.ENTITIES] == [] diff --git a/tests/plugin/text/slot_filler/test_rule_slot_filler.py b/tests/plugin/text/slot_filler/test_rule_slot_filler.py index 6ada4c58..6deb2fae 100644 --- a/tests/plugin/text/slot_filler/test_rule_slot_filler.py +++ b/tests/plugin/text/slot_filler/test_rule_slot_filler.py @@ -1,11 +1,13 @@ """ This is a tutorial for understanding the use of `RuleBasedSlotFillerPlugin`. """ +from textwrap import fill from typing import Any import pytest import dialogy.constants as const +from dialogy.base import Input, Output from dialogy.plugins import RuleBasedSlotFillerPlugin from dialogy.types.entity import BaseEntity from dialogy.types.intent import Intent @@ -29,14 +31,9 @@ def test_slot_filling() -> None: This test case covers a trivial usage of a slot-filler. We have `rules` that demonstrate association of intents with entities and their respective slot-configuration. """ - - def access(workflow: Workflow) -> Any: - return (workflow.output[const.INTENTS], workflow.output[const.ENTITIES]) - intent_name = "intent_1" - # Setting up the slot-filler, both instantiation and plugin is created. (notice two calls). - slot_filler = RuleBasedSlotFillerPlugin(rules=rules, access=access) + slot_filler = RuleBasedSlotFillerPlugin(rules=rules, dest="output.intents") # Create a mock `workflow` workflow = Workflow([slot_filler]) @@ -50,17 +47,19 @@ def access(workflow: Workflow) -> Any: range={"from": 0, "to": len(body)}, body=body, dim="default", - type="entity_1", + entity_type="entity_1", values=[{"key": "value"}], ) # The RuleBasedSlotFillerPlugin specifies that it expects `Tuple[Intent, List[Entity])` on `access(workflow)`. - workflow.output = {const.INTENTS: [intent], const.ENTITIES: [entity]} - output = workflow.run(body) + workflow.set("output.intents", [intent]).set("output.entities", [entity]) + + _, output = workflow.run(Input(utterances=body)) + intent, *_ = output[const.INTENTS] # `workflow.output[0]` is the `Intent` we created. # so we are checking if the `entity_1_slot` is filled by our mock entity. - assert output[const.INTENTS][0].slots["entity_1_slot"].values[0] == entity + assert intent["slots"]["entity_1_slot"]["values"][0] == entity.json() def test_slot_no_fill() -> None: @@ -68,14 +67,10 @@ def test_slot_no_fill() -> None: Here, we will see that an entity will not fill an intent unless the intent has a slot for it. `intent_1` doesn't have a slot for an entity of type `entity_2`. """ - - def access(workflow: Workflow) -> Any: - return (workflow.output[const.INTENTS], workflow.output[const.ENTITIES]) - intent_name = "intent_1" # Setting up the slot-filler, both instantiation and plugin is created. (notice two calls). - slot_filler = RuleBasedSlotFillerPlugin(rules=rules, access=access) + slot_filler = RuleBasedSlotFillerPlugin(rules=rules, dest="output.intents") # Create a mock `workflow` workflow = Workflow([slot_filler]) @@ -89,17 +84,18 @@ def access(workflow: Workflow) -> Any: range={"from": 0, "to": len(body)}, body=body, dim="default", - type="entity_2", + entity_type="entity_2", values=[{"key": "value"}], ) # The RuleBasedSlotFillerPlugin specifies that it expects `Tuple[Intent, List[Entity])` on `access(workflow)`. - workflow.output = {const.INTENTS: [intent], const.ENTITIES: [entity]} - output = workflow.run(body) + workflow.set("output.intents", [intent]).set("output.entities", [entity]) + + _, output = workflow.run(Input(utterances=body)) # `workflow.output[0]` is the `Intent` we created. # we can see that the `entity_2_slot` is not filled by our mock entity. - assert "entity_1_slot" not in output[const.INTENTS][0].slots + assert "entity_1_slot" not in output[const.INTENTS][0]["slots"] def test_slot_invalid_intent() -> None: @@ -107,14 +103,12 @@ def test_slot_invalid_intent() -> None: Here, we will see that an entity will not fill an intent unless the intent has a slot for it. `intent_1` doesn't have a slot for an entity of type `entity_2`. """ - - def access(workflow: Workflow) -> Any: - return (workflow.output[const.INTENTS], workflow.output[const.ENTITIES]) - intent_name = "intent_1" + # ... a mock `Intent` + intent = Intent(name=intent_name, score=0.8) # Setting up the slot-filler, both instantiation and plugin is created. (notice two calls). - slot_filler = RuleBasedSlotFillerPlugin(rules=rules, access=access) + slot_filler = RuleBasedSlotFillerPlugin(rules=rules, dest="output.intents") # Create a mock `workflow` workflow = Workflow([slot_filler]) @@ -125,17 +119,15 @@ def access(workflow: Workflow) -> Any: range={"from": 0, "to": len(body)}, body=body, dim="default", - type="entity_2", + entity_type="entity_1", values=[{"key": "value"}], ) # The RuleBasedSlotFillerPlugin specifies that it expects `Tuple[Intent, List[Entity])` on `access(workflow)`. - workflow.output = {const.INTENTS: [1], const.ENTITIES: [entity]} - output = workflow.run(body) + workflow.set("output.intents", [1]).set("output.entities", [entity]) - # `workflow.output[0]` is the `Intent` we created. - # we can see that the `entity_2_slot` is not filled by our mock entity. - assert output[const.INTENTS] == [1] + with pytest.raises(AttributeError): + workflow.run(Input(utterances=body)) def test_slot_invalid_intents() -> None: @@ -150,7 +142,7 @@ def access(workflow: Workflow) -> Any: intent_name = "intent_1" # Setting up the slot-filler, both instantiation and plugin is created. (notice two calls). - slot_filler = RuleBasedSlotFillerPlugin(rules=rules, access=access) + slot_filler = RuleBasedSlotFillerPlugin(rules=rules, dest="output.intents") # Create a mock `workflow` workflow = Workflow([slot_filler]) @@ -161,13 +153,13 @@ def access(workflow: Workflow) -> Any: range={"from": 0, "to": len(body)}, body=body, dim="default", - type="entity_2", + entity_type="entity_1", values=[{"key": "value"}], ) # The RuleBasedSlotFillerPlugin specifies that it expects `Tuple[Intent, List[Entity])` on `access(workflow)`. - workflow.output = {const.INTENTS: [], const.ENTITIES: [entity]} - output = workflow.run(body) + workflow.set("output.intents", []).set("output.entities", [entity]) + _, output = workflow.run(Input(utterances=body)) # `workflow.output[0]` is the `Intent` we created. # we can see that the `entity_2_slot` is not filled by our mock entity. @@ -186,7 +178,7 @@ def access(workflow: Workflow) -> Any: intent_name = "intent_2" # Setting up the slot-filler, both instantiation and plugin is created. (notice two calls). - slot_filler = RuleBasedSlotFillerPlugin(rules=rules, access=access) + slot_filler = RuleBasedSlotFillerPlugin(rules=rules, dest="output.intents") # Create a mock `workflow` workflow = Workflow([slot_filler]) @@ -200,7 +192,7 @@ def access(workflow: Workflow) -> Any: range={"from": 0, "to": len(body)}, body=body, dim="default", - type="entity_1", + entity_type="entity_1", values=[{"key": "value"}], ) @@ -208,18 +200,24 @@ def access(workflow: Workflow) -> Any: range={"from": 0, "to": len(body)}, body=body, dim="default", - type="entity_2", + entity_type="entity_2", values=[{"key": "value"}], ) # The RuleBasedSlotFillerPlugin specifies that it expects `Tuple[Intent, List[Entity])` on `access(workflow)`. - workflow.output = {const.INTENTS: [intent], const.ENTITIES: [entity_1, entity_2]} - output = workflow.run(body) + workflow.set("output.intents", [intent]).set( + "output.entities", [entity_1, entity_2] + ) + _, output = workflow.run(Input(utterances=body)) # `workflow.output[0]` is the `Intent` we created. # The `entity_1_slot` and `entity_2_slot` are filled. - assert output[const.INTENTS][0].slots["entity_1_slot"].values == [entity_1] - assert output[const.INTENTS][0].slots["entity_2_slot"].values == [entity_2] + assert output[const.INTENTS][0]["slots"]["entity_1_slot"]["values"] == [ + entity_1.json() + ] + assert output[const.INTENTS][0]["slots"]["entity_2_slot"]["values"] == [ + entity_2.json() + ] def test_slot_filling_multiple() -> None: @@ -227,15 +225,11 @@ def test_slot_filling_multiple() -> None: Let's try filling both the slots this time with fill_multiple=True! `intent_2` supports both `entity_1` and `entity_2`. """ - - def access(workflow: Workflow) -> Any: - return (workflow.output[const.INTENTS], workflow.output[const.ENTITIES]) - intent_name = "intent_2" # Setting up the slot-filler, both instantiation and plugin is created. (notice two calls). slot_filler = RuleBasedSlotFillerPlugin( - rules=rules, access=access, fill_multiple=True + rules=rules, dest="output.intents", fill_multiple=True ) # Create a mock `workflow` @@ -250,7 +244,7 @@ def access(workflow: Workflow) -> Any: range={"from": 0, "to": len(body)}, body=body, dim="default", - type="entity_1", + entity_type="entity_1", values=[{"key": "value"}], ) @@ -258,32 +252,36 @@ def access(workflow: Workflow) -> Any: range={"from": 0, "to": len(body)}, body=body, dim="default", - type="entity_2", + entity_type="entity_2", values=[{"key": "value"}], ) # The RuleBasedSlotFillerPlugin specifies that it expects `Tuple[Intent, List[Entity])` on `access(workflow)`. - workflow.output = {const.INTENTS: [intent], const.ENTITIES: [entity_1, entity_2]} - output = workflow.run(body) + workflow.set("output.intents", [intent]).set( + "output.entities", [entity_1, entity_2] + ) + _, output = workflow.run(Input(utterances=body)) # `workflow.output[0]` is the `Intent` we created. # The `entity_1_slot` and `entity_2_slot` are filled. - assert output[const.INTENTS][0].slots["entity_1_slot"].values == [entity_1] - assert output[const.INTENTS][0].slots["entity_2_slot"].values == [entity_2] + assert output[const.INTENTS][0]["slots"]["entity_1_slot"]["values"] == [ + entity_1.json() + ] + assert output[const.INTENTS][0]["slots"]["entity_2_slot"]["values"] == [ + entity_2.json() + ] -def test_slot_competition() -> None: +def test_slot_competition_fill_multiple() -> None: """ What happens when we have two entities of the same type but different value? """ - - def access(workflow: Workflow) -> Any: - return (workflow.output[const.INTENTS], workflow.output[const.ENTITIES]) - intent_name = "intent_1" # Setting up the slot-filler, both instantiation and plugin is created. (notice two calls). - slot_filler = RuleBasedSlotFillerPlugin(rules=rules, access=access) + slot_filler = RuleBasedSlotFillerPlugin( + rules=rules, dest="output.intents", fill_multiple=True + ) # Create a mock `workflow` workflow = Workflow([slot_filler]) @@ -297,7 +295,7 @@ def access(workflow: Workflow) -> Any: range={"from": 0, "to": len(body)}, body=body, dim="default", - type="entity_1", + entity_type="entity_1", values=[{"key": "value_1"}], ) @@ -305,64 +303,65 @@ def access(workflow: Workflow) -> Any: range={"from": 0, "to": len(body)}, body=body, dim="default", - type="entity_1", + entity_type="entity_1", values=[{"key": "value_2"}], ) - # The RuleBasedSlotFillerPlugin specifies that it expects `Tuple[Intent, List[Entity])` on `access(workflow)`. - workflow.output = {const.INTENTS: [intent], const.ENTITIES: [entity_1, entity_2]} - output = workflow.run(body) + workflow.set("output.intents", [intent]).set( + "output.entities", [entity_1, entity_2] + ) + _, output = workflow.run(Input(utterances=body)) # `workflow.output[0]` is the `Intent` we created. # The `entity_1_slot` and `entity_2_slot` are filled. - assert "entity_1_slot" not in output[const.INTENTS][0].slots + assert output[const.INTENTS][0]["slots"]["entity_1_slot"]["values"] == [ + entity_1.json(), + entity_2.json(), + ] -def test_incorrect_access_fn() -> None: + +def test_slot_competition_fill_one() -> None: """ - This test shows that the plugin needs `access` function to be a `PluginFn`, - or else it throws a `TypeError`. + What happens when we have two entities of the same type but different value? """ - rules = {"basic": {"slot_name": "basic_slot", "entity_type": "basic"}} - access = 5 + intent_name = "intent_1" + + # Setting up the slot-filler, both instantiation and plugin is created. (notice two calls). + slot_filler = RuleBasedSlotFillerPlugin( + rules=rules, dest="output.intents", fill_multiple=False + ) - slot_filler = RuleBasedSlotFillerPlugin(rules=rules, access=access) + # Create a mock `workflow` workflow = Workflow([slot_filler]) - intent = Intent(name="intent", score=0.8) + # ... a mock `Intent` + intent = Intent(name=intent_name, score=0.8) + + # Here we have two entities which compete for the same slot but have different values. body = "12th december" - entity = BaseEntity( + entity_1 = BaseEntity( range={"from": 0, "to": len(body)}, body=body, dim="default", - type="basic", - values=[{"key": "value"}], + entity_type="entity_1", + values=[{"key": "value_1"}], ) - workflow.output = {const.INTENTS: [intent], const.ENTITIES: [entity]} - - with pytest.raises(TypeError): - workflow.run("") - - -def test_missing_access_fn() -> None: - """ - This test shows that the plugin needs an `access` provided or else it raises a type error. - """ - slot_filler = RuleBasedSlotFillerPlugin(rules=rules) - workflow = Workflow([slot_filler]) - intent = Intent(name="intent", score=0.8) - - body = "12th december" - entity = BaseEntity( + entity_2 = BaseEntity( range={"from": 0, "to": len(body)}, body=body, dim="default", - type="basic", - values=[{"key": "value"}], + entity_type="entity_1", + values=[{"key": "value_2"}], ) - workflow.output = {const.INTENTS: [intent], const.ENTITIES: [entity]} + workflow.set("output.intents", [intent]).set( + "output.entities", [entity_1, entity_2] + ) + _, output = workflow.run(Input(utterances=body)) + + # `workflow.output[0]` is the `Intent` we created. + # The `entity_1_slot` and `entity_2_slot` are filled. - with pytest.raises(TypeError): - workflow.run("") + assert "entity_1_slot" not in output[const.INTENTS][0]["slots"] diff --git a/tests/plugin/text/test_calibration/test_calibration_filter.py b/tests/plugin/text/test_calibration/test_calibration_filter.py index b7ba8e4d..996d631a 100644 --- a/tests/plugin/text/test_calibration/test_calibration_filter.py +++ b/tests/plugin/text/test_calibration/test_calibration_filter.py @@ -6,13 +6,12 @@ import numpy as np import pandas as pd -import pytest from scipy import sparse -from dialogy import constants as const +from dialogy.base import Input, Output from dialogy.plugins.text.calibration.xgb import CalibrationModel -from dialogy.workflow.workflow import Workflow -from tests import EXCEPTIONS, load_tests +from dialogy.types import utterances +from tests import load_tests json_data = load_tests("df", __file__, ext=".json") df = pd.DataFrame(json_data, columns=["conv_id", "data", "tag", "value", "time"]) @@ -47,8 +46,7 @@ def predict(self, X): classifier = MyClassifier() calibration_model = CalibrationModel( - access=access, - mutate=mutate, + dest="input.transcripts", threshold=float("inf"), input_column="data", model_name="temp.pkl", @@ -102,21 +100,20 @@ def test_calibration_model_validation(): def test_calibration_model_utility(): - assert calibration_model.utility( - [[{"transcript": "hello", "am_score": -100, "lm_score": -200}]] - ) == ["hello"] + input_ = Input( + utterances=[[{"transcript": "hello", "am_score": -100, "lm_score": -200}]] + ) + assert calibration_model.utility(input_, Output()) == ["hello"] calibration_model.threshold = float("inf") - assert ( - calibration_model.utility( + input_ = Input( + utterances=[ [ - [ - { - "transcript": "hello world hello world", - "am_score": -100, - "lm_score": -200, - } - ] + { + "transcript": "hello world hello world", + "am_score": -100, + "lm_score": -200, + } ] - ) - == ["hello world hello world"] + ] ) + assert calibration_model.utility(input_, Output()) == ["hello world hello world"] diff --git a/tests/plugin/text/test_duckling_lb_plugin/test_cases.yaml b/tests/plugin/text/test_duckling_lb_plugin/test_cases.yaml index 1043ae07..d7ac3168 100644 --- a/tests/plugin/text/test_duckling_lb_plugin/test_cases.yaml +++ b/tests/plugin/text/test_duckling_lb_plugin/test_cases.yaml @@ -15,7 +15,7 @@ 'end': 15, 'dim': 'time', 'latent': False}] - expected: [{"entity": "TimeEntity"}] + expected: [{"entity_type": "date"}] - description: "test time entity." input: ["27th next month", "28th next month"] @@ -45,7 +45,7 @@ 'end': 15, 'dim': 'time', 'latent': False}] - expected: [{"entity": "TimeEntity"}] + expected: [{"entity_type": "date"}] - description: "test time entity." input: "" @@ -96,7 +96,7 @@ dimensions: ["people", "time", "date", "duration"] locale: "en_IN" timezone: "Asia/Kolkata" - expected: [{"entity": "PeopleEntity"}, {"entity": "TimeEntity"}, {"entity": "TimeEntity"}] + expected: [{"entity_type": "people"}, {"entity_type": "date"}, {"entity_type": "time"}] - description: "numerical entity test" input: "4 items" @@ -110,4 +110,4 @@ 'end': 4, 'dim': 'number', 'latent': False}] - expected: [{"entity": "NumericalEntity"}] \ No newline at end of file + expected: [{"entity_type": "number"}] diff --git a/tests/plugin/text/test_duckling_lb_plugin/test_duckling_lb_plugin.py b/tests/plugin/text/test_duckling_lb_plugin/test_duckling_lb_plugin.py index 7f729d92..e4dc68bc 100644 --- a/tests/plugin/text/test_duckling_lb_plugin/test_duckling_lb_plugin.py +++ b/tests/plugin/text/test_duckling_lb_plugin/test_duckling_lb_plugin.py @@ -3,6 +3,7 @@ import httpretty import pytest +from dialogy.base import Input, Output from dialogy.plugins import DucklingPluginLB from dialogy.workflow import Workflow from tests import EXCEPTIONS, load_tests, request_builder @@ -24,13 +25,7 @@ def test_plugin_working_cases(payload) -> None: reference_time = payload.get("reference_time") use_latent = payload.get("use_latent") - def access(workflow): - return workflow.input, reference_time, locale, use_latent - - def mutate(workflow, entities): - workflow.output = {"entities": entities} - - duckling_plugin = DucklingPluginLB(access=access, mutate=mutate, **duckling_args) + duckling_plugin = DucklingPluginLB(dest="output.entities", **duckling_args) request_callback = request_builder(mock_entity_json, response_code=response_code) httpretty.register_uri( @@ -40,15 +35,17 @@ def mutate(workflow, entities): workflow = Workflow([duckling_plugin]) if expected_types is not None: - output = workflow.run(body) - module = importlib.import_module("dialogy.types.entity") - - if not output["entities"]: - assert output["entities"] == [] + _, output = workflow.run( + Input( + utterances=body, + locale=locale, + reference_time=reference_time, + latent_entities=use_latent, + ) + ) for i, entity in enumerate(output["entities"]): - class_name = expected_types[i]["entity"] - assert isinstance(entity, getattr(module, class_name)) + assert entity["entity_type"] == expected_types[i]["entity_type"] else: with pytest.raises(EXCEPTIONS[exception]): workflow.run(body) diff --git a/tests/plugin/text/test_duckling_plugin/test_cases.yaml b/tests/plugin/text/test_duckling_plugin/test_cases.yaml index fbd4af65..e0b537bb 100644 --- a/tests/plugin/text/test_duckling_plugin/test_cases.yaml +++ b/tests/plugin/text/test_duckling_plugin/test_cases.yaml @@ -38,7 +38,7 @@ dimensions: ["people", "time", "date", "duration"] locale: "en_IN" timezone: "Asia/Kolkata" - expected: [{"entity": "PeopleEntity"}, {"entity": "TimeEntity"}, {"entity": "TimeEntity"}] + expected: [{"entity_type": "people"}, {"entity_type": "date"}, {"entity_type": "time"}] - description: "test time interval." input: "between 2 to 4 am" @@ -84,7 +84,7 @@ 'end': 4, 'dim': 'number', 'latent': False}] - expected: [{"entity": "NumericalEntity"}] + expected: [{"entity_type": "number"}] - description: "numerical entity test with a list of strings" input: ["4 items", "four items"] @@ -98,7 +98,7 @@ 'end': 4, 'dim': 'number', 'latent': False}] - expected: [{"entity": "NumericalEntity"}, {"entity": "NumericalEntity"}] + expected: [{"entity_type": "number"}, {"entity_type": "number"}] - description: "plastic money entity" input: ["my card 4111111111111111"] @@ -112,7 +112,7 @@ 'end': 24, 'dim': 'credit-card-number', 'latent': False}] - expected: [{"entity": "PlasticCurrencyEntity"}] + expected: [{"entity_type": "credit-card-number"}] - description: "test time entity." input: "27th next month" @@ -131,7 +131,7 @@ 'end': 15, 'dim': 'time', 'latent': False}] - expected: [{"entity": "TimeEntity"}] + expected: [{"entity_type": "date"}] - description: "test time interval." input: "between 2 to 4 am" @@ -157,7 +157,7 @@ 'end': 17, 'dim': 'time', 'latent': False}] - expected: [{"entity": "TimeIntervalEntity"}] + expected: [{"entity_type": "time"}] - description: "Test time interval entity having inconsistent values." input: "11 से 13 तक 11 तारीख" @@ -182,7 +182,7 @@ 'end': 20, 'dim': 'time', 'latent': False}] - expected: [{"entity": "TimeIntervalEntity"}] + expected: [{"entity_type": "datetime"}] - description: "time interval entity with only from value." input: "from 4 am" @@ -204,7 +204,7 @@ 'end': 9, 'dim': 'time', 'latent': False}] - expected: [{"entity": "TimeIntervalEntity"}] + expected: [{"entity_type": "time"}] - description: "time interval entity with only to value" input: "till 2 pm" @@ -226,7 +226,7 @@ 'end': 9, 'dim': 'time', 'latent': False}] - expected: [{"entity":"TimeIntervalEntity"}] + expected: [{"entity_type":"time"}] - description: "duration entity" input: "2 hours" @@ -244,7 +244,7 @@ 'end': 7, 'dim': 'duration', 'latent': False}] - expected: [{"entity": "DurationEntity"}] + expected: [{"entity_type": "duration"}] - description: "no entity found." input: "there is no spoon" @@ -371,7 +371,7 @@ 'end': 7, 'dim': 'time', 'latent': False}] - expected: [{"entity": "TimeEntity"}] + expected: [{"entity_type": "date"}] - description: "Filter out future dates only." input: "" @@ -415,19 +415,19 @@ } }] reference_time: 1622640071000 - expected: [{"entity": "TimeEntity"}] + expected: [{"entity_type": "time"}] -- description: "Filter out future dates only but preserve other entities." +- description: "Filter out future dates using 'ge' (greater than or equals) filter." input: "" duckling: locale: "hi_IN" - dimensions: ["time", "date", "numer"] + dimensions: ["time", "date"] timezone: "Asia/Kolkata" - datetime_filters: "future" + datetime_filters: "ge" debug: False mock_entity_json: [{ - 'body': 'परसो शाम 5 बजे 2 प्लेट चहिये', + 'body': 'परसो शाम 5 बजे', 'dim': 'time', 'latent': False, 'end': 23, @@ -458,29 +458,21 @@ 'type': 'value', 'value': '2021-06-04T17:00:00.000+05:30' } - }, { - 'body': '2', - 'start': 15, - 'value': {'value': 2, 'type': 'value'}, - 'end': 16, - 'dim': 'number', - 'latent': False }] reference_time: 1622640071000 - expected: [{"entity": "TimeEntity"}, {"entity": "NumericalEntity"}] + expected: [{"entity_type": "time"}] - -- description: "Exceptions due to incorrect datetime filter value." +- description: "Filter out future dates only but preserve other entities." input: "" duckling: locale: "hi_IN" - dimensions: ["time", "date"] + dimensions: ["time", "date", "numer"] timezone: "Asia/Kolkata" - datetime_filters: "latest" + datetime_filters: "future" debug: False mock_entity_json: [{ - 'body': 'परसो शाम 5 बजे', + 'body': 'परसो शाम 5 बजे 2 प्लेट चहिये', 'dim': 'time', 'latent': False, 'end': 23, @@ -511,17 +503,25 @@ 'type': 'value', 'value': '2021-06-04T17:00:00.000+05:30' } + }, { + 'body': '2', + 'start': 15, + 'value': {'value': 2, 'type': 'value'}, + 'end': 16, + 'dim': 'number', + 'latent': False }] reference_time: 1622640071000 - exception: "ValueError" + expected: [{"entity_type": "time"}, {"entity_type": "number"}] -- description: "Exceptions due to incorrect datetime filter type." + +- description: "Exceptions due to incorrect datetime filter value." input: "" duckling: locale: "hi_IN" dimensions: ["time", "date"] timezone: "Asia/Kolkata" - datetime_filters: 1425 + datetime_filters: 116 debug: False mock_entity_json: [{ 'body': 'परसो शाम 5 बजे', @@ -559,13 +559,13 @@ reference_time: 1622640071000 exception: "TypeError" -- description: "Exceptions due to incorrect reftime." +- description: "Exceptions due to incorrect datetime filter type." input: "" duckling: locale: "hi_IN" dimensions: ["time", "date"] timezone: "Asia/Kolkata" - datetime_filters: "latest" + datetime_filters: 1425 debug: False mock_entity_json: [{ 'body': 'परसो शाम 5 बजे', @@ -600,16 +600,16 @@ 'value': '2021-06-04T17:00:00.000+05:30' } }] - reference_time: '2021-06-02T18:51:11.000+05:30' + reference_time: '2021-06-04T17:00:00.000+05:30' exception: "TypeError" -- description: "Exceptions due to incorrect date filter type." +- description: "Exceptions due to incorrect reftime." input: "" duckling: locale: "hi_IN" dimensions: ["time", "date"] timezone: "Asia/Kolkata" - datetime_filters: 12 + datetime_filters: "future" debug: False mock_entity_json: [{ 'body': 'परसो शाम 5 बजे', @@ -644,7 +644,7 @@ 'value': '2021-06-04T17:00:00.000+05:30' } }] - reference_time: '2021-06-02T18:51:11.000+05:30' + reference_time: 4854163.585 exception: "TypeError" - description: "Testing aggregation on ASR transcripts." @@ -753,10 +753,10 @@ 'dim': 'time', 'latent': False}] reference_time: 1622640071000 - expected: [{"entity": "NumericalEntity"}, - {"entity": "TimeEntity"}, - {"entity": "TimeEntity"}, - {"entity": "TimeEntity"}] + expected: [{"entity": "number"}, + {"entity_type": "time"}, + {"entity_type": "time"}, + {"entity_type": "time"}] - description: "Testing currency entity." input: "300 dollars" @@ -772,7 +772,7 @@ 'dim': 'amount-of-money', 'latent': False}] reference_time: 1622640071000 - expected: [{"entity": "CurrencyEntity"}] + expected: [{"entity_type": "amount-of-money"}] - description: "Latent entity present" input: "on 2nd" @@ -799,4 +799,4 @@ 'end': 6, 'dim': 'time', 'latent': True}] - expected: [{"entity": "TimeEntity"}] + expected: [{"entity_type": "date"}] diff --git a/tests/plugin/text/test_duckling_plugin/test_duckling_plugin.py b/tests/plugin/text/test_duckling_plugin/test_duckling_plugin.py index c505eab4..b6935baa 100644 --- a/tests/plugin/text/test_duckling_plugin/test_duckling_plugin.py +++ b/tests/plugin/text/test_duckling_plugin/test_duckling_plugin.py @@ -1,13 +1,14 @@ -import importlib +import operator import time import httpretty import pandas as pd import pytest -import requests +from dialogy.base import Input from dialogy.plugins import DucklingPlugin from dialogy.types import BaseEntity, KeywordEntity, TimeEntity +from dialogy.utils import make_unix_ts from dialogy.workflow import Workflow from tests import EXCEPTIONS, load_tests, request_builder @@ -27,49 +28,6 @@ def test_plugin_with_custom_entity_map() -> None: assert duckling_plugin.dimension_entity_map["number"]["value"] == BaseEntity -def test_plugin_io_missing() -> None: - """ - Here we are checking if the plugin has access to workflow. - Since we haven't provided `access`, `mutate` to `DucklingPlugin` - we will receive a `TypeError`. - """ - duckling_plugin = DucklingPlugin( - locale="en_IN", timezone="Asia/Kolkata", dimensions=["time"] - ) - - workflow = Workflow([duckling_plugin]) - with pytest.raises(TypeError): - workflow.run("") - - -# == Test invalid i/o == -@pytest.mark.parametrize( - "access,mutate", - [ - (1, 1), - (lambda x: x, 1), - (1, lambda x: x), - ], -) -def test_plugin_io_type_mismatch(access, mutate) -> None: - """ - Here we are chcking if the plugin has access to workflow. - Since we have provided `access`, `mutate` of incorrect types to `DucklingPlugin` - we will receive a `TypeError`. - """ - duckling_plugin = DucklingPlugin( - access=access, - mutate=mutate, - locale="en_IN", - dimensions=["time"], - timezone="Asia/Kolkata", - ) - - workflow = Workflow([duckling_plugin]) - with pytest.raises(TypeError): - workflow.run("") - - def test_remove_low_scoring_entities_works_only_if_threshold_is_not_none(): duckling_plugin = DucklingPlugin( locale="en_IN", @@ -91,6 +49,41 @@ def test_remove_low_scoring_entities_works_only_if_threshold_is_not_none(): assert duckling_plugin.remove_low_scoring_entities([entity]) == [entity] +def test_duckling_get_operator_happy_case(): + duckling_plugin = DucklingPlugin( + locale="en_IN", + dimensions=["time"], + timezone="Asia/Kolkata", + threshold=0.2, + datetime_filters="future", + ) + assert duckling_plugin.get_operator("lt") == operator.lt + + +def test_duckling_get_operator_exception(): + duckling_plugin = DucklingPlugin( + locale="en_IN", + dimensions=["time"], + timezone="Asia/Kolkata", + threshold=0.2, + datetime_filters="future", + ) + with pytest.raises(ValueError): + duckling_plugin.get_operator("invalid") + + +def test_duckling_reftime(): + duckling_plugin = DucklingPlugin( + locale="en_IN", + dimensions=["time"], + timezone="Asia/Kolkata", + threshold=0.2, + datetime_filters="future", + ) + with pytest.raises(TypeError): + duckling_plugin.validate("test", None) + + def test_remove_low_scoring_entities_doesnt_remove_unscored_entities(): duckling_plugin = DucklingPlugin( locale="en_IN", dimensions=["time"], timezone="Asia/Kolkata", threshold=0.2 @@ -104,6 +97,7 @@ def test_remove_low_scoring_entities_doesnt_remove_unscored_entities(): type="basic", dim="default", values=[], + score=0.0, ) entity_B = BaseEntity( range={"from": 0, "to": len(body)}, @@ -115,7 +109,6 @@ def test_remove_low_scoring_entities_doesnt_remove_unscored_entities(): ) assert duckling_plugin.remove_low_scoring_entities([entity_A, entity_B]) == [ - entity_A, entity_B, ] @@ -135,29 +128,22 @@ def raise_timeout(_, __, headers): time.sleep(wait_time) return 200, headers, "received" - def access(workflow): - return workflow.input, None, locale, False - - def mutate(workflow, entities): - workflow.output = {"entities": entities} + httpretty.register_uri( + httpretty.POST, "http://0.0.0.0:8000/parse", body=raise_timeout + ) duckling_plugin = DucklingPlugin( locale=locale, dimensions=["time"], timezone="Asia/Kolkata", - access=access, - mutate=mutate, threshold=0.2, timeout=0.01, - ) - - httpretty.register_uri( - httpretty.POST, "http://0.0.0.0:8000/parse", body=raise_timeout + dest="output.entities", ) workflow = Workflow([duckling_plugin]) - workflow.run("test") - assert workflow.output["entities"] == [] + _, output = workflow.run(Input(utterances="test")) + assert output["entities"] == [] def test_duckling_connection_error() -> None: @@ -169,25 +155,18 @@ def test_duckling_connection_error() -> None: """ locale = "en_IN" - def access(workflow): - return workflow.input, None, locale, False - - def mutate(workflow, entities): - workflow.output = {"entities": entities} - duckling_plugin = DucklingPlugin( locale=locale, dimensions=["time"], timezone="Asia/Kolkata", - access=access, - mutate=mutate, + dest="output.entities", threshold=0.2, timeout=0.01, url="https://duckling/parse", ) workflow = Workflow([duckling_plugin]) - output = workflow.run("test") + _, output = workflow.run(Input(utterances="test", locale=locale)) assert output["entities"] == [] @@ -206,11 +185,8 @@ def test_max_workers_greater_than_zero() -> None: """ locale = "en_IN" - def access(workflow): - return workflow.input, None, locale, False - duckling_plugin = DucklingPlugin( - access=access, + dest="output.entities", dimensions=["time"], timezone="Asia/Kolkata", url="https://duckling/parse", @@ -219,7 +195,7 @@ def access(workflow): workflow = Workflow([duckling_plugin]) alternatives = [] # When ASR returns empty transcriptions. try: - workflow.run(alternatives) + workflow.run(Input(utterances=alternatives, locale=locale)) except ValueError as exc: pytest.fail(f"{exc}") @@ -240,13 +216,7 @@ def test_plugin_working_cases(payload) -> None: reference_time = payload.get("reference_time") use_latent = payload.get("use_latent") - def access(workflow): - return workflow.input, reference_time, locale, use_latent - - def mutate(workflow, entities): - workflow.output = {"entities": entities} - - duckling_plugin = DucklingPlugin(access=access, mutate=mutate, **duckling_args) + duckling_plugin = DucklingPlugin(dest="output.entities", **duckling_args) request_callback = request_builder(mock_entity_json, response_code=response_code) httpretty.register_uri( @@ -254,20 +224,33 @@ def mutate(workflow, entities): ) workflow = Workflow([duckling_plugin]) + if isinstance(reference_time, str): + reference_time = make_unix_ts("Asia/Kolkata")(reference_time) if expected_types is not None: - output = workflow.run(body) - module = importlib.import_module("dialogy.types.entity") + input_ = Input( + utterances=body, + locale=locale, + reference_time=reference_time, + latent_entities=use_latent, + ) + _, output = workflow.run(input_) if not output["entities"]: assert output["entities"] == [] for i, entity in enumerate(output["entities"]): - class_name = expected_types[i]["entity"] - assert isinstance(entity, getattr(module, class_name)) + expected_entity_type = expected_types[i]["entity_type"] + assert entity["entity_type"] == expected_entity_type else: with pytest.raises(EXCEPTIONS[exception]): - workflow.run(body) + input_ = Input( + utterances=body, + locale=locale, + reference_time=reference_time, + latent_entities=use_latent, + ) + workflow.run(input_) @httpretty.activate @@ -298,12 +281,6 @@ def test_plugin_no_transform(): httpretty.POST, "http://0.0.0.0:8000/parse", body=request_callback ) - def access(workflow): - return workflow.input, None, "en_IN", False - - def mutate(workflow, entities): - workflow.output = {"entities": entities} - df = pd.DataFrame( [ { @@ -322,8 +299,7 @@ def mutate(workflow, entities): locale="en_IN", dimensions=["time"], timezone="Asia/Kolkata", - access=access, - mutate=mutate, + dest="output.entities", threshold=0.2, timeout=0.01, use_transform=False, @@ -331,19 +307,6 @@ def mutate(workflow, entities): output_column="entities", ) - today = TimeEntity( - type="date", - body="today", - parsers=["DucklingPlugin"], - range={"start": 0, "end": 5}, - score=1.0, - alternative_index=0, - latent=False, - value="2021-09-14T00:00:00.000+05:30", - origin="value", - grain="day", - ) - df_ = duckling_plugin.transform(df) assert "entities" not in df_.columns @@ -376,12 +339,6 @@ def test_plugin_transform(): httpretty.POST, "http://0.0.0.0:8000/parse", body=request_callback ) - def access(workflow): - return workflow.input, None, "en_IN", False - - def mutate(workflow, entities): - workflow.output = {"entities": entities} - df = pd.DataFrame( [ { @@ -400,8 +357,7 @@ def mutate(workflow, entities): locale="en_IN", dimensions=["time"], timezone="Asia/Kolkata", - access=access, - mutate=mutate, + dest="output.entities", threshold=0.1, timeout=0.01, use_transform=True, @@ -410,7 +366,7 @@ def mutate(workflow, entities): ) today = TimeEntity( - type="date", + entity_type="date", body="today", parsers=["DucklingPlugin"], range={"start": 0, "end": 5}, @@ -456,12 +412,6 @@ def test_plugin_transform_type_error(): httpretty.POST, "http://0.0.0.0:8000/parse", body=request_callback ) - def access(workflow): - return workflow.input, None, "en_IN" - - def mutate(workflow, entities): - workflow.output = {"entities": entities} - df = pd.DataFrame( [ { @@ -476,8 +426,7 @@ def mutate(workflow, entities): locale="en_IN", dimensions=["time"], timezone="Asia/Kolkata", - access=access, - mutate=mutate, + dest="output.entities", threshold=0.2, timeout=0.01, use_transform=True, @@ -517,12 +466,6 @@ def test_plugin_transform_existing_entity(): httpretty.POST, "http://0.0.0.0:8000/parse", body=request_callback ) - def access_fn(workflow): - return workflow.input, None, "en_IN", False - - def mutate(workflow, entities): - workflow.output = {"entities": entities} - df = pd.DataFrame( [ { @@ -536,7 +479,7 @@ def mutate(workflow, entities): KeywordEntity( range={"start": 0, "end": 0}, value="apple", - type="fruits", + entity_type="fruits", body="apple", ) ], @@ -548,7 +491,7 @@ def mutate(workflow, entities): KeywordEntity( range={"start": 0, "end": 0}, value="apple", - type="fruits", + entity_type="fruits", body="apple", ) ], @@ -561,8 +504,7 @@ def mutate(workflow, entities): locale="en_IN", dimensions=["time"], timezone="Asia/Kolkata", - access=access_fn, - mutate=mutate, + dest="output.entities", threshold=0.2, timeout=0.01, use_transform=True, @@ -572,7 +514,7 @@ def mutate(workflow, entities): df_ = duckling_plugin.transform(df) today = TimeEntity( - type="date", + entity_type="date", body="today", parsers=["DucklingPlugin"], range={"start": 0, "end": 5}, diff --git a/tests/plugin/text/test_list_entity_plugin/test_list_entity_plugin.py b/tests/plugin/text/test_list_entity_plugin/test_list_entity_plugin.py index ba184dcb..2b09dff8 100644 --- a/tests/plugin/text/test_list_entity_plugin/test_list_entity_plugin.py +++ b/tests/plugin/text/test_list_entity_plugin/test_list_entity_plugin.py @@ -1,6 +1,7 @@ import pandas as pd import pytest +from dialogy.base import Input, Output from dialogy.plugins import ListEntityPlugin from dialogy.types import KeywordEntity from dialogy.workflow import Workflow @@ -33,17 +34,14 @@ def mutate(w, v): def test_value_error_if_incorrect_style(): with pytest.raises(ValueError): - l = ListEntityPlugin( - access=lambda w: (w.input,), mutate=mutate, style="unknown" - ) + l = ListEntityPlugin(dest="output.entities", style="unknown") l._parse({"location": ["..."]}) def test_value_error_if_spacy_missing(): with pytest.raises(ValueError): l = ListEntityPlugin( - access=lambda w: (w.input,), - mutate=mutate, + dest="output.entities", style="spacy", spacy_nlp=None, ) @@ -53,8 +51,7 @@ def test_value_error_if_spacy_missing(): def test_type_error_if_compiled_patterns_missing(): with pytest.raises(TypeError): l = ListEntityPlugin( - access=lambda w: (w.input,), - mutate=mutate, + dest="output.entities", style="spacy", spacy_nlp=None, ) @@ -63,8 +60,7 @@ def test_type_error_if_compiled_patterns_missing(): def test_entity_extractor_transform(): entity_extractor = ListEntityPlugin( - access=lambda x: x, - mutate=lambda y: y, + dest="output.entities", input_column="data", output_column="entities", use_transform=True, @@ -100,8 +96,7 @@ def test_entity_extractor_transform(): def test_entity_extractor_no_transform(): entity_extractor = ListEntityPlugin( - access=lambda x: x, - mutate=lambda y: y, + dest="output.entities", input_column="data", output_column="entities", use_transform=False, @@ -134,8 +129,7 @@ def test_entity_extractor_no_transform(): def test_entity_extractor_transform_no_existing_entity(): entity_extractor = ListEntityPlugin( - access=lambda x: x, - mutate=lambda y: y, + dest="output.entities", input_column="data", output_column="entities", use_transform=True, @@ -174,29 +168,27 @@ def test_get_list_entities(payload): if expected: list_entity_plugin = ListEntityPlugin( - access=lambda w: (w.input,), mutate=mutate, spacy_nlp=spacy_mocker, **config + dest="output.entities", spacy_nlp=spacy_mocker, **config ) workflow = Workflow([list_entity_plugin]) - output = workflow.run(input_=transcripts) - entities = output + print(transcripts) + _, output = workflow.run(input_=Input(utterances=transcripts)) + entities = output["entities"] if not entities and expected: pytest.fail("No entities found!") for i, entity in enumerate(entities): - assert entity.value == expected[i]["value"] - assert entity.type == expected[i]["type"] + assert entity["value"] == expected[i]["value"] + assert entity["type"] == expected[i]["type"] if "score" in expected[i]: - assert entity.score == expected[i]["score"] + assert entity["score"] == expected[i]["score"] else: with pytest.raises(EXCEPTIONS.get(exception)): list_entity_plugin = ListEntityPlugin( - access=lambda w: (w.input,), - mutate=mutate, - spacy_nlp=spacy_mocker, - **config + dest="output.entities", spacy_nlp=spacy_mocker, **config ) workflow = Workflow([list_entity_plugin]) - workflow.run(input_=input_) + _, output = workflow.run(input_=Input(utterances=transcripts)) diff --git a/tests/plugin/text/test_list_search_plugin/test_list_search_plugin.py b/tests/plugin/text/test_list_search_plugin/test_list_search_plugin.py index 420668bd..994873f9 100644 --- a/tests/plugin/text/test_list_search_plugin/test_list_search_plugin.py +++ b/tests/plugin/text/test_list_search_plugin/test_list_search_plugin.py @@ -1,8 +1,8 @@ import pandas as pd import pytest +from dialogy.base import Input, Output from dialogy.plugins import ListSearchPlugin -from dialogy.types import KeywordEntity from dialogy.workflow import Workflow from tests import EXCEPTIONS, load_tests @@ -20,29 +20,23 @@ def __call__(self, transcript): return self -def mutate(w, v): - w.output = v - - def test_not_supported_lang(): with pytest.raises(ValueError): l = ListSearchPlugin( - access=lambda w: (w.input,), - mutate=mutate, + dest="output.entities", fuzzy_threshold=0.3, fuzzy_dp_config={"te": {"channel": {"hello": "hello"}}}, ) - l.utility(".........", "te") + l.get_entities(["........."], "te") def test_entity_not_found(): l = ListSearchPlugin( - access=lambda w: (w.input,), - mutate=mutate, + dest="output.entities", fuzzy_threshold=0.4, fuzzy_dp_config={"en": {"location": {"delhi": "Delhi"}}}, ) - assert l.utility(["I live in punjab"], "en") == [] + assert l.get_entities(["I live in punjab"], "en") == [] @pytest.mark.parametrize("payload", load_tests("cases", __file__)) @@ -55,29 +49,23 @@ def test_get_list_entities(payload): transcripts = [expectation["text"] for expectation in input_] if expected: - list_entity_plugin = ListSearchPlugin( - access=lambda w: (w.input["alternatives"], w.input["lang"]), - mutate=mutate, - **config - ) + list_entity_plugin = ListSearchPlugin(dest="output.entities", **config) workflow = Workflow([list_entity_plugin]) - output = workflow.run(input_={"alternatives": transcripts, "lang": lang_}) - entities = output + _, output = workflow.run(Input(utterances=transcripts, lang=lang_)) + entities = output["entities"] if not entities and expected: pytest.fail("No entities found!") for i, entity in enumerate(entities): - assert entity.value == expected[i]["value"] - assert entity.type == expected[i]["type"] + assert entity["value"] == expected[i]["value"] + assert entity["type"] == expected[i]["type"] if "score" in expected[i]: - assert entity.score == expected[i]["score"] + assert entity["score"] == expected[i]["score"] else: with pytest.raises(EXCEPTIONS.get(exception)): - list_entity_plugin = ListSearchPlugin( - access=lambda w: (w.input,), mutate=mutate, **config - ) + list_entity_plugin = ListSearchPlugin(dest="output.entities", **config) workflow = Workflow([list_entity_plugin]) - workflow.run(input_=input_) + workflow.run(Input(utterances=transcripts, lang=lang_)) diff --git a/tests/plugin/text/test_merge_asr_output.py b/tests/plugin/text/test_merge_asr_output.py index 8bc3e635..e01cc499 100644 --- a/tests/plugin/text/test_merge_asr_output.py +++ b/tests/plugin/text/test_merge_asr_output.py @@ -3,20 +3,12 @@ import pandas as pd import pytest +from dialogy.base import Input from dialogy.plugins import MergeASROutputPlugin from dialogy.workflow import Workflow - -def access(workflow): - return workflow.input - - -def mutate(workflow, value): - workflow.output = value - - merge_asr_output_plugin = MergeASROutputPlugin( - access=access, mutate=mutate, use_transform=True, input_column="data" + dest="input.clf_feature", use_transform=True, input_column="data" ) @@ -26,9 +18,10 @@ def test_merge_asr_output() -> None: """ workflow = Workflow([merge_asr_output_plugin]) + input_ = Input(utterances=[[{"transcript": "hello world", "confidence": None}]]) - output = workflow.run([[{"transcript": "hello world", "confidence": None}]]) - assert output == [" hello world "] + input_, _ = workflow.run(input_) + assert input_["clf_feature"] == [" hello world "] def test_merge_longer_asr_output() -> None: @@ -36,9 +29,8 @@ def test_merge_longer_asr_output() -> None: This case shows the merge in case there are multiple options. """ workflow = Workflow([merge_asr_output_plugin]) - - output = workflow.run( - [ + input_ = Input( + utterances=[ [ {"transcript": "hello world", "confidence": None}, {"transcript": "hello word", "confidence": None}, @@ -46,7 +38,11 @@ def test_merge_longer_asr_output() -> None: ] ] ) - assert output == [" hello world hello word jello world "] + + input_, _ = workflow.run(input_) + assert input_["clf_feature"] == [ + " hello world hello word jello world " + ] def test_merge_keyerror_on_missing_transcript() -> None: @@ -56,9 +52,10 @@ def test_merge_keyerror_on_missing_transcript() -> None: """ workflow = Workflow([merge_asr_output_plugin]) + input_ = Input(utterances=[[{"not_transcript": "hello world", "confidence": None}]]) with pytest.raises(TypeError): - workflow.run([[{"not_transcript": "hello world", "confidence": None}]]) + workflow.run(input_) def test_invalid_data() -> None: diff --git a/tests/plugin/text/voting/test_intent_voting.py b/tests/plugin/text/voting/test_intent_voting.py index 477ce431..5f1fba58 100644 --- a/tests/plugin/text/voting/test_intent_voting.py +++ b/tests/plugin/text/voting/test_intent_voting.py @@ -1,11 +1,12 @@ """ This is a tutorial for understanding the use of `VotePlugin`. """ -from typing import Any, List +from typing import List import pytest from dialogy import constants as const +from dialogy.base import Input, Output from dialogy.plugins import VotePlugin from dialogy.types.intent import Intent from dialogy.workflow import Workflow @@ -22,11 +23,11 @@ def test_voting_0_intents(): have a test to see if it takes care of division 0. """ intents: List[Intent] = [] - vote_plugin = VotePlugin(access=lambda w: (w.output[0], 0), mutate=update_intent) + vote_plugin = VotePlugin(dest="output.intents") workflow = Workflow([vote_plugin]) - workflow.output = intents, [] - intent, _ = workflow.run(input_="") - assert intent.name == const.S_INTENT_OOS + workflow.output = Output(intents=intents) + _, output = workflow.run(Input(utterances=["some text"])) + assert output["intents"][0]["name"] == const.S_INTENT_OOS def test_voting_n_intents(): @@ -40,12 +41,13 @@ def test_voting_n_intents(): Intent(name="a", score=1), ] vote_plugin = VotePlugin( - debug=False, access=lambda w: (w.output[0], len(intents)), mutate=update_intent + debug=False, + dest="output.intents", ) workflow = Workflow([vote_plugin]) - workflow.output = intents, [] - intent, _ = workflow.run(input_="") - assert intent.name == "a" + workflow.output = Output(intents=intents) + _, output = workflow.run(Input(utterances=["some text"])) + assert output["intents"][0]["name"] == "a" def test_voting_on_conflicts(): @@ -58,13 +60,11 @@ def test_voting_on_conflicts(): Intent(name="b", score=1), Intent(name="b", score=1), ] - vote_plugin = VotePlugin( - access=lambda w: (w.output[0], len(intents)), mutate=update_intent - ) + vote_plugin = VotePlugin(dest="output.intents") workflow = Workflow([vote_plugin]) - workflow.output = intents, [] - intent, _ = workflow.run(input_="") - assert intent.name == "_oos_" + workflow.output = Output(intents=intents) + _, output = workflow.run(Input(utterances=["some text"])) + assert output["intents"][0]["name"] == "_oos_" def test_voting_on_weak_signals(): @@ -77,43 +77,11 @@ def test_voting_on_weak_signals(): Intent(name="b", score=0.1), Intent(name="b", score=0.1), ] - vote_plugin = VotePlugin( - access=lambda w: (w.output[0], len(intents)), mutate=update_intent - ) - workflow = Workflow([vote_plugin]) - workflow.output = intents, [] - intent, _ = workflow.run(input_="") - assert intent.name == "_oos_" - - -def test_missing_access(): - intents = [ - Intent(name="a", score=0.3), - Intent(name="a", score=0.2), - Intent(name="b", score=0.1), - Intent(name="b", score=0.1), - ] - - vote_plugin = VotePlugin(mutate=update_intent) - workflow = Workflow([vote_plugin]) - workflow.output = intents, [] - with pytest.raises(TypeError): - intent, _ = workflow.run(input_="") - - -def test_missing_mutate(): - intents = [ - Intent(name="a", score=0.3), - Intent(name="a", score=0.2), - Intent(name="b", score=0.1), - Intent(name="b", score=0.1), - ] - - vote_plugin = VotePlugin(access=lambda w: w.output[0]) + vote_plugin = VotePlugin(dest="output.intents") workflow = Workflow([vote_plugin]) - workflow.output = intents, [] - with pytest.raises(TypeError): - intent, _ = workflow.run(input_="") + workflow.output = Output(intents=intents) + _, output = workflow.run(Input(utterances=["some text"])) + assert output["intents"][0]["name"] == "_oos_" def test_representation_oos(): @@ -125,13 +93,11 @@ def test_representation_oos(): Intent(name="d", score=0.44), ] - vote_plugin = VotePlugin( - access=lambda w: (w.output[0], len(intents)), mutate=update_intent - ) + vote_plugin = VotePlugin(dest="output.intents") workflow = Workflow([vote_plugin]) - workflow.output = intents, [] - intent, _ = workflow.run(input_="") - assert intent.name == "_oos_" + workflow.output = Output(intents=intents) + _, output = workflow.run(Input(utterances=["some text"])) + assert output["intents"][0]["name"] == "_oos_" def test_representation_intent(): @@ -144,13 +110,11 @@ def test_representation_intent(): Intent(name="d", score=0.44), ] - vote_plugin = VotePlugin( - access=lambda w: (w.output[0], len(intents)), mutate=update_intent - ) + vote_plugin = VotePlugin(dest="output.intents") workflow = Workflow([vote_plugin]) - workflow.output = intents, [] - intent, _ = workflow.run(input_="") - assert intent.name == "a" + workflow.output = Output(intents=intents) + _, output = workflow.run(Input(utterances=["some text"])) + assert output["intents"][0]["name"] == "a" def test_aggregate_fn_incorrect(): @@ -164,13 +128,12 @@ def test_aggregate_fn_incorrect(): ] vote_plugin = VotePlugin( - access=lambda w: (w.output[0], len(intents)), - mutate=update_intent, + dest="output.intents", aggregate_fn=5, ) workflow = Workflow([vote_plugin]) - workflow.output = intents, [] + workflow.output = Output(intents=intents) with pytest.raises(TypeError): - intent, _ = workflow.run(input_="") - assert intent.name == "a" + _, output = workflow.run(Input(utterances=[""])) + assert output["intents"][0]["name"] == "a" diff --git a/tests/types/entity/test_cases.yaml b/tests/types/entity/test_cases.yaml index 154f1ae3..e2e528e7 100644 --- a/tests/types/entity/test_cases.yaml +++ b/tests/types/entity/test_cases.yaml @@ -31,7 +31,7 @@ "latent": false } ] - expected: [{"type": "date", "entity": "TimeEntity"}] + expected: [{"entity_type": "date", "entity": "TimeEntity"}] - description: "time interval time type test" input: "between 2 to 4 am" @@ -90,7 +90,7 @@ "latent": false } ] - expected: [{"type": "time", "entity": "TimeIntervalEntity"}] + expected: [{"entity_type": "time", "entity": "TimeIntervalEntity"}] - description: "datetime time type test" input: "Monday 9 pm" @@ -125,7 +125,7 @@ "latent": false } ] - expected: [{"type": "datetime", "entity": "TimeEntity"}] + expected: [{"entity_type": "datetime", "entity": "TimeEntity"}] - description: "time interval with neither `from` nor `to` keys" input: "between 2 to 4 am" diff --git a/tests/types/entity/test_entities.py b/tests/types/entity/test_entities.py index 05c3c80a..6a878caf 100644 --- a/tests/types/entity/test_entities.py +++ b/tests/types/entity/test_entities.py @@ -7,7 +7,7 @@ import httpretty import pytest -from dialogy.base.plugin import Plugin +from dialogy.base import Input, Plugin from dialogy.plugins import DucklingPlugin from dialogy.types.entity import ( BaseEntity, @@ -66,7 +66,7 @@ def test_entity_parser(): range={"from": 0, "to": len(body)}, body=body, dim="default", - type="basic", + entity_type="basic", values=[{"value": 0}], ) entity.add_parser(MockPlugin()) @@ -81,7 +81,7 @@ def test_entity_values_index_error(): range={"from": 0, "to": len(body)}, body=body, dim="default", - type="basic", + entity_type="basic", values=[], ) with pytest.raises(IndexError): @@ -93,7 +93,7 @@ def test_entity_deep_copy(): entity = BaseEntity( range={"from": 0, "to": len(body)}, body=body, - type="basic", + entity_type="basic", dim="default", values=[], ) @@ -110,7 +110,7 @@ def test_base_entity_value_setter(): entity = BaseEntity( range={"from": 0, "to": len(body)}, body=body, - type="basic", + entity_type="basic", dim="default", values=[], ) @@ -125,7 +125,7 @@ def test_entity_synthesis(): range={"from": 0, "to": len(body)}, body=body, dim="default", - type="basic", + entity_type="basic", values=[], ) synthetic_entity = entity_synthesis(entity, "body", "12th november") @@ -145,7 +145,7 @@ def test_entity_values_key_error(): range={"from": 0, "to": len(body)}, body=body, dim="default", - type="basic", + entity_type="basic", values=[{"key": "value"}], ) with pytest.raises(KeyError): @@ -161,7 +161,7 @@ def test_people_entity_unit_not_str_error(): body = "12 people" with pytest.raises(TypeError): _ = PeopleEntity( - range={"from": 0, "to": len(body)}, body=body, type="people", unit=0 + range={"from": 0, "to": len(body)}, body=body, entity_type="people", unit=0 ) @@ -169,7 +169,7 @@ def test_time_entity_grain_not_str_error(): body = "12 pm" with pytest.raises(TypeError): _ = TimeEntity( - range={"from": 0, "to": len(body)}, body=body, type="time", grain=0 + range={"from": 0, "to": len(body)}, body=body, entity_type="time", grain=0 ) @@ -177,7 +177,7 @@ def test_time_interval_entity_value_not_dict_error(): body = "from 4 pm to 12 am" with pytest.raises(TypeError): _ = TimeIntervalEntity( - range={"from": 0, "to": len(body)}, body=body, type="time", grain=0 + range={"from": 0, "to": len(body)}, body=body, entity_type="time", grain=0 ) @@ -185,7 +185,7 @@ def test_location_entity_value_not_int_error(): body = "bangalore" with pytest.raises(TypeError): _ = LocationEntity( - range={"from": 0, "to": len(body)}, body=body, type="location" + range={"from": 0, "to": len(body)}, body=body, entity_type="location" ) @@ -195,7 +195,7 @@ def test_entity_set_value_values_present(): range={"from": 0, "to": len(body)}, body=body, dim="default", - type="basic", + entity_type="basic", values=[{"value": 4}], ) entity.set_value() @@ -208,7 +208,7 @@ def test_entity_set_value_values_missing(): range={"from": 0, "to": len(body)}, body=body, dim="default", - type="basic", + entity_type="basic", ) entity.set_value(value=4) assert entity.value == 4 @@ -223,7 +223,7 @@ def test_interval_entity_set_value_values_missing() -> None: entity = TimeIntervalEntity( range={"from": 0, "to": len(body)}, body=body, - type="time", + entity_type="time", grain="hour", values=[value], ) @@ -239,7 +239,7 @@ def test_entity_jsonify() -> None: range={"from": 0, "to": len(body)}, body=body, dim="default", - type="basic", + entity_type="basic", values=values, ) entity.set_value(value) @@ -256,7 +256,7 @@ def test_entity_jsonify_unrestricted() -> None: range={"from": 0, "to": len(body)}, body=body, dim="default", - type="basic", + entity_type="basic", values=values, ) entity_json = entity.json(add=["dim", "values"]) @@ -273,7 +273,7 @@ def test_entity_jsonify_skip() -> None: range={"from": 0, "to": len(body)}, body=body, dim="default", - type="basic", + entity_type="basic", values=values, ) entity_json = entity.json(skip=["values"]) @@ -286,10 +286,10 @@ def test_both_entity_type_attributes_match() -> None: entity = BaseEntity( range={"from": 0, "to": len(body)}, body=body, - type="base", + entity_type="base", values=[value], ) - assert entity.type == entity.entity_type + assert "base" == entity.entity_type def test_interval_entity_only_from() -> None: @@ -300,7 +300,7 @@ def test_interval_entity_only_from() -> None: entity = TimeIntervalEntity( range={"from": 0, "to": len(body)}, body=body, - type="time", + entity_type="time", grain="hour", values=[value], ) @@ -316,7 +316,7 @@ def test_interval_entity_only_to() -> None: entity = TimeIntervalEntity( range={"from": 0, "to": len(body)}, body=body, - type="time", + entity_type="time", grain="hour", values=[value], ) @@ -330,7 +330,7 @@ def test_bad_interval_entity_neither_from_nor_to() -> None: entity = TimeIntervalEntity( range={"from": 0, "to": len(body)}, body=body, - type="time", + entity_type="time", grain="hour", values=[value], ) @@ -344,7 +344,7 @@ def test_bad_time_entity_invalid_value() -> None: entity = TimeEntity( range={"from": 0, "to": len(body)}, body=body, - type="time", + entity_type="time", grain="hour", values=[value], ) @@ -358,7 +358,7 @@ def test_bad_time_entity_no_value() -> None: entity = TimeEntity( range={"from": 0, "to": len(body)}, body=body, - type="time", + entity_type="time", grain="hour", values=[], ) @@ -373,7 +373,7 @@ def test_time_interval_entity_value_without_range() -> None: entity = TimeIntervalEntity( range={"from": 0, "to": len(body)}, body=body, - type="time", + entity_type="time", grain="hour", values=[value], value=value, @@ -421,7 +421,7 @@ def test_time_interval_entity_get_value() -> None: entity = TimeIntervalEntity( range={"from": 0, "to": len(body)}, body=body, - type="time", + entity_type="time", grain="hour", values=[value], value=value, @@ -439,7 +439,7 @@ def test_time_interval_entity_no_value() -> None: entity = TimeIntervalEntity( range={"from": 0, "to": len(body)}, body=body, - type="time", + entity_type="time", grain="hour", values=[value], value=value, @@ -460,15 +460,8 @@ def test_entity_type(payload) -> None: expected = payload.get("expected") exception = payload.get("exception") - def access(workflow): - return workflow.input, None, None, False - - def mutate(workflow, entities): - workflow.output = {"entities": entities} - duckling_plugin = DucklingPlugin( - access=access, - mutate=mutate, + dest="output.entities", dimensions=["people", "time", "date", "duration"], locale="en_IN", timezone="Asia/Kolkata", @@ -482,14 +475,10 @@ def mutate(workflow, entities): workflow = Workflow([duckling_plugin]) if expected: - workflow.run(body) - module = importlib.import_module("dialogy.types.entity") - - for i, entity in enumerate(workflow.output["entities"]): - class_name = expected[i]["entity"] - assert entity.type == expected[i]["type"] - assert entity.entity_type == expected[i]["type"] - assert isinstance(entity, getattr(module, class_name)) + _, output = workflow.run(Input(utterances=body)) + entities = output["entities"] + for i, entity in enumerate(entities): + assert entity["entity_type"] == expected[i]["entity_type"] elif exception: with pytest.raises(EXCEPTIONS[exception]): - workflow.run(body) + workflow.run(Input(utterances=body)) diff --git a/tests/types/intents/test_intents.py b/tests/types/intents/test_intents.py index ac226981..70857df0 100644 --- a/tests/types/intents/test_intents.py +++ b/tests/types/intents/test_intents.py @@ -74,7 +74,7 @@ def test_slot_filling() -> None: range={"from": 0, "to": len(body)}, body=body, dim="default", - type="basic", + entity_type="basic", values=[{"key": "value"}], slot_names=["basic_slot"], ) @@ -103,7 +103,7 @@ def test_slot_filling_prop_removal() -> None: range={"from": 0, "to": len(body)}, body=body, dim="default", - type="basic", + entity_type="basic", values=[{"key": "value"}], slot_names=["basic_slot"], ) @@ -113,7 +113,7 @@ def test_slot_filling_prop_removal() -> None: intent.fill_slot(entity) intent_json = intent.json() - assert "dim" not in intent_json["slots"][0] + assert "dim" not in intent_json["slots"]["basic_slot"]["values"][0] def test_rule_with_multiple_types() -> None: @@ -121,7 +121,7 @@ def test_rule_with_multiple_types() -> None: range={"from": 0, "to": 15}, body="12th december", dim="default", - type="ordinal", + entity_type="ordinal", values=[{"key": "12th"}], slot_names=["basic_slot"], ) @@ -129,7 +129,7 @@ def test_rule_with_multiple_types() -> None: range={"from": 0, "to": 15}, body="12 december", dim="default", - type="number", + entity_type="number", values=[{"key": "12"}], slot_names=["basic_slot"], ) diff --git a/tests/utils/test_datetime_utils.py b/tests/utils/test_datetime_utils.py new file mode 100644 index 00000000..3124636a --- /dev/null +++ b/tests/utils/test_datetime_utils.py @@ -0,0 +1,23 @@ +import pytest + +from dialogy.utils import make_unix_ts + + +def test_tz_aware_ts(): + assert ( + make_unix_ts("Asia/Kolkata")("2022-02-07T19:36:37.188396+05:30") + == 1644242797188 + ) + + +def test_tz_unaware_ts(): + assert make_unix_ts("Asia/Kolkata")("2022-02-07T19:39:39.537827") == 1644241599537 + + +def test_incorrect_tz(): + with pytest.raises(ValueError): + make_unix_ts(None)("2022-02-07T19:36:37.188396") + + +def test_int_returned_as_is(): + make_unix_ts(None)(1644241599537) == 1644241599537 diff --git a/tests/workflow/test_workflow.py b/tests/workflow/test_workflow.py index ce7b0707..983bd96f 100644 --- a/tests/workflow/test_workflow.py +++ b/tests/workflow/test_workflow.py @@ -1,14 +1,8 @@ -import json -from typing import Any, Optional - -import pandas as pd import pytest -from sklearn.metrics import f1_score import dialogy.constants as const -from dialogy.base.plugin import Plugin, PluginFn +from dialogy.base import Input, Output from dialogy.plugins import MergeASROutputPlugin -from dialogy.types import Intent from dialogy.workflow import Workflow @@ -17,7 +11,7 @@ def test_workflow_get_input() -> None: Basic initialization. """ workflow = Workflow([]) - assert workflow.input == {}, "workflow.get_input() is a dict()." + assert workflow.input == None, "workflow is NoneType." def test_workflow_set_output() -> None: @@ -48,19 +42,14 @@ def test_workflow_history_logs() -> None: """ We can execute the workflow. """ - - def m(w, v): - w.output = v - workflow = Workflow( - [MergeASROutputPlugin(access=lambda w: w.input, mutate=m, debug=True)], + [MergeASROutputPlugin(dest="input.clf_feature", debug=True)], debug=True, ) - output = workflow.run(input_=["apples"]) - assert output == [" apples "], "workflow.output should == 'apples'." - workflow.flush() - assert workflow.input == {} - assert workflow.output == {const.INTENTS: [], const.ENTITIES: []} + input_, _ = workflow.run(Input(utterances=["apples"])) + assert input_["clf_feature"] == [" apples "] + assert workflow.input == None + assert workflow.output == Output() def test_workflow_as_dict(): @@ -69,6 +58,31 @@ def test_workflow_as_dict(): """ workflow = Workflow() assert workflow.json() == { - "input": {}, + "input": None, "output": {const.INTENTS: [], const.ENTITIES: []}, } + + +def test_workflow_invalid_set_path(): + """ + We can't set invalid values in workflow. + """ + workflow = Workflow() + with pytest.raises(ValueError): + workflow.set("invalid.path", []) + + +def test_workflow_invalid_set_value(): + """ + We can't set invalid values in workflow. + """ + workflow = Workflow() + with pytest.raises(ValueError): + workflow.set("output.intents", 10) + + +def test_safe_flush(): + workflow = Workflow() + i, o = workflow.flush() + assert i == {} + assert o == {}