Skip to content

Commit

Permalink
Symbol use in optimum: fix misprint (#948)
Browse files Browse the repository at this point in the history
* Symbol use in optimum: fix misprint

* fix wrong filling chatglm position_ids input

* fix cutting position ids

---------

Co-authored-by: eaidova <[email protected]>
  • Loading branch information
jane-intel and eaidova authored Oct 22, 2024
1 parent 03a59aa commit dcb49ea
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 5 deletions.
2 changes: 1 addition & 1 deletion optimum/exporters/openvino/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def _get_input_info(
symbol = name_to_symbol[dim_name]
else:
symbol = Symbol()
name_to_symbol[name] = symbol
name_to_symbol[dim_name] = symbol
dim = Dimension(-1)
dim.set_symbol(symbol)
shape[idx] = dim
Expand Down
26 changes: 22 additions & 4 deletions optimum/intel/openvino/modeling_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import os
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union

import numpy as np
import openvino
Expand All @@ -31,7 +31,7 @@
from transformers.generation.logits_process import LogitsProcessorList
from transformers.generation.stopping_criteria import StoppingCriteriaList
from transformers.generation.utils import GenerateOutput, GenerationMode
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.modeling_outputs import CausalLMOutputWithPast, ModelOutput

from optimum.utils.normalized_config import NormalizedConfigManager

Expand Down Expand Up @@ -504,8 +504,8 @@ def prepare_inputs(
else:
position_ids = np.cumsum(attention_mask, axis=1) - 1
position_ids[attention_mask == 0] = 1
if past_key_values:
position_ids = position_ids[:, -input_ids.shape[1] :]
if past_key_values:
position_ids = position_ids[:, -input_ids.shape[1] :]

inputs["position_ids"] = position_ids

Expand Down Expand Up @@ -604,6 +604,24 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwarg

return model_inputs

def _update_model_kwargs_for_generation(
self,
outputs: ModelOutput,
model_kwargs: Dict[str, Any],
is_encoder_decoder: bool = False,
**kwargs,
) -> Dict[str, Any]:
model_kwargs = super()._update_model_kwargs_for_generation(
outputs=outputs, model_kwargs=model_kwargs, is_encoder_decoder=is_encoder_decoder, **kwargs
)

if "position_ids" in model_kwargs:
position_ids = model_kwargs["position_ids"]
new_position_id = position_ids[..., -1:].clone()
new_position_id += 1
model_kwargs["position_ids"] = torch.cat([position_ids, new_position_id], dim=-1)
return model_kwargs

def _expand_outputs_for_generation(self, indicies, logits: torch.Tensor, past_key_values: Tuple):
batch_size = logits.shape[0]
if indicies.shape[0] != 1:
Expand Down

0 comments on commit dcb49ea

Please sign in to comment.