diff --git a/src/open_prompt_extension.cpp b/src/open_prompt_extension.cpp index cf57b5a..e6625ba 100644 --- a/src/open_prompt_extension.cpp +++ b/src/open_prompt_extension.cpp @@ -24,19 +24,23 @@ namespace duckdb { struct OpenPromptData: FunctionData { idx_t model_idx; idx_t json_schema_idx; + idx_t json_system_prompt_idx; unique_ptr Copy() const { auto res = make_uniq(); res->model_idx = model_idx; res->json_schema_idx = json_schema_idx; + res->json_system_prompt_idx = json_system_prompt_idx; return res; }; bool Equals(const FunctionData &other) const { return model_idx == other.Cast().model_idx && - json_schema_idx == other.Cast().json_schema_idx; + json_schema_idx == other.Cast().json_schema_idx && + json_system_prompt_idx==other.Cast().json_system_prompt_idx; }; OpenPromptData() { model_idx = 0; json_schema_idx = 0; + json_system_prompt_idx = 0; } }; @@ -49,6 +53,8 @@ namespace duckdb { res->model_idx = i; } else if (argument->alias == "json_schema") { res->json_schema_idx = i; + } else if (argument->alias == "system_prompt") { + res->json_system_prompt_idx = i; } } return std::move(res); @@ -182,6 +188,7 @@ static void OpenPromptRequestFunction(DataChunk &args, ExpressionState &state, V std::string api_token = GetConfigValue(context, "openprompt_api_token", ""); std::string model_name = GetConfigValue(context, "openprompt_model_name", "qwen2.5:0.5b"); std::string json_schema; + std::string system_prompt; if (info.model_idx != 0) { model_name = args.data[info.model_idx].GetValue(0).ToString(); @@ -189,19 +196,57 @@ static void OpenPromptRequestFunction(DataChunk &args, ExpressionState &state, V if (info.json_schema_idx != 0) { json_schema = args.data[info.json_schema_idx].GetValue(0).ToString(); } + if (info.json_system_prompt_idx != 0) { + system_prompt = args.data[info.json_system_prompt_idx].GetValue(0).ToString(); + } - std::string request_body = "{"; - request_body += "\"model\":\"" + model_name + "\","; + unique_ptr doc( + duckdb_yyjson::yyjson_mut_doc_new(nullptr), &duckdb_yyjson::yyjson_mut_doc_free); + auto obj = duckdb_yyjson::yyjson_mut_obj(doc.get()); + duckdb_yyjson::yyjson_mut_doc_set_root(doc.get(), obj); + duckdb_yyjson::yyjson_mut_obj_add(obj, + duckdb_yyjson::yyjson_mut_str(doc.get(), "model"), + duckdb_yyjson::yyjson_mut_str(doc.get(), model_name.c_str()) + ); if (!json_schema.empty()) { - request_body += "\"response_format\":{\"type\":\"json_object\", \"schema\":"; - request_body += json_schema; - request_body += "},"; + auto response_format = duckdb_yyjson::yyjson_mut_obj(doc.get()); + duckdb_yyjson::yyjson_mut_obj_add(response_format, + duckdb_yyjson::yyjson_mut_str(doc.get(), "type"), + duckdb_yyjson::yyjson_mut_str(doc.get(), "json_object")); + auto yyschema = duckdb_yyjson::yyjson_mut_raw(doc.get(), json_schema.c_str()); + duckdb_yyjson::yyjson_mut_obj_add(response_format, + duckdb_yyjson::yyjson_mut_str(doc.get(), "schema"), + yyschema); + duckdb_yyjson::yyjson_mut_obj_add(obj, + duckdb_yyjson::yyjson_mut_str(doc.get(),"response_format"), + response_format); } - request_body += "\"messages\":["; - request_body += "{\"role\":\"system\",\"content\":\"You are a helpful assistant.\"},"; - request_body += "{\"role\":\"user\",\"content\":\"" + user_prompt.GetString() + "\"}"; - request_body += "]}"; - + auto messages = duckdb_yyjson::yyjson_mut_arr(doc.get()); + string str_messages[2][2] = { + {"system", system_prompt}, + {"user", user_prompt.GetString()} + }; + for (auto message : str_messages) { + if (message[1].empty()) { + continue; + } + auto yymessage = duckdb_yyjson::yyjson_mut_arr_add_obj(doc.get(),messages); + duckdb_yyjson::yyjson_mut_obj_add(yymessage, + duckdb_yyjson::yyjson_mut_str(doc.get(), "role"), + duckdb_yyjson::yyjson_mut_str(doc.get(), message[0].c_str())); + duckdb_yyjson::yyjson_mut_obj_add(yymessage, + duckdb_yyjson::yyjson_mut_str(doc.get(), "content"), + duckdb_yyjson::yyjson_mut_str(doc.get(), message[1].c_str())); + } + duckdb_yyjson::yyjson_mut_obj_add(obj, duckdb_yyjson::yyjson_mut_str(doc.get(), "messages"), + messages); + duckdb_yyjson::yyjson_write_err err; + auto request_body = duckdb_yyjson::yyjson_mut_write_opts(doc.get(), 0, nullptr, nullptr, &err); + if (request_body == nullptr) { + throw std::runtime_error(err.msg); + } + string str_request_body(request_body); + free(request_body); try { auto client_and_path = SetupHttpClient(api_url); @@ -214,7 +259,7 @@ static void OpenPromptRequestFunction(DataChunk &args, ExpressionState &state, V headers.emplace("Authorization", "Bearer " + api_token); } - auto res = client.Post(path.c_str(), headers, request_body, "application/json"); + auto res = client.Post(path.c_str(), headers, str_request_body, "application/json"); if (!res) { HandleHttpError(res, "POST"); @@ -286,10 +331,14 @@ static void LoadInternal(DatabaseInstance &instance) { open_prompt.AddFunction(ScalarFunction( {LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::VARCHAR, OpenPromptRequestFunction, OpenPromptBind)); - open_prompt.AddFunction(ScalarFunction( + open_prompt.AddFunction(ScalarFunction( {LogicalType::VARCHAR, LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::VARCHAR, OpenPromptRequestFunction, OpenPromptBind)); + open_prompt.AddFunction(ScalarFunction( + {LogicalType::VARCHAR, LogicalType::VARCHAR, LogicalType::VARCHAR, LogicalType::VARCHAR}, + LogicalType::VARCHAR, OpenPromptRequestFunction, + OpenPromptBind)); ExtensionUtil::RegisterFunction(instance, open_prompt);