Marques-Smith neuropixel 384ch paired recording
Sorter comparison with paired (neuropixel - patch) recordings¶
Author : Samuel Garcia
André Marques-Smith make an open dataset with simultaneous patch-clamp and neuropixel probe extracellular recordings from the same cortical neuron in anaesthetized rats.
This is very very usefull to test spike sorting engine.
The original contain 42 recordings.
Here we select only a subset of 6 files. I keep only when the SNR in the extra cellular trace is big enough to be detected. One file (c24) was remove because the juxta cellular itself is ambiguous.
The patch recording will be the "Groud Truth". And the neuropixel with 384ch will computed by 5 sorters to compare results.
Please have a look to the paper:
https://www.biorxiv.org/content/10.1101/370080v2
The repo the explain everything
https://github.com/kampff-lab/sc.io/tree/master/Paired%20Recordings
Data set availaible here :
http://crcns.org/data-sets/methods/spe-1
or here
https://drive.google.com/drive/folders/13GCOuWN4QMW6vQmlNIolUrxPy-4Wv1BC
Note :
- I will note use the spiek indexes provide by André because for some files small errors on double peak detection can occurs.
- This results is also on spike forest here but the spikeforest keep only 32 channels to reduce the computation. Here the computation is done on 384 channels. Lets see if we have the same results.
# import everything
import os, getpass
kilosort2_path = '/home/samuel/Documents/Spikeinterface/Kilosort2'
os.environ["KILOSORT2_PATH"] = kilosort2_path
kilosort_path = '/home/samuel/Documents/Spikeinterface/KiloSort/'
os.environ["KILOSORT_PATH"] = kilosort_path
ironclust_path = '/home/samuel/Documents/Spikeinterface/ironclust'
os.environ["IRONCLUST_PATH"] = ironclust_path
from pathlib import Path
import scipy.signal
import scipy.io
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import spikeinterface as si
import spikeinterface.extractors as se
import spikeinterface.sorters as ss
import spikeinterface.widgets as sw
from spikeinterface.comparison import GroundTruthStudy
si.print_spikeinterface_version()
ss.print_sorter_versions()
path and list¶
p = '/media/samuel/dataspikesorting/DataSpikeSortingHD2/andre_paired_neuropixel/'
p = Path(p)
recordings_folder = p / 'recordings'
study_folder = p / 'study_paired_neuropixel'
rec_names = [
'c14',
'c26',
'c28',
'c37',
'c45',
'c46'
]
function to detect peak from patch recording¶
File provide by André contains small errors in peak detection. Here we computed then again.
def detect_peak_on_patch_sig(patch_sig, sample_rate):
# filter because some traces have drift
sos = scipy.signal.iirfilter(5, 200./sample_rate*2, analog=False, btype = 'highpass', ftype = 'butter', output = 'sos')
patch_sig_f = scipy.signal.sosfiltfilt(sos, patch_sig, axis=0)
med = np.median(patch_sig_f)
mad = np.median(np.abs(patch_sig_f-med))*1.4826
thresh = med - 12 * mad
# 1 ms aounrd peak
d = int(sample_rate * 0.001)
spike_indexes, prop = scipy.signal.find_peaks(-patch_sig_f, height=-thresh, distance=d)
#~ fig, ax = plt.subplots()
#~ ax.plot(patch_sig_f)
#~ ax.axhline(thresh)
#~ ax.plot(spike_indexes, patch_sig_f[spike_indexes], ls='None', marker='o')
#~ plt.show()
return spike_indexes
## create the study
# the file chanMap.mat contain the geometry
d = scipy.io.loadmat(str(p / 'chanMap.mat'))
locations = np.zeros((384, 2))
locations[:, 0] = d['xcoords'][:, 0]
locations[:, 1] = d['ycoords'][:, 0]
#~ fig, ax = plt.subplots()
#~ ax.scatter(locations[:, 0], locations[:, 1])
#~ plt.show()
gt_dict = {}
for rec_name in rec_names:
print(rec_name)
sample_rate = 30000.
# neuropixel sigs
raw_bin_filename = recordings_folder / rec_name / (rec_name + '_npx_raw.bin')
mea_sigs = np.memmap(raw_bin_filename, dtype='int16', mode='r').reshape(-1, 384)
# patch recoring
filename = recordings_folder / rec_name / (rec_name + '_patch_ch1.bin')
#~ patch_sig = np.memmap(str(filename), dtype='float64', mode='r')
patch_sig = np.fromfile(str(filename), dtype='float64')
# spike index inn the patch clock refrence
sr = 50023. # this is not theexact freq but it do not matter here
gt_spike_indexes_patch = detect_peak_on_patch_sig(patch_sig, sr)
# this is time factor strech between the 2 recordings (neuropixel and patch)
time_factor = mea_sigs.shape[0] / patch_sig.shape[0]
print('time_factor', time_factor)
# spike index in the neuropixel clock refrence
gt_spike_indexes = (gt_spike_indexes_patch * time_factor).astype('int64')
# recording
rec = se.BinDatRecordingExtractor(raw_bin_filename, sample_rate, 384, 'int16', offset=0, time_axis=0)
rec.set_channel_locations(locations)
# gt sorting
sorting_gt = se.NumpySortingExtractor()
sorting_gt.set_times_labels(gt_spike_indexes, np.zeros(gt_spike_indexes.size, dtype='int64'))
sorting_gt.set_sampling_frequency(sample_rate)
gt_dict[rec_name] = (rec, sorting_gt)
study = GroundTruthStudy.create(study_folder, gt_dict)
Get signal to noise ratio for all units¶
study = GroundTruthStudy(study_folder)
snr = study.concat_all_snr()
snr
fig, ax = plt.subplots()
ax.hist(snr['snr'].values, bins=np.arange(0, 40, 5))
ax.set_xlabel('GT units SNR')
Run all sorters¶
sorter_list = ['herdingspikes', 'ironclust', 'kilosort2',
'spykingcircus', 'tridesclous']
study = GroundTruthStudy(study_folder)
study.run_sorters(sorter_list, mode='keep', verbose=False)
Run comparison with ground truth and retreive result tables¶
study = GroundTruthStudy(study_folder)
# this copy sorting is necessary to copy results from sorter
# into a centralize folder with all results
study.copy_sortings()
# this run all comparison to GT
# exhaustive_gt=False because it is a pair recording so only one GT units
study.run_comparisons(exhaustive_gt=False, match_score=0.1, overmerged_score=0.2)
# this retrieve results
comparisons = study.comparisons
dataframes = study.aggregate_dataframes()
Run times¶
dataframes['run_times'].set_index(['rec_name', 'sorter_name']).unstack('sorter_name')
sns.set_palette(sns.color_palette("Set1"))
#fig, ax = plt.subplots()
g = sns.catplot(data=dataframes['run_times'], x='sorter_name', y='run_time',
hue="rec_name",
order=sorter_list,
kind='bar',
legend=False)
g.fig.set_size_inches(12,5)
#ax.set_yscale('log')
#ax.set_ylabel('Run time (s)');
#ax.set_xlabel(None);
fig.axes[0].set_xticklabels(sorter_list, rotation=40, ha='right');
accuracy/precision/recall scores per sorters¶
sns.set_palette(sns.color_palette("Set1"))
df = pd.melt(dataframes['perf_by_units'], id_vars=['rec_name', 'sorter_name'],
var_name='metric', value_name='score',
value_vars=('accuracy','precision', 'recall'))
display(df)
# df.assign(Generation=df.metric.map({'metric':'Metric','accuracy': 'Accuracy', 'precision': 'Precision', 'recall': 'Recall'}))
g = sns.catplot(data=df, x='sorter_name', y='score', hue='metric', kind='swarm', dodge=True, # ax=ax,
order=sorter_list, legend_out=True, s=4)
g.fig.set_size_inches(12,5)
sns.set_palette(sns.color_palette("deep"))
g = sns.relplot(data=dataframes['perf_by_units'], x='precision', y='recall', col='sorter_name',
col_wrap=3, col_order=sorter_list, s=60)
Accuracy vns SNR¶
df = dataframes['perf_by_units']
df
df = dataframes['perf_by_units']
# add snr to the by-unit table
df['snr'] = None
for rec_name, gt_id in snr.index:
df['snr'].loc[(df['gt_unit_id']==gt_id) & (df['rec_name']==rec_name)] = snr.at[(rec_name, gt_id), 'snr']
sns.set_palette(sns.color_palette("deep"))
g = sns.relplot(data=df, x='snr', y='accuracy', col='sorter_name',
col_wrap=3, col_order=sorter_list, s=80)