From 2dab07bca0b45f55ea9e7a6b634bd1088c18c252 Mon Sep 17 00:00:00 2001 From: Oleg Pipikin Date: Tue, 21 Jan 2025 13:58:14 +0100 Subject: [PATCH 1/7] Update samples readme (#1545) Apply the rest comments from https://github.com/openvinotoolkit/openvino.genai/pull/1411 --- samples/cpp/text_generation/README.md | 75 +++++++++++++++--------- samples/python/text_generation/README.md | 73 ++++++++++++++--------- 2 files changed, 91 insertions(+), 57 deletions(-) diff --git a/samples/cpp/text_generation/README.md b/samples/cpp/text_generation/README.md index d9e5bd8d22..f370c74a80 100644 --- a/samples/cpp/text_generation/README.md +++ b/samples/cpp/text_generation/README.md @@ -2,7 +2,7 @@ These samples showcase the use of OpenVINO's inference capabilities for text generation tasks, including different decoding strategies such as beam search, multinomial sampling, and speculative decoding. Each sample has a specific focus and demonstrates a unique aspect of text generation. The applications don't have many configuration options to encourage the reader to explore and modify the source code. For example, change the device for inference to GPU. -There are also Jupyter notebooks for some samples. You can find links to them in the appropriate sample descritions. +There are also Jupyter notebooks for some samples. You can find links to them in the appropriate sample descriptions. ## Table of Contents 1. [Download and Convert the Model and Tokenizers](#download-and-convert-the-model-and-tokenizers) @@ -11,25 +11,50 @@ There are also Jupyter notebooks for some samples. You can find links to them in 4. [Support and Contribution](#support-and-contribution) ## Download and convert the model and tokenizers - The `--upgrade-strategy eager` option is needed to ensure `optimum-intel` is upgraded to the latest version. - -It's not required to install [../../export-requirements.txt](../../export-requirements.txt) for deployment if the model has already been exported. - +Install [../../export-requirements.txt](../../export-requirements.txt) if model conversion is required. ```sh -pip install --upgrade-strategy eager -r ../../requirements.txt +pip install --upgrade-strategy eager -r ../../export-requirements.txt optimim-cli export openvino --model ``` +If a converted model in OpenVINO IR format is already available in the collection of [OpenVINO optimized LLMs](https://huggingface.co/collections/OpenVINO/llm-6687aaa2abca3bbcec71a9bd) on Hugging Face, it can be downloaded directly via huggingface-cli. +```sh +pip install --upgrade-strategy eager -r ../../export-requirements.txt +huggingface-cli download --local-dir +``` ## Sample Descriptions ### Common information Follow [Get Started with Samples](https://docs.openvino.ai/2024/learn-openvino/openvino-samples/get-started-demos.html) to get common information about OpenVINO samples. +Follow [build instruction](https://github.com/openvinotoolkit/openvino.genai/blob/master/src/docs/BUILD.md) to build GenAI samples + +GPUs usually provide better performance compared to CPUs. Modify the source code to change the device for inference to the GPU. -Discrete GPUs (dGPUs) usually provide better performance compared to CPUs. It is recommended to run larger models on a dGPU with 32GB+ RAM. For example, the model meta-llama/Llama-2-13b-chat-hf can benefit from being run on a dGPU. Modify the source code to change the device for inference to the GPU. +See https://github.com/openvinotoolkit/openvino.genai/blob/master/SUPPORTED_MODELS.md for the list of supported models. -See https://github.com/openvinotoolkit/openvino.genai/blob/master/src/README.md#supported-models for the list of supported models. +Install [../../deployment-requirements.txt](../../deployment-requirements.txt) to run samples +```sh +pip install --upgrade-strategy eager -r ../../deployment-requirements.txt +``` -### 1. Greedy Causal LM (`greedy_causal_lm`) +### 1. Chat Sample (`chat_sample`) +- **Description:** +Interactive chat interface powered by OpenVINO. +Here is a Jupyter [notebook](https://github.com/openvinotoolkit/openvino_notebooks/tree/latest/notebooks/llm-chatbot) that provides an example of LLM-powered text generation in Python. +Recommended models: meta-llama/Llama-2-7b-chat-hf, TinyLlama/TinyLlama-1.1B-Chat-v1.0, etc +- **Main Feature:** Real-time chat-like text generation. +- **Run Command:** + ```bash + ./chat_sample + ``` +#### Missing chat template +If you encounter an exception indicating a missing "chat template" when launching the `ov::genai::LLMPipeline` in chat mode, it likely means the model was not tuned for chat functionality. To work this around, manually add the chat template to tokenizer_config.json of your model. +The following template can be used as a default, but it may not work properly with every model: +``` +"chat_template": "{% for message in messages %}{% if (message['role'] == 'user') %}{{'<|im_start|>user\n' + message['content'] + '<|im_end|>\n<|im_start|>assistant\n'}}{% elif (message['role'] == 'assistant') %}{{message['content'] + '<|im_end|>\n'}}{% endif %}{% endfor %}", +``` + +### 2. Greedy Causal LM (`greedy_causal_lm`) - **Description:** Basic text generation using a causal language model. Here is a Jupyter [notebook](https://github.com/openvinotoolkit/openvino_notebooks/tree/latest/notebooks/llm-question-answering) that provides an example of LLM-powered text generation in Python. @@ -40,7 +65,7 @@ Recommended models: meta-llama/Llama-2-7b-hf, etc ./greedy_causal_lm "" ``` -### 2. Beam Search Causal LM (`beam_search_causal_lm`) +### 3. Beam Search Causal LM (`beam_search_causal_lm`) - **Description:** Uses beam search for more coherent text generation. Here is a Jupyter [notebook](https://github.com/openvinotoolkit/openvino_notebooks/tree/latest/notebooks/llm-question-answering) that provides an example of LLM-powered text generation in Python. @@ -51,23 +76,6 @@ Recommended models: meta-llama/Llama-2-7b-hf, etc ./beam_search_causal_lm "" ["" ...] ``` -### 3. Chat Sample (`chat_sample`) -- **Description:** -Interactive chat interface powered by OpenVINO. -Here is a Jupyter [notebook](https://github.com/openvinotoolkit/openvino_notebooks/tree/latest/notebooks/llm-chatbot) that provides an example of LLM-powered text generation in Python. -Recommended models: meta-llama/Llama-2-7b-chat-hf, TinyLlama/TinyLlama-1.1B-Chat-v1.0, etc -- **Main Feature:** Real-time chat-like text generation. -- **Run Command:** - ```bash - ./chat_sample - ``` -#### Missing chat template -If you encounter an exception indicating a missing "chat template" when launching the `ov::genai::LLMPipeline` in chat mode, it likely means the model was not tuned for chat functionality. To work this around, manually add the chat template to tokenizer_config.json of your model. -The following template can be used as a default, but it may not work properly with every model: -``` -"chat_template": "{% for message in messages %}{% if (message['role'] == 'user') %}{{'<|im_start|>user\n' + message['content'] + '<|im_end|>\n<|im_start|>assistant\n'}}{% elif (message['role'] == 'assistant') %}{{message['content'] + '<|im_end|>\n'}}{% endif %}{% endfor %}", -``` - ### 4. Multinomial Causal LM (`multinomial_causal_lm`) - **Description:** Text generation with multinomial sampling for diversity. Recommended models: meta-llama/Llama-2-7b-hf, etc @@ -104,7 +112,16 @@ Recommended models: meta-llama/Llama-2-13b-hf as main model and TinyLlama/TinyLl ./speculative_decoding_lm "" ``` -### 7. Encrypted Model Causal LM (`encrypted_model_causal_lm`) +### 7. LoRA Greedy Causal LM (`lora_greedy_causal_lm`) +- **Description:** +This sample demonstrates greedy decoding using Low-Rank Adaptation (LoRA) fine-tuned causal language models. LoRA enables efficient fine-tuning, reducing resource requirements for adapting large models to specific tasks. +- **Main Feature:** Lightweight fine-tuning with LoRA for efficient text generation +- **Run Command:** + ```bash + ./lora_greedy_causal_lm "" + ``` + +### 8. Encrypted Model Causal LM (`encrypted_model_causal_lm`) - **Description:** LLMPipeline and Tokenizer objects can be initialized directly from the memory buffer, e.g. when user stores only encrypted files and decrypts them on-the-fly. The following code snippet demonstrates how to load the model from the memory buffer: @@ -120,7 +137,7 @@ For the sake of brevity the code above does not include Tokenizer decryption. Fo ./encrypted_model_causal_lm "" ``` -### 8. LLMs benchmarking sample (`benchmark_genai`) +### 9. LLMs benchmarking sample (`benchmark_genai`) - **Description:** This sample script demonstrates how to benchmark an LLMs in OpenVINO GenAI. The script includes functionality for warm-up iterations, generating text, and calculating various performance metrics. diff --git a/samples/python/text_generation/README.md b/samples/python/text_generation/README.md index 9940904cfb..84b5302639 100644 --- a/samples/python/text_generation/README.md +++ b/samples/python/text_generation/README.md @@ -2,7 +2,7 @@ These samples showcase the use of OpenVINO's inference capabilities for text generation tasks, including different decoding strategies such as beam search, multinomial sampling, and speculative decoding. Each sample has a specific focus and demonstrates a unique aspect of text generation. The applications don't have many configuration options to encourage the reader to explore and modify the source code. For example, change the device for inference to GPU. -There are also Jupyter notebooks for some samples. You can find links to them in the appropriate sample descritions. +There are also Jupyter notebooks for some samples. You can find links to them in the appropriate sample descriptions. ## Table of Contents 1. [Download and Convert the Model and Tokenizers](#download-and-convert-the-model-and-tokenizers) @@ -11,25 +11,50 @@ There are also Jupyter notebooks for some samples. You can find links to them in 4. [Support and Contribution](#support-and-contribution) ## Download and convert the model and tokenizers - The `--upgrade-strategy eager` option is needed to ensure `optimum-intel` is upgraded to the latest version. - -It's not required to install [../../export-requirements.txt](../../export-requirements.txt) for deployment if the model has already been exported. - +Install [../../export-requirements.txt](../../export-requirements.txt) if model conversion is required. ```sh -pip install --upgrade-strategy eager -r ../../requirements.txt +pip install --upgrade-strategy eager -r ../../export-requirements.txt optimim-cli export openvino --model ``` +If a converted model in OpenVINO IR format is already available in the collection of [OpenVINO optimized LLMs](https://huggingface.co/collections/OpenVINO/llm-6687aaa2abca3bbcec71a9bd) on Hugging Face, it can be downloaded directly via huggingface-cli. +```sh +pip install --upgrade-strategy eager -r ../../export-requirements.txt +huggingface-cli download --local-dir +``` ## Sample Descriptions ### Common information Follow [Get Started with Samples](https://docs.openvino.ai/2024/learn-openvino/openvino-samples/get-started-demos.html) to get common information about OpenVINO samples. +Follow [build instruction](https://github.com/openvinotoolkit/openvino.genai/blob/master/src/docs/BUILD.md) to build GenAI samples + +GPUs usually provide better performance compared to CPUs. Modify the source code to change the device for inference to the GPU. -Discrete GPUs (dGPUs) usually provide better performance compared to CPUs. It is recommended to run larger models on a dGPU with 32GB+ RAM. For example, the model meta-llama/Llama-2-13b-chat-hf can benefit from being run on a dGPU. Modify the source code to change the device for inference to the GPU. +See https://github.com/openvinotoolkit/openvino.genai/blob/master/SUPPORTED_MODELS.md for the list of supported models. -See https://github.com/openvinotoolkit/openvino.genai/blob/master/src/README.md#supported-models for the list of supported models. +Install [../../deployment-requirements.txt](../../deployment-requirements.txt) to run samples +```sh +pip install --upgrade-strategy eager -r ../../deployment-requirements.txt +``` -### 1. Greedy Causal LM (`greedy_causal_lm`) +### 1. Chat Sample (`chat_sample`) +- **Description:** +Interactive chat interface powered by OpenVINO. +Here is a Jupyter [notebook](https://github.com/openvinotoolkit/openvino_notebooks/tree/latest/notebooks/llm-chatbot) that provides an example of LLM-powered text generation in Python. +Recommended models: meta-llama/Llama-2-7b-chat-hf, TinyLlama/TinyLlama-1.1B-Chat-v1.0, etc +- **Main Feature:** Real-time chat-like text generation. +- **Run Command:** + ```bash + python chat_sample.py model_dir + ``` +#### Missing chat template +If you encounter an exception indicating a missing "chat template" when launching the `ov::genai::LLMPipeline` in chat mode, it likely means the model was not tuned for chat functionality. To work this around, manually add the chat template to tokenizer_config.json of your model. +The following template can be used as a default, but it may not work properly with every model: +``` +"chat_template": "{% for message in messages %}{% if (message['role'] == 'user') %}{{'<|im_start|>user\n' + message['content'] + '<|im_end|>\n<|im_start|>assistant\n'}}{% elif (message['role'] == 'assistant') %}{{message['content'] + '<|im_end|>\n'}}{% endif %}{% endfor %}", +``` + +### 2. Greedy Causal LM (`greedy_causal_lm`) - **Description:** Basic text generation using a causal language model. Here is a Jupyter [notebook](https://github.com/openvinotoolkit/openvino_notebooks/tree/latest/notebooks/llm-question-answering) that provides an example of LLM-powered text generation in Python. @@ -40,7 +65,7 @@ Recommended models: meta-llama/Llama-2-7b-hf, etc python greedy_causal_lm.py [-h] model_dir prompt ``` -### 2. Beam Search Causal LM (`beam_search_causal_lm`) +### 3. Beam Search Causal LM (`beam_search_causal_lm`) - **Description:** Uses beam search for more coherent text generation. Here is a Jupyter [notebook](https://github.com/openvinotoolkit/openvino_notebooks/tree/latest/notebooks/llm-question-answering) that provides an example of LLM-powered text generation in Python. @@ -51,23 +76,6 @@ Recommended models: meta-llama/Llama-2-7b-hf, etc python beam_search_causal_lm.py model_dir prompt [prompts ...] ``` -### 3. Chat Sample (`chat_sample`) -- **Description:** -Interactive chat interface powered by OpenVINO. -Here is a Jupyter [notebook](https://github.com/openvinotoolkit/openvino_notebooks/tree/latest/notebooks/llm-chatbot) that provides an example of LLM-powered text generation in Python. -Recommended models: meta-llama/Llama-2-7b-chat-hf, TinyLlama/TinyLlama-1.1B-Chat-v1.0, etc -- **Main Feature:** Real-time chat-like text generation. -- **Run Command:** - ```bash - python chat_sample.py model_dir - ``` -#### Missing chat template -If you encounter an exception indicating a missing "chat template" when launching the `ov::genai::LLMPipeline` in chat mode, it likely means the model was not tuned for chat functionality. To work this around, manually add the chat template to tokenizer_config.json of your model. -The following template can be used as a default, but it may not work properly with every model: -``` -"chat_template": "{% for message in messages %}{% if (message['role'] == 'user') %}{{'<|im_start|>user\n' + message['content'] + '<|im_end|>\n<|im_start|>assistant\n'}}{% elif (message['role'] == 'assistant') %}{{message['content'] + '<|im_end|>\n'}}{% endif %}{% endfor %}", -``` - ### 4. Multinomial Causal LM (`multinomial_causal_lm`) - **Description:** Text generation with multinomial sampling for diversity. Recommended models: meta-llama/Llama-2-7b-hf, etc @@ -104,7 +112,16 @@ Recommended models: meta-llama/Llama-2-13b-hf as main model and TinyLlama/TinyLl python speculative_decoding_lm.py model_dir draft_model_dir prompt ``` -### 7. LLMs benchmarking sample (`benchmark_genai`) +### 7. LoRA Greedy Causal LM (`lora_greedy_causal_lm`) +- **Description:** +This sample demonstrates greedy decoding using Low-Rank Adaptation (LoRA) fine-tuned causal language models. LoRA enables efficient fine-tuning, reducing resource requirements for adapting large models to specific tasks. +- **Main Feature:** Lightweight fine-tuning with LoRA for efficient text generation +- **Run Command:** + ```bash + ./lora_greedy_causal_lm "" + ``` + +### 8. LLMs benchmarking sample (`benchmark_genai`) - **Description:** This sample script demonstrates how to benchmark an LLMs in OpenVINO GenAI. The script includes functionality for warm-up iterations, generating text, and calculating various performance metrics. From a9048d3fef5b98a69826f31f25398f90a20dade0 Mon Sep 17 00:00:00 2001 From: Irina Efode Date: Tue, 21 Jan 2025 18:16:06 +0400 Subject: [PATCH 2/7] [ Speculative decoding ][ Prompt lookup ] Enable Perf Metrics for assisting pipelines (#1599) Ticket: *[160440](https://jira.devtools.intel.com/browse/CVS-160440) --- src/cpp/src/continuous_batching_impl.cpp | 8 +---- src/cpp/src/continuous_batching_impl.hpp | 2 +- .../continuous_batching_for_prompt_lookup.cpp | 4 +++ .../continuous_batching_for_prompt_lookup.hpp | 2 ++ .../src/prompt_lookup/prompt_lookup_impl.cpp | 29 ++++++++++++++++- .../src/prompt_lookup/prompt_lookup_impl.hpp | 1 + src/cpp/src/sampler.cpp | 2 ++ src/cpp/src/sampler.hpp | 2 ++ ...batching_for_speculative_decoding_impl.cpp | 4 +++ ...batching_for_speculative_decoding_impl.hpp | 2 ++ .../speculative_decoding_impl.cpp | 31 +++++++++++++++++++ .../speculative_decoding_impl.hpp | 3 ++ src/cpp/src/timer.hpp | 22 ++++++++++--- 13 files changed, 98 insertions(+), 14 deletions(-) diff --git a/src/cpp/src/continuous_batching_impl.cpp b/src/cpp/src/continuous_batching_impl.cpp index 4877860442..4c035fbd7b 100644 --- a/src/cpp/src/continuous_batching_impl.cpp +++ b/src/cpp/src/continuous_batching_impl.cpp @@ -156,13 +156,6 @@ void ContinuousBatchingPipeline::ContinuousBatchingImpl::step() { m_pipeline_metrics.max_cache_usage = std::max(m_pipeline_metrics.max_cache_usage, scheduler_output.m_cache_usage); _register_step_cache_usage(scheduler_output.m_cache_usage); m_pipeline_metrics.avg_cache_usage = _get_current_running_average_cache_usage(); - - m_batch_size = 0; // total number of running sequences - for (size_t i = 0; i < scheduler_output.m_scheduled_sequence_groups_ids.size(); ++i) { - size_t seq_group_id = scheduler_output.m_scheduled_sequence_groups_ids[i]; - SequenceGroup::CPtr sequence_group = m_requests[seq_group_id]; - m_batch_size += sequence_group->num_running_seqs(); - } } // if no tokens were scheduled, we are out of memory => free all requests and return @@ -210,6 +203,7 @@ void ContinuousBatchingPipeline::ContinuousBatchingImpl::step() { static ManualTimer timer("sample"); timer.start(); sampler_output = m_sampler->sample(m_requests, logits, m_is_validation_mode_enabled); + m_batch_size = sampler_output.num_generated_tokens; timer.end(); } diff --git a/src/cpp/src/continuous_batching_impl.hpp b/src/cpp/src/continuous_batching_impl.hpp index 8980038f73..7e2480e5b0 100644 --- a/src/cpp/src/continuous_batching_impl.hpp +++ b/src/cpp/src/continuous_batching_impl.hpp @@ -31,7 +31,7 @@ class ContinuousBatchingPipeline::ContinuousBatchingImpl : public ContinuousBatc // for perf metrics float m_load_time_ms = 0.0f; - size_t m_batch_size = 0; // stored number of scheduled sequences on last step + size_t m_batch_size = 0; // stored number of processed tokens on last step // flag to enable validation mode for sampler bool m_is_validation_mode_enabled = false; diff --git a/src/cpp/src/prompt_lookup/continuous_batching_for_prompt_lookup.cpp b/src/cpp/src/prompt_lookup/continuous_batching_for_prompt_lookup.cpp index d01c863549..aa4ea8a53a 100644 --- a/src/cpp/src/prompt_lookup/continuous_batching_for_prompt_lookup.cpp +++ b/src/cpp/src/prompt_lookup/continuous_batching_for_prompt_lookup.cpp @@ -91,4 +91,8 @@ std::vector ContinuousBatchingPipeline::ContinuousBatchingFo return m_awaiting_requests; } +size_t ContinuousBatchingPipeline::ContinuousBatchingForPromptLookupImpl::get_processed_tokens_per_iteration() { + return m_batch_size; +} + } \ No newline at end of file diff --git a/src/cpp/src/prompt_lookup/continuous_batching_for_prompt_lookup.hpp b/src/cpp/src/prompt_lookup/continuous_batching_for_prompt_lookup.hpp index fc4942701e..98b2d71586 100644 --- a/src/cpp/src/prompt_lookup/continuous_batching_for_prompt_lookup.hpp +++ b/src/cpp/src/prompt_lookup/continuous_batching_for_prompt_lookup.hpp @@ -37,6 +37,8 @@ class ContinuousBatchingPipeline::ContinuousBatchingForPromptLookupImpl : public bool is_requests_empty(); std::vector get_awaiting_requests(); + size_t get_processed_tokens_per_iteration(); + using ContinuousBatchingPipeline::ContinuousBatchingImpl::drop_requests; protected: TokenIds generate_candidates(const TokenIds& input_ids, size_t num_pred_tokens, size_t max_ngram_size); diff --git a/src/cpp/src/prompt_lookup/prompt_lookup_impl.cpp b/src/cpp/src/prompt_lookup/prompt_lookup_impl.cpp index 41c3e6370f..9eb54f700c 100644 --- a/src/cpp/src/prompt_lookup/prompt_lookup_impl.cpp +++ b/src/cpp/src/prompt_lookup/prompt_lookup_impl.cpp @@ -29,6 +29,11 @@ bool ContinuousBatchingPipeline::PromptLookupImpl::has_non_finished_requests() { } void ContinuousBatchingPipeline::PromptLookupImpl::step() { + auto& raw_perf_counters = m_perf_metrics.raw_metrics; + + ManualTimer step_timer("prompt_lookup_decoding: step()"); + step_timer.start(); + ManualTimer candidates_timer("prompt_lookup_decoding: generate_candidates()"); candidates_timer.start(); m_pipeline->generate_candidates(); @@ -36,7 +41,7 @@ void ContinuousBatchingPipeline::PromptLookupImpl::step() { m_sd_metrics.draft_duration += candidates_timer.get_duration(); auto generated_len_before = m_pipeline->get_generated_request_len(); - ManualTimer main_timer("prompt_lookup_decoding: step()"); + ManualTimer main_timer("prompt_lookup_decoding: pipeline: step()"); main_timer.start(); m_pipeline->step(); main_timer.end(); @@ -63,6 +68,18 @@ void ContinuousBatchingPipeline::PromptLookupImpl::step() { m_sd_metrics.update_draft_accepted_tokens(request_id, num_matches); } + // update perf metrics + const auto num_generated_tokens = m_pipeline->get_processed_tokens_per_iteration(); + if (num_generated_tokens > 0) { + raw_perf_counters.m_batch_sizes.emplace_back(num_generated_tokens); + + auto infer_duration = step_timer.get_duration_microsec(); + + raw_perf_counters.m_token_infer_durations.emplace_back(infer_duration); + raw_perf_counters.m_inference_durations[0] += MicroSeconds(infer_duration); + raw_perf_counters.m_new_token_times.emplace_back(main_timer.get_end_time()); + } + if (generated_len_after.empty() && 0) { m_sd_metrics.print(true); m_sd_metrics.clean_up(); @@ -73,6 +90,9 @@ std::vector ContinuousBatchingPipeline::PromptLookupImpl::generate(const std::vector& input_ids, const std::vector& sampling_params, const StreamerVariant& streamer) { + m_perf_metrics = PerfMetrics(); + m_perf_metrics.raw_metrics.m_inference_durations = {{ MicroSeconds(0.0f) }}; + OPENVINO_ASSERT(!has_non_finished_requests(), "Generate cannot be called while ContinuousBatchingPipeline is already in running state. Use ContinuousBatchingPipeline::add_request"); OPENVINO_ASSERT(input_ids.size() == sampling_params.size()); @@ -173,6 +193,13 @@ ContinuousBatchingPipeline::PromptLookupImpl::generate(const std::vectorget_status(); + + // The same perf metrics for each sequence, only tokenization/detokenization will differ. + m_perf_metrics.raw_metrics.generate_durations.clear(); + m_perf_metrics.raw_metrics.generate_durations.emplace_back(generate_timer.get_duration_microsec()); + m_perf_metrics.num_input_tokens = request->get_prompt_len(); + m_perf_metrics.evaluate_statistics(generate_timer.get_start_time()); + results.push_back(std::move(result)); } diff --git a/src/cpp/src/prompt_lookup/prompt_lookup_impl.hpp b/src/cpp/src/prompt_lookup/prompt_lookup_impl.hpp index 1499bcc76e..0535931d81 100644 --- a/src/cpp/src/prompt_lookup/prompt_lookup_impl.hpp +++ b/src/cpp/src/prompt_lookup/prompt_lookup_impl.hpp @@ -15,6 +15,7 @@ class ContinuousBatchingPipeline::PromptLookupImpl : public ContinuousBatchingPi protected: std::shared_ptr m_pipeline; SpeculativeDecodingMetrics m_sd_metrics; + PerfMetrics m_perf_metrics; void drop_requests(); diff --git a/src/cpp/src/sampler.cpp b/src/cpp/src/sampler.cpp index 827309724e..4c399ab641 100644 --- a/src/cpp/src/sampler.cpp +++ b/src/cpp/src/sampler.cpp @@ -250,6 +250,7 @@ void Sampler::GroupBeamSearcher::select_next_tokens(const ov::Tensor& logits, for (Group& group : m_groups) { if (!group.done) { for (Beam& beam : group.ongoing) { + sampler_output.num_generated_tokens++; uint64_t parent_seq_id = beam.m_sequence->get_id(); // here we need to map index of sequence in beam search group(s) and sequence group @@ -793,6 +794,7 @@ SamplerOutput Sampler::sample(const std::vector & sequence_g bool is_validation_passed = true; // make `num_tokens_to_process` iteration to validate a candidate generated by `draft_model` + 1 iteration to generate one more token by `main_model` for (size_t i = 0; i <= num_tokens_to_process; ++i) { + sampler_output.num_generated_tokens++; // calculate token offset from the end of logit size_t token_offset = num_tokens_to_process - i; // max counter of needed to be sampled tokens diff --git a/src/cpp/src/sampler.hpp b/src/cpp/src/sampler.hpp index ca8937cb60..73c656a41d 100644 --- a/src/cpp/src/sampler.hpp +++ b/src/cpp/src/sampler.hpp @@ -38,6 +38,8 @@ struct SamplerOutput { // IDs of sequences that need to be forked (note, the same sequence can be forked multiple times) // it will later be used by scheduler to fork block_tables for child sequences std::unordered_map> m_forked_sequences; + // store number of generated_tokens + size_t num_generated_tokens = 0; }; class Sampler { diff --git a/src/cpp/src/speculative_decoding/continuous_batching_for_speculative_decoding_impl.cpp b/src/cpp/src/speculative_decoding/continuous_batching_for_speculative_decoding_impl.cpp index dccc633d4d..bec2b75e0d 100644 --- a/src/cpp/src/speculative_decoding/continuous_batching_for_speculative_decoding_impl.cpp +++ b/src/cpp/src/speculative_decoding/continuous_batching_for_speculative_decoding_impl.cpp @@ -292,6 +292,10 @@ std::vector ContinuousBatchingPipeline::ContinuousBatchingFo return m_awaiting_requests; } +size_t ContinuousBatchingPipeline::ContinuousBatchingForSpeculativeDecodingImpl::get_processed_tokens_per_iteration() { + return m_batch_size; +} + void ContinuousBatchingPipeline::ContinuousBatchingForSpeculativeDecodingImpl::pull_awaiting_requests(bool is_pause_request) { std::lock_guard lock{m_awaiting_requests_mutex}; diff --git a/src/cpp/src/speculative_decoding/continuous_batching_for_speculative_decoding_impl.hpp b/src/cpp/src/speculative_decoding/continuous_batching_for_speculative_decoding_impl.hpp index 3777d9b87b..e4e4be63d8 100644 --- a/src/cpp/src/speculative_decoding/continuous_batching_for_speculative_decoding_impl.hpp +++ b/src/cpp/src/speculative_decoding/continuous_batching_for_speculative_decoding_impl.hpp @@ -32,6 +32,8 @@ class ContinuousBatchingPipeline::ContinuousBatchingForSpeculativeDecodingImpl : bool is_requests_empty(); std::vector get_awaiting_requests(); + size_t get_processed_tokens_per_iteration(); + UpdateRequestResult init_request_by_candidate(uint64_t request_id, const GeneratedSequences& candidates); protected: diff --git a/src/cpp/src/speculative_decoding/speculative_decoding_impl.cpp b/src/cpp/src/speculative_decoding/speculative_decoding_impl.cpp index 5483523698..7a6066fc5c 100644 --- a/src/cpp/src/speculative_decoding/speculative_decoding_impl.cpp +++ b/src/cpp/src/speculative_decoding/speculative_decoding_impl.cpp @@ -131,6 +131,12 @@ void print_generated_request(const ov::genai::GeneratedRequests& requests) { void ContinuousBatchingPipeline::SpeculativeDecodingImpl::step() { // this blocks adding new requests during step as it may break coherence between main and draft models std::lock_guard lock{m_draft_generations_mutex}; + + auto& raw_perf_counters = m_perf_metrics.raw_metrics; + + ManualTimer step_timer("speculative_decoding: step()"); + step_timer.start(); + m_draft_pipeline->pull_awaiting_requests(true); m_main_pipeline->pull_awaiting_requests(); @@ -182,6 +188,18 @@ void ContinuousBatchingPipeline::SpeculativeDecodingImpl::step() { m_sd_metrics.update_draft_accepted_tokens(request_id, (updated_seq_info.inserted_tokens_cnt - updated_seq_info.removed_tokens_cnt)); } + // update perf metrics + const auto num_generated_tokens = m_main_pipeline->get_processed_tokens_per_iteration(); + if (num_generated_tokens > 0) { + auto infer_duration = step_timer.get_duration_microsec(); + + raw_perf_counters.m_token_infer_durations.emplace_back(infer_duration); + raw_perf_counters.m_inference_durations[0] += MicroSeconds(infer_duration); + raw_perf_counters.m_new_token_times.emplace_back(main_timer.get_end_time()); + + raw_perf_counters.m_batch_sizes.emplace_back(num_generated_tokens); + } + if (main_generated_requests.empty() && 0) { m_sd_metrics.print(true); m_sd_metrics.clean_up(); @@ -192,6 +210,9 @@ std::vector ContinuousBatchingPipeline::SpeculativeDecodingImpl::generate(const std::vector& input_ids, const std::vector& sampling_params, const StreamerVariant& streamer) { + m_perf_metrics = PerfMetrics(); + m_perf_metrics.raw_metrics.m_inference_durations = {{ MicroSeconds(0.0f) }}; + OPENVINO_ASSERT(!has_non_finished_requests(), "Generate cannot be called while ContinuousBatchingPipeline is already in running state. Use ContinuousBatchingPipeline::add_request"); OPENVINO_ASSERT(input_ids.size() == sampling_params.size()); @@ -273,6 +294,8 @@ ContinuousBatchingPipeline::SpeculativeDecodingImpl::generate(const std::vector< std::vector results; results.reserve(all_requests.size()); + generate_timer.end(); + for (size_t request_id = 0; request_id < all_requests.size(); ++request_id) { const auto& request = all_requests[request_id]; auto sampling_params = request->get_sampling_parameters(); @@ -297,6 +320,14 @@ ContinuousBatchingPipeline::SpeculativeDecodingImpl::generate(const std::vector< } result.m_status = main_generations[request_id]->get_status(); + + // The same perf metrics for each sequence, only tokenization/detokenization will differ. + m_perf_metrics.raw_metrics.generate_durations.clear(); + m_perf_metrics.raw_metrics.generate_durations.emplace_back(generate_timer.get_duration_microsec()); + m_perf_metrics.num_input_tokens = request->get_prompt_len(); + m_perf_metrics.evaluate_statistics(generate_timer.get_start_time()); + + result.perf_metrics = m_perf_metrics; results.push_back(std::move(result)); } diff --git a/src/cpp/src/speculative_decoding/speculative_decoding_impl.hpp b/src/cpp/src/speculative_decoding/speculative_decoding_impl.hpp index 7475d9d766..4023519287 100644 --- a/src/cpp/src/speculative_decoding/speculative_decoding_impl.hpp +++ b/src/cpp/src/speculative_decoding/speculative_decoding_impl.hpp @@ -37,7 +37,10 @@ struct ModelDesc { class ContinuousBatchingPipeline::SpeculativeDecodingImpl : public ContinuousBatchingPipeline::IContinuousBatchingPipeline { protected: std::shared_ptr m_main_pipeline, m_draft_pipeline; + // Metrics SpeculativeDecodingMetrics m_sd_metrics; + PerfMetrics m_perf_metrics; + // Mutex protecting access to m_draft_generations, so add_request and step methods can be called from different threads std::mutex m_draft_generations_mutex; std::map m_draft_generations; diff --git a/src/cpp/src/timer.hpp b/src/cpp/src/timer.hpp index f389e10d5d..588fb4967d 100644 --- a/src/cpp/src/timer.hpp +++ b/src/cpp/src/timer.hpp @@ -9,7 +9,7 @@ class ManualTimer { double m_total; - decltype(std::chrono::steady_clock::now()) m_start; + std::chrono::steady_clock::time_point m_start, m_end; std::string m_title; public: ManualTimer(const std::string& title) : @@ -22,15 +22,27 @@ class ManualTimer { } void end() { - auto m_end = std::chrono::steady_clock::now(); - m_total += std::chrono::duration(m_end - m_start).count(); + m_end = std::chrono::steady_clock::now(); + m_total += std::chrono::duration_cast(m_end - m_start).count(); + } + + std::chrono::steady_clock::time_point get_start_time() { + return m_start; + } + + std::chrono::steady_clock::time_point get_end_time() { + return m_end; } float get_duration() const { - return m_total / 1000.; + return m_total / 1e6; + } + + float get_duration_microsec() const { + return m_total; } ~ManualTimer() { - // std::cout << m_title << ": " << m_total / 1000. << " secs" << std::endl; + // std::cout << m_title << ": " << m_total / 1e6. << " secs" << std::endl; } }; From bb62b71fee7124166dd2168d0dd92e0bde868fda Mon Sep 17 00:00:00 2001 From: Alexey Smirnov Date: Tue, 21 Jan 2025 15:33:12 +0000 Subject: [PATCH 3/7] [LLM] [NPU] StaticLLMPipeline: Export blob (#1601) Release: https://github.com/openvinotoolkit/openvino.genai/pull/1603 --- src/cpp/src/llm_pipeline_static.cpp | 32 +++++++++++++++++++++++------ 1 file changed, 26 insertions(+), 6 deletions(-) diff --git a/src/cpp/src/llm_pipeline_static.cpp b/src/cpp/src/llm_pipeline_static.cpp index 971378bc42..db25572d33 100644 --- a/src/cpp/src/llm_pipeline_static.cpp +++ b/src/cpp/src/llm_pipeline_static.cpp @@ -698,12 +698,13 @@ StatefulLLMPipeline::StatefulLLMPipeline( utils::from_config_json_if_exists(models_path)), m_sampler(m_tokenizer) { ov::AnyMap properties = config; - const auto use_blob = pop_or_default(properties, "USE_BLOB", false); - if (use_blob) { - auto blob_path = pop_or_default(properties, "BLOB_PATH", std::string{}); - if (blob_path.empty()) { - blob_path = (models_path / "openvino_model.blob").string(); - } + + auto blob_path = pop_or_default(properties, "BLOB_PATH", std::string{}); + const auto export_blob = pop_or_default(properties, "EXPORT_BLOB", false); + + bool do_import = (!blob_path.empty() && !export_blob); + + if (do_import) { if (!std::filesystem::exists(blob_path)) { OPENVINO_THROW("Blob file is not found at: " + blob_path); } @@ -721,6 +722,25 @@ StatefulLLMPipeline::StatefulLLMPipeline( ModelConfigDesc model_desc = get_modeldesc_from_json(models_path / "config.json"); ov::AnyMap properties = config; auto compiled = setupAndCompileModel(model, model_desc, properties); + // Also export compiled model if required + if (export_blob) { + if (blob_path.empty()) { + blob_path = (models_path / "openvino_model.blob").string(); + } + // Check the path is full + const int EXT_SIZE = 5; // ".blob" + if (blob_path.size() < EXT_SIZE) { + OPENVINO_THROW("Please provide a full path to blob file in BLOB_PATH: " + blob_path); + } + if (strncmp(".blob", &blob_path[blob_path.size() - EXT_SIZE], EXT_SIZE) != 0) { + OPENVINO_THROW("Please provide a full path to blob file in BLOB_PATH: " + blob_path); + } + std::ofstream fout(blob_path, std::ios::out | std::ios::binary); + if (!fout.is_open()) { + OPENVINO_THROW("Blob file can't be exported to: " + blob_path); + } + compiled->export_model(fout); + } m_request = compiled->create_infer_request(); m_sampler.set_seed(m_generation_config.rng_seed); } From 2da00a047e54fd6b3e36cad9c562a0a005b796b8 Mon Sep 17 00:00:00 2001 From: Ilya Lavrenov Date: Tue, 21 Jan 2025 21:18:41 +0400 Subject: [PATCH 4/7] LLM: use set_output_seq_len instead of WA (#1611) Such method `set_output_seq_len` of `SequenceGroup` was introduced here https://github.com/openvinotoolkit/openvino.genai/pull/1261 --- src/cpp/src/llm_pipeline_static.cpp | 8 ++++---- src/cpp/src/lm_encoding.cpp | 7 ++----- 2 files changed, 6 insertions(+), 9 deletions(-) diff --git a/src/cpp/src/llm_pipeline_static.cpp b/src/cpp/src/llm_pipeline_static.cpp index db25572d33..b17ee959c5 100644 --- a/src/cpp/src/llm_pipeline_static.cpp +++ b/src/cpp/src/llm_pipeline_static.cpp @@ -940,8 +940,8 @@ EncodedResults StatefulLLMPipeline::generate( auto sequence_group = std::make_shared( 0 /* request_id */, input_ids, config, 1 /* block_size */); - sequence_group->update_processed_tokens_num(sequence_group->get_prompt_len() - output_sequence_len); - sequence_group->schedule_tokens(output_sequence_len); + sequence_group->schedule_tokens(sequence_group->get_prompt_len()); + sequence_group->set_output_seq_len(output_sequence_len); // NB: Controls what tokens are ready to be pushed into the streamer GenerationHandle handle = std::make_shared( @@ -1412,8 +1412,8 @@ EncodedResults StatelessLLMPipeline::generate( // Retrive only useful logits and work only with them here. auto sequence_group = std::make_shared( 0 /* request_id */, padded_input_ids, config, 1 /* block_size */); - sequence_group->update_processed_tokens_num(m_kvcache_desc.max_prompt_size - output_sequence_len); - sequence_group->schedule_tokens(output_sequence_len); + sequence_group->schedule_tokens(m_kvcache_desc.max_prompt_size); + sequence_group->set_output_seq_len(output_sequence_len); // NB: Controls what tokens are ready to be pushed into the streamer GenerationHandle handle = std::make_shared( diff --git a/src/cpp/src/lm_encoding.cpp b/src/cpp/src/lm_encoding.cpp index f93542c70f..e2ec3a1b33 100644 --- a/src/cpp/src/lm_encoding.cpp +++ b/src/cpp/src/lm_encoding.cpp @@ -138,13 +138,10 @@ std::pair> get_lm_encoded_results( auto logits = m_llm.get_tensor("logits"); - // since we have applied `Slice` operation to last MatMul, model output sequence lenght is 1 - // so, we need to update sequence groups to think that they already have processed all prompt tokens except last ones - // and schedule only `output_sequence_len` ones int64_t output_sequence_len = logits.get_shape().at(1); for (auto& sequence_group : sequence_groups) { - sequence_group->update_processed_tokens_num(sequence_group->get_prompt_len() - output_sequence_len); - sequence_group->schedule_tokens(output_sequence_len); + sequence_group->schedule_tokens(sequence_group->get_prompt_len()); + sequence_group->set_output_seq_len(output_sequence_len); } std::map beam_offets; From e0488c87b6243cb411c150e488d506f5219fa80c Mon Sep 17 00:00:00 2001 From: Ekaterina Aidova Date: Tue, 21 Jan 2025 21:32:06 +0400 Subject: [PATCH 5/7] [llm_bench] enable prompt permutations for prevent prefix caching and fix vlm image load (#1607) CVS-160892 --- tools/llm_bench/benchmark.py | 1 + .../llm_bench/llm_bench_utils/model_utils.py | 7 +++-- tools/llm_bench/task/text_generation.py | 30 +++++++++++++++++-- .../task/visual_language_generation.py | 7 +++-- 4 files changed, 38 insertions(+), 7 deletions(-) diff --git a/tools/llm_bench/benchmark.py b/tools/llm_bench/benchmark.py index 6fc135c4ef..5d4a1436a7 100644 --- a/tools/llm_bench/benchmark.py +++ b/tools/llm_bench/benchmark.py @@ -161,6 +161,7 @@ def get_argprser(): parser.add_argument("--num_steps", type=int, required=False, help="Number of inference steps for image generation") parser.add_argument("--height", type=int, required=False, help="Generated image height. Applicable only for Image Generation.") parser.add_argument("--width", type=int, required=False, help="Generated image width. Applicable only for Image Generation.") + parser.add_argument("--disable_prompt_permutation", action="store_true", help="Disable modification prompt from run to run for avoid prefix caching") return parser.parse_args() diff --git a/tools/llm_bench/llm_bench_utils/model_utils.py b/tools/llm_bench/llm_bench_utils/model_utils.py index 324a67bc2a..51d77d3215 100644 --- a/tools/llm_bench/llm_bench_utils/model_utils.py +++ b/tools/llm_bench/llm_bench_utils/model_utils.py @@ -37,12 +37,12 @@ def get_param_from_file(args, input_key): if args["use_case"] != "vlm": raise RuntimeError("Multiple sources for benchmarking supported only for Visual Language Models") data_dict = {} - if args["media"] is None: + if args["media"] is None and args["images"] is None: log.warn("Input image is not provided. Only text generation part will be evaluated") else: - data_dict["media"] = args["media"] + data_dict["media"] = args["media"] if args["media"] is not None else args["images"] if args["prompt"] is None: - data_dict["prompt"] = "What is OpenVINO?" if args["media"] is None else "Describe image" + data_dict["prompt"] = "What is OpenVINO?" if data_dict["media"] is None else "Describe image" else: data_dict["prompt"] = args["prompt"] data_list.append(data_dict) @@ -113,6 +113,7 @@ def analyze_args(args): model_args['torch_compile_options'] = args.torch_compile_options model_args['torch_compile_input_module'] = args.torch_compile_input_module model_args['media'] = args.media + model_args["disable_prompt_permutation"] = args.disable_prompt_permutation optimum = args.optimum diff --git a/tools/llm_bench/task/text_generation.py b/tools/llm_bench/task/text_generation.py index de798f158f..372a034148 100644 --- a/tools/llm_bench/task/text_generation.py +++ b/tools/llm_bench/task/text_generation.py @@ -207,6 +207,19 @@ def run_text_generation_genai(input_text, num, model, tokenizer, args, iter_data tokenization_end = time.perf_counter() tokenization_time = [(tokenization_end - tokenization_start) * 1000] + enable_prompt_permutations = not args.get("disable_prompt_permutation", False) + if enable_prompt_permutations: + log.warning( + "Enabled input prompt permutations. It means that generation results can be vary on different steps. " + "If it does not expected please specify --disable_prompr_permutation in your benchmarking command to disable this behavior" + ) + from openvino_genai import TokenizedInputs + import openvino as ov + + input_ids = input_data.input_ids.data + input_ids[:, 0] = num + 1 + attention_mask = input_data.attention_mask + input_data = TokenizedInputs(input_ids=ov.Tensor(input_ids), attention_mask=attention_mask) num_input_tokens = input_data.input_ids.shape[1] if args['batch_size'] > 1: out_str = '[warm-up]' if num == 0 else '[{}]'.format(num) @@ -325,7 +338,7 @@ def token_printer(): batch_size=args['batch_size'], prompt_idx=prompt_index ) - if num > 0: + if num > 0 and not enable_prompt_permutations: prev_md5 = md5_list[num - 1][prompt_index] if result_md5_list != prev_md5: log.warning(f"[{num}] Prompt[{prompt_index}]'s md5 {result_md5_list} " @@ -366,6 +379,19 @@ def run_text_generation_genai_with_stream(input_text, num, model, tokenizer, arg gen_config.max_new_tokens = max_gen_tokens gen_config.num_beams = args["num_beams"] gen_config.do_sample = False + enable_prompt_permutations = not args.get("disable_prompt_permutation", False) + if enable_prompt_permutations: + log.warning( + "Enabled input prompt permutations. It means that generation results can be vary on different steps. " + "If it does not expected please specify --disable_prompr_permutation in your benchmarking command to disable this behavior" + ) + from openvino_genai import TokenizedInputs + import openvino as ov + + input_ids = input_data.input_ids.data + input_ids[:, 0] = num + 1 + attention_mask = input_data.attention_mask + input_data = TokenizedInputs(input_ids=ov.Tensor(input_ids), attention_mask=attention_mask) if args.get('draft_model', ''): config_info = "Speculative decoding config: " if args.get("num_assistant_tokens", None): @@ -439,7 +465,7 @@ def run_text_generation_genai_with_stream(input_text, num, model, tokenizer, arg batch_size=args['batch_size'], prompt_idx=prompt_index ) - if num > 0: + if num > 0 and not enable_prompt_permutations: prev_md5 = md5_list[num - 1][prompt_index] if result_md5_list != prev_md5: log.warning(f"[{num}] Prompt[{prompt_index}]'s md5 {result_md5_list} " diff --git a/tools/llm_bench/task/visual_language_generation.py b/tools/llm_bench/task/visual_language_generation.py index a5fb0ecc0c..a02b16b2bb 100644 --- a/tools/llm_bench/task/visual_language_generation.py +++ b/tools/llm_bench/task/visual_language_generation.py @@ -44,7 +44,7 @@ def run_visual_language_generation_optimum( for bs_index, in_text in enumerate(prompts): llm_bench_utils.output_file.output_input_text(in_text, args, model_precision, prompt_index, bs_index, proc_id) tok_encode_start = time.perf_counter() - input_data = model.preprocess_inputs(text=prompts[0], image=images[0], **processor) + input_data = model.preprocess_inputs(text=prompts[0], image=images[0] if images else None, **processor) tok_encode_end = time.perf_counter() tok_encode_time = (tok_encode_end - tok_encode_start) * 1000 # Remove `token_type_ids` from inputs @@ -211,8 +211,11 @@ def run_visual_language_generation_genai( gen_config.max_new_tokens = max_gen_tokens gen_config.num_beams = args["num_beams"] gen_config.do_sample = False + kwargs = {} + if len(images) >= 1: + kwargs["images"] = images[0] start = time.perf_counter() - generation_result = model.generate(prompts[0], images=images[0], generation_config=gen_config) + generation_result = model.generate(prompts[0], generation_config=gen_config, **kwargs) end = time.perf_counter() generated_text = generation_result.texts perf_metrics = generation_result.perf_metrics From aa552d1330dfdc073a4b14e6f8d3467d2ecacbc8 Mon Sep 17 00:00:00 2001 From: Ilya Lavrenov Date: Wed, 22 Jan 2025 09:00:16 +0400 Subject: [PATCH 6/7] CB: support different number of K and V heads per layer (#1610) CVS-160810 --- src/cpp/src/device_config.hpp | 76 ++++++++++--------- .../utils/paged_attention_transformations.cpp | 40 +++++----- tests/cpp/cache_manager.cpp | 25 +++--- tests/cpp/device_config.cpp | 10 ++- tests/cpp/scheduler.cpp | 4 +- 5 files changed, 81 insertions(+), 74 deletions(-) diff --git a/src/cpp/src/device_config.hpp b/src/cpp/src/device_config.hpp index cbf3fe70c5..52789ff013 100644 --- a/src/cpp/src/device_config.hpp +++ b/src/cpp/src/device_config.hpp @@ -10,22 +10,27 @@ #include "openvino/genai/scheduler_config.hpp" namespace ov::genai { + +/** + * Per layer KV cache size configuration + */ +struct KVHeadConfig { + size_t num_v_heads, num_k_heads; + size_t v_head_size, k_head_size; +}; + class DeviceConfig { ov::element::Type m_kv_cache_type; std::vector m_key_cache_shape, m_value_cache_shape; - std::vector m_num_kv_heads; - ov::Shape::value_type m_head_size, m_num_decoder_layers; - size_t m_num_kv_blocks = 0; - size_t m_block_size = 0; - size_t m_cache_size = 0; + std::vector m_kv_heads_config; + size_t m_num_decoder_layers = 0; + size_t m_num_kv_blocks = 0, m_cache_size = 0; // KV cache sizes in either blocks or GBs + size_t m_block_size = 0; // block size is per inference device std::string m_device; size_t get_block_size_by_device(const std::string& device) const { - const size_t cpu_block_size = 32; - const size_t gpu_block_size = 16; - - bool is_gpu = device.find("GPU") != std::string::npos; - + const size_t cpu_block_size = 32, gpu_block_size = 16; + const bool is_gpu = device.find("GPU") != std::string::npos; return is_gpu ? gpu_block_size : cpu_block_size; } @@ -83,17 +88,14 @@ class DeviceConfig { if (scheduling_config.num_kv_blocks > 0) { m_num_kv_blocks = scheduling_config.num_kv_blocks; - } - else if (scheduling_config.cache_size > 0) { + } else if (scheduling_config.cache_size > 0) { m_cache_size = scheduling_config.cache_size; } } - void set_model_params(std::vector num_kv_heads, size_t head_size, size_t num_decoder_layers) { - m_head_size = head_size; - m_num_decoder_layers = num_decoder_layers; - - m_num_kv_heads.assign(num_kv_heads.begin(), num_kv_heads.end()); + void set_kv_head_configs(std::vector kv_heads_config) { + m_kv_heads_config = kv_heads_config; + m_num_decoder_layers = m_kv_heads_config.size(); m_key_cache_shape.reserve(m_num_decoder_layers); m_value_cache_shape.reserve(m_num_decoder_layers); @@ -103,35 +105,37 @@ class DeviceConfig { // |scale(f32)|zeropoint(f32)|quantized data(u8,idx_1)|quantized data(u8,idx_2)|...|quantized data(u8,idx_head_size)| // so, we have to extend head_size by 8, which is sizeof(float) // for scale and sizeof(float) for zeropoint - if (m_kv_cache_type == ov::element::u8) - m_head_size += 8; + if (m_kv_cache_type == ov::element::u8) { + for (size_t layer_id = 0; layer_id < m_num_decoder_layers; ++layer_id) { + m_kv_heads_config[layer_id].k_head_size += 8; + m_kv_heads_config[layer_id].v_head_size += 8; + } + } } if (m_num_kv_blocks == 0 && m_cache_size > 0) { - size_t block_size = 0; - size_t size_in_bytes = m_cache_size * 1024 * 1024 * 1024; - for (size_t layer_id = 0; layer_id < m_num_decoder_layers; layer_id++) { - block_size += 2 * m_num_kv_heads[layer_id] * m_block_size * m_head_size * m_kv_cache_type.size(); - } - m_num_kv_blocks = size_in_bytes / block_size; + size_t size_in_bytes = m_cache_size * 1024 * 1024 * 1024; // convert GBs to bytes + m_num_kv_blocks = size_in_bytes / get_block_size_in_bytes(); } for (size_t layer_id = 0; layer_id < m_num_decoder_layers; layer_id++) { + const KVHeadConfig& config = m_kv_heads_config[layer_id]; + m_value_cache_shape.push_back(ov::PartialShape{ov::Dimension::dynamic(), - ov::Dimension(m_num_kv_heads[layer_id]), + ov::Dimension(config.num_v_heads), ov::Dimension(m_block_size), - ov::Dimension(m_head_size)}); + ov::Dimension(config.v_head_size)}); if (m_device.find("GPU") == std::string::npos) { m_key_cache_shape.push_back(ov::PartialShape{ov::Dimension::dynamic(), - ov::Dimension(m_num_kv_heads[layer_id]), + ov::Dimension(config.num_k_heads), ov::Dimension(m_block_size), - ov::Dimension(m_head_size)}); - } else if (m_device.find("GPU") != std::string::npos) { + ov::Dimension(config.k_head_size)}); + } else if (m_device.find("GPU") != std::string::npos) { // Update key shape, as the key's shape is different from the value's shape m_key_cache_shape.push_back(ov::PartialShape{ov::Dimension::dynamic(), - ov::Dimension(m_num_kv_heads[layer_id]), - ov::Dimension(m_head_size), + ov::Dimension(config.num_k_heads), + ov::Dimension(config.k_head_size), ov::Dimension(m_block_size)}); } } @@ -168,11 +172,13 @@ class DeviceConfig { } size_t get_block_size_in_bytes() const { - size_t block_size = 0; + size_t block_size_in_bytes = 0; for (size_t layer_id = 0; layer_id < m_num_decoder_layers; layer_id++) { - block_size += 2 * m_num_kv_heads[layer_id] * m_block_size * m_head_size * get_cache_precision().size(); + const KVHeadConfig& config = m_kv_heads_config[layer_id]; + block_size_in_bytes += config.k_head_size * config.num_k_heads + config.v_head_size * config.num_v_heads; } - return block_size; + block_size_in_bytes *= get_block_size() * get_cache_precision().size(); + return block_size_in_bytes; } }; } diff --git a/src/cpp/src/utils/paged_attention_transformations.cpp b/src/cpp/src/utils/paged_attention_transformations.cpp index baef7d8dd6..0d62bb10e9 100644 --- a/src/cpp/src/utils/paged_attention_transformations.cpp +++ b/src/cpp/src/utils/paged_attention_transformations.cpp @@ -31,37 +31,35 @@ void apply_paged_attention_transformations(std::shared_ptr model, boo } void set_kv_cache_type_and_shape(std::shared_ptr model, DeviceConfig& device_config) { - const ov::ParameterVector& parameters = model->get_parameters(); - std::map> key_cache_params, value_cache_params; - for (const auto& param_ptr : parameters) { + for (const auto& param_ptr : model->get_parameters()) { const auto& name = param_ptr->get_friendly_name(); if (name.find("key_cache.") == 0) { key_cache_params[name] = param_ptr; - } - else if (name.find("value_cache.") == 0) { + } else if (name.find("value_cache.") == 0) { value_cache_params[name] = param_ptr; } } - OPENVINO_ASSERT(key_cache_params.size() > 0); - OPENVINO_ASSERT(key_cache_params.size() == value_cache_params.size()); + OPENVINO_ASSERT(key_cache_params.size() == value_cache_params.size() && key_cache_params.size() > 0); - size_t num_layers = key_cache_params.size(); - // extract num_kv_heads and head_size - std::string key_cache_param_name = "key_cache.0"; - OPENVINO_ASSERT(key_cache_params.count(key_cache_param_name) != 0, "key_cache.0 tensor not found among model parameters"); - ov::PartialShape k_shape = key_cache_params[key_cache_param_name]->get_partial_shape(); - OPENVINO_ASSERT(k_shape.rank().get_length() == 3, "KV cache shape is expected to have rank 3, while shape is ", k_shape); - size_t head_size = k_shape[2].get_length(); - std::vector num_kv_heads(num_layers); - for (size_t idx = 0; idx < num_layers; idx++) { - size_t num_heads = key_cache_params[std::string("key_cache.") + std::to_string(idx)]->get_partial_shape()[1].get_length(); - num_kv_heads[idx] = num_heads; + size_t num_decoder_layers = key_cache_params.size(); + std::vector kv_heads_config(num_decoder_layers); + + for (size_t idx = 0; idx < num_decoder_layers; idx++) { + KVHeadConfig& config = kv_heads_config[idx]; + + auto key_shape = key_cache_params[std::string("key_cache.") + std::to_string(idx)]->get_partial_shape(); + config.num_k_heads = key_shape[1].get_length(); + config.k_head_size = key_shape[2].get_length(); + + auto value_shape = value_cache_params[std::string("value_cache.") + std::to_string(idx)]->get_partial_shape(); + config.num_v_heads = value_shape[1].get_length(); + config.v_head_size = value_shape[2].get_length(); } - device_config.set_model_params(num_kv_heads, head_size, num_layers); + device_config.set_kv_head_configs(kv_heads_config); - for (size_t idx = 0; idx < num_layers; idx++) { + for (size_t idx = 0; idx < num_decoder_layers; idx++) { auto k = key_cache_params[std::string("key_cache.") + std::to_string(idx)]; auto v = value_cache_params[std::string("value_cache.") + std::to_string(idx)]; k->set_element_type(device_config.get_cache_precision()); @@ -80,4 +78,4 @@ void apply_paged_attention_transformations(std::shared_ptr model, Dev } // namespace utils } // namespace genai -} // namespace ov \ No newline at end of file +} // namespace ov diff --git a/tests/cpp/cache_manager.cpp b/tests/cpp/cache_manager.cpp index 7d855ded12..0c483f0ec1 100644 --- a/tests/cpp/cache_manager.cpp +++ b/tests/cpp/cache_manager.cpp @@ -56,9 +56,9 @@ TEST(TestCacheManager, test_cache_size_param) { const std::string device = "CPU"; ov::genai::DeviceConfig device_config(core, scheduler_config, "CPU"); - size_t num_decoder_layers = 12; - std::vector num_kv_heads(12, 12); - device_config.set_model_params(num_kv_heads, 64, num_decoder_layers); + const size_t num_decoder_layers = 12; + const std::vector kv_heads_config(num_decoder_layers, KVHeadConfig { 12, 12, 64, 64 }); + device_config.set_kv_head_configs(kv_heads_config); ov::InferRequest request = core.compile_model(get_dummy_model(core, num_decoder_layers)).create_infer_request(); auto cache_manager = std::make_shared(device_config, request, core); @@ -79,9 +79,9 @@ TEST(TestCacheManager, test_kv_blocks_param) { const std::string device = "CPU"; ov::genai::DeviceConfig device_config(core, scheduler_config, "CPU"); - size_t num_decoder_layers = 12; - std::vector num_kv_heads(12, 12); - device_config.set_model_params(num_kv_heads, 64, num_decoder_layers); + const size_t num_decoder_layers = 12; + const std::vector kv_heads_config(num_decoder_layers, KVHeadConfig { 12, 12, 64, 64 }); + device_config.set_kv_head_configs(kv_heads_config); ov::InferRequest request = core.compile_model(get_dummy_model(core, num_decoder_layers)).create_infer_request(); auto cache_manager = std::make_shared(device_config, request, core); @@ -100,15 +100,16 @@ TEST(TestCacheManager, test_dynamic_cache_increase) { const std::string device = "CPU"; ov::genai::DeviceConfig device_config(core, scheduler_config, "CPU"); - size_t num_decoder_layers = 12; - size_t head_size = 64; - std::vector num_kv_heads(12, 12); - device_config.set_model_params(num_kv_heads, head_size, num_decoder_layers); + const size_t num_decoder_layers = 12; + const std::vector kv_heads_config(num_decoder_layers, KVHeadConfig { 12, 12, 64, 64 }); + device_config.set_kv_head_configs(kv_heads_config); + size_t block_size_in_bytes = 0; for (size_t layer_id = 0; layer_id < num_decoder_layers; layer_id++) { - block_size_in_bytes += 2 * num_kv_heads[layer_id] * device_config.get_block_size() * head_size * device_config.get_cache_precision().size(); + KVHeadConfig config = kv_heads_config[layer_id]; + block_size_in_bytes += config.k_head_size * config.num_k_heads + config.v_head_size * config.num_v_heads; } - + block_size_in_bytes *= device_config.get_block_size() * device_config.get_cache_precision().size(); ov::InferRequest request = core.compile_model(get_dummy_model(core, num_decoder_layers)).create_infer_request(); auto cache_manager = std::make_shared(device_config, request, core); diff --git a/tests/cpp/device_config.cpp b/tests/cpp/device_config.cpp index 93e06f02e7..a97037b1e8 100644 --- a/tests/cpp/device_config.cpp +++ b/tests/cpp/device_config.cpp @@ -18,13 +18,15 @@ TEST(TestDeviceConfig, kv_cache_precision_u8) { const std::string device = "CPU"; size_t num_decoder_layers = 12; size_t head_size = 64, head_size_u8 = head_size + 8; - std::vector num_kv_heads(12, 12); - ov::genai::DeviceConfig device_config_default(core, scheduler_config, "CPU"); - device_config_default.set_model_params(num_kv_heads, head_size_u8, num_decoder_layers); + ov::genai::KVHeadConfig kv_head_config { 12, 12, head_size_u8, head_size_u8 }; + ov::genai::KVHeadConfig kv_head_config_u8 { 12, 12, head_size, head_size }; + ov::genai::DeviceConfig device_config_default(core, scheduler_config, "CPU"); ov::genai::DeviceConfig device_config_u8(core, scheduler_config, "CPU", { ov::hint::kv_cache_precision(ov::element::u8) }); - device_config_u8.set_model_params(num_kv_heads, head_size, num_decoder_layers); + + device_config_default.set_kv_head_configs(std::vector(num_decoder_layers, kv_head_config)); + device_config_u8.set_kv_head_configs(std::vector(num_decoder_layers, kv_head_config_u8)); const auto ratio = ov::element::f16.size() / ov::element::u8.size(); ASSERT_EQ(device_config_default.get_num_kv_blocks() * ratio, device_config_u8.get_num_kv_blocks()); diff --git a/tests/cpp/scheduler.cpp b/tests/cpp/scheduler.cpp index ecd53fa665..201318347a 100644 --- a/tests/cpp/scheduler.cpp +++ b/tests/cpp/scheduler.cpp @@ -47,9 +47,9 @@ std::shared_ptr init_cache_manager(SchedulerConfig scheduler_confi size_t num_decoder_layers = 12; ov::InferRequest request = core.compile_model(get_model(core, num_decoder_layers)).create_infer_request(); size_t head_size = 64, head_size_u8 = head_size + 8; - std::vector num_kv_heads(12, 12); + std::vector kv_head_configs(num_decoder_layers, KVHeadConfig { 12, 12, head_size_u8, head_size_u8 }); ov::genai::DeviceConfig device_config(core, scheduler_config, "CPU"); - device_config.set_model_params(num_kv_heads, head_size_u8, num_decoder_layers); + device_config.set_kv_head_configs(kv_head_configs); return std::make_shared(device_config, request, core); } From d3bf47bca6805cc0b01c18bc9e4bf7b4d5f03ddc Mon Sep 17 00:00:00 2001 From: Ilya Lavrenov Date: Wed, 22 Jan 2025 15:46:41 +0400 Subject: [PATCH 7/7] LLM: fixed Slice / Gather of last MatMul (#1616) CVS-160884 --- src/cpp/src/utils.cpp | 54 +++++++++++++++++++++++++++++++------------ 1 file changed, 39 insertions(+), 15 deletions(-) diff --git a/src/cpp/src/utils.cpp b/src/cpp/src/utils.cpp index 5b17b2eacf..2d6dfd2ae5 100644 --- a/src/cpp/src/utils.cpp +++ b/src/cpp/src/utils.cpp @@ -233,50 +233,74 @@ ov::genai::TokenizedInputs subtract_chat_tokenized_inputs(const ov::genai::Token } namespace { -std::shared_ptr find_llm_matmul(const std::shared_ptr& model) { + +bool has_op_with_type(const std::shared_ptr& function, const std::string& type_name) { + for (const auto& op : function->get_ops()) { + if (op->get_type_name() == type_name) { + return true; + } + } + return false; +} + +std::tuple, int64_t> find_llm_matmul(const std::shared_ptr& model) { auto last_node = model->output(0).get_node()->input_value(0).get_node_shared_ptr(); - std::shared_ptr matmul = std::dynamic_pointer_cast(last_node); + std::shared_ptr matmul = ov::as_type_ptr(last_node); + + // in case of PA all tokens are moved to batch dimension and we have to slice / gather accordingly + const bool pa_based_model = has_op_with_type(model, "PagedAttentionExtension"); + int64_t slice_gather_dim = pa_based_model ? 0 : 1; + // There are several patterns for matmul we are looking for: // Matmul -> Result // Matmul -> Add -> Result // Matmul -> Transpose -> Result // MatMul -> Divide -> Tanh -> Multiply -> Result if (!matmul) { - if(auto add = std::dynamic_pointer_cast(last_node)) { - matmul = std::dynamic_pointer_cast(add->input_value(0).get_node_shared_ptr()); - } else if (auto transpose = std::dynamic_pointer_cast(last_node)) { - matmul = std::dynamic_pointer_cast(transpose->input_value(0).get_node_shared_ptr()); - } else if (auto multiply = std::dynamic_pointer_cast(last_node)) { - if (auto tanh = std::dynamic_pointer_cast(multiply->input_value(0).get_node_shared_ptr())) { - if (auto divide = std::dynamic_pointer_cast(tanh->input_value(0).get_node_shared_ptr())) { - matmul = std::dynamic_pointer_cast(divide->input_value(0).get_node_shared_ptr()); + if (auto add = ov::as_type_ptr(last_node)) { + matmul = ov::as_type_ptr(add->input_value(0).get_node_shared_ptr()); + } else if (auto transpose = ov::as_type_ptr(last_node)) { + matmul = ov::as_type_ptr(transpose->input_value(0).get_node_shared_ptr()); + auto order = ov::as_type_ptr(transpose->input_value(1).get_node_shared_ptr())->get_axis_vector_val(); + slice_gather_dim = order[slice_gather_dim]; + } else if (auto multiply = ov::as_type_ptr(last_node)) { + if (auto tanh = ov::as_type_ptr(multiply->input_value(0).get_node_shared_ptr())) { + if (auto divide = ov::as_type_ptr(tanh->input_value(0).get_node_shared_ptr())) { + matmul = as_type_ptr(divide->input_value(0).get_node_shared_ptr()); } } } } - return matmul; + return std::make_tuple(matmul, slice_gather_dim); } + } // namespace void apply_slice_before_matmul_transformation(std::shared_ptr model) { - auto matmul = find_llm_matmul(model); + std::shared_ptr matmul = nullptr; + int64_t slice_gather_dim = -1; + std::tie(matmul, slice_gather_dim) = find_llm_matmul(model); + if (matmul && matmul->input(0).get_partial_shape().rank().get_length() == 3) { auto start = std::make_shared(ov::element::i64, ov::Shape{1}, std::vector{-1}); auto stop = std::make_shared(ov::element::i64, ov::Shape{1}, std::vector{-2}); auto step = std::make_shared(ov::element::i64, ov::Shape{1}, std::vector{-1}); - auto axis = std::make_shared(ov::element::i64, ov::Shape{1}, std::vector{1}); + auto axis = std::make_shared(ov::element::i64, ov::Shape{1}, std::vector{slice_gather_dim}); auto slice = std::make_shared(matmul->input_value(0), start, stop, step, axis); matmul->input(0).replace_source_output(slice); } } void apply_gather_before_matmul_transformation(std::shared_ptr model) { - auto matmul = ov::genai::utils::find_llm_matmul(model); + std::shared_ptr matmul = nullptr; + int64_t slice_gather_dim = -1; + std::tie(matmul, slice_gather_dim) = find_llm_matmul(model); + if (matmul && matmul->input(0).get_partial_shape().rank().get_length() == 3) { auto indices = std::make_shared(ov::element::i64, ov::PartialShape{-1}); indices->set_friendly_name("sampled_tokens_indices"); indices->output(0).get_tensor().set_names({"sampled_tokens_indices"}); - auto axis = std::make_shared(ov::element::i64, ov::Shape{1}, std::vector{0}); + auto axis = std::make_shared(ov::element::i64, ov::Shape{1}, std::vector{slice_gather_dim}); auto gather = std::make_shared(matmul->input_value(0), indices, axis); matmul->input(0).replace_source_output(gather); model->add_parameters({indices});