From 46d1a638360948b8ed5d2ffad625edd4e0688299 Mon Sep 17 00:00:00 2001 From: NeonJarbas <59943014+NeonJarbas@users.noreply.github.com> Date: Fri, 12 Jan 2024 02:55:16 +0000 Subject: [PATCH] feat/OCP csv keyword files support (#174) * load_ocp_keyword_from_csv * csv support --------- Co-authored-by: NeonJarbas Co-authored-by: JarbasAi --- ovos_workshop/skills/common_play.py | 100 +++++++++++++++++++--------- 1 file changed, 69 insertions(+), 31 deletions(-) diff --git a/ovos_workshop/skills/common_play.py b/ovos_workshop/skills/common_play.py index 7d7472d5..b192c7a2 100644 --- a/ovos_workshop/skills/common_play.py +++ b/ovos_workshop/skills/common_play.py @@ -1,16 +1,19 @@ +import os from inspect import signature from threading import Event -from ovos_workshop.skills.ovos import OVOSSkill -from ovos_bus_client import Message -from ovos_utils.log import LOG -from ovos_utils import camel_case_split from typing import List + +from ovos_utils import camel_case_split +from ovos_utils.log import LOG + +from ovos_bus_client import Message from ovos_classifiers.skovos.features import KeywordFeatures +from ovos_config.locations import get_xdg_cache_save_path +from ovos_workshop.skills.ovos import OVOSSkill # backwards compat imports, do not delete, skills import from here from ovos_workshop.decorators.ocp import ocp_play, ocp_next, ocp_pause, ocp_resume, ocp_search, \ ocp_previous, ocp_featured_media - from ovos_utils.ocp import MediaType, MediaState, MatchConfidence, \ PlaybackType, PlaybackMode, PlayerState, LoopState, TrackState @@ -79,6 +82,16 @@ def __init__(self, *args, **kwargs): self.ocp_matchers = {} super().__init__(*args, **kwargs) + @property + def ocp_cache_dir(self): + """path to cached .csv file with ocp entities data + this file needs to be available in ovos-core + + NB: ovos-docker needs a shared volume + """ + os.makedirs(f"{get_xdg_cache_save_path()}/OCP", exist_ok=True) + return f"{get_xdg_cache_save_path()}/OCP" + def bind(self, bus): """Overrides the normal bind method. @@ -180,7 +193,7 @@ def ocp_voc_match(self, utterance, lang=None): matches[k] = v return matches - def load_ocp_keyword_from_csv(self, csv_path: str, lang: str): + def load_ocp_keyword_from_csv(self, csv_path: str, lang: str = None): """ load entities from a .csv file for usage with self.ocp_voc_match see the ocp_entities.csv datatsets for example files built from wikidata SPARQL queries @@ -194,9 +207,33 @@ def load_ocp_keyword_from_csv(self, csv_path: str, lang: str): film_genre,spy film ... """ + if lang is None: + for lang in self.native_langs: + if lang not in self.ocp_matchers: + self.ocp_matchers[lang] = KeywordFeatures() + self.ocp_matchers[lang].load_entities(csv_path) + else: + if lang not in self.ocp_matchers: + self.ocp_matchers[lang] = KeywordFeatures() + self.ocp_matchers[lang].load_entities(csv_path) + + def export_ocp_keywords_csv(self, csv_path: str = None, lang: str = None, + label: str = None): + """ export entities to a .csv file """ + lang = lang or self.lang if lang not in self.ocp_matchers: - self.ocp_matchers[lang] = KeywordFeatures() - self.ocp_matchers[lang].load_entities(csv_path) + raise RuntimeError(f"no entities registered for lang: {lang}") + + csv_path = csv_path or f"{self.ocp_cache_dir}/{self.skill_id}_{lang}.csv" + with open(csv_path, "w") as f: + f.write("label,sample") + for ent, samples in self.ocp_matchers[lang].entities.items(): + if label is not None and label != ent: + continue + for s in set(samples): + f.write(f"\n{ent},{s}") + LOG.info(f"{self.skill_id} OCP {lang} entities exported to {csv_path}") + return csv_path def register_ocp_keyword(self, media_type: MediaType, label: str, samples: List, langs: List[str] = None): @@ -205,27 +242,35 @@ def register_ocp_keyword(self, media_type: MediaType, label: str, ocp keywords can be efficiently matched with self.ocp_match helper method that uses Aho–Corasick algorithm """ + samples = list(set(samples)) langs = langs or self.native_langs for l in langs: if l not in self.ocp_matchers: self.ocp_matchers[l] = KeywordFeatures() self.ocp_matchers[l].register_entity(label, samples) - # TODO - send bus message once Pipeline is in # if the label is a valid OCP entity known by the classifier it will help # the classifier disambiguate between media_types # eg, if OCP finds a movie name in user utterances it will # prefer to search netflix instead of spotify - # right now only used for internal matching - # NB: consider sending a file path, - # bus messages with thousands of entities dont work well - #self.bus.emit( - # Message('ovos.common_play.register_keyword', - # {"skill_id": self.skill_id, - # "label": label, # if in OCP_ENTITIES it influences classifier - # "langs": langs, - # "samples": samples, - # "media_type": media_type})) + + # NB: we send a file path, bus messages with thousands of entities dont work well + if len(samples) >= 20: + csv = f"{self.ocp_cache_dir}/{self.skill_id}_{label}.csv" + self.export_ocp_keywords_csv(csv, label=label) + self.bus.emit( + Message('ovos.common_play.register_keyword', + {"skill_id": self.skill_id, + "label": label, # if in OCP_ENTITIES it influences classifier + "csv": csv, + "media_type": media_type})) + else: + self.bus.emit( + Message('ovos.common_play.register_keyword', + {"skill_id": self.skill_id, + "label": label, # if in OCP_ENTITIES it influences classifier + "samples": samples, + "media_type": media_type})) def deregister_ocp_keyword(self, media_type: MediaType, label: str, langs: List[str] = None): @@ -234,18 +279,11 @@ def deregister_ocp_keyword(self, media_type: MediaType, label: str, if l in self.ocp_matchers: self.ocp_matchers[l].deregister_entity(label) - # TODO - send bus message once Pipeline is in - # if the label is a valid OCP entity known by the classifier it will help - # the classifier disambiguate between media_types - # eg, if OCP finds a movie name in user utterances it will - # prefer to search netflix instead of spotify - # right now only used for internal matching - #self.bus.emit( - # Message('ovos.common_play.deregister_keyword', - # {"skill_id": self.skill_id, - # "label": label, - # "langs": langs, - # "media_type": media_type})) + self.bus.emit( + Message('ovos.common_play.deregister_keyword', + {"skill_id": self.skill_id, + "label": label, + "media_type": media_type})) def _register_decorated(self): # register search handlers