Example of ground truth comparison study
Example of ground-truth comparison with SpikeInterface¶
This notebook shows how SpikeInterface is used to perform a ground truth comparisong for multiple spike sorters.
The dataset is simulated using MEArec and it can be downloaded from Zenodo: https://doi.org/10.5281/zenodo.4058272
The dataset name is: recordings_50cells_SqMEA-10-15_600.0_10.0uV_21-01-2020_18-12.h5. It contains 50 neurons recorded on a 10x10 MEA with 15um pitch. The duration is 600s and the noise level is 10uV.
Author: Samuel Garcia, CRNL, Lyon
Requirements¶
For this need you will need the following Python packages:
- numpy
- pandas
- matplotlib
- seaborn
- spikeinterface
To run the MATLAB-based sorters, you would also need a MATLAB license. For other sorters, please refer to the documentation on how to install sorters.
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import spikeinterface as si
import spikeinterface.extractors as se
import spikeinterface.sorters as ss
import spikeinterface.widgets as sw
from spikeinterface.comparison import GroundTruthStudy
# clone and install MATLAB sorters
# kilosort2
!git clone https://github.com/MouseLand/Kilosort2.git
kilosort2_path = './Kilosort2'
ss.Kilosort2Sorter.set_kilosort2_path(kilosort2_path)
# kilosort
!git clone https://github.com/cortex-lab/KiloSort.git
kilosort_path = './KiloSort'
ss.KilosortSorter.set_kilosort_path(kilosort_path)
# ironclust
!git clone https://github.com/flatironinstitute/ironclust.git
ironclust_path = './ironclust'
ss.IronclustSorter.set_ironclust_path(ironclust_path)
%matplotlib inline
# some matplotlib hack to prettify figure
SMALL_SIZE = 12
MEDIUM_SIZE = 14
BIGGER_SIZE = 16
plt.rc('font', size=SMALL_SIZE) # controls default text sizes
plt.rc('axes', titlesize=SMALL_SIZE) # fontsize of the axes title
plt.rc('axes', labelsize=MEDIUM_SIZE) # fontsize of the x and y labels
plt.rc('xtick', labelsize=SMALL_SIZE) # fontsize of the tick labels
plt.rc('ytick', labelsize=SMALL_SIZE) # fontsize of the tick labels
plt.rc('legend', fontsize=SMALL_SIZE) # legend fontsize
plt.rc('figure', figsize=(10.0, 8.0)) # figsize
plt.rc('figure', titlesize=BIGGER_SIZE) # fontsize of the figure title
def clear_axes(ax):
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
Check spikeinterface version and sorter version¶
In order to be a bit reproducible lets checks for versions of each sorter and spikeinterface subpackage.
si.print_spikeinterface_version()
ss.print_sorter_versions()
Setup global path¶
# Change this path to point to where you downloaded the dataset
p = Path('/home/samuel/Documents/DataSpikeSorting/mearec/')
study_folder = p / 'study_mearec_SqMEA1015um/'
Setup ground truth study¶
# lets load from mearec file the recording (traces) and the true sorting (spiketrains and units)
mearec_filename = p / 'recordings_50cells_SqMEA-10-15_600.0_10.0uV_21-01-2020_18-12.h5'
rec0 = se.MEArecRecordingExtractor(mearec_filename)
gt_sorting0 = se.MEArecSortingExtractor(mearec_filename)
# a study can have several pari of recording/GT sorting
# here we have only one
gt_dict = {'rec0' : (rec0, gt_sorting0) }
study = GroundTruthStudy.create(study_folder, gt_dict)
Run all sorters¶
sorter_list = ['herdingspikes', 'ironclust', 'kilosort2', 'kilosort',
'spykingcircus', 'tridesclous'] #
study = GroundTruthStudy(study_folder)
# here we will take default params for each sorter engine
# but we could change it
sorter_params = {}
study.run_sorters(sorter_list, sorter_params=sorter_params, mode='keep', verbose=True)
Get signal to noise ratio for all units¶
study = GroundTruthStudy(study_folder)
snr = study.get_units_snr()
snr.head(10)
fig, ax = plt.subplots()
ax.hist(snr['snr'].values, bins=20)
ax.set_xlabel('GT units SNR')
Run comparison with ground truth and retreive result tables¶
# this copy sorting is necessary to copy results from sorter
# into a centralize folder with all results
study.copy_sortings()
# this run all comparison sto GT
study.run_comparisons(exhaustive_gt=True, match_score=0.1, overmerged_score=0.2)
# this retrieve results
comparisons = study.comparisons
dataframes = study.aggregate_dataframes()
Run times¶
dataframes['run_times']
fig, ax = plt.subplots()
sns.barplot(data=dataframes['run_times'], x='sorter_name', y='run_time', ax=ax, order=sorter_list)
sns.set_palette(sns.color_palette("Set1"))
ax.set_ylabel('Run time (s)');
ax.set_xlabel(None);
ax.set_xticklabels(sorter_list, rotation=40, ha='right');
confusion matrix¶
for (rec_name, sorter_name), comp in comparisons.items():
fig, ax = plt.subplots()
sw.plot_agreement_matrix(comp, ax=ax)
fig.suptitle(rec_name+' '+ sorter_name)
accuracy/precision/recall scores per sorters¶
sns.set_palette(sns.color_palette("Set1"))
df = pd.melt(dataframes['perf_by_units'], id_vars='sorter_name', var_name='metric', value_name='score',
value_vars=('accuracy','precision', 'recall'))
g = sns.catplot(data=df, x='sorter_name', y='score', hue='metric', kind='swarm', dodge=True,
order=sorter_list, legend_out=True, s=4)
g.fig.set_size_inches(15,8)
ax = g.axes[0, 0]
#ax.legend(['Accuracy', 'Precision', 'Recall'], labelspacing=0.2,
# bbox_to_anchor=(1, 0.5), loc=2, borderaxespad=0., frameon=False)
ax.set_xticklabels(sorter_list, rotation=30, ha='center')
ax.set_xlabel(None)
ax.set_ylabel('Score')
clear_axes(ax)
A wise man (a reviewer of our paper to be honest) suggest us that the swarnplot is not so infortative to benchmark sorters. This is totally true, we can check check the balance between false postive spikes and and false negative spikes (precission vs recall). Here this is another possible way to present the results.
sns.set_palette(sns.color_palette("deep"))
g = sns.relplot(data=dataframes['perf_by_units'], x='precision', y='recall', col='sorter_name',
col_wrap=3, col_order=sorter_list, s=60)
count units¶
- well detected
- false positive
- redundant
- overmerged
sns.set_palette(sns.color_palette("Set2"))
df = pd.melt(dataframes['count_units'], id_vars='sorter_name', var_name='metric', value_name='score',
value_vars=('num_well_detected', 'num_false_positive', 'num_redundant', 'num_overmerged'))
g = sns.catplot(x='sorter_name', y='score', hue='metric', data=df,
height=6, kind="bar", order=sorter_list)
ax = g.axes[0,0]
ax.set_xticklabels(sorter_list, rotation=30, ha='right')
ax.get_legend()
ax.legend(bbox_to_anchor=(0.9, 0.95), borderaxespad=0., frameon=False, labelspacing=0.2)
for t, l in zip(ax.legend_.texts,("Well detected", "False positive", "Redundant", "Overmerged")):
t.set_text(l)
ax.set_xlabel(None)
ax.set_ylabel('Number of units')
clear_axes(ax)
Accuracy vns SNR¶
df = dataframes['perf_by_units']
# add snr to the by-unit table
df['snr'] = None
for gt_id in snr.index:
df['snr'].loc[df['gt_unit_id']==gt_id] = snr.at[gt_id, 'snr']
df
sns.set_palette(sns.color_palette("deep"))
g = sns.relplot(data=dataframes['perf_by_units'], x='snr', y='accuracy', col='sorter_name',
col_wrap=3, col_order=sorter_list, s=60)
for i,ax in enumerate(g.axes):
ax.set_title(sorter_list[i])
ax.set_xlabel('')
g.axes[0].set_xlabel('SNR');
g.axes[0].set_ylabel('Accuracy');