Skip to content

Commit

Permalink
scripts/vsmlrt.py: remove int8 and fp8 support
Browse files Browse the repository at this point in the history
  • Loading branch information
WolframRhodium committed Oct 21, 2023
1 parent 75c6d6c commit daf9620
Showing 1 changed file with 2 additions and 22 deletions.
24 changes: 2 additions & 22 deletions scripts/vsmlrt.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "3.18.0"
__version__ = "3.18.1"

__all__ = [
"Backend", "BackendV2",
Expand Down Expand Up @@ -139,9 +139,7 @@ class TRT:
builder_optimization_level: int = 3
max_aux_streams: typing.Optional[int] = None
short_path: typing.Optional[bool] = None # True on Windows by default, False otherwise
int8: bool = False
bf16: bool = False
fp8: bool = False
custom_env: typing.Dict[str, str] = field(default_factory=lambda: {})

# internal backend attributes
Expand Down Expand Up @@ -1088,9 +1086,7 @@ def get_engine_path(
builder_optimization_level: int,
max_aux_streams: typing.Optional[int],
short_path: typing.Optional[bool],
int8: bool,
bf16: bool,
fp8: bool
bf16: bool
) -> str:

with open(network_path, "rb") as file:
Expand All @@ -1117,9 +1113,7 @@ def get_engine_path(
shape_str +
("_fp16" if fp16 else "") +
("_tf32" if tf32 else "") +
("_int8" if int8 else "") +
("_bf16" if bf16 else "") +
("_fp8" if fp8 else "") +
(f"_workspace{workspace}" if workspace is not None else "") +
f"_opt{builder_optimization_level}" +
(f"_max-aux-streams{max_aux_streams}" if max_aux_streams is not None else "") +
Expand Down Expand Up @@ -1166,9 +1160,7 @@ def trtexec(
builder_optimization_level: int = 3,
max_aux_streams: typing.Optional[int] = None,
short_path: typing.Optional[bool] = None,
int8: bool = False,
bf16: bool = False,
fp8: bool = False,
custom_env: typing.Dict[str, str] = {}
) -> str:

Expand All @@ -1184,9 +1176,7 @@ def trtexec(
if force_fp16:
fp16 = True
tf32 = False
int8 = False
bf16 = False
fp8 = False

engine_path = get_engine_path(
network_path=network_path,
Expand All @@ -1205,9 +1195,7 @@ def trtexec(
builder_optimization_level=builder_optimization_level,
max_aux_streams=max_aux_streams,
short_path=short_path,
int8=int8,
bf16=bf16,
fp8=fp8
)

if os.access(engine_path, mode=os.R_OK):
Expand Down Expand Up @@ -1340,15 +1328,9 @@ def trtexec(
if max_aux_streams is not None:
args.append(f"--maxAuxStreams={max_aux_streams}")

if int8:
args.append("--int8")

if trt_version >= 9000:
if bf16:
args.append("--bf16")

if fp8:
args.append("--fp8")

if log:
env_key = "TRTEXEC_LOG_FILE"
Expand Down Expand Up @@ -1584,9 +1566,7 @@ def _inference(
builder_optimization_level=backend.builder_optimization_level,
max_aux_streams=backend.max_aux_streams,
short_path=backend.short_path,
int8=backend.int8,
bf16=backend.bf16,
fp8=backend.fp8,
custom_env=backend.custom_env
)
clip = core.trt.Model(
Expand Down

0 comments on commit daf9620

Please sign in to comment.