Skip to content

Commit

Permalink
Add vista network (#7987)
Browse files Browse the repository at this point in the history
Fixes # .

### Description

Add VISTA3D model architecture to MONAI core

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.

---------

Signed-off-by: heyufan1995 <[email protected]>
Signed-off-by: Yufan He <[email protected]>
Signed-off-by: Yiheng Wang <[email protected]>
Signed-off-by: Yiheng Wang <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Yiheng Wang <[email protected]>
Co-authored-by: Yiheng Wang <[email protected]>
Co-authored-by: YunLiu <[email protected]>
  • Loading branch information
5 people authored Aug 15, 2024
1 parent e85580a commit 77304dd
Show file tree
Hide file tree
Showing 7 changed files with 1,189 additions and 35 deletions.
10 changes: 10 additions & 0 deletions docs/source/networks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -481,6 +481,11 @@ Nets
.. autoclass:: SegResNetDS
:members:

`SegResNetDS2`
~~~~~~~~~~~~~~
.. autoclass:: SegResNetDS2
:members:

`SegResNetVAE`
~~~~~~~~~~~~~~
.. autoclass:: SegResNetVAE
Expand Down Expand Up @@ -556,6 +561,11 @@ Nets
.. autoclass:: UNETR
:members:

`VISTA3D`
~~~~~~~~~
.. autoclass:: VISTA3D
:members:

`SwinUNETR`
~~~~~~~~~~~
.. autoclass:: SwinUNETR
Expand Down
3 changes: 2 additions & 1 deletion monai/networks/nets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@
resnet200,
)
from .segresnet import SegResNet, SegResNetVAE
from .segresnet_ds import SegResNetDS
from .segresnet_ds import SegResNetDS, SegResNetDS2
from .senet import (
SENet,
SEnet,
Expand Down Expand Up @@ -118,6 +118,7 @@
from .unet import UNet, Unet
from .unetr import UNETR
from .varautoencoder import VarAutoEncoder
from .vista3d import VISTA3D, vista3d132
from .vit import ViT
from .vitautoenc import ViTAutoEnc
from .vnet import VNet
Expand Down
128 changes: 127 additions & 1 deletion monai/networks/nets/segresnet_ds.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from __future__ import annotations

import copy
from collections.abc import Callable
from typing import Union

Expand All @@ -23,7 +24,7 @@
from monai.networks.layers.utils import get_act_layer, get_norm_layer
from monai.utils import UpsampleMode, has_option

__all__ = ["SegResNetDS"]
__all__ = ["SegResNetDS", "SegResNetDS2"]


def scales_for_resolution(resolution: tuple | list, n_stages: int | None = None):
Expand Down Expand Up @@ -425,3 +426,128 @@ def _forward(self, x: torch.Tensor) -> Union[None, torch.Tensor, list[torch.Tens

def forward(self, x: torch.Tensor) -> Union[None, torch.Tensor, list[torch.Tensor]]:
return self._forward(x)


class SegResNetDS2(SegResNetDS):
"""
SegResNetDS2 adds an additional decorder branch to SegResNetDS and is the image encoder of VISTA3D
<https://arxiv.org/abs/2406.05285>`_.
Args:
spatial_dims: spatial dimension of the input data. Defaults to 3.
init_filters: number of output channels for initial convolution layer. Defaults to 32.
in_channels: number of input channels for the network. Defaults to 1.
out_channels: number of output channels for the network. Defaults to 2.
act: activation type and arguments. Defaults to ``RELU``.
norm: feature normalization type and arguments. Defaults to ``BATCH``.
blocks_down: number of downsample blocks in each layer. Defaults to ``[1,2,2,4]``.
blocks_up: number of upsample blocks (optional).
dsdepth: number of levels for deep supervision. This will be the length of the list of outputs at each scale level.
At dsdepth==1,only a single output is returned.
preprocess: optional callable function to apply before the model's forward pass
resolution: optional input image resolution. When provided, the network will first use non-isotropic kernels to bring
image spacing into an approximately isotropic space.
Otherwise, by default, the kernel size and downsampling is always isotropic.
"""

def __init__(
self,
spatial_dims: int = 3,
init_filters: int = 32,
in_channels: int = 1,
out_channels: int = 2,
act: tuple | str = "relu",
norm: tuple | str = "batch",
blocks_down: tuple = (1, 2, 2, 4),
blocks_up: tuple | None = None,
dsdepth: int = 1,
preprocess: nn.Module | Callable | None = None,
upsample_mode: UpsampleMode | str = "deconv",
resolution: tuple | None = None,
):
super().__init__(
spatial_dims=spatial_dims,
init_filters=init_filters,
in_channels=in_channels,
out_channels=out_channels,
act=act,
norm=norm,
blocks_down=blocks_down,
blocks_up=blocks_up,
dsdepth=dsdepth,
preprocess=preprocess,
upsample_mode=upsample_mode,
resolution=resolution,
)

self.up_layers_auto = nn.ModuleList([copy.deepcopy(layer) for layer in self.up_layers])

def forward( # type: ignore
self, x: torch.Tensor, with_point: bool = True, with_label: bool = True
) -> tuple[Union[None, torch.Tensor, list[torch.Tensor]], Union[None, torch.Tensor, list[torch.Tensor]]]:
"""
Args:
x: input tensor.
with_point: if true, return the point branch output.
with_label: if true, return the label branch output.
"""
if self.preprocess is not None:
x = self.preprocess(x)

if not self.is_valid_shape(x):
raise ValueError(f"Input spatial dims {x.shape} must be divisible by {self.shape_factor()}")

x_down = self.encoder(x)

x_down.reverse()
x = x_down.pop(0)

if len(x_down) == 0:
x_down = [torch.zeros(1, device=x.device, dtype=x.dtype)]

outputs: list[torch.Tensor] = []
outputs_auto: list[torch.Tensor] = []
x_ = x.clone()
if with_point:
i = 0
for level in self.up_layers:
x = level["upsample"](x)
x = x + x_down[i]
x = level["blocks"](x)

if len(self.up_layers) - i <= self.dsdepth:
outputs.append(level["head"](x))
i = i + 1

outputs.reverse()
x = x_
if with_label:
i = 0
for level in self.up_layers_auto:
x = level["upsample"](x)
x = x + x_down[i]
x = level["blocks"](x)

if len(self.up_layers) - i <= self.dsdepth:
outputs_auto.append(level["head"](x))
i = i + 1

outputs_auto.reverse()

return outputs[0] if len(outputs) == 1 else outputs, outputs_auto[0] if len(outputs_auto) == 1 else outputs_auto

def set_auto_grad(self, auto_freeze=False, point_freeze=False):
"""
Args:
auto_freeze: if true, freeze the image encoder and the auto-branch.
point_freeze: if true, freeze the image encoder and the point-branch.
"""
for param in self.encoder.parameters():
param.requires_grad = (not auto_freeze) and (not point_freeze)

for param in self.up_layers_auto.parameters():
param.requires_grad = not auto_freeze

for param in self.up_layers.parameters():
param.requires_grad = not point_freeze
Loading

0 comments on commit 77304dd

Please sign in to comment.