diff --git a/apisix/cli/config.lua b/apisix/cli/config.lua index 94843621a74b..7f15542b1d7e 100644 --- a/apisix/cli/config.lua +++ b/apisix/cli/config.lua @@ -213,6 +213,7 @@ local _M = { "authz-keycloak", "proxy-cache", "body-transformer", + "ai-prompt-template", "proxy-mirror", "proxy-rewrite", "workflow", diff --git a/apisix/plugins/ai-prompt-template.lua b/apisix/plugins/ai-prompt-template.lua new file mode 100644 index 000000000000..0a092c3f77c0 --- /dev/null +++ b/apisix/plugins/ai-prompt-template.lua @@ -0,0 +1,146 @@ +-- +-- Licensed to the Apache Software Foundation (ASF) under one or more +-- contributor license agreements. See the NOTICE file distributed with +-- this work for additional information regarding copyright ownership. +-- The ASF licenses this file to You under the Apache License, Version 2.0 +-- (the "License"); you may not use this file except in compliance with +-- the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. +-- +local core = require("apisix.core") +local body_transformer = require("apisix.plugins.body-transformer") +local ipairs = ipairs + +local prompt_schema = { + properties = { + role = { + type = "string", + enum = { "system", "user", "assistant" } + }, + content = { + type = "string", + minLength = 1, + } + }, + required = { "role", "content" } +} + +local prompts = { + type = "array", + minItems = 1, + items = prompt_schema +} + +local schema = { + type = "object", + properties = { + templates = { + type = "array", + minItems = 1, + items = { + type = "object", + properties = { + name = { + type = "string", + minLength = 1, + }, + template = { + type = "object", + properties = { + model = { + type = "string", + minLength = 1, + }, + messages = prompts + } + } + }, + required = {"name", "template"} + } + }, + }, + required = {"templates"}, +} + + +local _M = { + version = 0.1, + priority = 1060, + name = "ai-prompt-template", + schema = schema, +} + +local templates_lrucache = core.lrucache.new({ + ttl = 300, count = 256 +}) + +local templates_json_lrucache = core.lrucache.new({ + ttl = 300, count = 256 +}) + +function _M.check_schema(conf) + return core.schema.check(schema, conf) +end + + +local function get_request_body_table() + local body, err = core.request.get_body() + if not body then + return nil, { message = "could not get body: " .. err } + end + + local body_tab, err = core.json.decode(body) + if not body_tab then + return nil, { message = "could not get parse JSON request body: ", err } + end + + return body_tab +end + + +local function find_template(conf, template_name) + for _, template in ipairs(conf.templates) do + if template.name == template_name then + return template.template + end + end + return nil +end + +function _M.rewrite(conf, ctx) + local body_tab, err = get_request_body_table() + if not body_tab then + return 400, err + end + local template_name = body_tab.template_name + if not template_name then + return 400, { message = "template name is missing in request." } + end + + local template = templates_lrucache(template_name, conf, find_template, conf, template_name) + if not template then + return 400, { message = "template: " .. template_name .. " not configured." } + end + + local template_json = templates_json_lrucache(template, template, core.json.encode, template) + core.log.info("sending template to body_transformer: ", template_json) + return body_transformer.rewrite( + { + request = { + template = template_json, + input_format = "json" + } + }, + ctx + ) +end + + +return _M diff --git a/conf/config.yaml.example b/conf/config.yaml.example index 5a490a4bb4b3..5d22418caeb5 100644 --- a/conf/config.yaml.example +++ b/conf/config.yaml.example @@ -476,6 +476,7 @@ plugins: # plugin list (sorted by priority) #- error-log-logger # priority: 1091 - proxy-cache # priority: 1085 - body-transformer # priority: 1080 + - ai-prompt-template # priority: 1060 - proxy-mirror # priority: 1010 - proxy-rewrite # priority: 1008 - workflow # priority: 1006 diff --git a/docs/en/latest/config.json b/docs/en/latest/config.json index 928aec3b2b62..0998ec730cd1 100644 --- a/docs/en/latest/config.json +++ b/docs/en/latest/config.json @@ -91,6 +91,7 @@ "plugins/proxy-rewrite", "plugins/grpc-transcode", "plugins/grpc-web", + "plugins/ai-prompt-template", "plugins/fault-injection", "plugins/mocking", "plugins/degraphql", diff --git a/docs/en/latest/plugins/ai-prompt-template.md b/docs/en/latest/plugins/ai-prompt-template.md new file mode 100644 index 000000000000..9ca4e1f70aa9 --- /dev/null +++ b/docs/en/latest/plugins/ai-prompt-template.md @@ -0,0 +1,102 @@ +--- +title: ai-prompt-template +keywords: + - Apache APISIX + - API Gateway + - Plugin + - ai-prompt-template +description: This document contains information about the Apache APISIX ai-prompt-template Plugin. +--- + + + +## Description + +The `ai-prompt-template` plugin simplifies access to LLM providers, such as OpenAI and Anthropic, and their models by predefining the request format +using a template, which only allows users to pass customized values into template variables. + +## Plugin Attributes + +| **Field** | **Required** | **Type** | **Description** | +| ------------------------------------- | ------------ | -------- | --------------------------------------------------------------------------------------------------------------------------- | +| `templates` | Yes | Array | An array of template objects | +| `templates.name` | Yes | String | Name of the template. | +| `templates.template.model` | Yes | String | Model of the AI Model, for example `gpt-4` or `gpt-3.5`. See your LLM provider API documentation for more available models. | +| `templates.template.messages.role` | Yes | String | Role of the message (`system`, `user`, `assistant`) | +| `templates.template.messages.content` | Yes | String | Content of the message. | + +## Example usage + +Create a route with the `ai-prompt-template` plugin like so: + +```shell +curl "http://127.0.0.1:9180/apisix/admin/routes/1" -X PUT \ + -H "X-API-KEY: ${ADMIN_API_KEY}" \ + -d '{ + "uri": "/v1/chat/completions", + "upstream": { + "type": "roundrobin", + "nodes": { + "api.openai.com:443": 1 + }, + "scheme": "https", + "pass_host": "node" + }, + "plugins": { + "ai-prompt-template": { + "templates": [ + { + "name": "level of detail", + "template": { + "model": "gpt-4", + "messages": [ + { + "role": "user", + "content": "Explain about {{ topic }} in {{ level }}." + } + ] + } + } + ] + } + } + }' +``` + +Now send a request: + +```shell +curl http://127.0.0.1:9080/v1/chat/completions -i -XPOST -H 'Content-Type: application/json' -d '{ + "template_name": "level of detail", + "topic": "psychology", + "level": "brief" +}' -H "Authorization: Bearer " +``` + +Then the request body will be modified to something like this: + +```json +{ + "model": "some model", + "messages": [ + { "role": "user", "content": "Explain about psychology in brief." } + ] +} +``` diff --git a/t/admin/plugins.t b/t/admin/plugins.t index 911205f48cb4..547b1a316d56 100644 --- a/t/admin/plugins.t +++ b/t/admin/plugins.t @@ -93,6 +93,7 @@ opa authz-keycloak proxy-cache body-transformer +ai-prompt-template proxy-mirror proxy-rewrite workflow diff --git a/t/plugin/ai-prompt-template.t b/t/plugin/ai-prompt-template.t new file mode 100644 index 000000000000..050e0f246268 --- /dev/null +++ b/t/plugin/ai-prompt-template.t @@ -0,0 +1,403 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +use t::APISIX 'no_plan'; + +repeat_each(1); +log_level('info'); +no_root_location(); +no_shuffle(); + +add_block_preprocessor(sub { + my ($block) = @_; + + if (!$block->request) { + $block->set_value("request", "GET /t"); + } + +}); + +run_tests(); + +__DATA__ + +=== TEST 1: sanity +--- config + location /t { + content_by_lua_block { + local t = require("lib.test_admin").test + local code, body = t('/apisix/admin/routes/1', + ngx.HTTP_PUT, + [[{ + "uri": "/echo", + "upstream": { + "type": "roundrobin", + "nodes": { + "127.0.0.1:1980": 1 + } + }, + "plugins": { + "ai-prompt-template": { + "templates":[ + { + "name": "programming question", + "template": { + "model": "some model", + "messages": [ + { "role": "system", "content": "You are a {{ language }} programmer." }, + { "role": "user", "content": "Write a {{ program_name }} program." } + ] + } + }, + { + "name": "level of detail", + "template": { + "model": "some model", + "messages": [ + { "role": "user", "content": "Explain about {{ topic }} in {{ level }}." } + ] + } + } + ] + } + } + }]] + ) + + if code >= 300 then + ngx.status = code + end + ngx.say(body) + } +} +--- response_body +passed + + + +=== TEST 2: no templates +--- config + location /t { + content_by_lua_block { + local t = require("lib.test_admin").test + local code, body = t('/apisix/admin/routes/1', + ngx.HTTP_PUT, + [[{ + "uri": "/echo", + "upstream": { + "type": "roundrobin", + "nodes": { + "127.0.0.1:1980": 1 + } + }, + "plugins": { + "ai-prompt-template": { + "templates":[] + } + } + }]] + ) + + if code >= 300 then + ngx.status = code + end + ngx.say(body) + } +} +--- error_code: 400 +--- response_body eval +qr/.*property \\"templates\\" validation failed: expect array to have at least 1 items.*/ + + + +=== TEST 3: test template insertion +--- config + location /t { + content_by_lua_block { + local t = require("lib.test_admin").test + local json = require("apisix.core.json") + local code, body, actual_resp = t('/echo', + ngx.HTTP_POST, + [[{ + "template_name": "programming question", + "language": "python", + "program_name": "quick sort" + }]], + [[{ + "model": "some model", + "messages": [ + { "role": "system", "content": "You are a python programmer." }, + { "role": "user", "content": "Write a quick sort program." } + ] + }]] + ) + if code >= 300 then + ngx.status = code + ngx.say(body) + return + end + ngx.say("passed") + } + } +--- response_body +passed + + + +=== TEST 4: multiple templates +--- config + location /t { + content_by_lua_block { + local t = require("lib.test_admin").test + local code, body = t('/apisix/admin/routes/1', + ngx.HTTP_PUT, + [[{ + "uri": "/echo", + "upstream": { + "type": "roundrobin", + "nodes": { + "127.0.0.1:1980": 1 + } + }, + "plugins": { + "ai-prompt-template": { + "templates":[ + { + "name": "programming question", + "template": { + "model": "some model", + "messages": [ + { "role": "system", "content": "You are a {{ language }} programmer." }, + { "role": "user", "content": "Write a {{ program_name }} program." } + ] + } + }, + { + "name": "level of detail", + "template": { + "model": "some model", + "messages": [ + { "role": "user", "content": "Explain about {{ topic }} in {{ level }}." } + ] + } + } + ] + } + } + }]] + ) + + if code >= 300 then + ngx.status = code + end + ngx.say(body) + } +} +--- response_body +passed + + + +=== TEST 5: test second template +--- config + location /t { + content_by_lua_block { + local t = require("lib.test_admin").test + local json = require("apisix.core.json") + local code, body, actual_resp = t('/echo', + ngx.HTTP_POST, + [[{ + "template_name": "level of detail", + "topic": "psychology", + "level": "brief" + }]], + [[{ + "model": "some model", + "messages": [ + { "role": "user", "content": "Explain about psychology in brief." } + ] + }]] + ) + if code >= 300 then + ngx.status = code + ngx.say(body) + return + end + ngx.say("passed") + } + } +--- response_body +passed + + + +=== TEST 6: missing template items +--- config + location /t { + content_by_lua_block { + local t = require("lib.test_admin").test + local json = require("apisix.core.json") + local code, body, actual_resp = t('/echo', + ngx.HTTP_POST, + [[{ + "template_name": "level of detail", + "topic-missing": "psychology", + "level-missing": "brief" + }]], + [[{ + "model": "some model", + "messages": [ + { "role": "user", "content": "Explain about in ." } + ] + }]] + ) + if code >= 300 then + ngx.status = code + ngx.say(body) + return + end + ngx.say("passed") + } + } +--- response_body +passed + + + +=== TEST 7: request body contains non-existent template +--- config + location /t { + content_by_lua_block { + local t = require("lib.test_admin").test + local json = require("apisix.core.json") + local code, body, actual_resp = t('/echo', + ngx.HTTP_POST, + [[{ + "template_name": "random", + "some-key": "some-value" + }]] + ) + if code >= 300 then + ngx.status = code + ngx.say(body) + return + end + ngx.say("passed") + } + } +--- error_code: 400 +--- response_body eval +qr/.*template: random not configured.*/ + + + +=== TEST 8: request body contains non-existent template +--- config + location /t { + content_by_lua_block { + local t = require("lib.test_admin").test + local json = require("apisix.core.json") + local code, body, actual_resp = t('/echo', + ngx.HTTP_POST, + [[{ + "missing-template-name": "haha" + }]] + ) + if code >= 300 then + ngx.status = code + ngx.say(body) + return + end + ngx.say("passed") + } + } +--- error_code: 400 +--- response_body eval +qr/.*template name is missing in request.*/ + + + +=== TEST 9: (cache test) same template name in different routes +--- config + location /t { + content_by_lua_block { + local t = require("lib.test_admin").test + for i = 1, 5, 1 do + local code = t('/apisix/admin/routes/' .. i, + ngx.HTTP_PUT, + [[{ + "uri": "/]] .. i .. [[", + "upstream": { + "type": "roundrobin", + "nodes": { + "127.0.0.1:1980": 1 + } + }, + "plugins": { + "ai-prompt-template": { + "templates":[ + { + "name": "same name", + "template": { + "model": "some model", + "messages": [ + { "role": "system", "content": "Field: {{ field }} in route]] .. i .. [[." } + ] + } + } + ] + }, + "proxy-rewrite": { + "uri": "/echo" + } + } + }]] + ) + + if code >= 300 then + ngx.status = code + ngx.say("failed") + return + end + end + + for i = 1, 5, 1 do + local code, body = t('/' .. i, + ngx.HTTP_POST, + [[{ + "template_name": "same name", + "field": "foo" + }]], + [[{ + "model": "some model", + "messages": [ + { "role": "system", "content": "Field: foo in route]] .. i .. [[." } + ] + }]] + ) + if code >= 300 then + ngx.status = code + ngx.say(body) + return + end + end + ngx.status = 200 + ngx.say("passed") + } + } + +--- response_body +passed