-
Notifications
You must be signed in to change notification settings - Fork 206
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[VLM] Set stop token ids from default generation config in VLM pipeline #1612
base: master
Are you sure you want to change the base?
Conversation
m_inputs_embedder->set_stop_token_ids(generation_config.stop_token_ids); | ||
|
||
auto start_get_inputs_embeds = std::chrono::steady_clock::now(); | ||
ov::Tensor inputs_embeds = m_inputs_embedder->get_inputs_embeds(prompt, rgbs, perf_metrics); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Another approach is to pass generation_config.stop_token_ids
to m_inputs_embedder->get_inputs_embeds()
and then to InputsEmbedder::get_encoded_input_ids()
, but this will require updating all get_inputs_embeds()
methods of all inherited InputsEmbedder classes for existing VLM models.
ef06bb9
to
27a173d
Compare
27a173d
to
15798f8
Compare
15798f8
to
fedb16f
Compare
@@ -194,8 +199,8 @@ class InputsEmbedder::IInputsEmbedder { | |||
// so let's check it out, find the trusted part and use it in on the next step | |||
size_t trusted_history_length = 0; | |||
if (!m_tokenized_history.empty()) { | |||
std::set<int64_t> stop_tokens = {m_tokenizer.get_eos_token_id()}; | |||
trusted_history_length = ov::genai::utils::get_first_history_difference(prev_chat_tokens, m_tokenized_history, stop_tokens); | |||
OPENVINO_ASSERT(!m_stop_token_ids.empty(), "Stop tokens are not set for InputsEmbedder"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why should it be raised error without stop_token_ids? Generation could run and finished without it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As I can see, we do not have generation config in InputsEmbedder, but we do in VLM pipeline (as well as in other pipelines).
Previously we was relying on single stop token retrieved from tokenizer:
std::set<int64_t> stop_tokens = {m_tokenizer.get_eos_token_id()};
Now, when several stop tokens can present in generation config, we need to ensure that the same stop tokens are used in InputsEmbedder. For this I am using m_stop_token_ids
field. This assert checks that we did not forget to call set_stop_token_ids()
in VLM pipeline for InputsEmbedder.
Does it make sense or it is still not needed here? @sbalandi @ilya-lavrenov
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if I understand right it is possible to set generation config that stop_token_ids is empty , in this case it will fail incorrectly
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
agree with @sbalandi
according to our tests
openvino.genai/tests/python_tests/test_generation_config.py
Lines 25 to 26 in 38ab055
dict(max_new_tokens=12), | |
dict(max_length=12), |
absence of any stop token is valid scenario.
In VLM pipeline
generate()
method hasgeneration_config
parameter. Ifeos_token_id
is not set explicitly in this generation config, it is taken from the defaultm_generation_config
and as an effect it is added to thestop_token_ids
.However, if a model (e.g. Qwen2-VL or InternVL2) has multiple
eos_token_id
in itsgeneration_config.json
,stop_token_ids
set is not updated as it includes only oneeos_token_id
.This issue results in excessive output, e.g. when model (Qwen2-VL w/o instruct) adds system messages after real answer utill max new tokens limit is reached.