Skip to content

Commit

Permalink
feat/OCP csv keyword files support (#174)
Browse files Browse the repository at this point in the history
* load_ocp_keyword_from_csv

* csv support
---------

Co-authored-by: NeonJarbas <[email protected]>
Co-authored-by: JarbasAi <[email protected]>
  • Loading branch information
3 people authored Jan 12, 2024
1 parent 2adaf46 commit 46d1a63
Showing 1 changed file with 69 additions and 31 deletions.
100 changes: 69 additions & 31 deletions ovos_workshop/skills/common_play.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,19 @@
import os
from inspect import signature
from threading import Event
from ovos_workshop.skills.ovos import OVOSSkill
from ovos_bus_client import Message
from ovos_utils.log import LOG
from ovos_utils import camel_case_split
from typing import List

from ovos_utils import camel_case_split
from ovos_utils.log import LOG

from ovos_bus_client import Message
from ovos_classifiers.skovos.features import KeywordFeatures
from ovos_config.locations import get_xdg_cache_save_path
from ovos_workshop.skills.ovos import OVOSSkill

# backwards compat imports, do not delete, skills import from here
from ovos_workshop.decorators.ocp import ocp_play, ocp_next, ocp_pause, ocp_resume, ocp_search, \
ocp_previous, ocp_featured_media

from ovos_utils.ocp import MediaType, MediaState, MatchConfidence, \
PlaybackType, PlaybackMode, PlayerState, LoopState, TrackState

Expand Down Expand Up @@ -79,6 +82,16 @@ def __init__(self, *args, **kwargs):
self.ocp_matchers = {}
super().__init__(*args, **kwargs)

@property
def ocp_cache_dir(self):
"""path to cached .csv file with ocp entities data
this file needs to be available in ovos-core
NB: ovos-docker needs a shared volume
"""
os.makedirs(f"{get_xdg_cache_save_path()}/OCP", exist_ok=True)
return f"{get_xdg_cache_save_path()}/OCP"

def bind(self, bus):
"""Overrides the normal bind method.
Expand Down Expand Up @@ -180,7 +193,7 @@ def ocp_voc_match(self, utterance, lang=None):
matches[k] = v
return matches

def load_ocp_keyword_from_csv(self, csv_path: str, lang: str):
def load_ocp_keyword_from_csv(self, csv_path: str, lang: str = None):
""" load entities from a .csv file for usage with self.ocp_voc_match
see the ocp_entities.csv datatsets for example files built from wikidata SPARQL queries
Expand All @@ -194,9 +207,33 @@ def load_ocp_keyword_from_csv(self, csv_path: str, lang: str):
film_genre,spy film
...
"""
if lang is None:
for lang in self.native_langs:
if lang not in self.ocp_matchers:
self.ocp_matchers[lang] = KeywordFeatures()
self.ocp_matchers[lang].load_entities(csv_path)
else:
if lang not in self.ocp_matchers:
self.ocp_matchers[lang] = KeywordFeatures()
self.ocp_matchers[lang].load_entities(csv_path)

def export_ocp_keywords_csv(self, csv_path: str = None, lang: str = None,
label: str = None):
""" export entities to a .csv file """
lang = lang or self.lang
if lang not in self.ocp_matchers:
self.ocp_matchers[lang] = KeywordFeatures()
self.ocp_matchers[lang].load_entities(csv_path)
raise RuntimeError(f"no entities registered for lang: {lang}")

csv_path = csv_path or f"{self.ocp_cache_dir}/{self.skill_id}_{lang}.csv"
with open(csv_path, "w") as f:
f.write("label,sample")
for ent, samples in self.ocp_matchers[lang].entities.items():
if label is not None and label != ent:
continue
for s in set(samples):
f.write(f"\n{ent},{s}")
LOG.info(f"{self.skill_id} OCP {lang} entities exported to {csv_path}")
return csv_path

def register_ocp_keyword(self, media_type: MediaType, label: str,
samples: List, langs: List[str] = None):
Expand All @@ -205,27 +242,35 @@ def register_ocp_keyword(self, media_type: MediaType, label: str,
ocp keywords can be efficiently matched with self.ocp_match helper method
that uses Aho–Corasick algorithm
"""
samples = list(set(samples))
langs = langs or self.native_langs
for l in langs:
if l not in self.ocp_matchers:
self.ocp_matchers[l] = KeywordFeatures()
self.ocp_matchers[l].register_entity(label, samples)

# TODO - send bus message once Pipeline is in
# if the label is a valid OCP entity known by the classifier it will help
# the classifier disambiguate between media_types
# eg, if OCP finds a movie name in user utterances it will
# prefer to search netflix instead of spotify
# right now only used for internal matching
# NB: consider sending a file path,
# bus messages with thousands of entities dont work well
#self.bus.emit(
# Message('ovos.common_play.register_keyword',
# {"skill_id": self.skill_id,
# "label": label, # if in OCP_ENTITIES it influences classifier
# "langs": langs,
# "samples": samples,
# "media_type": media_type}))

# NB: we send a file path, bus messages with thousands of entities dont work well
if len(samples) >= 20:
csv = f"{self.ocp_cache_dir}/{self.skill_id}_{label}.csv"
self.export_ocp_keywords_csv(csv, label=label)
self.bus.emit(
Message('ovos.common_play.register_keyword',
{"skill_id": self.skill_id,
"label": label, # if in OCP_ENTITIES it influences classifier
"csv": csv,
"media_type": media_type}))
else:
self.bus.emit(
Message('ovos.common_play.register_keyword',
{"skill_id": self.skill_id,
"label": label, # if in OCP_ENTITIES it influences classifier
"samples": samples,
"media_type": media_type}))

def deregister_ocp_keyword(self, media_type: MediaType, label: str,
langs: List[str] = None):
Expand All @@ -234,18 +279,11 @@ def deregister_ocp_keyword(self, media_type: MediaType, label: str,
if l in self.ocp_matchers:
self.ocp_matchers[l].deregister_entity(label)

# TODO - send bus message once Pipeline is in
# if the label is a valid OCP entity known by the classifier it will help
# the classifier disambiguate between media_types
# eg, if OCP finds a movie name in user utterances it will
# prefer to search netflix instead of spotify
# right now only used for internal matching
#self.bus.emit(
# Message('ovos.common_play.deregister_keyword',
# {"skill_id": self.skill_id,
# "label": label,
# "langs": langs,
# "media_type": media_type}))
self.bus.emit(
Message('ovos.common_play.deregister_keyword',
{"skill_id": self.skill_id,
"label": label,
"media_type": media_type}))

def _register_decorated(self):
# register search handlers
Expand Down

0 comments on commit 46d1a63

Please sign in to comment.