Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added support for MPS (Apple metal) GPU #839

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added .DS_Store
Binary file not shown.
Binary file added kilosort/.DS_Store
Binary file not shown.
2 changes: 1 addition & 1 deletion kilosort/bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,4 +258,4 @@ def load_phy(filename, fpath, ops):

yclu_new, Wsub = clu_ypos(filename, ops, st_new - 20, clu_new)

return st_new, clu_new, yclu_new, Wsub
return st_new, clu_new, yclu_new, Wsub
92 changes: 67 additions & 25 deletions kilosort/clustering_qr.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
logger = logging.getLogger(__name__)


def neigh_mat(Xd, nskip=10, n_neigh=30):
def neigh_mat(Xd, nskip=10, n_neigh=30, device=torch.device('cuda')):
# Xd is spikes by PCA features in a local neighborhood
# finding n_neigh neighbors of each spike to a subset of every nskip spike

Expand All @@ -36,6 +36,9 @@ def neigh_mat(Xd, nskip=10, n_neigh=30):

# exact neighbor search ("brute force")
# results is dn and kn, kn is n_samples by n_neigh, contains integer indices into Xsub
if device == torch.device('mps'):
# Limit threads for faiss, otherwise, 'index.search(Xd, n_neigh)' causes segmentation fault
faiss.omp_set_num_threads(1)
index = faiss.IndexFlatL2(dim) # build the index
index.add(Xsub) # add vectors to the index
_, kn = index.search(Xd, n_neigh) # actual search
Expand All @@ -53,19 +56,30 @@ def neigh_mat(Xd, nskip=10, n_neigh=30):
return kn, M


def assign_mu(iclust, Xg, cols_mu, tones, nclust = None, lpow = 1):
def assign_mu(iclust, Xg, cols_mu, tones, nclust = None, lpow = 1, device=torch.device('cuda')):
NN, nfeat = Xg.shape

rows = iclust.unsqueeze(-1).tile((1,nfeat))
ii = torch.vstack((rows.flatten(), cols_mu.flatten()))
iin = torch.vstack((rows[:,0], cols_mu[:,0]))
if lpow==1:
C_values = Xg.flatten()
C = coo(ii, Xg.flatten(), (nclust, nfeat))
else:
C_values
C = coo(ii, (Xg**lpow).flatten(), (nclust, nfeat))
N = coo(iin, tones, (nclust, 1))
C = C.to_dense()
N = N.to_dense()
if device == torch.device('mps'):
C_ = coo(ii, C_values, (nclust, nfeat), device = torch.device('cpu'))
C_ = C_.to_dense()
C = C_.to(device)
N_ = coo(iin, tones, (nclust, 1), device = torch.device('cpu'))
N_ = N_.to_dense()
N = N_.to(device)
else:
C = coo(ii, C_values, (nclust, nfeat))
C = C.to_dense()
N = coo(iin, tones, (nclust, 1))
N = N.to_dense()
mu = C / (1e-6 + N)

