Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make evaluation run a context manager instead of a singleton. #3529

Merged
merged 9 commits into from
Jul 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
229 changes: 133 additions & 96 deletions src/promptflow-evals/promptflow/evals/evaluate/_eval_run.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
# ---------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# ---------------------------------------------------------
import contextlib
import dataclasses
import enum
import logging
import os
import posixpath
import requests
import time
import uuid
from typing import Any, Dict, Optional, Type
from typing import Any, Dict, Optional, Set
from urllib.parse import urlparse

from requests.adapters import HTTPAdapter
Expand Down Expand Up @@ -52,28 +54,15 @@ def generate(run_name: Optional[str]) -> 'RunInfo':
)


class Singleton(type):
"""Singleton class, which will be used as a metaclass."""
class RunStatus(enum.Enum):
"""Run states."""
NOT_STARTED = 0
STARTED = 1
BROKEN = 2
TERMINATED = 3

_instances = {}

def __call__(cls, *args, **kwargs):
"""Redefinition of call to return one instance per type."""
if cls not in Singleton._instances:
Singleton._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs)
return Singleton._instances[cls]

@staticmethod
def destroy(cls: Type) -> None:
"""
Destroy the singleton instance.

:param cls: The class to be destroyed.
"""
Singleton._instances.pop(cls, None)


class EvalRun(metaclass=Singleton):
class EvalRun(contextlib.AbstractContextManager):
"""
The simple singleton run class, used for accessing artifact store.

Expand Down Expand Up @@ -119,25 +108,18 @@ def __init__(self,
self._workspace_name: str = workspace_name
self._ml_client: Any = ml_client
self._is_promptflow_run: bool = promptflow_run is not None
self._is_broken = False
if self._tracking_uri is None:
LOGGER.warning("tracking_uri was not provided, "
"The results will be saved locally, but will not be logged to Azure.")
self._url_base = None
self._is_broken = True
self.info = RunInfo.generate(run_name)
else:
self._url_base = urlparse(self._tracking_uri).netloc
if promptflow_run is not None:
self.info = RunInfo(
promptflow_run.name,
promptflow_run._experiment_name,
promptflow_run.name
)
else:
self._is_broken = self._start_run(run_name)
self._run_name = run_name
self._promptflow_run = promptflow_run
self._status = RunStatus.NOT_STARTED

self._is_terminated = False
@property
def status(self) -> RunStatus:
"""
Return the run status.

:return: The status of the run.
"""
return self._status

def _get_scope(self) -> str:
"""
Expand All @@ -156,76 +138,97 @@ def _get_scope(self) -> str:
self._workspace_name,
)

