Quick benchmarck with new API and new sorters (april 2021)

Quick benchmark with new spikeinetrface API with new sorters

In spring 2021 the spikeinterface is refactored deeply.

During this refactoring some sorters have been added.

Here quick benchmark with one simulated dataset with MEArec.

In [7]:
%matplotlib inline
In [8]:
from pathlib import Path
import os
import shutil
from pprint import pprint
import getpass


import numpy as np
import matplotlib.pyplot as plt

import MEArec as mr
import neo
import quantities as pq


import spikeinterface.extractors  as se
import spikeinterface.widgets  as sw
import spikeinterface.sorters  as ss

from spikeinterface.comparison import GroundTruthStudy
In [9]:
basedir = '/mnt/data/sam/DataSpikeSorting/'

basedir = Path(basedir)

workdir = basedir / 'mearec_bench_2021'

study_folder = workdir /'study_mearec_march_2021'

tmp_folder = workdir / 'tmp'
tmp_folder.mkdir(parents=True, exist_ok=True)

generate recording with mearec

In [ ]:
template_filename = workdir / 'templates_Neuronexus-32_100.h5'
probe = 'Neuronexus-32'
n_cell = 15
duration = 10 * 60.

recording_filename = workdir /  f'recordings_{n_cell}cells_{probe}_{duration:0.0f}s.h5'


fs = 30000.


#~ spgen = mr.SpikeTrainGenerator()
rec_params = mr.get_default_recordings_params()

rec_params['recordings']['fs'] = fs
rec_params['recordings']['sync_rate'] = None
rec_params['recordings']['sync_jitter'] = 5
rec_params['recordings']['noise_level'] = 5
rec_params['recordings']['filter'] = False
rec_params['recordings']['chunk_duration'] = 10.
rec_params['spiketrains']['duration'] = duration
rec_params['spiketrains']['n_exc'] = n_cell
rec_params['spiketrains']['n_inh'] = 0
rec_params['templates']['n_overlap_pairs'] = None
rec_params['templates']['min_dist'] = 0

recgen = mr.gen_recordings(params=rec_params, #spgen=spgen, 
            templates=template_filename, verbose=True,
            n_jobs=1, tmp_mode='memmap',
            tmp_folder=str(tmp_folder))

mr.save_recording_generator(recgen, filename=recording_filename)

set sorter path

In [3]:
user = getpass.getuser()

kilosort_path = f'/home/{user}/Documents/SpikeInterface/code_sorters/KiloSort1'
ss.KilosortSorter.set_kilosort_path(kilosort_path)

kilosort2_path = f'/home/{user}/Documents/SpikeInterface/code_sorters/Kilosort2'
ss.Kilosort2Sorter.set_kilosort2_path(kilosort2_path)

kilosort2_5_path = f'/home/{user}/Documents/SpikeInterface/code_sorters/Kilosort2.5'
ss.Kilosort2_5Sorter.set_kilosort2_5_path(kilosort2_path)

kilosort3_path = f'/home/{user}/Documents/SpikeInterface/code_sorters/Kilosort3'
ss.Kilosort3Sorter.set_kilosort3_path(kilosort3_path)

ironclust_path = f'/home/{user}/Documents/SpikeInterface/code_sorters/ironclust/'
ss.IronClustSorter.set_ironclust_path(ironclust_path)
Setting KILOSORT_PATH environment variable for subprocess calls to: /home/samuel.garcia/Documents/SpikeInterface/code_sorters/KiloSort1
Setting KILOSORT2_PATH environment variable for subprocess calls to: /home/samuel.garcia/Documents/SpikeInterface/code_sorters/Kilosort2
Setting KILOSORT2_5_PATH environment variable for subprocess calls to: /home/samuel.garcia/Documents/SpikeInterface/code_sorters/Kilosort2
Setting KILOSORT3_PATH environment variable for subprocess calls to: /home/samuel.garcia/Documents/SpikeInterface/code_sorters/Kilosort3
Setting IRONCLUST_PATH environment variable for subprocess calls to: /home/samuel.garcia/Documents/SpikeInterface/code_sorters/ironclust

create study

In [6]:
mearec_filename = workdir / 'recordings_15cells_Neuronexus-32_600s.h5'

if study_folder.is_dir():
    shutil.rmtree(study_folder)

rec  = se.MEArecRecordingExtractor(mearec_filename)
sorting_gt = se.MEArecSortingExtractor(mearec_filename)
print(rec)
print(sorting_gt)

gt_dict = {'rec0' : (rec, sorting_gt) }

study = GroundTruthStudy.create(study_folder, gt_dict)
MEArecRecordingExtractor: 32 channels - 1 segments - 30.0kHz
  file_path: /mnt/data/sam/DataSpikeSorting/mearec_bench_2021/recordings_15cells_Neuronexus-32_600s.h5
MEArecSortingExtractor: 15 units - 1 segments - 30.0kHz
  file_path: /mnt/data/sam/DataSpikeSorting/mearec_bench_2021/recordings_15cells_Neuronexus-32_600s.h5
write_binary_recording with n_jobs 1  chunk_size None

plot probe

In [14]:
study = GroundTruthStudy(study_folder)
rec = study.get_recording()
probe = rec.get_probe()
print(probe)
from probeinterface.plotting import plot_probe
plot_probe(probe)
Probe - 32ch
Out[14]:
(<matplotlib.collections.PolyCollection at 0x7f93854cc370>,
 <matplotlib.collections.PolyCollection at 0x7f947882e7c0>)

run sorters

In [ ]:
sorter_list = ['spykingcircus', 'kilosort2', 'kilosort3', 'tridesclous']
study = GroundTruthStudy(study_folder)
study.run_sorters(sorter_list, mode_if_folder_exists='overwrite', verbose=False)
study.copy_sortings()

collect results

In [4]:
study = GroundTruthStudy(study_folder)
study.copy_sortings()


study.run_comparisons(exhaustive_gt=True, delta_time=1.5)


comparisons = study.comparisons
dataframes = study.aggregate_dataframes()
In [10]:
for (rec_name, sorter_name), comp in comparisons.items():
    print()
    print('*'*20)
    print(rec_name, sorter_name)
    print(comp.count_score)
********************
rec0 spykingcircus
              tp    fn    fp num_gt num_tested tested_id
gt_unit_id                                              
#0             0  2772     0   2772          0        -1
#1          2305     0  2127   2305       4432         0
#2             0  3009     0   3009          0        -1
#3             0  2503     0   2503          0        -1
#4          3135     0     4   3135       3139         2
#5             0  2081     0   2081          0        -1
#6          2192     0     2   2192       2194         5
#7          2723     0    55   2723       2778         3
#8             0  3453     0   3453          0        -1
#9             0  2334     0   2334          0        -1
#10         2280    15     8   2295       2288        11
#11         2588     8    12   2596       2600        10
#12         2721   333  1503   3054       4224         8
#13            0  3020     0   3020          0        -1
#14         3612     0  1070   3612       4682         6

********************
rec0 kilosort2
              tp   fn  fp num_gt num_tested tested_id
gt_unit_id                                           
#0          2765    7   6   2772       2771        29
#1          2299    6   0   2305       2299         8
#2          3008    1   0   3009       3008        19
#3          2502    1   2   2503       2504        25
#4          3117   18   0   3135       3117        10
#5          2076    5   1   2081       2077         7
#6          2188    4   0   2192       2188         3
#7          2717    6   0   2723       2717        26
#8          3447    6   0   3453       3447         4
#9          2323   11   5   2334       2328         6
#10         2112  183  54   2295       2166        31
#11         2592    4   0   2596       2592        11
#12         3051    3   0   3054       3051        14
#13         3019    1   0   3020       3019         1
#14         3603    9   0   3612       3603        22

********************
rec0 tridesclous
              tp  fn  fp num_gt num_tested tested_id
gt_unit_id                                          
#0          2727  45  22   2772       2749        14
#1          2294  11   0   2305       2294         4
#2          3003   6   1   3009       3004         1
#3          2467  36  20   2503       2487         9
#4          3123  12   9   3135       3132        13
#5          2047  34   6   2081       2053        10
#6          2159  33  12   2192       2171         7
#7          2695  28   0   2723       2695         6
#8          3420  33   1   3453       3421         5
#9          2293  41  63   2334       2356        12
#10         2230  65  24   2295       2254         3
#11         2532  64  18   2596       2550         2
#12         3023  31  21   3054       3044         0
#13         2979  41  10   3020       2989         8
#14         3588  24  12   3612       3600        11

********************
rec0 kilosort3
              tp    fn  fp num_gt num_tested tested_id
gt_unit_id                                            
#0          2734    38  12   2772       2746         3
#1          2302     3   0   2305       2302        29
#2          3005     4   2   3009       3007        77
#3          2450    53  96   2503       2546        74
#4          2906   229  26   3135       2932         7
#5          2067    14  42   2081       2109         2
#6          1381   811  56   2192       1437        14
#7          2712    11   2   2723       2714        76
#8          3447     6   0   3453       3447         0
#9          2288    46   3   2334       2291         1
#10         1424   871  52   2295       1476        35
#11            0  2596   0   2596          0        -1
#12         3041    13   0   3054       3041        23
#13         1580  1440   0   3020       1580        11
#14         3573    39  97   3612       3670        32

