Skip to content

Commit

Permalink
scripts/vsmlrt.py: add fp16_blacklist_ops in ov and ort backends
Browse files Browse the repository at this point in the history
  • Loading branch information
WolframRhodium committed Jul 24, 2022
1 parent 7f41f48 commit 99ebde4
Showing 1 changed file with 15 additions and 5 deletions.
20 changes: 15 additions & 5 deletions scripts/vsmlrt.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,9 @@ class ORT_CPU:
num_streams: int = 1
verbosity: int = 2
fp16: bool = False
fp16_blacklist_ops: typing.Optional[typing.Sequence[str]] = None

# internal backend attributes
supports_onnx_serialization: bool = True

@dataclass(frozen=False)
Expand All @@ -63,15 +65,19 @@ class ORT_CUDA:
verbosity: int = 2
fp16: bool = False
use_cuda_graph: bool = False # preview, not supported by all models
fp16_blacklist_ops: typing.Optional[typing.Sequence[str]] = None

# internal backend attributes
supports_onnx_serialization: bool = True

@dataclass(frozen=False)
class OV_CPU:
fp16: bool = False
num_streams: typing.Union[int, str] = 1
bind_thread: bool = True
fp16_blacklist_ops: typing.Optional[typing.Sequence[str]] = None

# internal backend attributes
supports_onnx_serialization: bool = True

@dataclass(frozen=False)
Expand All @@ -92,9 +98,9 @@ class TRT:

# as of TensorRT 8.4, it can be turned off without performance penalty in most cases
use_cudnn: bool = True

use_edge_mask_convolutions: bool = True

# internal backend attributes
_channels: int = field(init=False, repr=False, compare=False)
supports_onnx_serialization: bool = False

Expand Down Expand Up @@ -895,7 +901,8 @@ def inference(
num_streams=backend.num_streams,
verbosity=backend.verbosity,
fp16=backend.fp16,
path_is_serialization=path_is_serialization
path_is_serialization=path_is_serialization,
fp16_blacklist_ops=backend.fp16_blacklist_ops
)
elif isinstance(backend, Backend.ORT_CUDA):
clip = core.ort.Model(
Expand All @@ -908,7 +915,8 @@ def inference(
cudnn_benchmark=backend.cudnn_benchmark,
fp16=backend.fp16,
path_is_serialization=path_is_serialization,
use_cuda_graph=backend.use_cuda_graph
use_cuda_graph=backend.use_cuda_graph,
fp16_blacklist_ops=backend.fp16_blacklist_ops
)
elif isinstance(backend, Backend.OV_CPU):
config = lambda: dict(
Expand All @@ -921,7 +929,8 @@ def inference(
device="CPU", builtin=False,
fp16=backend.fp16,
config=config,
path_is_serialization=path_is_serialization
path_is_serialization=path_is_serialization,
fp16_blacklist_ops=backend.fp16_blacklist_ops
)
elif isinstance(backend, Backend.OV_GPU):
config = lambda: dict(
Expand All @@ -933,7 +942,8 @@ def inference(
device=f"GPU.{backend.device_id}", builtin=False,
fp16=backend.fp16,
config=config,
path_is_serialization=path_is_serialization
path_is_serialization=path_is_serialization,
fp16_blacklist_ops=backend.fp16_blacklist_ops
)
elif isinstance(backend, Backend.TRT):
if path_is_serialization:
Expand Down

0 comments on commit 99ebde4

Please sign in to comment.