Skip to content

Commit

Permalink
feat/blacklist_from_session (#492)
Browse files Browse the repository at this point in the history
* feat/blacklist_from_session

allow skill_id and intents to be blacklisted in the Session, these will be ignored during the matching process

this allows finetuning utterances and permissions per client, taking it into account during the match process instead of filtering before/after the handling of the utterance

needed for hivemind RBAC

* requirements.txt

* fix message

* tests

* tests

* ocp

* ocp blacklist tests

* fallback blacklist tests

* common_qa blacklist

* common_qa blacklist tests
  • Loading branch information
JarbasAl authored Jun 17, 2024
1 parent cda88f2 commit c074d3a
Show file tree
Hide file tree
Showing 17 changed files with 674 additions and 63 deletions.
1 change: 1 addition & 0 deletions .github/workflows/coverage.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ jobs:
pip install ./test/end2end/skill-old-stop
pip install ./test/end2end/skill-fake-fm
pip install ./test/end2end/skill-fake-fm-legacy
pip install ./test/end2end/skill-ovos-fakewiki
pip install ./test/end2end/metadata-test-plugin
- name: Install core repo
run: |
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/unit_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ jobs:
pip install ./test/end2end/skill-old-stop
pip install ./test/end2end/skill-fake-fm
pip install ./test/end2end/skill-fake-fm-legacy
pip install ./test/end2end/skill-ovos-fakewiki
pip install ./test/end2end/metadata-test-plugin
- name: Install core repo
run: |
Expand Down
6 changes: 6 additions & 0 deletions ovos_core/intent_services/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,12 @@ def handle_utterance(self, message: Message):
for match_func in self.get_pipeline(session=sess):
match = match_func(utterances, lang, message)
if match:
if match.skill_id and match.skill_id in sess.blacklisted_skills:
LOG.debug(f"ignoring match, skill_id '{match.skill_id}' blacklisted by Session '{sess.session_id}'")
continue
if match.intent_type and match.intent_type in sess.blacklisted_intents:
LOG.debug(f"ignoring match, intent '{match.intent_type}' blacklisted by Session '{sess.session_id}'")
continue
try:
self._emit_match_message(match, message)
break
Expand Down
13 changes: 9 additions & 4 deletions ovos_core/intent_services/adapt_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def match_low(self, utterances: List[str],
return match
return None

@lru_cache(maxsize=3)
@lru_cache(maxsize=3) # NOTE - message is a string because of this
def match_intent(self, utterances: Tuple[str],
lang: Optional[str] = None,
message: Optional[str] = None):
Expand All @@ -197,6 +197,11 @@ def match_intent(self, utterances: Tuple[str],
Returns:
Intent structure, or None if no match was found.
"""

if message:
message = Message.deserialize(message)
sess = SessionManager.get(message)

# we call flatten in case someone is sending the old style list of tuples
utterances = flatten_list(utterances)

Expand All @@ -215,13 +220,13 @@ def take_best(intent, utt):
nonlocal best_intent
best = best_intent.get('confidence', 0.0) if best_intent else 0.0
conf = intent.get('confidence', 0.0)
if best < conf:
skill = intent['intent_type'].split(":")[0]
if best < conf and intent["intent_type"] not in sess.blacklisted_intents \
and skill not in sess.blacklisted_skills:
best_intent = intent
# TODO - Shouldn't Adapt do this?
best_intent['utterance'] = utt

if message:
message = Message.deserialize(message)
sess = SessionManager.get(message)
for utt in utterances:
try:
Expand Down
12 changes: 11 additions & 1 deletion ovos_core/intent_services/commonqa_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,9 @@ def handle_question(self, message: Message):
replies=[], extensions=[],
query_time=time.time(), timeout_time=time.time() + self._max_time,
responses_gathered=Event(), completed=Event(),
answered=False, queried_skills=[])
answered=False,
queried_skills=[s for s in sess.blacklisted_skills
if s in self.common_query_skills]) # dont wait for these
assert query.responses_gathered.is_set() is False
assert query.completed.is_set() is False
self.active_queries[sess.session_id] = query
Expand Down Expand Up @@ -172,6 +174,11 @@ def handle_query_response(self, message: Message):
searching = message.data.get('searching')
answer = message.data.get('answer')

sess = SessionManager.get(message)
if skill_id in sess.blacklisted_skills:
LOG.debug(f"ignoring match, skill_id '{skill_id}' blacklisted by Session '{sess.session_id}'")
return

query = self.active_queries.get(SessionManager.get(message).session_id)
if not query:
LOG.warning(f"Late answer received from {skill_id}, no active query for: {search_phrase}")
Expand Down Expand Up @@ -221,6 +228,7 @@ def _query_timeout(self, message: Message):
handler can perform any additional actions.
@param message: question:query.response Message with `phrase` data
"""
sess = SessionManager.get(message)
query = self.active_queries.get(SessionManager.get(message).session_id)
LOG.info(f'Check responses with {len(query.replies)} replies')
search_phrase = message.data.get('phrase', "")
Expand All @@ -233,6 +241,8 @@ def _query_timeout(self, message: Message):
best = None
ties = []
for response in query.replies:
if response["skill_id"] in sess.blacklisted_skills:
continue
if not best or response['conf'] > best['conf']:
best = response
ties = [response]
Expand Down
3 changes: 3 additions & 0 deletions ovos_core/intent_services/converse_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,9 @@ def converse_with_skills(self, utterances, lang, message):
self._check_converse_timeout(message)
# check if any skill wants to handle utterance
for skill_id in self._collect_converse_skills(message):
if skill_id in session.blacklisted_skills:
LOG.debug(f"ignoring match, skill_id '{skill_id}' blacklisted by Session '{session.session_id}'")
continue
if self.converse(utterances, skill_id, lang, message):
state = session.utterance_states.get(skill_id, UtteranceState.INTENT)
return ovos_core.intent_services.IntentMatch(intent_service='Converse',
Expand Down
34 changes: 23 additions & 11 deletions ovos_core/intent_services/fallback_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import ovos_core.intent_services
from ovos_utils import flatten_list
from ovos_utils.log import LOG
from ovos_bus_client.session import SessionManager
from ovos_workshop.skills.fallback import FallbackMode

FallbackRange = namedtuple('FallbackRange', ['start', 'stop'])
Expand Down Expand Up @@ -82,9 +83,11 @@ def _collect_fallback_skills(self, message, fb_range=FallbackRange(0, 100)):
skill_ids = [] # skill_ids that already answered to ping
fallback_skills = [] # skill_ids that want to handle fallback

sess = SessionManager.get(message)
# filter skills outside the fallback_range
in_range = [s for s, p in self.registered_fallbacks.items()
if fb_range.start < p <= fb_range.stop]
if fb_range.start < p <= fb_range.stop
and s not in sess.blacklisted_skills]
skill_ids += [s for s in self.registered_fallbacks if s not in in_range]

def handle_ack(msg):
Expand All @@ -97,18 +100,19 @@ def handle_ack(msg):
LOG.info(f"{skill_id} will NOT try to handle fallback")
skill_ids.append(skill_id)

self.bus.on("ovos.skills.fallback.pong", handle_ack)
if in_range: # no need to search if no skills available
self.bus.on("ovos.skills.fallback.pong", handle_ack)

LOG.info("checking for FallbackSkillsV2 candidates")
# wait for all skills to acknowledge they want to answer fallback queries
self.bus.emit(message.forward("ovos.skills.fallback.ping",
message.data))
start = time.time()
while not all(s in skill_ids for s in self.registered_fallbacks) \
and time.time() - start <= 0.5:
time.sleep(0.02)
LOG.info("checking for FallbackSkillsV2 candidates")
# wait for all skills to acknowledge they want to answer fallback queries
self.bus.emit(message.forward("ovos.skills.fallback.ping",
message.data))
start = time.time()
while not all(s in skill_ids for s in self.registered_fallbacks) \
and time.time() - start <= 0.5:
time.sleep(0.02)

self.bus.remove("ovos.skills.fallback.pong", handle_ack)
self.bus.remove("ovos.skills.fallback.pong", handle_ack)
return fallback_skills

def attempt_fallback(self, utterances, skill_id, lang, message):
Expand All @@ -124,6 +128,10 @@ def attempt_fallback(self, utterances, skill_id, lang, message):
Returns:
handled (bool): True if handled otherwise False.
"""
sess = SessionManager.get(message)
if skill_id in sess.blacklisted_skills:
LOG.debug(f"ignoring match, skill_id '{skill_id}' blacklisted by Session '{sess.session_id}'")
return False
if self._fallback_allowed(skill_id):
fb_msg = message.reply(f"ovos.skills.fallback.{skill_id}.request",
{"skill_id": skill_id,
Expand Down Expand Up @@ -158,11 +166,15 @@ def _fallback_range(self, utterances, lang, message, fb_range):
message.data["utterances"] = utterances # all transcripts
message.data["lang"] = lang

sess = SessionManager.get(message)
# new style bus api
fallbacks = [(k, v) for k, v in self.registered_fallbacks.items()
if k in self._collect_fallback_skills(message, fb_range)]
sorted_handlers = sorted(fallbacks, key=operator.itemgetter(1))
for skill_id, prio in sorted_handlers:
if skill_id in sess.blacklisted_skills:
LOG.debug(f"ignoring match, skill_id '{skill_id}' blacklisted by Session '{sess.session_id}'")
continue
result = self.attempt_fallback(utterances, skill_id, lang, message)
if result:
return ovos_core.intent_services.IntentMatch(intent_service='Fallback',
Expand Down
30 changes: 23 additions & 7 deletions ovos_core/intent_services/ocp_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -892,10 +892,12 @@ def _process_play_query(self, utterance: str, lang: str, match: dict = None,

self.speak_dialog("just.one.moment")

sess = SessionManager.get(message)
# if a skill was explicitly requested, search it first
valid_skills = [
skill_id for skill_id, samples in self.skill_aliases.items()
if any(s.lower() in utterance for s in samples)
if skill_id not in sess.blacklisted_skills and
any(s.lower() in utterance for s in samples)
]
if valid_skills:
LOG.info(f"OCP specific skill names matched: {valid_skills}")
Expand Down Expand Up @@ -935,7 +937,7 @@ def handle_search_query(self, message: Message):
media_type, prob = self.classify_media(utterance, lang)
# search common play skills
results = self._search(phrase, media_type, lang, message=message)
best = self.select_best(results)
best = self.select_best(results, message)
results = [r.as_dict if isinstance(best, (MediaEntry, Playlist)) else r
for r in results]
if isinstance(best, (MediaEntry, Playlist)):
Expand Down Expand Up @@ -979,7 +981,12 @@ def handle_play_intent(self, message: Message):
"media_type": media_type})
else:
LOG.debug(f"Playing {len(results)} results for: {query}")
best = self.select_best(results)
best = self.select_best(results, message)
if best is None:
self.speak_dialog("cant.play",
data={"phrase": query,
"media_type": media_type})
return
LOG.debug(f"OCP Best match: {best}")
results = [r for r in results if r.as_dict != best.as_dict]
results.insert(0, best)
Expand Down Expand Up @@ -1361,7 +1368,10 @@ def _execute_query(self, phrase: str,
LOG.debug(f'Returning {len(results)} search results')
return results

def select_best(self, results: list) -> MediaEntry:
def select_best(self, results: list, message: Message) -> MediaEntry:

sess = SessionManager.get(message)

# Look at any replies that arrived before the timeout
# Find response(s) with the highest confidence
best = None
Expand All @@ -1370,6 +1380,9 @@ def select_best(self, results: list) -> MediaEntry:
for res in results:
if isinstance(res, dict):
res = dict2entry(res)
if res.skill_id in sess.blacklisted_skills:
LOG.debug(f"ignoring match, skill_id '{res.skill_id}' blacklisted by Session '{sess.session_id}'")
continue
if not best or res.match_confidence > best.match_confidence:
best = res
ties = [best]
Expand All @@ -1382,8 +1395,11 @@ def select_best(self, results: list) -> MediaEntry:
# TODO: Ask user to pick between ties or do it automagically
else:
selected = best
LOG.info(f"OVOSCommonPlay selected: {selected.skill_id} - {selected.match_confidence}")
LOG.debug(str(selected))
if selected:
LOG.info(f"OVOSCommonPlay selected: {selected.skill_id} - {selected.match_confidence}")
LOG.debug(str(selected))
else:
LOG.error("No valid OCP matches")
return selected

##################
Expand Down Expand Up @@ -1476,7 +1492,7 @@ def handle_legacy_cps(self, message: Message):
utt = message.data["query"]
res = self.mycroft_cps.search(utt)
if res:
best = self.select_best([r[0] for r in res])
best = self.select_best([r[0] for r in res], message)
if best:
callback = [r[1] for r in res if r[0].uri == best.uri][0]
self.mycroft_cps.skill_play(skill_id=best.skill_id,
Expand Down
37 changes: 28 additions & 9 deletions ovos_core/intent_services/padacioso_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import List, Optional

from ovos_config.config import Configuration
from ovos_bus_client.session import SessionManager, Session
from ovos_utils import flatten_list
from ovos_utils.log import LOG
from padacioso import IntentContainer as FallbackIntentContainer
Expand Down Expand Up @@ -75,7 +76,8 @@ def __init__(self, bus, config):
self.max_words = 50 # if an utterance contains more words than this, don't attempt to match
LOG.debug('Loaded Padacioso intent parser.')

def _match_level(self, utterances, limit, lang=None):
def _match_level(self, utterances, limit, lang=None,
message: Optional[Message] = None):
"""Match intent and make sure a certain level of confidence is reached.
Args:
Expand All @@ -87,7 +89,7 @@ def _match_level(self, utterances, limit, lang=None):
# call flatten in case someone is sending the old style list of tuples
utterances = flatten_list(utterances)
lang = lang or self.lang
padacioso_intent = self.calc_intent(utterances, lang)
padacioso_intent = self.calc_intent(utterances, lang, message)
if padacioso_intent is not None and padacioso_intent.conf > limit:
skill_id = padacioso_intent.name.split(':')[0]
return ovos_core.intent_services.IntentMatch(
Expand All @@ -101,7 +103,7 @@ def match_high(self, utterances, lang=None, message=None):
utterances (list of tuples): Utterances to parse, originals paired
with optional normalized version.
"""
return self._match_level(utterances, self.conf_high, lang)
return self._match_level(utterances, self.conf_high, lang, message)

def match_medium(self, utterances, lang=None, message=None):
"""Intent matcher for medium confidence.
Expand All @@ -110,7 +112,7 @@ def match_medium(self, utterances, lang=None, message=None):
utterances (list of tuples): Utterances to parse, originals paired
with optional normalized version.
"""
return self._match_level(utterances, self.conf_med, lang)
return self._match_level(utterances, self.conf_med, lang, message)

def match_low(self, utterances, lang=None, message=None):
"""Intent matcher for low confidence.
Expand All @@ -119,7 +121,7 @@ def match_low(self, utterances, lang=None, message=None):
utterances (list of tuples): Utterances to parse, originals paired
with optional normalized version.
"""
return self._match_level(utterances, self.conf_low, lang)
return self._match_level(utterances, self.conf_low, lang, message)

def __detach_intent(self, intent_name):
""" Remove an intent if it has been registered.
Expand Down Expand Up @@ -221,7 +223,8 @@ def register_entity(self, message):
self._register_object(message, 'entity',
self.containers[lang].add_entity)

def calc_intent(self, utterances: List[str], lang: str = None) -> Optional[PadaciosoIntent]:
def calc_intent(self, utterances: List[str], lang: str = None,
message: Optional[Message] = None) -> Optional[PadaciosoIntent]:
"""
Get the best intent match for the given list of utterances. Utilizes a
thread pool for overall faster execution. Note that this method is NOT
Expand All @@ -236,11 +239,14 @@ def calc_intent(self, utterances: List[str], lang: str = None) -> Optional[Padac
if not utterances:
LOG.error(f"utterance exceeds max size of {self.max_words} words, skipping padacioso match")
return None

lang = lang or self.lang
lang = lang.lower()
sess = SessionManager.get(message)
if lang in self.containers:
intent_container = self.containers.get(lang)
intents = [_calc_padacioso_intent(utt, intent_container) for utt in utterances]
intents = [_calc_padacioso_intent(utt, intent_container, sess)
for utt in utterances]
intents = [i for i in intents if i is not None]
# select best
if intents:
Expand All @@ -254,15 +260,28 @@ def shutdown(self):


@lru_cache(maxsize=3) # repeat calls under different conf levels wont re-run code
def _calc_padacioso_intent(utt, intent_container) -> \
def _calc_padacioso_intent(utt: str,
intent_container: FallbackIntentContainer,
sess: Session) -> \
Optional[PadaciosoIntent]:
"""
Try to match an utterance to an intent in an intent_container
@param args: tuple of (utterance, IntentContainer)
@return: matched PadaciosoIntent
"""
try:
intent = intent_container.calc_intent(utt)
intents = [i for i in intent_container.calc_intents(utt)
if i is not None
and i["name"] not in sess.blacklisted_intents
and i["name"].split(":")[0] not in sess.blacklisted_skills]
if len(intents) == 0:
return None
best_conf = max(x.get("conf", 0) for x in intents if x.get("name"))
ties = [i for i in intents if i.get("conf", 0) == best_conf]
if not ties:
return None
# TODO - how to disambiguate ?
intent = ties[0]
if "entities" in intent:
intent["matches"] = intent.pop("entities")
intent["sent"] = utt
Expand Down
Loading

0 comments on commit c074d3a

Please sign in to comment.