Ensemble sorting of a Neuropixels recording
Ensemble sorting of a Neuropixels recording¶
This notebook reproduces figures 1 and 4 from the paper SpikeInterface, a unified framework for spike sorting.
The data set for this notebook is available on the Dandi Archive: https://gui.dandiarchive.org/#/dandiset/000034
The entire data archive can be downloaded with the command dandi download https://gui.dandiarchive.org/#/dandiset/000034/draft
(about 75GB).
Files required to run the code are:
- the raw data: sub-mouse412804_ecephys.nwb
- two manually curated sortings:
These files should be in the same directory where the notebook is located (otherwise adjust paths below).
Author: Matthias Hennig, University of Edinburgh, 24 Aug 2020
Requirements¶
For this need you will need the following Python packages:
- numpy
- pandas
- matplotlib
- seaborn
- spikeinterface
- dandi
- matplotlib-venn
To run the MATLAB-based sorters, you would also need a MATLAB license. For other sorters, please refer to the documentation on how to install sorters.
import os
# Matlab sorter paths:
# change these to match your environment
os.environ["IRONCLUST_PATH"] = "./ironclust"
os.environ["KILOSORT2_PATH"] = "./Kilosort2"
os.environ["HDSORT_PATH"] = "./HDsort"
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path
import pandas as pd
import seaborn as sns
from collections import defaultdict
from matplotlib_venn import venn3
import spikeinterface as si
import spikeextractors as se
import spiketoolkit as st
import spikesorters as ss
import spikecomparison as sc
import spikewidgets as sw
from spikecomparison import GroundTruthStudy, MultiSortingComparison
%matplotlib inline
def clear_axes(ax):
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
# print version information
si.print_spikeinterface_version()
ss.print_sorter_versions()
# where to find the data set
data_file = Path('./') / 'mouse412804_probeC_15min.nwb'
# results are stored here
study_path = Path('./')
# this folder will contain all results
study_folder = study_path / 'study_15min/'
# this folder will be used as temorary space, and hold the sortings etc.
working_folder = study_path / 'working_15min'
# sorters to use
sorter_list = ['herdingspikes', 'kilosort2', 'ironclust', 'tridesclous', 'spykingcircus', 'hdsort']
# pass the following parameters to the sorters
sorter_params = {
# 'kilosort2': {'keep_good_only': True}, # uncomment this to test the native filter for false positives
'spyking_circus': {'adjacency_radius': 50},
'herdingspikes': {'filter': True, }}
sorter_names = ['HerdingSpikes', 'Kilosort2', 'Ironclust','Tridesclous', 'SpykingCircus', 'HDSort']
sorter_names_short = ['HS', 'KS', 'IC', 'TDC', 'SC', 'HDS']
# create an extractor object for the raw data
recording = se.NwbRecordingExtractor(str(data_file))
print("Number of frames: {}\nSampling rate: {}Hz\nNumber of channels: {}".format(
recording.get_num_frames(), recording.get_sampling_frequency(),
recording.get_num_channels()))
Run spike sorters and perform comparison between all outputs¶
# set up the study environment and run all sorters
# sorters are not re-run if outputs are found in working_folder
if not study_folder.is_dir():
print('Setting up study folder:', study_folder)
os.mkdir(study_folder)
# run all sorters
result_dict = ss.run_sorters(sorter_list=sorter_list, recording_dict_or_list={'rec': recording}, with_output=True,
sorter_params=sorter_params, working_folder=working_folder, engine='loop',
mode='keep', verbose=True)
# store sortings in a list for quick access
sortings = []
for s in sorter_list:
sortings.append(result_dict['rec',s])
# perform a multi-comparison, all to all sortings
# result is stored, and loaded from disk if the file is found
if not os.path.isfile(study_folder / 'multicomparison.gpickle'):
mcmp = sc.compare_multiple_sorters(sorting_list=sortings, name_list=sorter_names_short,
verbose=True)
print('saving multicomparison')
mcmp.dump(study_folder)
else:
print('loading multicomparison')
mcmp = sc.MultiSortingComparison.load_multicomparison(study_folder)
Figure 1 - comparison of sorter outputs¶
# activity levels on the probe
plt.figure(figsize=(16,2))
ax = plt.subplot(111)
w = sw.plot_activity_map(recording, trange=(0,20), transpose=True, ax=ax, background='w', frame=True)
ax.plot((50,150),(-30,-30),'k-')
ax.annotate('100$\\mu m$',(100,-90), ha='center');
# example data traces
plt.figure(figsize=(16,6))
ax = plt.subplot(111)
w = sw.plot_timeseries(recording, channel_ids=range(20,28), color='k', ax=ax, trange=(1,2))
ax.axis('off')
p = ax.get_position()
p.y0 = 0.55
ax.set_position(p)
ax.set_xticks(())
ax.plot((1.01,1.11),(-1790,-1790),'k-')
ax.annotate('100ms',(1.051,-2900), ha='center');
ax.set_ylim((-2900,ax.set_ylim()[1]))
ax = plt.subplot(111)
ax.bar(range(len(sortings)), [len(s.get_unit_ids()) for s in sortings], color='tab:blue')
ax.set_xticks(range(len(sorter_names)))
ax.set_xticklabels(sorter_names_short, rotation=60, ha='center')
ax.set_ylabel('Units detected')
clear_axes(ax)
w = sw.plot_multicomp_agreement(mcmp, plot_type='pie')
w = sw.plot_multicomp_agreement_by_sorter(mcmp, show_legend=True)
# numbers for figure above
print('number of units detected:')
for i,s in enumerate(sortings):
print("{}: {}".format(sorter_names[i],len(s.get_unit_ids())))
sg_names, sg_units = mcmp.compute_subgraphs()
v, c = np.unique([len(np.unique(s)) for s in sg_names], return_counts=True)
df = pd.DataFrame(np.vstack((v,c,np.round(100*c/np.sum(c),2))).T,
columns=('in # sorters','# units','percentage'))
print('\nall sorters, all units:')
print(df)
df = pd.DataFrame()
for i, name in enumerate(sorter_names_short):
v, c = np.unique([len(np.unique(sn)) for sn in sg_names if name in sn], return_counts=True)
df.insert(2*i,name,c)
df.insert(2*i+1,name+'%',np.round(100*c/np.sum(c),1))
print('\nper sorter:')
print(df)
Supplemental Figure - example unit templates¶
# show unit emplates and spike trains for two units/all sorters
sorting = mcmp.get_agreement_sorting(minimum_agreement_count=6)
get_sorting = lambda u: [mcmp.sorting_list[i] for i,n in enumerate(mcmp.name_list) if n==u[0]][0]
get_spikes = lambda u: [mcmp.sorting_list[i].get_unit_spike_train(u[1]) for i,n in enumerate(mcmp.name_list) if n==u[0]][0]
# one well matched and one not so well matched unit, all sorters
show_units = [2,17]
for i,unit in enumerate(show_units):
fig = plt.figure(figsize=(16, 2))
ax = plt.subplot(111)
ax.set_title('Average agreement: {:.2f}'.format(sorting.get_unit_property(sorting.get_unit_ids()[unit],'avg_agreement')))
units = sorting.get_unit_property(sorting.get_unit_ids()[unit], 'sorter_unit_ids')
cols = plt.cm.Accent(np.arange(len(units))/len(units))
for j,u in enumerate(dict(sorted(units.items())).items()):
s = get_sorting(u).get_units_spike_train((u[1],))[0]
s = s[s<20*get_sorting(u).get_sampling_frequency()]
ax.plot(s/get_sorting(u).get_sampling_frequency(), np.ones(len(s))*j, '|', color=cols[j], label=u[0])
ax.set_frame_on(False)
ax.set_xticks(())
ax.set_yticks(())
ax.plot((0,1),(-1,-1),'k')
ax.annotate('1s',(0.5,-1.75), ha='center')
ax.set_ylim((-2,len(units)+1))
fig = plt.figure(figsize=(16, 2))
units = sorting.get_unit_property(sorting.get_unit_ids()[unit], 'sorter_unit_ids')
print(units)
print('Agreement: {}'.format(sorting.get_unit_property(sorting.get_unit_ids()[unit],'avg_agreement')))
cols = plt.cm.Accent(np.arange(len(units))/len(units))
for j,u in enumerate(dict(sorted(units.items())).items()):
ax = plt.subplot(1, len(sorter_list), j+1)
w = sw.plot_unit_templates(recording, get_sorting(u), unit_ids=(u[1],), max_spikes_per_unit=10,
channel_locs=True, radius=75, show_all_channels=False, color=[cols[j]],
lw=1.5, ax=ax, plot_channels=False, set_title=False, axis_equal=True)
# was 100 spikes in original plot
ax.set_title(u[0])
Figure 4 - comparsion between ensembe sortings and curated data¶
# perform a comparison with curated sortings (KS2)
curated1 = se.NwbSortingExtractor('sub-mouse412804_ses-20200824T155542.nwb', sampling_frequency=30000)
curated2 = se.NwbSortingExtractor('sub-mouse412804_ses-20200824T155543.nwb', sampling_frequency=30000)
comparison_curated = sc.compare_two_sorters(curated1, curated2)
comparison_curated_ks = sc.compare_multiple_sorters((curated1, curated2, sortings[sorter_list.index('kilosort2')]))
# consensus sortings (units where at least 2 sorters agree)
sorting = mcmp.get_agreement_sorting(minimum_agreement_count=2)
consensus_sortings = []
units_dict = defaultdict(list)
units = [sorting.get_unit_property(u,'sorter_unit_ids') for u in sorting.get_unit_ids()]
for au in units:
for u in au.items():
units_dict[u[0]].append(u[1])
for i,s in enumerate(sorter_names_short):
consensus_sortings.append(se.SubSortingExtractor(sortings[i], unit_ids=units_dict[s]))
# orphan units (units found by only one sorter)
sorting = mcmp.get_agreement_sorting(minimum_agreement_count=1, minimum_agreement_count_only=True)
unmatched_sortings = []
units_dict = defaultdict(list)
units = [sorting.get_unit_property(u,'sorter_unit_ids') for u in sorting.get_unit_ids()]
for au in units:
for u in au.items():
units_dict[u[0]].append(u[1])
for i,s in enumerate(sorter_names_short):
unmatched_sortings.append(se.SubSortingExtractor(sortings[i], unit_ids=units_dict[s]))
consensus_curated_comparisons = []
for s in consensus_sortings:
consensus_curated_comparisons.append(sc.compare_two_sorters(s, curated1))
consensus_curated_comparisons.append(sc.compare_two_sorters(s, curated2))
unmatched_curated_comparisons = []
for s in unmatched_sortings:
unmatched_curated_comparisons.append(sc.compare_two_sorters(s, curated1))
unmatched_curated_comparisons.append(sc.compare_two_sorters(s, curated2))
all_curated_comparisons = []
for s in sortings:
all_curated_comparisons.append(sc.compare_two_sorters(s, curated1))
all_curated_comparisons.append(sc.compare_two_sorters(s, curated2)) \
# count various types of units
count_mapped = lambda x : np.sum([u!=-1 for u in x.get_mapped_unit_ids()])
count_not_mapped = lambda x : np.sum([u==-1 for u in x.get_mapped_unit_ids()])
count_units = lambda x : len(x.get_unit_ids())
n_consensus_curated_mapped = np.array([count_mapped(c.get_mapped_sorting1()) for c in consensus_curated_comparisons]).reshape((len(sorter_list),2))
n_consensus_curated_unmapped = np.array([count_not_mapped(c.get_mapped_sorting1()) for c in consensus_curated_comparisons]).reshape((len(sorter_list),2))
n_unmatched_curated_mapped = np.array([count_mapped(c.get_mapped_sorting1()) for c in unmatched_curated_comparisons]).reshape((len(sorter_list),2))
n_all_curated_mapped = np.array([count_mapped(c.get_mapped_sorting1()) for c in all_curated_comparisons]).reshape((len(sorter_list),2))
n_all_curated_unmapped = np.array([count_not_mapped(c.get_mapped_sorting1()) for c in all_curated_comparisons]).reshape((len(sorter_list),2))
n_curated_all_unmapped = np.array([count_not_mapped(c.get_mapped_sorting2()) for c in all_curated_comparisons]).reshape((len(sorter_list),2))
n_all = np.array([count_units(s) for s in sortings])
n_consensus = np.array([count_units(s) for s in consensus_sortings])
n_unmatched = np.array([count_units(s) for s in unmatched_sortings])
n_curated1 = len(curated1.get_unit_ids())
n_curated2 = len(curated2.get_unit_ids())
# overlap between two manually curated data and the Kilosort2 sorting they were derived from
i = {}
for k in ['{0:03b}'.format(v) for v in range(1,2**3)]:
i[k] = 0
i['111'] = len(comparison_curated_ks.get_agreement_sorting(minimum_agreement_count=3).get_unit_ids())
s = comparison_curated_ks.get_agreement_sorting(minimum_agreement_count=2, minimum_agreement_count_only=True)
units = [s.get_unit_property(u,'sorter_unit_ids').keys() for u in s.get_unit_ids()]
for u in units:
if 'sorting1' in u and 'sorting2' in u:
i['110'] += 1
if 'sorting1' in u and 'sorting3' in u:
i['101'] += 1
if 'sorting2' in u and 'sorting3' in u:
i['011'] += 1
s = comparison_curated_ks.get_agreement_sorting(minimum_agreement_count=1, minimum_agreement_count_only=True)
units = [s.get_unit_property(u,'sorter_unit_ids').keys() for u in s.get_unit_ids()]
for u in units:
if 'sorting1' in u:
i['100'] += 1
if 'sorting2' in u:
i['010'] += 1
if 'sorting3' in u:
i['001'] += 1
colors = plt.cm.RdYlBu(np.linspace(0,1,3))
venn3(subsets = i,set_labels=('Curated 1', 'Curated 2', 'Kilosort2'),
set_colors=colors, alpha=0.6, normalize_to=100)
# overlaps betweem ensemble sortings (per sorter) and manually curated sortings
def plot_mcmp_results(data, labels, ax, ylim=None, yticks=None, legend=False):
angles = (np.linspace(0, 2*np.pi, len(sorter_list), endpoint=False)).tolist()
angles += angles[:1]
for i,v in enumerate(data):
v = v.tolist() + v[:1].tolist()
ax.bar(np.array(angles)+i*2*np.pi/len(sorter_list)/len(data)/2-2*np.pi/len(sorter_list)/len(data)/4,
v, label=labels[i],
alpha=0.8, width=np.pi/len(sorter_list)/2)
ax.set_thetagrids(np.degrees(angles), sorter_names_short)
if legend:
ax.legend(bbox_to_anchor=(1.0, 1), loc=2, borderaxespad=0., frameon=False, fontsize=8, markerscale=0.25)
ax.set_theta_offset(np.pi / 2)
ax.set_theta_direction(-1)
if ylim is not None:
ax.set_ylim(ylim)
if yticks is not None:
ax.set_yticks(yticks)
plt.figure(figsize=(14,3))
sns.set_palette(sns.color_palette("Set1"))
ax = plt.subplot(131, projection='polar')
plot_mcmp_results((n_all_curated_mapped[:,0]/n_all*100,
n_all_curated_mapped[:,1]/n_all*100),
('Curated 1','Curated 2'), ax, yticks=np.arange(20,101,20))
ax.set_title('Percent all units\nwith match in curated sets',pad=20);
plt.ylim((0,100))
ax = plt.subplot(132, projection='polar')
plot_mcmp_results((n_consensus_curated_mapped[:,0]/n_consensus*100,
n_consensus_curated_mapped[:,1]/n_consensus*100),
('Curated 1','Curated 2'), ax, yticks=np.arange(20,101,20))
ax.set_title('Percent consensus units\nwith match in curated sets',pad=20);
plt.ylim((0,100))
ax = plt.subplot(133, projection='polar')
plot_mcmp_results((n_unmatched_curated_mapped[:,0]/n_unmatched*100,
n_unmatched_curated_mapped[:,1]/n_unmatched*100),
('Curated 1','Curated 2'), ax, ylim=(0,30), yticks=np.arange(10,21,10), legend=True)
ax.set_title('Percent non-consensus units\nwith match in curated sets',pad=20);
# numbers for figure above
df = pd.DataFrame(np.vstack((n_all_curated_mapped[:,0]/n_all*100, n_all_curated_mapped[:,1]/n_all*100,
n_all_curated_mapped[:,0], n_all_curated_mapped[:,1])).T,
columns = ('C1 %', 'C2 %', 'C1', 'C2'), index=sorter_names_short)
print('Percent all units with match in curated sets')
print(df)
df = pd.DataFrame(np.vstack((n_consensus_curated_mapped[:,0]/n_consensus*100, n_consensus_curated_mapped[:,1]/n_consensus*100,
n_consensus_curated_mapped[:,0],n_consensus_curated_mapped[:,1])).T,
columns = ('C1 %', 'C2 %', 'C1', 'C2'), index=sorter_names_short)
print('\nPercent consensus units with match in curated sets')
print(df)
df = pd.DataFrame(np.vstack((n_unmatched_curated_mapped[:,0]/n_unmatched*100,
n_unmatched_curated_mapped[:,1]/n_unmatched*100,
n_unmatched_curated_mapped[:,0],n_unmatched_curated_mapped[:,1])).T,
columns = ('C1 %', 'C2 %', 'C1', 'C2'), index=sorter_names_short)
print('\nPercent non-consensus units with match in curated sets')
print(df)