Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pending tistdp exhibit #4

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 75 additions & 0 deletions exhibits/time-integrated-stdp/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# Time-Integrated Spike-Timing-Dependent Plasticity

<b>Version</b>: ngclearn>=1.2.beta1, ngcsimlib==0.3.beta4

This exhibit contains an implementation of the spiking neuronal model
and credit assignment process proposed and studied in:

Gebhardt, William, and Alexander G. Ororbia. "Time-Integrated Spike-Timing-
Dependent Plasticity." arXiv preprint arXiv:2407.10028 (2024).

<p align="center">
<img height="250" src="fig/tistdp_snn.jpg"><br>
<i>Visual depiction of the TI-STDP-adapted SNN architecture.</i>
</p>

<!--
The model and its biological credit assignment process are also discussed in
the ngc-learn
<a href="https://ngc-learn.readthedocs.io/en/latest/museum/tistdp.html">documentation</a>.
-->

## Running and Analyzing the Model Simulations

### Unsupervised Digit-Level Biophysical Model

To run the main TI-STDP SNN experiments of the paper, simply execute:

```console
$ ./train_models.sh 0 tistdp snn_case1
```

which will trigger three experimental trials for adaptation of the
`Case 1` model described in the paper on MNIST. If you want to train
the online, `Case 2` model described in the paper on MNIST, you simply
need to change the third argument to the Bash script like so:

```console
$ ./train_models.sh 0 tistdp snn_case2
```

Independent of whichever case-study you select above, you can analyze the
trained models, in accordance with what was done in the paper, by executing
the analysis bash script as follows:

```console
$ ./analyze_models.sh 0 tistdp ## run on GPU 0 the "tistdp" config
```

<i>Task</i>: Models under this section engage in unsupervised representation
learning and jointly learn, through spike-timing driven credit assignment,
a low-level and higher-level abstract distributed, discrete representations
of sensory input data. In this exhibit, this is particularly focused on
using patterns in the MNIST database.

### Part-Whole Assembly SNN Model

To run the patch-model SNN adapted with TI-STDP, enter the `patch_model/`
sub-directory and then execute:

```console
$ ./train_patch_models.sh
```

which will run a single trial to produce the SNN generative assembly
(or part-whole hierarchical) model constructed in the paper.

<i>Task</i>: This biophysical model engages in a form unsupervised
representation learning that is focused on learning a simple bi-level
part-whole hierarchy of sensory input data (in this exhibit, the focus is on
using data from the MNIST database).

## Model Descriptions, Hyperparameters, and Configuration Details

Model explanations, meta-parameters settings and experimental details are
provided in the above reference paper.
74 changes: 74 additions & 0 deletions exhibits/time-integrated-stdp/analyze_models.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
#!/bin/bash

## get in user-provided program args
GPU_ID=$1 #1
MODEL=$2 # evstdp trstdp tistdp

if [[ "$MODEL" != "evstdp" && "$MODEL" != "trstdp" && "$MODEL" != "tistdp" ]]; then
echo "Invalid Arg: $MODEL -- only 'evstdp', 'trstdp', 'tistdp' models supported!"
exit 1
fi
echo " >>>> Setting up $MODEL on GPU $GPU_ID"

SEEDS=(1234 77 811)

PARAM_SUBDIR="/custom"
DISABLE_ADAPT_AT_EVAL=False ## set to true to turn off eval-time adaptive thresholds
MAKE_CLUSTER_PLOT=False #True
REBIND_LABELS=0 ## rebind labels to train model?

N_SAMPLES=50000
DATA_X="../../data/mnist/trainX.npy"
DATA_Y="../../data/mnist/trainY.npy"
DEV_X="../../data/mnist/testX.npy" # validX.npy
DEV_Y="../../data/mnist/testY.npy" # validY.npy
EXTRACT_TRAINING_SPIKES=0 # set to 1 if you want to extract training set codes

for seed in "${SEEDS[@]}"
do
EXP_DIR="exp_$MODEL""_$seed/"
echo " > Running Simulation/Model: $EXP_DIR"

CODEBOOK=$EXP_DIR"training_codes.npy"
TEST_CODEBOOK=$EXP_DIR"test_codes.npy"
PLOT_FNAME=$EXP_DIR"codes.jpg"

