Marques-Smith neuropixel 384ch paired recording

Sorter comparison with paired (neuropixel - patch) recordings

Author : Samuel Garcia

André Marques-Smith make an open dataset with simultaneous patch-clamp and neuropixel probe extracellular recordings from the same cortical neuron in anaesthetized rats.

This is very very usefull to test spike sorting engine.

The original contain 42 recordings.

Here we select only a subset of 6 files. I keep only when the SNR in the extra cellular trace is big enough to be detected. One file (c24) was remove because the juxta cellular itself is ambiguous.

The patch recording will be the "Groud Truth". And the neuropixel with 384ch will computed by 5 sorters to compare results.

Please have a look to the paper:

https://www.biorxiv.org/content/10.1101/370080v2

The repo the explain everything

https://github.com/kampff-lab/sc.io/tree/master/Paired%20Recordings

Data set availaible here :

http://crcns.org/data-sets/methods/spe-1

or here

https://drive.google.com/drive/folders/13GCOuWN4QMW6vQmlNIolUrxPy-4Wv1BC

Note :

  • I will note use the spiek indexes provide by André because for some files small errors on double peak detection can occurs.
  • This results is also on spike forest here but the spikeforest keep only 32 channels to reduce the computation. Here the computation is done on 384 channels. Lets see if we have the same results.
In [7]:
# import everything
import os, getpass

kilosort2_path = '/home/samuel/Documents/Spikeinterface/Kilosort2'
os.environ["KILOSORT2_PATH"] = kilosort2_path

kilosort_path = '/home/samuel/Documents/Spikeinterface/KiloSort/'
os.environ["KILOSORT_PATH"] = kilosort_path

ironclust_path = '/home/samuel/Documents/Spikeinterface/ironclust'
os.environ["IRONCLUST_PATH"] = ironclust_path

from pathlib import Path
import scipy.signal
import scipy.io

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path


import spikeinterface as si
import spikeinterface.extractors as se
import spikeinterface.sorters as ss
import spikeinterface.widgets as sw

from spikeinterface.comparison import GroundTruthStudy
In [2]:
si.print_spikeinterface_version()
ss.print_sorter_versions()
spikeinterface: 0.9.9
  * spikeextractor: 0.8.4
  * spiketoolkit: 0.6.3
  * spikesorters: 0.3.3
  * spikecomparison: 0.2.6
  * spikewidgets: 0.4.3

herdingspikes: 0.3.7+git.45665a2b6438
ironclust: 5.9.8
kilosort: git-cd040da1963d
kilosort2: git-e243c934339e
spykingcircus: 0.9.7
tridesclous: 1.6.1.dev

path and list

In [4]:
p = '/media/samuel/dataspikesorting/DataSpikeSortingHD2/andre_paired_neuropixel/'
p = Path(p)

recordings_folder = p / 'recordings'
study_folder = p / 'study_paired_neuropixel'


rec_names = [
    'c14',
    'c26', 
    'c28', 
    'c37',
    'c45', 
    'c46'
]

function to detect peak from patch recording

File provide by André contains small errors in peak detection. Here we computed then again.

In [6]:
def detect_peak_on_patch_sig(patch_sig, sample_rate):
    # filter because some traces have drift
    sos = scipy.signal.iirfilter(5, 200./sample_rate*2, analog=False, btype = 'highpass', ftype = 'butter', output = 'sos')
    patch_sig_f = scipy.signal.sosfiltfilt(sos, patch_sig, axis=0)
    
    med = np.median(patch_sig_f)
    mad = np.median(np.abs(patch_sig_f-med))*1.4826
    thresh = med - 12 * mad
    
    # 1 ms aounrd peak
    d = int(sample_rate * 0.001)
    spike_indexes, prop = scipy.signal.find_peaks(-patch_sig_f, height=-thresh, distance=d)

    #~ fig, ax = plt.subplots()
    #~ ax.plot(patch_sig_f)
    #~ ax.axhline(thresh)
    #~ ax.plot(spike_indexes, patch_sig_f[spike_indexes], ls='None', marker='o')
    #~ plt.show()
    
    return spike_indexes
In [ ]:
## create the study
In [ ]:
# the file chanMap.mat contain the geometry
d = scipy.io.loadmat(str(p / 'chanMap.mat'))
locations = np.zeros((384, 2))
locations[:, 0] = d['xcoords'][:, 0]
locations[:, 1] = d['ycoords'][:, 0]

#~ fig, ax = plt.subplots()
#~ ax.scatter(locations[:, 0], locations[:, 1])
#~ plt.show()


