Ensemble sorting of a Neuropixels recording 2

Ensemble sorting of a Neuropixel recording (2)

This notebook reproduces supplemental figure S2 from the paper SpikeInterface, a unified framework for spike sorting.

The recording was made by André Marques-Smith in the lab of Adam Kampff. Reference:

Marques-Smith, A., Neto, J.P., Lopes, G., Nogueira, J., Calcaterra, L., Frazão, J., Kim, D., Phillips, M., Dimitriadis, G., Kampff, A.R. (2018). Recording from the same neuron with high-density CMOS probes and patch-clamp: a ground-truth dataset and an experiment in collaboration. bioRxiv 370080; doi: https://doi.org/10.1101/370080

The data set for this notebook is available on the Dandi Archive: https://gui.dandiarchive.org/#/dandiset/000034

The entire data archive can be downloaded with the command dandi download https://gui.dandiarchive.org/#/dandiset/000034/draft (about 75GB).

File required to run the code:

This file should be in the same directory where the notebook is located (otherwise adjust paths below).

Author: Matthias Hennig, University of Edinburgh, 25 Aug 2020


For this need you will need the following Python packages:

  • numpy
  • pandas
  • matplotlib
  • seaborn
  • spikeinterface
  • dandi

To run the MATLAB-based sorters, you would also need a MATLAB license. For other sorters, please refer to the documentation on how to install sorters.

In [1]:
import os

# Matlab sorter paths:
# change these to match your environment
os.environ["IRONCLUST_PATH"] = "./ironclust"
os.environ["KILOSORT2_PATH"] = "./Kilosort2"
os.environ["HDSORT_PATH"] = "./HDsort"

import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path
import pandas as pd
import seaborn as sns
from collections import defaultdict
from matplotlib_venn import venn3

import spikeinterface as si
import spikeextractors as se
import spiketoolkit as st
import spikesorters as ss
import spikecomparison as sc
import spikewidgets as sw
from spikecomparison import GroundTruthStudy, MultiSortingComparison

%matplotlib inline

def clear_axes(ax):
# print version information
spikeinterface: 0.9.9
  * spikeextractor: 0.8.4
  * spiketoolkit: 0.6.3
  * spikesorters: 0.3.3
  * spikecomparison: 0.2.6
  * spikewidgets: 0.4.3

hdsort: version = '1.0.2'

herdingspikes: 0.3.7
ironclust: 5.9.8
kilosort2: git-48bf2b81d8ad
klusta: 3.0.16
mountainsort4: unknown
spykingcircus: 0.9.7
tridesclous: 1.6.0
In [2]:
# the recording file, downloaded from Dandi in NWB:N format 
data_file =  './sub-c1_ecephys.nwb'

# paths for spike sorter outputs
p = Path('./')

