Skip to content

Commit

Permalink
feat: ai-proxy plugin (#11499)
Browse files Browse the repository at this point in the history
  • Loading branch information
shreemaan-abhishek authored Sep 17, 2024
1 parent b5ea128 commit d46737f
Show file tree
Hide file tree
Showing 18 changed files with 1,542 additions and 1 deletion.
6 changes: 6 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,12 @@ install: runtime
$(ENV_INSTALL) -d $(ENV_INST_LUADIR)/apisix/utils
$(ENV_INSTALL) apisix/utils/*.lua $(ENV_INST_LUADIR)/apisix/utils/

$(ENV_INSTALL) -d $(ENV_INST_LUADIR)/apisix/plugins/ai-proxy
$(ENV_INSTALL) apisix/plugins/ai-proxy/*.lua $(ENV_INST_LUADIR)/apisix/plugins/ai-proxy

$(ENV_INSTALL) -d $(ENV_INST_LUADIR)/apisix/plugins/ai-proxy/drivers
$(ENV_INSTALL) apisix/plugins/ai-proxy/drivers/*.lua $(ENV_INST_LUADIR)/apisix/plugins/ai-proxy/drivers

$(ENV_INSTALL) bin/apisix $(ENV_INST_BINDIR)/apisix


Expand Down
1 change: 1 addition & 0 deletions apisix/cli/config.lua
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,7 @@ local _M = {
"proxy-rewrite",
"workflow",
"api-breaker",
"ai-proxy",
"limit-conn",
"limit-count",
"limit-req",
Expand Down
16 changes: 16 additions & 0 deletions apisix/core/request.lua
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

local lfs = require("lfs")
local log = require("apisix.core.log")
local json = require("apisix.core.json")
local io = require("apisix.core.io")
local req_add_header
if ngx.config.subsystem == "http" then
Expand Down Expand Up @@ -334,6 +335,21 @@ function _M.get_body(max_size, ctx)
end


function _M.get_json_request_body_table()
local body, err = _M.get_body()
if not body then
return nil, { message = "could not get body: " .. (err or "request body is empty") }
end

local body_tab, err = json.decode(body)
if not body_tab then
return nil, { message = "could not get parse JSON request body: " .. err }
end

return body_tab
end


function _M.get_scheme(ctx)
if not ctx then
ctx = ngx.ctx.api_ctx
Expand Down
138 changes: 138 additions & 0 deletions apisix/plugins/ai-proxy.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
--
-- Licensed to the Apache Software Foundation (ASF) under one or more
-- contributor license agreements. See the NOTICE file distributed with
-- this work for additional information regarding copyright ownership.
-- The ASF licenses this file to You under the Apache License, Version 2.0
-- (the "License"); you may not use this file except in compliance with
-- the License. You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
--
local core = require("apisix.core")
local schema = require("apisix.plugins.ai-proxy.schema")
local require = require
local pcall = pcall
local internal_server_error = ngx.HTTP_INTERNAL_SERVER_ERROR
local bad_request = ngx.HTTP_BAD_REQUEST
local ngx_req = ngx.req
local ngx_print = ngx.print
local ngx_flush = ngx.flush

local plugin_name = "ai-proxy"
local _M = {
version = 0.5,
priority = 999,
name = plugin_name,
schema = schema,
}


function _M.check_schema(conf)
local ai_driver = pcall(require, "apisix.plugins.ai-proxy.drivers." .. conf.model.provider)
if not ai_driver then
return false, "provider: " .. conf.model.provider .. " is not supported."
end
return core.schema.check(schema.plugin_schema, conf)
end


local CONTENT_TYPE_JSON = "application/json"


local function keepalive_or_close(conf, httpc)
if conf.set_keepalive then
httpc:set_keepalive(10000, 100)
return
end
httpc:close()
end


function _M.access(conf, ctx)
local ct = core.request.header(ctx, "Content-Type") or CONTENT_TYPE_JSON
if not core.string.has_prefix(ct, CONTENT_TYPE_JSON) then
return bad_request, "unsupported content-type: " .. ct
end

local request_table, err = core.request.get_json_request_body_table()
if not request_table then
return bad_request, err
end

local ok, err = core.schema.check(schema.chat_request_schema, request_table)
if not ok then
return bad_request, "request format doesn't match schema: " .. err
end

if conf.model.name then
request_table.model = conf.model.name
end

if core.table.try_read_attr(conf, "model", "options", "stream") then
request_table.stream = true
end

local ai_driver = require("apisix.plugins.ai-proxy.drivers." .. conf.model.provider)
local res, err, httpc = ai_driver.request(conf, request_table, ctx)
if not res then
core.log.error("failed to send request to LLM service: ", err)
return internal_server_error
end

local body_reader = res.body_reader
if not body_reader then
core.log.error("LLM sent no response body")
return internal_server_error
end

if conf.passthrough then
ngx_req.init_body()
while true do
local chunk, err = body_reader() -- will read chunk by chunk
if err then
core.log.error("failed to read response chunk: ", err)
break
end
if not chunk then
break
end
ngx_req.append_body(chunk)
end
ngx_req.finish_body()
keepalive_or_close(conf, httpc)
return
end

if request_table.stream then
while true do
local chunk, err = body_reader() -- will read chunk by chunk
if err then
core.log.error("failed to read response chunk: ", err)
break
end
if not chunk then
break
end
ngx_print(chunk)
ngx_flush(true)
end
keepalive_or_close(conf, httpc)
return
else
local res_body, err = res:read_body()
if not res_body then
core.log.error("failed to read response body: ", err)
return internal_server_error
end
keepalive_or_close(conf, httpc)
return res.status, res_body
end
end

return _M
85 changes: 85 additions & 0 deletions apisix/plugins/ai-proxy/drivers/openai.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
--
-- Licensed to the Apache Software Foundation (ASF) under one or more
-- contributor license agreements. See the NOTICE file distributed with
-- this work for additional information regarding copyright ownership.
-- The ASF licenses this file to You under the Apache License, Version 2.0
-- (the "License"); you may not use this file except in compliance with
-- the License. You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
--
local _M = {}

local core = require("apisix.core")
local http = require("resty.http")
local url = require("socket.url")

local pairs = pairs

-- globals
local DEFAULT_HOST = "api.openai.com"
local DEFAULT_PORT = 443
local DEFAULT_PATH = "/v1/chat/completions"


function _M.request(conf, request_table, ctx)
local httpc, err = http.new()
if not httpc then
return nil, "failed to create http client to send request to LLM server: " .. err
end
httpc:set_timeout(conf.timeout)

local endpoint = core.table.try_read_attr(conf, "override", "endpoint")
local parsed_url
if endpoint then
parsed_url = url.parse(endpoint)
end

local ok, err = httpc:connect({
scheme = parsed_url.scheme or "https",
host = parsed_url.host or DEFAULT_HOST,
port = parsed_url.port or DEFAULT_PORT,
ssl_verify = conf.ssl_verify,
ssl_server_name = parsed_url.host or DEFAULT_HOST,
pool_size = conf.keepalive and conf.keepalive_pool,
})

if not ok then
return nil, "failed to connect to LLM server: " .. err
end

local path = (parsed_url.path or DEFAULT_PATH)

local headers = (conf.auth.header or {})
headers["Content-Type"] = "application/json"
local params = {
method = "POST",
headers = headers,
keepalive = conf.keepalive,
ssl_verify = conf.ssl_verify,
path = path,
query = conf.auth.query
}

if conf.model.options then
for opt, val in pairs(conf.model.options) do
request_table[opt] = val
end
end
params.body = core.json.encode(request_table)

local res, err = httpc:request(params)
if not res then
return nil, err
end

return res, nil, httpc
end

return _M
Loading

0 comments on commit d46737f

Please sign in to comment.