Agreement matrix

In [11]:
for (rec_name, sorter_name), comp in comparisons.items():
    fig, ax = plt.subplots()
    sw.plot_agreement_matrix(comp, ax=ax)
    fig.suptitle(rec_name+'   '+ sorter_name)

Accuracy vs SNR

In [ ]:
 

Compare old vs new spikeinterface API

Compare "old" vs "new " spikeinterface API

Author : Samuel Garcia 29 March 2021

In spring 2021, the spikeinterface team plan a "big refactoring" of the spikeinterface tool suite.

Main changes are:

  • use neo as much as possible for extractors
  • handle multi segment
  • improve performance (pre and post processing)
  • add A WaveformExtractor class

Here I will benchmark 2 aspects of the "new API":

  • filter with 10 workers on a multi core machine
  • extractor waveform 1 worker vs 10 workers

The becnhmark is done a 10 min spikeglx file with 384 channels.

The sorting is done with kilosort3.

My machine is Intel(R) Xeon(R) Silver 4210 CPU @ 2.20GHz 2 CPU with 20 core each.

In [5]:
from pathlib import Path
import shutil
import time
import matplotlib.pyplot as plt

base_folder = Path('/mnt/data/sam/DataSpikeSorting/eduarda_arthur') 
data_folder = base_folder / 'raw_awake'

Filter with OLD API

Here we :

  1. open the file
  2. lazy filter
  3. cache it
  4. dump to json

The "cache" step is in fact the "compute and save" step.

In [6]:
import spikeextractors as se
import spiketoolkit as st

print('spikeextractors version', se.__version__)
print('spiketoolkit version', st.__version__)

# step 1: open
file_path = data_folder / 'raw_awake_01_g0_t0.imec0.ap.bin'
recording = se.SpikeGLXRecordingExtractor(file_path)

# step 2: lazy filter
rec_filtered = st.preprocessing.bandpass_filter(recording,  freq_min=300. freq_max=6000.)
print(rec_filtered)

save_folder = base_folder / 'raw_awake_filtered_old'
if save_folder.is_dir():
    shutil.rmtree(save_folder)
save_folder.mkdir()

save_file = save_folder / 'filetred_recording.dat'
dump_file = save_folder / 'filetred_recording.json'

# step 3: cache
t0 = time.perf_counter()
cached = se.CacheRecordingExtractor(rec_filtered, chunk_mb=50, n_jobs=10, 
    save_path=save_file)
t1 = time.perf_counter()
run_time_filter_old = t1-t0
print('Old spikeextractors cache', run_time_filter_old)

# step : dump
cached.dump_to_json(dump_file)
spikeextractors version 0.9.5
spiketoolkit version 0.7.4
<spiketoolkit.preprocessing.bandpass_filter.BandpassFilterRecording object at 0x7f648d3ee130>
Old spikeextractors cache 801.9439885600004

Filter with NEW API

Here we :

  1. open the file
  2. lazy filter
  3. save it

The "save" step is in fact the "compute and save" step.

In [7]:
 
import spikeinterface as si

import spikeinterface.extractors as se
import spikeinterface.toolkit as st
print('spikeinterface version', si.__version__)

# step 1: open
recording = se.SpikeGLXRecordingExtractor(data_folder)
print(recording)

# step 2: lazy filter
rec_filtered =st.bandpass_filter(recording,  freq_min=300., freq_max=6000.)
print(rec_filtered)


filter_path = base_folder / 'raw_awake_filtered'
if filter_path.is_dir():
    shutil.rmtree(filter_path)

# step 3 : compute and save with 10 workers
t0 = time.perf_counter()
cached = rec_filtered.save(folder=filter_path,
    format='binary', dtype='int16',
    n_jobs=10,  total_memory="50M", progress_bar=True)
t1 = time.perf_counter()
run_time_filter_new = t1 -t0
print('New spikeinterface filter + save binary', run_time_filter_new)
spikeinterface version 0.90.0
SpikeGLXRecordingExtractor: 385 channels - 1 segments - 30.0kHz
BandpassFilterRecording: 385 channels - 1 segments - 30.0kHz
write_binary_recording with n_jobs 10  chunk_size 3246
write_binary_recording: 100%|██████████| 5546/5546 [00:51<00:00, 108.39it/s]
New spikeinterface filter + save binary 54.79437772196252

Extract waveform with OLD API

Here we use get_unit_waveforms from toolkit.

We do the computation with 1 and then 10 jobs.

In [21]:
from spikeextractors.baseextractor import BaseExtractor
import spikeextractors as se
import spiketoolkit as st
print('spikeextractors version', se.__version__)
print('spiketoolkit version', st.__version__)
spikeextractors version 0.9.5
spiketoolkit version 0.7.4
In [24]:
save_folder = base_folder / 'raw_awake_filtered_old'
dump_file = save_folder / 'filetred_recording.json'
recording = BaseExtractor.load_extractor_from_json(dump_file)

sorting_KS3 = se.KiloSortSortingExtractor(base_folder / 'output_kilosort3')
waveform_folder = base_folder / 'waveforms_extractor_old_1_job'
if waveform_folder.is_dir():
    shutil.rmtree(waveform_folder)
waveform_folder.mkdir()
sorting_KS3.set_tmp_folder(waveform_folder)

t0 = time.perf_counter()
wf, indexes, channels = st.postprocessing.get_unit_waveforms(recording, sorting_KS3,
            max_spikes_per_unit=500, return_idxs=True, chunk_mb=50, n_jobs=1,
            memmap=True)
t1 = time.perf_counter()
run_time_waveform_old_1jobs = t1 - t0
print('OLD API get_unit_waveforms 1 jobs', run_time_waveform_old_1jobs)
OLD API get_unit_waveforms 1 jobs 513.5964983040467
In [30]:
save_folder = base_folder / 'raw_awake_filtered_old'
dump_file = save_folder / 'filetred_recording.json'
recording = BaseExtractor.load_extractor_from_json(dump_file)

sorting_KS3_bis = se.KiloSortSortingExtractor(base_folder / 'output_kilosort3')
waveform_folder = base_folder / 'waveforms_extractor_old_10_jobs_'
if waveform_folder.is_dir():
    shutil.rmtree(waveform_folder)
waveform_folder.mkdir()
sorting_KS3_bis.set_tmp_folder(waveform_folder)

t0 = time.perf_counter()
wf, indexes, channels = st.postprocessing.get_unit_waveforms(recording, sorting_KS3_bis,
            max_spikes_per_unit=500, return_idxs=True, chunk_mb=500, n_jobs=10,
            memmap=True, verbose=True)
t1 = time.perf_counter()
run_time_waveform_old_10jobs = t1 - t0
print('OLD API get_unit_waveforms 10 jobs', run_time_waveform_old_10jobs)
Number of chunks: 553 - Number of jobs: 10
Impossible to delete temp file: /mnt/data/sam/DataSpikeSorting/eduarda_arthur/waveforms_extractor_old_10_jobs Error [Errno 16] Device or resource busy: '.nfs0000000004ce04d3000007b8'
OLD API get_unit_waveforms 10 jobs 823.8002076600096

Extract waveform with NEW API

The spikeinterface 0.9 API introduce more flexible object WaveformExtractor to do the same (extract snipet).

Here some code example and benchmark speed.

In [39]:
import spikeinterface.extractors as se
from spikeinterface import WaveformExtractor, load_extractor
print('spikeinterface version', si.__version__)

filter_path = base_folder / 'raw_awake_filtered'
filered_recording = load_extractor(filter_path)

sorting_KS3 = se.KiloSortSortingExtractor(base_folder / 'output_kilosort3')
print(sorting_KS3)
spikeinterface version 0.90.0
KiloSortSortingExtractor: 184 units - 1 segments - 30.0kHz
In [41]:
# 1 worker
waveform_folder = base_folder / 'waveforms_extractor_1_job_new_'
if waveform_folder.is_dir():
    shutil.rmtree(waveform_folder)
we = WaveformExtractor.create(filered_recording, sorting_KS3, waveform_folder)

t0 = time.perf_counter()
we.set_params(ms_before=3., ms_after=4., max_spikes_per_unit=500)
we.run(n_jobs=1, total_memory="50M", progress_bar=True)
t1 = time.perf_counter()
run_time_waveform_new_1jobs = t1 - t0
print('New WaveformExtractor 1 jobs',run_time_waveform_new_1jobs)
100%|##########| 278/278 [01:42<00:00,  2.72it/s]
New WaveformExtractor 1 jobs 115.03656197001692
In [42]:
# 1 worker
waveform_folder = base_folder / 'waveforms_extractor_10_job_new_'
if waveform_folder.is_dir():
    shutil.rmtree(waveform_folder)
we = WaveformExtractor.create(filered_recording, sorting_KS3, waveform_folder)

t0 = time.perf_counter()
we.set_params(ms_before=3., ms_after=4., max_spikes_per_unit=500)
we.run(n_jobs=10, total_memory="500M", progress_bar=True)
t1 = time.perf_counter()
run_time_waveform_new_10jobs = t1 - t0
print('New WaveformExtractor 10 jobs', run_time_waveform_new_10jobs)
100%|██████████| 278/278 [00:31<00:00,  8.87it/s]
New WaveformExtractor 10 jobs 48.819815920025576

