Skip to content

Commit

Permalink
Type mismatch in onnx quantize_with_accuracy_control validation_fn (#…
Browse files Browse the repository at this point in the history
…2963)

### Changes

- Align type hints for the `quantize_with_accuracy_control()` method

### Reason for changes

Issue: #2957
  • Loading branch information
andrey-churkin authored Sep 12, 2024
1 parent 4401d44 commit 6a8c366
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 61 deletions.
61 changes: 1 addition & 60 deletions nncf/openvino/quantization/quantize_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ def native_quantize_impl(
@tracked_function(
NNCF_OV_CATEGORY, [CompressionStartedWithQuantizeApi(), "target_device", "preset", "max_drop", "drop_type"]
)
def native_quantize_with_accuracy_control_impl(
def quantize_with_accuracy_control_impl(
model: ov.Model,
calibration_dataset: Dataset,
validation_dataset: Dataset,
Expand Down Expand Up @@ -365,65 +365,6 @@ def quantize_impl(
)


def wrap_validation_fn(validation_fn):
"""
Wraps validation function to support case when it only returns metric value.
:param validation_fn: Validation function to wrap.
:return: Wrapped validation function.
"""

def wrapper(*args, **kwargs):
retval = validation_fn(*args, **kwargs)
if isinstance(retval, tuple):
return retval
return retval, None

return wrapper


def quantize_with_accuracy_control_impl(
model: ov.Model,
calibration_dataset: Dataset,
validation_dataset: Dataset,
validation_fn: Callable[[Any, Iterable[Any]], float],
max_drop: float = 0.01,
drop_type: DropType = DropType.ABSOLUTE,
preset: Optional[QuantizationPreset] = None,
target_device: TargetDevice = TargetDevice.ANY,
subset_size: int = 300,
fast_bias_correction: bool = True,
model_type: Optional[ModelType] = None,
ignored_scope: Optional[IgnoredScope] = None,
advanced_quantization_parameters: Optional[AdvancedQuantizationParameters] = None,
advanced_accuracy_restorer_parameters: Optional[AdvancedAccuracyRestorerParameters] = None,
) -> ov.Model:
"""
Implementation of the `quantize_with_accuracy_control()` method for the OpenVINO backend.
"""

quantize_with_accuracy_control_fn = native_quantize_with_accuracy_control_impl

val_func = wrap_validation_fn(validation_fn)

return quantize_with_accuracy_control_fn(
model,
calibration_dataset,
validation_dataset,
val_func,
max_drop,
drop_type,
preset,
target_device,
subset_size,
fast_bias_correction,
model_type,
ignored_scope,
advanced_quantization_parameters,
advanced_accuracy_restorer_parameters,
)


def compress_weights_impl(
model: ov.Model,
dataset: Dataset,
Expand Down
22 changes: 21 additions & 1 deletion nncf/quantization/quantize_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,12 +246,29 @@ def quantize(
raise nncf.UnsupportedBackendError(f"Unsupported type of backend: {backend}")


def wrap_validation_fn(validation_fn):
"""
Wraps validation function to support case when it only returns metric value.
:param validation_fn: Validation function to wrap.
:return: Wrapped validation function.
"""

def wrapper(*args, **kwargs):
retval = validation_fn(*args, **kwargs)
if isinstance(retval, tuple):
return retval
return retval, None

return wrapper


@api(canonical_alias="nncf.quantize_with_accuracy_control")
def quantize_with_accuracy_control(
model: TModel,
calibration_dataset: Dataset,
validation_dataset: Dataset,
validation_fn: Callable[[Any, Iterable[Any]], float],
validation_fn: Callable[[Any, Iterable[Any]], Tuple[float, Union[None, List[float], List[List[TTensor]]]]],
max_drop: float = 0.01,
drop_type: DropType = DropType.ABSOLUTE,
preset: Optional[QuantizationPreset] = None,
Expand Down Expand Up @@ -316,6 +333,9 @@ def quantize_with_accuracy_control(
)

backend = get_backend(model)

validation_fn = wrap_validation_fn(validation_fn)

if backend == BackendType.OPENVINO:
from nncf.openvino.quantization.quantize_model import quantize_with_accuracy_control_impl

Expand Down

0 comments on commit 6a8c366

Please sign in to comment.