diff --git a/Dockerfile b/Dockerfile index e97836e3ce..e45932c6bb 100644 --- a/Dockerfile +++ b/Dockerfile @@ -56,4 +56,5 @@ RUN apt-get update \ && rm -rf /var/lib/apt/lists/* # append /opt/tools to runtime path for NGC CLI to be accessible from all file system locations ENV PATH=${PATH}:/opt/tools +ENV POLYGRAPHY_AUTOINSTALL_DEPS=1 WORKDIR /opt/monai diff --git a/docs/source/config_syntax.md b/docs/source/config_syntax.md index c932879b5a..742841acca 100644 --- a/docs/source/config_syntax.md +++ b/docs/source/config_syntax.md @@ -16,6 +16,7 @@ Content: - [`$` to evaluate as Python expressions](#to-evaluate-as-python-expressions) - [`%` to textually replace configuration elements](#to-textually-replace-configuration-elements) - [`_target_` (`_disabled_`, `_desc_`, `_requires_`, `_mode_`) to instantiate a Python object](#instantiate-a-python-object) + - [`+` to alter semantics of merging config keys from multiple configuration files](#multiple-config-files) - [The command line interface](#the-command-line-interface) - [Recommendations](#recommendations) @@ -175,6 +176,47 @@ _Description:_ `_requires_`, `_disabled_`, `_desc_`, and `_mode_` are optional k - `"debug"` -- execute with debug prompt and return the return value of ``pdb.runcall(_target_, **kwargs)``, see also [`pdb.runcall`](https://docs.python.org/3/library/pdb.html#pdb.runcall). +## Multiple config files + +_Description:_ Multiple config files may be specified on the command line. +The content of those config files is being merged. When same keys are specifiled in more than one config file, +the value associated with the key is being overridden, in the order config files are specified. +If the desired behaviour is to merge values from both files, the key in second config file should be prefixed with `+`. +The value types for the merged contents must match and be both of `dict` or both of `list` type. +`dict` values will be merged via update(), `list` values - concatenated via extend(). +Here's an example. In this case, "amp" value will be overridden by extra_config.json. +`imports` and `preprocessing#transforms` lists will be merged. An error would be thrown if the value type in `"+imports"` is not `list`: + +config.json: +```json +{ + "amp": "$True" + "imports": [ + "$import torch" + ], + "preprocessing": { + "_target_": "Compose", + "transforms": [ + "$@t1", + "$@t2" + ] + }, +} +``` + +extra_config.json: +```json +{ + "amp": "$False" + "+imports": [ + "$from monai.networks import trt_compile" + ], + "+preprocessing#transforms": [ + "$@t3" + ] +} +``` + ## The command line interface In addition to the Pythonic APIs, a few command line interfaces (CLI) are provided to interact with the bundle. diff --git a/monai/bundle/config_parser.py b/monai/bundle/config_parser.py index a2ffeedc92..1d9920a230 100644 --- a/monai/bundle/config_parser.py +++ b/monai/bundle/config_parser.py @@ -20,7 +20,7 @@ from monai.bundle.config_item import ComponentLocator, ConfigComponent, ConfigExpression, ConfigItem from monai.bundle.reference_resolver import ReferenceResolver -from monai.bundle.utils import ID_REF_KEY, ID_SEP_KEY, MACRO_KEY +from monai.bundle.utils import ID_REF_KEY, ID_SEP_KEY, MACRO_KEY, merge_kv from monai.config import PathLike from monai.utils import ensure_tuple, look_up_option, optional_import from monai.utils.misc import CheckKeyDuplicatesYamlLoader, check_key_duplicates @@ -423,8 +423,10 @@ def load_config_files(cls, files: PathLike | Sequence[PathLike] | dict, **kwargs if isinstance(files, str) and not Path(files).is_file() and "," in files: files = files.split(",") for i in ensure_tuple(files): - for k, v in (cls.load_config_file(i, **kwargs)).items(): - parser[k] = v + config_dict = cls.load_config_file(i, **kwargs) + for k, v in config_dict.items(): + merge_kv(parser, k, v) + return parser.get() # type: ignore @classmethod diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index 142a366669..f1d1286e4b 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -32,7 +32,7 @@ from monai.apps.utils import _basename, download_url, extractall, get_logger from monai.bundle.config_item import ConfigComponent from monai.bundle.config_parser import ConfigParser -from monai.bundle.utils import DEFAULT_INFERENCE, DEFAULT_METADATA +from monai.bundle.utils import DEFAULT_INFERENCE, DEFAULT_METADATA, merge_kv from monai.bundle.workflows import BundleWorkflow, ConfigWorkflow from monai.config import IgniteInfo, PathLike from monai.data import load_net_with_metadata, save_net_with_metadata @@ -105,7 +105,7 @@ def update_kwargs(args: str | dict | None = None, ignore_none: bool = True, **kw if isinstance(v, dict) and isinstance(args_.get(k), dict): args_[k] = update_kwargs(args_[k], ignore_none, **v) else: - args_[k] = v + merge_kv(args_, k, v) return args_ diff --git a/monai/bundle/utils.py b/monai/bundle/utils.py index 50d2608f4c..53d619f234 100644 --- a/monai/bundle/utils.py +++ b/monai/bundle/utils.py @@ -13,6 +13,7 @@ import json import os +import warnings import zipfile from typing import Any @@ -21,12 +22,21 @@ yaml, _ = optional_import("yaml") -__all__ = ["ID_REF_KEY", "ID_SEP_KEY", "EXPR_KEY", "MACRO_KEY", "DEFAULT_MLFLOW_SETTINGS", "DEFAULT_EXP_MGMT_SETTINGS"] +__all__ = [ + "ID_REF_KEY", + "ID_SEP_KEY", + "EXPR_KEY", + "MACRO_KEY", + "MERGE_KEY", + "DEFAULT_MLFLOW_SETTINGS", + "DEFAULT_EXP_MGMT_SETTINGS", +] ID_REF_KEY = "@" # start of a reference to a ConfigItem ID_SEP_KEY = "::" # separator for the ID of a ConfigItem EXPR_KEY = "$" # start of a ConfigExpression MACRO_KEY = "%" # start of a macro of a config +MERGE_KEY = "+" # prefix indicating merge instead of override in case of multiple configs. _conf_values = get_config_values() @@ -233,3 +243,27 @@ def load_bundle_config(bundle_path: str, *config_names: str, **load_kw_args: Any parser.read_config(f=cdata) return parser + + +def merge_kv(args: dict | Any, k: str, v: Any) -> None: + """ + Update the `args` dict-like object with the key/value pair `k` and `v`. + """ + if k.startswith(MERGE_KEY): + """ + Both values associated with `+`-prefixed key pair must be of `dict` or `list` type. + `dict` values will be merged, `list` values - concatenated. + """ + id = k[1:] + if id in args: + if isinstance(v, dict) and isinstance(args[id], dict): + args[id].update(v) + elif isinstance(v, list) and isinstance(args[id], list): + args[id].extend(v) + else: + raise ValueError(ValueError(f"config must be dict or list for key `{k}`, but got {type(v)}: {v}.")) + else: + warnings.warn(f"Can't merge entry ['{k}'], '{id}' is not in target dict - copying instead.") + args[id] = v + else: + args[k] = v diff --git a/monai/handlers/__init__.py b/monai/handlers/__init__.py index 641f9aae7d..fa6e158be8 100644 --- a/monai/handlers/__init__.py +++ b/monai/handlers/__init__.py @@ -40,5 +40,6 @@ from .stats_handler import StatsHandler from .surface_distance import SurfaceDistance from .tensorboard_handlers import TensorBoardHandler, TensorBoardImageHandler, TensorBoardStatsHandler +from .trt_handler import TrtHandler from .utils import from_engine, ignore_data, stopping_fn_from_loss, stopping_fn_from_metric, write_metrics_reports from .validation_handler import ValidationHandler diff --git a/monai/handlers/trt_handler.py b/monai/handlers/trt_handler.py new file mode 100644 index 0000000000..0e36b59d8c --- /dev/null +++ b/monai/handlers/trt_handler.py @@ -0,0 +1,61 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from monai.config import IgniteInfo +from monai.networks import trt_compile +from monai.utils import min_version, optional_import + +Events, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Events") +if TYPE_CHECKING: + from ignite.engine import Engine +else: + Engine, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Engine") + + +class TrtHandler: + """ + TrtHandler acts as an Ignite handler to apply TRT acceleration to the model. + Usage example:: + handler = TrtHandler(model=model, base_path="/test/checkpoint.pt", args={"precision": "fp16"}) + handler.attach(engine) + engine.run() + """ + + def __init__(self, model, base_path, args=None, submodule=None): + """ + Args: + base_path: TRT path basename. TRT plan(s) saved to "base_path[.submodule].plan" + args: passed to trt_compile(). See trt_compile() for details. + submodule : Hierarchical ids of submodules to convert, e.g. 'image_decoder.decoder' + """ + self.model = model + self.base_path = base_path + self.args = args + self.submodule = submodule + + def attach(self, engine: Engine) -> None: + """ + Args: + engine: Ignite Engine, it can be a trainer, validator or evaluator. + """ + self.logger = engine.logger + engine.add_event_handler(Events.STARTED, self) + + def __call__(self, engine: Engine) -> None: + """ + Args: + engine: Ignite Engine, it can be a trainer, validator or evaluator. + """ + trt_compile(self.model, self.base_path, args=self.args, submodule=self.submodule, logger=self.logger) diff --git a/monai/networks/__init__.py b/monai/networks/__init__.py index 4c429ae813..5a240021d6 100644 --- a/monai/networks/__init__.py +++ b/monai/networks/__init__.py @@ -11,7 +11,9 @@ from __future__ import annotations +from .trt_compiler import trt_compile from .utils import ( + add_casts_around_norms, convert_to_onnx, convert_to_torchscript, convert_to_trt, diff --git a/monai/networks/nets/swin_unetr.py b/monai/networks/nets/swin_unetr.py index 3900c866b3..714d986f4b 100644 --- a/monai/networks/nets/swin_unetr.py +++ b/monai/networks/nets/swin_unetr.py @@ -320,7 +320,7 @@ def _check_input_size(self, spatial_shape): ) def forward(self, x_in): - if not torch.jit.is_scripting(): + if not torch.jit.is_scripting() and not torch.jit.is_tracing(): self._check_input_size(x_in.shape[2:]) hidden_states_out = self.swinViT(x_in, self.normalize) enc0 = self.encoder1(x_in) @@ -1046,14 +1046,14 @@ def __init__( def proj_out(self, x, normalize=False): if normalize: - x_shape = x.size() + x_shape = x.shape + # Force trace() to generate a constant by casting to int + ch = int(x_shape[1]) if len(x_shape) == 5: - n, ch, d, h, w = x_shape x = rearrange(x, "n c d h w -> n d h w c") x = F.layer_norm(x, [ch]) x = rearrange(x, "n d h w c -> n c d h w") elif len(x_shape) == 4: - n, ch, h, w = x_shape x = rearrange(x, "n c h w -> n h w c") x = F.layer_norm(x, [ch]) x = rearrange(x, "n h w c -> n c h w") diff --git a/monai/networks/trt_compiler.py b/monai/networks/trt_compiler.py new file mode 100644 index 0000000000..a9dd0d9e9b --- /dev/null +++ b/monai/networks/trt_compiler.py @@ -0,0 +1,565 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import inspect +import os +import tempfile +import threading +from collections import OrderedDict +from pathlib import Path +from types import MethodType +from typing import Any, Dict, List, Union + +import torch + +from monai.apps.utils import get_logger +from monai.networks.utils import add_casts_around_norms, convert_to_onnx, convert_to_torchscript, get_profile_shapes +from monai.utils.module import optional_import + +polygraphy, polygraphy_imported = optional_import("polygraphy") +if polygraphy_imported: + from polygraphy.backend.common import bytes_from_path + from polygraphy.backend.trt import ( + CreateConfig, + Profile, + engine_bytes_from_network, + engine_from_bytes, + network_from_onnx_path, + ) + +trt, trt_imported = optional_import("tensorrt") +torch_tensorrt, _ = optional_import("torch_tensorrt", "1.4.0") +cudart, _ = optional_import("cuda.cudart") + + +lock_sm = threading.Lock() + + +# Map of TRT dtype -> Torch dtype +def trt_to_torch_dtype_dict(): + return { + trt.int32: torch.int32, + trt.float32: torch.float32, + trt.float16: torch.float16, + trt.bfloat16: torch.float16, + trt.int64: torch.int64, + trt.int8: torch.int8, + trt.bool: torch.bool, + } + + +def get_dynamic_axes(profiles): + """ + This method calculates dynamic_axes to use in onnx.export(). + Args: + profiles: [[min,opt,max],...] list of profile dimensions + """ + dynamic_axes: dict[str, list[int]] = {} + if not profiles: + return dynamic_axes + for profile in profiles: + for key in profile: + axes = [] + vals = profile[key] + for i in range(len(vals[0])): + if vals[0][i] != vals[2][i]: + axes.append(i) + if len(axes) > 0: + dynamic_axes[key] = axes + return dynamic_axes + + +def cuassert(cuda_ret): + """ + Error reporting method for CUDA calls. + Args: + cuda_ret: CUDA return code. + """ + err = cuda_ret[0] + if err != 0: + raise RuntimeError(f"CUDA ERROR: {err}") + if len(cuda_ret) > 1: + return cuda_ret[1] + return None + + +class ShapeError(Exception): + """ + Exception class to report errors from setting TRT plan input shapes + """ + + pass + + +class TRTEngine: + """ + An auxiliary class to implement running of TRT optimized engines + + """ + + def __init__(self, plan_path, logger=None): + """ + Loads serialized engine, creates execution context and activates it + Args: + plan_path: path to serialized TRT engine. + logger: optional logger object + """ + self.plan_path = plan_path + self.logger = logger or get_logger("trt_compile") + self.logger.info(f"Loading TensorRT engine: {self.plan_path}") + self.engine = engine_from_bytes(bytes_from_path(self.plan_path)) + self.tensors = OrderedDict() + self.cuda_graph_instance = None # cuda graph + self.context = self.engine.create_execution_context() + self.input_names = [] + self.output_names = [] + self.dtypes = [] + self.cur_profile = 0 + dtype_dict = trt_to_torch_dtype_dict() + for idx in range(self.engine.num_io_tensors): + binding = self.engine[idx] + if self.engine.get_tensor_mode(binding) == trt.TensorIOMode.INPUT: + self.input_names.append(binding) + elif self.engine.get_tensor_mode(binding) == trt.TensorIOMode.OUTPUT: + self.output_names.append(binding) + dtype = dtype_dict[self.engine.get_tensor_dtype(binding)] + self.dtypes.append(dtype) + + def allocate_buffers(self, device): + """ + Allocates outputs to run TRT engine + Args: + device: GPU device to allocate memory on + """ + ctx = self.context + + for i, binding in enumerate(self.output_names): + shape = list(ctx.get_tensor_shape(binding)) + if binding not in self.tensors or list(self.tensors[binding].shape) != shape: + t = torch.empty(shape, dtype=self.dtypes[i], device=device).contiguous() + self.tensors[binding] = t + ctx.set_tensor_address(binding, t.data_ptr()) + + def set_inputs(self, feed_dict, stream): + """ + Sets input bindings for TRT engine according to feed_dict + Args: + feed_dict: a dictionary [str->Tensor] + stream: CUDA stream to use + """ + e = self.engine + ctx = self.context + + last_profile = self.cur_profile + + def try_set_inputs(): + for binding, t in feed_dict.items(): + if t is not None: + t = t.contiguous() + shape = t.shape + ctx.set_input_shape(binding, shape) + ctx.set_tensor_address(binding, t.data_ptr()) + + while True: + try: + try_set_inputs() + break + except ShapeError: + next_profile = (self.cur_profile + 1) % e.num_optimization_profiles + if next_profile == last_profile: + raise + self.cur_profile = next_profile + ctx.set_optimization_profile_async(self.cur_profile, stream) + + left = ctx.infer_shapes() + assert len(left) == 0 + + def infer(self, stream, use_cuda_graph=False): + """ + Runs TRT engine. + Args: + stream: CUDA stream to run on + use_cuda_graph: use CUDA graph. Note: requires all inputs to be the same GPU memory between calls. + """ + if use_cuda_graph: + if self.cuda_graph_instance is not None: + cuassert(cudart.cudaGraphLaunch(self.cuda_graph_instance, stream)) + cuassert(cudart.cudaStreamSynchronize(stream)) + else: + # do inference before CUDA graph capture + noerror = self.context.execute_async_v3(stream) + if not noerror: + raise ValueError("ERROR: inference failed.") + # capture cuda graph + cuassert( + cudart.cudaStreamBeginCapture(stream, cudart.cudaStreamCaptureMode.cudaStreamCaptureModeThreadLocal) + ) + self.context.execute_async_v3(stream) + graph = cuassert(cudart.cudaStreamEndCapture(stream)) + self.cuda_graph_instance = cuassert(cudart.cudaGraphInstantiate(graph, 0)) + self.logger.info("CUDA Graph captured!") + else: + noerror = self.context.execute_async_v3(stream) + cuassert(cudart.cudaStreamSynchronize(stream)) + if not noerror: + raise ValueError("ERROR: inference failed.") + + return self.tensors + + +class TrtCompiler: + """ + This class implements: + - TRT lazy persistent export + - Running TRT with optional fallback to Torch + (for TRT engines with limited profiles) + """ + + def __init__( + self, + model, + plan_path, + precision="fp16", + method="onnx", + input_names=None, + output_names=None, + export_args=None, + build_args=None, + input_profiles=None, + dynamic_batchsize=None, + use_cuda_graph=False, + timestamp=None, + fallback=False, + logger=None, + ): + """ + Initialization method: + Tries to load persistent serialized TRT engine + Saves its arguments for lazy TRT build on first forward() call + Args: + model: Model to "wrap". + plan_path : Path where to save persistent serialized TRT engine. + precision: TRT builder precision o engine model. Should be 'fp32'|'tf32'|'fp16'|'bf16'. + method: One of 'onnx'|'torch_trt'. + Default is 'onnx' (torch.onnx.export()->TRT). This is the most stable and efficient option. + 'torch_trt' may not work for some nets. Also AMP must be turned off for it to work. + input_names: Optional list of input names. If None, will be read from the function signature. + output_names: Optional list of output names. Note: If not None, patched forward() will return a dictionary. + export_args: Optional args to pass to export method. See onnx.export() and Torch-TensorRT docs for details. + build_args: Optional args to pass to TRT builder. See polygraphy.Config for details. + input_profiles: Optional list of profiles for TRT builder and ONNX export. + Each profile is a map of the form : {"input id" : [min_shape, opt_shape, max_shape], ...}. + dynamic_batchsize: A sequence with three elements to define the batch size range of the input for the model to be + converted. Should be a sequence like [MIN_BATCH, OPT_BATCH, MAX_BATCH]. + [note]: If neither input_profiles nor dynamic_batchsize specified, static shapes will be used to build TRT engine. + use_cuda_graph: Use CUDA Graph for inference. Note: all inputs have to be the same GPU memory between calls! + timestamp: Optional timestamp to rebuild TRT engine (e.g. if config file changes). + fallback: Allow to fall back to Pytorch when TRT inference fails (e.g, shapes exceed max profile). + """ + + method_vals = ["onnx", "torch_trt"] + if method not in method_vals: + raise ValueError(f"trt_compile(): 'method' should be one of {method_vals}, got: {method}.") + precision_vals = ["fp32", "tf32", "fp16", "bf16"] + if precision not in precision_vals: + raise ValueError(f"trt_compile(): 'precision' should be one of {precision_vals}, got: {precision}.") + + self.plan_path = plan_path + self.precision = precision + self.method = method + self.return_dict = output_names is not None + self.output_names = output_names or [] + self.profiles = input_profiles or [] + self.dynamic_batchsize = dynamic_batchsize + self.export_args = export_args or {} + self.build_args = build_args or {} + self.engine: TRTEngine | None = None + self.use_cuda_graph = use_cuda_graph + self.fallback = fallback + self.disabled = False + + self.logger = logger or get_logger("trt_compile") + + # Normally we read input_names from forward() but can be overridden + if input_names is None: + argspec = inspect.getfullargspec(model.forward) + input_names = argspec.args[1:] + self.input_names = input_names + self.old_forward = model.forward + + # Force engine rebuild if older than the timestamp + if timestamp is not None and os.path.exists(self.plan_path) and os.path.getmtime(self.plan_path) < timestamp: + os.remove(self.plan_path) + + def _inputs_to_dict(self, input_example): + trt_inputs = {} + for i, inp in enumerate(input_example): + input_name = self.input_names[i] + trt_inputs[input_name] = inp + return trt_inputs + + def _load_engine(self): + """ + Loads TRT plan from disk and activates its execution context. + """ + try: + self.engine = TRTEngine(self.plan_path, self.logger) + self.input_names = self.engine.input_names + except Exception as e: + self.logger.debug(f"Exception while loading the engine:\n{e}") + + def forward(self, model, argv, kwargs): + """ + Main forward method: + Builds TRT engine if not available yet. + Tries to run TRT engine + If exception thrown and self.callback==True: falls back to original Pytorch + + Args: Passing through whatever args wrapped module's forward() has + Returns: Passing through wrapped module's forward() return value(s) + + """ + if self.engine is None and not self.disabled: + # Restore original forward for export + new_forward = model.forward + model.forward = self.old_forward + try: + self._load_engine() + if self.engine is None: + build_args = kwargs.copy() + if len(argv) > 0: + build_args.update(self._inputs_to_dict(argv)) + self._build_and_save(model, build_args) + # This will reassign input_names from the engine + self._load_engine() + except Exception as e: + if self.fallback: + self.logger.info(f"Failed to build engine: {e}") + self.disabled = True + else: + raise e + if not self.disabled and not self.fallback: + # Delete all parameters + for param in model.parameters(): + del param + # Call empty_cache to release GPU memory + torch.cuda.empty_cache() + model.forward = new_forward + # Run the engine + try: + if len(argv) > 0: + kwargs.update(self._inputs_to_dict(argv)) + argv = () + + if self.engine is not None: + # forward_trt is not thread safe as we do not use per-thread execution contexts + with lock_sm: + device = torch.cuda.current_device() + stream = torch.cuda.Stream(device=device) + self.engine.set_inputs(kwargs, stream.cuda_stream) + self.engine.allocate_buffers(device=device) + # Need this to synchronize with Torch stream + stream.wait_stream(torch.cuda.current_stream()) + ret = self.engine.infer(stream.cuda_stream, use_cuda_graph=self.use_cuda_graph) + # if output_names is not None, return dictionary + if not self.return_dict: + ret = list(ret.values()) + if len(ret) == 1: + ret = ret[0] + return ret + except Exception as e: + if model is not None: + self.logger.info(f"Exception: {e}\nFalling back to Pytorch ...") + else: + raise e + return self.old_forward(*argv, **kwargs) + + def _onnx_to_trt(self, onnx_path): + """ + Builds TRT engine from ONNX file at onnx_path and saves to self.plan_path + """ + + profiles = [] + if self.profiles: + for input_profile in self.profiles: + if isinstance(input_profile, Profile): + profiles.append(input_profile) + else: + p = Profile() + for name, dims in input_profile.items(): + assert len(dims) == 3 + p.add(name, min=dims[0], opt=dims[1], max=dims[2]) + profiles.append(p) + + build_args = self.build_args.copy() + build_args["tf32"] = self.precision != "fp32" + build_args["fp16"] = self.precision == "fp16" + build_args["bf16"] = self.precision == "bf16" + + self.logger.info(f"Building TensorRT engine for {onnx_path}: {self.plan_path}") + network = network_from_onnx_path(onnx_path, flags=[trt.OnnxParserFlag.NATIVE_INSTANCENORM]) + return engine_bytes_from_network(network, config=CreateConfig(profiles=profiles, **build_args)) + + def _build_and_save(self, model, input_example): + """ + If TRT engine is not ready, exports model to ONNX, + builds TRT engine and saves serialized TRT engine to the disk. + Args: + input_example: passed to onnx.export() + """ + + if self.engine is not None: + return + + export_args = self.export_args + + add_casts_around_norms(model) + + if self.method == "torch_trt": + enabled_precisions = [torch.float32] + if self.precision == "fp16": + enabled_precisions.append(torch.float16) + elif self.precision == "bf16": + enabled_precisions.append(torch.bfloat16) + inputs = list(input_example.values()) + ir_model = convert_to_torchscript(model, inputs=inputs, use_trace=True) + + def get_torch_trt_input(input_shape, dynamic_batchsize): + min_input_shape, opt_input_shape, max_input_shape = get_profile_shapes(input_shape, dynamic_batchsize) + return torch_tensorrt.Input( + min_shape=min_input_shape, opt_shape=opt_input_shape, max_shape=max_input_shape + ) + + tt_inputs = [get_torch_trt_input(i.shape, self.dynamic_batchsize) for i in inputs] + engine_bytes = torch_tensorrt.convert_method_to_trt_engine( + ir_model, + "forward", + inputs=tt_inputs, + ir="torchscript", + enabled_precisions=enabled_precisions, + **export_args, + ) + else: + dbs = self.dynamic_batchsize + if dbs: + if len(self.profiles) > 0: + raise ValueError("ERROR: Both dynamic_batchsize and input_profiles set for TrtCompiler!") + if len(dbs) != 3: + raise ValueError("dynamic_batchsize has to have len ==3 ") + profiles = {} + for id, val in input_example.items(): + sh = val.shape[1:] + profiles[id] = [[dbs[0], *sh], [dbs[1], *sh], [dbs[2], *sh]] + self.profiles = [profiles] + + if len(self.profiles) > 0: + export_args.update({"dynamic_axes": get_dynamic_axes(self.profiles)}) + + # Use temporary directory for easy cleanup in case of external weights + with tempfile.TemporaryDirectory() as tmpdir: + onnx_path = Path(tmpdir) / "model.onnx" + self.logger.info( + f"Exporting to {onnx_path}:\n\toutput_names={self.output_names}\n\texport args: {export_args}" + ) + convert_to_onnx( + model, + input_example, + filename=str(onnx_path), + input_names=self.input_names, + output_names=self.output_names, + **export_args, + ) + self.logger.info("Export to ONNX successful.") + engine_bytes = self._onnx_to_trt(str(onnx_path)) + + open(self.plan_path, "wb").write(engine_bytes) + + +def trt_forward(self, *argv, **kwargs): + """ + Patch function to replace original model's forward() with. + Redirects to TrtCompiler.forward() + """ + return self._trt_compiler.forward(self, argv, kwargs) + + +def trt_compile( + model: torch.nn.Module, + base_path: str, + args: Dict[str, Any] | None = None, + submodule: Union[str, List[str]] | None = None, + logger: Any | None = None, +) -> torch.nn.Module: + """ + Instruments model or submodule(s) with TrtCompiler and replaces its forward() with TRT hook. + Args: + model: module to patch with TrtCompiler object. + base_path: TRT plan(s) saved to f"{base_path}[.{submodule}].plan" path. + dirname(base_path) must exist, base_path does not have to. + If base_path does point to existing file (e.g. associated checkpoint), + that file becomes a dependency - its mtime is added to args["timestamp"]. + args: Optional dict : unpacked and passed to TrtCompiler() - see TrtCompiler above for details. + submodule: Optional hierarchical id(s) of submodule to patch, e.g. ['image_decoder.decoder'] + If None, TrtCompiler patch is applied to the whole model. + Otherwise, submodule (or list of) is being patched. + logger: Optional logger for diagnostics. + Returns: + Always returns same model passed in as argument. This is for ease of use in configs. + """ + + default_args: Dict[str, Any] = { + "method": "onnx", + "precision": "fp16", + "build_args": {"builder_optimization_level": 5, "precision_constraints": "obey"}, + } + + default_args.update(args or {}) + args = default_args + + if trt_imported and polygraphy_imported and torch.cuda.is_available(): + # if "path" filename point to existing file (e.g. checkpoint) + # it's also treated as dependency + if os.path.exists(base_path): + timestamp = int(os.path.getmtime(base_path)) + if "timestamp" in args: + timestamp = max(int(args["timestamp"]), timestamp) + args["timestamp"] = timestamp + + def wrap(model, path): + wrapper = TrtCompiler(model, path + ".plan", logger=logger, **args) + model._trt_compiler = wrapper + model.forward = MethodType(trt_forward, model) + + def find_sub(parent, submodule): + idx = submodule.find(".") + # if there is "." in name, call recursively + if idx != -1: + parent_name = submodule[:idx] + parent = getattr(parent, parent_name) + submodule = submodule[idx + 1 :] + return find_sub(parent, submodule) + return parent, submodule + + if submodule is not None: + if isinstance(submodule, str): + submodule = [submodule] + for s in submodule: + parent, sub = find_sub(model, s) + wrap(getattr(parent, sub), base_path + "." + s) + else: + wrap(model, base_path) + else: + logger = logger or get_logger("trt_compile") + logger.warning("TensorRT and/or polygraphy packages are not available! trt_compile() has no effect.") + + return model diff --git a/monai/networks/utils.py b/monai/networks/utils.py index bd65ffa33e..d0150b4e5b 100644 --- a/monai/networks/utils.py +++ b/monai/networks/utils.py @@ -36,6 +36,8 @@ onnx, _ = optional_import("onnx") onnxreference, _ = optional_import("onnx.reference") onnxruntime, _ = optional_import("onnxruntime") +polygraphy, polygraphy_imported = optional_import("polygraphy") +torch_tensorrt, _ = optional_import("torch_tensorrt", "1.4.0") __all__ = [ "one_hot", @@ -61,6 +63,7 @@ "look_up_named_module", "set_named_module", "has_nvfuser_instance_norm", + "get_profile_shapes", ] logger = get_logger(module_name=__name__) @@ -68,6 +71,26 @@ _has_nvfuser = None +def get_profile_shapes(input_shape: Sequence[int], dynamic_batchsize: Sequence[int] | None): + """ + Given a sample input shape, calculate min/opt/max shapes according to dynamic_batchsize. + """ + + def scale_batch_size(input_shape: Sequence[int], scale_num: int): + scale_shape = [*input_shape] + scale_shape[0] = scale_num + return scale_shape + + # Use the dynamic batchsize range to generate the min, opt and max model input shape + if dynamic_batchsize: + min_input_shape = scale_batch_size(input_shape, dynamic_batchsize[0]) + opt_input_shape = scale_batch_size(input_shape, dynamic_batchsize[1]) + max_input_shape = scale_batch_size(input_shape, dynamic_batchsize[2]) + else: + min_input_shape = opt_input_shape = max_input_shape = input_shape + return min_input_shape, opt_input_shape, max_input_shape + + def has_nvfuser_instance_norm(): """whether the current environment has InstanceNorm3dNVFuser https://github.com/NVIDIA/apex/blob/23.05-devel/apex/normalization/instance_norm.py#L15-L16 @@ -606,6 +629,9 @@ def convert_to_onnx( rtol: float = 1e-4, atol: float = 0.0, use_trace: bool = True, + do_constant_folding: bool = True, + constant_size_threshold: int = 16 * 1024 * 1024 * 1024, + dynamo=False, **kwargs, ): """ @@ -632,7 +658,10 @@ def convert_to_onnx( rtol: the relative tolerance when comparing the outputs of PyTorch model and TorchScript model. atol: the absolute tolerance when comparing the outputs of PyTorch model and TorchScript model. use_trace: whether to use `torch.jit.trace` to export the torchscript model. - kwargs: other arguments except `obj` for `torch.jit.script()` to convert model, for more details: + do_constant_folding: passed to onnx.export(). If True, extra polygraphy folding pass is done. + constant_size_threshold: passed to polygrapy conatant forling, default = 16M + kwargs: if use_trace=True: additional arguments to pass to torch.onnx.export() + else: other arguments except `obj` for `torch.jit.script()` to convert model, for more details: https://pytorch.org/docs/master/generated/torch.jit.script.html. """ @@ -642,6 +671,7 @@ def convert_to_onnx( if use_trace: # let torch.onnx.export to trace the model. mode_to_export = model + torch_versioned_kwargs = kwargs else: if not pytorch_after(1, 10): if "example_outputs" not in kwargs: @@ -654,32 +684,37 @@ def convert_to_onnx( del kwargs["example_outputs"] mode_to_export = torch.jit.script(model, **kwargs) + if torch.is_tensor(inputs) or isinstance(inputs, dict): + onnx_inputs = (inputs,) + else: + onnx_inputs = tuple(inputs) + if filename is None: f = io.BytesIO() - torch.onnx.export( - mode_to_export, - tuple(inputs), - f=f, - input_names=input_names, - output_names=output_names, - dynamic_axes=dynamic_axes, - opset_version=opset_version, - **torch_versioned_kwargs, - ) + else: + f = filename + + torch.onnx.export( + mode_to_export, + onnx_inputs, + f=f, + input_names=input_names, + output_names=output_names, + dynamic_axes=dynamic_axes, + opset_version=opset_version, + do_constant_folding=do_constant_folding, + **torch_versioned_kwargs, + ) + if filename is None: onnx_model = onnx.load_model_from_string(f.getvalue()) else: - torch.onnx.export( - mode_to_export, - tuple(inputs), - f=filename, - input_names=input_names, - output_names=output_names, - dynamic_axes=dynamic_axes, - opset_version=opset_version, - **torch_versioned_kwargs, - ) onnx_model = onnx.load(filename) + if do_constant_folding and polygraphy_imported: + from polygraphy.backend.onnx.loader import fold_constants + + fold_constants(onnx_model, size_threshold=constant_size_threshold) + if verify: if device is None: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") @@ -814,7 +849,6 @@ def _onnx_trt_compile( """ trt, _ = optional_import("tensorrt", "8.5.3") - torch_tensorrt, _ = optional_import("torch_tensorrt", "1.4.0") input_shapes = (min_shape, opt_shape, max_shape) # default to an empty list to fit the `torch_tensorrt.ts.embed_engine_in_new_module` function. @@ -916,8 +950,6 @@ def convert_to_trt( to compile model, for more details: https://pytorch.org/TensorRT/py_api/torch_tensorrt.html#torch-tensorrt-py. """ - torch_tensorrt, _ = optional_import("torch_tensorrt", version="1.4.0") - if not torch.cuda.is_available(): raise Exception("Cannot find any GPU devices.") @@ -935,23 +967,9 @@ def convert_to_trt( convert_precision = torch.float32 if precision == "fp32" else torch.half inputs = [torch.rand(ensure_tuple(input_shape)).to(target_device)] - def scale_batch_size(input_shape: Sequence[int], scale_num: int): - scale_shape = [*input_shape] - scale_shape[0] *= scale_num - return scale_shape - - # Use the dynamic batchsize range to generate the min, opt and max model input shape - if dynamic_batchsize: - min_input_shape = scale_batch_size(input_shape, dynamic_batchsize[0]) - opt_input_shape = scale_batch_size(input_shape, dynamic_batchsize[1]) - max_input_shape = scale_batch_size(input_shape, dynamic_batchsize[2]) - else: - min_input_shape = opt_input_shape = max_input_shape = input_shape - # convert the torch model to a TorchScript model on target device model = model.eval().to(target_device) - ir_model = convert_to_torchscript(model, device=target_device, inputs=inputs, use_trace=use_trace) - ir_model.eval() + min_input_shape, opt_input_shape, max_input_shape = get_profile_shapes(input_shape, dynamic_batchsize) if use_onnx: # set the batch dim as dynamic @@ -960,7 +978,6 @@ def scale_batch_size(input_shape: Sequence[int], scale_num: int): ir_model = convert_to_onnx( model, inputs, onnx_input_names, onnx_output_names, use_trace=use_trace, dynamic_axes=dynamic_axes ) - # convert the model through the ONNX-TensorRT way trt_model = _onnx_trt_compile( ir_model, @@ -973,6 +990,8 @@ def scale_batch_size(input_shape: Sequence[int], scale_num: int): output_names=onnx_output_names, ) else: + ir_model = convert_to_torchscript(model, device=target_device, inputs=inputs, use_trace=use_trace) + ir_model.eval() # convert the model through the Torch-TensorRT way ir_model.to(target_device) with torch.no_grad(): @@ -1189,3 +1208,168 @@ def forward(self, x): if dtype == self.initial_type: x = x.to(self.initial_type) return x + + +def cast_tensor(x, from_dtype=torch.float16, to_dtype=torch.float32): + """ + Utility function to cast a single tensor from from_dtype to to_dtype + """ + return x.to(dtype=to_dtype) if x.dtype == from_dtype else x + + +def cast_all(x, from_dtype=torch.float16, to_dtype=torch.float32): + """ + Utility function to cast all tensors in a tuple from from_dtype to to_dtype + """ + if isinstance(x, torch.Tensor): + return cast_tensor(x, from_dtype=from_dtype, to_dtype=to_dtype) + else: + if isinstance(x, dict): + new_dict = {} + for k in x.keys(): + new_dict[k] = cast_all(x[k], from_dtype=from_dtype, to_dtype=to_dtype) + return new_dict + elif isinstance(x, tuple): + return tuple(cast_all(y, from_dtype=from_dtype, to_dtype=to_dtype) for y in x) + + +class CastToFloat(torch.nn.Module): + """ + Class used to add autocast protection for ONNX export + for forward methods with single return vaue + """ + + def __init__(self, mod): + super().__init__() + self.mod = mod + + def forward(self, x): + dtype = x.dtype + with torch.amp.autocast("cuda", enabled=False): + ret = self.mod.forward(x.to(torch.float32)).to(dtype) + return ret + + +class CastToFloatAll(torch.nn.Module): + """ + Class used to add autocast protection for ONNX export + for forward methods with multiple return values + """ + + def __init__(self, mod): + super().__init__() + self.mod = mod + + def forward(self, *args): + from_dtype = args[0].dtype + with torch.amp.autocast("cuda", enabled=False): + ret = self.mod.forward(*cast_all(args, from_dtype=from_dtype, to_dtype=torch.float32)) + return cast_all(ret, from_dtype=torch.float32, to_dtype=from_dtype) + + +def wrap_module(base_t: type[nn.Module], dest_t: type[nn.Module]) -> Callable[[nn.Module], nn.Module | None]: + """ + Generic function generator to replace base_t module with dest_t wrapper. + Args: + base_t : module type to replace + dest_t : destination module type + Returns: + swap function to replace base_t module with dest_t + """ + + def expansion_fn(mod: nn.Module) -> nn.Module | None: + out = dest_t(mod) + return out + + return expansion_fn + + +def simple_replace(base_t: type[nn.Module], dest_t: type[nn.Module]) -> Callable[[nn.Module], nn.Module | None]: + """ + Generic function generator to replace base_t module with dest_t. + base_t and dest_t should have same atrributes. No weights are copied. + Args: + base_t : module type to replace + dest_t : destination module type + Returns: + swap function to replace base_t module with dest_t + """ + + def expansion_fn(mod: nn.Module) -> nn.Module | None: + if not isinstance(mod, base_t): + return None + args = [getattr(mod, name, None) for name in mod.__constants__] + out = dest_t(*args) + return out + + return expansion_fn + + +def _swap_modules(model: nn.Module, mapping: dict[str, nn.Module]) -> nn.Module: + """ + This function swaps nested modules as specified by "dot paths" in mod with a desired replacement. This allows + for swapping nested modules through arbitrary levels if children + + NOTE: This occurs in place, if you want to preserve model then make sure to copy it first. + + """ + for path, new_mod in mapping.items(): + expanded_path = path.split(".") + parent_mod = model + for sub_path in expanded_path[:-1]: + submod = parent_mod._modules[sub_path] + if submod is None: + break + else: + parent_mod = submod + parent_mod._modules[expanded_path[-1]] = new_mod + + return model + + +def replace_modules_by_type( + model: nn.Module, expansions: dict[str, Callable[[nn.Module], nn.Module | None]] +) -> nn.Module: + """ + Top-level function to replace modules in model, specified by class name with a desired replacement. + NOTE: This occurs in place, if you want to preserve model then make sure to copy it first. + Args: + model : top level module + expansions : replacement dictionary: module class name -> replacement function generator + Returns: + model, possibly modified in-place + """ + mapping: dict[str, nn.Module] = {} + for name, m in model.named_modules(): + m_type = type(m).__name__ + if m_type in expansions: + # print (f"Found {m_type} in expansions ...") + swapped = expansions[m_type](m) + if swapped: + mapping[name] = swapped + + print(f"Swapped {len(mapping)} modules") + _swap_modules(model, mapping) + return model + + +def add_casts_around_norms(model: nn.Module) -> nn.Module: + """ + Top-level function to add cast wrappers around modules known to cause issues for FP16/autocast ONNX export + NOTE: This occurs in place, if you want to preserve model then make sure to copy it first. + Args: + model : top level module + Returns: + model, possibly modified in-place + """ + print("Adding casts around norms...") + cast_replacements = { + "BatchNorm1d": wrap_module(nn.BatchNorm1d, CastToFloat), + "BatchNorm2d": wrap_module(nn.BatchNorm2d, CastToFloat), + "BatchNorm3d": wrap_module(nn.BatchNorm2d, CastToFloat), + "LayerNorm": wrap_module(nn.LayerNorm, CastToFloat), + "InstanceNorm1d": wrap_module(nn.InstanceNorm1d, CastToFloat), + "InstanceNorm3d": wrap_module(nn.InstanceNorm3d, CastToFloat), + } + replace_modules_by_type(model, cast_replacements) + return model diff --git a/requirements-dev.txt b/requirements-dev.txt index 9aad0804e6..6d0ccd378a 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -59,3 +59,5 @@ nvidia-ml-py huggingface_hub pyamg>=5.0.0 git+https://github.com/facebookresearch/segment-anything.git@6fdee8f2727f4506cfbbe553e23b895e27956588 +onnx_graphsurgeon +polygraphy diff --git a/setup.cfg b/setup.cfg index 1ce4a3f34c..c97118d43a 100644 --- a/setup.cfg +++ b/setup.cfg @@ -160,6 +160,9 @@ lpips = lpips==0.1.4 pynvml = nvidia-ml-py +polygraphy = + polygraphy + # # workaround https://github.com/Project-MONAI/MONAI/issues/5882 # MetricsReloaded = # MetricsReloaded @ git+https://github.com/Project-MONAI/MetricsReloaded@monai-support#egg=MetricsReloaded diff --git a/tests/min_tests.py b/tests/min_tests.py index f80d06f5d3..632355b5c6 100644 --- a/tests/min_tests.py +++ b/tests/min_tests.py @@ -186,6 +186,7 @@ def run_testsuit(): "test_torchvisiond", "test_transchex", "test_transformerblock", + "test_trt_compile", "test_unetr", "test_unetr_block", "test_vit", diff --git a/tests/test_config_parser.py b/tests/test_config_parser.py index cf1edc8f08..2b00c9f9d1 100644 --- a/tests/test_config_parser.py +++ b/tests/test_config_parser.py @@ -125,6 +125,22 @@ def __call__(self, a, b): [0, 4], ] +TEST_CASE_MERGE_JSON = ["""{"key1": [0], "key2": [0] }""", """{"key1": [1], "+key2": [4] }""", "json", [1], [0, 4]] + +TEST_CASE_MERGE_YAML = [ + """ + key1: 0 + key2: [0] + """, + """ + key1: 1 + +key2: [4] + """, + "yaml", + 1, + [0, 4], +] + class TestConfigParser(unittest.TestCase): @@ -357,6 +373,22 @@ def test_parse_json_warn(self, config_string, extension, expected_unique_val, ex self.assertEqual(parser.get_parsed_content("key#unique"), expected_unique_val) self.assertIn(parser.get_parsed_content("key#duplicate"), expected_duplicate_vals) + @parameterized.expand([TEST_CASE_MERGE_JSON, TEST_CASE_MERGE_YAML]) + @skipUnless(has_yaml, "Requires pyyaml") + def test_load_configs( + self, config_string, config_string2, extension, expected_overridden_val, expected_merged_vals + ): + with tempfile.TemporaryDirectory() as tempdir: + config_path1 = Path(tempdir) / f"config1.{extension}" + config_path2 = Path(tempdir) / f"config2.{extension}" + config_path1.write_text(config_string) + config_path2.write_text(config_string2) + + parser = ConfigParser.load_config_files([config_path1, config_path2]) + + self.assertEqual(parser["key1"], expected_overridden_val) + self.assertEqual(parser["key2"], expected_merged_vals) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_sure_loss.py b/tests/test_sure_loss.py index 903f9bd2ca..fb8f5dda72 100644 --- a/tests/test_sure_loss.py +++ b/tests/test_sure_loss.py @@ -65,7 +65,7 @@ def operator(x): loss_real = sure_loss_real(operator, x_real, y_pseudo_gt_real, complex_input=False) loss_complex = sure_loss_complex(operator, x_complex, y_pseudo_gt_complex, complex_input=True) - self.assertAlmostEqual(loss_real.item(), loss_complex.abs().item(), places=6) + self.assertAlmostEqual(loss_real.item(), loss_complex.abs().item(), places=5) if __name__ == "__main__": diff --git a/tests/test_trt_compile.py b/tests/test_trt_compile.py new file mode 100644 index 0000000000..21125d203f --- /dev/null +++ b/tests/test_trt_compile.py @@ -0,0 +1,140 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import tempfile +import unittest + +import torch +from parameterized import parameterized + +from monai.handlers import TrtHandler +from monai.networks import trt_compile +from monai.networks.nets import UNet, cell_sam_wrapper, vista3d132 +from monai.utils import optional_import +from tests.utils import skip_if_no_cuda, skip_if_quick, skip_if_windows + +trt, trt_imported = optional_import("tensorrt") +polygraphy, polygraphy_imported = optional_import("polygraphy") +build_sam_vit_b, has_sam = optional_import("segment_anything.build_sam", name="build_sam_vit_b") + +TEST_CASE_1 = ["fp32"] +TEST_CASE_2 = ["fp16"] + + +@skip_if_windows +@skip_if_no_cuda +@skip_if_quick +@unittest.skipUnless(trt_imported, "tensorrt is required") +@unittest.skipUnless(polygraphy_imported, "polygraphy is required") +class TestTRTCompile(unittest.TestCase): + + def setUp(self): + self.gpu_device = torch.cuda.current_device() + + def tearDown(self): + current_device = torch.cuda.current_device() + if current_device != self.gpu_device: + torch.cuda.set_device(self.gpu_device) + + def test_handler(self): + from ignite.engine import Engine + + net1 = torch.nn.Sequential(*[torch.nn.PReLU(), torch.nn.PReLU()]) + data1 = net1.state_dict() + data1["0.weight"] = torch.tensor([0.1]) + data1["1.weight"] = torch.tensor([0.2]) + net1.load_state_dict(data1) + net1.cuda() + + with tempfile.TemporaryDirectory() as tempdir: + engine = Engine(lambda e, b: None) + args = {"method": "torch_trt"} + TrtHandler(net1, tempdir + "/trt_handler", args=args).attach(engine) + engine.run([0] * 8, max_epochs=1) + self.assertIsNotNone(net1._trt_compiler) + self.assertIsNone(net1._trt_compiler.engine) + net1.forward(torch.tensor([[0.0, 1.0], [1.0, 2.0]], device="cuda")) + self.assertIsNotNone(net1._trt_compiler.engine) + + @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) + def test_unet_value(self, precision): + model = UNet( + spatial_dims=3, + in_channels=1, + out_channels=2, + channels=(2, 2, 4, 8, 4), + strides=(2, 2, 2, 2), + num_res_units=2, + norm="batch", + ).cuda() + with torch.no_grad(), tempfile.TemporaryDirectory() as tmpdir: + model.eval() + input_example = torch.randn(2, 1, 96, 96, 96).cuda() + output_example = model(input_example) + args: dict = {"builder_optimization_level": 1} + trt_compile( + model, + f"{tmpdir}/test_unet_trt_compile", + args={"precision": precision, "build_args": args, "dynamic_batchsize": [1, 4, 8]}, + ) + self.assertIsNone(model._trt_compiler.engine) + trt_output = model(input_example) + # Check that lazy TRT build succeeded + self.assertIsNotNone(model._trt_compiler.engine) + torch.testing.assert_close(trt_output, output_example, rtol=0.01, atol=0.01) + + @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) + @unittest.skipUnless(has_sam, "Requires SAM installation") + def test_cell_sam_wrapper_value(self, precision): + model = cell_sam_wrapper.CellSamWrapper(checkpoint=None).to("cuda") + with torch.no_grad(), tempfile.TemporaryDirectory() as tmpdir: + model.eval() + input_example = torch.randn(1, 3, 128, 128).to("cuda") + output_example = model(input_example) + trt_compile( + model, + f"{tmpdir}/test_cell_sam_wrapper_trt_compile", + args={"precision": precision, "dynamic_batchsize": [1, 1, 1]}, + ) + self.assertIsNone(model._trt_compiler.engine) + trt_output = model(input_example) + # Check that lazy TRT build succeeded + self.assertIsNotNone(model._trt_compiler.engine) + torch.testing.assert_close(trt_output, output_example, rtol=0.01, atol=0.01) + + @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) + def test_vista3d(self, precision): + model = vista3d132(in_channels=1).to("cuda") + with torch.no_grad(), tempfile.TemporaryDirectory() as tmpdir: + model.eval() + input_example = torch.randn(1, 1, 64, 64, 64).to("cuda") + output_example = model(input_example) + model = trt_compile( + model, + f"{tmpdir}/test_vista3d_trt_compile", + args={"precision": precision, "dynamic_batchsize": [1, 1, 1]}, + submodule=["image_encoder.encoder", "class_head"], + ) + self.assertIsNotNone(model.image_encoder.encoder._trt_compiler) + self.assertIsNotNone(model.class_head._trt_compiler) + trt_output = model.forward(input_example) + # Check that lazy TRT build succeeded + # TODO: set up input_example in such a way that image_encoder.encoder and class_head are called + # and uncomment the asserts below + # self.assertIsNotNone(model.image_encoder.encoder._trt_compiler.engine) + # self.assertIsNotNone(model.class_head._trt_compiler.engine) + torch.testing.assert_close(trt_output, output_example, rtol=0.01, atol=0.01) + + +if __name__ == "__main__": + unittest.main()