Ground truth comparison and ensemble sorting of a synthetic Neuropixels recording
Ground-truth comparison and ensemble sorting of a synthetic Neuropixels recording¶
This notebook reproduces figures 2 and 3 from the paper SpikeInterface, a unified framework for spike sorting.
The data set for this notebook is available on the Dandi Archive: https://gui.dandiarchive.org/#/dandiset/000034.
The entire data archive can be downloaded with the command dandi download https://gui.dandiarchive.org/#/dandiset/000034/draft
(about 75GB).
The data file required to run the code is:
- the raw data: sub-MEAREC-250neuron-Neuropixels_ecephys.nwb
This file should be in the same directory where the notebook is located (otherwise adjust paths below).
Author: Matthias Hennig, University of Edinburgh, 22 Aug 2020
Requirements¶
For this need you will need the following Python packages:
- numpy
- pandas
- matplotlib
- seaborn
- spikeinterface
- dandi
- matplotlib-venn
To run the MATLAB-based sorters, you would also need a MATLAB license. For other sorters, please refer to the documentation on how to install sorters.
import os
# Matlab sorter paths:
# change these to match your environment
os.environ["IRONCLUST_PATH"] = "./ironclust"
os.environ["KILOSORT2_PATH"] = "./Kilosort2"
os.environ["HDSORT_PATH"] = "./HDsort"
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path
import pandas as pd
import seaborn as sns
from collections import defaultdict
from matplotlib_venn import venn3
import spikeinterface as si
import spikeextractors as se
import spiketoolkit as st
import spikesorters as ss
import spikecomparison as sc
import spikewidgets as sw
from spikecomparison import GroundTruthStudy
%matplotlib inline
def clear_axes(ax):
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
# print version information
si.print_spikeinterface_version()
ss.print_sorter_versions()
Set up ground truth study an run all sorters¶
study_path = Path('.')
data_path = Path('.')
study_folder = study_path / 'study_mearec_250cells_Neuropixels-384chans_duration600s_noise10uV_2020-02-28/'
# the original data
# this NWB file contains both the ground truth spikes and the raw data
data_filename = data_path / 'sub-MEAREC-250neuron-Neuropixels_ecephys.nwb'
SX_gt = se.NwbSortingExtractor(str(data_filename))
RX = se.NwbRecordingExtractor(str(data_filename))
if not os.path.isdir(study_folder):
gt_dict = {'rec0' : (RX, SX_gt) }
study = GroundTruthStudy.create(study_folder, gt_dict)
else:
study = GroundTruthStudy(study_folder)
sorter_list = ['herdingspikes', 'kilosort2', 'ironclust',
'spykingcircus', 'tridesclous', 'hdsort']
sorter_names = ['HerdingSpikes', 'Kilosort2', 'Ironclust',
'SpykingCircus', 'Tridesclous', 'HDSort']
sorter_names_short = ['HS', 'KS', 'IC', 'SC', 'TDC', 'HDS']
study.run_sorters(sorter_list, mode='keep', engine='loop', verbose=True)
study.copy_sortings()
# compute or load SNR for the ground truth units
snr_file = study_folder / 'snr.npy'
if os.path.isfile(snr_file):
snr = np.load(snr_file)
else:
print('computing snr')
# note this is quite slow for a NWB file as the data is arranged as channels:time
# it is faster to first write out a binary file in time:channels order
snr = st.validation.compute_snrs(SX_gt, RX, apply_filter=False, verbose=False,
memmap=True, max_spikes_per_unit_for_snr=500)
np.save(snr_file, snr)
Run the ground truth comparison and summarise the results¶
study.run_comparisons(exhaustive_gt=True, match_score=0.1)
comparisons = study.comparisons
dataframes = study.aggregate_dataframes()
# comparison summary
dataframes['count_units']
Figure 1 - ground truth study results¶
# activity levels on the probe
plt.figure(figsize=(16,2))
ax = plt.subplot(111)
w = sw.plot_activity_map(RX, trange=(0,20), transpose=True, ax=ax, background='w', frame=True)
ax.plot((-1800,-1700), (-120,-120), 'k-')
ax.annotate('100$\\mu m$',(-1750,-220), ha='center');
# example data traces
plt.figure(figsize=(16,6))
ax = plt.subplot(111)
w = sw.plot_timeseries(RX, channel_ids=range(10,18), color='k', ax=ax, trange=(1,2))
ax.axis('off')
p = ax.get_position()
p.y0 = 0.58
ax.set_position(p)
ax.set_xticks(())
ax.plot((1.01,1.11),(-400,-400),'k-')
ax.annotate('100ms',(1.051,-750), ha='center');
ax.set_ylim((-750,ax.set_ylim()[1]))
ax = plt.subplot(111)
n = []
for s in sorter_list:
n.append(len(study.get_sorting(s).get_unit_ids()))
ax.bar(range(len(sorter_list)), n, color='tab:blue')
ax.set_xticks(range(len(sorter_names_short)))
ax.set_xticklabels(sorter_names_short, rotation=60, ha='center')
ax.set_ylabel('Units detected')
clear_axes(ax)
ax = plt.subplot(111)
p = ax.get_position()
p.x1=0.85
ax.set_position(p)
sns.set_palette(sns.color_palette("Set1"))
df = pd.melt(dataframes['perf_by_units'], id_vars='sorter_name', var_name='Metric', value_name='Score',
value_vars=('accuracy','precision', 'recall'))
sns.swarmplot(data=df, x='sorter_name', y='Score', hue='Metric', dodge=True,
order=sorter_list, s=3, ax=ax)
ax.set_xticklabels(sorter_names_short, rotation=30, ha='center')
ax.legend(bbox_to_anchor=(1.0, 1), loc=2, borderaxespad=0., frameon=False, fontsize=8, markerscale=0.5)
ax.set_xlabel(None);
ax.set_ylabel('Score');
clear_axes(ax)
ax = plt.subplot(111)
p = ax.get_position()
p.x1=0.85
ax.set_position(p)
df = pd.melt(dataframes['count_units'], id_vars='sorter_name', var_name='Type', value_name='Units',
value_vars=('num_well_detected', 'num_false_positive', 'num_redundant', 'num_overmerged'))
sns.set_palette(sns.color_palette("Set1"))
sns.barplot(x='sorter_name', y='Units', hue='Type', data=df,
order=sorter_list, ax=ax)
ax.set_xticklabels(sorter_names_short, rotation=30, ha='right')
ax.legend(bbox_to_anchor=(1.0, 1), loc=2, borderaxespad=0., frameon=False, fontsize=8, markerscale=0.1)
for t, l in zip(ax.legend_.texts, ("Well detected", "False positive", "Redundant", "Overmerged")): t.set_text(l)
ax.set_xlabel(None);
clear_axes(ax)
# precision vs. recall and accuracy vs. SNR
fig = plt.figure(figsize=(14, 4))
sns.set_palette(sns.color_palette("deep"))
axesA = []
for i,s in enumerate(sorter_list):
ax = plt.subplot(2,len(sorter_list),i+1)
axesA.append(ax)
g = sns.scatterplot(data=dataframes['perf_by_units'].loc[dataframes['perf_by_units'].sorter_name==s],
x='precision', y='recall', s=30, edgecolor=None, alpha=0.1)
ax.set_title(sorter_names[i])
ax.set_aspect('equal')
clear_axes(ax)
ax.set_xlabel('Precision')
ax.set_ylabel('Recall')
for ax in axesA[1:]:
axesA[0].get_shared_y_axes().join(axesA[0], ax)
ax.set_yticklabels([])
ax.set_ylabel('')
ax.set_xlabel('')
ax.autoscale()
############### B
df = dataframes['perf_by_units']
# add snr to the by-unit table
if 'snr' not in df.keys():
snr_d = {k:snr[k] for i,k in enumerate(SX_gt.get_unit_ids())}
df['snr'] = df['gt_unit_id'].map(snr_d)
axesB = []
for i,s in enumerate(sorter_list):
ax = plt.subplot(2,len(sorter_list),len(sorter_list)+i+1)
axesB.append(ax)
g = sns.scatterplot(data=dataframes['perf_by_units'].loc[dataframes['perf_by_units'].sorter_name==s],
x='snr', y='accuracy', s=30, alpha=0.2)
clear_axes(ax)
ax.set_xlabel('Ground truth SNR')
ax.set_ylabel('Accuracy')
for ax in axesB[1:]:
axesB[0].get_shared_y_axes().join(axesB[0], ax)
ax.set_yticklabels([])
ax.set_ylabel('')
ax.set_xlabel('')
ax.autoscale()
Figure 3 - comparison of sorter outputs and ensembe sorting¶
# perform an all-to-all multicomparison or load it from
# disk if file exists
sortings = []
for s in sorter_list:
sortings.append(study.get_sorting(s))
cmp_folder = study_folder / 'multicomparison/'
if not os.path.isdir(cmp_folder):
os.mkdir(cmp_folder)
if not os.path.isfile(cmp_folder / 'multicomparison.gpickle'):
mcmp = sc.compare_multiple_sorters(sorting_list=sortings, name_list=sorter_names_short,
verbose=False, match_score=0.5)
print('saving multicomparison')
mcmp.dump(cmp_folder)
else:
print('loading multicomparison')
mcmp = sc.MultiSortingComparison.load_multicomparison(cmp_folder)
mcmp_graph = mcmp.graph.copy()
# get sorting extractors with unit with no agreement (minimum agreement = 1) and one
# with at least 2 sorters in agreement
not_in_agreement = mcmp.get_agreement_sorting(minimum_agreement_count=1, minimum_agreement_count_only=True)
in_agreement = mcmp.get_agreement_sorting(minimum_agreement_count=2)
# score these against ground truth
cmp_no_agr = sc.compare_sorter_to_ground_truth(SX_gt, not_in_agreement)
cmp_agr = sc.compare_sorter_to_ground_truth(SX_gt, in_agreement)
# now collect results for each sorter:
# create dict to collect results
results = {'TP':{}, 'FP':{}, 'SNR':{}}
ns = len(sorter_names_short)
for s in sorter_names_short:
results['TP'][s] = dict(zip(range(1,ns+1), [0]*(ns+1)))
results['FP'][s] = dict(zip(range(1,ns+1), [0]*(ns+1)))
results['SNR'][s] = dict(zip(range(1,ns+1), [[]]*(ns+1)))
# sorter names
dict_names = dict(zip(sorter_names_short, sorter_list))
# iterate over all units gathered from subgraphs
for u in mcmp._new_units.keys():
found_in_gt = []
gt_index = []
# check if units have a match in ground truth, store boolen
for u2 in mcmp._new_units[u]['sorter_unit_ids'].items():
found_in_gt.append(u2[1] in study.comparisons['rec0',dict_names[u2[0]]].best_match_12.values)
if found_in_gt[-1]:
gt_index.append(np.where(study.comparisons['rec0',dict_names[u2[0]]].best_match_12==u2[1])[0][0])
if len(set(gt_index))>1:
print('different gt units: ',u, gt_index)
if np.sum(found_in_gt)==len(found_in_gt):
# if np.sum(found_in_gt)>0:#==len(found_in_gt): # use this if interested in equal matches
key = 'TP'
else:
key = 'FP'
if len(found_in_gt)>1:
print('FP unit found by >1 sorter: ',u)
for i,u2 in enumerate(mcmp._new_units[u]['sorter_unit_ids'].items()):
# results[key][u2[0]][np.sum(found_in_gt)] += 1 # use this if interested in equal matches
results[key][u2[0]][len(found_in_gt)] += 1
if key is 'TP':
# something odd with nested oython dicts requires this:
d = results['SNR'][u2[0]][len(found_in_gt)].copy()
d.append(snr[gt_index[i]])
results['SNR'][u2[0]][len(found_in_gt)] = d
# this fails, I wonder why:
# results['SNR'][u2[0]][len(found_in_gt)].append(snr[gt_index[i]])
w = sw.plot_multicomp_agreement(mcmp, plot_type='pie')
w = sw.plot_multicomp_agreement_by_sorter(mcmp, show_legend=True)
fig = plt.figure(figsize=(14,4))
axes = []
for i,s in enumerate(results['TP'].keys()):
ax = plt.subplot(2,len(sorter_list), i+1)
ax.bar(results['FP'][s].keys(), list(results['FP'][s].values()), alpha=0.5, width = 0.6, color='r', label='false positive')
ax.bar(results['TP'][s].keys(), list(results['TP'][s].values()), bottom=list(results['FP'][s].values()), alpha=0.5, width = 0.6, color='b', label='matched')
ax.set_xticks(range(1,len(sorter_list)+1))
ax.set_xticklabels(range(1,len(sorter_list)+1))
ax.set_title(s)
clear_axes(ax)
axes.append(ax)
if i == 0:
ax.set_ylabel('Number of units')
else:
ax.get_shared_y_axes().join(axes[0], ax)
ax.set_yticklabels([])
ax = plt.subplot(2,len(sorter_list), len(sorter_list)+i+1)
d = results['SNR'][s]
sns.boxenplot(data=pd.DataFrame([pd.Series(d[k]) for k in d.keys()]).T, color='b', ax=ax)
ax.set_xticks(range(0,len(sorter_list)))
ax.set_xticklabels(range(1,len(sorter_list)+1))
clear_axes(ax)
axes.append(ax)
if i == 0:
ax.set_ylabel('Ground truth SNR')
ax.set_xlabel('Found by # sorters')
else:
ax.get_shared_y_axes().join(axes[1], ax)
ax.set_yticklabels([])
# numbers for figure above
sg_names, sg_units = mcmp.compute_subgraphs()
v, c = np.unique([len(np.unique(s)) for s in sg_names], return_counts=True)
df = pd.DataFrame(np.vstack((v,c,np.round(100*c/np.sum(c),2))).T,
columns=('in # sorters','# units','percentage'))
print('all sorters, all units:')
print(df)
df = pd.DataFrame()
for i, name in enumerate(sorter_names_short):
v, c = np.unique([len(np.unique(sn)) for sn in sg_names if name in sn], return_counts=True)
cl = np.zeros(len(sorter_list), dtype=int)
cl[v.astype(int)-1] = c
df.insert(2*i,name,cl)
df.insert(2*i+1,name+'%',np.round(100*cl/np.sum(cl),1))
print('\nper sorter:')
print(df)
for i,s in enumerate(results['TP'].keys()):
print(s, list(results['FP'][s].values()))