From 819fccf28ac4b6013c0070964baf4648d2c4cff5 Mon Sep 17 00:00:00 2001 From: imotai Date: Sun, 12 Nov 2023 23:48:33 +0800 Subject: [PATCH 1/2] feat: move the prompt to role module --- .gitignore | 2 + agent/src/og_agent/llama_agent.py | 45 ---- agent/src/og_agent/llama_client.py | 2 +- agent/tests/llama_agent_tests.py | 288 +++++++++++++++++++++++++ agent/tests/tokenizer_test.py | 21 ++ roles/README.md | 1 + roles/setup.py | 26 +++ roles/src/og_roles/__init__.py | 0 roles/src/og_roles/code_interpreter.py | 77 +++++++ 9 files changed, 416 insertions(+), 46 deletions(-) create mode 100644 agent/tests/llama_agent_tests.py create mode 100644 agent/tests/tokenizer_test.py create mode 100644 roles/README.md create mode 100644 roles/setup.py create mode 100644 roles/src/og_roles/__init__.py create mode 100644 roles/src/og_roles/code_interpreter.py diff --git a/.gitignore b/.gitignore index 4c503a3..c982df7 100644 --- a/.gitignore +++ b/.gitignore @@ -159,3 +159,5 @@ cython_debug/ # and can be added to the global gitignore or merged into this file. For a more nuclear # option (not recommended) you can uncomment the following to ignore the entire idea folder. #.idea/ + +.cosine \ No newline at end of file diff --git a/agent/src/og_agent/llama_agent.py b/agent/src/og_agent/llama_agent.py index 87c412b..176e7ad 100644 --- a/agent/src/og_agent/llama_agent.py +++ b/agent/src/og_agent/llama_agent.py @@ -36,51 +36,6 @@ def _output_exception(self): "Sorry, the LLM did return nothing, You can use a better performance model" ) - - def _format_output(self, json_response): - """ - format the response and send it to the user - """ - answer = json_response["explanation"] - if json_response["action"] == "no_action": - return answer - elif json_response["action"] == "show_sample_code": - return "" - else: - code = json_response.get("code", None) - answer_code = """%s -```%s -%s -``` -""" % ( - answer, - json_response.get("language", "python"), - code if code else "", - ) - return answer_code - - async def handle_show_sample_code( - self, json_response, queue, context, task_context - ): - code = json_response["code"] - explanation = json_response["explanation"] - saved_filenames = json_response.get("saved_filenames", []) - tool_input = json.dumps({ - "code": code, - "explanation": explanation, - "saved_filenames": saved_filenames, - "language": json_response.get("language", "text"), - }) - await queue.put( - TaskResponse( - state=task_context.to_context_state_proto(), - response_type=TaskResponse.OnStepActionStart, - on_step_action_start=OnStepActionStart( - input=tool_input, tool="show_sample_code" - ), - ) - ) - async def handle_bash_code( self, json_response, queue, context, task_context, task_opt ): diff --git a/agent/src/og_agent/llama_client.py b/agent/src/og_agent/llama_client.py index 234f247..01879c9 100644 --- a/agent/src/og_agent/llama_client.py +++ b/agent/src/og_agent/llama_client.py @@ -20,7 +20,7 @@ def __init__(self, endpoint, key, grammar): super().__init__(endpoint + "/v1/chat/completions", key) self.grammar = grammar - async def chat(self, messages, model, temperature=0, max_tokens=1024, stop=[]): + async def chat(self, messages, model, temperature=0, max_tokens=1024, stop=['\n']): data = { "messages": messages, "temperature": temperature, diff --git a/agent/tests/llama_agent_tests.py b/agent/tests/llama_agent_tests.py new file mode 100644 index 0000000..ae09a0d --- /dev/null +++ b/agent/tests/llama_agent_tests.py @@ -0,0 +1,288 @@ +# vim:fenc=utf-8 + +# SPDX-FileCopyrightText: 2023 imotai +# SPDX-FileContributor: imotai +# +# SPDX-License-Identifier: Elastic-2.0 + +""" """ + +import json +import logging +import pytest +from og_sdk.kernel_sdk import KernelSDK +from og_agent import openai_agent +from og_proto.agent_server_pb2 import ProcessOptions, TaskResponse, ProcessTaskRequest +from openai.openai_object import OpenAIObject +import asyncio +import pytest_asyncio + +api_base = "127.0.0.1:9528" +api_key = "ZCeI9cYtOCyLISoi488BgZHeBkHWuFUH" + +logger = logging.getLogger(__name__) + +class PayloadStream: + + def __init__(self, payload): + self.payload = payload + + def __aiter__(self): + # create an iterator of the input keys + self.iter_keys = iter(self.payload) + return self + + async def __anext__(self): + try: + k = next(self.iter_keys) + obj = OpenAIObject() + delta = OpenAIObject() + content = OpenAIObject() + content.content = k + delta.delta = content + obj.choices = [delta] + return obj + except StopIteration: + # raise stopasynciteration at the end of iterator + raise StopAsyncIteration + + +class FunctionCallPayloadStream: + + def __init__(self, name, arguments): + self.name = name + self.arguments = arguments + + def __aiter__(self): + # create an iterator of the input keys + self.iter_keys = iter(self.arguments) + return self + + async def __anext__(self): + try: + k = next(self.iter_keys) + obj = OpenAIObject() + delta = OpenAIObject() + function_para = OpenAIObject() + function_para.name = self.name + function_para.arguments = k + function_call = OpenAIObject() + function_call.function_call = function_para + delta.delta = function_call + obj.choices = [delta] + return obj + except StopIteration: + # raise stopasynciteration at the end of iterator + raise StopAsyncIteration + + +class MockContext: + + def done(self): + return False + + +class MultiCallMock: + + def __init__(self, responses): + self.responses = responses + self.index = 0 + + def call(self, *args, **kwargs): + if self.index >= len(self.responses): + raise Exception("no more response") + self.index += 1 + logger.debug("call index %d", self.index) + return self.responses[self.index - 1] + + +@pytest.fixture +def kernel_sdk(): + endpoint = ( + "localhost:9527" # Replace with the actual endpoint of your test gRPC server + ) + return KernelSDK(endpoint, "ZCeI9cYtOCyLISoi488BgZHeBkHWuFUH") + + +@pytest.mark.asyncio +async def test_openai_agent_call_execute_bash_code(mocker, kernel_sdk): + kernel_sdk.connect() + arguments = { + "explanation": "the hello world in bash", + "code": "echo 'hello world'", + "saved_filenames": [], + "language": "bash", + } + stream1 = FunctionCallPayloadStream("execute", json.dumps(arguments)) + sentence = "The output 'hello world' is the result" + stream2 = PayloadStream(sentence) + call_mock = MultiCallMock([stream1, stream2]) + with mocker.patch( + "og_agent.openai_agent.openai.ChatCompletion.acreate", + side_effect=call_mock.call, + ) as mock_openai: + agent = openai_agent.OpenaiAgent("gpt4", kernel_sdk, is_azure=False) + queue = asyncio.Queue() + task_opt = ProcessOptions( + streaming=True, + llm_name="gpt4", + input_token_limit=100000, + output_token_limit=100000, + timeout=5, + ) + request = ProcessTaskRequest( + input_files=[], + task="write a hello world in bash", + context_id="", + options=task_opt, + ) + await agent.arun(request, queue, MockContext(), task_opt) + responses = [] + while True: + try: + response = await queue.get() + if not response: + break + responses.append(response) + except asyncio.QueueEmpty: + break + logger.info(responses) + console_output = list( + filter( + lambda x: x.response_type == TaskResponse.OnStepActionStreamStdout, + responses, + ) + ) + assert len(console_output) == 1, "bad console output count" + assert console_output[0].console_stdout == "hello world\n", "bad console output" + +@pytest.mark.asyncio +async def test_openai_agent_direct_message(mocker, kernel_sdk): + kernel_sdk.connect() + arguments = { + "message": "hello world", + } + stream1 = FunctionCallPayloadStream("direct_message", json.dumps(arguments)) + call_mock = MultiCallMock([stream1]) + with mocker.patch( + "og_agent.openai_agent.openai.ChatCompletion.acreate", + side_effect=call_mock.call, + ) as mock_openai: + agent = openai_agent.OpenaiAgent("gpt4", kernel_sdk, is_azure=False) + queue = asyncio.Queue() + task_opt = ProcessOptions( + streaming=False, + llm_name="gpt4", + input_token_limit=100000, + output_token_limit=100000, + timeout=5, + ) + request = ProcessTaskRequest( + input_files=[], + task="say hello world", + context_id="", + options=task_opt, + ) + await agent.arun(request, queue, MockContext(), task_opt) + responses = [] + while True: + try: + response = await queue.get() + if not response: + break + responses.append(response) + except asyncio.QueueEmpty: + break + logger.info(responses) + assert responses[0].final_answer.answer == "hello world" + + +@pytest.mark.asyncio +async def test_openai_agent_call_execute_python_code(mocker, kernel_sdk): + kernel_sdk.connect() + arguments = { + "explanation": "the hello world in python", + "code": "print('hello world')", + "language": "python", + "saved_filenames": [], + } + stream1 = FunctionCallPayloadStream("execute", json.dumps(arguments)) + sentence = "The output 'hello world' is the result" + stream2 = PayloadStream(sentence) + call_mock = MultiCallMock([stream1, stream2]) + with mocker.patch( + "og_agent.openai_agent.openai.ChatCompletion.acreate", + side_effect=call_mock.call, + ) as mock_openai: + agent = openai_agent.OpenaiAgent("gpt4", kernel_sdk, is_azure=False) + queue = asyncio.Queue() + task_opt = ProcessOptions( + streaming=True, + llm_name="gpt4", + input_token_limit=100000, + output_token_limit=100000, + timeout=5, + ) + request = ProcessTaskRequest( + input_files=[], + task="write a hello world in python", + context_id="", + options=task_opt, + ) + await agent.arun(request, queue, MockContext(), task_opt) + responses = [] + while True: + try: + response = await queue.get() + if not response: + break + responses.append(response) + except asyncio.QueueEmpty: + break + logger.info(responses) + console_output = list( + filter( + lambda x: x.response_type == TaskResponse.OnStepActionStreamStdout, + responses, + ) + ) + assert len(console_output) == 1, "bad console output count" + assert console_output[0].console_stdout == "hello world\n", "bad console output" + + +@pytest.mark.asyncio +async def test_openai_agent_smoke_test(mocker, kernel_sdk): + sentence = "Hello, how can I help you?" + stream = PayloadStream(sentence) + with mocker.patch( + "og_agent.openai_agent.openai.ChatCompletion.acreate", return_value=stream + ) as mock_openai: + agent = openai_agent.OpenaiAgent("gpt4", kernel_sdk, is_azure=False) + queue = asyncio.Queue() + task_opt = ProcessOptions( + streaming=True, + llm_name="gpt4", + input_token_limit=100000, + output_token_limit=100000, + timeout=5, + ) + request = ProcessTaskRequest( + input_files=[], task="hello", context_id="", options=task_opt + ) + await agent.arun(request, queue, MockContext(), task_opt) + responses = [] + while True: + try: + response = await queue.get() + if not response: + break + responses.append(response) + except asyncio.QueueEmpty: + break + logger.info(responses) + assert len(responses) == len(sentence) + 1, "bad response count" + assert ( + responses[-1].response_type == TaskResponse.OnFinalAnswer + ), "bad response type" + assert responses[-1].state.input_token_count == 325 + assert responses[-1].state.output_token_count == 8 diff --git a/agent/tests/tokenizer_test.py b/agent/tests/tokenizer_test.py new file mode 100644 index 0000000..63049df --- /dev/null +++ b/agent/tests/tokenizer_test.py @@ -0,0 +1,21 @@ +# vim:fenc=utf-8 + +# SPDX-FileCopyrightText: 2023 imotai +# SPDX-FileContributor: imotai +# +# SPDX-License-Identifier: Elastic-2.0 + +""" + +""" + +import logging +import io +from og_agent.tokenizer import tokenize + +logger = logging.getLogger(__name__) +def test_parse_explanation(): + arguments="""{"function_call":"execute", "arguments": {"explanation":"h""" + for token_state, token in tokenize(io.StringIO(arguments)): + logger.info(f"token_state: {token_state}, token: {token}") + diff --git a/roles/README.md b/roles/README.md new file mode 100644 index 0000000..2bbb978 --- /dev/null +++ b/roles/README.md @@ -0,0 +1 @@ +# the role module diff --git a/roles/setup.py b/roles/setup.py new file mode 100644 index 0000000..032689a --- /dev/null +++ b/roles/setup.py @@ -0,0 +1,26 @@ +# Copyright (C) 2023 dbpunk.com Author imotai +# SPDX-FileCopyrightText: 2023 imotai +# SPDX-FileContributor: imotai +# +# SPDX-License-Identifier: Elastic-2.0 + +""" """ +from setuptools import setup, find_packages + +setup( + name="og_roles", + version="0.3.6", + description="Open source llm agent service", + author="imotai", + author_email="wangtaize@dbpunk.com", + url="https://github.com/dbpunk-labs/octogen", + long_description=open("README.md").read(), + long_description_content_type="text/markdown", + packages=[ + "og_roles", + ], + package_dir={ + "og_roles": "src/og_roles", + }, + package_data={}, +) diff --git a/roles/src/og_roles/__init__.py b/roles/src/og_roles/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/roles/src/og_roles/code_interpreter.py b/roles/src/og_roles/code_interpreter.py new file mode 100644 index 0000000..c3727cd --- /dev/null +++ b/roles/src/og_roles/code_interpreter.py @@ -0,0 +1,77 @@ +# vim:fenc=utf-8 + +# SPDX-FileCopyrightText: 2023 imotai +# SPDX-FileContributor: imotai +# +# SPDX-License-Identifier: Elastic-2.0 + +""" + +""" +import json +from og_proto.prompt_pb2 import ActionDesc + + +ROLE = f"""You are the Programming Copilot, a world-class programmer to complete any goal by executing code""" +RULES = [ + "To complete the goal, write a plan and execute it step-by-step, limiting the number of steps to five", + "Every step must include the explanation and the code block. if the code block has any display data, save it as a file and add it to saved_filenames field", + "You have a fully controlled programming environment to execute code with internet connection but sudo is not allowed", + "You must try to correct your code when you get errors from the output", + "You can install new package with pip", + "Use `execute` action to execute any code and `direct_message` action to send message to user", +] +FUNCTION_EXECUTE= ActionDesc( + name="execute", + desc="This action executes code in your programming environment and returns the output", + parameters=json.dumps({ + "type": "object", + "properties": { + "explanation": { + "type": "string", + "description": "the explanation about the code parameters", + }, + "code": { + "type": "string", + "description": "the bash code to be executed", + }, + "language": { + "type": "string", + "description": "the language of the code, only python and bash are supported", + }, + "saved_filenames": { + "type": "array", + "items": {"type": "string"}, + "description": "A list of filenames that were created by the code", + }, + }, + "required": ["explanation", "code", "language"], + }), + ) + +FUNCTION_DIRECT_MESSAGE= ActionDesc( + name="direct_message", + desc="This action sends a direct message to user.", + parameters=json.dumps({ + "type": "object", + "properties": { + "message": { + "type": "string", + "description": "the message will be sent to user", + }, + }, + "required": ["message"], + }), +) +ACTIONS = [ + FUNCTION_EXECUTE +] +OUTPUT_FORMAT = """The output format must be a JSON format with the following fields: +* function_call: The name of the action +* arguments: The arguments of the action +""" + +OCTOGEN_CODELLAMA_MID_INS = """The above output of the %s determines whether the execution is successful. +If successful, go to the next step. If the current step is the final step, summarize the entire plan. If not, adjust the input and try again""" + +OCTOGEN_CODELLAMA_MID_ERROR_INS = """Adjust the action input and try again for the above output of %s showing the error message""" From 46f957af8c978ffb874ec749345cb49185e2db6e Mon Sep 17 00:00:00 2001 From: imotai Date: Wed, 13 Dec 2023 15:50:39 +0800 Subject: [PATCH 2/2] test: remove llama test --- agent/src/og_agent/agent_api_server.py | 5 +- agent/src/og_agent/base_agent.py | 6 +- agent/src/og_agent/llama_agent.py | 42 ++-- agent/src/og_agent/llama_client.py | 2 +- agent/src/og_agent/prompt.py | 80 ++++--- agent/tests/llama_agent_tests.py | 288 ------------------------- agent/tests/openai_agent_tests.py | 1 - agent/tests/tokenizer_test.py | 9 +- sdk/src/og_sdk/agent_sdk.py | 4 +- 9 files changed, 83 insertions(+), 354 deletions(-) delete mode 100644 agent/tests/llama_agent_tests.py diff --git a/agent/src/og_agent/agent_api_server.py b/agent/src/og_agent/agent_api_server.py index e62e62a..f28b25a 100644 --- a/agent/src/og_agent/agent_api_server.py +++ b/agent/src/og_agent/agent_api_server.py @@ -178,10 +178,13 @@ class TaskRequest(BaseModel): async def run_task(task: TaskRequest, key): - async for respond in agent_sdk.prompt(task.prompt, key, files=task.input_files, context_id=task.context_id): + async for respond in agent_sdk.prompt( + task.prompt, key, files=task.input_files, context_id=task.context_id + ): response = StepResponse.new_from(respond).model_dump(exclude_none=True) yield "data: %s\n" % json.dumps(response) + @app.post("/process") async def process_task( task: TaskRequest, diff --git a/agent/src/og_agent/base_agent.py b/agent/src/og_agent/base_agent.py index 97a8286..117315d 100644 --- a/agent/src/og_agent/base_agent.py +++ b/agent/src/og_agent/base_agent.py @@ -63,6 +63,7 @@ class TypingState: MESSAGE = 4 OTHER = 5 + class BaseAgent: def __init__(self, sdk): @@ -70,7 +71,9 @@ def __init__(self, sdk): self.model_name = "" self.agent_memories = {} - def create_new_memory_with_default_prompt(self, user_name, user_id, actions = ACTIONS): + def create_new_memory_with_default_prompt( + self, user_name, user_id, actions=ACTIONS + ): """ create a new memory for the user """ @@ -386,7 +389,6 @@ async def extract_message( response_token_count + context_output_token_count ) if is_json_format: - ( new_text_content, new_code_content, diff --git a/agent/src/og_agent/llama_agent.py b/agent/src/og_agent/llama_agent.py index 176e7ad..7c2652c 100644 --- a/agent/src/og_agent/llama_agent.py +++ b/agent/src/og_agent/llama_agent.py @@ -85,7 +85,7 @@ async def handle_python_function( state=task_context.to_context_state_proto(), response_type=TaskResponse.OnStepActionStart, on_step_action_start=OnStepActionStart( - input=tool_input, tool='execute' + input=tool_input, tool="execute" ), ) ) @@ -131,10 +131,10 @@ async def arun(self, request, queue, context, task_opt): context_id = ( request.context_id if request.context_id - else self.create_new_memory_with_default_prompt("", "", actions=[FUNCTION_EXECUTE, - FUNCTION_DIRECT_MESSAGE]) + else self.create_new_memory_with_default_prompt( + "", "", actions=[FUNCTION_EXECUTE, FUNCTION_DIRECT_MESSAGE] + ) ) - if context_id not in self.agent_memories: await queue.put( TaskResponse( @@ -145,7 +145,6 @@ async def arun(self, request, queue, context, task_opt): ) ) return - agent_memory = self.agent_memories[context_id] agent_memory.update_options(self.memory_option) agent_memory.append_chat_message( @@ -208,7 +207,8 @@ async def arun(self, request, queue, context, task_opt): break logger.debug(f" llama response {json_response}") if ( - 'function_call'in json_response and json_response["function_call"] == "execute" + "function_call" in json_response + and json_response["function_call"] == "execute" ): agent_memory.append_chat_message(message) tools_mapping = { @@ -216,8 +216,14 @@ async def arun(self, request, queue, context, task_opt): "bash": self.handle_bash_code, } - function_result = await tools_mapping[json_response["arguments"]['language']]( - json_response['arguments'], queue, context, task_context, task_opt + function_result = await tools_mapping[ + json_response["arguments"]["language"] + ]( + json_response["arguments"], + queue, + context, + task_context, + task_opt, ) logger.debug(f"the function result {function_result}") @@ -242,23 +248,31 @@ async def arun(self, request, queue, context, task_opt): "role": "user", "content": f"{action_output} \n {function_result.console_stdout}", }) - agent_memory.append_chat_message({"role": "user", "content": current_question}) + agent_memory.append_chat_message( + {"role": "user", "content": current_question} + ) elif function_result.has_error: agent_memory.append_chat_message({ "role": "user", "content": f"{action_output} \n {function_result.console_stderr}", }) current_question = f"Generate a new step to fix the above error" - agent_memory.append_chat_message({"role": "user", "content": current_question}) + agent_memory.append_chat_message( + {"role": "user", "content": current_question} + ) else: agent_memory.append_chat_message({ "role": "user", "content": f"{action_output} \n {function_result.console_stdout}", }) - agent_memory.append_chat_message({ - "role": "user", "content": current_question}) - elif 'function_call' in json_response and json_response["function_call"] == "direct_message": - message = json_response['arguments']['message'] + agent_memory.append_chat_message( + {"role": "user", "content": current_question} + ) + elif ( + "function_call" in json_response + and json_response["function_call"] == "direct_message" + ): + message = json_response["arguments"]["message"] await queue.put( TaskResponse( state=task_context.to_context_state_proto(), diff --git a/agent/src/og_agent/llama_client.py b/agent/src/og_agent/llama_client.py index 01879c9..465ab19 100644 --- a/agent/src/og_agent/llama_client.py +++ b/agent/src/og_agent/llama_client.py @@ -20,7 +20,7 @@ def __init__(self, endpoint, key, grammar): super().__init__(endpoint + "/v1/chat/completions", key) self.grammar = grammar - async def chat(self, messages, model, temperature=0, max_tokens=1024, stop=['\n']): + async def chat(self, messages, model, temperature=0, max_tokens=1024, stop=["\n"]): data = { "messages": messages, "temperature": temperature, diff --git a/agent/src/og_agent/prompt.py b/agent/src/og_agent/prompt.py index 82b6a15..9ba4b77 100644 --- a/agent/src/og_agent/prompt.py +++ b/agent/src/og_agent/prompt.py @@ -17,52 +17,50 @@ "Use `execute` action to execute any code and `direct_message` action to send message to user", ] -FUNCTION_EXECUTE= ActionDesc( - name="execute", - desc="This action executes code in your programming environment and returns the output", - parameters=json.dumps({ - "type": "object", - "properties": { - "explanation": { - "type": "string", - "description": "the explanation about the code parameters", - }, - "code": { - "type": "string", - "description": "the bash code to be executed", - }, - "language": { - "type": "string", - "description": "the language of the code, only python and bash are supported", - }, - "saved_filenames": { - "type": "array", - "items": {"type": "string"}, - "description": "A list of filenames that were created by the code", - }, +FUNCTION_EXECUTE = ActionDesc( + name="execute", + desc="This action executes code in your programming environment and returns the output", + parameters=json.dumps({ + "type": "object", + "properties": { + "explanation": { + "type": "string", + "description": "the explanation about the code parameters", }, - "required": ["explanation", "code", "language"], - }), - ) + "code": { + "type": "string", + "description": "the bash code to be executed", + }, + "language": { + "type": "string", + "description": "the language of the code, only python and bash are supported", + }, + "saved_filenames": { + "type": "array", + "items": {"type": "string"}, + "description": "A list of filenames that were created by the code", + }, + }, + "required": ["explanation", "code", "language"], + }), +) -FUNCTION_DIRECT_MESSAGE= ActionDesc( - name="direct_message", - desc="This action sends a direct message to user.", - parameters=json.dumps({ - "type": "object", - "properties": { - "message": { - "type": "string", - "description": "the message will be sent to user", - }, +FUNCTION_DIRECT_MESSAGE = ActionDesc( + name="direct_message", + desc="This action sends a direct message to user.", + parameters=json.dumps({ + "type": "object", + "properties": { + "message": { + "type": "string", + "description": "the message will be sent to user", }, - "required": ["message"], - }), + }, + "required": ["message"], + }), ) -ACTIONS = [ - FUNCTION_EXECUTE -] +ACTIONS = [FUNCTION_EXECUTE] OUTPUT_FORMAT = """The output format must be a JSON format with the following fields: * function_call: The name of the action diff --git a/agent/tests/llama_agent_tests.py b/agent/tests/llama_agent_tests.py deleted file mode 100644 index ae09a0d..0000000 --- a/agent/tests/llama_agent_tests.py +++ /dev/null @@ -1,288 +0,0 @@ -# vim:fenc=utf-8 - -# SPDX-FileCopyrightText: 2023 imotai -# SPDX-FileContributor: imotai -# -# SPDX-License-Identifier: Elastic-2.0 - -""" """ - -import json -import logging -import pytest -from og_sdk.kernel_sdk import KernelSDK -from og_agent import openai_agent -from og_proto.agent_server_pb2 import ProcessOptions, TaskResponse, ProcessTaskRequest -from openai.openai_object import OpenAIObject -import asyncio -import pytest_asyncio - -api_base = "127.0.0.1:9528" -api_key = "ZCeI9cYtOCyLISoi488BgZHeBkHWuFUH" - -logger = logging.getLogger(__name__) - -class PayloadStream: - - def __init__(self, payload): - self.payload = payload - - def __aiter__(self): - # create an iterator of the input keys - self.iter_keys = iter(self.payload) - return self - - async def __anext__(self): - try: - k = next(self.iter_keys) - obj = OpenAIObject() - delta = OpenAIObject() - content = OpenAIObject() - content.content = k - delta.delta = content - obj.choices = [delta] - return obj - except StopIteration: - # raise stopasynciteration at the end of iterator - raise StopAsyncIteration - - -class FunctionCallPayloadStream: - - def __init__(self, name, arguments): - self.name = name - self.arguments = arguments - - def __aiter__(self): - # create an iterator of the input keys - self.iter_keys = iter(self.arguments) - return self - - async def __anext__(self): - try: - k = next(self.iter_keys) - obj = OpenAIObject() - delta = OpenAIObject() - function_para = OpenAIObject() - function_para.name = self.name - function_para.arguments = k - function_call = OpenAIObject() - function_call.function_call = function_para - delta.delta = function_call - obj.choices = [delta] - return obj - except StopIteration: - # raise stopasynciteration at the end of iterator - raise StopAsyncIteration - - -class MockContext: - - def done(self): - return False - - -class MultiCallMock: - - def __init__(self, responses): - self.responses = responses - self.index = 0 - - def call(self, *args, **kwargs): - if self.index >= len(self.responses): - raise Exception("no more response") - self.index += 1 - logger.debug("call index %d", self.index) - return self.responses[self.index - 1] - - -@pytest.fixture -def kernel_sdk(): - endpoint = ( - "localhost:9527" # Replace with the actual endpoint of your test gRPC server - ) - return KernelSDK(endpoint, "ZCeI9cYtOCyLISoi488BgZHeBkHWuFUH") - - -@pytest.mark.asyncio -async def test_openai_agent_call_execute_bash_code(mocker, kernel_sdk): - kernel_sdk.connect() - arguments = { - "explanation": "the hello world in bash", - "code": "echo 'hello world'", - "saved_filenames": [], - "language": "bash", - } - stream1 = FunctionCallPayloadStream("execute", json.dumps(arguments)) - sentence = "The output 'hello world' is the result" - stream2 = PayloadStream(sentence) - call_mock = MultiCallMock([stream1, stream2]) - with mocker.patch( - "og_agent.openai_agent.openai.ChatCompletion.acreate", - side_effect=call_mock.call, - ) as mock_openai: - agent = openai_agent.OpenaiAgent("gpt4", kernel_sdk, is_azure=False) - queue = asyncio.Queue() - task_opt = ProcessOptions( - streaming=True, - llm_name="gpt4", - input_token_limit=100000, - output_token_limit=100000, - timeout=5, - ) - request = ProcessTaskRequest( - input_files=[], - task="write a hello world in bash", - context_id="", - options=task_opt, - ) - await agent.arun(request, queue, MockContext(), task_opt) - responses = [] - while True: - try: - response = await queue.get() - if not response: - break - responses.append(response) - except asyncio.QueueEmpty: - break - logger.info(responses) - console_output = list( - filter( - lambda x: x.response_type == TaskResponse.OnStepActionStreamStdout, - responses, - ) - ) - assert len(console_output) == 1, "bad console output count" - assert console_output[0].console_stdout == "hello world\n", "bad console output" - -@pytest.mark.asyncio -async def test_openai_agent_direct_message(mocker, kernel_sdk): - kernel_sdk.connect() - arguments = { - "message": "hello world", - } - stream1 = FunctionCallPayloadStream("direct_message", json.dumps(arguments)) - call_mock = MultiCallMock([stream1]) - with mocker.patch( - "og_agent.openai_agent.openai.ChatCompletion.acreate", - side_effect=call_mock.call, - ) as mock_openai: - agent = openai_agent.OpenaiAgent("gpt4", kernel_sdk, is_azure=False) - queue = asyncio.Queue() - task_opt = ProcessOptions( - streaming=False, - llm_name="gpt4", - input_token_limit=100000, - output_token_limit=100000, - timeout=5, - ) - request = ProcessTaskRequest( - input_files=[], - task="say hello world", - context_id="", - options=task_opt, - ) - await agent.arun(request, queue, MockContext(), task_opt) - responses = [] - while True: - try: - response = await queue.get() - if not response: - break - responses.append(response) - except asyncio.QueueEmpty: - break - logger.info(responses) - assert responses[0].final_answer.answer == "hello world" - - -@pytest.mark.asyncio -async def test_openai_agent_call_execute_python_code(mocker, kernel_sdk): - kernel_sdk.connect() - arguments = { - "explanation": "the hello world in python", - "code": "print('hello world')", - "language": "python", - "saved_filenames": [], - } - stream1 = FunctionCallPayloadStream("execute", json.dumps(arguments)) - sentence = "The output 'hello world' is the result" - stream2 = PayloadStream(sentence) - call_mock = MultiCallMock([stream1, stream2]) - with mocker.patch( - "og_agent.openai_agent.openai.ChatCompletion.acreate", - side_effect=call_mock.call, - ) as mock_openai: - agent = openai_agent.OpenaiAgent("gpt4", kernel_sdk, is_azure=False) - queue = asyncio.Queue() - task_opt = ProcessOptions( - streaming=True, - llm_name="gpt4", - input_token_limit=100000, - output_token_limit=100000, - timeout=5, - ) - request = ProcessTaskRequest( - input_files=[], - task="write a hello world in python", - context_id="", - options=task_opt, - ) - await agent.arun(request, queue, MockContext(), task_opt) - responses = [] - while True: - try: - response = await queue.get() - if not response: - break - responses.append(response) - except asyncio.QueueEmpty: - break - logger.info(responses) - console_output = list( - filter( - lambda x: x.response_type == TaskResponse.OnStepActionStreamStdout, - responses, - ) - ) - assert len(console_output) == 1, "bad console output count" - assert console_output[0].console_stdout == "hello world\n", "bad console output" - - -@pytest.mark.asyncio -async def test_openai_agent_smoke_test(mocker, kernel_sdk): - sentence = "Hello, how can I help you?" - stream = PayloadStream(sentence) - with mocker.patch( - "og_agent.openai_agent.openai.ChatCompletion.acreate", return_value=stream - ) as mock_openai: - agent = openai_agent.OpenaiAgent("gpt4", kernel_sdk, is_azure=False) - queue = asyncio.Queue() - task_opt = ProcessOptions( - streaming=True, - llm_name="gpt4", - input_token_limit=100000, - output_token_limit=100000, - timeout=5, - ) - request = ProcessTaskRequest( - input_files=[], task="hello", context_id="", options=task_opt - ) - await agent.arun(request, queue, MockContext(), task_opt) - responses = [] - while True: - try: - response = await queue.get() - if not response: - break - responses.append(response) - except asyncio.QueueEmpty: - break - logger.info(responses) - assert len(responses) == len(sentence) + 1, "bad response count" - assert ( - responses[-1].response_type == TaskResponse.OnFinalAnswer - ), "bad response type" - assert responses[-1].state.input_token_count == 325 - assert responses[-1].state.output_token_count == 8 diff --git a/agent/tests/openai_agent_tests.py b/agent/tests/openai_agent_tests.py index f353899..9013190 100644 --- a/agent/tests/openai_agent_tests.py +++ b/agent/tests/openai_agent_tests.py @@ -158,7 +158,6 @@ async def test_openai_agent_call_execute_bash_code(mocker, kernel_sdk): assert console_output[0].console_stdout == "hello world\n", "bad console output" - @pytest.mark.asyncio async def test_openai_agent_call_execute_python_code(mocker, kernel_sdk): kernel_sdk.connect() diff --git a/agent/tests/tokenizer_test.py b/agent/tests/tokenizer_test.py index 63049df..3540d7d 100644 --- a/agent/tests/tokenizer_test.py +++ b/agent/tests/tokenizer_test.py @@ -5,17 +5,16 @@ # # SPDX-License-Identifier: Elastic-2.0 -""" - -""" +""" """ import logging import io from og_agent.tokenizer import tokenize logger = logging.getLogger(__name__) + + def test_parse_explanation(): - arguments="""{"function_call":"execute", "arguments": {"explanation":"h""" + arguments = """{"function_call":"execute", "arguments": {"explanation":"h""" for token_state, token in tokenize(io.StringIO(arguments)): logger.info(f"token_state: {token_state}, token: {token}") - diff --git a/sdk/src/og_sdk/agent_sdk.py b/sdk/src/og_sdk/agent_sdk.py index 17595cc..23e68ba 100644 --- a/sdk/src/og_sdk/agent_sdk.py +++ b/sdk/src/og_sdk/agent_sdk.py @@ -186,7 +186,9 @@ async def prompt(self, prompt, api_key, files=[], context_id=None): metadata = aio.Metadata( ("api_key", api_key), ) - request = agent_server_pb2.ProcessTaskRequest(task=prompt, input_files=files, context_id=context_id) + request = agent_server_pb2.ProcessTaskRequest( + task=prompt, input_files=files, context_id=context_id + ) async for respond in self.stub.process_task(request, metadata=metadata): yield respond