Skip to content

Commit

Permalink
adding more suggestions
Browse files Browse the repository at this point in the history
  • Loading branch information
Kerem Kurban committed Nov 6, 2024
1 parent bc23ea3 commit 8b025f8
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 17 deletions.
16 changes: 8 additions & 8 deletions src/neuroagent/scripts/avalidate_tool_calls.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@


async def fetch_tool_call(
session: aiohttp.ClientSession, test_case: Dict[str, Any], base_url: str
session: aiohttp.ClientSession, query: Dict[str, Any], base_url: str
) -> Dict[str, Any]:
"""
Fetch the tool call results for a given test case.
Expand All @@ -27,7 +27,7 @@ async def fetch_tool_call(
Args:
----
session (aiohttp.ClientSession): The aiohttp session used to make the HTTP request.
test_case (dict): A dictionary containing the test case data, including the prompt,
query (dict): A dictionary containing the test case data, including the prompt,
expected tools, optional tools, and forbidden tools.
base_url (str): The base URL of the API.
Expand All @@ -36,10 +36,10 @@ async def fetch_tool_call(
dict: A dictionary containing the prompt, actual tool calls, expected tool calls,
and whether the actual calls match the expected ones.
"""
prompt = test_case["prompt"]
expected_tool_calls = test_case["expected_tools"]
optional_tools = test_case["optional_tools"]
forbidden_tools = test_case["forbidden_tools"]
prompt = query["prompt"]
expected_tool_calls = query["expected_tools"]
optional_tools = query["optional_tools"]
forbidden_tools = query["forbidden_tools"]

logging.info(f"Testing prompt: {prompt}")

Expand Down Expand Up @@ -117,8 +117,8 @@ async def validate_tool_calls_async(

async with aiohttp.ClientSession() as session:
tasks = [
fetch_tool_call(session, test_case, base_url)
for test_case in tool_calls_data
fetch_tool_call(session, query, base_url)
for query in tool_calls_data
]
results_list = await asyncio.gather(*tasks)

Expand Down
17 changes: 8 additions & 9 deletions src/neuroagent/scripts/validate_tool_calls.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,23 +83,22 @@ def validate_tool_calls(
results_list = []

# Iterate over each test case with a progress bar
for test_case in tqdm(tool_calls_data, desc="Processing test cases"):
prompt = test_case["prompt"]
expected_tool_calls = test_case["expected_tools"]
optional_tools = test_case["optional_tools"]
forbidden_tools = test_case["forbidden_tools"]
for query in tqdm(tool_calls_data, desc="Processing test cases"):
prompt = query["prompt"]
expected_tool_calls = query["expected_tools"]
optional_tools = query["optional_tools"]
forbidden_tools = query["forbidden_tools"]

logging.info(f"Testing prompt: {prompt}")

# Send a request to the API
response = requests.post(
f"{base_url}/qa/run", # Replace with the actual endpoint
f"{base_url}/qa/run",
headers={
"Content-Type": "application/json"
}, # Ensure the correct header is set
},
json={
"query": prompt, # Add the 'query' field with the prompt as its value
"messages": [{"role": "user", "content": prompt}],
"query": prompt,
},
)

Expand Down

0 comments on commit 8b025f8

Please sign in to comment.