From 68bcc97ad7fab8a34030a70e280f054f1e64a105 Mon Sep 17 00:00:00 2001 From: Katherine Crowson Date: Wed, 10 Oct 2018 02:45:58 -0700 Subject: [PATCH] Add functional layer tools --- keras_contrib/layers/__init__.py | 1 + keras_contrib/layers/functional.py | 88 +++++++++++++++++++ tests/keras_contrib/layers/functional_test.py | 35 ++++++++ 3 files changed, 124 insertions(+) create mode 100644 keras_contrib/layers/functional.py create mode 100644 tests/keras_contrib/layers/functional_test.py diff --git a/keras_contrib/layers/__init__.py b/keras_contrib/layers/__init__.py index 415cb2e2a..818318f92 100644 --- a/keras_contrib/layers/__init__.py +++ b/keras_contrib/layers/__init__.py @@ -11,3 +11,4 @@ from .wrappers import * from .convolutional_recurrent import * from .crf import * +from .functional import * diff --git a/keras_contrib/layers/functional.py b/keras_contrib/layers/functional.py new file mode 100644 index 000000000..c48ebef54 --- /dev/null +++ b/keras_contrib/layers/functional.py @@ -0,0 +1,88 @@ +"""Functional tools for working with layers.""" + +from functools import reduce + +__all__ = ['sequence', 'repeat'] + + +def sequence(*layers): + """Composes layers sequentially. + + # Arguments + *layers: Layers, or other callables that map a tensor to a tensor. + + # Returns + A callable that maps a tensor to the output tensor of the last layer. + + # Examples + + ```python + from keras.layers import Dense, Input + from keras.models import Model + from keras_contrib.layers import sequence + + input_layer = Input(shape=(16,)) + + output = sequence( + Dense(8, activation='relu'), + Dense(8, activation='relu'), + Dense(8, activation='relu'), + Dense(1), + )(input_layer) + + model = Model(input_layer, output) + ``` + """ + return reduce(lambda f, g: lambda x: g(f(x)), layers, lambda x: x) + + +def repeat(n, layer_factory): + """Constructs a sequence of repeated layers. + + # Arguments + n: int. The number of times to repeat the layer. + layer_factory: A function taking no arguments that returns a layer or + another callable that maps a tensor to a tensor. + + # Returns + A callable that maps a tensor to the output tensor of the last layer. + + # Examples + + ```python + from keras.layers import Dense, Input + from keras.models import Model + from keras_contrib.layers import repeat, sequence + + input_layer = Input(shape=(16,)) + + output = sequence( + repeat(3, lambda: Dense(8, activation='relu')), + Dense(1), + )(input_layer) + + model = Model(input_layer, output) + ``` + + `sequence` and `repeat` can be freely intermixed with layers, since they + both map a tensor to a tensor: + + ```python + from keras.layers import Activation, Dense, Input + from keras.models import Model + from keras_contrib.layers import repeat, sequence + + input_layer = Input(shape=(16,)) + + output = sequence( + repeat(3, lambda: sequence( + Dense(8), + Activation('relu'), + )), + Dense(1), + )(input_layer) + + model = Model(input_layer, output) + ``` + """ + return sequence(*(layer_factory() for _ in range(n))) diff --git a/tests/keras_contrib/layers/functional_test.py b/tests/keras_contrib/layers/functional_test.py new file mode 100644 index 000000000..a749133f3 --- /dev/null +++ b/tests/keras_contrib/layers/functional_test.py @@ -0,0 +1,35 @@ +"""Tests for functions in keras_contrib/layers/functional.py.""" + +import pytest + +from keras.layers import Dense, Input +from keras.models import Model +from keras_contrib.layers import repeat, sequence + + +def test_sequence(): + input_layer = Input(shape=(16,)) + output = sequence( + Dense(8), + Dense(1), + )(input_layer) + model = Model(input_layer, output) + assert len(model.layers) == 3 + assert model.layers[1].__class__.__name__ == 'Dense' + assert model.layers[2].__class__.__name__ == 'Dense' + assert model.layers[1].get_output_shape_at(0) == (None, 8) + assert model.layers[2].get_output_shape_at(0) == (None, 1) + + +def test_repeat(): + input_layer = Input(shape=(16,)) + output = repeat(2, lambda: Dense(8))(input_layer) + model = Model(input_layer, output) + assert len(model.layers) == 3 + assert model.layers[1].__class__.__name__ == 'Dense' + assert model.layers[2].__class__.__name__ == 'Dense' + assert id(model.layers[1]) != id(model.layers[2]) + + +if __name__ == '__main__': + pytest.main([__file__])