if [[ $REBIND_LABELS == 1 ]]; then
CUDA_VISIBLE_DEVICES=$GPU_ID python bind_labels.py --dataX=$DATA_X --dataY=$DATA_Y \
--model_type=$MODEL \
--model_dir=$EXP_DIR$MODEL \
--n_samples=$N_SAMPLES \
--exp_dir=$EXP_DIR \
--disable_adaptation=$DISABLE_ADAPT_AT_EVAL \
--param_subdir=$PARAM_SUBDIR
fi

## eval model
# CUDA_VISIBLE_DEVICES=$GPU_ID python eval.py --dataX=$DEV_X --dataY=$DEV_Y \
# --model_type=$MODEL --model_dir=$EXP_DIR$MODEL \
# --label_fname=$EXP_DIR"binded_labels.npy" \
# --exp_dir=$EXP_DIR \
# --disable_adaptation=$DISABLE_ADAPT_AT_EVAL \
# --param_subdir=$PARAM_SUBDIR \
# --make_cluster_plot=$MAKE_CLUSTER_PLOT
## call codebook extraction processes
if [[ $EXTRACT_TRAINING_SPIKES == 1 ]]; then
CUDA_VISIBLE_DEVICES=$GPU_ID python extract_codes.py --dataX=$DATA_X \
--n_samples=$N_SAMPLES \
--codebook_fname=$CODEBOOK \
--model_type=$MODEL \
--model_fname=$EXP_DIR$MODEL \
--disable_adaptation=False \
--param_subdir=$PARAM_SUBDIR
fi
CUDA_VISIBLE_DEVICES=$GPU_ID python extract_codes.py --dataX=$DEV_X \
--codebook_fname=$TEST_CODEBOOK \
--model_type=$MODEL \
--model_fname=$EXP_DIR$MODEL \
--disable_adaptation=False \
--param_subdir=$PARAM_SUBDIR
## visualize latent codes
CUDA_VISIBLE_DEVICES=$GPU_ID python viz_codes.py --plot_fname=$PLOT_FNAME \
--codes_fname=$TEST_CODEBOOK \
--labels_fname=$DEV_Y
done
23 changes: 23 additions & 0 deletions exhibits/time-integrated-stdp/bind.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
#!/bin/bash
GPU_ID=1 #0

N_SAMPLES=10000
DISABLE_ADAPT_AT_EVAL=False

EXP_DIR="exp_trstdp/"
MODEL="trstdp"
#EXP_DIR="exp_evstdp/"
#MODEL="evstdp"
DEV_X="../../data/mnist/trainX.npy" # validX.npy
DEV_Y="../../data/mnist/trainY.npy" # validY.npy
PARAM_SUBDIR="/custom_snapshot2"
#PARAM_SUBDIR="/custom"

## eval model
CUDA_VISIBLE_DEVICES=$GPU_ID python bind_labels.py --dataX=$DEV_X --dataY=$DEV_Y \
--model_type=$MODEL \
--model_dir=$EXP_DIR$MODEL \
--n_samples=$N_SAMPLES \
--exp_dir=$EXP_DIR \
--disable_adaptation=$DISABLE_ADAPT_AT_EVAL \
--param_subdir=$PARAM_SUBDIR
137 changes: 137 additions & 0 deletions exhibits/time-integrated-stdp/bind_labels.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
from jax import numpy as jnp, random
import sys, getopt as gopt, optparse, time

import matplotlib #.pyplot as plt
matplotlib.use('Agg')
import matplotlib.pyplot as plt
cmap = plt.cm.jet

################################################################################
# read in general program arguments
options, remainder = gopt.getopt(sys.argv[1:], '', ["dataX=", "dataY=",
"model_dir=", "model_type=",
"exp_dir=", "n_samples=",
"disable_adaptation=",
"param_subdir="])