# select spike sorters to be used, note Kilosort2 requires a NVIDIA GPU to run
sorter_list = ['herdingspikes', 'kilosort2', 'ironclust', 'tridesclous', 'spykingcircus', 'hdsort']
sorter_params = {
#     'kilosort2': {'keep_good_only': True}, # removes good units!
    'mountainsort4': {'adjacency_radius': 50},
    'spyking_circus': {'adjacency_radius': 50},
    'herdingspikes': {'filter': True, 
sorter_names = ['HerdingSpikes', 'Kilosort2', 'Ironclust','Tridesclous', 'SpykingCircus', 'HDSort']
sorter_names_short = ['HS', 'KS', 'IC', 'TDC', 'SC', 'HDS']

# create a recording extractor, this gives access to the raw data in the NWB:N file
recording = se.NwbRecordingExtractor(data_file)

# NWB:N files store the data in (channels:time) order, but for spike sorting the transposed format is much
# more efficient. Therefore here we can create a CacheRecordingExtractor that re-writes the data
# as a binary file in the desired order. This will take some time, but speeds up subsequent steps:
# recording = se.CacheRecordingExtractor(recording)

# print some info
print("Sampling rate: {}Hz".format(recording.get_sampling_frequency()))
print("Duration: {}s".format(recording.get_num_frames()/recording.get_sampling_frequency()))
print("Number of channels: {}".format(recording.get_num_channels()))
Sampling rate: 30000.0Hz
Duration: 270.01123333333334s
Number of channels: 384

Run spike sorters and perform comparison between all outputs

In [3]:
# now create the study environment and run all spike sorters
# note that this function will not re-run a spike sorter if the sorting is already present in
# the working folder

study_folder = p / 'study/'
working_folder = p / 'working/'
if not study_folder.is_dir():
    print('Setting up study folder:', study_folder)
rec_dict = {'rec': recording}

result_dict = ss.run_sorters(sorter_list=sorter_list, recording_dict_or_list=rec_dict, with_output=True,
                             sorter_params=sorter_params, working_folder=working_folder, engine='loop', 
                             mode='keep', verbose=True)

# when done, load all sortings into a handly list
sortings = []
for s in sorter_list:
/disk/scratch/mhennig/miniconda3/envs/spikesorting/lib/python3.7/site-packages/spikeextractors/extractors/matsortingextractor/matsortingextractor.py:65: ResourceWarning: unclosed file <_io.BufferedReader name='/disk/scratch/mhennig/spikeinterface/neuropixels_ms/working/rec/hdsort/hdsort_output/hdsort_output_results.mat'>
  raise ImportError("Version 7.2 .mat file given, but you don't have h5py installed.")
ResourceWarning: Enable tracemalloc to get the object allocation traceback
In [4]:
# run a multi-sorting comparison, an all-to-all comparison
# results are saved and just loaded from a file if this exists

if not os.path.isfile(study_folder / 'multicomparison.gpickle'):
    mcmp = sc.compare_multiple_sorters(sorting_list=sortings, name_list=sorter_names_short, 
    print('saving multicomparison')
    print('loading multicomparison')
    mcmp = sc.MultiSortingComparison.load_multicomparison(study_folder)
loading multicomparison
Multicomaprison step 3: clean graph
Removed edge ('HDS_25017', 'KS_231', {'weight': 0.5272206303724928})
Removed edge ('IC_66', 'TDC_28', {'weight': 0.5032594524119948})
Removed edge ('KS_231', 'HS_9', {'weight': 0.5059360730593607})
Removed edge ('SC_119', 'TDC_28', {'weight': 0.5013054830287206})
Removed edge ('KS_71', 'HS_57', {'weight': 0.5238095238095238})
Removed 5 duplicate nodes
Multicomaprison step 4: extract agreement from graph
In [5]:
# plot an activity map
# the method uses a rather simpe (and slow) threshold spike detection

ax = plt.subplot(111)
w = sw.plot_activity_map(recording, transpose=True, ax=ax, background='w', frame=True)
ax.annotate('100$\\mu m$',(100,-115), ha='center');
In [6]:
# raw data traces

ax = plt.subplot(111)
w = sw.plot_timeseries(recording, channel_ids=range(160,168), color='k', ax=ax, trange=(7,8))
ax.annotate('100ms',(7.051,-190), ha='center');
In [7]:
# number of units found by each sorter

ax = plt.subplot(111)
ax.bar(range(len(sortings)), [len(s.get_unit_ids()) for s in sortings])
ax.set_xticklabels(sorter_names_short, rotation=60, ha='center')
ax.set_ylabel('Units detected') 
Text(0, 0.5, 'Units detected')
In [8]:
# spikewidgets provides handy widgets to plot summary statistics of the comparison

# show the number of units agreed upon by k sorters, in aggregate
ax = plt.subplot(111)
w = sw.plot_multicomp_agreement(mcmp, plot_type='pie', ax=ax)

# show the number of units agreed upon by k sorters, per sorter
ax = plt.subplot(111)
w = sw.plot_multicomp_agreement_by_sorter(mcmp, show_legend=True, ax=ax)