Non-Official Implementation of "Attention as an RNN" from https://arxiv.org/pdf/2405.13956 implemented as a recurrent layer and efficient prefix-sum layer.
import numpy as np
from src import models,layers,utils
import matplotlib.pyplot as plt
EPOCHS = 400
BATCH_SIZE = 50
# Init sin dataset
x = utils.generate_sin()
x = tf.transpose(x, (2, 0, 1))
# Init Model
model = models.ScanRNNAttentionModel(heads=[10, 5], dims=[5, 2], activation="silu", concat_heads=False)
_ = model(x)
model.compile("adam", "mse")
# Train
history = model.fit(x, x, epochs=EPOCHS, batch_size=BATCH_SIZE)
# Visualize Result
o = model(x)
plt.plot(o[30, :, 0], label="RNN-Attention")
plt.plot(x[30, :, 0], label="True")
plt.legend()
plt.plot(history.history["loss"], label="train loss")
plt.plot(history.history["val_loss"], label="validation loss")
plt.ylabel("Loss")
plt.xlabel("Epochs")
plt.legend()
You can also run the model in training mode, making the output stochasic via dropout and approximating the epistemic uncertainty
from src import plotter
fig, ax = plotter.plot_hist2d(scan_model, x[0], axis=0)
You can train the model using the pre-fix sum implementation and transfer the weights to the recurrent implementation to make inference recurrentely:
config = model.get_config()
heads = config["heads"]
dims = config["dims"]
activation = config["activation"]
dropout = config["dropout"]
recurrent_dropout = config["recurrent_dropout"]
rnn_model = models.AttentionRNN(
heads,
dims,
activation=activation,
dropout=dropout,
recurrent_dropout=recurrent_dropout,
)
rnn_model.build(x.shape)
rnn_model.set_weights(scan_model.get_weights())
rnn_model.compile("adam", "mse")
# Load dataset (!pip install aeon)
from aeon.datasets import load_classification
X, y = load_classification("ECG200")
y = y.astype(float)
y = np.where(y>0, 1., 0.)
plt.plot(X[:10, 0, :].T)
# Construct model
ki = tf.keras.Input(shape=(None, 1))
scan = models.ScanRNNAttentionModel([10, 10], [10, 10])
avg_pool = tf.keras.layers.GlobalAveragePooling1D()
max_pool = tf.keras.layers.GlobalMaxPooling1D()
min_pool = tf.keras.layers.Lambda(lambda x: tf.reduce_min(x, -2))
conc = tf.keras.layers.Concatenate()
dense = tf.keras.layers.Dense(1, "sigmoid")
h = scan(ki)
avgp = avg_pool(h)
maxp = max_pool(h)
minp = min_pool(h)
mix = conc([avgp, maxp, minp])
o = dense(mix)
classification_model = tf.keras.Model(ki, o)
classification_model.compile("adam", "bce")
# Train
hist = classification_model.fit(X[:, 0, :], y, epochs=1000)
# Score
pred = classification_model.predict(X[:, 0, :])
tf.keras.metrics.Accuracy()(pred>0.5, y)
ki = tf.keras.Input(shape=(None, 1))
cnn = layers.LinearSelfAttention(5,10, "linear")
avg_pool = tf.keras.layers.GlobalAveragePooling1D()
max_pool = tf.keras.layers.GlobalMaxPooling1D()
min_pool = tf.keras.layers.Lambda(lambda x: tf.reduce_min(x, -2))
conc = tf.keras.layers.Concatenate()
dense = tf.keras.layers.Dense(1, "sigmoid")
h = cnn(ki)
avgp = avg_pool(h)
maxp = max_pool(h)
minp = min_pool(h)
mix = conc([avgp, maxp, minp])
o = dense(mix)
classification_model_linear = tf.keras.Model(ki, o)
classification_model_linear.compile("adam", "bce")
ki = tf.keras.Input(shape=(None, 1))
cnn = tf.keras.layers.Conv1D(32,10)
avg_pool = tf.keras.layers.GlobalAveragePooling1D()
max_pool = tf.keras.layers.GlobalMaxPooling1D()
min_pool = tf.keras.layers.Lambda(lambda x: tf.reduce_min(x, -2))
conc = tf.keras.layers.Concatenate()
dense = tf.keras.layers.Dense(1, "sigmoid")
h = cnn(ki)
avgp = avg_pool(h)
maxp = max_pool(h)
minp = min_pool(h)
mix = conc([avgp, maxp, minp])
o = dense(mix)
classification_model_cnn = tf.keras.Model(ki, o)
classification_model_cnn.compile("adam", "bce")