Skip to content

Commit

Permalink
implement correct cnn shorthands
Browse files Browse the repository at this point in the history
  • Loading branch information
BenjaminMidtvedt committed Nov 20, 2023
1 parent bf39b1d commit efa64d8
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 13 deletions.
21 changes: 18 additions & 3 deletions deeplay/components/cnn/cnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,20 +86,35 @@ class ConvolutionalNeuralNetwork(DeeplayModule):
blocks: LayerList[PoolLayerActivationNormalization]

@property
def input_block(self):
def input(self):
"""Return the input layer of the network. Equivalent to `.blocks[0]`."""
return self.blocks[0]

@property
def hidden_blocks(self):
def hidden(self):
"""Return the hidden layers of the network. Equivalent to `.blocks[:-1]`"""
return self.blocks[:-1]

@property
def output_block(self):
def output(self):
"""Return the last layer of the network. Equivalent to `.blocks[-1]`."""
return self.blocks[-1]

@property
def layer(self) -> LayerList[Layer]:
"""Return the layers of the network. Equivalent to `.blocks.layer`."""
return self.blocks.layer

@property
def activation(self) -> LayerList[Layer]:
"""Return the activations of the network. Equivalent to `.blocks.activation`."""
return self.blocks.activation

@property
def normalization(self) -> LayerList[Layer]:
"""Return the normalizations of the network. Equivalent to `.blocks.normalization`."""
return self.blocks.normalization

def __init__(
self,
in_channels: Optional[int],
Expand Down
20 changes: 10 additions & 10 deletions deeplay/tests/test_cnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ def test_cnn_defaults(self):
self.assertEqual(cnn.blocks[0].layer.in_channels, 3)
self.assertEqual(cnn.blocks[0].layer.out_channels, 4)

self.assertEqual(cnn.output_block.layer.in_channels, 4)
self.assertEqual(cnn.output_block.layer.out_channels, 1)
self.assertEqual(cnn.output.layer.in_channels, 4)
self.assertEqual(cnn.output.layer.out_channels, 1)

# test on a batch of 2
x = torch.randn(2, 3, 5, 5)
Expand All @@ -27,10 +27,10 @@ def test_cnn_lazy_input(self):
cnn = ConvolutionalNeuralNetwork(None, [4], 1).build()
self.assertEqual(len(cnn.blocks), 2)

self.assertEqual(cnn.blocks[0].layer.in_channels, 0)
self.assertEqual(cnn.input.layer.in_channels, 0)
self.assertEqual(cnn.blocks[0].layer.out_channels, 4)
self.assertEqual(cnn.output_block.layer.in_channels, 4)
self.assertEqual(cnn.output_block.layer.out_channels, 1)
self.assertEqual(cnn.output.layer.in_channels, 4)
self.assertEqual(cnn.output.layer.out_channels, 1)

# test on a batch of 2
x = torch.randn(2, 3, 5, 5)
Expand All @@ -48,21 +48,21 @@ def test_change_act(self):
cnn.configure(out_activation=nn.Sigmoid)
cnn.build()
self.assertEqual(len(cnn.blocks), 2)
self.assertIsInstance(cnn.output_block.activation, nn.Sigmoid)
self.assertIsInstance(cnn.output.activation, nn.Sigmoid)

def test_change_out_act_Layer(self):
cnn = ConvolutionalNeuralNetwork(2, [4], 3)
cnn.configure(out_activation=Layer(nn.Sigmoid))
cnn.build()
self.assertEqual(len(cnn.blocks), 2)
self.assertIsInstance(cnn.output_block.activation, nn.Sigmoid)
self.assertIsInstance(cnn.output.activation, nn.Sigmoid)

def test_change_out_act_instance(self):
cnn = ConvolutionalNeuralNetwork(2, [4], 3)
cnn.configure(out_activation=nn.Sigmoid())
cnn.build()
self.assertEqual(len(cnn.blocks), 2)
self.assertIsInstance(cnn.output_block.activation, nn.Sigmoid)
self.assertIsInstance(cnn.output.activation, nn.Sigmoid)

def test_default_values_initialization(self):
cnn = ConvolutionalNeuralNetwork(
Expand All @@ -80,8 +80,8 @@ def test_empty_hidden_channels(self):
self.assertEqual(cnn.blocks[0].layer.in_channels, 3)
self.assertEqual(cnn.blocks[0].layer.out_channels, 1)

self.assertIs(cnn.blocks[0], cnn.output_block)
self.assertIs(cnn.blocks[0], cnn.input_block)
self.assertIs(cnn.blocks[0], cnn.input)
self.assertIs(cnn.blocks[0], cnn.output)

def test_zero_out_channels(self):
with self.assertRaises(ValueError):
Expand Down

0 comments on commit efa64d8

Please sign in to comment.