From 9cedcf86ee1b9200d02866583833ca08ab7c0428 Mon Sep 17 00:00:00 2001 From: akvlad Date: Sat, 26 Oct 2024 15:15:29 +0300 Subject: [PATCH 1/3] yyjson rewrite --- src/open_prompt_extension.cpp | 38 ++++++++++++++++++++++++++++++----- 1 file changed, 33 insertions(+), 5 deletions(-) diff --git a/src/open_prompt_extension.cpp b/src/open_prompt_extension.cpp index cf57b5a..767df49 100644 --- a/src/open_prompt_extension.cpp +++ b/src/open_prompt_extension.cpp @@ -190,13 +190,41 @@ static void OpenPromptRequestFunction(DataChunk &args, ExpressionState &state, V json_schema = args.data[info.json_schema_idx].GetValue(0).ToString(); } - std::string request_body = "{"; - request_body += "\"model\":\"" + model_name + "\","; + unique_ptr doc( + new duckdb_yyjson::yyjson_mut_doc(), &duckdb_yyjson::yyjson_mut_doc_free); + auto obj = duckdb_yyjson::yyjson_mut_obj(doc.get()); + duckdb_yyjson::yyjson_mut_obj_add(obj, + duckdb_yyjson::yyjson_mut_str(doc.get(), "model"), + duckdb_yyjson::yyjson_mut_str(doc.get(), model_name.c_str()) + ); if (!json_schema.empty()) { - request_body += "\"response_format\":{\"type\":\"json_object\", \"schema\":"; - request_body += json_schema; - request_body += "},"; + auto response_format = duckdb_yyjson::yyjson_mut_obj(doc.get()); + duckdb_yyjson::yyjson_mut_obj_add(response_format, + duckdb_yyjson::yyjson_mut_str(doc.get(), "type"), + duckdb_yyjson::yyjson_mut_str(doc.get(), "json_object")); + auto yyschema = duckdb_yyjson::yyjson_mut_raw(doc.get(), json_schema.c_str()); + duckdb_yyjson::yyjson_mut_obj_add(response_format, + duckdb_yyjson::yyjson_mut_str(doc.get(), "schema"), + yyschema); + duckdb_yyjson::yyjson_mut_obj_add(obj, + duckdb_yyjson::yyjson_mut_str(doc.get(),"response_format"), + response_format); + } + auto messages = duckdb_yyjson::yyjson_mut_arr(doc.get()); + string str_messages[2][2] = { + {"system", "You are a helpful assistant."}, + {"user", user_prompt.GetString()} + }; + for (auto message : str_messages) { + if (message[1].empty()) { + continue; + } + auto yymessage = duckdb_yyjson::yyjson_mut_obj(doc.get()); + } + duckdb_yyjson::yyjson_mut_obj_add(obj, duckdb_yyjson::yyjson_mut_str(doc.get(), "messages"), + ) request_body += "\"messages\":["; request_body += "{\"role\":\"system\",\"content\":\"You are a helpful assistant.\"},"; request_body += "{\"role\":\"user\",\"content\":\"" + user_prompt.GetString() + "\"}"; From f84601993b379a596a98a14293511ca4d358d024 Mon Sep 17 00:00:00 2001 From: akvlad Date: Sat, 26 Oct 2024 16:37:53 +0300 Subject: [PATCH 2/3] yyjson fix --- src/open_prompt_extension.cpp | 50 ++++++++++++++++++++++++----------- 1 file changed, 35 insertions(+), 15 deletions(-) diff --git a/src/open_prompt_extension.cpp b/src/open_prompt_extension.cpp index 767df49..b4c8155 100644 --- a/src/open_prompt_extension.cpp +++ b/src/open_prompt_extension.cpp @@ -24,15 +24,18 @@ namespace duckdb { struct OpenPromptData: FunctionData { idx_t model_idx; idx_t json_schema_idx; + idx_t json_system_prompt_idx; unique_ptr Copy() const { auto res = make_uniq(); res->model_idx = model_idx; res->json_schema_idx = json_schema_idx; + res->json_system_prompt_idx = json_system_prompt_idx; return res; }; bool Equals(const FunctionData &other) const { return model_idx == other.Cast().model_idx && - json_schema_idx == other.Cast().json_schema_idx; + json_schema_idx == other.Cast().json_schema_idx && + json_system_prompt_idx==other.Cast().json_system_prompt_idx; }; OpenPromptData() { model_idx = 0; @@ -49,6 +52,8 @@ namespace duckdb { res->model_idx = i; } else if (argument->alias == "json_schema") { res->json_schema_idx = i; + } else if (argument->alias == "system_prompt") { + res->json_system_prompt_idx = i; } } return std::move(res); @@ -182,6 +187,7 @@ static void OpenPromptRequestFunction(DataChunk &args, ExpressionState &state, V std::string api_token = GetConfigValue(context, "openprompt_api_token", ""); std::string model_name = GetConfigValue(context, "openprompt_model_name", "qwen2.5:0.5b"); std::string json_schema; + std::string system_prompt; if (info.model_idx != 0) { model_name = args.data[info.model_idx].GetValue(0).ToString(); @@ -189,11 +195,14 @@ static void OpenPromptRequestFunction(DataChunk &args, ExpressionState &state, V if (info.json_schema_idx != 0) { json_schema = args.data[info.json_schema_idx].GetValue(0).ToString(); } + if (info.json_system_prompt_idx != 0) { + system_prompt = args.data[info.json_system_prompt_idx].GetValue(0).ToString(); + } - unique_ptr doc( - new duckdb_yyjson::yyjson_mut_doc(), &duckdb_yyjson::yyjson_mut_doc_free); + unique_ptr doc( + duckdb_yyjson::yyjson_mut_doc_new(nullptr), &duckdb_yyjson::yyjson_mut_doc_free); auto obj = duckdb_yyjson::yyjson_mut_obj(doc.get()); + duckdb_yyjson::yyjson_mut_doc_set_root(doc.get(), obj); duckdb_yyjson::yyjson_mut_obj_add(obj, duckdb_yyjson::yyjson_mut_str(doc.get(), "model"), duckdb_yyjson::yyjson_mut_str(doc.get(), model_name.c_str()) @@ -213,23 +222,30 @@ static void OpenPromptRequestFunction(DataChunk &args, ExpressionState &state, V } auto messages = duckdb_yyjson::yyjson_mut_arr(doc.get()); string str_messages[2][2] = { - {"system", "You are a helpful assistant."}, + {"system", system_prompt}, {"user", user_prompt.GetString()} }; for (auto message : str_messages) { if (message[1].empty()) { continue; } - auto yymessage = duckdb_yyjson::yyjson_mut_obj(doc.get()); - + auto yymessage = duckdb_yyjson::yyjson_mut_arr_add_obj(doc.get(),messages); + duckdb_yyjson::yyjson_mut_obj_add(yymessage, + duckdb_yyjson::yyjson_mut_str(doc.get(), "role"), + duckdb_yyjson::yyjson_mut_str(doc.get(), message[0].c_str())); + duckdb_yyjson::yyjson_mut_obj_add(yymessage, + duckdb_yyjson::yyjson_mut_str(doc.get(), "content"), + duckdb_yyjson::yyjson_mut_str(doc.get(), message[1].c_str())); } duckdb_yyjson::yyjson_mut_obj_add(obj, duckdb_yyjson::yyjson_mut_str(doc.get(), "messages"), - ) - request_body += "\"messages\":["; - request_body += "{\"role\":\"system\",\"content\":\"You are a helpful assistant.\"},"; - request_body += "{\"role\":\"user\",\"content\":\"" + user_prompt.GetString() + "\"}"; - request_body += "]}"; - + messages); + duckdb_yyjson::yyjson_write_err err; + auto request_body = duckdb_yyjson::yyjson_mut_write_opts(doc.get(), 0, nullptr, nullptr, &err); + if (request_body == nullptr) { + throw std::runtime_error(err.msg); + } + string str_request_body(request_body); + free(request_body); try { auto client_and_path = SetupHttpClient(api_url); @@ -242,7 +258,7 @@ static void OpenPromptRequestFunction(DataChunk &args, ExpressionState &state, V headers.emplace("Authorization", "Bearer " + api_token); } - auto res = client.Post(path.c_str(), headers, request_body, "application/json"); + auto res = client.Post(path.c_str(), headers, str_request_body, "application/json"); if (!res) { HandleHttpError(res, "POST"); @@ -314,10 +330,14 @@ static void LoadInternal(DatabaseInstance &instance) { open_prompt.AddFunction(ScalarFunction( {LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::VARCHAR, OpenPromptRequestFunction, OpenPromptBind)); - open_prompt.AddFunction(ScalarFunction( + open_prompt.AddFunction(ScalarFunction( {LogicalType::VARCHAR, LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::VARCHAR, OpenPromptRequestFunction, OpenPromptBind)); + open_prompt.AddFunction(ScalarFunction( + {LogicalType::VARCHAR, LogicalType::VARCHAR, LogicalType::VARCHAR, LogicalType::VARCHAR}, + LogicalType::VARCHAR, OpenPromptRequestFunction, + OpenPromptBind)); ExtensionUtil::RegisterFunction(instance, open_prompt); From 42c47046485de6a863b7b17c6dfd1cb52e3d0ce2 Mon Sep 17 00:00:00 2001 From: akvlad Date: Sat, 26 Oct 2024 16:42:11 +0300 Subject: [PATCH 3/3] yyjson fix --- src/open_prompt_extension.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/src/open_prompt_extension.cpp b/src/open_prompt_extension.cpp index b4c8155..e6625ba 100644 --- a/src/open_prompt_extension.cpp +++ b/src/open_prompt_extension.cpp @@ -40,6 +40,7 @@ namespace duckdb { OpenPromptData() { model_idx = 0; json_schema_idx = 0; + json_system_prompt_idx = 0; } };