From d46737fe70b6ce332146a9eb322e76997c8fa8ba Mon Sep 17 00:00:00 2001 From: Shreemaan Abhishek Date: Tue, 17 Sep 2024 10:08:58 +0545 Subject: [PATCH] feat: ai-proxy plugin (#11499) --- Makefile | 6 + apisix/cli/config.lua | 1 + apisix/core/request.lua | 16 + apisix/plugins/ai-proxy.lua | 138 ++++ apisix/plugins/ai-proxy/drivers/openai.lua | 85 +++ apisix/plugins/ai-proxy/schema.lua | 154 +++++ ci/common.sh | 21 + ci/linux_openresty_common_runner.sh | 2 + ci/redhat-ci.sh | 2 + conf/config.yaml.example | 1 + docs/en/latest/config.json | 3 +- docs/en/latest/plugins/ai-proxy.md | 144 +++++ t/admin/plugins.t | 1 + t/assets/ai-proxy-response.json | 15 + t/plugin/ai-proxy.t | 693 +++++++++++++++++++++ t/plugin/ai-proxy2.t | 200 ++++++ t/sse_server_example/go.mod | 3 + t/sse_server_example/main.go | 58 ++ 18 files changed, 1542 insertions(+), 1 deletion(-) create mode 100644 apisix/plugins/ai-proxy.lua create mode 100644 apisix/plugins/ai-proxy/drivers/openai.lua create mode 100644 apisix/plugins/ai-proxy/schema.lua create mode 100644 docs/en/latest/plugins/ai-proxy.md create mode 100644 t/assets/ai-proxy-response.json create mode 100644 t/plugin/ai-proxy.t create mode 100644 t/plugin/ai-proxy2.t create mode 100644 t/sse_server_example/go.mod create mode 100644 t/sse_server_example/main.go diff --git a/Makefile b/Makefile index 21a2389633b3..545a21e4f29f 100644 --- a/Makefile +++ b/Makefile @@ -374,6 +374,12 @@ install: runtime $(ENV_INSTALL) -d $(ENV_INST_LUADIR)/apisix/utils $(ENV_INSTALL) apisix/utils/*.lua $(ENV_INST_LUADIR)/apisix/utils/ + $(ENV_INSTALL) -d $(ENV_INST_LUADIR)/apisix/plugins/ai-proxy + $(ENV_INSTALL) apisix/plugins/ai-proxy/*.lua $(ENV_INST_LUADIR)/apisix/plugins/ai-proxy + + $(ENV_INSTALL) -d $(ENV_INST_LUADIR)/apisix/plugins/ai-proxy/drivers + $(ENV_INSTALL) apisix/plugins/ai-proxy/drivers/*.lua $(ENV_INST_LUADIR)/apisix/plugins/ai-proxy/drivers + $(ENV_INSTALL) bin/apisix $(ENV_INST_BINDIR)/apisix diff --git a/apisix/cli/config.lua b/apisix/cli/config.lua index 6ab10c9256cd..f5c5d8dcaf94 100644 --- a/apisix/cli/config.lua +++ b/apisix/cli/config.lua @@ -219,6 +219,7 @@ local _M = { "proxy-rewrite", "workflow", "api-breaker", + "ai-proxy", "limit-conn", "limit-count", "limit-req", diff --git a/apisix/core/request.lua b/apisix/core/request.lua index c5278b6b8072..fef4bf17e3f7 100644 --- a/apisix/core/request.lua +++ b/apisix/core/request.lua @@ -21,6 +21,7 @@ local lfs = require("lfs") local log = require("apisix.core.log") +local json = require("apisix.core.json") local io = require("apisix.core.io") local req_add_header if ngx.config.subsystem == "http" then @@ -334,6 +335,21 @@ function _M.get_body(max_size, ctx) end +function _M.get_json_request_body_table() + local body, err = _M.get_body() + if not body then + return nil, { message = "could not get body: " .. (err or "request body is empty") } + end + + local body_tab, err = json.decode(body) + if not body_tab then + return nil, { message = "could not get parse JSON request body: " .. err } + end + + return body_tab +end + + function _M.get_scheme(ctx) if not ctx then ctx = ngx.ctx.api_ctx diff --git a/apisix/plugins/ai-proxy.lua b/apisix/plugins/ai-proxy.lua new file mode 100644 index 000000000000..8a0d8fa970d4 --- /dev/null +++ b/apisix/plugins/ai-proxy.lua @@ -0,0 +1,138 @@ +-- +-- Licensed to the Apache Software Foundation (ASF) under one or more +-- contributor license agreements. See the NOTICE file distributed with +-- this work for additional information regarding copyright ownership. +-- The ASF licenses this file to You under the Apache License, Version 2.0 +-- (the "License"); you may not use this file except in compliance with +-- the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. +-- +local core = require("apisix.core") +local schema = require("apisix.plugins.ai-proxy.schema") +local require = require +local pcall = pcall +local internal_server_error = ngx.HTTP_INTERNAL_SERVER_ERROR +local bad_request = ngx.HTTP_BAD_REQUEST +local ngx_req = ngx.req +local ngx_print = ngx.print +local ngx_flush = ngx.flush + +local plugin_name = "ai-proxy" +local _M = { + version = 0.5, + priority = 999, + name = plugin_name, + schema = schema, +} + + +function _M.check_schema(conf) + local ai_driver = pcall(require, "apisix.plugins.ai-proxy.drivers." .. conf.model.provider) + if not ai_driver then + return false, "provider: " .. conf.model.provider .. " is not supported." + end + return core.schema.check(schema.plugin_schema, conf) +end + + +local CONTENT_TYPE_JSON = "application/json" + + +local function keepalive_or_close(conf, httpc) + if conf.set_keepalive then + httpc:set_keepalive(10000, 100) + return + end + httpc:close() +end + + +function _M.access(conf, ctx) + local ct = core.request.header(ctx, "Content-Type") or CONTENT_TYPE_JSON + if not core.string.has_prefix(ct, CONTENT_TYPE_JSON) then + return bad_request, "unsupported content-type: " .. ct + end + + local request_table, err = core.request.get_json_request_body_table() + if not request_table then + return bad_request, err + end + + local ok, err = core.schema.check(schema.chat_request_schema, request_table) + if not ok then + return bad_request, "request format doesn't match schema: " .. err + end + + if conf.model.name then + request_table.model = conf.model.name + end + + if core.table.try_read_attr(conf, "model", "options", "stream") then + request_table.stream = true + end + + local ai_driver = require("apisix.plugins.ai-proxy.drivers." .. conf.model.provider) + local res, err, httpc = ai_driver.request(conf, request_table, ctx) + if not res then + core.log.error("failed to send request to LLM service: ", err) + return internal_server_error + end + + local body_reader = res.body_reader + if not body_reader then + core.log.error("LLM sent no response body") + return internal_server_error + end + + if conf.passthrough then + ngx_req.init_body() + while true do + local chunk, err = body_reader() -- will read chunk by chunk + if err then + core.log.error("failed to read response chunk: ", err) + break + end + if not chunk then + break + end + ngx_req.append_body(chunk) + end + ngx_req.finish_body() + keepalive_or_close(conf, httpc) + return + end + + if request_table.stream then + while true do + local chunk, err = body_reader() -- will read chunk by chunk + if err then + core.log.error("failed to read response chunk: ", err) + break + end + if not chunk then + break + end + ngx_print(chunk) + ngx_flush(true) + end + keepalive_or_close(conf, httpc) + return + else + local res_body, err = res:read_body() + if not res_body then + core.log.error("failed to read response body: ", err) + return internal_server_error + end + keepalive_or_close(conf, httpc) + return res.status, res_body + end +end + +return _M diff --git a/apisix/plugins/ai-proxy/drivers/openai.lua b/apisix/plugins/ai-proxy/drivers/openai.lua new file mode 100644 index 000000000000..c8f7f4b6223f --- /dev/null +++ b/apisix/plugins/ai-proxy/drivers/openai.lua @@ -0,0 +1,85 @@ +-- +-- Licensed to the Apache Software Foundation (ASF) under one or more +-- contributor license agreements. See the NOTICE file distributed with +-- this work for additional information regarding copyright ownership. +-- The ASF licenses this file to You under the Apache License, Version 2.0 +-- (the "License"); you may not use this file except in compliance with +-- the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. +-- +local _M = {} + +local core = require("apisix.core") +local http = require("resty.http") +local url = require("socket.url") + +local pairs = pairs + +-- globals +local DEFAULT_HOST = "api.openai.com" +local DEFAULT_PORT = 443 +local DEFAULT_PATH = "/v1/chat/completions" + + +function _M.request(conf, request_table, ctx) + local httpc, err = http.new() + if not httpc then + return nil, "failed to create http client to send request to LLM server: " .. err + end + httpc:set_timeout(conf.timeout) + + local endpoint = core.table.try_read_attr(conf, "override", "endpoint") + local parsed_url + if endpoint then + parsed_url = url.parse(endpoint) + end + + local ok, err = httpc:connect({ + scheme = parsed_url.scheme or "https", + host = parsed_url.host or DEFAULT_HOST, + port = parsed_url.port or DEFAULT_PORT, + ssl_verify = conf.ssl_verify, + ssl_server_name = parsed_url.host or DEFAULT_HOST, + pool_size = conf.keepalive and conf.keepalive_pool, + }) + + if not ok then + return nil, "failed to connect to LLM server: " .. err + end + + local path = (parsed_url.path or DEFAULT_PATH) + + local headers = (conf.auth.header or {}) + headers["Content-Type"] = "application/json" + local params = { + method = "POST", + headers = headers, + keepalive = conf.keepalive, + ssl_verify = conf.ssl_verify, + path = path, + query = conf.auth.query + } + + if conf.model.options then + for opt, val in pairs(conf.model.options) do + request_table[opt] = val + end + end + params.body = core.json.encode(request_table) + + local res, err = httpc:request(params) + if not res then + return nil, err + end + + return res, nil, httpc +end + +return _M diff --git a/apisix/plugins/ai-proxy/schema.lua b/apisix/plugins/ai-proxy/schema.lua new file mode 100644 index 000000000000..382644dc2147 --- /dev/null +++ b/apisix/plugins/ai-proxy/schema.lua @@ -0,0 +1,154 @@ +-- +-- Licensed to the Apache Software Foundation (ASF) under one or more +-- contributor license agreements. See the NOTICE file distributed with +-- this work for additional information regarding copyright ownership. +-- The ASF licenses this file to You under the Apache License, Version 2.0 +-- (the "License"); you may not use this file except in compliance with +-- the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. +-- +local _M = {} + +local auth_item_schema = { + type = "object", + patternProperties = { + ["^[a-zA-Z0-9._-]+$"] = { + type = "string" + } + } +} + +local auth_schema = { + type = "object", + patternProperties = { + header = auth_item_schema, + query = auth_item_schema, + }, + additionalProperties = false, +} + +local model_options_schema = { + description = "Key/value settings for the model", + type = "object", + properties = { + max_tokens = { + type = "integer", + description = "Defines the max_tokens, if using chat or completion models.", + default = 256 + + }, + input_cost = { + type = "number", + description = "Defines the cost per 1M tokens in your prompt.", + minimum = 0 + + }, + output_cost = { + type = "number", + description = "Defines the cost per 1M tokens in the output of the AI.", + minimum = 0 + + }, + temperature = { + type = "number", + description = "Defines the matching temperature, if using chat or completion models.", + minimum = 0.0, + maximum = 5.0, + + }, + top_p = { + type = "number", + description = "Defines the top-p probability mass, if supported.", + minimum = 0, + maximum = 1, + + }, + stream = { + description = "Stream response by SSE", + type = "boolean", + default = false, + } + } +} + +local model_schema = { + type = "object", + properties = { + provider = { + type = "string", + description = "Name of the AI service provider.", + oneOf = { "openai" }, -- add more providers later + + }, + name = { + type = "string", + description = "Model name to execute.", + }, + options = model_options_schema, + override = { + type = "object", + properties = { + endpoint = { + type = "string", + description = "To be specified to override the host of the AI provider", + }, + } + } + }, + required = {"provider", "name"} +} + +_M.plugin_schema = { + type = "object", + properties = { + auth = auth_schema, + model = model_schema, + passthrough = { type = "boolean", default = false }, + timeout = { + type = "integer", + minimum = 1, + maximum = 60000, + default = 3000, + description = "timeout in milliseconds", + }, + keepalive = {type = "boolean", default = true}, + keepalive_timeout = {type = "integer", minimum = 1000, default = 60000}, + keepalive_pool = {type = "integer", minimum = 1, default = 30}, + ssl_verify = {type = "boolean", default = true }, + }, + required = {"model", "auth"} +} + +_M.chat_request_schema = { + type = "object", + properties = { + messages = { + type = "array", + minItems = 1, + items = { + properties = { + role = { + type = "string", + enum = {"system", "user", "assistant"} + }, + content = { + type = "string", + minLength = "1", + }, + }, + additionalProperties = false, + required = {"role", "content"}, + }, + } + }, + required = {"messages"} +} + +return _M diff --git a/ci/common.sh b/ci/common.sh index 146b7aa5080a..ae5d12b2b7c6 100644 --- a/ci/common.sh +++ b/ci/common.sh @@ -203,3 +203,24 @@ function start_grpc_server_example() { ss -lntp | grep 10051 | grep grpc_server && break done } + + +function start_sse_server_example() { + # build sse_server_example + pushd t/sse_server_example + go build + ./sse_server_example 7737 2>&1 & + + for (( i = 0; i <= 10; i++ )); do + sleep 0.5 + SSE_PROC=`ps -ef | grep sse_server_example | grep -v grep || echo "none"` + if [[ $SSE_PROC == "none" || "$i" -eq 10 ]]; then + echo "failed to start sse_server_example" + ss -antp | grep 7737 || echo "no proc listen port 7737" + exit 1 + else + break + fi + done + popd +} diff --git a/ci/linux_openresty_common_runner.sh b/ci/linux_openresty_common_runner.sh index ea2e8b41c8bb..1b73ceec92c6 100755 --- a/ci/linux_openresty_common_runner.sh +++ b/ci/linux_openresty_common_runner.sh @@ -77,6 +77,8 @@ script() { start_grpc_server_example + start_sse_server_example + # APISIX_ENABLE_LUACOV=1 PERL5LIB=.:$PERL5LIB prove -Itest-nginx/lib -r t FLUSH_ETCD=1 TEST_EVENTS_MODULE=$TEST_EVENTS_MODULE prove --timer -Itest-nginx/lib -I./ -r $TEST_FILE_SUB_DIR | tee /tmp/test.result rerun_flaky_tests /tmp/test.result diff --git a/ci/redhat-ci.sh b/ci/redhat-ci.sh index 3cad10b5992b..da9839d4e699 100755 --- a/ci/redhat-ci.sh +++ b/ci/redhat-ci.sh @@ -77,6 +77,8 @@ install_dependencies() { yum install -y iproute procps start_grpc_server_example + start_sse_server_example + # installing grpcurl install_grpcurl diff --git a/conf/config.yaml.example b/conf/config.yaml.example index da125f77daa2..bd741b2f767b 100644 --- a/conf/config.yaml.example +++ b/conf/config.yaml.example @@ -486,6 +486,7 @@ plugins: # plugin list (sorted by priority) - limit-count # priority: 1002 - limit-req # priority: 1001 #- node-status # priority: 1000 + - ai-proxy # priority: 999 #- brotli # priority: 996 - gzip # priority: 995 - server-info # priority: 990 diff --git a/docs/en/latest/config.json b/docs/en/latest/config.json index 2195688a365c..ad9c1e051523 100644 --- a/docs/en/latest/config.json +++ b/docs/en/latest/config.json @@ -96,7 +96,8 @@ "plugins/fault-injection", "plugins/mocking", "plugins/degraphql", - "plugins/body-transformer" + "plugins/body-transformer", + "plugins/ai-proxy" ] }, { diff --git a/docs/en/latest/plugins/ai-proxy.md b/docs/en/latest/plugins/ai-proxy.md new file mode 100644 index 000000000000..a6a4e35426eb --- /dev/null +++ b/docs/en/latest/plugins/ai-proxy.md @@ -0,0 +1,144 @@ +--- +title: ai-proxy +keywords: + - Apache APISIX + - API Gateway + - Plugin + - ai-proxy +description: This document contains information about the Apache APISIX ai-proxy Plugin. +--- + + + +## Description + +The `ai-proxy` plugin simplifies access to LLM providers and models by defining a standard request format +that allows key fields in plugin configuration to be embedded into the request. + +Proxying requests to OpenAI is supported now. Other LLM services will be supported soon. + +## Request Format + +### OpenAI + +- Chat API + +| Name | Type | Required | Description | +| ------------------ | ------ | -------- | --------------------------------------------------- | +| `messages` | Array | Yes | An array of message objects | +| `messages.role` | String | Yes | Role of the message (`system`, `user`, `assistant`) | +| `messages.content` | String | Yes | Content of the message | + +## Plugin Attributes + +| **Field** | **Required** | **Type** | **Description** | +| ------------------------- | ------------ | -------- | ------------------------------------------------------------------------------------ | +| auth | Yes | Object | Authentication configuration | +| auth.header | No | Object | Authentication headers. Key must match pattern `^[a-zA-Z0-9._-]+$`. | +| auth.query | No | Object | Authentication query parameters. Key must match pattern `^[a-zA-Z0-9._-]+$`. | +| model.provider | Yes | String | Name of the AI service provider (`openai`). | +| model.name | Yes | String | Model name to execute. | +| model.options | No | Object | Key/value settings for the model | +| model.options.max_tokens | No | Integer | Defines the max tokens if using chat or completion models. Default: 256 | +| model.options.input_cost | No | Number | Cost per 1M tokens in your prompt. Minimum: 0 | +| model.options.output_cost | No | Number | Cost per 1M tokens in the output of the AI. Minimum: 0 | +| model.options.temperature | No | Number | Matching temperature for models. Range: 0.0 - 5.0 | +| model.options.top_p | No | Number | Top-p probability mass. Range: 0 - 1 | +| model.options.stream | No | Boolean | Stream response by SSE. Default: false | +| model.override.endpoint | No | String | Override the endpoint of the AI provider | +| passthrough | No | Boolean | If enabled, the response from LLM will be sent to the upstream. Default: false | +| timeout | No | Integer | Timeout in milliseconds for requests to LLM. Range: 1 - 60000. Default: 3000 | +| keepalive | No | Boolean | Enable keepalive for requests to LLM. Default: true | +| keepalive_timeout | No | Integer | Keepalive timeout in milliseconds for requests to LLM. Minimum: 1000. Default: 60000 | +| keepalive_pool | No | Integer | Keepalive pool size for requests to LLM. Minimum: 1. Default: 30 | +| ssl_verify | No | Boolean | SSL verification for requests to LLM. Default: true | + +## Example usage + +Create a route with the `ai-proxy` plugin like so: + +```shell +curl "http://127.0.0.1:9180/apisix/admin/routes/1" -X PUT \ + -H "X-API-KEY: ${ADMIN_API_KEY}" \ + -d '{ + "uri": "/anything", + "plugins": { + "ai-proxy": { + "auth": { + "header": { + "Authorization": "Bearer " + } + }, + "model": { + "provider": "openai", + "name": "gpt-4", + "options": { + "max_tokens": 512, + "temperature": 1.0 + } + } + } + }, + "upstream": { + "type": "roundrobin", + "nodes": { + "somerandom.com:443": 1 + }, + "scheme": "https", + "pass_host": "node" + } + }' +``` + +Since `passthrough` is not enabled upstream node can be any arbitrary value because it won't be contacted. + +Now send a request: + +```shell +curl http://127.0.0.1:9080/anything -i -XPOST -H 'Content-Type: application/json' -d '{ + "messages": [ + { "role": "system", "content": "You are a mathematician" }, + { "role": "user", "a": 1, "content": "What is 1+1?" } + ] + }' +``` + +You will receive a response like this: + +```json +{ + "choices": [ + { + "finish_reason": "stop", + "index": 0, + "message": { + "content": "The sum of \\(1 + 1\\) is \\(2\\).", + "role": "assistant" + } + } + ], + "created": 1723777034, + "id": "chatcmpl-9whRKFodKl5sGhOgHIjWltdeB8sr7", + "model": "gpt-4o-2024-05-13", + "object": "chat.completion", + "system_fingerprint": "fp_abc28019ad", + "usage": { "completion_tokens": 15, "prompt_tokens": 23, "total_tokens": 38 } +} +``` diff --git a/t/admin/plugins.t b/t/admin/plugins.t index ef43ea9f3965..bf3d485e8b31 100644 --- a/t/admin/plugins.t +++ b/t/admin/plugins.t @@ -102,6 +102,7 @@ api-breaker limit-conn limit-count limit-req +ai-proxy gzip server-info traffic-split diff --git a/t/assets/ai-proxy-response.json b/t/assets/ai-proxy-response.json new file mode 100644 index 000000000000..94665e5eaea9 --- /dev/null +++ b/t/assets/ai-proxy-response.json @@ -0,0 +1,15 @@ +{ + "choices": [ + { + "finish_reason": "stop", + "index": 0, + "message": { "content": "1 + 1 = 2.", "role": "assistant" } + } + ], + "created": 1723780938, + "id": "chatcmpl-9wiSIg5LYrrpxwsr2PubSQnbtod1P", + "model": "gpt-4o-2024-05-13", + "object": "chat.completion", + "system_fingerprint": "fp_abc28019ad", + "usage": { "completion_tokens": 8, "prompt_tokens": 23, "total_tokens": 31 } +} diff --git a/t/plugin/ai-proxy.t b/t/plugin/ai-proxy.t new file mode 100644 index 000000000000..445e406f60ab --- /dev/null +++ b/t/plugin/ai-proxy.t @@ -0,0 +1,693 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +use t::APISIX 'no_plan'; + +log_level("info"); +repeat_each(1); +no_long_string(); +no_root_location(); + + +my $resp_file = 't/assets/ai-proxy-response.json'; +open(my $fh, '<', $resp_file) or die "Could not open file '$resp_file' $!"; +my $resp = do { local $/; <$fh> }; +close($fh); + +print "Hello, World!\n"; +print $resp; + + +add_block_preprocessor(sub { + my ($block) = @_; + + if (!defined $block->request) { + $block->set_value("request", "GET /t"); + } + + my $http_config = $block->http_config // <<_EOC_; + server { + server_name openai; + listen 6724; + + default_type 'application/json'; + + location /anything { + content_by_lua_block { + local json = require("cjson.safe") + + if ngx.req.get_method() ~= "POST" then + ngx.status = 400 + ngx.say("Unsupported request method: ", ngx.req.get_method()) + end + ngx.req.read_body() + local body = ngx.req.get_body_data() + + if body ~= "SELECT * FROM STUDENTS" then + ngx.status = 503 + ngx.say("passthrough doesn't work") + return + end + ngx.say('{"foo", "bar"}') + } + } + + location /v1/chat/completions { + content_by_lua_block { + local json = require("cjson.safe") + + if ngx.req.get_method() ~= "POST" then + ngx.status = 400 + ngx.say("Unsupported request method: ", ngx.req.get_method()) + end + ngx.req.read_body() + local body, err = ngx.req.get_body_data() + body, err = json.decode(body) + + local test_type = ngx.req.get_headers()["test-type"] + if test_type == "options" then + if body.foo == "bar" then + ngx.status = 200 + ngx.say("options works") + else + ngx.status = 500 + ngx.say("model options feature doesn't work") + end + return + end + + local header_auth = ngx.req.get_headers()["authorization"] + local query_auth = ngx.req.get_uri_args()["apikey"] + + if header_auth ~= "Bearer token" and query_auth ~= "apikey" then + ngx.status = 401 + ngx.say("Unauthorized") + return + end + + if header_auth == "Bearer token" or query_auth == "apikey" then + ngx.req.read_body() + local body, err = ngx.req.get_body_data() + body, err = json.decode(body) + + if not body.messages or #body.messages < 1 then + ngx.status = 400 + ngx.say([[{ "error": "bad request"}]]) + return + end + + if body.messages[1].content == "write an SQL query to get all rows from student table" then + ngx.print("SELECT * FROM STUDENTS") + return + end + + ngx.status = 200 + ngx.say([[$resp]]) + return + end + + + ngx.status = 503 + ngx.say("reached the end of the test suite") + } + } + + location /random { + content_by_lua_block { + ngx.say("path override works") + } + } + } +_EOC_ + + $block->set_value("http_config", $http_config); +}); + +run_tests(); + +__DATA__ + +=== TEST 1: minimal viable configuration +--- config + location /t { + content_by_lua_block { + local plugin = require("apisix.plugins.ai-proxy") + local ok, err = plugin.check_schema({ + model = { + provider = "openai", + name = "gpt-4", + }, + auth = { + header = { + some_header = "some_value" + } + } + }) + + if not ok then + ngx.say(err) + else + ngx.say("passed") + end + } + } +--- response_body +passed + + + +=== TEST 2: unsupported provider +--- config + location /t { + content_by_lua_block { + local plugin = require("apisix.plugins.ai-proxy") + local ok, err = plugin.check_schema({ + model = { + provider = "some-unique", + name = "gpt-4", + }, + auth = { + header = { + some_header = "some_value" + } + } + }) + + if not ok then + ngx.say(err) + else + ngx.say("passed") + end + } + } +--- response_body eval +qr/.*provider: some-unique is not supported.*/ + + + +=== TEST 3: set route with wrong auth header +--- config + location /t { + content_by_lua_block { + local t = require("lib.test_admin").test + local code, body = t('/apisix/admin/routes/1', + ngx.HTTP_PUT, + [[{ + "uri": "/anything", + "plugins": { + "ai-proxy": { + "auth": { + "header": { + "Authorization": "Bearer wrongtoken" + } + }, + "model": { + "provider": "openai", + "name": "gpt-35-turbo-instruct", + "options": { + "max_tokens": 512, + "temperature": 1.0 + } + }, + "override": { + "endpoint": "http://localhost:6724" + }, + "ssl_verify": false + } + }, + "upstream": { + "type": "roundrobin", + "nodes": { + "canbeanything.com": 1 + } + } + }]] + ) + + if code >= 300 then + ngx.status = code + end + ngx.say(body) + } + } +--- response_body +passed + + + +=== TEST 4: send request +--- request +POST /anything +{ "messages": [ { "role": "system", "content": "You are a mathematician" }, { "role": "user", "content": "What is 1+1?"} ] } +--- error_code: 401 +--- response_body +Unauthorized + + + +=== TEST 5: set route with right auth header +--- config + location /t { + content_by_lua_block { + local t = require("lib.test_admin").test + local code, body = t('/apisix/admin/routes/1', + ngx.HTTP_PUT, + [[{ + "uri": "/anything", + "plugins": { + "ai-proxy": { + "auth": { + "header": { + "Authorization": "Bearer token" + } + }, + "model": { + "provider": "openai", + "name": "gpt-35-turbo-instruct", + "options": { + "max_tokens": 512, + "temperature": 1.0 + } + }, + "override": { + "endpoint": "http://localhost:6724" + }, + "ssl_verify": false + } + }, + "upstream": { + "type": "roundrobin", + "nodes": { + "canbeanything.com": 1 + } + } + }]] + ) + + if code >= 300 then + ngx.status = code + end + ngx.say(body) + } + } +--- response_body +passed + + + +=== TEST 6: send request +--- request +POST /anything +{ "messages": [ { "role": "system", "content": "You are a mathematician" }, { "role": "user", "content": "What is 1+1?"} ] } +--- more_headers +Authorization: Bearer token +--- error_code: 200 +--- response_body eval +qr/\{ "content": "1 \+ 1 = 2\.", "role": "assistant" \}/ + + + +=== TEST 7: send request with empty body +--- request +POST /anything +--- more_headers +Authorization: Bearer token +--- error_code: 400 +--- response_body_chomp +failed to get request body: request body is empty + + + +=== TEST 8: send request with wrong method (GET) should work +--- request +GET /anything +{ "messages": [ { "role": "system", "content": "You are a mathematician" }, { "role": "user", "content": "What is 1+1?"} ] } +--- more_headers +Authorization: Bearer token +--- error_code: 200 +--- response_body eval +qr/\{ "content": "1 \+ 1 = 2\.", "role": "assistant" \}/ + + + +=== TEST 9: wrong JSON in request body should give error +--- request +GET /anything +{}"messages": [ { "role": "system", "cont +--- error_code: 400 +--- response_body +{"message":"could not get parse JSON request body: Expected the end but found T_STRING at character 3"} + + + +=== TEST 10: content-type should be JSON +--- request +POST /anything +prompt%3Dwhat%2520is%25201%2520%252B%25201 +--- more_headers +Content-Type: application/x-www-form-urlencoded +--- error_code: 400 +--- response_body chomp +unsupported content-type: application/x-www-form-urlencoded + + + +=== TEST 11: request schema validity check +--- request +POST /anything +{ "messages-missing": [ { "role": "system", "content": "xyz" } ] } +--- more_headers +Authorization: Bearer token +--- error_code: 400 +--- response_body chomp +request format doesn't match schema: property "messages" is required + + + +=== TEST 12: model options being merged to request body +--- config + location /t { + content_by_lua_block { + local t = require("lib.test_admin").test + local code, body = t('/apisix/admin/routes/1', + ngx.HTTP_PUT, + [[{ + "uri": "/anything", + "plugins": { + "ai-proxy": { + "auth": { + "header": { + "Authorization": "Bearer token" + } + }, + "model": { + "provider": "openai", + "name": "some-model", + "options": { + "foo": "bar", + "temperature": 1.0 + } + }, + "override": { + "endpoint": "http://localhost:6724" + }, + "ssl_verify": false + } + }, + "upstream": { + "type": "roundrobin", + "nodes": { + "canbeanything.com": 1 + } + } + }]] + ) + + if code >= 300 then + ngx.status = code + ngx.say(body) + return + end + + local code, body, actual_body = t("/anything", + ngx.HTTP_POST, + [[{ + "messages": [ + { "role": "system", "content": "You are a mathematician" }, + { "role": "user", "content": "What is 1+1?" } + ] + }]], + nil, + { + ["test-type"] = "options", + ["Content-Type"] = "application/json", + } + ) + + ngx.status = code + ngx.say(actual_body) + + } + } +--- error_code: 200 +--- response_body_chomp +options_works + + + +=== TEST 13: override path +--- config + location /t { + content_by_lua_block { + local t = require("lib.test_admin").test + local code, body = t('/apisix/admin/routes/1', + ngx.HTTP_PUT, + [[{ + "uri": "/anything", + "plugins": { + "ai-proxy": { + "auth": { + "header": { + "Authorization": "Bearer token" + } + }, + "model": { + "provider": "openai", + "name": "some-model", + "options": { + "foo": "bar", + "temperature": 1.0 + } + }, + "override": { + "endpoint": "http://localhost:6724/random" + }, + "ssl_verify": false + } + }, + "upstream": { + "type": "roundrobin", + "nodes": { + "canbeanything.com": 1 + } + } + }]] + ) + + if code >= 300 then + ngx.status = code + ngx.say(body) + return + end + + local code, body, actual_body = t("/anything", + ngx.HTTP_POST, + [[{ + "messages": [ + { "role": "system", "content": "You are a mathematician" }, + { "role": "user", "content": "What is 1+1?" } + ] + }]], + nil, + { + ["test-type"] = "path", + ["Content-Type"] = "application/json", + } + ) + + ngx.status = code + ngx.say(actual_body) + + } + } +--- response_body_chomp +path override works + + + +=== TEST 14: set route with right auth header +--- config + location /t { + content_by_lua_block { + local t = require("lib.test_admin").test + local code, body = t('/apisix/admin/routes/1', + ngx.HTTP_PUT, + [[{ + "uri": "/anything", + "plugins": { + "ai-proxy": { + "auth": { + "header": { + "Authorization": "Bearer token" + } + }, + "model": { + "provider": "openai", + "name": "gpt-35-turbo-instruct", + "options": { + "max_tokens": 512, + "temperature": 1.0 + } + }, + "override": { + "endpoint": "http://localhost:6724" + }, + "ssl_verify": false, + "passthrough": true + } + }, + "upstream": { + "type": "roundrobin", + "nodes": { + "127.0.0.1:6724": 1 + } + } + }]] + ) + + if code >= 300 then + ngx.status = code + end + ngx.say(body) + } + } +--- response_body +passed + + + +=== TEST 15: send request with wrong method should work +--- request +POST /anything +{ "messages": [ { "role": "user", "content": "write an SQL query to get all rows from student table" } ] } +--- more_headers +Authorization: Bearer token +--- error_code: 200 +--- response_body +{"foo", "bar"} + + + +=== TEST 16: set route with stream = true (SSE) +--- config + location /t { + content_by_lua_block { + local t = require("lib.test_admin").test + local code, body = t('/apisix/admin/routes/1', + ngx.HTTP_PUT, + [[{ + "uri": "/anything", + "plugins": { + "ai-proxy": { + "auth": { + "header": { + "Authorization": "Bearer token" + } + }, + "model": { + "provider": "openai", + "name": "gpt-35-turbo-instruct", + "options": { + "max_tokens": 512, + "temperature": 1.0, + "stream": true + } + }, + "override": { + "endpoint": "http://localhost:7737" + }, + "ssl_verify": false + } + }, + "upstream": { + "type": "roundrobin", + "nodes": { + "canbeanything.com": 1 + } + } + }]] + ) + + if code >= 300 then + ngx.status = code + end + ngx.say(body) + } + } +--- response_body +passed + + + +=== TEST 17: test is SSE works as expected +--- config + location /t { + content_by_lua_block { + local http = require("resty.http") + local httpc = http.new() + local core = require("apisix.core") + + local ok, err = httpc:connect({ + scheme = "http", + host = "localhost", + port = ngx.var.server_port, + }) + + if not ok then + ngx.status = 500 + ngx.say(err) + return + end + + local params = { + method = "POST", + headers = { + ["Content-Type"] = "application/json", + }, + path = "/anything", + body = [[{ + "messages": [ + { "role": "system", "content": "some content" } + ] + }]], + } + + local res, err = httpc:request(params) + if not res then + ngx.status = 500 + ngx.say(err) + return + end + + local final_res = {} + while true do + local chunk, err = res.body_reader() -- will read chunk by chunk + if err then + core.log.error("failed to read response chunk: ", err) + break + end + if not chunk then + break + end + core.table.insert_tail(final_res, chunk) + end + + ngx.print(#final_res .. final_res[6]) + } + } +--- response_body_like eval +qr/6data: \[DONE\]\n\n/ diff --git a/t/plugin/ai-proxy2.t b/t/plugin/ai-proxy2.t new file mode 100644 index 000000000000..6e398e5665a4 --- /dev/null +++ b/t/plugin/ai-proxy2.t @@ -0,0 +1,200 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +use t::APISIX 'no_plan'; + +log_level("info"); +repeat_each(1); +no_long_string(); +no_root_location(); + + +my $resp_file = 't/assets/ai-proxy-response.json'; +open(my $fh, '<', $resp_file) or die "Could not open file '$resp_file' $!"; +my $resp = do { local $/; <$fh> }; +close($fh); + +print "Hello, World!\n"; +print $resp; + + +add_block_preprocessor(sub { + my ($block) = @_; + + if (!defined $block->request) { + $block->set_value("request", "GET /t"); + } + + my $http_config = $block->http_config // <<_EOC_; + server { + server_name openai; + listen 6724; + + default_type 'application/json'; + + location /v1/chat/completions { + content_by_lua_block { + local json = require("cjson.safe") + + if ngx.req.get_method() ~= "POST" then + ngx.status = 400 + ngx.say("Unsupported request method: ", ngx.req.get_method()) + end + ngx.req.read_body() + local body, err = ngx.req.get_body_data() + body, err = json.decode(body) + + local query_auth = ngx.req.get_uri_args()["api_key"] + + if query_auth ~= "apikey" then + ngx.status = 401 + ngx.say("Unauthorized") + return + end + + + ngx.status = 200 + ngx.say("passed") + } + } + } +_EOC_ + + $block->set_value("http_config", $http_config); +}); + +run_tests(); + +__DATA__ + +=== TEST 1: set route with wrong query param +--- config + location /t { + content_by_lua_block { + local t = require("lib.test_admin").test + local code, body = t('/apisix/admin/routes/1', + ngx.HTTP_PUT, + [[{ + "uri": "/anything", + "plugins": { + "ai-proxy": { + "auth": { + "query": { + "api_key": "wrong_key" + } + }, + "model": { + "provider": "openai", + "name": "gpt-35-turbo-instruct", + "options": { + "max_tokens": 512, + "temperature": 1.0 + } + }, + "override": { + "endpoint": "http://localhost:6724" + }, + "ssl_verify": false + } + }, + "upstream": { + "type": "roundrobin", + "nodes": { + "canbeanything.com": 1 + } + } + }]] + ) + + if code >= 300 then + ngx.status = code + end + ngx.say(body) + } + } +--- response_body +passed + + + +=== TEST 2: send request +--- request +POST /anything +{ "messages": [ { "role": "system", "content": "You are a mathematician" }, { "role": "user", "content": "What is 1+1?"} ] } +--- error_code: 401 +--- response_body +Unauthorized + + + +=== TEST 3: set route with right query param +--- config + location /t { + content_by_lua_block { + local t = require("lib.test_admin").test + local code, body = t('/apisix/admin/routes/1', + ngx.HTTP_PUT, + [[{ + "uri": "/anything", + "plugins": { + "ai-proxy": { + "auth": { + "query": { + "api_key": "apikey" + } + }, + "model": { + "provider": "openai", + "name": "gpt-35-turbo-instruct", + "options": { + "max_tokens": 512, + "temperature": 1.0 + } + }, + "override": { + "endpoint": "http://localhost:6724" + }, + "ssl_verify": false + } + }, + "upstream": { + "type": "roundrobin", + "nodes": { + "canbeanything.com": 1 + } + } + }]] + ) + + if code >= 300 then + ngx.status = code + end + ngx.say(body) + } + } +--- response_body +passed + + + +=== TEST 4: send request +--- request +POST /anything +{ "messages": [ { "role": "system", "content": "You are a mathematician" }, { "role": "user", "content": "What is 1+1?"} ] } +--- error_code: 200 +--- response_body +passed diff --git a/t/sse_server_example/go.mod b/t/sse_server_example/go.mod new file mode 100644 index 000000000000..9cc909d0338e --- /dev/null +++ b/t/sse_server_example/go.mod @@ -0,0 +1,3 @@ +module foo.bar/apache/sse_server_example + +go 1.17 diff --git a/t/sse_server_example/main.go b/t/sse_server_example/main.go new file mode 100644 index 000000000000..ab976c86094a --- /dev/null +++ b/t/sse_server_example/main.go @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package main + +import ( + "fmt" + "log" + "net/http" + "os" + "time" +) + +func sseHandler(w http.ResponseWriter, r *http.Request) { + // Set the headers for SSE + w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set("Cache-Control", "no-cache") + w.Header().Set("Connection", "keep-alive") + + f, ok := w.(http.Flusher); + if !ok { + fmt.Fprintf(w, "[ERROR]") + return + } + // A simple loop that sends a message every 500ms + for i := 0; i < 5; i++ { + // Create a message to send to the client + fmt.Fprintf(w, "data: %s\n\n", time.Now().Format(time.RFC3339)) + + // Flush the data immediately to the client + f.Flush() + time.Sleep(500 * time.Millisecond) + } + fmt.Fprintf(w, "data: %s\n\n", "[DONE]") +} + +func main() { + // Create a simple route + http.HandleFunc("/v1/chat/completions", sseHandler) + port := os.Args[1] + // Start the server + log.Println("Starting server on :", port) + log.Fatal(http.ListenAndServe(":" + port, nil)) +}