Skip to content

Commit

Permalink
scripts/vsmlrt.py: add max_tactics option to the TRT backend
Browse files Browse the repository at this point in the history
  • Loading branch information
WolframRhodium committed Sep 7, 2024
1 parent 5ae95bc commit 28325f3
Showing 1 changed file with 9 additions and 2 deletions.
11 changes: 9 additions & 2 deletions scripts/vsmlrt.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "3.21.19"
__version__ = "3.21.20"

__all__ = [
"Backend", "BackendV2",
Expand Down Expand Up @@ -185,6 +185,7 @@ class TRT:
custom_env: typing.Dict[str, str] = field(default_factory=lambda: {})
custom_args: typing.List[str] = field(default_factory=lambda: [])
engine_folder: typing.Optional[str] = None
max_tactics: typing.Optional[int] = None

# internal backend attributes
supports_onnx_serialization: bool = False
Expand Down Expand Up @@ -1884,7 +1885,8 @@ def trtexec(
bf16: bool = False,
custom_env: typing.Dict[str, str] = {},
custom_args: typing.List[str] = [],
engine_folder: typing.Optional[str] = None
engine_folder: typing.Optional[str] = None,
max_tactics: typing.Optional[int] = None
) -> str:

# tensort runtime version
Expand Down Expand Up @@ -2062,6 +2064,10 @@ def trtexec(
if bf16:
args.append("--bf16")

if trt_version >= (10, 4, 0):
if max_tactics is not None:
args.append(f"--maxTactics={max_tactics}")

args.extend(custom_args)

if log:
Expand Down Expand Up @@ -2479,6 +2485,7 @@ def _inference(
custom_env=backend.custom_env,
custom_args=backend.custom_args,
engine_folder=backend.engine_folder,
max_tactics=backend.max_tactics,
)
ret = core.trt.Model(
clips, engine_path,
Expand Down

0 comments on commit 28325f3

Please sign in to comment.