Skip to content

Commit

Permalink
Bm/fix-config-nested-new (#102)
Browse files Browse the repository at this point in the history
* implement test behavior with no test inputs

* docs

* Refactor code in deeplay/activelearning/data.py, deeplay/activelearning/strategies/uncertainty.py, deeplay/applications/application.py, and deeplay/activelearning/strategies/strategy.py

* clear config before multi

* update to use absolute imports

* Refactor import statements in test_selectors.py

* Bm/fix-config-nested-new
Fixes an issue with multi-blocks not having their configuration correctly cleared if created multiple times

* remove test file

* Fix issue with clearing configuration before creating multiple blocks

* Fix issue with configuring upsample in Conv2dBlock

* Re-enable test_strided_multi in test_conv.py

* add stubs

* Refactor residual function in Conv2dBlock to support flexible layer order

* Refactor available_styles method in DeeplayModule to use classmethod

* Implement script to create stubs with style typing

* Refactor Conv2dBlock and related functions in conv2d.pyi

* Remove publish script from package.json

* Add .gitignore entry for package.json

---------

Co-authored-by: Giovanni Volpe <[email protected]>
  • Loading branch information
BenjaminMidtvedt and giovannivolpe authored May 8, 2024
1 parent a4bf91a commit 0013dff
Show file tree
Hide file tree
Showing 50 changed files with 3,972 additions and 174 deletions.
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,6 @@ dist/
*.egg-info

# Mac
*.DS_Store
*.DS_Store

package.json
3 changes: 3 additions & 0 deletions deeplay/activelearning/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ def get_unannotated_labels(self):
def get_unannotated_data(self):
return torch.utils.data.Subset(self.dataset, np.where(~self.annotated)[0])

def get_num_annotated(self):
return np.sum(self.annotated)


class JointDataset(torch.utils.data.Dataset):

Expand Down
25 changes: 13 additions & 12 deletions deeplay/activelearning/strategies/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from deeplay.activelearning.data import ActiveLearningDataset

import torch
import copy


class Strategy(Application):
Expand All @@ -11,7 +12,7 @@ def __init__(
self,
train_pool: ActiveLearningDataset,
val_pool: Optional[ActiveLearningDataset] = None,
test: Optional[torch.utils.data.Dataset] = None,
test_data: Optional[torch.utils.data.Dataset] = None,
batch_size: int = 32,
val_batch_size: Optional[int] = None,
test_batch_size: Optional[int] = None,
Expand All @@ -20,7 +21,7 @@ def __init__(
super().__init__(**kwargs)
self.train_pool = train_pool
self.val_pool = val_pool
self.test = test
self.test_data = test_data
self.initial_model_state: Optional[Dict[str, Any]] = None
self.batch_size = batch_size
self.val_batch_size = (
Expand All @@ -34,7 +35,7 @@ def on_train_start(self) -> None:
# Save the initial model state before training
# such that we can reset the model to its initial state
# if needed.
self.initial_model_state = self.state_dict()
self.initial_model_state = copy.deepcopy(self.state_dict())
self.train()

return super().on_train_start()
Expand Down Expand Up @@ -91,19 +92,19 @@ def train_dataloader(self):
data, batch_size=self.batch_size, shuffle=True
)

def val_dataloader(self):
if self.val_pool is None:
return []
data = self.train_pool.get_unannotated_data()
return torch.utils.data.DataLoader(
data, batch_size=self.val_batch_size, shuffle=False
)
# def val_dataloader(self):
# if self.val_pool is None:
# return []
# data = self.train_pool.get_unannotated_data()
# return torch.utils.data.DataLoader(
# data, batch_size=self.val_batch_size, shuffle=False
# )

def test_dataloader(self):
if self.test is None:
if self.test_data is None:
return []
return torch.utils.data.DataLoader(
self.test, batch_size=self.test_batch_size, shuffle=False
self.test_data, batch_size=self.test_batch_size, shuffle=False
)

def test_step(self, batch, batch_idx):
Expand Down
3 changes: 2 additions & 1 deletion deeplay/activelearning/strategies/uncertainty.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from typing import Optional
from deeplay.activelearning.strategies.strategy import Strategy
from deeplay.activelearning.data import ActiveLearningDataset
from deeplay.activelearning.criterion import ActiveLearningCriterion
Expand All @@ -15,7 +16,7 @@ def __init__(
classifier: DeeplayModule,
criterion: ActiveLearningCriterion,
train_pool: ActiveLearningDataset,
val_pool: ActiveLearningDataset = None,
val_pool: Optional[ActiveLearningDataset] = None,
test: torch.utils.data.Dataset = None,
batch_size: int = 32,
val_batch_size: int = None,
Expand Down
5 changes: 2 additions & 3 deletions deeplay/applications/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def test(
Tuple[str, tm.Metric],
Sequence[Union[tm.Metric, Tuple[str, tm.Metric]]],
Dict[str, tm.Metric],
None
None,
] = None,
batch_size: int = 32,
reset_metrics: bool = True,
Expand Down Expand Up @@ -187,9 +187,8 @@ def test(
)

dict_metrics: Dict[str, tm.Metric]

if metrics is None:
return self.trainer.test(self, test_dataloader)
return self.trainer.test(self, test_dataloader)[0]

if isinstance(metrics, tm.Metric):
dict_metrics = {metrics._get_name(): metrics}
Expand Down
6 changes: 3 additions & 3 deletions deeplay/applications/autoencoders/vae.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from typing import Optional, Sequence, Callable, List

from ...components import ConvolutionalEncoder2d, ConvolutionalDecoder2d
from ..application import Application
from ...external import External, Optimizer, Adam
from deeplay.components import ConvolutionalEncoder2d, ConvolutionalDecoder2d
from deeplay.applications import Application
from deeplay.external import External, Optimizer, Adam


import torch
Expand Down
7 changes: 3 additions & 4 deletions deeplay/applications/autoencoders/wae.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
from typing import Optional, Sequence, Callable, List

from ...components import ConvolutionalEncoder2d, ConvolutionalDecoder2d
from ..application import Application
from ...external import External, Optimizer, Adam
from ... import Layer
from deeplay.components import ConvolutionalEncoder2d, ConvolutionalDecoder2d
from deeplay.applications import Application
from deeplay.external import External, Optimizer, Adam, Layer

import torch
import torch.nn as nn
Expand Down
6 changes: 2 additions & 4 deletions deeplay/applications/classification/binary.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Optional, Sequence

from ..application import Application
from ...external import Optimizer, Adam
from deeplay.applications import Application
from deeplay.external import Optimizer, Adam


import torch
Expand Down Expand Up @@ -36,12 +36,10 @@ def __init__(
def params(self):
return self.model.parameters()


def compute_loss(self, y_hat, y):
if isinstance(self.loss, (torch.nn.BCELoss, torch.nn.BCEWithLogitsLoss)):
y = y.float()
return super().compute_loss(y_hat, y)


def forward(self, x):
return self.model(x)
4 changes: 2 additions & 2 deletions deeplay/applications/classification/categorical.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Optional, Sequence

from ..application import Application
from ...external import Optimizer, Adam
from deeplay.applications import Application
from deeplay.external import Optimizer, Adam


import torch
Expand Down
4 changes: 2 additions & 2 deletions deeplay/applications/classification/classifier.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Optional, Sequence

from ..application import Application
from ...external import External, Optimizer, Adam
from deeplay.applications import Application
from deeplay.external import External, Optimizer, Adam


import torch
Expand Down
4 changes: 2 additions & 2 deletions deeplay/applications/classification/multilabel.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Optional, Sequence

from ..application import Application
from ...external import Optimizer, Adam
from deeplay.applications import Application
from deeplay.external import Optimizer, Adam


import torch
Expand Down
4 changes: 2 additions & 2 deletions deeplay/applications/detection/lodestar/lodestar.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
import torch.nn as nn
from skimage import morphology

from ....components import ConvolutionalNeuralNetwork
from ...application import Application
from deeplay.components import ConvolutionalNeuralNetwork
from deeplay.applications import Application
from .transforms import RandomRotation2d, RandomTranslation2d, Transforms


Expand Down
4 changes: 2 additions & 2 deletions deeplay/applications/regression/regressor.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Optional, Sequence

from ..application import Application
from ...external import External, Optimizer, Adam
from deeplay.applications import Application
from deeplay.external import External, Optimizer, Adam


import torch
Expand Down
7 changes: 7 additions & 0 deletions deeplay/blocks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,13 @@ def __init__(self, *args, **kwargs):
super(BaseBlock, self).__init__(*args, **kwargs)

def multi(self, n=1) -> Self:

# Remove configurations before making new blocks
tags = self.tags
for key, vlist in self._user_config.items():
if key[:-1] in tags:
vlist.clear()

def make_new_self():
args, kwargs = self.get_init_args()
args = list(args) + list(self._args)
Expand Down
2 changes: 1 addition & 1 deletion deeplay/blocks/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

import warnings

from ..module import DeeplayModule
from deeplay import DeeplayModule
from deeplay.external import Layer


Expand Down
33 changes: 26 additions & 7 deletions deeplay/blocks/conv/conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from deeplay.ops.shape import Permute
from deeplay.blocks.base import DeferredConfigurableLayer


class Conv2dBlock(BaseBlock):
"""Convolutional block with optional normalization and activation."""

Expand Down Expand Up @@ -106,9 +107,10 @@ def upsampled(
after=None,
) -> Self:
upsample = upsample.new()
upsample.configure(
in_channels=self.out_channels, out_channels=self.out_channels
)
if "in_channels" in upsample.configurables:
upsample.configure(in_channels=self.out_channels)
if "out_channels" in upsample.configurables:
upsample.configure(out_channels=self.out_channels)
self.set("upsample", upsample, mode=mode, after=after)
return self

Expand Down Expand Up @@ -194,10 +196,27 @@ def _assert_valid_configurable(self, *args):
def residual(
block: Conv2dBlock,
order: str = "lanlan|",
activation=nn.ReLU,
normalization=nn.BatchNorm2d,
dropout=0.1,
activation: Union[Type[nn.Module], Layer] = nn.ReLU,
normalization: Union[Type[nn.Module], Layer] = nn.BatchNorm2d,
dropout: float = 0.1,
):
"""Make a residual block with the given order of layers.
Parameters
----------
order : str
The order of layers in the residual block. The shorthand is a string of 'l', 'a', 'n', 'd' and '|'.
'l' stands for layer, 'a' stands for activation, 'n' stands for normalization, 'd' stands for dropout,
and '|' stands for the skip connection. The order of the characters in the string determines the order
of the layers in the residual block. The characters after the '|' determine the order of the layers after
the skip connection.
activation : Union[Type[nn.Module], Layer]
The activation function to use in the residual block.
normalization : Union[Type[nn.Module], Layer]
The normalization layer to use in the residual block.
dropout : float
The dropout rate to use in the residual block.
"""
order = order.lower()
if "|" not in order:
order += "|"
Expand Down Expand Up @@ -324,7 +343,7 @@ def spatial_cross_attention(


@Conv2dBlock.register_style
def spatial_tranformer(
def spatial_transformer(
block: Conv2dBlock,
to_channel_last: bool = False,
normalization: Union[Layer, Type[nn.Module]] = nn.LayerNorm,
Expand Down
63 changes: 63 additions & 0 deletions deeplay/blocks/conv/conv2d.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
from typing import Literal, Type, Union, Optional, overload
import torch.nn as nn
from _typeshed import Incomplete
from deeplay.blocks.base import BaseBlock as BaseBlock, DeferredConfigurableLayer as DeferredConfigurableLayer
from deeplay.external import Layer as Layer
from deeplay.module import DeeplayModule as DeeplayModule
from deeplay.ops.logs import FromLogs as FromLogs
from deeplay.ops.merge import Add as Add, MergeOp as MergeOp
from deeplay.ops.shape import Permute as Permute
from typing import Literal, Type
from typing_extensions import Self

class Conv2dBlock(BaseBlock):
pool: DeferredConfigurableLayer | nn.Module
in_channels: Incomplete
out_channels: Incomplete
kernel_size: Incomplete
stride: Incomplete
padding: Incomplete
def __init__(self, in_channels: int | None, out_channels: int, kernel_size: int = 3, stride: int = 1, padding: int = 0, **kwargs) -> None: ...
def normalized(self, normalization: Type[nn.Module] | DeeplayModule = ..., mode: str = 'append', after: Incomplete | None = None) -> Self: ...
def pooled(self, pool: Layer = ..., mode: str = 'prepend', after: Incomplete | None = None) -> Self: ...
def upsampled(self, upsample: Layer = ..., mode: str = 'append', after: Incomplete | None = None) -> Self: ...
def transposed(self, transpose: Layer = ..., mode: str = 'prepend', after: Incomplete | None = None, remove_upsample: bool = True, remove_layer: bool = True) -> Self: ...
def strided(self, stride: int | tuple[int, ...], remove_pool: bool = True) -> Self: ...
def multi(self, n: int = 1) -> Self: ...
def shortcut(self, merge: MergeOp = ..., shortcut: Literal['auto'] | Type[nn.Module] | DeeplayModule | None = 'auto') -> Self: ...
@overload
def style(self, style: Literal["residual"], order: str="lanlan|", activation: Union[Type[nn.Module], Layer]=..., normalization: Union[Type[nn.Module], Layer]=..., dropout: float=0.1) -> Self:
"""Make a residual block with the given order of layers.
Parameters
----------
order : str
The order of layers in the residual block. The shorthand is a string of 'l', 'a', 'n', 'd' and '|'.
'l' stands for layer, 'a' stands for activation, 'n' stands for normalization, 'd' stands for dropout,
and '|' stands for the skip connection. The order of the characters in the string determines the order
of the layers in the residual block. The characters after the '|' determine the order of the layers after
the skip connection.
activation : Union[Type[nn.Module], Layer]
The activation function to use in the residual block.
normalization : Union[Type[nn.Module], Layer]
The normalization layer to use in the residual block.
dropout : float
The dropout rate to use in the residual block.
"""
@overload
def style(self, style: Literal["spatial_self_attention"], to_channel_last: bool=False, normalization: Union[Layer, Type[nn.Module]]=...) -> Self: ...
@overload
def style(self, style: Literal["spatial_cross_attention"], to_channel_last: bool=False, normalization: Union[Layer, Type[nn.Module]]=..., condition_name: str="condition") -> Self: ...
@overload
def style(self, style: Literal["spatial_transformer"], to_channel_last: bool=False, normalization: Union[Layer, Type[nn.Module]]=..., condition_name: Optional[str]="condition") -> Self: ...
@overload
def style(self, style: Literal["resnet"], stride: int=1) -> Self: ...
@overload
def style(self, style: Literal["resnet18_input"], ) -> Self: ...
def style(self, style: str, **kwargs) -> Self: ...

def residual(block: Conv2dBlock, order: str = 'lanlan|', activation: Type[nn.Module] | Layer = ..., normalization: Type[nn.Module] | Layer = ..., dropout: float = 0.1): ...
def spatial_self_attention(block: Conv2dBlock, to_channel_last: bool = False, normalization: Layer | Type[nn.Module] = ...): ...
def spatial_cross_attention(block: Conv2dBlock, to_channel_last: bool = False, normalization: Layer | Type[nn.Module] = ..., condition_name: str = 'condition'): ...
def spatial_transformer(block: Conv2dBlock, to_channel_last: bool = False, normalization: Layer | Type[nn.Module] = ..., condition_name: str | None = 'condition'): ...
14 changes: 5 additions & 9 deletions deeplay/blocks/la.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import torch.nn as nn

from ..module import DeeplayModule
from deeplay import DeeplayModule
from .sequential import SequentialBlock


Expand All @@ -34,20 +34,16 @@ def configure(
layer: Optional[DeeplayModule] = None,
activation: Optional[DeeplayModule] = None,
**kwargs: DeeplayModule,
) -> None:
...
) -> None: ...

@overload
def configure(self, name: Literal["layer"], *args, **kwargs) -> None:
...
def configure(self, name: Literal["layer"], *args, **kwargs) -> None: ...

@overload
def configure(self, name: Literal["activation"], *args, **kwargs) -> None:
...
def configure(self, name: Literal["activation"], *args, **kwargs) -> None: ...

@overload
def configure(self, name: str, *args, **kwargs: Any) -> None:
...
def configure(self, name: str, *args, **kwargs: Any) -> None: ...

def configure(self, *args, **kwargs): # type: ignore
super().configure(*args, **kwargs)
Loading

0 comments on commit 0013dff

Please sign in to comment.