Skip to content

Commit

Permalink
Set stop token ids from generation config to VLM pipeline
Browse files Browse the repository at this point in the history
  • Loading branch information
yatarkan committed Jan 21, 2025
1 parent d4504e3 commit 7ecbb82
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 3 deletions.
3 changes: 2 additions & 1 deletion src/cpp/src/generation_config.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,8 @@ void GenerationConfig::set_eos_token_id(size_t tokenizer_eos_token_id) {
tokenizer_eos_token_id, ")");
}
// Merge user defined stop tokens with model EOS token
stop_token_ids.insert(eos_token_id);
if (stop_token_ids.find(eos_token_id) == stop_token_ids.end())
stop_token_ids.insert(eos_token_id);
}

void GenerationConfig::update_generation_config(const ov::AnyMap& properties) {
Expand Down
13 changes: 11 additions & 2 deletions src/cpp/src/visual_language/inputs_embedder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ class InputsEmbedder::IInputsEmbedder {
// so, let's keep info about amount of tokens to trim from kv cache and amount of tokens to keep in history
ov::genai::utils::HistoryRemoveManager m_kv_history_manager = {0, 0};

std::set<int64_t> m_stop_token_ids;
public:
virtual ov::Tensor get_inputs_embeds(const std::string& prompt, const std::vector<ov::Tensor>& images, ov::genai::VLMPerfMetrics& metrics) = 0;

Expand All @@ -74,6 +75,10 @@ class InputsEmbedder::IInputsEmbedder {
return m_kv_history_manager.num_tokens_to_remove_from_kv_cache;
}

void set_stop_token_ids(const std::set<int64_t>& stop_token_ids) {
m_stop_token_ids = stop_token_ids;
}

void update_tokenized_history(const std::vector<int64_t>& encoded_result, std::optional<int64_t> last_disappeared_token, bool is_beam_search, size_t last_answer_len) {
if (is_beam_search) {
m_kv_history_manager.trusted_history_length = m_tokenized_history.size();
Expand Down Expand Up @@ -186,8 +191,8 @@ class InputsEmbedder::IInputsEmbedder {
// so let's check it out, find the trusted part and use it in on the next step
size_t trusted_history_length = 0;
if (!m_tokenized_history.empty()) {
std::set<int64_t> stop_tokens = {m_tokenizer.get_eos_token_id()};
trusted_history_length = ov::genai::utils::get_first_history_difference(prev_chat_tokens.input_ids, m_tokenized_history, stop_tokens);
OPENVINO_ASSERT(!m_stop_token_ids.empty(), "Stop tokens are not set for InputsEmbedder");
trusted_history_length = ov::genai::utils::get_first_history_difference(prev_chat_tokens.input_ids, m_tokenized_history, m_stop_token_ids);
}

if (m_tokenized_history.empty()) {
Expand Down Expand Up @@ -1617,6 +1622,10 @@ EmbeddingsModel InputsEmbedder::get_embedding_model() const {
return m_impl->get_embedding_model();
}

void InputsEmbedder::set_stop_token_ids(const std::set<int64_t>& stop_token_ids) {
return m_impl->set_stop_token_ids(stop_token_ids);
}

std::vector<int64_t> InputsEmbedder::get_tokenized_history() const {
return m_impl->get_tokenized_history();
}
Expand Down
2 changes: 2 additions & 0 deletions src/cpp/src/visual_language/inputs_embedder.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ class InputsEmbedder {
// returns tokenizer
Tokenizer get_tokenizer() const;

void set_stop_token_ids(const std::set<int64_t>& stop_token_ids);

// returns tokenized chat history
std::vector<int64_t> get_tokenized_history() const;

Expand Down
5 changes: 5 additions & 0 deletions src/cpp/src/visual_language/pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -160,10 +160,15 @@ class ov::genai::VLMPipeline::VLMPipelineImpl {
VLMPerfMetrics perf_metrics;
auto& raw_counters = perf_metrics.raw_metrics;
auto& raw_vlm_counters = perf_metrics.vlm_raw_metrics;
// If stop_token_ids were not provided, take value from default m_generation_config
if (generation_config.stop_token_ids.empty())
generation_config.stop_token_ids = m_generation_config.stop_token_ids;
// If eos_token_id was not provided, take value from default m_generation_config
if (generation_config.eos_token_id == -1)
generation_config.set_eos_token_id(m_generation_config.eos_token_id);
generation_config.validate();

m_inputs_embedder->set_stop_token_ids(generation_config.stop_token_ids);

auto start_get_inputs_embeds = std::chrono::steady_clock::now();
ov::Tensor inputs_embeds = m_inputs_embedder->get_inputs_embeds(prompt, rgbs, perf_metrics);
Expand Down

0 comments on commit 7ecbb82

Please sign in to comment.