model_case = "snn_case1"
disable_adaptation = True
exp_dir = "exp/"
param_subdir = "/custom"
model_type = "tistdp"
model_dir = "exp/tistdp"
dataX = "../../data/mnist/trainX.npy"
dataY = "../../data/mnist/trainY.npy"
n_samples = 10000
verbosity = 0 ## verbosity level (0 - fairly minimal, 1 - prints multiple lines on I/O)
for opt, arg in options:
if opt in ("--dataX"):
dataX = arg.strip()
elif opt in ("--dataY"):
dataY = arg.strip()
elif opt in ('--model_dir'):
model_dir = arg.strip()
elif opt in ('--model_type'):
model_type = arg.strip()
elif opt in ('--exp_dir'):
exp_dir = arg.strip()
elif opt in ('--param_subdir'):
param_subdir = arg.strip()
elif opt in ('--n_samples'):
n_samples = int(arg.strip())
elif opt in ('--disable_adaptation'):
disable_adaptation = (arg.strip().lower() == "true")
print(" > Disable short-term adaptation? ", disable_adaptation)

if model_case == "snn_case1":
print(" >> Setting up Case 1 model!")
from snn_case1 import load_from_disk, get_nodes
elif model_case == "snn_case2":
print(" >> Setting up Case 2 model!")
from snn_case2 import load_from_disk, get_nodes
else:
print("Error: No other model case studies supported! (", model_case, " invalid)")
exit()

print(">> X: {} Y: {}".format(dataX, dataY))

dkey = random.PRNGKey(1234)
dkey, *subkeys = random.split(dkey, 3)

## load dataset
_X = jnp.load(dataX)
_Y = jnp.load(dataY)
if 0 < n_samples < _X.shape[0]:
ptrs = random.permutation(subkeys[0], _X.shape[0])[0:n_samples]
_X = _X[ptrs, :]
_Y = _Y[ptrs, :]
# _X = _X[0:n_samples, :]
# _Y = _Y[0:n_samples, :]
print("-> Binding {} first randomly selected samples to model".format(n_samples))
n_batches = _X.shape[0] ## num batches is = to num samples (online learning)

## basic simulation hyper-parameter/configuration values go here
viz_mod = 1000 #10000
mb_size = 1 ## locked to batch sizes of 1
patch_shape = (28, 28)
in_dim = patch_shape[0] * patch_shape[1]

T = 250 #300 ## num time steps to simulate (stimulus presentation window length)
dt = 1. ## integration time constant

################################################################################
print("--- Loading Model ---")

## Load in model
model = load_from_disk(model_dir, param_dir=param_subdir,
disable_adaptation=disable_adaptation)
nodes, node_map = get_nodes(model)

################################################################################
print("--- Starting Binding Process ---")

print("------------------------------------")
model.showStats(-1)

## enter main adaptation loop over data patterns
class_responses = jnp.zeros((_Y.shape[1], node_map.get("z2e").n_units))
num_bound = 0
n_total_samp_seen = 0
tstart = time.time()
n_samps_seen = 0
for j in range(n_batches):
idx = j
Xb = _X[idx: idx + mb_size, :]
Yb = _Y[idx: idx + mb_size, :]

model.reset()
model.clamp(Xb)
spikes1, spikes2 = model.infer(
jnp.array([[dt * k, dt] for k in range(T)]))
## bind output spike train(s)
responses = Yb.T * jnp.sum(spikes2, axis=0)
class_responses = class_responses + responses
num_bound += 1

n_samps_seen += Xb.shape[0]
n_total_samp_seen += Xb.shape[0]
print("\r Binding {} images...".format(n_samps_seen), end="")
tend = time.time()
print()
sim_time = tend - tstart
sim_time_hr = (sim_time/3600.0) # convert time to hours
print(" -> Binding.Time = {} s".format(sim_time_hr))
print("------------------------------------")

## compute max-frequency (~firing rate) spike responses
class_responses = jnp.argmax(class_responses, axis=0, keepdims=True)
print("---- Max Class Responses ----")
print(class_responses)
print(class_responses.shape)
bind_fname = "{}binded_labels.npy".format(exp_dir)
print(" >> Saving label bindings to: ", bind_fname)
jnp.save(bind_fname, class_responses)




1 change: 1 addition & 0 deletions exhibits/time-integrated-stdp/custom/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .ti_stdp_synapse import TI_STDP_Synapse
Loading