Conclusion

For filter with 10 workers the speedup is x14.

For waveform extactor with 1 workers the speedup is x4

For waveform extactor with 10 workers the speedup is x16

In [11]:
speedup_filter = run_time_filter_old / run_time_filter_new
print('speedup filter', speedup_filter)
speedup filter 14.635515939778026
In [43]:
speedup_waveform_1jobs = run_time_waveform_old_1jobs / run_time_waveform_new_1jobs
print('speedup waveforms 1 jobs', speedup_waveform_1jobs)

speedup_waveform_10jobs = run_time_waveform_old_10jobs / run_time_waveform_new_10jobs
print('speedup waveformd 10jobs', speedup_waveform_10jobs)
speedup waveforms 1 jobs 4.464637064152789
speedup waveformd 10jobs 16.874299751754943
In [ ]:
 

Ensemble sorting of a 3Brain Biocam recording from a retina

Ensemble sorting of a 3Brain Biocam recording from a mouse retina

This notebook reproduces supplemental figure S3 from the paper SpikeInterface, a unified framework for spike sorting.

The recording was made by Gerrit Hilgen in the lab of Evelyne Sernagor, University of Newcastle.

The data set for this notebook is available on the Dandi Archive: https://gui.dandiarchive.org/#/dandiset/000034

The entire data archive can be downloaded with the command dandi download https://gui.dandiarchive.org/#/dandiset/000034/draft (about 75GB).

File required to run the code:

The file should be in the same directory where the notebook is located (otherwise adjust paths below).

Author: Matthias Hennig, University of Edinburgh, 24 Aug 2020

Requirements

For this need you will need the following Python packages:

  • numpy
  • pandas
  • matplotlib
  • seaborn
  • spikeinterface
  • dandi

To run the MATLAB-based sorters, you would also need a MATLAB license. For other sorters, please refer to the documentation on how to install sorters.

In [1]:
import os

# Matlab sorter paths:
# change these to match your environment
os.environ["IRONCLUST_PATH"] = "./ironclust"
os.environ["KILOSORT2_PATH"] = "./Kilosort2"
os.environ["HDSORT_PATH"] = "./HDsort"

import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path
import pandas as pd
import seaborn as sns
from collections import defaultdict
from matplotlib_venn import venn3

import spikeinterface as si
import spikeextractors as se
import spiketoolkit as st
import spikesorters as ss
import spikecomparison as sc
import spikewidgets as sw
from spikecomparison import GroundTruthStudy, MultiSortingComparison

%matplotlib inline

def clear_axes(ax):
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    
# print version information
si.print_spikeinterface_version()
ss.print_sorter_versions()
spikeinterface: 0.9.9
  * spikeextractor: 0.8.4
  * spiketoolkit: 0.6.3
  * spikesorters: 0.3.3
  * spikecomparison: 0.2.6
  * spikewidgets: 0.4.3

hdsort: version = '1.0.2'

herdingspikes: 0.3.7
ironclust: 5.9.8
kilosort2: git-48bf2b81d8ad
klusta: 3.0.16
mountainsort4: unknown
spykingcircus: 0.9.7
tridesclous: 1.6.0
In [2]:
# the recording file, downloaded from Dandi in NWB:N format 
data_file =  './sub-P29-16-05-14-retina02-left_ecephys.nwb'

# paths for spike sorter outputs
p = Path('./')

# select spike sorters to be used, note Kilosort2 requires a NVIDIA GPU to run
sorter_list = ['herdingspikes', 'kilosort2', 'ironclust', 'hdsort']
sorter_params = {'spyking_circus': {'adjacency_radius': 50}, 'herdingspikes': {'filter': True, }}
sorter_names = ['HerdingSpikes', 'Kilosort2', 'Ironclust',  'HDSort']
sorter_names_short = ['HS', 'KS', 'IC', 'HDS']

# create a recording extractor, this gives access to the raw data in the NWB:N file
recording = se.NwbRecordingExtractor(data_file)

# this recording has some bad artefcats, and these are easily blanked out using spiketoolkit
# note that no spike sorter can handle these artefacts well, so this step is essential
recording = st.preprocessing.blank_saturation(recording)
# (aside: the warning below can be ignored, artefacts are quite sparse but definetely there, the check
# if very conservative)

# NWB:N files store the data in (channels:time) order, but for spike sorting the transposed format is much
# more efficient. Therefore here we create a CacheRecordingExtractor that re-writes the data
# as a binary file in the desired order. This will take some time, but speeds up subsequent steps:
recording = se.CacheRecordingExtractor(recording)

# print some info
print("Sampling rate: {}Hz".format(recording.get_sampling_frequency()))
print("Duration: {}s".format(recording.get_num_frames()/recording.get_sampling_frequency()))
print("Number of channels: {}".format(recording.get_num_channels()))
Warning, narrow signal range suggests artefact-free data.
Sampling rate: 23199.090358491783Hz
Duration: 135.7303218075276s
Number of channels: 1024

Run spike sorters and perform comparison between all outputs

In [3]:
# now create the study environment and run all spike sorters
# note that this function will not re-run a spike sorter if the sorting is already present in
# the working folder

study_folder = p / 'study/'
working_folder = p / 'working/'
if not study_folder.is_dir():
    print('Setting up study folder:', study_folder)
    os.mkdir(study_folder)
rec_dict = {'rec': recording}

result_dict = ss.run_sorters(sorter_list=sorter_list, recording_dict_or_list=rec_dict, with_output=True,
                             sorter_params=sorter_params, working_folder=working_folder, engine='loop', 
                             mode='keep', verbose=True)

# when done, load all sortings into a handly list
sortings = []
for s in sorter_list:
    sortings.append(result_dict['rec',s])
/disk/scratch/mhennig/miniconda3/envs/spikesorting/lib/python3.7/site-packages/spikeextractors/extractors/matsortingextractor/matsortingextractor.py:65: ResourceWarning: unclosed file <_io.BufferedReader name='/disk/scratch/mhennig/spikeinterface/retina_short/working/rec/hdsort/hdsort_output/hdsort_output_results.mat'>
  raise ImportError("Version 7.2 .mat file given, but you don't have h5py installed.")
ResourceWarning: Enable tracemalloc to get the object allocation traceback
In [4]:
# run a multi-sorting comparison, an all-to-all comparison
# results are saved and just loaded from a file if this exists

if not os.path.isfile(study_folder / 'multicomparison.gpickle'):
    mcmp = sc.compare_multiple_sorters(sorting_list=sortings, name_list=sorter_names_short, 
                                       verbose=True)
    print('saving multicomparison')
    mcmp.dump(study_folder)
else:
    print('loading multicomparison')
    mcmp = sc.MultiSortingComparison.load_multicomparison(study_folder)
loading multicomparison
Multicomaprison step 3: clean graph
Removed edge ('KS_244', 'HS_18', {'weight': 0.6169491525423729})
Removed edge ('IC_181', 'HDS_137003', {'weight': 0.9284436493738819})
Removed edge ('HDS_450004', 'KS_1160', {'weight': 0.5208791208791209})
Removed 3 duplicate nodes
Multicomaprison step 4: extract agreement from graph
In [5]:
# plot an activity map
# as the channel count is high, here we use the spikes detected by HS2
# they are easily retrieved from the sorting results

sx = result_dict['rec','herdingspikes']
n,v = np.histogram(sx._rf['ch'], bins=np.arange(1025))

ax = plt.subplot(111)
ax.imshow(n.reshape((32,32)))
ax.set_xticks(())
ax.set_yticks(())
ax.plot((10,10+100/42),(-1,-1),'k-')
ax.annotate('100$\\mu m$',(10+100/42/2,-2), ha='center');
ax.axis('off')
ax.set_aspect('equal')
In [6]:
# next we plot some raw data traces
# for better visualisation, we bandpass filter the traces before showing them
# to this end, we use a lazy bandpass filter from spiketoolkit
recording_bandpass = st.preprocessing.bandpass_filter(recording)
plt.figure(figsize=(12,3))
ax = plt.subplot(111)
w = sw.plot_timeseries(recording_bandpass, channel_ids=range(300,308), color='k', ax=ax, trange=(2,3))
ax.plot((2.01,2.11),(-50,-50),'k-');
ax.annotate('100ms',(2.051,-120), ha='center');
ax.axis('off');
Warning: dumping a CacheRecordingExtractor. The path to the tmp binary file will be lost in further sessions. To prevent this, use the 'CacheRecordingExtractor.move_to('path-to-file)' function
In [7]:
# number of units found by each sorter
ax = plt.subplot(111)
ax.bar(range(len(sortings)), [len(s.get_unit_ids()) for s in sortings])
ax.set_xticks(range(len(sorter_names)))
ax.set_xticklabels(sorter_names_short, rotation=60, ha='center')
ax.set_ylabel('Units detected') 
Out[7]:
Text(0, 0.5, 'Units detected')
In [8]:
# spikewidgets provides handy widgets to plot summary statistics of the comparison

