From 5c642648b97e6710700202925f4d78eee114d112 Mon Sep 17 00:00:00 2001 From: Diondra Peck Date: Fri, 19 Jul 2024 19:44:41 -0700 Subject: [PATCH] Fix content safety unit tests --- .../promptflow/evals/evaluators/_coherence/_coherence.py | 7 ++----- .../_content_safety/_content_safety_sub_evaluator_base.py | 8 +++----- .../evals/evaluators/_content_safety/_hate_unfairness.py | 7 +------ .../_content_safety/common/evaluate_with_rai_service.py | 8 ++++---- 4 files changed, 10 insertions(+), 20 deletions(-) diff --git a/src/promptflow-evals/promptflow/evals/evaluators/_coherence/_coherence.py b/src/promptflow-evals/promptflow/evals/evaluators/_coherence/_coherence.py index f7a3135d4e8..74aaebe1377 100644 --- a/src/promptflow-evals/promptflow/evals/evaluators/_coherence/_coherence.py +++ b/src/promptflow-evals/promptflow/evals/evaluators/_coherence/_coherence.py @@ -4,7 +4,6 @@ import os import re -from typing import Dict import numpy as np @@ -24,9 +23,7 @@ def __init__(self, model_config: AzureOpenAIModelConfiguration): prompty_model_config = {"configuration": model_config} if USER_AGENT and isinstance(model_config, AzureOpenAIModelConfiguration): - prompty_model_config.update( - {"parameters": {"extra_headers": {"x-ms-useragent": USER_AGENT}}} - ) + prompty_model_config.update({"parameters": {"extra_headers": {"x-ms-useragent": USER_AGENT}}}) current_dir = os.path.dirname(__file__) prompty_path = os.path.join(current_dir, "coherence.prompty") self._flow = AsyncPrompty.load(source=prompty_path, model=prompty_model_config) @@ -79,7 +76,7 @@ class CoherenceEvaluator: def __init__(self, model_config: AzureOpenAIModelConfiguration): self._async_evaluator = _AsyncCoherenceEvaluator(model_config) - def __call__(self, *, question: str, answer: str, **kwargs) -> Dict[str, float]: + def __call__(self, *, question: str, answer: str, **kwargs): """ Evaluate coherence. diff --git a/src/promptflow-evals/promptflow/evals/evaluators/_content_safety/_content_safety_sub_evaluator_base.py b/src/promptflow-evals/promptflow/evals/evaluators/_content_safety/_content_safety_sub_evaluator_base.py index 0a699c7a13d..36a4bdbd768 100644 --- a/src/promptflow-evals/promptflow/evals/evaluators/_content_safety/_content_safety_sub_evaluator_base.py +++ b/src/promptflow-evals/promptflow/evals/evaluators/_content_safety/_content_safety_sub_evaluator_base.py @@ -2,7 +2,6 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # --------------------------------------------------------- from abc import ABC -from typing import Dict, List try: from .common.constants import EvaluationMetrics @@ -24,7 +23,7 @@ class ContentSafetySubEvaluatorBase(ABC): :type metric: ~promptflow.evals.evaluators._content_safety.flow.constants.EvaluationMetrics :param project_scope: The scope of the Azure AI project. It contains subscription id, resource group, and project name. - :type project_scope: dict + :type project_scope: Dict :param credential: The credential for connecting to Azure AI project. :type credential: TokenCredential """ @@ -34,7 +33,7 @@ def __init__(self, metric: EvaluationMetrics, project_scope: dict, credential=No self._project_scope = project_scope self._credential = credential - def __call__(self, *, question: str, answer: str, **kwargs) -> List[List[Dict]]: + def __call__(self, *, question: str, answer: str, **kwargs): """ Evaluates content according to this evaluator's metric. @@ -43,13 +42,12 @@ def __call__(self, *, question: str, answer: str, **kwargs) -> List[List[Dict]]: :keyword answer: The answer to be evaluated. :paramtype answer: str :return: The evaluation score computation based on the Content Safety metric (self.metric). - :rtype: List[List[Dict]] + :rtype: Any """ # Validate inputs # Raises value error if failed, so execution alone signifies success. _ = validate_inputs(question=question, answer=answer) - # question: str, answer: str, metric_name: str, project_scope: dict, credential: TokenCredential # Run score computation based on supplied metric. result = evaluate_with_rai_service( metric_name=self._metric, diff --git a/src/promptflow-evals/promptflow/evals/evaluators/_content_safety/_hate_unfairness.py b/src/promptflow-evals/promptflow/evals/evaluators/_content_safety/_hate_unfairness.py index be0f3396d46..9799d8c247b 100644 --- a/src/promptflow-evals/promptflow/evals/evaluators/_content_safety/_hate_unfairness.py +++ b/src/promptflow-evals/promptflow/evals/evaluators/_content_safety/_hate_unfairness.py @@ -1,9 +1,6 @@ # --------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # --------------------------------------------------------- -from typing import Dict, Union - -from azure.identity import DefaultAzureCredential, ManagedIdentityCredential try: from ._content_safety_sub_evaluator_base import ContentSafetySubEvaluatorBase @@ -46,9 +43,7 @@ class HateUnfairnessEvaluator(ContentSafetySubEvaluatorBase): } """ - def __init__( - self, project_scope: Dict, credential: Union[DefaultAzureCredential, ManagedIdentityCredential] = None - ) -> None: + def __init__(self, project_scope: dict, credential=None) -> None: # Hate_fairness is the actual backend metric name. Which, uh, doesn't sound great. # so invert the name. super().__init__( diff --git a/src/promptflow-evals/promptflow/evals/evaluators/_content_safety/common/evaluate_with_rai_service.py b/src/promptflow-evals/promptflow/evals/evaluators/_content_safety/common/evaluate_with_rai_service.py index a7c0b31f376..00c231163c2 100644 --- a/src/promptflow-evals/promptflow/evals/evaluators/_content_safety/common/evaluate_with_rai_service.py +++ b/src/promptflow-evals/promptflow/evals/evaluators/_content_safety/common/evaluate_with_rai_service.py @@ -215,7 +215,7 @@ def parse_response( # pylint: disable=too-many-branches,too-many-statements return result -def _get_service_discovery_url(azure_ai_project: Dict, token: str) -> str: +def _get_service_discovery_url(azure_ai_project: dict, token: str) -> str: """Get the discovery service URL for the Azure AI project :param azure_ai_project: The Azure AI project details. @@ -240,7 +240,7 @@ def _get_service_discovery_url(azure_ai_project: Dict, token: str) -> str: return f"{base_url.scheme}://{base_url.netloc}" -def get_rai_svc_url(project_scope: Dict, token: str) -> str: +def get_rai_svc_url(project_scope: dict, token: str) -> str: """Get the Responsible AI service URL :param project_scope: The Azure AI project scope details. @@ -294,8 +294,8 @@ def fetch_or_reuse_token(credential: TokenCredential, token: str = None) -> str: def evaluate_with_rai_service( - question: str, answer: str, metric_name: str, project_scope: Dict, credential: TokenCredential -) -> List[List[Dict]]: + question: str, answer: str, metric_name: str, project_scope: dict, credential: TokenCredential +): """ "Evaluate the content safety of the answer using Responsible AI service :param question: The question to evaluate.