-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathrun_evals.py
95 lines (79 loc) · 4.02 KB
/
run_evals.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import os
from typing import Optional, List
from enum import Enum
from langsmith import Client
from langchain.smith import RunEvalConfig, run_on_dataset
from langchain_experimental.llms.ollama_functions import OllamaFunctions
from langchain_mistralai.chat_models import ChatMistralAI
from langchain_anthropic import ChatAnthropic
from langchain_anthropic.experimental import ChatAnthropicTools
from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser
from langchain.pydantic_v1 import BaseModel, Field
from langchain_openai.chat_models import ChatOpenAI
from langchain.prompts import ChatPromptTemplate
from langchain.chains.openai_functions import (
convert_to_openai_function
)
class ToneEnum(str, Enum):
positive = "positive"
negative = "negative"
class Email(BaseModel):
"""Relevant information about an email."""
sender: Optional[str] = Field(None, description="The sender's name, if available")
sender_phone_number: Optional[str] = Field(None, description="The sender's phone number, if available")
sender_address: Optional[str] = Field(None, description="The sender's address, if available")
action_items: List[str] = Field(..., description="A list of action items requested by the email")
topic: str = Field(..., description="High level description of what the email is about")
tone: ToneEnum = Field(..., description="The tone of the email.")
prompt = ChatPromptTemplate.from_messages(
[
("system", "You are an expert researcher."),
(
"human",
"What can you tell me about the following email? Make sure to answer in the correct format. If provided with a tool, you must call it when responding: {email}",
),
]
)
openai_functions = [convert_to_openai_function(Email)]
llm_kwargs = {
"functions": openai_functions,
"function_call": {"name": openai_functions[0]["name"]}
}
# Ollama JSON mode has a bug where it infintely generates newlines. This stop sequence hack fixes it
# llm = OllamaFunctions(temperature=0, model="llama2", timeout=300, stop=["\n\n\n\n"])
# llm = ChatOpenAI(temperature=0, model="gpt-4-1106-preview")
llm = ChatAnthropicTools(temperature=0, model="claude-3-sonnet-20240229")
# llm = ChatMistralAI(model="mistral-large")
# output_parser = get_openai_output_parser([Email])
# output_parser = JsonOutputFunctionsParser()
# extraction_chain = prompt | llm.bind(**llm_kwargs) | output_parser | (lambda x: { "output": x })
extraction_chain = prompt | llm.with_structured_output(Email)
eval_llm = ChatOpenAI(model="gpt-4", temperature=0, model_kwargs={"seed": 42})
# eval_llm = ChatAnthropic(model="claude-3-opus-20240229", temperature=0)
evaluation_config = RunEvalConfig(
evaluators=[
RunEvalConfig.LabeledScoreString(
criteria={
"accuracy": """
Score 1: The answer is incorrect and unrelated to the question or reference document.
Score 3: The answer is partially correct but has more than one omission or major errors.
Score 5: The answer is mostly correct but has more than one omission or major error.
Score 7: The answer is mostly correct but has at most one omission or major error.
Score 9: The answer is mostly correct with no omissions and only minor errors, and aligns with the reference document.
Score 10: The answer is correct, complete, and aligns with the reference document. Extra information is acceptable if it is sensible.
If the reference answer contains multiple alternatives, the predicted answer must only match one of the alternatives to be considered correct.
If the predicted answer contains additional helpful and accurate information that is not present in the reference answer, it should still be considered correct and not be penalized.
""" # noqa
}, llm=eval_llm, normalize_by=10.0
),
],
)
client = Client()
run_on_dataset(
dataset_name="Extraction Over Spam Emails",
llm_or_chain_factory=extraction_chain,
client=client,
evaluation=evaluation_config,
project_name="anthropic-tools-claude-3-sonnet-test",
concurrency_level=1,
)