From 3b8250a2690b66c94e9049c1de5e7f03284e59ce Mon Sep 17 00:00:00 2001 From: mzegla Date: Fri, 6 Sep 2024 16:50:28 +0200 Subject: [PATCH] revert request class move --- ci/cppclean.sh | 4 +- src/llm/apis/openai_completions.cpp | 193 +++++++--------------------- src/llm/apis/openai_completions.hpp | 91 ++++++++++++- 3 files changed, 138 insertions(+), 150 deletions(-) diff --git a/ci/cppclean.sh b/ci/cppclean.sh index f327739578..40b0a3cfcf 100755 --- a/ci/cppclean.sh +++ b/ci/cppclean.sh @@ -42,7 +42,7 @@ fi if [ ${NO_WARNINGS_DIRECT} -gt 15 ]; then errors+="Failed probably due to not using static keyword with functions definitions: ${NO_WARNINGS_DIRECT}"$'\n' fi -if [ ${NO_WARNINGS_NOTUSED} -gt 5 ]; then +if [ ${NO_WARNINGS_NOTUSED} -gt 4 ]; then errors+="Failed probably due to unnecessary forward includes: ${NO_WARNINGS_NOTUSED}"$'\n' fi if [ ${NO_WARNINGS_TEST_FORWARD} -gt 1 ]; then @@ -54,7 +54,7 @@ fi if [ ${NO_WARNINGS_TEST_NOTUSED} -gt 0 ]; then errors+="Failed probably due to unnecessary forward includes: ${NO_WARNINGS_TEST_NOTUSED}"$'\n' fi -if [ ${NO_WARNINGS} -gt 165 ]; then +if [ ${NO_WARNINGS} -gt 164 ]; then errors+="Failed due to higher than allowed number of issues in code: ${NO_WARNINGS}"$'\n' fi if [ ${NO_WARNINGS_TEST} -gt 52 ]; then diff --git a/src/llm/apis/openai_completions.cpp b/src/llm/apis/openai_completions.cpp index 1391977b64..84435c1ba0 100644 --- a/src/llm/apis/openai_completions.cpp +++ b/src/llm/apis/openai_completions.cpp @@ -16,8 +16,6 @@ #include "openai_completions.hpp" -#include - #include #include @@ -28,97 +26,6 @@ using namespace rapidjson; namespace ovms { -// Class that maps OpenAI request content and provides methods to create GenerationConfig from it. -struct OpenAIChatCompletionsRequest { - chat_t messages; - std::optional prompt{std::nullopt}; - bool stream{false}; - StreamOptions streamOptions; - std::string model; - std::optional maxTokens{std::nullopt}; - std::optional frequencyPenalty{std::nullopt}; - std::optional presencePenalty{std::nullopt}; - std::optional diversityPenalty{std::nullopt}; - std::optional repetitionPenalty{std::nullopt}; - std::optional lengthPenalty{std::nullopt}; - std::optional numReturnSequences{std::nullopt}; - std::optional temperature{std::nullopt}; - std::optional topP{std::nullopt}; - std::optional topK{std::nullopt}; - std::optional seed{std::nullopt}; - std::optional> stop{std::nullopt}; - std::optional includeStopStrInOutput{std::nullopt}; - std::optional bestOf{std::nullopt}; - std::optional ignoreEOS{std::nullopt}; - - OpenAIChatCompletionsRequest() = default; - ~OpenAIChatCompletionsRequest() = default; - - ov::genai::GenerationConfig createGenerationConfig() { - ov::genai::GenerationConfig config; - - // Generic - if (maxTokens.has_value()) - config.max_new_tokens = maxTokens.value(); - // TODO: max_length = ? - if (ignoreEOS.has_value()) - config.ignore_eos = ignoreEOS.value(); - - // Beam search specific - config.num_beam_groups = 1; // OpenAI hardcoded - config.num_beams = 1; // OpenAI hardcoded - config.no_repeat_ngram_size = std::numeric_limits::max(); - - if (bestOf.has_value()) - config.num_beams = bestOf.value(); - - if (diversityPenalty.has_value()) - config.diversity_penalty = diversityPenalty.value(); // TODO: Not available in OpenAI nor vLLM - // TODO: stop_criteria = ? - if (numReturnSequences.has_value()) - config.num_return_sequences = numReturnSequences.value(); - if (repetitionPenalty.has_value()) - config.repetition_penalty = repetitionPenalty.value(); - if (lengthPenalty.has_value()) - config.length_penalty = lengthPenalty.value(); - // TODO: no_repeat_ngram_size = ? - // TODO: early_finish = ? - // TODO use_beam_search is unused ? - - // Multinomial specific - if (temperature.has_value()) - config.temperature = temperature.value(); - if (topK.has_value()) - config.top_k = topK.value(); - if (topP.has_value()) - config.top_p = topP.value(); - if (seed.has_value()) - config.rng_seed = seed.value(); - if (stop.has_value()) - config.stop_strings = stop.value(); - if (includeStopStrInOutput.has_value()) - config.include_stop_str_in_output = includeStopStrInOutput.value(); - if (frequencyPenalty.has_value()) - config.frequency_penalty = frequencyPenalty.value(); - if (presencePenalty.has_value()) - config.presence_penalty = presencePenalty.value(); - config.do_sample = config.temperature > 0.0f && config.num_beams == 1; - - return config; - } -}; - -OpenAIChatCompletionsHandler::OpenAIChatCompletionsHandler(Document& doc, Endpoint endpoint, std::chrono::time_point creationTime, - ov::genai::Tokenizer tokenizer) : - doc(doc), - endpoint(endpoint), - created(creationTime), - tokenizer(tokenizer) { - request = new OpenAIChatCompletionsRequest; -} - -OpenAIChatCompletionsHandler::~OpenAIChatCompletionsHandler() { delete request; } - absl::Status OpenAIChatCompletionsHandler::parseCompletionsPart() { // prompt: string auto it = doc.FindMember("prompt"); @@ -126,10 +33,10 @@ absl::Status OpenAIChatCompletionsHandler::parseCompletionsPart() { if (!it->value.IsString()) { return absl::InvalidArgumentError("prompt is not a string"); } else { - request->prompt = it->value.GetString(); + request.prompt = it->value.GetString(); } } - if (!request->prompt.has_value() || !request->prompt.value().size()) { + if (!request.prompt.has_value() || !request.prompt.value().size()) { return absl::Status(absl::StatusCode::kInvalidArgument, "prompt is missing"); } return absl::OkStatus(); @@ -144,13 +51,13 @@ absl::Status OpenAIChatCompletionsHandler::parseChatCompletionsPart() { return absl::InvalidArgumentError("Messages are not an array"); if (it->value.GetArray().Size() == 0) return absl::InvalidArgumentError("Messages array cannot be empty"); - request->messages.clear(); - request->messages.reserve(it->value.GetArray().Size()); + request.messages.clear(); + request.messages.reserve(it->value.GetArray().Size()); for (size_t i = 0; i < it->value.GetArray().Size(); i++) { const auto& obj = it->value.GetArray()[i]; if (!obj.IsObject()) return absl::InvalidArgumentError("Message is not a JSON object"); - auto& chat = request->messages.emplace_back(chat_entry_t{}); + auto& chat = request.messages.emplace_back(chat_entry_t{}); for (auto member = obj.MemberBegin(); member != obj.MemberEnd(); member++) { if (!member->name.IsString()) return absl::InvalidArgumentError("Invalid message structure"); @@ -160,7 +67,7 @@ absl::Status OpenAIChatCompletionsHandler::parseChatCompletionsPart() { } } - if (request->messages.size() <= 0) { + if (request.messages.size() <= 0) { return absl::Status(absl::StatusCode::kInvalidArgument, "messages are missing"); } return absl::OkStatus(); @@ -175,12 +82,12 @@ absl::Status OpenAIChatCompletionsHandler::parseCommonPart(uint32_t maxTokensLim if (it != doc.MemberEnd()) { if (!it->value.IsBool()) return absl::InvalidArgumentError("Stream is not bool"); - request->stream = it->value.GetBool(); + request.stream = it->value.GetBool(); } it = doc.FindMember("stream_options"); if (it != doc.MemberEnd()) { - if (!request->stream) + if (!request.stream) return absl::InvalidArgumentError("stream_options provided, but stream not set to true"); if (!it->value.IsObject()) return absl::InvalidArgumentError("stream_options is not an object"); @@ -191,7 +98,7 @@ absl::Status OpenAIChatCompletionsHandler::parseCommonPart(uint32_t maxTokensLim if (it != streamOptionsObj.MemberEnd()) { if (!it->value.IsBool()) return absl::InvalidArgumentError("stream_options.include_usage is not a boolean"); - request->streamOptions.includeUsage = it->value.GetBool(); + request.streamOptions.includeUsage = it->value.GetBool(); streamOptionsFound++; } @@ -205,7 +112,7 @@ absl::Status OpenAIChatCompletionsHandler::parseCommonPart(uint32_t maxTokensLim if (it != doc.MemberEnd()) { if (!it->value.IsString()) return absl::InvalidArgumentError("model is not a string"); - request->model = it->value.GetString(); + request.model = it->value.GetString(); } else { return absl::InvalidArgumentError("model missing in request"); } @@ -216,7 +123,7 @@ absl::Status OpenAIChatCompletionsHandler::parseCommonPart(uint32_t maxTokensLim if (it != doc.MemberEnd()) { if (!it->value.IsBool()) return absl::InvalidArgumentError("ignore_eos accepts values true or false"); - request->ignoreEOS = it->value.GetBool(); + request.ignoreEOS = it->value.GetBool(); } // max_tokens: uint; optional @@ -231,14 +138,14 @@ absl::Status OpenAIChatCompletionsHandler::parseCommonPart(uint32_t maxTokensLim return absl::InvalidArgumentError("max_tokens value should be greater than 0"); if (!(it->value.GetUint() < maxTokensLimit)) return absl::InvalidArgumentError(absl::StrCat("max_tokens exceeds limit provided in graph config: ", maxTokensLimit)); - request->maxTokens = it->value.GetUint(); + request.maxTokens = it->value.GetUint(); } - if (request->ignoreEOS.value_or(false)) { - if (request->maxTokens.has_value()) { + if (request.ignoreEOS.value_or(false)) { + if (request.maxTokens.has_value()) { if (it->value.GetUint() > IGNORE_EOS_MAX_TOKENS_LIMIT) return absl::InvalidArgumentError("when ignore_eos is true max_tokens can not be greater than 4000"); } else { - request->maxTokens = IGNORE_EOS_MAX_TOKENS_LIMIT; + request.maxTokens = IGNORE_EOS_MAX_TOKENS_LIMIT; } } @@ -247,8 +154,8 @@ absl::Status OpenAIChatCompletionsHandler::parseCommonPart(uint32_t maxTokensLim if (it != doc.MemberEnd()) { if (!it->value.IsDouble() && !it->value.IsInt()) return absl::InvalidArgumentError("frequency_penalty is not a valid number"); - request->frequencyPenalty = it->value.GetDouble(); - if (request->frequencyPenalty < -2.0f || request->frequencyPenalty > 2.0f) + request.frequencyPenalty = it->value.GetDouble(); + if (request.frequencyPenalty < -2.0f || request.frequencyPenalty > 2.0f) return absl::InvalidArgumentError("frequency_penalty out of range(-2.0, 2.0)"); } @@ -257,8 +164,8 @@ absl::Status OpenAIChatCompletionsHandler::parseCommonPart(uint32_t maxTokensLim if (it != doc.MemberEnd()) { if (!it->value.IsDouble() && !it->value.IsInt()) return absl::InvalidArgumentError("presence_penalty is not a valid number"); - request->presencePenalty = it->value.GetDouble(); - if (request->presencePenalty < -2.0f || request->presencePenalty > 2.0f) + request.presencePenalty = it->value.GetDouble(); + if (request.presencePenalty < -2.0f || request.presencePenalty > 2.0f) return absl::InvalidArgumentError("presence_penalty out of range(-2.0, 2.0)"); } @@ -268,7 +175,7 @@ absl::Status OpenAIChatCompletionsHandler::parseCommonPart(uint32_t maxTokensLim if (it != doc.MemberEnd()) { if (!it->value.IsDouble() && !it->value.IsInt()) return absl::InvalidArgumentError("repetition_penalty is not a valid number"); - request->repetitionPenalty = it->value.GetDouble(); + request.repetitionPenalty = it->value.GetDouble(); } // diversity_penalty: float; optional - defaults to 1.0 @@ -277,7 +184,7 @@ absl::Status OpenAIChatCompletionsHandler::parseCommonPart(uint32_t maxTokensLim if (it != doc.MemberEnd()) { if (!it->value.IsDouble() && !it->value.IsInt()) return absl::InvalidArgumentError("diversity_penalty is not a valid number"); - request->diversityPenalty = it->value.GetDouble(); + request.diversityPenalty = it->value.GetDouble(); } // length_penalty: float; optional - defaults to 1.0 @@ -286,7 +193,7 @@ absl::Status OpenAIChatCompletionsHandler::parseCommonPart(uint32_t maxTokensLim if (it != doc.MemberEnd()) { if (!it->value.IsDouble() && !it->value.IsInt()) return absl::InvalidArgumentError("length_penalty is not a valid number"); - request->lengthPenalty = it->value.GetDouble(); + request.lengthPenalty = it->value.GetDouble(); } // temperature: float; optional - defaults to 0.0 (different than OpenAI which is 1.0) @@ -294,8 +201,8 @@ absl::Status OpenAIChatCompletionsHandler::parseCommonPart(uint32_t maxTokensLim if (it != doc.MemberEnd()) { if (!it->value.IsDouble() && !it->value.IsInt()) return absl::InvalidArgumentError("temperature is not a valid number"); - request->temperature = it->value.GetDouble(); - if (request->temperature < 0.0f || request->temperature > 2.0f) + request.temperature = it->value.GetDouble(); + if (request.temperature < 0.0f || request.temperature > 2.0f) return absl::InvalidArgumentError("temperature out of range(0.0, 2.0)"); } @@ -304,8 +211,8 @@ absl::Status OpenAIChatCompletionsHandler::parseCommonPart(uint32_t maxTokensLim if (it != doc.MemberEnd()) { if (!it->value.IsDouble() && !it->value.IsInt()) return absl::InvalidArgumentError("top_p is not a valid number"); - request->topP = it->value.GetDouble(); - if (request->topP < 0.0f || request->topP > 1.0f) + request.topP = it->value.GetDouble(); + if (request.topP < 0.0f || request.topP > 1.0f) return absl::InvalidArgumentError("top_p out of range(0.0, 1.0)"); } @@ -315,7 +222,7 @@ absl::Status OpenAIChatCompletionsHandler::parseCommonPart(uint32_t maxTokensLim if (it != doc.MemberEnd()) { if (!it->value.IsInt()) return absl::InvalidArgumentError("top_k is not an integer"); - request->topK = it->value.GetInt(); + request.topK = it->value.GetInt(); } // seed: int; optional - defaults to 0 (not set) @@ -323,26 +230,26 @@ absl::Status OpenAIChatCompletionsHandler::parseCommonPart(uint32_t maxTokensLim if (it != doc.MemberEnd()) { if (!it->value.IsUint()) return absl::InvalidArgumentError("seed is not an unsigned integer"); - request->seed = it->value.GetUint(); + request.seed = it->value.GetUint(); } // stop: string or array; optional - defaults to null (not set) it = doc.FindMember("stop"); if (it != doc.MemberEnd()) { if (it->value.IsString()) { - request->stop = std::set{it->value.GetString()}; + request.stop = std::set{it->value.GetString()}; } else if (it->value.IsArray()) { auto stopArray = it->value.GetArray(); // TODO: OpenAI API defines upper bound but do we want it? if (stopArray.Size() < 1 || stopArray.Size() > 4) return absl::InvalidArgumentError("stop array must have a least 1 and no more than 4 strings"); - request->stop = std::set{}; + request.stop = std::set{}; for (size_t i = 0; i < stopArray.Size(); i++) { const auto& element = stopArray[i]; if (!element.IsString()) return absl::InvalidArgumentError("stop array contains non string element"); - request->stop->insert(element.GetString()); + request.stop->insert(element.GetString()); } } else { return absl::InvalidArgumentError("stop is not a string or array of strings"); @@ -353,17 +260,17 @@ absl::Status OpenAIChatCompletionsHandler::parseCommonPart(uint32_t maxTokensLim // Extension, unsupported by OpenAI API, however supported by vLLM and CB lib // If stream is true, then include stop string in output by default - if (request->stream) { - request->includeStopStrInOutput = true; + if (request.stream) { + request.includeStopStrInOutput = true; } it = doc.FindMember("include_stop_str_in_output"); if (it != doc.MemberEnd()) { if (!it->value.IsBool()) return absl::InvalidArgumentError("include_stop_str_in_output accepts values true or false"); - if (!it->value.GetBool() && request->stream) + if (!it->value.GetBool() && request.stream) return absl::InvalidArgumentError("include_stop_str_in_output cannot be set to false if streaming is used"); - request->includeStopStrInOutput = it->value.GetBool(); + request.includeStopStrInOutput = it->value.GetBool(); } // best_of: int; optional - defaults to 1 @@ -376,7 +283,7 @@ absl::Status OpenAIChatCompletionsHandler::parseCommonPart(uint32_t maxTokensLim return absl::InvalidArgumentError("best_of value should be greater than 0"); if (!(it->value.GetUint() < bestOfLimit)) return absl::InvalidArgumentError(absl::StrCat("best_of exceeds limit provided in graph config: ", bestOfLimit)); - request->bestOf = it->value.GetUint(); + request.bestOf = it->value.GetUint(); } // n: int; optional - defaults to 1 @@ -387,11 +294,11 @@ absl::Status OpenAIChatCompletionsHandler::parseCommonPart(uint32_t maxTokensLim return absl::InvalidArgumentError("n is not an unsigned integer"); if (it->value.GetUint() == 0) return absl::InvalidArgumentError("n value should be greater than 0"); - size_t bestOf = request->bestOf.has_value() ? request->bestOf.value() : 1; // 1 is default best_of value + size_t bestOf = request.bestOf.has_value() ? request.bestOf.value() : 1; // 1 is default best_of value if (bestOf < it->value.GetUint()) { return absl::InvalidArgumentError("n value cannot be greater than best_of"); } - request->numReturnSequences = it->value.GetUint(); + request.numReturnSequences = it->value.GetUint(); } // use_beam_search: bool; optional - defaults to false @@ -401,7 +308,7 @@ absl::Status OpenAIChatCompletionsHandler::parseCommonPart(uint32_t maxTokensLim // if (it != doc.MemberEnd()) { // if (!it->value.IsBool()) // return false; - // request->useBeamSearch = it->value.GetBool(); + // request.useBeamSearch = it->value.GetBool(); // } // logit_bias TODO @@ -416,12 +323,12 @@ absl::Status OpenAIChatCompletionsHandler::parseCommonPart(uint32_t maxTokensLim return absl::OkStatus(); } -std::optional OpenAIChatCompletionsHandler::getPrompt() const { return request->prompt; } -std::optional OpenAIChatCompletionsHandler::getNumReturnSequences() const { return request->numReturnSequences; } -StreamOptions OpenAIChatCompletionsHandler::getStreamOptions() const { return request->streamOptions; } +std::optional OpenAIChatCompletionsHandler::getPrompt() const { return request.prompt; } +std::optional OpenAIChatCompletionsHandler::getNumReturnSequences() const { return request.numReturnSequences; } +StreamOptions OpenAIChatCompletionsHandler::getStreamOptions() const { return request.streamOptions; } -bool OpenAIChatCompletionsHandler::isStream() const { return request->stream; } -std::string OpenAIChatCompletionsHandler::getModel() const { return request->model; } +bool OpenAIChatCompletionsHandler::isStream() const { return request.stream; } +std::string OpenAIChatCompletionsHandler::getModel() const { return request.model; } void OpenAIChatCompletionsHandler::setPromptTokensUsage(int promptTokens) { usage.promptTokens = promptTokens; @@ -432,7 +339,7 @@ void OpenAIChatCompletionsHandler::incrementCompletionTokensUsage() { } ov::genai::GenerationConfig OpenAIChatCompletionsHandler::createGenerationConfig() const { - return request->createGenerationConfig(); + return request.createGenerationConfig(); } absl::Status OpenAIChatCompletionsHandler::parseRequest(uint32_t maxTokensLimit, uint32_t bestOfLimit) { @@ -460,7 +367,7 @@ std::string OpenAIChatCompletionsHandler::serializeUnaryResponse(const std::vect writer.String("choices"); writer.StartArray(); // [ int i = 0; - int n = request->numReturnSequences.value_or(1); + int n = request.numReturnSequences.value_or(1); usage.completionTokens = 0; for (const ov::genai::GenerationOutput& generationOutput : generationOutputs) { if (i >= n) @@ -519,7 +426,7 @@ std::string OpenAIChatCompletionsHandler::serializeUnaryResponse(const std::vect // model: string; copied from the request writer.String("model"); - writer.String(request->model.c_str()); + writer.String(request.model.c_str()); // object: string; defined that the type is unary rather than streamed chunk if (endpoint == Endpoint::CHAT_COMPLETIONS) { @@ -610,7 +517,7 @@ std::string OpenAIChatCompletionsHandler::serializeStreamingChunk(const std::str // model: string; copied from the request writer.String("model"); - writer.String(request->model.c_str()); + writer.String(request.model.c_str()); // object: string; defined that the type streamed chunk rather than complete response if (endpoint == Endpoint::CHAT_COMPLETIONS) { @@ -621,7 +528,7 @@ std::string OpenAIChatCompletionsHandler::serializeStreamingChunk(const std::str writer.String("text_completion.chunk"); } - if (request->streamOptions.includeUsage) { + if (request.streamOptions.includeUsage) { writer.String("usage"); writer.Null(); } @@ -654,7 +561,7 @@ std::string OpenAIChatCompletionsHandler::serializeStreamingUsageChunk() { // model: string; copied from the request writer.String("model"); - writer.String(request->model.c_str()); + writer.String(request.model.c_str()); // object: string; defined that the type streamed chunk rather than complete response if (endpoint == Endpoint::CHAT_COMPLETIONS) { diff --git a/src/llm/apis/openai_completions.hpp b/src/llm/apis/openai_completions.hpp index bc108188ac..183dae7793 100644 --- a/src/llm/apis/openai_completions.hpp +++ b/src/llm/apis/openai_completions.hpp @@ -15,6 +15,7 @@ //***************************************************************************** #pragma once +#include #include #include #include @@ -56,7 +57,85 @@ struct CompletionUsageStatistics { } }; -struct OpenAIChatCompletionsRequest; +// Class that maps OpenAI request content and provides methods to create GenerationConfig from it. +struct OpenAIChatCompletionsRequest { + chat_t messages; + std::optional prompt{std::nullopt}; + bool stream{false}; + StreamOptions streamOptions; + std::string model; + std::optional maxTokens{std::nullopt}; + std::optional frequencyPenalty{std::nullopt}; + std::optional presencePenalty{std::nullopt}; + std::optional diversityPenalty{std::nullopt}; + std::optional repetitionPenalty{std::nullopt}; + std::optional lengthPenalty{std::nullopt}; + std::optional numReturnSequences{std::nullopt}; + std::optional temperature{std::nullopt}; + std::optional topP{std::nullopt}; + std::optional topK{std::nullopt}; + std::optional seed{std::nullopt}; + std::optional> stop{std::nullopt}; + std::optional includeStopStrInOutput{std::nullopt}; + std::optional bestOf{std::nullopt}; + std::optional ignoreEOS{std::nullopt}; + + OpenAIChatCompletionsRequest() = default; + ~OpenAIChatCompletionsRequest() = default; + + ov::genai::GenerationConfig createGenerationConfig() const { + ov::genai::GenerationConfig config; + + // Generic + if (maxTokens.has_value()) + config.max_new_tokens = maxTokens.value(); + // TODO: max_length = ? + if (ignoreEOS.has_value()) + config.ignore_eos = ignoreEOS.value(); + + // Beam search specific + config.num_beam_groups = 1; // OpenAI hardcoded + config.num_beams = 1; // OpenAI hardcoded + config.no_repeat_ngram_size = std::numeric_limits::max(); + + if (bestOf.has_value()) + config.num_beams = bestOf.value(); + + if (diversityPenalty.has_value()) + config.diversity_penalty = diversityPenalty.value(); // TODO: Not available in OpenAI nor vLLM + // TODO: stop_criteria = ? + if (numReturnSequences.has_value()) + config.num_return_sequences = numReturnSequences.value(); + if (repetitionPenalty.has_value()) + config.repetition_penalty = repetitionPenalty.value(); + if (lengthPenalty.has_value()) + config.length_penalty = lengthPenalty.value(); + // TODO: no_repeat_ngram_size = ? + // TODO: early_finish = ? + // TODO use_beam_search is unused ? + + // Multinomial specific + if (temperature.has_value()) + config.temperature = temperature.value(); + if (topK.has_value()) + config.top_k = topK.value(); + if (topP.has_value()) + config.top_p = topP.value(); + if (seed.has_value()) + config.rng_seed = seed.value(); + if (stop.has_value()) + config.stop_strings = stop.value(); + if (includeStopStrInOutput.has_value()) + config.include_stop_str_in_output = includeStopStrInOutput.value(); + if (frequencyPenalty.has_value()) + config.frequency_penalty = frequencyPenalty.value(); + if (presencePenalty.has_value()) + config.presence_penalty = presencePenalty.value(); + config.do_sample = config.temperature > 0.0f && config.num_beams == 1; + + return config; + } +}; // Class that wraps OpenAI request, holds and processes raw JSON, provides methods for serialization and keeps track of usage. // It is used in the calculator. @@ -64,7 +143,7 @@ class OpenAIChatCompletionsHandler { Document& doc; Endpoint endpoint; CompletionUsageStatistics usage; - OpenAIChatCompletionsRequest* request = nullptr; + OpenAIChatCompletionsRequest request; std::chrono::time_point created; ov::genai::Tokenizer tokenizer; @@ -74,9 +153,11 @@ class OpenAIChatCompletionsHandler { public: OpenAIChatCompletionsHandler(Document& doc, Endpoint endpoint, std::chrono::time_point creationTime, - ov::genai::Tokenizer tokenizer); - - ~OpenAIChatCompletionsHandler(); + ov::genai::Tokenizer tokenizer) : + doc(doc), + endpoint(endpoint), + created(creationTime), + tokenizer(tokenizer) {} std::optional getPrompt() const; std::optional getNumReturnSequences() const;