Skip to content

Commit

Permalink
scripts/vsmlrt.py: add custom_env for user-specified environment in…
Browse files Browse the repository at this point in the history
… trtexec execution
  • Loading branch information
WolframRhodium committed Oct 21, 2023
1 parent 1549bdf commit 75c6d6c
Showing 1 changed file with 11 additions and 4 deletions.
15 changes: 11 additions & 4 deletions scripts/vsmlrt.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
]

import copy
from dataclasses import dataclass
from dataclasses import dataclass, field
import enum
import math
import os
Expand Down Expand Up @@ -142,6 +142,7 @@ class TRT:
int8: bool = False
bf16: bool = False
fp8: bool = False
custom_env: typing.Dict[str, str] = field(default_factory=lambda: {})

# internal backend attributes
supports_onnx_serialization: bool = False
Expand Down Expand Up @@ -1167,7 +1168,8 @@ def trtexec(
short_path: typing.Optional[bool] = None,
int8: bool = False,
bf16: bool = False,
fp8: bool = False
fp8: bool = False,
custom_env: typing.Dict[str, str] = {}
) -> str:

# tensort runtime version, e.g. 8401 => 8.4.1
Expand Down Expand Up @@ -1355,6 +1357,7 @@ def trtexec(
if prev_env_value is not None and len(prev_env_value) > 0:
# env_key has been set, no extra action
env = {env_key: prev_env_value, "CUDA_MODULE_LOADING": "LAZY"}
env.update(**custom_env)
subprocess.run(args, env=env, check=True, stdout=sys.stderr)
else:
time_str = time.strftime('%y%m%d_%H%M%S', time.localtime())
Expand All @@ -1365,6 +1368,7 @@ def trtexec(
)

env = {env_key: log_filename, "CUDA_MODULE_LOADING": "LAZY"}
env.update(**custom_env)

completed_process = subprocess.run(args, env=env, check=False, stdout=sys.stderr)

Expand All @@ -1380,7 +1384,9 @@ def trtexec(
else:
raise RuntimeError(f"trtexec execution fails but no log is found")
else:
subprocess.run(args, env={"CUDA_MODULE_LOADING": "LAZY"}, check=True, stdout=sys.stderr)
env = {"CUDA_MODULE_LOADING": "LAZY"}
env.update(**custom_env)
subprocess.run(args, env=custom_env, check=True, stdout=sys.stderr)

return engine_path

Expand Down Expand Up @@ -1580,7 +1586,8 @@ def _inference(
short_path=backend.short_path,
int8=backend.int8,
bf16=backend.bf16,
fp8=backend.fp8
fp8=backend.fp8,
custom_env=backend.custom_env
)
clip = core.trt.Model(
clips, engine_path,
Expand Down

0 comments on commit 75c6d6c

Please sign in to comment.