def _start_run(self, run_name: Optional[str]) -> bool:
def _start_run(self) -> None:
"""
Make a request to start the mlflow run. If the run will not start, it will be

marked as broken and the logging will be switched off.
:param run_name: The display name for the run.
:type run_name: Optional[str]
:returns: True if the run has started and False otherwise.
Start the run, or, if it is not applicable (for example, if tracking is not enabled), mark it as started.
"""
url = f"https://{self._url_base}/mlflow/v2.0" f"{self._get_scope()}/api/2.0/mlflow/runs/create"
body = {
"experiment_id": "0",
"user_id": "promptflow-evals",
"start_time": int(time.time() * 1000),
"tags": [{"key": "mlflow.user", "value": "promptflow-evals"}],
}
if run_name:
body["run_name"] = run_name
response = self.request_with_retry(
url=url,
method='POST',
json_dict=body
)
if response.status_code != 200:
self.info = RunInfo.generate(run_name)
LOGGER.warning(f"The run failed to start: {response.status_code}: {response.text}."
self._check_state_and_log('start run',
{v for v in RunStatus if v != RunStatus.NOT_STARTED},
True)
self._status = RunStatus.STARTED
if self._tracking_uri is None:
LOGGER.warning("tracking_uri was not provided, "
"The results will be saved locally, but will not be logged to Azure.")
return True
parsed_response = response.json()
self.info = RunInfo(
run_id=parsed_response['run']['info']['run_id'],
experiment_id=parsed_response['run']['info']['experiment_id'],
run_name=parsed_response['run']['info']['run_name']
)
return False

def end_run(self, status: str) -> None:
self._url_base = None
self._status = RunStatus.BROKEN
self.info = RunInfo.generate(self._run_name)
else:
self._url_base = urlparse(self._tracking_uri).netloc
if self._promptflow_run is not None:
self.info = RunInfo(
self._promptflow_run.name,
self._promptflow_run._experiment_name,
self._promptflow_run.name
)
else:
url = f"https://{self._url_base}/mlflow/v2.0" f"{self._get_scope()}/api/2.0/mlflow/runs/create"
body = {
"experiment_id": "0",
"user_id": "promptflow-evals",
"start_time": int(time.time() * 1000),
"tags": [{"key": "mlflow.user", "value": "promptflow-evals"}],
}
if self._run_name:
body["run_name"] = self._run_name
response = self.request_with_retry(
url=url,
method='POST',
json_dict=body
)
if response.status_code != 200:
self.info = RunInfo.generate(self._run_name)
LOGGER.warning(f"The run failed to start: {response.status_code}: {response.text}."
"The results will be saved locally, but will not be logged to Azure.")
self._status = RunStatus.BROKEN
else:
parsed_response = response.json()
self.info = RunInfo(
run_id=parsed_response['run']['info']['run_id'],
experiment_id=parsed_response['run']['info']['experiment_id'],
run_name=parsed_response['run']['info']['run_name']
)
self._status = RunStatus.STARTED

def _end_run(self, reason: str) -> None:
"""
Tetminate the run.

:param status: One of "FINISHED" "FAILED" and "KILLED"
:type status: str
:param reason: One of "FINISHED" "FAILED" and "KILLED"
:type reason: str
:raises: ValueError if the run is not in ("FINISHED", "FAILED", "KILLED")
"""
if not self._check_state_and_log('stop run',
{RunStatus.BROKEN, RunStatus.NOT_STARTED, RunStatus.TERMINATED},
False):
return
if self._is_promptflow_run:
# This run is already finished, we just add artifacts/metrics to it.
Singleton.destroy(EvalRun)
self._status = RunStatus.TERMINATED
return
if status not in ("FINISHED", "FAILED", "KILLED"):
if reason not in ("FINISHED", "FAILED", "KILLED"):
raise ValueError(
f"Incorrect terminal status {status}. " 'Valid statuses are "FINISHED", "FAILED" and "KILLED".'
f"Incorrect terminal status {reason}. " 'Valid statuses are "FINISHED", "FAILED" and "KILLED".'
)
if self._is_terminated:
LOGGER.warning("Unable to stop run because it was already terminated.")
return
if self._is_broken:
LOGGER.warning("Unable to stop run because the run failed to start.")
return
url = f"https://{self._url_base}/mlflow/v2.0" f"{self._get_scope()}/api/2.0/mlflow/runs/update"
body = {
"run_uuid": self.info.run_id,
"status": status,
"status": reason,
"end_time": int(time.time() * 1000),
"run_id": self.info.run_id,
}
response = self.request_with_retry(url=url, method="POST", json_dict=body)
if response.status_code != 200:
LOGGER.warning("Unable to terminate the run.")
Singleton.destroy(EvalRun)
self._is_terminated = True
self._status = RunStatus.TERMINATED

def __enter__(self):
"""The Context Manager enter call."""
self._start_run()
return self

def __exit__(self, exc_type, exc_value, exc_tb):
"""The context manager exit call."""
self._end_run("FINISHED")

def get_run_history_uri(self) -> str:
"""
Expand Down Expand Up @@ -306,6 +309,33 @@ def _log_warning(self, failed_op: str, response: requests.Response) -> None:
f"{response.text=}."
)

def _check_state_and_log(
self,
action: str,
bad_states: Set[RunStatus],
should_raise: bool) -> bool:
"""
Check that the run is in the correct state and log worning if it is not.

:param action: Action, which caused this check. For example if it is "log artifact",
the log message will start "Unable to log artifact."
:type action: str
:param bad_states: The states, considered invalid for given action.
:type bad_states: set
:param should_raise: Should we raise an error if the bad state has been encountered?
:type should_raise: bool
:raises: RuntimeError if should_raise is True and invalid state was encountered.
:return: boolean saying if run is in the correct state.
"""
if self._status in bad_states:
msg = f"Unable to {action} due to Run status={self._status}."
if should_raise:
raise RuntimeError(msg)
else:
LOGGER.warning(msg)
return False
return True

def log_artifact(self, artifact_folder: str, artifact_name: str = EVALUATION_ARTIFACT) -> None:
"""
The local implementation of mlflow-like artifact logging.
Expand All @@ -316,8 +346,7 @@ def log_artifact(self, artifact_folder: str, artifact_name: str = EVALUATION_ART
:param artifact_folder: The folder with artifacts to be uploaded.
:type artifact_folder: str
"""
if self._is_broken:
LOGGER.warning("Unable to log artifact because the run failed to start.")
if not self._check_state_and_log('log artifact', {RunStatus.BROKEN, RunStatus.NOT_STARTED}, False):
return
# Check if artifact dirrectory is empty or does not exist.
if not os.path.isdir(artifact_folder):
Expand Down Expand Up @@ -404,8 +433,7 @@ def log_metric(self, key: str, value: float) -> None:
:param value: The valure to be logged.
:type value: float
"""
if self._is_broken:
LOGGER.warning("Unable to log metric because the run failed to start.")
if not self._check_state_and_log('log metric', {RunStatus.BROKEN, RunStatus.NOT_STARTED}, False):
return
body = {
"run_uuid": self.info.run_id,
Expand All @@ -423,11 +451,20 @@ def log_metric(self, key: str, value: float) -> None:
if response.status_code != 200:
self._log_warning("save metrics", response)

@staticmethod
def get_instance(*args, **kwargs) -> "EvalRun":
def write_properties_to_run_history(self, properties: Dict[str, Any]) -> None:
"""
The convenience method to the the EvalRun instance.
Write properties to the RunHistory service.

:return: The EvalRun instance.
:param properties: The properties to be written to run history.
:type properties: dict
"""
return EvalRun(*args, **kwargs)
if not self._check_state_and_log('write properties', {RunStatus.BROKEN, RunStatus.NOT_STARTED}, False):
return
# update host to run history and request PATCH API
response = self.request_with_retry(
url=self.get_run_history_uri(),
method="PATCH",
json_dict={"runId": self.info.run_id, "properties": properties},
)
if response.status_code != 200:
LOGGER.error("Fail writing properties '%s' to run history: %s", properties, response.text)
Loading
Loading