-
Notifications
You must be signed in to change notification settings - Fork 65
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #144 from GoogleCloudPlatform/feature/conversation…
…_evaluation_tool Add Evaluation Tool for generative conversations
- Loading branch information
Showing
3 changed files
with
703 additions
and
0 deletions.
There are no files selected for viewing
333 changes: 333 additions & 0 deletions
333
examples/vertex_ai_conversation/evaluation_tool__numeric_score__colab.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,333 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": { | ||
"id": "WpkyirmC-F33" | ||
}, | ||
"source": [ | ||
"# Vertex AI Conversation - Evaluation Tool\n", | ||
"\n", | ||
"This tool requieres user's input in several steps. Please run the cells one by one (Shift+Enter) to ensure all the steps are succesfully completed.\n", | ||
"\n", | ||
"## Instructions:\n", | ||
"\n", | ||
"1. **Set-up**\n", | ||
" 1. First cell: install and import dependencies\n", | ||
" 2. Second cell: authentication - it requieres following the steps in the pop-up window. Alternatively, it can be replaced by other [supported authentication method](https://github.com/GoogleCloudPlatform/dfcx-scrapi#authentication)\n", | ||
" 3. Third cell: introduce values for project, location and agent in the right panel; then run the cell.\n", | ||
" 4. Fourth cell: run examples to validate set-up is correct\n", | ||
"2. **Generate Questions & Answer**\n", | ||
" 1. First cell: save a sample csv file with correct format\n", | ||
" 2. Second cell: upload csv file with the fields `user_query` and an `ideal_answer` for all examples\n", | ||
" 3. Third cell: bulk generation of `agent_answer` that includes the text and link\n", | ||
"3. **Rating**\n", | ||
" 1. First cell: download csv and add the ratings offline\n", | ||
" 2. Second cell: upload csv file with the ratings\n", | ||
"4. **Results**\n", | ||
" 1. First cell: visualize distribution of ratings\n", | ||
"\n", | ||
"This notebook calls `DetectIntent` using [dfcx-scrapi library](https://github.com/GoogleCloudPlatform/dfcx-scrapi) for Dialogflow CX.\n", | ||
"\n", | ||
"\n", | ||
"## Rating guidance:\n", | ||
"\n", | ||
"For each sample (aka row), the rater should evaluate each answer (including ythe link) that was generated by the agent. The answer will be evaluated with a integer number (escalar) from -1 to 3 as following:\n", | ||
"* **+3** : Perfect answer > fully addresses the question with correct information and polite tone\n", | ||
"* **+2** : Good answer > may contain unnecessary info, may miss some info, or may not be perfectly articulated\n", | ||
"* **+1** : Slightly good answer > some truth to the answer\n", | ||
"* **0** : Neutral answer > no answer or answer contains irrelevant info\n", | ||
"* **-1** : Hurtful answer > wrong or misleading info, or inappropriate tone\n", | ||
"\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": { | ||
"id": "Afvsuux0zaWZ" | ||
}, | ||
"source": [ | ||
"## Set-up\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"id": "PPJYRHN83bHg" | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"# Dependencies\n", | ||
"!pip install dfcx-scrapi --quiet\n", | ||
"\n", | ||
"import io\n", | ||
"import matplotlib.pyplot as plt\n", | ||
"import numpy as np\n", | ||
"import pandas as pd\n", | ||
"\n", | ||
"from dfcx_scrapi.core.sessions import Sessions\n", | ||
"from google.auth import default\n", | ||
"from google.colab import auth\n", | ||
"from google.colab import files" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"id": "sztyBjNlIGAw" | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"# Authentication\n", | ||
"\n", | ||
"auth.authenticate_user()\n", | ||
"creds, _ = default()\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"id": "mRUB0Uf-3uzS" | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"# Agent config\n", | ||
"project_id = '' #@param{type: 'string'}\n", | ||
"location = 'global' #@param{type: 'string'}\n", | ||
"agent_id = '' #@param{type: 'string'}\n", | ||
"\n", | ||
"agent_id = f\"projects/{project_id}/locations/{location}/agents/{agent_id}\"\n", | ||
"print(agent_id)\n", | ||
"\n", | ||
"s = Sessions(agent_id=agent_id)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"id": "OChJbblt3dt7" | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"# Test\n", | ||
"user_query = 'Hello World!'\n", | ||
"agent_answer = s.get_agent_answer(user_query)\n", | ||
"print(f\" Q: {user_query}\\n A: {agent_answer}\")\n", | ||
"\n", | ||
"user_query = 'Which is the cheapest plan?'\n", | ||
"agent_answer = s.get_agent_answer(user_query)\n", | ||
"print(f\" Q: {user_query}\\n A: {agent_answer}\")\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": { | ||
"id": "L2WQime-8-Dw" | ||
}, | ||
"source": [ | ||
"## Generate Questions & Answer" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"id": "q3II66B04F0j" | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"# Create sample csv\n", | ||
"\n", | ||
"sample_df = pd.DataFrame({\n", | ||
" \"user_query\": [],\n", | ||
" \"ideal_answer\": [],\n", | ||
" \"agent_answer\": [],\n", | ||
" \"rating\": [],\n", | ||
" \"comment\": []\n", | ||
"})\n", | ||
"\n", | ||
"sample_df.loc[0] = [\"Who are you?\", \"I am an assistant\", \"\", 0, \"\"]\n", | ||
"sample_df.loc[1] = [\"Which is the cheapest plan?\", \"Basic plan\", \"\", 0, \"\"]\n", | ||
"sample_df.loc[2] = [\"My device is not working\", \"Call 888-555\", \"\", 0, \"\"]\n", | ||
"\n", | ||
"# Export to local drive as csv file\n", | ||
"file_name = 'data_sample.csv'\n", | ||
"sample_df.to_csv(file_name, encoding='utf-8-sig', index=False)\n", | ||
"files.download(file_name)\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"id": "OYr4Dy77KbfL" | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"input(f\"In your local drive, you can find the csv file '{file_name}' Add the user_query and ideal_answer per example \\nWhen done, click 'Enter'\")\n", | ||
"print('done')\n", | ||
"\n", | ||
"# Import from local drive the csv file with the user_query and ideal_answer per examples\n", | ||
"uploaded = files.upload()\n", | ||
"file_name2 = next(iter(uploaded))\n", | ||
"df = pd.read_csv(io.BytesIO(uploaded[file_name2]))\n", | ||
"\n", | ||
"assert df.shape[0] > 0, \"The csv has zero rows\"\n", | ||
"assert set(df.columns) == set(sample_df.columns), f\"The csv must have the following columns: {sample_df.columns.values}\"\n", | ||
"\n", | ||
"df" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"id": "RmJcxpFI881j" | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"# Generate answers for each query\n", | ||
"df['agent_answer'] = df.apply(lambda row: s.get_agent_answer(row[\"user_query\"]), axis=1)\n", | ||
"\n", | ||
"df" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": { | ||
"id": "yO2x7lc2BRDR" | ||
}, | ||
"source": [ | ||
"# Rating" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"id": "ZfAMlQbS8qsy" | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"# Export to local drive as csv file\n", | ||
"file_name = 'output.csv'\n", | ||
"df.to_csv(file_name, encoding='utf-8-sig', index=False)\n", | ||
"files.download(file_name)\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"id": "SEU44Mcy9mBU" | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"input(f\"In your local drive, you can find the csv file '{file_name}' Rate each agent_answer using ideal_answer as reference. Rating from -1 to 3. \\nWhen done, click 'Enter'\")\n", | ||
"print('done')\n", | ||
"\n", | ||
"# Import from local drive the csv file with the ratings\n", | ||
"uploaded = files.upload()\n", | ||
"file_name2 = next(iter(uploaded))\n", | ||
"df = pd.read_csv(io.BytesIO(uploaded[file_name2]))\n", | ||
"\n", | ||
"assert df.shape[0] > 0, \"The csv has zero rows\"\n", | ||
"assert set(df.columns) == set(sample_df.columns), f\"The csv must have the following columns: {sample_df.columns.values}\"\n", | ||
"\n", | ||
"df" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": { | ||
"id": "W5j9yAewRmNO" | ||
}, | ||
"source": [ | ||
"# Results\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"id": "I5209MB7VS1q" | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"# Rating distribution\n", | ||
"#df[\"rating\"].describe()\n", | ||
"\n", | ||
"# Histogram\n", | ||
"ratings_set = [-1, 0, 1, 2, 3]\n", | ||
"ratings_values = df['rating'].values\n", | ||
"ratings_count = len(ratings_values)\n", | ||
"\n", | ||
"bar_centers = np.linspace(min(ratings_set), max(ratings_set), len(ratings_set))\n", | ||
"bar_edges = np.linspace(min(ratings_set)-0.5, max(ratings_set)+0.5, len(ratings_set)+1)\n", | ||
"bar_heights, _ = np.histogram(ratings_values, bins=bar_edges, density=True)\n", | ||
"\n", | ||
"for center, _h in zip(bar_centers, bar_heights):\n", | ||
" print(f\"{center}: count={round(_h*ratings_count):.0f}, percentage={_h*100:.2f}%\")\n", | ||
"\n", | ||
"# Plot\n", | ||
"height_sum = 100 # for percentage, use 100\n", | ||
"fig, axs = plt.subplots(1, 1, figsize=(6, 4), tight_layout=True)\n", | ||
"\n", | ||
"plt.bar(bar_centers, height_sum*bar_heights, width=0.8)\n", | ||
"ratings_mean = np.mean(ratings_values)\n", | ||
"plt.plot([ratings_mean, ratings_mean], [0, height_sum], '--', label=f\"mean={ratings_mean:.2f}\", color='red')\n", | ||
"ratings_median = np.median(ratings_values)\n", | ||
"plt.plot([ratings_median, ratings_median], [0, height_sum], '--', label=f\"median={ratings_median:.2f}\", color='green')\n", | ||
"\n", | ||
"plt.axis((min(bar_edges), max(bar_edges), 0, round(1.2*max(height_sum*bar_heights), 1)))\n", | ||
"plt.legend(loc='upper left')\n", | ||
"plt.gca().grid(axis='y')\n", | ||
"plt.xlabel('Rating')\n", | ||
"plt.ylabel('Percentage [%]')\n", | ||
"plt.title(f\"Rating distribution (count={ratings_count})\")\n", | ||
"\n", | ||
"plt.tight_layout()\n", | ||
"plt.show()\n", | ||
"\n", | ||
"fig.savefig('ratings_distribution.png', dpi=fig.dpi)\n", | ||
"\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"id": "rYwsIZ0Ej-v9" | ||
}, | ||
"outputs": [], | ||
"source": [] | ||
} | ||
], | ||
"metadata": { | ||
"colab": { | ||
"provenance": [] | ||
}, | ||
"kernelspec": { | ||
"display_name": "Python 3", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.8.10" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 0 | ||
} |
Oops, something went wrong.