# show the number of units agreed upon by k sorters, in aggregate
plt.figure()
ax = plt.subplot(111)
w = sw.plot_multicomp_agreement(mcmp, plot_type='pie', ax=ax)

# show the number of units agreed upon by k sorters, per sorter
plt.figure()
ax = plt.subplot(111)
w = sw.plot_multicomp_agreement_by_sorter(mcmp, show_legend=True, ax=ax)

Ensemble sorting of a Neuropixels recording 2

Ensemble sorting of a Neuropixel recording (2)

This notebook reproduces supplemental figure S2 from the paper SpikeInterface, a unified framework for spike sorting.

The recording was made by André Marques-Smith in the lab of Adam Kampff. Reference:

Marques-Smith, A., Neto, J.P., Lopes, G., Nogueira, J., Calcaterra, L., Frazão, J., Kim, D., Phillips, M., Dimitriadis, G., Kampff, A.R. (2018). Recording from the same neuron with high-density CMOS probes and patch-clamp: a ground-truth dataset and an experiment in collaboration. bioRxiv 370080; doi: https://doi.org/10.1101/370080

The data set for this notebook is available on the Dandi Archive: https://gui.dandiarchive.org/#/dandiset/000034

The entire data archive can be downloaded with the command dandi download https://gui.dandiarchive.org/#/dandiset/000034/draft (about 75GB).

File required to run the code:

This file should be in the same directory where the notebook is located (otherwise adjust paths below).

Author: Matthias Hennig, University of Edinburgh, 25 Aug 2020

Requirements

For this need you will need the following Python packages:

  • numpy
  • pandas
  • matplotlib
  • seaborn
  • spikeinterface
  • dandi

To run the MATLAB-based sorters, you would also need a MATLAB license. For other sorters, please refer to the documentation on how to install sorters.

In [1]:
import os

# Matlab sorter paths:
# change these to match your environment
os.environ["IRONCLUST_PATH"] = "./ironclust"
os.environ["KILOSORT2_PATH"] = "./Kilosort2"
os.environ["HDSORT_PATH"] = "./HDsort"

import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path
import pandas as pd
import seaborn as sns
from collections import defaultdict
from matplotlib_venn import venn3

import spikeinterface as si
import spikeextractors as se
import spiketoolkit as st
import spikesorters as ss
import spikecomparison as sc
import spikewidgets as sw
from spikecomparison import GroundTruthStudy, MultiSortingComparison

%matplotlib inline

def clear_axes(ax):
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    
# print version information
si.print_spikeinterface_version()
ss.print_sorter_versions()
spikeinterface: 0.9.9
  * spikeextractor: 0.8.4
  * spiketoolkit: 0.6.3
  * spikesorters: 0.3.3
  * spikecomparison: 0.2.6
  * spikewidgets: 0.4.3

hdsort: version = '1.0.2'

herdingspikes: 0.3.7
ironclust: 5.9.8
kilosort2: git-48bf2b81d8ad
klusta: 3.0.16
mountainsort4: unknown
spykingcircus: 0.9.7
tridesclous: 1.6.0
In [2]:
# the recording file, downloaded from Dandi in NWB:N format 
data_file =  './sub-c1_ecephys.nwb'

# paths for spike sorter outputs
p = Path('./')

# select spike sorters to be used, note Kilosort2 requires a NVIDIA GPU to run
sorter_list = ['herdingspikes', 'kilosort2', 'ironclust', 'tridesclous', 'spykingcircus', 'hdsort']
sorter_params = {
#     'kilosort2': {'keep_good_only': True}, # removes good units!
    'mountainsort4': {'adjacency_radius': 50},
    'spyking_circus': {'adjacency_radius': 50},
    'herdingspikes': {'filter': True, 
                     }
}
sorter_names = ['HerdingSpikes', 'Kilosort2', 'Ironclust','Tridesclous', 'SpykingCircus', 'HDSort']
sorter_names_short = ['HS', 'KS', 'IC', 'TDC', 'SC', 'HDS']

# create a recording extractor, this gives access to the raw data in the NWB:N file
recording = se.NwbRecordingExtractor(data_file)

# NWB:N files store the data in (channels:time) order, but for spike sorting the transposed format is much
# more efficient. Therefore here we can create a CacheRecordingExtractor that re-writes the data
# as a binary file in the desired order. This will take some time, but speeds up subsequent steps:
# recording = se.CacheRecordingExtractor(recording)

# print some info
print("Sampling rate: {}Hz".format(recording.get_sampling_frequency()))
print("Duration: {}s".format(recording.get_num_frames()/recording.get_sampling_frequency()))
print("Number of channels: {}".format(recording.get_num_channels()))
Sampling rate: 30000.0Hz
Duration: 270.01123333333334s
Number of channels: 384

Run spike sorters and perform comparison between all outputs

In [3]:
# now create the study environment and run all spike sorters
# note that this function will not re-run a spike sorter if the sorting is already present in
# the working folder

study_folder = p / 'study/'
working_folder = p / 'working/'
if not study_folder.is_dir():
    print('Setting up study folder:', study_folder)
    os.mkdir(study_folder)
rec_dict = {'rec': recording}

result_dict = ss.run_sorters(sorter_list=sorter_list, recording_dict_or_list=rec_dict, with_output=True,
                             sorter_params=sorter_params, working_folder=working_folder, engine='loop', 
                             mode='keep', verbose=True)

# when done, load all sortings into a handly list
sortings = []
for s in sorter_list:
    sortings.append(result_dict['rec',s])
/disk/scratch/mhennig/miniconda3/envs/spikesorting/lib/python3.7/site-packages/spikeextractors/extractors/matsortingextractor/matsortingextractor.py:65: ResourceWarning: unclosed file <_io.BufferedReader name='/disk/scratch/mhennig/spikeinterface/neuropixels_ms/working/rec/hdsort/hdsort_output/hdsort_output_results.mat'>
  raise ImportError("Version 7.2 .mat file given, but you don't have h5py installed.")
ResourceWarning: Enable tracemalloc to get the object allocation traceback
In [4]:
# run a multi-sorting comparison, an all-to-all comparison
# results are saved and just loaded from a file if this exists

if not os.path.isfile(study_folder / 'multicomparison.gpickle'):
    mcmp = sc.compare_multiple_sorters(sorting_list=sortings, name_list=sorter_names_short, 
                                       verbose=True)
    print('saving multicomparison')
    mcmp.dump(study_folder)
else:
    print('loading multicomparison')
    mcmp = sc.MultiSortingComparison.load_multicomparison(study_folder)
loading multicomparison
Multicomaprison step 3: clean graph
Removed edge ('HDS_25017', 'KS_231', {'weight': 0.5272206303724928})
Removed edge ('IC_66', 'TDC_28', {'weight': 0.5032594524119948})
Removed edge ('KS_231', 'HS_9', {'weight': 0.5059360730593607})
Removed edge ('SC_119', 'TDC_28', {'weight': 0.5013054830287206})
Removed edge ('KS_71', 'HS_57', {'weight': 0.5238095238095238})
Removed 5 duplicate nodes
Multicomaprison step 4: extract agreement from graph
In [5]:
# plot an activity map
# the method uses a rather simpe (and slow) threshold spike detection

plt.figure(figsize=(16,2))
ax = plt.subplot(111)
w = sw.plot_activity_map(recording, transpose=True, ax=ax, background='w', frame=True)
ax.plot((50,150),(-40,-40),'k-')
ax.annotate('100$\\mu m$',(100,-115), ha='center');
In [6]:
# raw data traces

plt.figure(figsize=(12,3))
ax = plt.subplot(111)
w = sw.plot_timeseries(recording, channel_ids=range(160,168), color='k', ax=ax, trange=(7,8))
ax.axis('off')
ax.plot((7.01,7.11),(20,20),'k-')
ax.annotate('100ms',(7.051,-190), ha='center');
In [7]:
# number of units found by each sorter

ax = plt.subplot(111)
ax.bar(range(len(sortings)), [len(s.get_unit_ids()) for s in sortings])
ax.set_xticks(range(len(sorter_names)))
ax.set_xticklabels(sorter_names_short, rotation=60, ha='center')
ax.set_ylabel('Units detected') 
Out[7]:
Text(0, 0.5, 'Units detected')
In [8]:
# spikewidgets provides handy widgets to plot summary statistics of the comparison

# show the number of units agreed upon by k sorters, in aggregate
plt.figure()
ax = plt.subplot(111)
w = sw.plot_multicomp_agreement(mcmp, plot_type='pie', ax=ax)

# show the number of units agreed upon by k sorters, per sorter
plt.figure()
ax = plt.subplot(111)
w = sw.plot_multicomp_agreement_by_sorter(mcmp, show_legend=True, ax=ax)

Ground truth comparison and ensemble sorting of a synthetic Neuropixels recording

Ground-truth comparison and ensemble sorting of a synthetic Neuropixels recording

This notebook reproduces figures 2 and 3 from the paper SpikeInterface, a unified framework for spike sorting.

The data set for this notebook is available on the Dandi Archive: https://gui.dandiarchive.org/#/dandiset/000034.

