Skip to content

Commit

Permalink
Fix content safety unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
diondrapeck committed Jul 20, 2024
1 parent eaa5d81 commit 5c64264
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

import os
import re
from typing import Dict

import numpy as np

Expand All @@ -24,9 +23,7 @@ def __init__(self, model_config: AzureOpenAIModelConfiguration):

prompty_model_config = {"configuration": model_config}
if USER_AGENT and isinstance(model_config, AzureOpenAIModelConfiguration):
prompty_model_config.update(
{"parameters": {"extra_headers": {"x-ms-useragent": USER_AGENT}}}
)
prompty_model_config.update({"parameters": {"extra_headers": {"x-ms-useragent": USER_AGENT}}})
current_dir = os.path.dirname(__file__)
prompty_path = os.path.join(current_dir, "coherence.prompty")
self._flow = AsyncPrompty.load(source=prompty_path, model=prompty_model_config)
Expand Down Expand Up @@ -79,7 +76,7 @@ class CoherenceEvaluator:
def __init__(self, model_config: AzureOpenAIModelConfiguration):
self._async_evaluator = _AsyncCoherenceEvaluator(model_config)

def __call__(self, *, question: str, answer: str, **kwargs) -> Dict[str, float]:
def __call__(self, *, question: str, answer: str, **kwargs):
"""
Evaluate coherence.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# ---------------------------------------------------------
from abc import ABC
from typing import Dict, List

try:
from .common.constants import EvaluationMetrics
Expand All @@ -24,7 +23,7 @@ class ContentSafetySubEvaluatorBase(ABC):
:type metric: ~promptflow.evals.evaluators._content_safety.flow.constants.EvaluationMetrics
:param project_scope: The scope of the Azure AI project.
It contains subscription id, resource group, and project name.
:type project_scope: dict
:type project_scope: Dict
:param credential: The credential for connecting to Azure AI project.
:type credential: TokenCredential
"""
Expand All @@ -34,7 +33,7 @@ def __init__(self, metric: EvaluationMetrics, project_scope: dict, credential=No
self._project_scope = project_scope
self._credential = credential

def __call__(self, *, question: str, answer: str, **kwargs) -> List[List[Dict]]:
def __call__(self, *, question: str, answer: str, **kwargs):
"""
Evaluates content according to this evaluator's metric.
Expand All @@ -43,13 +42,12 @@ def __call__(self, *, question: str, answer: str, **kwargs) -> List[List[Dict]]:
:keyword answer: The answer to be evaluated.
:paramtype answer: str
:return: The evaluation score computation based on the Content Safety metric (self.metric).
:rtype: List[List[Dict]]
:rtype: Any
"""
# Validate inputs
# Raises value error if failed, so execution alone signifies success.
_ = validate_inputs(question=question, answer=answer)

# question: str, answer: str, metric_name: str, project_scope: dict, credential: TokenCredential
# Run score computation based on supplied metric.
result = evaluate_with_rai_service(
metric_name=self._metric,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
# ---------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# ---------------------------------------------------------
from typing import Dict, Union

from azure.identity import DefaultAzureCredential, ManagedIdentityCredential

try:
from ._content_safety_sub_evaluator_base import ContentSafetySubEvaluatorBase
Expand Down Expand Up @@ -46,9 +43,7 @@ class HateUnfairnessEvaluator(ContentSafetySubEvaluatorBase):
}
"""

def __init__(
self, project_scope: Dict, credential: Union[DefaultAzureCredential, ManagedIdentityCredential] = None
) -> None:
def __init__(self, project_scope: dict, credential=None) -> None:
# Hate_fairness is the actual backend metric name. Which, uh, doesn't sound great.
# so invert the name.
super().__init__(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ def parse_response( # pylint: disable=too-many-branches,too-many-statements
return result


def _get_service_discovery_url(azure_ai_project: Dict, token: str) -> str:
def _get_service_discovery_url(azure_ai_project: dict, token: str) -> str:
"""Get the discovery service URL for the Azure AI project
:param azure_ai_project: The Azure AI project details.
Expand All @@ -240,7 +240,7 @@ def _get_service_discovery_url(azure_ai_project: Dict, token: str) -> str:
return f"{base_url.scheme}://{base_url.netloc}"


def get_rai_svc_url(project_scope: Dict, token: str) -> str:
def get_rai_svc_url(project_scope: dict, token: str) -> str:
"""Get the Responsible AI service URL
:param project_scope: The Azure AI project scope details.
Expand Down Expand Up @@ -294,8 +294,8 @@ def fetch_or_reuse_token(credential: TokenCredential, token: str = None) -> str:


def evaluate_with_rai_service(
question: str, answer: str, metric_name: str, project_scope: Dict, credential: TokenCredential
) -> List[List[Dict]]:
question: str, answer: str, metric_name: str, project_scope: dict, credential: TokenCredential
):
""" "Evaluate the content safety of the answer using Responsible AI service
:param question: The question to evaluate.
Expand Down

0 comments on commit 5c64264

Please sign in to comment.