return mu, N
Expand All @@ -75,14 +89,25 @@ def assign_iclust(rows_neigh, isub, kn, tones2, nclust, lam, m, ki, kj, device=t
NN = kn.shape[0]

ij = torch.vstack((rows_neigh.flatten(), isub[kn].flatten()))
xN = coo(ij, tones2.flatten(), (NN, nclust))
xN = xN.to_dense()
if device == torch.device('mps'):
xN_ = coo(ij, tones2.flatten(), (NN, nclust), device = torch.device('cpu'))
xN_ = xN_.to_dense()
xN = xN_.to(device)
else:
xN = coo(ij, tones2.flatten(), (NN, nclust))
xN = xN.to_dense()

if lam > 0:
tones = torch.ones(len(kj), device = device)
tzeros = torch.zeros(len(kj), device = device)
ij = torch.vstack((tzeros, isub))
kN = coo(ij, tones, (1, nclust))

if device == torch.device('mps'):
kN_ = coo(ij, tones, (1, nclust), device = torch.device('cpu'))
kN_ = kN_.to_dense()
kN = kN_.to(device)
else:
kN = coo(ij, tones, (1, nclust))

xN = xN - lam/m * (ki.unsqueeze(-1) * kN.to_dense())

Expand All @@ -95,16 +120,28 @@ def assign_isub(iclust, kn, tones2, nclust, nsub, lam, m,ki,kj, device=torch.dev
n_neigh = kn.shape[1]
cols = iclust.unsqueeze(-1).tile((1, n_neigh))
iis = torch.vstack((kn.flatten(), cols.flatten()))

xS = coo(iis, tones2.flatten(), (nsub, nclust))
xS = xS.to_dense()

if device == torch.device('mps'):
xS_ = coo(iis, tones2.flatten(), (nsub, nclust), device = torch.device('cpu'))
xS_ = xS_.to_dense()
xS = xS_.to(device)
else:
xS = coo(iis, tones2.flatten(), (nsub, nclust))
xS = xS.to_dense()

if lam > 0:
tones = torch.ones(len(ki), device = device)
tzeros = torch.zeros(len(ki), device = device)
ij = torch.vstack((tzeros, iclust))
kN = coo(ij, tones, (1, nclust))
xS = xS - lam / m * (kj.unsqueeze(-1) * kN.to_dense())

if device == torch.device('mps'):
kN_ = coo(ij, tones, (1, nclust), device = torch.device('cpu'))
kN_ = kN_.to_dense()
kN = kN_.to(device)
xS = xS - lam / m * (kj.unsqueeze(-1) * kN)
else:
kN = coo(ij, tones, (1, nclust))
xS = xS - lam / m * (kj.unsqueeze(-1) * kN.to_dense())

isub = torch.argmax(xS, 1)
return isub
Expand All @@ -127,7 +164,7 @@ def cluster(Xd, iclust = None, kn = None, nskip = 20, n_neigh = 10, nclust = 200
seed = 1, niter = 200, lam = 0, device=torch.device('cuda')):

if kn is None:
kn, M = neigh_mat(Xd, nskip = nskip, n_neigh = n_neigh)
kn, M = neigh_mat(Xd, nskip = nskip, n_neigh = n_neigh, device = device)

m, ki, kj = Mstats(M, device=device)

Expand All @@ -140,22 +177,22 @@ def cluster(Xd, iclust = None, kn = None, nskip = 20, n_neigh = 10, nclust = 200
nsub = (NN-1)//nskip + 1

rows_neigh = torch.arange(NN, device = device).unsqueeze(-1).tile((1,n_neigh))

tones2 = torch.ones((NN, n_neigh), device = device)

if iclust is None:
iclust_init = kmeans_plusplus(Xg, niter = nclust, seed = seed, device=device)
iclust = iclust_init.clone()
else:
iclust_init = iclust.clone()

for t in range(niter):
# given iclust, reassign isub
isub = assign_isub(iclust, kn, tones2, nclust , nsub, lam, m,ki,kj, device=device)

# given mu and isub, reassign iclust
iclust = assign_iclust(rows_neigh, isub, kn, tones2, nclust, lam, m, ki, kj, device=device)

_, iclust = torch.unique(iclust, return_inverse=True)
nclust = iclust.max() + 1
isub = assign_isub(iclust, kn, tones2, nclust , nsub, lam, m,ki,kj, device=device)
Expand Down Expand Up @@ -384,10 +421,10 @@ def run(ops, st, tF, mode = 'template', device=torch.device('cuda'),
ops, xy, iC, iclust_template, tF, ycent[kk], xcent[jj],
dmin=dmin, dminx=dminx, ix=ix
)

if ii % 10 == 0:
log_performance(logger, header=f'Cluster center: {ii}')

if Xd is None:
nearby_chans_empty += 1
continue
Expand All @@ -404,18 +441,19 @@ def run(ops, st, tF, mode = 'template', device=torch.device('cuda'),
iclust, iclust0, M, _ = cluster(
Xd, nskip=nskip, lam=1, seed=5, device=device
)

if clear_cache:
gc.collect()
torch.cuda.empty_cache()

xtree, tstat, my_clus = hierarchical.maketree(M, iclust, iclust0)

xtree, tstat = swarmsplitter.split(
Xd.numpy(), xtree, tstat,iclust, my_clus, meta=st0
)

iclust = swarmsplitter.new_clusters(iclust, my_clus, xtree, tstat)

clu[igood] = iclust + nmax
Nfilt = int(iclust.max() + 1)
nmax += Nfilt
Expand Down Expand Up @@ -502,13 +540,17 @@ def get_data_cpu(ops, xy, iC, PID, tF, ycenter, xcenter, dmin=20, dminx=32,



def assign_clust(rows_neigh, iclust, kn, tones2, nclust):
def assign_clust(rows_neigh, iclust, kn, tones2, nclust, device=torch.device('cuda')):
NN = len(iclust)

ij = torch.vstack((rows_neigh.flatten(), iclust[kn].flatten()))
xN = coo(ij, tones2.flatten(), (NN, nclust))

xN = xN.to_dense()
if device == torch.device('mps'):
xN_ = coo(ij, tones2.flatten(), (NN, nclust), device = torch.device('cpu'))
xN_ = xN_.to_dense()
xN = xN_.to(device)
else:
xN = coo(ij, tones2.flatten(), (NN, nclust))
xN = xN.to_dense()
iclust = torch.argmax(xN, 1)

return iclust
Expand Down
4 changes: 2 additions & 2 deletions kilosort/datashift.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def align_block2(F, ysamp, ops, device=torch.device('cuda')):
dt = np.arange(-n,n+1,1)

# batch fingerprints are mean subtracted along depth
Fg = torch.from_numpy(F).to(device).float()
Fg = torch.from_numpy(F.astype('float32')).to(device)
Fg = Fg - Fg.mean(1).unsqueeze(1)

# the template fingerprint is initialized with batch 300 if that exists
Expand Down Expand Up @@ -156,7 +156,7 @@ def align_block2(F, ysamp, ops, device=torch.device('cuda')):
imin[:,j] = dall.sum(0)

# Fg gets reinitialized with the un-corrected F without subtracting the mean across depth.
Fg = torch.from_numpy(F).float()
Fg = torch.from_numpy(F.astype('float32'))
imax = dall[:niter-1].sum(0)

# Fg gets aligned again to compute the non-mean subtracted fingerprint
Expand Down
4 changes: 2 additions & 2 deletions kilosort/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def get_fwav(NT = 30122, fs = 30000, device=torch.device('cuda')):

# symmetric filter from scipy
wav = filtfilt(b,a , x).copy()
wav = torch.from_numpy(wav).to(device).float()
wav = torch.from_numpy(wav.astype('float32')).to(device)

# the filter will be used directly in the Fourier domain
fwav = fft(wav)
Expand Down Expand Up @@ -132,7 +132,7 @@ def get_highpass_filter(fs=30000, cutoff=300, device=torch.device('cuda')):
# symmetric filter from scipy
hp_filter = filtfilt(b, a , x).copy()

hp_filter = torch.from_numpy(hp_filter).to(device).float()
hp_filter = torch.from_numpy(hp_filter.astype('float32')).to(device)
return hp_filter

def fft_highpass(hp_filter, NT=30122):
Expand Down
27 changes: 24 additions & 3 deletions kilosort/run_kilosort.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
import numpy as np
import torch

import os
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"

import kilosort
from kilosort import (
preprocessing,
Expand Down Expand Up @@ -179,17 +182,31 @@ def run_kilosort(settings, probe=None, probe_name=None, filename=None,
logger.info(platform.processor())
if device is None:
if torch.cuda.is_available():
logger.info('Using GPU for PyTorch computations. '
logger.info('Using GPU (cuda) for PyTorch computations. '
'Specify `device` to change this.')
device = torch.device('cuda')
elif torch.backends.mps.is_available():
logger.info('Using GPU (mps) for PyTorch computations. '
'Specify `device` to change this.')
device = torch.device('mps')
else:
logger.info('Using CPU for PyTorch computations. '
'Specify `device` to change this.')
device = torch.device('cpu')

if device != torch.device('cpu'):
memory = torch.cuda.get_device_properties(device).total_memory/1024**3
logger.info(f'Using CUDA device: {torch.cuda.get_device_name()} {memory:.2f}GB')
if device == torch.device('cuda'):
memory = torch.cuda.get_device_properties(device).total_memory/1024**3
logger.info(f'Using CUDA device: {torch.cuda.get_device_name()} {memory:.2f}GB')
elif device == torch.device('mps'):
memory = torch.mps.recommended_max_memory()/1024**3
logger.info(f'Using MPS, recommended max memory: {memory:.2f}GB')
torch.mps.set_per_process_memory_fraction(1.0)
if settings.get('batch_size') > 65000:
settings['batch_size'] = 65000
logger.warning('Reducing batch size to 65000 for MPS.')
else:
raise ValueError(f'Invalid device: {device}, only cuda and mps are supported.')

logger.info('-'*40)
logger.info(f"Sorting {filename}")
Expand Down Expand Up @@ -225,6 +242,7 @@ def run_kilosort(settings, probe=None, probe_name=None, filename=None,
ops = compute_preprocessing(ops, device, tic0=tic0, file_object=file_object)
np.random.seed(1)
torch.cuda.manual_seed_all(1)
torch.mps.manual_seed(1)
torch.random.manual_seed(1)
ops, bfile, st0 = compute_drift_correction(
ops, device, tic0=tic0, progress_bar=progress_bar,
Expand Down Expand Up @@ -258,6 +276,7 @@ def run_kilosort(settings, probe=None, probe_name=None, filename=None,
logger.exception('Out of memory error, printing performance...')
log_performance(logger, level='info')
log_cuda_details(logger)
# No equivalent error code for mps

# This makes sure the full traceback is written to log file.
logger.exception('Encountered error in `run_kilosort`:')
Expand Down Expand Up @@ -905,6 +924,8 @@ def load_sorting(results_dir, device=None, load_extra_vars=False):
if device is None:
if torch.cuda.is_available():
device = torch.device('cuda')
elif torch.backends.mps.is_available():
device = torch.device('mps')
else:
device = torch.device('cpu')

Expand Down
26 changes: 16 additions & 10 deletions kilosort/spikedetect.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,18 +69,24 @@ def extract_wPCA_wTEMP(ops, bfile, nt=61, twav_min=20, Th_single_ch=6, nskip=25,
clips /= (clips**2).sum(1, keepdims=True)**.5

model = TruncatedSVD(n_components=ops['settings']['n_pcs']).fit(clips)
wPCA = torch.from_numpy(model.components_).to(device).float()
wPCA = torch.from_numpy(model.components_.astype('float32')).to(device)

with warnings.catch_warnings():
warnings.filterwarnings("ignore", message="")
# Prevents memory leak for KMeans when using MKL on Windows
msg = 'KMeans is known to have a memory leak on Windows with MKL'
nthread = os.environ.get('OMP_NUM_THREADS', msg)
os.environ['OMP_NUM_THREADS'] = '7'
model = KMeans(n_clusters=ops['settings']['n_templates'], n_init = 10).fit(clips)
wTEMP = torch.from_numpy(model.cluster_centers_).to(device).float()
wTEMP = wTEMP / (wTEMP**2).sum(1).unsqueeze(1)**.5
os.environ['OMP_NUM_THREADS'] = nthread
num_threads_ = '7'
if device == torch.device('mps'):
os.environ['OMP_NUM_THREADS'] = num_threads_
model = KMeans(n_clusters=ops['settings']['n_templates'], n_init = 10).fit(clips)
wTEMP = torch.from_numpy(model.cluster_centers_.astype('float32')).to(device)
wTEMP = wTEMP / (wTEMP**2).sum(1).unsqueeze(1)**.5
else: # Prevents memory leak for KMeans when using MKL on Windows
msg = 'KMeans is known to have a memory leak on Windows with MKL'
nthread = os.environ.get('OMP_NUM_THREADS', msg)
os.environ['OMP_NUM_THREADS'] = num_threads_
model = KMeans(n_clusters=ops['settings']['n_templates'], n_init = 10).fit(clips)
wTEMP = torch.from_numpy(model.cluster_centers_.astype('float32')).to(device)
wTEMP = wTEMP / (wTEMP**2).sum(1).unsqueeze(1)**.5
os.environ['OMP_NUM_THREADS'] = nthread

return wPCA, wTEMP

Expand Down Expand Up @@ -228,7 +234,7 @@ def run(ops, bfile, device=torch.device('cuda'), progress_bar=None,

iC2, _ = nearest_chans(ys, ys, xs, xs, nC2, device=device)

ds_torch = torch.from_numpy(ds).to(device).float()
ds_torch = torch.from_numpy(ds.astype('float32')).to(device)
template_sizes = sig * (1+torch.arange(nsizes, device=device))
weigh = torch.exp(-ds_torch.unsqueeze(-1) / template_sizes**2)
weigh = torch.permute(weigh, (2, 0, 1)).contiguous()
Expand Down
2 changes: 1 addition & 1 deletion kilosort/template_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def extract(ops, bfile, U, device=torch.device('cuda'), progress_bar=None):
st = np.concatenate((st, np.zeros_like(st)), 0)
tF = torch.cat((tF, torch.zeros_like(tF)), 0)

stt = stt.double()
stt = stt.cpu().double()
st[k:k+nsp,0] = ((stt[:,0]-nt) + ibatch * (ops['batch_size'])).cpu().numpy() - nt//2 + ops['nt0min']
st[k:k+nsp,1] = stt[:,1].cpu().numpy()
st[k:k+nsp,2] = amps[:,0].cpu().numpy()
Expand Down
Loading