Skip to content
This repository has been archived by the owner on Nov 3, 2022. It is now read-only.

Add functional layer tools #312

Open
wants to merge 1 commit into
base: master
Choose a base branch
from

Conversation

crowsonkb
Copy link

Adds two functional tools for working with layers, sequence and repeat. They are simple composable functions that allow you to write cleaner code when using the Functional API. Instead of ugly code like:

input_layer = Input(shape=(16,))
x = Dense(8, activation='relu')(input_layer)
x = Dense(8, activation='relu')(x)
x = Dense(8, activation='relu')(x)
x = Dense(1)(x)
model = Model(input_layer, x)

applying the current layer to the previous layer manually on each line, sequence allows you to write:

input_layer = Input(shape=(16,))
output = sequence(
    Dense(8, activation='relu'),
    Dense(8, activation='relu'),
    Dense(8, activation='relu'),
    Dense(1),
)(input_layer)
model = Model(input_layer, output)

sequence takes any number of callables mapping a tensor to a tensor and returns a callable mapping a tensor to a tensor. So you can logically nest sequence calls, as in this example:

def residual_block(units):
    return lambda a: sequence(
        Activation('relu'),
        Dense(units),
        Activation('relu'),
        Dense(units),
        lambda b: Add()([a, b]),
    )(a)

output = sequence(
    residual_block(16),
    residual_block(16),
    Dense(1),
)(input_layer)

repeat, which takes a repeat count and a layer factory function, can be used to simplify creation of a deep model, as follows:

output = sequence(
    repeat(4, lambda: residual_block(16)),
    Dense(8),
    repeat(4, lambda: residual_block(8)),
    Dense(1),
)(input_layer)

As you can see, these functions allow you to easily create deep models without repetitive code.

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant