Skip to content

Commit

Permalink
feat:common_query_decorator (#315)
Browse files Browse the repository at this point in the history
* feat:common_query_decorator

register common query handlers via decorators instead of requiring a subclass with magic method to override

* single handler

* single handler

* fix

* fix

* fix

* deprecation warning
  • Loading branch information
JarbasAl authored Dec 31, 2024
1 parent 71b43cd commit 9809329
Show file tree
Hide file tree
Showing 3 changed files with 109 additions and 4 deletions.
21 changes: 20 additions & 1 deletion ovos_workshop/decorators/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from functools import wraps
from typing import Optional
from typing import Optional, Callable
from ovos_utils.log import log_deprecation

from ovos_workshop.decorators.killable import killable_intent, killable_event
Expand Down Expand Up @@ -118,6 +118,25 @@ def skill_api_method(func: callable):
return func


# utterance, answer, lang
CQCallback = Callable[[Optional[str], Optional[str], Optional[str]], None]


def common_query(callback: Optional[CQCallback] = None):
"""
Decorator for adding a method as an intent handler.
"""

def real_decorator(func):
# mark the method as a common_query handler
func.common_query = True
func.cq_callback = callback
return func

return real_decorator



def converse_handler(func):
"""
Decorator for aliasing a method as the converse method
Expand Down
1 change: 1 addition & 0 deletions ovos_workshop/skills/common_query_skill.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ class CommonQuerySkill(OVOSSkill):
"""

def __init__(self, *args, **kwargs):
log_deprecation("'CommonQuerySkill' class has been deprecated, use @common_query decorator with regular OVOSSkill instead", "4.0.0")
# these should probably be configurable
self.level_confidence = {
CQSMatchLevel.EXACT: 0.9,
Expand Down
91 changes: 88 additions & 3 deletions ovos_workshop/skills/ovos.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import binascii
import datetime
import json
import os
Expand All @@ -14,7 +15,6 @@
from threading import Event, RLock
from typing import Dict, Callable, List, Optional, Union

import binascii
from json_database import JsonStorage
from langcodes import closest_match
from ovos_bus_client import MessageBusClient
Expand Down Expand Up @@ -157,6 +157,9 @@ def __init__(self, name: Optional[str] = None,
# Skill Public API
self.public_api: Dict[str, dict] = {}

self._cq_handler = None
self._cq_callback = None

self._original_converse = self.converse # for get_response

self.__responses = {}
Expand Down Expand Up @@ -1005,6 +1008,13 @@ def _register_decorated(self):
if hasattr(method, 'converse'):
self.converse = method

# TODO support for multiple common query handlers (?)
if hasattr(method, 'common_query'):
self._cq_handler = method
self._cq_callback = method.cq_callback
LOG.debug(f"Registering common query handler for: {self.skill_id} - callback: {self._cq_callback}")
self.__handle_common_query_ping(Message("ovos.common_query.ping"))

if hasattr(method, 'converse_intents'):
for intent_file in getattr(method, 'converse_intents'):
self.register_converse_intent(intent_file, method)
Expand All @@ -1026,6 +1036,75 @@ def bind(self, bus: MessageBusClient):
self.audio_service = OCPInterface(self.bus)
self.private_settings = PrivateSettings(self.skill_id)

def __handle_common_query_ping(self, message):
if self._cq_handler:
# announce skill to common query pipeline
self.bus.emit(message.reply("ovos.common_query.pong",
{"skill_id": self.skill_id},
{"skill_id": self.skill_id}))

def __handle_query_action(self, message: Message):
"""
If this skill's response was spoken to the user, this method is called.
@param message: `question:action` message
"""
if not self._cq_callback or message.data["skill_id"] != self.skill_id:
# Not for this skill!
return
LOG.debug(f"common query callback for: {self.skill_id}")
lang = get_message_lang(message)
answer = message.data.get("answer") or message.data.get("callback_data", {}).get("answer")

# Inspect the callback signature
callback_signature = signature(self._cq_callback)
params = callback_signature.parameters

# Check if the first parameter is 'self' (indicating it's an instance method)
if len(params) > 0 and list(params.keys())[0] == 'self':
# Instance method: pass 'self' as the first argument
self._cq_callback(self, message.data["phrase"], answer, lang)
else:
# Static method or function: don't pass 'self'
self._cq_callback(message.data["phrase"], answer, lang)

def __handle_question_query(self, message: Message):
"""
Handle an incoming question query.
@param message: Message with matched query 'phrase'
"""
if not self._cq_handler:
return
lang = get_message_lang(message)
search_phrase = message.data["phrase"]
message.context["skill_id"] = self.skill_id
LOG.debug(f"Common QA: {self.skill_id}")
# First, notify the requestor that we are attempting to handle
# (this extends a timeout while this skill looks for a match)
self.bus.emit(message.response({"phrase": search_phrase,
"skill_id": self.skill_id,
"searching": True}))
answer = None
confidence = 0
try:
answer, confidence = self._cq_handler(search_phrase, lang) or (None, 0)
LOG.debug(f"Common QA {self.skill_id} result: {answer}")
except:
LOG.exception(f"Failed to get answer from {self._cq_handler}")

if answer and confidence >= 0.5:
self.bus.emit(message.response({"phrase": search_phrase,
"skill_id": self.skill_id,
"answer": answer,
"callback_data": {"answer": answer}, # so we get it in callback
"conf": confidence}))
else:
# Signal we are done (can't handle it)
self.bus.emit(message.response({"phrase": search_phrase,
"skill_id": self.skill_id,
"searching": False}))

def _register_public_api(self):
"""
Find and register API methods decorated with `@api_method` and create a
Expand Down Expand Up @@ -1094,6 +1173,11 @@ def _register_system_event_handlers(self):

self.add_event(f"{self.skill_id}.converse.get_response", self.__handle_get_response, speak_errors=False)

self.add_event('question:query', self.__handle_question_query, speak_errors=False)
self.add_event("ovos.common_query.ping", self.__handle_common_query_ping, speak_errors=False)
self.add_event('question:action', self.__handle_query_action,
handler_info='mycroft.skill.handler', is_intent=True, speak_errors=False)

# homescreen might load after this skill and miss the original events
self.add_event("homescreen.metadata.get", self.handle_homescreen_loaded, speak_errors=False)

Expand Down Expand Up @@ -2162,7 +2246,8 @@ def voc_match(self, utt: str, voc_filename: str, lang: Optional[str] = None,
try:
_vocs = self.voc_list(voc_filename, lang)
except FileNotFoundError:
LOG.warning(f"{self.skill_id} failed to find voc file '{voc_filename}' for lang '{lang}' in `{self.res_dir}'")
LOG.warning(
f"{self.skill_id} failed to find voc file '{voc_filename}' for lang '{lang}' in `{self.res_dir}'")
return False

if utt and _vocs:
Expand All @@ -2172,7 +2257,7 @@ def voc_match(self, utt: str, voc_filename: str, lang: Optional[str] = None,
for i in _vocs)
else:
# Check for matches against complete words
match = any([re.match(r'.*\b' + i + r'\b.*', utt, re.IGNORECASE)
match = any([re.match(r'.*\b' + re.escape(i) + r'\b.*', utt, re.IGNORECASE)
for i in _vocs])

return match
Expand Down

0 comments on commit 9809329

Please sign in to comment.