Skip to content

Commit

Permalink
Openprompt polishing (#8)
Browse files Browse the repository at this point in the history
* yyjson rewrite

* yyjson fix

* yyjson fix
  • Loading branch information
akvlad authored Oct 26, 2024
1 parent 060bb45 commit e333d4e
Showing 1 changed file with 62 additions and 13 deletions.
75 changes: 62 additions & 13 deletions src/open_prompt_extension.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,19 +24,23 @@ namespace duckdb {
struct OpenPromptData: FunctionData {
idx_t model_idx;
idx_t json_schema_idx;
idx_t json_system_prompt_idx;
unique_ptr<FunctionData> Copy() const {
auto res = make_uniq<OpenPromptData>();
res->model_idx = model_idx;
res->json_schema_idx = json_schema_idx;
res->json_system_prompt_idx = json_system_prompt_idx;
return res;
};
bool Equals(const FunctionData &other) const {
return model_idx == other.Cast<OpenPromptData>().model_idx &&
json_schema_idx == other.Cast<OpenPromptData>().json_schema_idx;
json_schema_idx == other.Cast<OpenPromptData>().json_schema_idx &&
json_system_prompt_idx==other.Cast<OpenPromptData>().json_system_prompt_idx;
};
OpenPromptData() {
model_idx = 0;
json_schema_idx = 0;
json_system_prompt_idx = 0;
}
};

Expand All @@ -49,6 +53,8 @@ namespace duckdb {
res->model_idx = i;
} else if (argument->alias == "json_schema") {
res->json_schema_idx = i;
} else if (argument->alias == "system_prompt") {
res->json_system_prompt_idx = i;
}
}
return std::move(res);
Expand Down Expand Up @@ -182,26 +188,65 @@ static void OpenPromptRequestFunction(DataChunk &args, ExpressionState &state, V
std::string api_token = GetConfigValue(context, "openprompt_api_token", "");
std::string model_name = GetConfigValue(context, "openprompt_model_name", "qwen2.5:0.5b");
std::string json_schema;
std::string system_prompt;

if (info.model_idx != 0) {
model_name = args.data[info.model_idx].GetValue(0).ToString();
}
if (info.json_schema_idx != 0) {
json_schema = args.data[info.json_schema_idx].GetValue(0).ToString();
}
if (info.json_system_prompt_idx != 0) {
system_prompt = args.data[info.json_system_prompt_idx].GetValue(0).ToString();
}

std::string request_body = "{";
request_body += "\"model\":\"" + model_name + "\",";
unique_ptr<duckdb_yyjson::yyjson_mut_doc, void (*)(duckdb_yyjson::yyjson_mut_doc*)> doc(
duckdb_yyjson::yyjson_mut_doc_new(nullptr), &duckdb_yyjson::yyjson_mut_doc_free);
auto obj = duckdb_yyjson::yyjson_mut_obj(doc.get());
duckdb_yyjson::yyjson_mut_doc_set_root(doc.get(), obj);
duckdb_yyjson::yyjson_mut_obj_add(obj,
duckdb_yyjson::yyjson_mut_str(doc.get(), "model"),
duckdb_yyjson::yyjson_mut_str(doc.get(), model_name.c_str())
);
if (!json_schema.empty()) {
request_body += "\"response_format\":{\"type\":\"json_object\", \"schema\":";
request_body += json_schema;
request_body += "},";
auto response_format = duckdb_yyjson::yyjson_mut_obj(doc.get());
duckdb_yyjson::yyjson_mut_obj_add(response_format,
duckdb_yyjson::yyjson_mut_str(doc.get(), "type"),
duckdb_yyjson::yyjson_mut_str(doc.get(), "json_object"));
auto yyschema = duckdb_yyjson::yyjson_mut_raw(doc.get(), json_schema.c_str());
duckdb_yyjson::yyjson_mut_obj_add(response_format,
duckdb_yyjson::yyjson_mut_str(doc.get(), "schema"),
yyschema);
duckdb_yyjson::yyjson_mut_obj_add(obj,
duckdb_yyjson::yyjson_mut_str(doc.get(),"response_format"),
response_format);
}
request_body += "\"messages\":[";
request_body += "{\"role\":\"system\",\"content\":\"You are a helpful assistant.\"},";
request_body += "{\"role\":\"user\",\"content\":\"" + user_prompt.GetString() + "\"}";
request_body += "]}";

auto messages = duckdb_yyjson::yyjson_mut_arr(doc.get());
string str_messages[2][2] = {
{"system", system_prompt},
{"user", user_prompt.GetString()}
};
for (auto message : str_messages) {
if (message[1].empty()) {
continue;
}
auto yymessage = duckdb_yyjson::yyjson_mut_arr_add_obj(doc.get(),messages);
duckdb_yyjson::yyjson_mut_obj_add(yymessage,
duckdb_yyjson::yyjson_mut_str(doc.get(), "role"),
duckdb_yyjson::yyjson_mut_str(doc.get(), message[0].c_str()));
duckdb_yyjson::yyjson_mut_obj_add(yymessage,
duckdb_yyjson::yyjson_mut_str(doc.get(), "content"),
duckdb_yyjson::yyjson_mut_str(doc.get(), message[1].c_str()));
}
duckdb_yyjson::yyjson_mut_obj_add(obj, duckdb_yyjson::yyjson_mut_str(doc.get(), "messages"),
messages);
duckdb_yyjson::yyjson_write_err err;
auto request_body = duckdb_yyjson::yyjson_mut_write_opts(doc.get(), 0, nullptr, nullptr, &err);
if (request_body == nullptr) {
throw std::runtime_error(err.msg);
}
string str_request_body(request_body);
free(request_body);

try {
auto client_and_path = SetupHttpClient(api_url);
Expand All @@ -214,7 +259,7 @@ static void OpenPromptRequestFunction(DataChunk &args, ExpressionState &state, V
headers.emplace("Authorization", "Bearer " + api_token);
}

auto res = client.Post(path.c_str(), headers, request_body, "application/json");
auto res = client.Post(path.c_str(), headers, str_request_body, "application/json");

if (!res) {
HandleHttpError(res, "POST");
Expand Down Expand Up @@ -286,10 +331,14 @@ static void LoadInternal(DatabaseInstance &instance) {
open_prompt.AddFunction(ScalarFunction(
{LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::VARCHAR, OpenPromptRequestFunction,
OpenPromptBind));
open_prompt.AddFunction(ScalarFunction(
open_prompt.AddFunction(ScalarFunction(
{LogicalType::VARCHAR, LogicalType::VARCHAR, LogicalType::VARCHAR},
LogicalType::VARCHAR, OpenPromptRequestFunction,
OpenPromptBind));
open_prompt.AddFunction(ScalarFunction(
{LogicalType::VARCHAR, LogicalType::VARCHAR, LogicalType::VARCHAR, LogicalType::VARCHAR},
LogicalType::VARCHAR, OpenPromptRequestFunction,
OpenPromptBind));

ExtensionUtil::RegisterFunction(instance, open_prompt);

Expand Down

0 comments on commit e333d4e

Please sign in to comment.