Skip to content

Commit

Permalink
Added TRTWrapper (#7990)
Browse files Browse the repository at this point in the history
### Description

Added alternative class to ONNX->TRT export and wrap TRT engines for
inference.
It encapsulates filesystem persistence and does not rely on
torch-tensortrt for execution.
Also can be used to run ONNX with onnxruntime.

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.

---------

Signed-off-by: Boris Fomitchev <[email protected]>
Signed-off-by: Yiheng Wang <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: YunLiu <[email protected]>
Co-authored-by: Yiheng Wang <[email protected]>
Co-authored-by: Yiheng Wang <[email protected]>
Co-authored-by: binliunls <[email protected]>
  • Loading branch information
6 people authored Sep 1, 2024
1 parent fa1ef8b commit c9f8d32
Show file tree
Hide file tree
Showing 17 changed files with 1,121 additions and 51 deletions.
1 change: 1 addition & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -56,4 +56,5 @@ RUN apt-get update \
&& rm -rf /var/lib/apt/lists/*
# append /opt/tools to runtime path for NGC CLI to be accessible from all file system locations
ENV PATH=${PATH}:/opt/tools
ENV POLYGRAPHY_AUTOINSTALL_DEPS=1
WORKDIR /opt/monai
42 changes: 42 additions & 0 deletions docs/source/config_syntax.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ Content:
- [`$` to evaluate as Python expressions](#to-evaluate-as-python-expressions)
- [`%` to textually replace configuration elements](#to-textually-replace-configuration-elements)
- [`_target_` (`_disabled_`, `_desc_`, `_requires_`, `_mode_`) to instantiate a Python object](#instantiate-a-python-object)
- [`+` to alter semantics of merging config keys from multiple configuration files](#multiple-config-files)
- [The command line interface](#the-command-line-interface)
- [Recommendations](#recommendations)

Expand Down Expand Up @@ -175,6 +176,47 @@ _Description:_ `_requires_`, `_disabled_`, `_desc_`, and `_mode_` are optional k
- `"debug"` -- execute with debug prompt and return the return value of ``pdb.runcall(_target_, **kwargs)``,
see also [`pdb.runcall`](https://docs.python.org/3/library/pdb.html#pdb.runcall).

## Multiple config files

_Description:_ Multiple config files may be specified on the command line.
The content of those config files is being merged. When same keys are specifiled in more than one config file,
the value associated with the key is being overridden, in the order config files are specified.
If the desired behaviour is to merge values from both files, the key in second config file should be prefixed with `+`.
The value types for the merged contents must match and be both of `dict` or both of `list` type.
`dict` values will be merged via update(), `list` values - concatenated via extend().
Here's an example. In this case, "amp" value will be overridden by extra_config.json.
`imports` and `preprocessing#transforms` lists will be merged. An error would be thrown if the value type in `"+imports"` is not `list`:

config.json:
```json
{
"amp": "$True"
"imports": [
"$import torch"
],
"preprocessing": {
"_target_": "Compose",
"transforms": [
"$@t1",
"$@t2"
]
},
}
```

extra_config.json:
```json
{
"amp": "$False"
"+imports": [
"$from monai.networks import trt_compile"
],
"+preprocessing#transforms": [
"$@t3"
]
}
```

## The command line interface

In addition to the Pythonic APIs, a few command line interfaces (CLI) are provided to interact with the bundle.
Expand Down
8 changes: 5 additions & 3 deletions monai/bundle/config_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

from monai.bundle.config_item import ComponentLocator, ConfigComponent, ConfigExpression, ConfigItem
from monai.bundle.reference_resolver import ReferenceResolver
from monai.bundle.utils import ID_REF_KEY, ID_SEP_KEY, MACRO_KEY
from monai.bundle.utils import ID_REF_KEY, ID_SEP_KEY, MACRO_KEY, merge_kv
from monai.config import PathLike
from monai.utils import ensure_tuple, look_up_option, optional_import
from monai.utils.misc import CheckKeyDuplicatesYamlLoader, check_key_duplicates
Expand Down Expand Up @@ -423,8 +423,10 @@ def load_config_files(cls, files: PathLike | Sequence[PathLike] | dict, **kwargs
if isinstance(files, str) and not Path(files).is_file() and "," in files:
files = files.split(",")
for i in ensure_tuple(files):
for k, v in (cls.load_config_file(i, **kwargs)).items():
parser[k] = v
config_dict = cls.load_config_file(i, **kwargs)
for k, v in config_dict.items():
merge_kv(parser, k, v)

return parser.get() # type: ignore

@classmethod
Expand Down
4 changes: 2 additions & 2 deletions monai/bundle/scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from monai.apps.utils import _basename, download_url, extractall, get_logger
from monai.bundle.config_item import ConfigComponent
from monai.bundle.config_parser import ConfigParser
from monai.bundle.utils import DEFAULT_INFERENCE, DEFAULT_METADATA
from monai.bundle.utils import DEFAULT_INFERENCE, DEFAULT_METADATA, merge_kv
from monai.bundle.workflows import BundleWorkflow, ConfigWorkflow
from monai.config import IgniteInfo, PathLike
from monai.data import load_net_with_metadata, save_net_with_metadata
Expand Down Expand Up @@ -105,7 +105,7 @@ def update_kwargs(args: str | dict | None = None, ignore_none: bool = True, **kw
if isinstance(v, dict) and isinstance(args_.get(k), dict):
args_[k] = update_kwargs(args_[k], ignore_none, **v)
else:
args_[k] = v
merge_kv(args_, k, v)
return args_


Expand Down
36 changes: 35 additions & 1 deletion monai/bundle/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

import json
import os
import warnings
import zipfile
from typing import Any

Expand All @@ -21,12 +22,21 @@

yaml, _ = optional_import("yaml")

__all__ = ["ID_REF_KEY", "ID_SEP_KEY", "EXPR_KEY", "MACRO_KEY", "DEFAULT_MLFLOW_SETTINGS", "DEFAULT_EXP_MGMT_SETTINGS"]
__all__ = [
"ID_REF_KEY",
"ID_SEP_KEY",
"EXPR_KEY",
"MACRO_KEY",
"MERGE_KEY",
"DEFAULT_MLFLOW_SETTINGS",
"DEFAULT_EXP_MGMT_SETTINGS",
]

ID_REF_KEY = "@" # start of a reference to a ConfigItem
ID_SEP_KEY = "::" # separator for the ID of a ConfigItem
EXPR_KEY = "$" # start of a ConfigExpression
MACRO_KEY = "%" # start of a macro of a config
MERGE_KEY = "+" # prefix indicating merge instead of override in case of multiple configs.

_conf_values = get_config_values()

Expand Down Expand Up @@ -233,3 +243,27 @@ def load_bundle_config(bundle_path: str, *config_names: str, **load_kw_args: Any
parser.read_config(f=cdata)

return parser


def merge_kv(args: dict | Any, k: str, v: Any) -> None:
"""
Update the `args` dict-like object with the key/value pair `k` and `v`.
"""
if k.startswith(MERGE_KEY):
"""
Both values associated with `+`-prefixed key pair must be of `dict` or `list` type.
`dict` values will be merged, `list` values - concatenated.
"""
id = k[1:]
if id in args:
if isinstance(v, dict) and isinstance(args[id], dict):
args[id].update(v)
elif isinstance(v, list) and isinstance(args[id], list):
args[id].extend(v)
else:
raise ValueError(ValueError(f"config must be dict or list for key `{k}`, but got {type(v)}: {v}."))
else:
warnings.warn(f"Can't merge entry ['{k}'], '{id}' is not in target dict - copying instead.")
args[id] = v
else:
args[k] = v
1 change: 1 addition & 0 deletions monai/handlers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,5 +40,6 @@
from .stats_handler import StatsHandler
from .surface_distance import SurfaceDistance
from .tensorboard_handlers import TensorBoardHandler, TensorBoardImageHandler, TensorBoardStatsHandler
from .trt_handler import TrtHandler
from .utils import from_engine, ignore_data, stopping_fn_from_loss, stopping_fn_from_metric, write_metrics_reports
from .validation_handler import ValidationHandler
61 changes: 61 additions & 0 deletions monai/handlers/trt_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from typing import TYPE_CHECKING

from monai.config import IgniteInfo
from monai.networks import trt_compile
from monai.utils import min_version, optional_import

Events, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Events")
if TYPE_CHECKING:
from ignite.engine import Engine
else:
Engine, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Engine")


class TrtHandler:
"""
TrtHandler acts as an Ignite handler to apply TRT acceleration to the model.
Usage example::
handler = TrtHandler(model=model, base_path="/test/checkpoint.pt", args={"precision": "fp16"})
handler.attach(engine)
engine.run()
"""

def __init__(self, model, base_path, args=None, submodule=None):
"""
Args:
base_path: TRT path basename. TRT plan(s) saved to "base_path[.submodule].plan"
args: passed to trt_compile(). See trt_compile() for details.
submodule : Hierarchical ids of submodules to convert, e.g. 'image_decoder.decoder'
"""
self.model = model
self.base_path = base_path
self.args = args
self.submodule = submodule

def attach(self, engine: Engine) -> None:
"""
Args:
engine: Ignite Engine, it can be a trainer, validator or evaluator.
"""
self.logger = engine.logger
engine.add_event_handler(Events.STARTED, self)

def __call__(self, engine: Engine) -> None:
"""
Args:
engine: Ignite Engine, it can be a trainer, validator or evaluator.
"""
trt_compile(self.model, self.base_path, args=self.args, submodule=self.submodule, logger=self.logger)
2 changes: 2 additions & 0 deletions monai/networks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@

from __future__ import annotations

from .trt_compiler import trt_compile
from .utils import (
add_casts_around_norms,
convert_to_onnx,
convert_to_torchscript,
convert_to_trt,
Expand Down
8 changes: 4 additions & 4 deletions monai/networks/nets/swin_unetr.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,7 @@ def _check_input_size(self, spatial_shape):
)

def forward(self, x_in):
if not torch.jit.is_scripting():
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
self._check_input_size(x_in.shape[2:])
hidden_states_out = self.swinViT(x_in, self.normalize)
enc0 = self.encoder1(x_in)
Expand Down Expand Up @@ -1046,14 +1046,14 @@ def __init__(

def proj_out(self, x, normalize=False):
if normalize:
x_shape = x.size()
x_shape = x.shape
# Force trace() to generate a constant by casting to int
ch = int(x_shape[1])
if len(x_shape) == 5:
n, ch, d, h, w = x_shape
x = rearrange(x, "n c d h w -> n d h w c")
x = F.layer_norm(x, [ch])
x = rearrange(x, "n d h w c -> n c d h w")
elif len(x_shape) == 4:
n, ch, h, w = x_shape
x = rearrange(x, "n c h w -> n h w c")
x = F.layer_norm(x, [ch])
x = rearrange(x, "n h w c -> n c h w")
Expand Down
Loading

0 comments on commit c9f8d32

Please sign in to comment.