The entire data archive can be downloaded with the command dandi download https://gui.dandiarchive.org/#/dandiset/000034/draft (about 75GB).

The data file required to run the code is:

This file should be in the same directory where the notebook is located (otherwise adjust paths below).

Author: Matthias Hennig, University of Edinburgh, 22 Aug 2020

Requirements

For this need you will need the following Python packages:

  • numpy
  • pandas
  • matplotlib
  • seaborn
  • spikeinterface
  • dandi
  • matplotlib-venn

To run the MATLAB-based sorters, you would also need a MATLAB license. For other sorters, please refer to the documentation on how to install sorters.

In [1]:
import os

# Matlab sorter paths:
# change these to match your environment
os.environ["IRONCLUST_PATH"] = "./ironclust"
os.environ["KILOSORT2_PATH"] = "./Kilosort2"
os.environ["HDSORT_PATH"] = "./HDsort"

import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path
import pandas as pd
import seaborn as sns
from collections import defaultdict
from matplotlib_venn import venn3

import spikeinterface as si
import spikeextractors as se
import spiketoolkit as st
import spikesorters as ss
import spikecomparison as sc
import spikewidgets as sw
from spikecomparison import GroundTruthStudy

%matplotlib inline

def clear_axes(ax):
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    
# print version information
si.print_spikeinterface_version()
ss.print_sorter_versions()
spikeinterface: 0.9.9
  * spikeextractor: 0.8.4
  * spiketoolkit: 0.6.3
  * spikesorters: 0.3.3
  * spikecomparison: 0.2.6
  * spikewidgets: 0.4.3

hdsort: version = '1.0.2'

herdingspikes: 0.3.7
ironclust: 5.9.8
kilosort2: git-48bf2b81d8ad
klusta: 3.0.16
mountainsort4: unknown
spykingcircus: 0.9.7
tridesclous: 1.6.0

Set up ground truth study an run all sorters

In [2]:
study_path = Path('.')
data_path = Path('.')
study_folder = study_path / 'study_mearec_250cells_Neuropixels-384chans_duration600s_noise10uV_2020-02-28/'

# the original data
# this NWB file contains both the ground truth spikes and the raw data
data_filename = data_path / 'sub-MEAREC-250neuron-Neuropixels_ecephys.nwb'
SX_gt = se.NwbSortingExtractor(str(data_filename))
RX = se.NwbRecordingExtractor(str(data_filename))

if not os.path.isdir(study_folder):
    gt_dict = {'rec0' : (RX, SX_gt) }
    study = GroundTruthStudy.create(study_folder, gt_dict)
else:
    study = GroundTruthStudy(study_folder)

sorter_list = ['herdingspikes', 'kilosort2', 'ironclust',
               'spykingcircus', 'tridesclous', 'hdsort']
sorter_names = ['HerdingSpikes', 'Kilosort2', 'Ironclust',
               'SpykingCircus', 'Tridesclous', 'HDSort']
sorter_names_short = ['HS', 'KS', 'IC', 'SC', 'TDC', 'HDS']

study.run_sorters(sorter_list, mode='keep', engine='loop', verbose=True)

study.copy_sortings()

# compute or load SNR for the ground truth units
snr_file = study_folder / 'snr.npy'
if os.path.isfile(snr_file):
    snr = np.load(snr_file)
else:
    print('computing snr')
    # note this is quite slow for a NWB file as the data is arranged as channels:time
    # it is faster to first write out a binary file in time:channels order
    snr = st.validation.compute_snrs(SX_gt, RX, apply_filter=False, verbose=False, 
                                     memmap=True, max_spikes_per_unit_for_snr=500)
    np.save(snr_file, snr)
/disk/scratch/mhennig/miniconda3/envs/spikesorting/lib/python3.7/site-packages/spikeextractors/extractors/matsortingextractor/matsortingextractor.py:65: ResourceWarning: unclosed file <_io.BufferedReader name='/disk/scratch/mhennig/spikeinterface/paper/MEArec/study_mearec_250cells_Neuropixels-384chans_duration600s_noise10uV_2020-02-28/sorter_folders/rec0/hdsort/hdsort_output/hdsort_output_results.mat'>
  raise ImportError("Version 7.2 .mat file given, but you don't have h5py installed.")
ResourceWarning: Enable tracemalloc to get the object allocation traceback

Run the ground truth comparison and summarise the results

In [3]:
study.run_comparisons(exhaustive_gt=True, match_score=0.1)
comparisons = study.comparisons
dataframes = study.aggregate_dataframes()
In [4]:
# comparison summary
dataframes['count_units']
Out[4]:
rec_name sorter_name num_gt num_sorter num_well_detected num_redundant num_overmerged num_false_positive num_bad
0 rec0 ironclust 250 283 202 8 2 41 49
1 rec0 spykingcircus 250 343 175 43 1 56 99
2 rec0 tridesclous 250 189 135 3 5 9 12
3 rec0 hdsort 250 457 196 4 3 210 214
4 rec0 kilosort2 250 415 245 21 2 147 168
5 rec0 herdingspikes 250 233 128 6 3 39 45

Figure 1 - ground truth study results

In [5]:
# activity levels on the probe

plt.figure(figsize=(16,2))
ax = plt.subplot(111)
w = sw.plot_activity_map(RX, trange=(0,20), transpose=True, ax=ax, background='w', frame=True)
ax.plot((-1800,-1700), (-120,-120), 'k-')
ax.annotate('100$\\mu m$',(-1750,-220), ha='center');
In [6]:
# example data traces

plt.figure(figsize=(16,6))
ax = plt.subplot(111)
w = sw.plot_timeseries(RX, channel_ids=range(10,18), color='k', ax=ax, trange=(1,2))
ax.axis('off')
p = ax.get_position()
p.y0 = 0.58
ax.set_position(p)
ax.set_xticks(())
ax.plot((1.01,1.11),(-400,-400),'k-')
ax.annotate('100ms',(1.051,-750), ha='center');
ax.set_ylim((-750,ax.set_ylim()[1]))
Out[6]:
(-750, 4435.259765625)
In [7]:
ax = plt.subplot(111)
n = []
for s in sorter_list:
    n.append(len(study.get_sorting(s).get_unit_ids()))
ax.bar(range(len(sorter_list)), n, color='tab:blue')
ax.set_xticks(range(len(sorter_names_short)))
ax.set_xticklabels(sorter_names_short, rotation=60, ha='center')
ax.set_ylabel('Units detected') 
clear_axes(ax)
In [8]:
ax = plt.subplot(111)
p = ax.get_position()
p.x1=0.85
ax.set_position(p)
sns.set_palette(sns.color_palette("Set1"))
df = pd.melt(dataframes['perf_by_units'], id_vars='sorter_name', var_name='Metric', value_name='Score', 
        value_vars=('accuracy','precision', 'recall'))
sns.swarmplot(data=df, x='sorter_name', y='Score', hue='Metric', dodge=True,
                order=sorter_list,  s=3, ax=ax)
