Skip to content

Commit

Permalink
scripts/vsmlrt.py: add tf32 flag to the ort_cuda backend
Browse files Browse the repository at this point in the history
  • Loading branch information
WolframRhodium committed Apr 20, 2024
1 parent b83941d commit e1826de
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion scripts/vsmlrt.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "3.20.8"
__version__ = "3.20.9"

__all__ = [
"Backend", "BackendV2",
Expand Down Expand Up @@ -101,6 +101,7 @@ class ORT_CUDA:
fp16_blacklist_ops: typing.Optional[typing.Sequence[str]] = None
prefer_nhwc: bool = False
output_format: int = 0 # 0: fp32, 1: fp16
tf32: bool = False

# internal backend attributes
supports_onnx_serialization: bool = True
Expand Down Expand Up @@ -2057,6 +2058,7 @@ def _inference(
if version >= (1, 18, 0):
kwargs["prefer_nhwc"] = backend.prefer_nhwc
kwargs["output_format"] = backend.output_format
kwargs["tf32"] = backend.tf32

clip = core.ort.Model(
clips, network_path,
Expand Down

0 comments on commit e1826de

Please sign in to comment.