From 75c6d6cf43ba7ae6891d746acc0251ea51bd196a Mon Sep 17 00:00:00 2001 From: WolframRhodium Date: Sat, 21 Oct 2023 08:54:52 +0800 Subject: [PATCH] scripts/vsmlrt.py: add `custom_env` for user-specified environment in trtexec execution --- scripts/vsmlrt.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/scripts/vsmlrt.py b/scripts/vsmlrt.py index 822e44b..d9ecfd6 100644 --- a/scripts/vsmlrt.py +++ b/scripts/vsmlrt.py @@ -12,7 +12,7 @@ ] import copy -from dataclasses import dataclass +from dataclasses import dataclass, field import enum import math import os @@ -142,6 +142,7 @@ class TRT: int8: bool = False bf16: bool = False fp8: bool = False + custom_env: typing.Dict[str, str] = field(default_factory=lambda: {}) # internal backend attributes supports_onnx_serialization: bool = False @@ -1167,7 +1168,8 @@ def trtexec( short_path: typing.Optional[bool] = None, int8: bool = False, bf16: bool = False, - fp8: bool = False + fp8: bool = False, + custom_env: typing.Dict[str, str] = {} ) -> str: # tensort runtime version, e.g. 8401 => 8.4.1 @@ -1355,6 +1357,7 @@ def trtexec( if prev_env_value is not None and len(prev_env_value) > 0: # env_key has been set, no extra action env = {env_key: prev_env_value, "CUDA_MODULE_LOADING": "LAZY"} + env.update(**custom_env) subprocess.run(args, env=env, check=True, stdout=sys.stderr) else: time_str = time.strftime('%y%m%d_%H%M%S', time.localtime()) @@ -1365,6 +1368,7 @@ def trtexec( ) env = {env_key: log_filename, "CUDA_MODULE_LOADING": "LAZY"} + env.update(**custom_env) completed_process = subprocess.run(args, env=env, check=False, stdout=sys.stderr) @@ -1380,7 +1384,9 @@ def trtexec( else: raise RuntimeError(f"trtexec execution fails but no log is found") else: - subprocess.run(args, env={"CUDA_MODULE_LOADING": "LAZY"}, check=True, stdout=sys.stderr) + env = {"CUDA_MODULE_LOADING": "LAZY"} + env.update(**custom_env) + subprocess.run(args, env=custom_env, check=True, stdout=sys.stderr) return engine_path @@ -1580,7 +1586,8 @@ def _inference( short_path=backend.short_path, int8=backend.int8, bf16=backend.bf16, - fp8=backend.fp8 + fp8=backend.fp8, + custom_env=backend.custom_env ) clip = core.trt.Model( clips, engine_path,