ax.set_xticklabels(sorter_names_short, rotation=30, ha='center')
ax.legend(bbox_to_anchor=(1.0, 1), loc=2, borderaxespad=0., frameon=False, fontsize=8, markerscale=0.5)
ax.set_xlabel(None);
ax.set_ylabel('Score');
clear_axes(ax)
/disk/scratch/mhennig/miniconda3/envs/spikesorting/lib/python3.7/site-packages/matplotlib/axes/_axes.py:4204: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
  (isinstance(c, collections.Iterable) and
In [9]:
ax = plt.subplot(111)
p = ax.get_position()
p.x1=0.85
ax.set_position(p)
df = pd.melt(dataframes['count_units'], id_vars='sorter_name', var_name='Type', value_name='Units', 
        value_vars=('num_well_detected', 'num_false_positive', 'num_redundant', 'num_overmerged'))
sns.set_palette(sns.color_palette("Set1"))
sns.barplot(x='sorter_name', y='Units', hue='Type', data=df,
                order=sorter_list, ax=ax)
ax.set_xticklabels(sorter_names_short, rotation=30, ha='right')
ax.legend(bbox_to_anchor=(1.0, 1), loc=2, borderaxespad=0., frameon=False, fontsize=8, markerscale=0.1)
for t, l in zip(ax.legend_.texts, ("Well detected", "False positive", "Redundant", "Overmerged")): t.set_text(l)
ax.set_xlabel(None);
clear_axes(ax)
In [10]:
# precision vs. recall and accuracy vs. SNR
fig = plt.figure(figsize=(14, 4))

sns.set_palette(sns.color_palette("deep"))

axesA = []
for i,s in enumerate(sorter_list):
    ax = plt.subplot(2,len(sorter_list),i+1)
    axesA.append(ax)

    g = sns.scatterplot(data=dataframes['perf_by_units'].loc[dataframes['perf_by_units'].sorter_name==s], 
                    x='precision', y='recall', s=30, edgecolor=None, alpha=0.1)
    ax.set_title(sorter_names[i])
    ax.set_aspect('equal')
    clear_axes(ax)
    ax.set_xlabel('Precision')
    ax.set_ylabel('Recall')

for ax in axesA[1:]:
    axesA[0].get_shared_y_axes().join(axesA[0], ax)
    ax.set_yticklabels([])
    ax.set_ylabel('')
    ax.set_xlabel('')
    ax.autoscale()
    
############### B

df = dataframes['perf_by_units']

# add snr to the by-unit table
if 'snr' not in df.keys():
    snr_d = {k:snr[k] for i,k in enumerate(SX_gt.get_unit_ids())}
    df['snr'] = df['gt_unit_id'].map(snr_d)
    

axesB = []
for i,s in enumerate(sorter_list):
    ax = plt.subplot(2,len(sorter_list),len(sorter_list)+i+1)
    axesB.append(ax)
    
    g = sns.scatterplot(data=dataframes['perf_by_units'].loc[dataframes['perf_by_units'].sorter_name==s], 
                        x='snr', y='accuracy', s=30, alpha=0.2)
    clear_axes(ax)
    ax.set_xlabel('Ground truth SNR')
    ax.set_ylabel('Accuracy')
    
for ax in axesB[1:]:
    axesB[0].get_shared_y_axes().join(axesB[0], ax)
    ax.set_yticklabels([])
    ax.set_ylabel('')
    ax.set_xlabel('')
    ax.autoscale()
    

Figure 3 - comparison of sorter outputs and ensembe sorting

In [11]:
# perform an all-to-all multicomparison or load it from
# disk if file exists

sortings = []
for s in sorter_list:
    sortings.append(study.get_sorting(s))

cmp_folder = study_folder / 'multicomparison/'
if not os.path.isdir(cmp_folder):
    os.mkdir(cmp_folder)
if not os.path.isfile(cmp_folder / 'multicomparison.gpickle'):
    mcmp = sc.compare_multiple_sorters(sorting_list=sortings, name_list=sorter_names_short, 
                                   verbose=False, match_score=0.5)
    print('saving multicomparison')
    mcmp.dump(cmp_folder)
else:
    print('loading multicomparison')
    mcmp = sc.MultiSortingComparison.load_multicomparison(cmp_folder)
    
mcmp_graph = mcmp.graph.copy()
loading multicomparison
In [12]:
# get sorting extractors with unit with no agreement (minimum agreement = 1) and one
# with at least 2 sorters in agreement
not_in_agreement = mcmp.get_agreement_sorting(minimum_agreement_count=1, minimum_agreement_count_only=True)
in_agreement = mcmp.get_agreement_sorting(minimum_agreement_count=2)

# score these against ground truth
cmp_no_agr = sc.compare_sorter_to_ground_truth(SX_gt, not_in_agreement)
cmp_agr = sc.compare_sorter_to_ground_truth(SX_gt, in_agreement)

# now collect results for each sorter:

# create dict to collect results
results = {'TP':{}, 'FP':{}, 'SNR':{}}
ns = len(sorter_names_short)
for s in sorter_names_short:
    results['TP'][s] = dict(zip(range(1,ns+1), [0]*(ns+1)))
    results['FP'][s] = dict(zip(range(1,ns+1), [0]*(ns+1)))
    results['SNR'][s] = dict(zip(range(1,ns+1), [[]]*(ns+1)))
    
# sorter names
dict_names = dict(zip(sorter_names_short, sorter_list))

# iterate over all units gathered from subgraphs
for u in  mcmp._new_units.keys():
    found_in_gt = []
    gt_index = []
    # check if units have a match in ground truth, store boolen
    for u2 in mcmp._new_units[u]['sorter_unit_ids'].items():
        found_in_gt.append(u2[1] in study.comparisons['rec0',dict_names[u2[0]]].best_match_12.values)
        if found_in_gt[-1]:
            gt_index.append(np.where(study.comparisons['rec0',dict_names[u2[0]]].best_match_12==u2[1])[0][0])
    if len(set(gt_index))>1:
        print('different gt units: ',u, gt_index)
    if np.sum(found_in_gt)==len(found_in_gt):
#     if np.sum(found_in_gt)>0:#==len(found_in_gt):  # use this if interested in equal matches
        key = 'TP'
    else:
        key = 'FP'
        if len(found_in_gt)>1:
            print('FP unit found by >1 sorter: ',u)
        
    for i,u2 in enumerate(mcmp._new_units[u]['sorter_unit_ids'].items()):
#         results[key][u2[0]][np.sum(found_in_gt)] += 1 # use this if interested in equal matches
        results[key][u2[0]][len(found_in_gt)] += 1
        if key is 'TP':
            # something odd with nested oython dicts requires this:
            d = results['SNR'][u2[0]][len(found_in_gt)].copy()
            d.append(snr[gt_index[i]])
            results['SNR'][u2[0]][len(found_in_gt)] = d
            # this fails, I wonder why:
            # results['SNR'][u2[0]][len(found_in_gt)].append(snr[gt_index[i]])
different gt units:  20 [213, 213, 213, 146, 213, 213]
different gt units:  35 [189, 189, 185, 189, 189]
different gt units:  42 [224, 224, 224, 76, 224, 224]
different gt units:  46 [108, 108, 108, 90, 108, 108]
FP unit found by >1 sorter:  92
different gt units:  102 [153, 197, 197, 197, 197, 197]
different gt units:  103 [175, 175, 175, 11, 175, 175]
FP unit found by >1 sorter:  111
different gt units:  114 [157, 157, 107, 157, 157, 157]
different gt units:  149 [78, 99, 99, 99, 99]
FP unit found by >1 sorter:  150
FP unit found by >1 sorter:  158
different gt units:  162 [179, 196, 196, 196, 196, 196]
different gt units:  284 [185, 185, 185, 177, 185]
different gt units:  316 [90, 90, 90, 57, 90]
FP unit found by >1 sorter:  418
FP unit found by >1 sorter:  549
different gt units:  673 [129, 182]
In [13]:
w = sw.plot_multicomp_agreement(mcmp, plot_type='pie')
In [14]:
w = sw.plot_multicomp_agreement_by_sorter(mcmp, show_legend=True)
In [15]:
fig = plt.figure(figsize=(14,4))

axes = []
for i,s in enumerate(results['TP'].keys()):
    ax = plt.subplot(2,len(sorter_list), i+1)
    ax.bar(results['FP'][s].keys(), list(results['FP'][s].values()), alpha=0.5, width = 0.6, color='r', label='false positive')
    ax.bar(results['TP'][s].keys(), list(results['TP'][s].values()), bottom=list(results['FP'][s].values()), alpha=0.5, width = 0.6, color='b', label='matched')
    ax.set_xticks(range(1,len(sorter_list)+1))
    ax.set_xticklabels(range(1,len(sorter_list)+1))
    ax.set_title(s)
    clear_axes(ax)
    axes.append(ax)
    if i == 0:
        ax.set_ylabel('Number of units')
    else:
        ax.get_shared_y_axes().join(axes[0], ax)
        ax.set_yticklabels([])
    
    ax = plt.subplot(2,len(sorter_list), len(sorter_list)+i+1)
    d  = results['SNR'][s]
    sns.boxenplot(data=pd.DataFrame([pd.Series(d[k]) for k in d.keys()]).T, color='b', ax=ax)
    ax.set_xticks(range(0,len(sorter_list)))
    ax.set_xticklabels(range(1,len(sorter_list)+1))
    clear_axes(ax)
    axes.append(ax)
    if i == 0:
        ax.set_ylabel('Ground truth SNR')
        ax.set_xlabel('Found by # sorters')
    else:
        ax.get_shared_y_axes().join(axes[1], ax)
        ax.set_yticklabels([])
    
In [16]:
# numbers for figure above

sg_names, sg_units = mcmp.compute_subgraphs()
v, c = np.unique([len(np.unique(s)) for s in sg_names], return_counts=True)
df = pd.DataFrame(np.vstack((v,c,np.round(100*c/np.sum(c),2))).T,
             columns=('in # sorters','# units','percentage'))
print('all sorters, all units:')
print(df)
df = pd.DataFrame()
for i, name in enumerate(sorter_names_short):
    v, c = np.unique([len(np.unique(sn)) for sn in sg_names if name in sn], return_counts=True)
    cl = np.zeros(len(sorter_list), dtype=int)
    cl[v.astype(int)-1] = c
    df.insert(2*i,name,cl)
    df.insert(2*i+1,name+'%',np.round(100*cl/np.sum(cl),1))
print('\nper sorter:')
print(df)

for i,s in enumerate(results['TP'].keys()):
    print(s, list(results['FP'][s].values()))
all sorters, all units:
   in # sorters  # units  percentage
0           1.0    659.0       72.34
1           2.0     18.0        1.98
2           3.0     27.0        2.96
3           4.0     30.0        3.29
4           5.0     38.0        4.17
5           6.0    139.0       15.26

per sorter:
    HS   HS%   KS   KS%   IC   IC%   SC   SC%  TDC  TDC%  HDS  HDS%
0   68  29.2  168  40.5   50  17.7  129  37.6   18   9.5  226  49.5
1    2   0.9   13   3.1    6   2.1    5   1.5    2   1.1    8   1.8
2    2   0.9   27   6.5   21   7.4    7   2.0    0   0.0   24   5.3
3    2   0.9   30   7.2   30  10.6   28   8.2    2   1.1   28   6.1
4   20   8.6   38   9.2   37  13.1   35  10.2   28  14.8   32   7.0
5  139  59.7  139  33.5  139  49.1  139  40.5  139  73.5  139  30.4
HS [42, 1, 2, 0, 1, 0]
KS [164, 1, 2, 0, 1, 0]
IC [45, 2, 2, 0, 0, 0]
SC [97, 1, 0, 0, 1, 0]
TDC [11, 0, 0, 0, 1, 0]
HDS [214, 1, 0, 0, 1, 0]

Ensemble sorting of a Neuropixels recording

Ensemble sorting of a Neuropixels recording

This notebook reproduces figures 1 and 4 from the paper SpikeInterface, a unified framework for spike sorting.

The data set for this notebook is available on the Dandi Archive: https://gui.dandiarchive.org/#/dandiset/000034

The entire data archive can be downloaded with the command dandi download https://gui.dandiarchive.org/#/dandiset/000034/draft (about 75GB).

Files required to run the code are:

These files should be in the same directory where the notebook is located (otherwise adjust paths below).

Author: Matthias Hennig, University of Edinburgh, 24 Aug 2020

Requirements

For this need you will need the following Python packages:

  • numpy
  • pandas
  • matplotlib
  • seaborn
  • spikeinterface
  • dandi
  • matplotlib-venn

To run the MATLAB-based sorters, you would also need a MATLAB license. For other sorters, please refer to the documentation on how to install sorters.

In [1]:
import os

# Matlab sorter paths:
# change these to match your environment
os.environ["IRONCLUST_PATH"] = "./ironclust"
os.environ["KILOSORT2_PATH"] = "./Kilosort2"
os.environ["HDSORT_PATH"] = "./HDsort"

import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path
import pandas as pd
import seaborn as sns
from collections import defaultdict
from matplotlib_venn import venn3

import spikeinterface as si
import spikeextractors as se
import spiketoolkit as st
import spikesorters as ss
import spikecomparison as sc
import spikewidgets as sw
from spikecomparison import GroundTruthStudy, MultiSortingComparison

%matplotlib inline

def clear_axes(ax):
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    
# print version information
si.print_spikeinterface_version()
ss.print_sorter_versions()
spikeinterface: 0.9.9
  * spikeextractor: 0.8.4
  * spiketoolkit: 0.6.3
  * spikesorters: 0.3.3
  * spikecomparison: 0.2.6
  * spikewidgets: 0.4.3

hdsort: version = '1.0.2'

herdingspikes: 0.3.7
ironclust: 5.9.8
kilosort2: git-48bf2b81d8ad
klusta: 3.0.16
mountainsort4: unknown
spykingcircus: 0.9.7
tridesclous: 1.6.0
In [2]:
# where to find the data set
data_file = Path('./') / 'mouse412804_probeC_15min.nwb'
# results are stored here
study_path = Path('./')
# this folder will contain all results
study_folder = study_path / 'study_15min/'
# this folder will be used as temorary space, and hold the sortings etc.
working_folder = study_path / 'working_15min'

# sorters to use
sorter_list = ['herdingspikes', 'kilosort2', 'ironclust', 'tridesclous', 'spykingcircus', 'hdsort']
# pass the following parameters to the sorters
sorter_params = {
#     'kilosort2': {'keep_good_only': True}, # uncomment this to test the native filter for false positives
    'spyking_circus': {'adjacency_radius': 50},
    'herdingspikes': {'filter': True, }}
sorter_names = ['HerdingSpikes', 'Kilosort2', 'Ironclust','Tridesclous', 'SpykingCircus', 'HDSort']
sorter_names_short = ['HS', 'KS', 'IC', 'TDC', 'SC', 'HDS']
In [3]:
# create an extractor object for the raw data 
recording = se.NwbRecordingExtractor(str(data_file))

print("Number of frames: {}\nSampling rate: {}Hz\nNumber of channels: {}".format(
    recording.get_num_frames(), recording.get_sampling_frequency(), 
    recording.get_num_channels()))
Number of frames: 27000000
Sampling rate: 30000.0Hz
Number of channels: 248

Run spike sorters and perform comparison between all outputs

In [4]:
# set up the study environment and run all sorters
# sorters are not re-run if outputs are found in working_folder 

if not study_folder.is_dir():
    print('Setting up study folder:', study_folder)
    os.mkdir(study_folder)

# run all sorters
result_dict = ss.run_sorters(sorter_list=sorter_list, recording_dict_or_list={'rec': recording}, with_output=True,
                             sorter_params=sorter_params, working_folder=working_folder, engine='loop', 
                             mode='keep', verbose=True)

# store sortings in a list for quick access
sortings = []
for s in sorter_list:
    sortings.append(result_dict['rec',s])
/disk/scratch/mhennig/miniconda3/envs/spikesorting/lib/python3.7/site-packages/spikeextractors/extractors/matsortingextractor/matsortingextractor.py:65: ResourceWarning: unclosed file <_io.BufferedReader name='/disk/scratch/mhennig/spikeinterface/Neuropixels_Allen/working_15min/rec/hdsort/hdsort_output/hdsort_output_results.mat'>
  raise ImportError("Version 7.2 .mat file given, but you don't have h5py installed.")
ResourceWarning: Enable tracemalloc to get the object allocation traceback
In [5]:
# perform a multi-comparison, all to all sortings
# result is stored, and loaded from disk if the file is found

if not os.path.isfile(study_folder / 'multicomparison.gpickle'):
    mcmp = sc.compare_multiple_sorters(sorting_list=sortings, name_list=sorter_names_short, 
                                       verbose=True)
    print('saving multicomparison')
    mcmp.dump(study_folder)
else:
    print('loading multicomparison')
    mcmp = sc.MultiSortingComparison.load_multicomparison(study_folder)
loading multicomparison
Multicomaprison step 3: clean graph
Removed edge ('IC_137', 'HS_49', {'weight': 0.552943287867911})
Removed edge ('KS_185', 'HDS_18005', {'weight': 0.6143539400371452})
Removed edge ('KS_295', 'HS_159', {'weight': 0.6111111111111112})
Removed edge ('KS_195', 'TDC_8', {'weight': 0.6909090909090909})
Removed 4 duplicate nodes
Multicomaprison step 4: extract agreement from graph

Figure 1 - comparison of sorter outputs

In [6]:
# activity levels on the probe

plt.figure(figsize=(16,2))
ax = plt.subplot(111)
w = sw.plot_activity_map(recording, trange=(0,20), transpose=True, ax=ax, background='w', frame=True)
ax.plot((50,150),(-30,-30),'k-')
ax.annotate('100$\\mu m$',(100,-90), ha='center');
In [7]:
# example data traces

plt.figure(figsize=(16,6))
ax = plt.subplot(111)
w = sw.plot_timeseries(recording, channel_ids=range(20,28), color='k', ax=ax, trange=(1,2))
ax.axis('off')
p = ax.get_position()
p.y0 = 0.55
ax.set_position(p)
ax.set_xticks(())
ax.plot((1.01,1.11),(-1790,-1790),'k-')
ax.annotate('100ms',(1.051,-2900), ha='center');
ax.set_ylim((-2900,ax.set_ylim()[1]))
Out[7]:
(-2900, 19056.0)
In [8]:
ax = plt.subplot(111)
ax.bar(range(len(sortings)), [len(s.get_unit_ids()) for s in sortings], color='tab:blue')
ax.set_xticks(range(len(sorter_names)))
ax.set_xticklabels(sorter_names_short, rotation=60, ha='center')
ax.set_ylabel('Units detected') 
clear_axes(ax)
In [9]:
w = sw.plot_multicomp_agreement(mcmp, plot_type='pie')
In [10]:
w = sw.plot_multicomp_agreement_by_sorter(mcmp, show_legend=True)
In [11]:
# numbers for figure above

print('number of units detected:')
for i,s in enumerate(sortings):
    print("{}: {}".format(sorter_names[i],len(s.get_unit_ids())))

sg_names, sg_units = mcmp.compute_subgraphs()
v, c = np.unique([len(np.unique(s)) for s in sg_names], return_counts=True)
df = pd.DataFrame(np.vstack((v,c,np.round(100*c/np.sum(c),2))).T,
             columns=('in # sorters','# units','percentage'))
print('\nall sorters, all units:')
print(df)
df = pd.DataFrame()
for i, name in enumerate(sorter_names_short):
    v, c = np.unique([len(np.unique(sn)) for sn in sg_names if name in sn], return_counts=True)
    df.insert(2*i,name,c)
    df.insert(2*i+1,name+'%',np.round(100*c/np.sum(c),1))
print('\nper sorter:')
print(df)
number of units detected:
HerdingSpikes: 210
Kilosort2: 448
Ironclust: 234
Tridesclous: 187
SpykingCircus: 628
HDSort: 317

all sorters, all units:
   in # sorters  # units  percentage
0           1.0   1093.0       80.60
1           2.0     89.0        6.56
2           3.0     52.0        3.83
3           4.0     49.0        3.61
4           5.0     40.0        2.95
5           6.0     33.0        2.43

per sorter:
    HS   HS%   KS   KS%  IC   IC%  TDC  TDC%   SC   SC%  HDS  HDS%
0  139  66.2  226  50.7  66  28.3   82  43.9  396  63.1  184  58.0
1    4   1.9   57  12.8  15   6.4    8   4.3   67  10.7   27   8.5
2    6   2.9   44   9.9  34  14.6    5   2.7   47   7.5   20   6.3
3    8   3.8   46  10.3  45  19.3   21  11.2   46   7.3   30   9.5
4   20   9.5   40   9.0  40  17.2   38  20.3   39   6.2   23   7.3
5   33  15.7   33   7.4  33  14.2   33  17.6   33   5.3   33  10.4

Supplemental Figure - example unit templates

In [12]:
# show unit emplates and spike trains for two units/all sorters

sorting = mcmp.get_agreement_sorting(minimum_agreement_count=6)

get_sorting = lambda u: [mcmp.sorting_list[i] for i,n in enumerate(mcmp.name_list) if n==u[0]][0]
get_spikes = lambda u: [mcmp.sorting_list[i].get_unit_spike_train(u[1]) for i,n in enumerate(mcmp.name_list) if n==u[0]][0]


# one well matched and one not so well matched unit, all sorters
show_units = [2,17]

for i,unit in enumerate(show_units):
    fig = plt.figure(figsize=(16, 2))
    ax = plt.subplot(111)
    ax.set_title('Average agreement: {:.2f}'.format(sorting.get_unit_property(sorting.get_unit_ids()[unit],'avg_agreement')))
    units = sorting.get_unit_property(sorting.get_unit_ids()[unit], 'sorter_unit_ids')
    cols = plt.cm.Accent(np.arange(len(units))/len(units))
    for j,u in enumerate(dict(sorted(units.items())).items()):
        s = get_sorting(u).get_units_spike_train((u[1],))[0]
        s = s[s<20*get_sorting(u).get_sampling_frequency()]
        ax.plot(s/get_sorting(u).get_sampling_frequency(), np.ones(len(s))*j, '|', color=cols[j], label=u[0])
    ax.set_frame_on(False)
    ax.set_xticks(())
    ax.set_yticks(())
    ax.plot((0,1),(-1,-1),'k')
    ax.annotate('1s',(0.5,-1.75), ha='center')
    ax.set_ylim((-2,len(units)+1))
    
    fig = plt.figure(figsize=(16, 2))
    units = sorting.get_unit_property(sorting.get_unit_ids()[unit], 'sorter_unit_ids')
    print(units)
    print('Agreement: {}'.format(sorting.get_unit_property(sorting.get_unit_ids()[unit],'avg_agreement')))
    cols = plt.cm.Accent(np.arange(len(units))/len(units))
    for j,u in enumerate(dict(sorted(units.items())).items()):
        ax = plt.subplot(1, len(sorter_list), j+1)
        w = sw.plot_unit_templates(recording, get_sorting(u), unit_ids=(u[1],), max_spikes_per_unit=10, 
                               channel_locs=True, radius=75, show_all_channels=False, color=[cols[j]], 
                               lw=1.5, ax=ax, plot_channels=False, set_title=False, axis_equal=True) 
        # was 100 spikes in original plot
        ax.set_title(u[0])
{'SC': 563, 'TDC': 3, 'HDS': 44005, 'HS': 5, 'KS': 116, 'IC': 88}
Agreement: 0.9661710270732794
{'HS': 49, 'TDC': 15, 'KS': 189, 'HDS': 18005, 'IC': 137, 'SC': 272}
Agreement: 0.6920718057618236

Figure 4 - comparsion between ensembe sortings and curated data

In [13]:
# perform a comparison with curated sortings (KS2)

curated1 = se.NwbSortingExtractor('sub-mouse412804_ses-20200824T155542.nwb', sampling_frequency=30000)
curated2 = se.NwbSortingExtractor('sub-mouse412804_ses-20200824T155543.nwb', sampling_frequency=30000)

comparison_curated = sc.compare_two_sorters(curated1, curated2)
comparison_curated_ks = sc.compare_multiple_sorters((curated1, curated2, sortings[sorter_list.index('kilosort2')]))

# consensus sortings (units where at least 2 sorters agree)
sorting = mcmp.get_agreement_sorting(minimum_agreement_count=2)
consensus_sortings = []
units_dict = defaultdict(list)
units = [sorting.get_unit_property(u,'sorter_unit_ids') for u in sorting.get_unit_ids()]
for au in units:
    for u in au.items():
        units_dict[u[0]].append(u[1])
for i,s in enumerate(sorter_names_short):
    consensus_sortings.append(se.SubSortingExtractor(sortings[i], unit_ids=units_dict[s]))

# orphan units (units found by only one sorter)
sorting = mcmp.get_agreement_sorting(minimum_agreement_count=1, minimum_agreement_count_only=True)
unmatched_sortings = []
units_dict = defaultdict(list)
units = [sorting.get_unit_property(u,'sorter_unit_ids') for u in sorting.get_unit_ids()]
for au in units:
    for u in au.items():
        units_dict[u[0]].append(u[1])
for i,s in enumerate(sorter_names_short):
    unmatched_sortings.append(se.SubSortingExtractor(sortings[i], unit_ids=units_dict[s]))
    
consensus_curated_comparisons = []
for s in consensus_sortings:
    consensus_curated_comparisons.append(sc.compare_two_sorters(s, curated1))
    consensus_curated_comparisons.append(sc.compare_two_sorters(s, curated2))    
    
unmatched_curated_comparisons = []
for s in unmatched_sortings:
    unmatched_curated_comparisons.append(sc.compare_two_sorters(s, curated1))
    unmatched_curated_comparisons.append(sc.compare_two_sorters(s, curated2))

all_curated_comparisons = []
for s in sortings:
    all_curated_comparisons.append(sc.compare_two_sorters(s, curated1))
    all_curated_comparisons.append(sc.compare_two_sorters(s, curated2))    \
    
# count various types of units

count_mapped = lambda x : np.sum([u!=-1 for u in x.get_mapped_unit_ids()])
count_not_mapped = lambda x : np.sum([u==-1 for u in x.get_mapped_unit_ids()])
count_units = lambda x : len(x.get_unit_ids())

n_consensus_curated_mapped = np.array([count_mapped(c.get_mapped_sorting1()) for c in consensus_curated_comparisons]).reshape((len(sorter_list),2))
n_consensus_curated_unmapped = np.array([count_not_mapped(c.get_mapped_sorting1()) for c in consensus_curated_comparisons]).reshape((len(sorter_list),2))
n_unmatched_curated_mapped = np.array([count_mapped(c.get_mapped_sorting1()) for c in unmatched_curated_comparisons]).reshape((len(sorter_list),2))
n_all_curated_mapped = np.array([count_mapped(c.get_mapped_sorting1()) for c in all_curated_comparisons]).reshape((len(sorter_list),2))
n_all_curated_unmapped = np.array([count_not_mapped(c.get_mapped_sorting1()) for c in all_curated_comparisons]).reshape((len(sorter_list),2))
n_curated_all_unmapped = np.array([count_not_mapped(c.get_mapped_sorting2()) for c in all_curated_comparisons]).reshape((len(sorter_list),2))
n_all = np.array([count_units(s) for s in sortings])
n_consensus = np.array([count_units(s) for s in consensus_sortings])
n_unmatched = np.array([count_units(s) for s in unmatched_sortings])
n_curated1 = len(curated1.get_unit_ids())
n_curated2 = len(curated2.get_unit_ids())
In [14]:
# overlap between two manually curated data and the Kilosort2 sorting they were derived from

i = {}
for k in ['{0:03b}'.format(v) for v in range(1,2**3)]:
    i[k] = 0
i['111'] = len(comparison_curated_ks.get_agreement_sorting(minimum_agreement_count=3).get_unit_ids())
s = comparison_curated_ks.get_agreement_sorting(minimum_agreement_count=2, minimum_agreement_count_only=True)
units = [s.get_unit_property(u,'sorter_unit_ids').keys() for u in s.get_unit_ids()]
for u in units:
    if 'sorting1' in u and 'sorting2' in u:
        i['110'] += 1
    if 'sorting1' in u and 'sorting3' in u:
        i['101'] += 1
    if 'sorting2' in u and 'sorting3' in u:
        i['011'] += 1   
s = comparison_curated_ks.get_agreement_sorting(minimum_agreement_count=1, minimum_agreement_count_only=True)
units = [s.get_unit_property(u,'sorter_unit_ids').keys() for u in s.get_unit_ids()]
for u in units:
    if 'sorting1' in u:
        i['100'] += 1
    if 'sorting2' in u:
        i['010'] += 1
    if 'sorting3' in u:
        i['001'] += 1   
colors = plt.cm.RdYlBu(np.linspace(0,1,3))
venn3(subsets = i,set_labels=('Curated 1', 'Curated 2', 'Kilosort2'), 
      set_colors=colors, alpha=0.6, normalize_to=100)
Out[14]:
<matplotlib_venn._common.VennDiagram at 0x7fadfef520b8>
In [15]:
# overlaps betweem ensemble sortings (per sorter) and manually curated sortings

def plot_mcmp_results(data, labels, ax, ylim=None, yticks=None, legend=