gt_dict = {}
for rec_name in rec_names:
    print(rec_name)
    sample_rate = 30000.

    # neuropixel sigs
    raw_bin_filename = recordings_folder /  rec_name / (rec_name + '_npx_raw.bin')
    mea_sigs = np.memmap(raw_bin_filename, dtype='int16', mode='r').reshape(-1, 384)

    # patch recoring
    filename = recordings_folder /  rec_name / (rec_name + '_patch_ch1.bin')
    #~ patch_sig = np.memmap(str(filename), dtype='float64', mode='r')
    patch_sig = np.fromfile(str(filename), dtype='float64')
    
    # spike index inn the patch clock refrence
    sr = 50023. #  this is not theexact freq but it do not matter here
    gt_spike_indexes_patch = detect_peak_on_patch_sig(patch_sig, sr)

    # this is time factor strech between the 2 recordings (neuropixel and patch)
    time_factor = mea_sigs.shape[0] / patch_sig.shape[0]
    print('time_factor', time_factor)
    
    # spike index in the neuropixel clock refrence
    gt_spike_indexes = (gt_spike_indexes_patch * time_factor).astype('int64')


    # recording
    rec = se.BinDatRecordingExtractor(raw_bin_filename, sample_rate, 384, 'int16', offset=0, time_axis=0)
    rec.set_channel_locations(locations)

    # gt sorting
    sorting_gt = se.NumpySortingExtractor()
    sorting_gt.set_times_labels(gt_spike_indexes, np.zeros(gt_spike_indexes.size, dtype='int64'))
    sorting_gt.set_sampling_frequency(sample_rate)

    gt_dict[rec_name] = (rec, sorting_gt)


study = GroundTruthStudy.create(study_folder, gt_dict)

Get signal to noise ratio for all units

In [ ]:
study = GroundTruthStudy(study_folder)
snr = study.concat_all_snr()
snr
In [ ]:
fig, ax = plt.subplots()
ax.hist(snr['snr'].values, bins=np.arange(0, 40, 5))
ax.set_xlabel('GT units SNR')

Run all sorters

In [9]:
sorter_list = ['herdingspikes', 'ironclust', 'kilosort2',  
                'spykingcircus', 'tridesclous']
In [ ]:
study = GroundTruthStudy(study_folder)

study.run_sorters(sorter_list, mode='keep', verbose=False)

Run comparison with ground truth and retreive result tables

In [ ]:
study = GroundTruthStudy(study_folder)
In [ ]:
# this copy sorting is necessary to copy results from sorter
# into a centralize folder with all results
study.copy_sortings()
In [ ]:
# this run all comparison to GT
# exhaustive_gt=False because it is a pair recording so only one GT units
study.run_comparisons(exhaustive_gt=False, match_score=0.1, overmerged_score=0.2)
In [ ]:
# this retrieve results
comparisons = study.comparisons
dataframes = study.aggregate_dataframes()

Run times

In [ ]:
dataframes['run_times'].set_index(['rec_name', 'sorter_name']).unstack('sorter_name')
In [ ]:
sns.set_palette(sns.color_palette("Set1"))

#fig, ax = plt.subplots()
g = sns.catplot(data=dataframes['run_times'], x='sorter_name', y='run_time',
                hue="rec_name",
                order=sorter_list,
                kind='bar',
               legend=False)
g.fig.set_size_inches(12,5)

#ax.set_yscale('log')
#ax.set_ylabel('Run time (s)');
#ax.set_xlabel(None);
fig.axes[0].set_xticklabels(sorter_list, rotation=40, ha='right');

 accuracy/precision/recall scores per sorters

In [ ]:
sns.set_palette(sns.color_palette("Set1"))

df = pd.melt(dataframes['perf_by_units'], id_vars=['rec_name', 'sorter_name'],
            var_name='metric', value_name='score', 
            value_vars=('accuracy','precision', 'recall'))
display(df)

# df.assign(Generation=df.metric.map({'metric':'Metric','accuracy': 'Accuracy', 'precision': 'Precision', 'recall': 'Recall'}))
g = sns.catplot(data=df, x='sorter_name', y='score', hue='metric', kind='swarm', dodge=True, # ax=ax, 
                order=sorter_list, legend_out=True, s=4)
g.fig.set_size_inches(12,5)
In [ ]:
sns.set_palette(sns.color_palette("deep"))
g = sns.relplot(data=dataframes['perf_by_units'], x='precision', y='recall', col='sorter_name', 
                col_wrap=3, col_order=sorter_list, s=60)

 Accuracy vns SNR

In [ ]:
df = dataframes['perf_by_units']
df
In [ ]:
df = dataframes['perf_by_units']

# add snr to the by-unit table
df['snr'] = None
for rec_name, gt_id in snr.index:
    df['snr'].loc[(df['gt_unit_id']==gt_id) & (df['rec_name']==rec_name)] = snr.at[(rec_name, gt_id), 'snr']

sns.set_palette(sns.color_palette("deep"))
g = sns.relplot(data=df, x='snr', y='accuracy', col='sorter_name',
        col_wrap=3, col_order=sorter_list, s=80)