Skip to content

Commit

Permalink
User deterministic core from MATLAB Kilosort2 (#21)
Browse files Browse the repository at this point in the history
  • Loading branch information
alexmorley authored Jan 4, 2021
1 parent 10e1c3a commit 5318929
Show file tree
Hide file tree
Showing 8 changed files with 1,063 additions and 371 deletions.
5 changes: 0 additions & 5 deletions pykilosort/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,6 @@
import logging
import os

# TODO: move this to config
# TODO: setting to true not yet implemented
ENABLE_STABLEMODE = False #True
ENSURE_DETERM = False

if os.getenv('MOCK_CUPY', False):
from pykilosort.testing.mock_cupy import cupy
from pykilosort.testing.mock_cupyx import cupyx
Expand Down
4 changes: 2 additions & 2 deletions pykilosort/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,8 +376,8 @@ def mexClustering2(Params, uproj, W, mu, call, iMatch, iC):
bestFilter((40,), (256,), (d_Params, d_iMatch, d_iC, d_call, d_cmax, d_id, d_x))

# average all spikes for same template -- ORIGINAL
average_snips = cp.RawKernel(code, 'average_snips')
average_snips(
average_snips_v2 = cp.RawKernel(code, 'average_snips_v2')
average_snips_v2(
(Nfilters,), (NrankPC, NchanNear), (d_Params, d_iC, d_call, d_id, d_uproj, d_cmax, d_dWU))

count_spikes = cp.RawKernel(code, 'count_spikes')
Expand Down
54 changes: 49 additions & 5 deletions pykilosort/cuda/mexClustering2.cu
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,15 @@ __global__ void bestFilter(const double *Params, const bool *match,
__global__ void average_snips(const double *Params, const int *iC, const int *call,
const int *id, const float *uproj, const float *cmax, float *WU){

//Nfilt blocks
//Thread grid = (NrankPC, NchanNear)
//This implementation does not work correctly for real data!
//Since this_chan is function of the spike -- spikes assigned to a given template
//will have max channels that span a 2-3 channel range -- different (tidx, tidy)
//pairs can wind up trying to add to the same element of dWU, resulting in
//collisions and incorrect results. Use the single-threaded version
//average_snips_v2 instead. Speed hit is only ~ 5-6 seconds out of 360 sec for a
//typical 2 hour Neuropixels 1.0 dataset.
int my_chan, this_chan, tidx, tidy, bid, ind, Nspikes, NrankPC, NchanNear, Nchan;
float xsum = 0.0f;

Expand All @@ -85,15 +94,51 @@ __global__ void average_snips(const double *Params, const int *iC, const int *ca
tidy = threadIdx.y;
bid = blockIdx.x;

for(ind=0; ind<Nspikes;ind++)
for(ind=0; ind<Nspikes;ind++) {
if (id[ind]==bid){
my_chan = call[ind];
this_chan = iC[tidy + NchanNear * my_chan];
xsum = uproj[tidx + NrankPC*tidy + NrankPC*NchanNear * ind];
WU[tidx + NrankPC*this_chan + NrankPC*Nchan * bid] += xsum;
}
}
}

}

//////////////////////////////////////////////////////////////////////////////////////////
__global__ void average_snips_v2(const double *Params, const int *iC, const int *call,
const int *id, const float *uproj, const float *cmax, float *WU){


// jic, version with no threading over features, to avoid
// collisions when summing WU
// run

int my_chan, this_chan, bid, ind, Nspikes, NrankPC, NchanNear, Nchan;
float xsum = 0.0f;
int chanIndex, pcIndex;

Nspikes = (int) Params[0];
NrankPC = (int) Params[1];
Nchan = (int) Params[7];
NchanNear = (int) Params[6];


bid = blockIdx.x;

for(ind=0; ind<Nspikes;ind++)
if (id[ind]==bid){
my_chan = call[ind];
for (chanIndex = 0; chanIndex < NchanNear; ++chanIndex) {
this_chan = iC[chanIndex + NchanNear * my_chan];
for (pcIndex = 0; pcIndex < NrankPC; ++pcIndex) {
xsum = uproj[pcIndex + NrankPC*chanIndex + NrankPC*NchanNear * ind];
WU[pcIndex + NrankPC*this_chan + NrankPC*Nchan * bid] += xsum;
}
}

}
}


//////////////////////////////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -141,8 +186,8 @@ __global__ void sum_dWU(const double *Params, const float *bigArray, float *WU)
int tid,bid, ind, Nfilters, Nthreads, Nfeatures, Nblocks, NfeatW, nWU, nElem;
float sum = 0.0f;

Nfeatures = (int) Params[1];
NfeatW = (int) Params[4];
Nfeatures = (int) Params[1]; //NrankPC, number of pcs
NfeatW = (int) Params[4]; //Nchan*nPC
Nfilters = (int) Params[2];
Nthreads = blockDim.x;
Nblocks = gridDim.x;
Expand Down Expand Up @@ -213,5 +258,4 @@ __global__ void count_spikes(const double *Params, const int *id, int *nsp, cons
tind += NthreadsMe * Nblocks;
}


}
2 changes: 1 addition & 1 deletion pykilosort/cuda/mexDistances2.cu
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ __global__ void computeCost(const double *Params, const float *Ws, const float *
int j, tid, bid, Nspikes, my_chan, this_chan, Nchan, NrankPC, NchanNear, Nthreads, k;
float xsum = 0.0f, Ci;

Nspikes = (int) Params[0];
Nspikes = (int) Params[0]; //more accurately, number of comparisons, Nfilt*Nbatch
Nchan = (int) Params[7];
NrankPC = (int) Params[1];
NchanNear = (int) Params[6];
Expand Down
Loading

0 comments on commit 5318929

Please sign in to comment.