Ensemble sorting of a 3Brain Biocam recording from a retina

Ensemble sorting of a 3Brain Biocam recording from a mouse retina

This notebook reproduces supplemental figure S3 from the paper SpikeInterface, a unified framework for spike sorting.

The recording was made by Gerrit Hilgen in the lab of Evelyne Sernagor, University of Newcastle.

The data set for this notebook is available on the Dandi Archive: https://gui.dandiarchive.org/#/dandiset/000034

The entire data archive can be downloaded with the command dandi download https://gui.dandiarchive.org/#/dandiset/000034/draft (about 75GB).

File required to run the code:

The file should be in the same directory where the notebook is located (otherwise adjust paths below).

Author: Matthias Hennig, University of Edinburgh, 24 Aug 2020

Requirements

For this need you will need the following Python packages:

  • numpy
  • pandas
  • matplotlib
  • seaborn
  • spikeinterface
  • dandi

To run the MATLAB-based sorters, you would also need a MATLAB license. For other sorters, please refer to the documentation on how to install sorters.

In [1]:
import os

# Matlab sorter paths:
# change these to match your environment
os.environ["IRONCLUST_PATH"] = "./ironclust"
os.environ["KILOSORT2_PATH"] = "./Kilosort2"
os.environ["HDSORT_PATH"] = "./HDsort"

import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path
import pandas as pd
import seaborn as sns
from collections import defaultdict
from matplotlib_venn import venn3

import spikeinterface as si
import spikeextractors as se
import spiketoolkit as st
import spikesorters as ss
import spikecomparison as sc
import spikewidgets as sw
from spikecomparison import GroundTruthStudy, MultiSortingComparison

%matplotlib inline

def clear_axes(ax):
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    
# print version information
si.print_spikeinterface_version()
ss.print_sorter_versions()
spikeinterface: 0.9.9
  * spikeextractor: 0.8.4
  * spiketoolkit: 0.6.3
  * spikesorters: 0.3.3
  * spikecomparison: 0.2.6
  * spikewidgets: 0.4.3

hdsort: version = '1.0.2'

herdingspikes: 0.3.7
ironclust: 5.9.8
kilosort2: git-48bf2b81d8ad
klusta: 3.0.16
mountainsort4: unknown
spykingcircus: 0.9.7
tridesclous: 1.6.0
In [2]:
# the recording file, downloaded from Dandi in NWB:N format 
data_file =  './sub-P29-16-05-14-retina02-left_ecephys.nwb'

# paths for spike sorter outputs
p = Path('./')

# select spike sorters to be used, note Kilosort2 requires a NVIDIA GPU to run
sorter_list = ['herdingspikes', 'kilosort2', 'ironclust', 'hdsort']
sorter_params = {'spyking_circus': {'adjacency_radius': 50}, 'herdingspikes': {'filter': True, }}
sorter_names = ['HerdingSpikes', 'Kilosort2', 'Ironclust',  'HDSort']
sorter_names_short = ['HS', 'KS', 'IC', 'HDS']

# create a recording extractor, this gives access to the raw data in the NWB:N file
recording = se.NwbRecordingExtractor(data_file)

# this recording has some bad artefcats, and these are easily blanked out using spiketoolkit
# note that no spike sorter can handle these artefacts well, so this step is essential
recording = st.preprocessing.blank_saturation(recording)
# (aside: the warning below can be ignored, artefacts are quite sparse but definetely there, the check
# if very conservative)

# NWB:N files store the data in (channels:time) order, but for spike sorting the transposed format is much
# more efficient. Therefore here we create a CacheRecordingExtractor that re-writes the data
# as a binary file in the desired order. This will take some time, but speeds up subsequent steps:
recording = se.CacheRecordingExtractor(recording)

# print some info
print("Sampling rate: {}Hz".format(recording.get_sampling_frequency()))
print("Duration: {}s".format(recording.get_num_frames()/recording.get_sampling_frequency()))
print("Number of channels: {}".format(recording.get_num_channels()))
Warning, narrow signal range suggests artefact-free data.
Sampling rate: 23199.090358491783Hz
Duration: 135.7303218075276s
Number of channels: 1024

Run spike sorters and perform comparison between all outputs

In [3]:
# now create the study environment and run all spike sorters
# note that this function will not re-run a spike sorter if the sorting is already present in
# the working folder

study_folder = p / 'study/'
working_folder = p / 'working/'
if not study_folder.is_dir():
    print('Setting up study folder:', study_folder)
    os.mkdir(study_folder)
rec_dict = {'rec': recording}

result_dict = ss.run_sorters(sorter_list=sorter_list, recording_dict_or_list=rec_dict, with_output=True,
                             sorter_params=sorter_params, working_folder=working_folder, engine='loop', 
                             mode='keep', verbose=True)

# when done, load all sortings into a handly list
sortings = []
for s in sorter_list:
    sortings.append(result_dict['rec',s])
/disk/scratch/mhennig/miniconda3/envs/spikesorting/lib/python3.7/site-packages/spikeextractors/extractors/matsortingextractor/matsortingextractor.py:65: ResourceWarning: unclosed file <_io.BufferedReader name='/disk/scratch/mhennig/spikeinterface/retina_short/working/rec/hdsort/hdsort_output/hdsort_output_results.mat'>
  raise ImportError("Version 7.2 .mat file given, but you don't have h5py installed.")
ResourceWarning: Enable tracemalloc to get the object allocation traceback
In [4]:
# run a multi-sorting comparison, an all-to-all comparison
# results are saved and just loaded from a file if this exists

if not os.path.isfile(study_folder / 'multicomparison.gpickle'):
    mcmp = sc.compare_multiple_sorters(sorting_list=sortings, name_list=sorter_names_short, 
                                       verbose=True)
    print('saving multicomparison')
    mcmp.dump(study_folder)
else:
    print('loading multicomparison')
    mcmp = sc.MultiSortingComparison.load_multicomparison(study_folder)
loading multicomparison
Multicomaprison step 3: clean graph
Removed edge ('KS_244', 'HS_18', {'weight': 0.6169491525423729})
Removed edge ('IC_181', 'HDS_137003', {'weight': 0.9284436493738819})
Removed edge ('HDS_450004', 'KS_1160', {'weight': 0.5208791208791209})
Removed 3 duplicate nodes
Multicomaprison step 4: extract agreement from graph
In [5]:
# plot an activity map
# as the channel count is high, here we use the spikes detected by HS2
# they are easily retrieved from the sorting results

sx = result_dict['rec','herdingspikes']
n,v = np.histogram(sx._rf['ch'], bins=np.arange(1025))

ax = plt.subplot(111)
ax.imshow(n.reshape((32,32)))
ax.set_xticks(())
ax.set_yticks(())
ax.plot((10,10+100/42),(-1,-1),'k-')
ax.annotate('100$\\mu m$',(10+100/42/2,-2), ha='center');
ax.axis('off')
ax.set_aspect('equal')
In [6]:
# next we plot some raw data traces
# for better visualisation, we bandpass filter the traces before showing them
# to this end, we use a lazy bandpass filter from spiketoolkit
recording_bandpass = st.preprocessing.bandpass_filter(recording)
plt.figure(figsize=(12,3))
ax = plt.subplot(111)
w = sw.plot_timeseries(recording_bandpass, channel_ids=range(300,308), color='k', ax=ax, trange=(2,3))
ax.plot((2.01,2.11),(-50,-50),'k-');
ax.annotate('100ms',(2.051,-120), ha='center');
ax.axis('off');
Warning: dumping a CacheRecordingExtractor. The path to the tmp binary file will be lost in further sessions. To prevent this, use the 'CacheRecordingExtractor.move_to('path-to-file)' function
In [7]:
# number of units found by each sorter
ax = plt.subplot(111)
ax.bar(range(len(sortings)), [len(s.get_unit_ids()) for s in sortings])
ax.set_xticks(range(len(sorter_names)))
ax.set_xticklabels(sorter_names_short, rotation=60, ha='center')
ax.set_ylabel('Units detected') 
Out[7]:
Text(0, 0.5, 'Units detected')
In [8]:
# spikewidgets provides handy widgets to plot summary statistics of the comparison

# show the number of units agreed upon by k sorters, in aggregate
plt.figure()
ax = plt.subplot(111)
w = sw.plot_multicomp_agreement(mcmp, plot_type='pie', ax=ax)

# show the number of units agreed upon by k sorters, per sorter
plt.figure()
ax = plt.subplot(111)
w = sw.plot_multicomp_agreement_by_sorter(mcmp, show_legend=True, ax=ax)

Ensemble sorting of a Neuropixels recording 2

Ensemble sorting of a Neuropixel recording (2)

This notebook reproduces supplemental figure S2 from the paper SpikeInterface, a unified framework for spike sorting.

The recording was made by André Marques-Smith in the lab of Adam Kampff. Reference:

Marques-Smith, A., Neto, J.P., Lopes, G., Nogueira, J., Calcaterra, L., Frazão, J., Kim, D., Phillips, M., Dimitriadis, G., Kampff, A.R. (2018). Recording from the same neuron with high-density CMOS probes and patch-clamp: a ground-truth dataset and an experiment in collaboration. bioRxiv 370080; doi: https://doi.org/10.1101/370080

The data set for this notebook is available on the Dandi Archive: https://gui.dandiarchive.org/#/dandiset/000034

The entire data archive can be downloaded with the command dandi download https://gui.dandiarchive.org/#/dandiset/000034/draft (about 75GB).

File required to run the code:

This file should be in the same directory where the notebook is located (otherwise adjust paths below).

Author: Matthias Hennig, University of Edinburgh, 25 Aug 2020

Requirements

For this need you will need the following Python packages:

  • numpy
  • pandas
  • matplotlib
  • seaborn
  • spikeinterface
  • dandi

To run the MATLAB-based sorters, you would also need a MATLAB license. For other sorters, please refer to the documentation on how to install sorters.

In [1]:
import os

# Matlab sorter paths:
# change these to match your environment
os.environ["IRONCLUST_PATH"] = "./ironclust"
os.environ["KILOSORT2_PATH"] = "./Kilosort2"
os.environ["HDSORT_PATH"] = "./HDsort"

import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path
import pandas as pd
import seaborn as sns
from collections import defaultdict
from matplotlib_venn import venn3

import spikeinterface as si
import spikeextractors as se
import spiketoolkit as st
import spikesorters as ss
import spikecomparison as sc
import spikewidgets as sw
from spikecomparison import GroundTruthStudy, MultiSortingComparison

%matplotlib inline

def clear_axes(ax):
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    
# print version information
si.print_spikeinterface_version()
ss.print_sorter_versions()
spikeinterface: 0.9.9
  * spikeextractor: 0.8.4
  * spiketoolkit: 0.6.3
  * spikesorters: 0.3.3
  * spikecomparison: 0.2.6
  * spikewidgets: 0.4.3

hdsort: version = '1.0.2'

herdingspikes: 0.3.7
ironclust: 5.9.8
kilosort2: git-48bf2b81d8ad
klusta: 3.0.16
mountainsort4: unknown
spykingcircus: 0.9.7
tridesclous: 1.6.0
In [2]:
# the recording file, downloaded from Dandi in NWB:N format 
data_file =  './sub-c1_ecephys.nwb'

# paths for spike sorter outputs
p = Path('./')

# select spike sorters to be used, note Kilosort2 requires a NVIDIA GPU to run
sorter_list = ['herdingspikes', 'kilosort2', 'ironclust', 'tridesclous', 'spykingcircus', 'hdsort']
sorter_params = {
#     'kilosort2': {'keep_good_only': True}, # removes good units!
    'mountainsort4': {'adjacency_radius': 50},
    'spyking_circus': {'adjacency_radius': 50},
    'herdingspikes': {'filter': True, 
                     }
}
sorter_names = ['HerdingSpikes', 'Kilosort2', 'Ironclust','Tridesclous', 'SpykingCircus', 'HDSort']
sorter_names_short = ['HS', 'KS', 'IC', 'TDC', 'SC', 'HDS']

# create a recording extractor, this gives access to the raw data in the NWB:N file
recording = se.NwbRecordingExtractor(data_file)

# NWB:N files store the data in (channels:time) order, but for spike sorting the transposed format is much
# more efficient. Therefore here we can create a CacheRecordingExtractor that re-writes the data
# as a binary file in the desired order. This will take some time, but speeds up subsequent steps:
# recording = se.CacheRecordingExtractor(recording)

# print some info
print("Sampling rate: {}Hz".format(recording.get_sampling_frequency()))
print("Duration: {}s".format(recording.get_num_frames()/recording.get_sampling_frequency()))
print("Number of channels: {}".format(recording.get_num_channels()))
Sampling rate: 30000.0Hz
Duration: 270.01123333333334s
Number of channels: 384

Run spike sorters and perform comparison between all outputs

In [3]:
# now create the study environment and run all spike sorters
# note that this function will not re-run a spike sorter if the sorting is already present in
# the working folder

study_folder = p / 'study/'
working_folder = p / 'working/'
if not study_folder.is_dir():
    print('Setting up study folder:', study_folder)
    os.mkdir(study_folder)
rec_dict = {'rec': recording}

result_dict = ss.run_sorters(sorter_list=sorter_list, recording_dict_or_list=rec_dict, with_output=True,
                             sorter_params=sorter_params, working_folder=working_folder, engine='loop', 
                             mode='keep', verbose=True)

# when done, load all sortings into a handly list
sortings = []
for s in sorter_list:
    sortings.append(result_dict['rec',s])
/disk/scratch/mhennig/miniconda3/envs/spikesorting/lib/python3.7/site-packages/spikeextractors/extractors/matsortingextractor/matsortingextractor.py:65: ResourceWarning: unclosed file <_io.BufferedReader name='/disk/scratch/mhennig/spikeinterface/neuropixels_ms/working/rec/hdsort/hdsort_output/hdsort_output_results.mat'>
  raise ImportError("Version 7.2 .mat file given, but you don't have h5py installed.")
ResourceWarning: Enable tracemalloc to get the object allocation traceback
In [4]:
# run a multi-sorting comparison, an all-to-all comparison
# results are saved and just loaded from a file if this exists

if not os.path.isfile(study_folder / 'multicomparison.gpickle'):
    mcmp = sc.compare_multiple_sorters(sorting_list=sortings, name_list=sorter_names_short, 
                                       verbose=True)
    print('saving multicomparison')
    mcmp.dump(study_folder)
else:
    print('loading multicomparison')
    mcmp = sc.MultiSortingComparison.load_multicomparison(study_folder)
loading multicomparison
Multicomaprison step 3: clean graph
Removed edge ('HDS_25017', 'KS_231', {'weight': 0.5272206303724928})
Removed edge ('IC_66', 'TDC_28', {'weight': 0.5032594524119948})
Removed edge ('KS_231', 'HS_9', {'weight': 0.5059360730593607})
Removed edge ('SC_119', 'TDC_28', {'weight': 0.5013054830287206})
Removed edge ('KS_71', 'HS_57', {'weight': 0.5238095238095238})
Removed 5 duplicate nodes
Multicomaprison step 4: extract agreement from graph
In [5]:
# plot an activity map
# the method uses a rather simpe (and slow) threshold spike detection

plt.figure(figsize=(16,2))
ax = plt.subplot(111)
w = sw.plot_activity_map(recording, transpose=True, ax=ax, background='w', frame=True)
ax.plot((50,150),(-40,-40),'k-')
ax.annotate('100$\\mu m$',(100,-115), ha='center');
In [6]:
# raw data traces

plt.figure(figsize=(12,3))
ax = plt.subplot(111)
w = sw.plot_timeseries(recording, channel_ids=range(160,168), color='k', ax=ax, trange=(7,8))
ax.axis('off')
ax.plot((7.01,7.11),(20,20),'k-')
ax.annotate('100ms',(7.051,-190), ha='center');
In [7]:
# number of units found by each sorter

ax = plt.subplot(111)
ax.bar(range(len(sortings)), [len(s.get_unit_ids()) for s in sortings])
ax.set_xticks(range(len(sorter_names)))
ax.set_xticklabels(sorter_names_short, rotation=60, ha='center')
ax.set_ylabel('Units detected') 
Out[7]:
Text(0, 0.5, 'Units detected')
In [8]:
# spikewidgets provides handy widgets to plot summary statistics of the comparison

# show the number of units agreed upon by k sorters, in aggregate
plt.figure()
ax = plt.subplot(111)
w = sw.plot_multicomp_agreement(mcmp, plot_type='pie', ax=ax)

# show the number of units agreed upon by k sorters, per sorter
plt.figure()
ax = plt.subplot(111)
w = sw.plot_multicomp_agreement_by_sorter(mcmp, show_legend=True, ax=ax)

Ground truth comparison and ensemble sorting of a synthetic Neuropixels recording

Ground-truth comparison and ensemble sorting of a synthetic Neuropixels recording

This notebook reproduces figures 2 and 3 from the paper SpikeInterface, a unified framework for spike sorting.

The data set for this notebook is available on the Dandi Archive: https://gui.dandiarchive.org/#/dandiset/000034.

The entire data archive can be downloaded with the command dandi download https://gui.dandiarchive.org/#/dandiset/000034/draft (about 75GB).

The data file required to run the code is:

This file should be in the same directory where the notebook is located (otherwise adjust paths below).

Author: Matthias Hennig, University of Edinburgh, 22 Aug 2020

Requirements

For this need you will need the following Python packages:

  • numpy
  • pandas
  • matplotlib
  • seaborn
  • spikeinterface
  • dandi
  • matplotlib-venn

To run the MATLAB-based sorters, you would also need a MATLAB license. For other sorters, please refer to the documentation on how to install sorters.

In [1]:
import os

# Matlab sorter paths:
# change these to match your environment
os.environ["IRONCLUST_PATH"] = "./ironclust"
os.environ["KILOSORT2_PATH"] = "./Kilosort2"
os.environ["HDSORT_PATH"] = "./HDsort"

import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path
import pandas as pd
import seaborn as sns
from collections import defaultdict
from matplotlib_venn import venn3

import spikeinterface as si
import spikeextractors as se
import spiketoolkit as st
import spikesorters as ss
import spikecomparison as sc
import spikewidgets as sw
from spikecomparison import GroundTruthStudy

%matplotlib inline

def clear_axes(ax):
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    
# print version information
si.print_spikeinterface_version()
ss.print_sorter_versions()
spikeinterface: 0.9.9
  * spikeextractor: 0.8.4
  * spiketoolkit: 0.6.3
  * spikesorters: 0.3.3
  * spikecomparison: 0.2.6
  * spikewidgets: 0.4.3

hdsort: version = '1.0.2'

herdingspikes: 0.3.7
ironclust: 5.9.8
kilosort2: git-48bf2b81d8ad
klusta: 3.0.16
mountainsort4: unknown
spykingcircus: 0.9.7
tridesclous: 1.6.0

Set up ground truth study an run all sorters

In [2]:
study_path = Path('.')
data_path = Path('.')
study_folder = study_path / 'study_mearec_250cells_Neuropixels-384chans_duration600s_noise10uV_2020-02-28/'

# the original data
# this NWB file contains both the ground truth spikes and the raw data
data_filename = data_path / 'sub-MEAREC-250neuron-Neuropixels_ecephys.nwb'
SX_gt = se.NwbSortingExtractor(str(data_filename))
RX = se.NwbRecordingExtractor(str(data_filename))

if not os.path.isdir(study_folder):
    gt_dict = {'rec0' : (RX, SX_gt) }
    study = GroundTruthStudy.create(study_folder, gt_dict)
else:
    study = GroundTruthStudy(study_folder)

sorter_list = ['herdingspikes', 'kilosort2', 'ironclust',
               'spykingcircus', 'tridesclous', 'hdsort']
sorter_names = ['HerdingSpikes', 'Kilosort2', 'Ironclust',
               'SpykingCircus', 'Tridesclous', 'HDSort']
sorter_names_short = ['HS', 'KS', 'IC', 'SC', 'TDC', 'HDS']

study.run_sorters(sorter_list, mode='keep', engine='loop', verbose=True)

study.copy_sortings()

# compute or load SNR for the ground truth units
snr_file = study_folder / 'snr.npy'
if os.path.isfile(snr_file):
    snr = np.load(snr_file)
else:
    print('computing snr')
    # note this is quite slow for a NWB file as the data is arranged as channels:time
    # it is faster to first write out a binary file in time:channels order
    snr = st.validation.compute_snrs(SX_gt, RX, apply_filter=False, verbose=False, 
                                     memmap=True, max_spikes_per_unit_for_snr=500)
    np.save(snr_file, snr)
/disk/scratch/mhennig/miniconda3/envs/spikesorting/lib/python3.7/site-packages/spikeextractors/extractors/matsortingextractor/matsortingextractor.py:65: ResourceWarning: unclosed file <_io.BufferedReader name='/disk/scratch/mhennig/spikeinterface/paper/MEArec/study_mearec_250cells_Neuropixels-384chans_duration600s_noise10uV_2020-02-28/sorter_folders/rec0/hdsort/hdsort_output/hdsort_output_results.mat'>
  raise ImportError("Version 7.2 .mat file given, but you don't have h5py installed.")
ResourceWarning: Enable tracemalloc to get the object allocation traceback

Run the ground truth comparison and summarise the results

In [3]:
study.run_comparisons(exhaustive_gt=True, match_score=0.1)
comparisons = study.comparisons
dataframes = study.aggregate_dataframes()
In [4]:
# comparison summary
dataframes['count_units']
Out[4]:
rec_name sorter_name num_gt num_sorter num_well_detected num_redundant num_overmerged num_false_positive num_bad
0 rec0 ironclust 250 283 202 8 2 41 49
1 rec0 spykingcircus 250 343 175 43 1 56 99
2 rec0 tridesclous 250 189 135 3 5 9 12
3 rec0 hdsort 250 457 196 4 3 210 214
4 rec0 kilosort2 250 415 245 21 2 147 168
5 rec0 herdingspikes 250 233 128 6 3 39 45

Figure 1 - ground truth study results

In [5]:
# activity levels on the probe

plt.figure(figsize=(16,2))
ax = plt.subplot(111)
w = sw.plot_activity_map(RX, trange=(0,20), transpose=True, ax=ax, background='w', frame=True)
ax.plot((-1800,-1700), (-120,-120), 'k-')
ax.annotate('100$\\mu m$',(-1750,-220), ha='center');
In [6]:
# example data traces

plt.figure(figsize=(16,6))
ax = plt.subplot(111)
w = sw.plot_timeseries(RX, channel_ids=range(10,18), color='k', ax=ax, trange=(1,2))
ax.axis('off')
p = ax.get_position()
p.y0 = 0.58
ax.set_position(p)
ax.set_xticks(())
ax.plot((1.01,1.11),(-400,-400),'k-')
ax.annotate('100ms',(1.051,-750), ha='center');
ax.set_ylim((-750,ax.set_ylim()[1]))
Out[6]:
(-750, 4435.259765625)
In [7]:
ax = plt.subplot(111)
n = []
for s in sorter_list:
    n.append(len(study.get_sorting(s).get_unit_ids()))
ax.bar(range(len(sorter_list)), n, color='tab:blue')
ax.set_xticks(range(len(sorter_names_short)))
ax.set_xticklabels(sorter_names_short, rotation=60, ha='center')
ax.set_ylabel('Units detected') 
clear_axes(ax)
In [8]:
ax = plt.subplot(111)
p = ax.get_position()
p.x1=0.85
ax.set_position(p)
sns.set_palette(sns.color_palette("Set1"))
df = pd.melt(dataframes['perf_by_units'], id_vars='sorter_name', var_name='Metric', value_name='Score', 
        value_vars=('accuracy','precision', 'recall'))
sns.swarmplot(data=df, x='sorter_name', y='Score', hue='Metric', dodge=True,
                order=sorter_list,  s=3, ax=ax)
ax.set_xticklabels(sorter_names_short, rotation=30, ha='center')
ax.legend(bbox_to_anchor=(1.0, 1), loc=2, borderaxespad=0., frameon=False, fontsize=8, markerscale=0.5)
ax.set_xlabel(None);
ax.set_ylabel('Score');
clear_axes(ax)
/disk/scratch/mhennig/miniconda3/envs/spikesorting/lib/python3.7/site-packages/matplotlib/axes/_axes.py:4204: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
  (isinstance(c, collections.Iterable) and
In [9]:
ax = plt.subplot(111)
p = ax.get_position()
p.x1=0.85
ax.set_position(p)
df = pd.melt(dataframes['count_units'], id_vars='sorter_name', var_name='Type', value_name='Units', 
        value_vars=('num_well_detected', 'num_false_positive', 'num_redundant', 'num_overmerged'))
sns.set_palette(sns.color_palette("Set1"))
sns.barplot(x='sorter_name', y='Units', hue='Type', data=df,
                order=sorter_list, ax=ax)
ax.set_xticklabels(sorter_names_short, rotation=30, ha='right')
ax.legend(bbox_to_anchor=(1.0, 1), loc=2, borderaxespad=0., frameon=False, fontsize=8, markerscale=0.1)
for t, l in zip(ax.legend_.texts, ("Well detected", "False positive", "Redundant", "Overmerged")): t.set_text(l)
ax.set_xlabel(None);
clear_axes(ax)
In [10]:
# precision vs. recall and accuracy vs. SNR
fig = plt.figure(figsize=(14, 4))

sns.set_palette(sns.color_palette("deep"))

axesA = []
for i,s in enumerate(sorter_list):
    ax = plt.subplot(2,len(sorter_list),i+1)
    axesA.append(ax)

    g = sns.scatterplot(data=dataframes['perf_by_units'].loc[dataframes['perf_by_units'].sorter_name==s], 
                    x='precision', y='recall', s=30, edgecolor=None, alpha=0.1)
    ax.set_title(sorter_names[i])
    ax.set_aspect('equal')
    clear_axes(ax)
    ax.set_xlabel('Precision')
    ax.set_ylabel('Recall')

for ax in axesA[1:]:
    axesA[0].get_shared_y_axes().join(axesA[0], ax)
    ax.set_yticklabels([])
    ax.set_ylabel('')
    ax.set_xlabel('')
    ax.autoscale()
    
############### B

df = dataframes['perf_by_units']

# add snr to the by-unit table
if 'snr' not in df.keys():
    snr_d = {k:snr[k] for i,k in enumerate(SX_gt.get_unit_ids())}
    df['snr'] = df['gt_unit_id'].map(snr_d)
    

axesB = []
for i,s in enumerate(sorter_list):
    ax = plt.subplot(2,len(sorter_list),len(sorter_list)+i+1)
    axesB.append(ax)
    
    g = sns.scatterplot(data=dataframes['perf_by_units'].loc[dataframes['perf_by_units'].sorter_name==s], 
                        x='snr', y='accuracy', s=30, alpha=0.2)
    clear_axes(ax)
    ax.set_xlabel('Ground truth SNR')
    ax.set_ylabel('Accuracy')
    
for ax in axesB[1:]:
    axesB[0].get_shared_y_axes().join(axesB[0], ax)
    ax.set_yticklabels([])
    ax.set_ylabel('')
    ax.set_xlabel('')
    ax.autoscale()
    

Figure 3 - comparison of sorter outputs and ensembe sorting

In [11]:
# perform an all-to-all multicomparison or load it from
# disk if file exists

sortings = []
for s in sorter_list:
    sortings.append(study.get_sorting(s))

cmp_folder = study_folder / 'multicomparison/'
if not os.path.isdir(cmp_folder):
    os.mkdir(cmp_folder)
if not os.path.isfile(cmp_folder / 'multicomparison.gpickle'):
    mcmp = sc.compare_multiple_sorters(sorting_list=sortings, name_list=sorter_names_short, 
                                   verbose=False, match_score=0.5)
    print('saving multicomparison')
    mcmp.dump(cmp_folder)
else:
    print('loading multicomparison')
    mcmp = sc.MultiSortingComparison.load_multicomparison(cmp_folder)
    
mcmp_graph = mcmp.graph.copy()
loading multicomparison
In [12]:
# get sorting extractors with unit with no agreement (minimum agreement = 1) and one
# with at least 2 sorters in agreement
not_in_agreement = mcmp.get_agreement_sorting(minimum_agreement_count=1, minimum_agreement_count_only=True)
in_agreement = mcmp.get_agreement_sorting(minimum_agreement_count=2)

# score these against ground truth
cmp_no_agr = sc.compare_sorter_to_ground_truth(SX_gt, not_in_agreement)
cmp_agr = sc.compare_sorter_to_ground_truth(SX_gt, in_agreement)

# now collect results for each sorter:

# create dict to collect results
results = {'TP':{}, 'FP':{}, 'SNR':{}}
ns = len(sorter_names_short)
for s in sorter_names_short:
    results['TP'][s] = dict(zip(range(1,ns+1), [0]*(ns+1)))
    results['FP'][s] = dict(zip(range(1,ns+1), [0]*(ns+1)))
    results['SNR'][s] = dict(zip(range(1,ns+1), [[]]*(ns+1)))
    
# sorter names
dict_names = dict(zip(sorter_names_short, sorter_list))

# iterate over all units gathered from subgraphs
for u in  mcmp._new_units.keys():
    found_in_gt = []
    gt_index = []
    # check if units have a match in ground truth, store boolen
    for u2 in mcmp._new_units[u]['sorter_unit_ids'].items():
        found_in_gt.append(u2[1] in study.comparisons['rec0',dict_names[u2[0]]].best_match_12.values)
        if found_in_gt[-1]:
            gt_index.append(np.where(study.comparisons['rec0',dict_names[u2[0]]].best_match_12==u2[1])[0][0])
    if len(set(gt_index))>1:
        print('different gt units: ',u, gt_index)
    if np.sum(found_in_gt)==len(found_in_gt):
#     if np.sum(found_in_gt)>0:#==len(found_in_gt):  # use this if interested in equal matches
        key = 'TP'
    else:
        key = 'FP'
        if len(found_in_gt)>1:
            print('FP unit found by >1 sorter: ',u)
        
    for i,u2 in enumerate(mcmp._new_units[u]['sorter_unit_ids'].items()):
#         results[key][u2[0]][np.sum(found_in_gt)] += 1 # use this if interested in equal matches
        results[key][u2[0]][len(found_in_gt)] += 1
        if key is 'TP':
            # something odd with nested oython dicts requires this:
            d = results['SNR'][u2[0]][len(found_in_gt)].copy()
            d.append(snr[gt_index[i]])
            results['SNR'][u2[0]][len(found_in_gt)] = d
            # this fails, I wonder why:
            # results['SNR'][u2[0]][len(found_in_gt)].append(snr[gt_index[i]])
different gt units:  20 [213, 213, 213, 146, 213, 213]
different gt units:  35 [189, 189, 185, 189, 189]
different gt units:  42 [224, 224, 224, 76, 224, 224]
different gt units:  46 [108, 108, 108, 90, 108, 108]
FP unit found by >1 sorter:  92
different gt units:  102 [153, 197, 197, 197, 197, 197]
different gt units:  103 [175, 175, 175, 11, 175, 175]
FP unit found by >1 sorter:  111
different gt units:  114 [157, 157, 107, 157, 157, 157]
different gt units:  149 [78, 99, 99, 99, 99]
FP unit found by >1 sorter:  150
FP unit found by >1 sorter:  158
different gt units:  162 [179, 196, 196, 196, 196, 196]
different gt units:  284 [185, 185, 185, 177, 185]
different gt units:  316 [90, 90, 90, 57, 90]
FP unit found by >1 sorter:  418
FP unit found by >1 sorter:  549
different gt units:  673 [129, 182]
In [13]:
w = sw.plot_multicomp_agreement(mcmp, plot_type='pie')
In [14]:
w = sw.plot_multicomp_agreement_by_sorter(mcmp, show_legend=True)
In [15]:
fig = plt.figure(figsize=(14,4))

axes = []
for i,s in enumerate(results['TP'].keys()):
    ax = plt.subplot(2,len(sorter_list), i+1)
    ax.bar(results['FP'][s].keys(), list(results['FP'][s].values()), alpha=0.5, width = 0.6, color='r', label='false positive')
    ax.bar(results['TP'][s].keys(), list(results['TP'][s].values()), bottom=list(results['FP'][s].values()), alpha=0.5, width = 0.6, color='b', label='matched')
    ax.set_xticks(range(1,len(sorter_list)+1))
    ax.set_xticklabels(range(1,len(sorter_list)+1))
    ax.set_title(s)
    clear_axes(ax)
    axes.append(ax)
    if i == 0:
        ax.set_ylabel('Number of units')
    else:
        ax.get_shared_y_axes().join(axes[0], ax)
        ax.set_yticklabels([])
    
    ax = plt.subplot(2,len(sorter_list), len(sorter_list)+i+1)
    d  = results['SNR'][s]
    sns.boxenplot(data=pd.DataFrame([pd.Series(d[k]) for k in d.keys()]).T, color='b', ax=ax)
    ax.set_xticks(range(0,len(sorter_list)))
    ax.set_xticklabels(range(1,len(sorter_list)+1))
    clear_axes(ax)
    axes.append(ax)
    if i == 0:
        ax.set_ylabel('Ground truth SNR')
        ax.set_xlabel('Found by # sorters')
    else:
        ax.get_shared_y_axes().join(axes[1], ax)
        ax.set_yticklabels([])
    
In [16]:
# numbers for figure above

sg_names, sg_units = mcmp.compute_subgraphs()
v, c = np.unique([len(np.unique(s)) for s in sg_names], return_counts=True)
df = pd.DataFrame(np.vstack((v,c,np.round(100*c/np.sum(c),2))).T,
             columns=('in # sorters','# units','percentage'))
print('all sorters, all units:')
print(df)
df = pd.DataFrame()
for i, name in enumerate(sorter_names_short):
    v, c = np.unique([len(np.unique(sn)) for sn in sg_names if name in sn], return_counts=True)
    cl = np.zeros(len(sorter_list), dtype=int)
    cl[v.astype(int)-1] = c
    df.insert(2*i,name,cl)
    df.insert(2*i+1,name+'%',np.round(100*cl/np.sum(cl),1))
print('\nper sorter:')
print(df)

for i,s in enumerate(results['TP'].keys()):
    print(s, list(results['FP'][s].values()))
all sorters, all units:
   in # sorters  # units  percentage
0           1.0    659.0       72.34
1           2.0     18.0        1.98
2           3.0     27.0        2.96
3           4.0     30.0        3.29
4           5.0     38.0        4.17
5           6.0    139.0       15.26

per sorter:
    HS   HS%   KS   KS%   IC   IC%   SC   SC%  TDC  TDC%  HDS  HDS%
0   68  29.2  168  40.5   50  17.7  129  37.6   18   9.5  226  49.5
1    2   0.9   13   3.1    6   2.1    5   1.5    2   1.1    8   1.8
2    2   0.9   27   6.5   21   7.4    7   2.0    0   0.0   24   5.3
3    2   0.9   30   7.2   30  10.6   28   8.2    2   1.1   28   6.1
4   20   8.6   38   9.2   37  13.1   35  10.2   28  14.8   32   7.0
5  139  59.7  139  33.5  139  49.1  139  40.5  139  73.5  139  30.4
HS [42, 1, 2, 0, 1, 0]
KS [164, 1, 2, 0, 1, 0]
IC [45, 2, 2, 0, 0, 0]
SC [97, 1, 0, 0, 1, 0]
TDC [11, 0, 0, 0, 1, 0]
HDS [214, 1, 0, 0, 1, 0]

Ensemble sorting of a Neuropixels recording

Ensemble sorting of a Neuropixels recording

This notebook reproduces figures 1 and 4 from the paper SpikeInterface, a unified framework for spike sorting.

The data set for this notebook is available on the Dandi Archive: https://gui.dandiarchive.org/#/dandiset/000034

The entire data archive can be downloaded with the command dandi download https://gui.dandiarchive.org/#/dandiset/000034/draft (about 75GB).

Files required to run the code are:

These files should be in the same directory where the notebook is located (otherwise adjust paths below).

Author: Matthias Hennig, University of Edinburgh, 24 Aug 2020

Requirements

For this need you will need the following Python packages:

  • numpy
  • pandas
  • matplotlib
  • seaborn
  • spikeinterface
  • dandi
  • matplotlib-venn

To run the MATLAB-based sorters, you would also need a MATLAB license. For other sorters, please refer to the documentation on how to install sorters.

In [1]:
import os

# Matlab sorter paths:
# change these to match your environment
os.environ["IRONCLUST_PATH"] = "./ironclust"
os.environ["KILOSORT2_PATH"] = "./Kilosort2"
os.environ["HDSORT_PATH"] = "./HDsort"

import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path
import pandas as pd
import seaborn as sns
from collections import defaultdict
from matplotlib_venn import venn3

import spikeinterface as si
import spikeextractors as se
import spiketoolkit as st
import spikesorters as ss
import spikecomparison as sc
import spikewidgets as sw
from spikecomparison import GroundTruthStudy, MultiSortingComparison

%matplotlib inline

def clear_axes(ax):
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    
# print version information
si.print_spikeinterface_version()
ss.print_sorter_versions()
spikeinterface: 0.9.9
  * spikeextractor: 0.8.4
  * spiketoolkit: 0.6.3
  * spikesorters: 0.3.3
  * spikecomparison: 0.2.6
  * spikewidgets: 0.4.3

hdsort: version = '1.0.2'

herdingspikes: 0.3.7
ironclust: 5.9.8
kilosort2: git-48bf2b81d8ad
klusta: 3.0.16
mountainsort4: unknown
spykingcircus: 0.9.7
tridesclous: 1.6.0
In [2]:
# where to find the data set
data_file = Path('./') / 'mouse412804_probeC_15min.nwb'
# results are stored here
study_path = Path('./')
# this folder will contain all results
study_folder = study_path / 'study_15min/'
# this folder will be used as temorary space, and hold the sortings etc.
working_folder = study_path / 'working_15min'

# sorters to use
sorter_list = ['herdingspikes', 'kilosort2', 'ironclust', 'tridesclous', 'spykingcircus', 'hdsort']
# pass the following parameters to the sorters
sorter_params = {
#     'kilosort2': {'keep_good_only': True}, # uncomment this to test the native filter for false positives
    'spyking_circus': {'adjacency_radius': 50},
    'herdingspikes': {'filter': True, }}
sorter_names = ['HerdingSpikes', 'Kilosort2', 'Ironclust','Tridesclous', 'SpykingCircus', 'HDSort']
sorter_names_short = ['HS', 'KS', 'IC', 'TDC', 'SC', 'HDS']
In [3]:
# create an extractor object for the raw data 
recording = se.NwbRecordingExtractor(str(data_file))

print("Number of frames: {}\nSampling rate: {}Hz\nNumber of channels: {}".format(
    recording.get_num_frames(), recording.get_sampling_frequency(), 
    recording.get_num_channels()))
Number of frames: 27000000
Sampling rate: 30000.0Hz
Number of channels: 248

Run spike sorters and perform comparison between all outputs

In [4]:
# set up the study environment and run all sorters
# sorters are not re-run if outputs are found in working_folder 

if not study_folder.is_dir():
    print('Setting up study folder:', study_folder)
    os.mkdir(study_folder)

# run all sorters
result_dict = ss.run_sorters(sorter_list=sorter_list, recording_dict_or_list={'rec': recording}, with_output=True,
                             sorter_params=sorter_params, working_folder=working_folder, engine='loop', 
                             mode='keep', verbose=True)

# store sortings in a list for quick access
sortings = []
for s in sorter_list:
    sortings.append(result_dict['rec',s])
/disk/scratch/mhennig/miniconda3/envs/spikesorting/lib/python3.7/site-packages/spikeextractors/extractors/matsortingextractor/matsortingextractor.py:65: ResourceWarning: unclosed file <_io.BufferedReader name='/disk/scratch/mhennig/spikeinterface/Neuropixels_Allen/working_15min/rec/hdsort/hdsort_output/hdsort_output_results.mat'>
  raise ImportError("Version 7.2 .mat file given, but you don't have h5py installed.")
ResourceWarning: Enable tracemalloc to get the object allocation traceback
In [5]:
# perform a multi-comparison, all to all sortings
# result is stored, and loaded from disk if the file is found

if not os.path.isfile(study_folder / 'multicomparison.gpickle'):
    mcmp = sc.compare_multiple_sorters(sorting_list=sortings, name_list=sorter_names_short, 
                                       verbose=True)
    print('saving multicomparison')
    mcmp.dump(study_folder)
else:
    print('loading multicomparison')
    mcmp = sc.MultiSortingComparison.load_multicomparison(study_folder)
loading multicomparison
Multicomaprison step 3: clean graph
Removed edge ('IC_137', 'HS_49', {'weight': 0.552943287867911})
Removed edge ('KS_185', 'HDS_18005', {'weight': 0.6143539400371452})
Removed edge ('KS_295', 'HS_159', {'weight': 0.6111111111111112})
Removed edge ('KS_195', 'TDC_8', {'weight': 0.6909090909090909})
Removed 4 duplicate nodes
Multicomaprison step 4: extract agreement from graph

Figure 1 - comparison of sorter outputs

In [6]:
# activity levels on the probe

plt.figure(figsize=(16,2))
ax = plt.subplot(111)
w = sw.plot_activity_map(recording, trange=(0,20), transpose=True, ax=ax, background='w', frame=True)
ax.plot((50,150),(-30,-30),'k-')
ax.annotate('100$\\mu m$',(100,-90), ha='center');
In [7]:
# example data traces

plt.figure(figsize=(16,6))
ax = plt.subplot(111)
w = sw.plot_timeseries(recording, channel_ids=range(20,28), color='k', ax=ax, trange=(1,2))
ax.axis('off')
p = ax.get_position()
p.y0 = 0.55
ax.set_position(p)
ax.set_xticks(())
ax.plot((1.01,1.11),(-1790,-1790),'k-')
ax.annotate('100ms',(1.051,-2900), ha='center');
ax.set_ylim((-2900,ax.set_ylim()[1]))
Out[7]:
(-2900, 19056.0)
In [8]:
ax = plt.subplot(111)
ax.bar(range(len(sortings)), [len(s.get_unit_ids()) for s in sortings], color='tab:blue')
ax.set_xticks(range(len(sorter_names)))
ax.set_xticklabels(sorter_names_short, rotation=60, ha='center')
ax.set_ylabel('Units detected') 
clear_axes(ax)
In [9]:
w = sw.plot_multicomp_agreement(mcmp, plot_type='pie')
In [10]:
w = sw.plot_multicomp_agreement_by_sorter(mcmp, show_legend=True)
In [11]:
# numbers for figure above

print('number of units detected:')
for i,s in enumerate(sortings):
    print("{}: {}".format(sorter_names[i],len(s.get_unit_ids())))

sg_names, sg_units = mcmp.compute_subgraphs()
v, c = np.unique([len(np.unique(s)) for s in sg_names], return_counts=True)
df = pd.DataFrame(np.vstack((v,c,np.round(100*c/np.sum(c),2))).T,
             columns=('in # sorters','# units','percentage'))
print('\nall sorters, all units:')
print(df)
df = pd.DataFrame()
for i, name in enumerate(sorter_names_short):
    v, c = np.unique([len(np.unique(sn)) for sn in sg_names if name in sn], return_counts=True)
    df.insert(2*i,name,c)
    df.insert(2*i+1,name+'%',np.round(100*c/np.sum(c),1))
print('\nper sorter:')
print(df)
number of units detected:
HerdingSpikes: 210
Kilosort2: 448
Ironclust: 234
Tridesclous: 187
SpykingCircus: 628
HDSort: 317

all sorters, all units:
   in # sorters  # units  percentage
0           1.0   1093.0       80.60
1           2.0     89.0        6.56
2           3.0     52.0        3.83
3           4.0     49.0        3.61
4           5.0     40.0        2.95
5           6.0     33.0        2.43

per sorter:
    HS   HS%   KS   KS%  IC   IC%  TDC  TDC%   SC   SC%  HDS  HDS%
0  139  66.2  226  50.7  66  28.3   82  43.9  396  63.1  184  58.0
1    4   1.9   57  12.8  15   6.4    8   4.3   67  10.7   27   8.5
2    6   2.9   44   9.9  34  14.6    5   2.7   47   7.5   20   6.3
3    8   3.8   46  10.3  45  19.3   21  11.2   46   7.3   30   9.5
4   20   9.5   40   9.0  40  17.2   38  20.3   39   6.2   23   7.3
5   33  15.7   33   7.4  33  14.2   33  17.6   33   5.3   33  10.4

Supplemental Figure - example unit templates

In [12]:
# show unit emplates and spike trains for two units/all sorters

sorting = mcmp.get_agreement_sorting(minimum_agreement_count=6)

get_sorting = lambda u: [mcmp.sorting_list[i] for i,n in enumerate(mcmp.name_list) if n==u[0]][0]
get_spikes = lambda u: [mcmp.sorting_list[i].get_unit_spike_train(u[1]) for i,n in enumerate(mcmp.name_list) if n==u[0]][0]


# one well matched and one not so well matched unit, all sorters
show_units = [2,17]

for i,unit in enumerate(show_units):
    fig = plt.figure(figsize=(16, 2))
    ax = plt.subplot(111)
    ax.set_title('Average agreement: {:.2f}'.format(sorting.get_unit_property(sorting.get_unit_ids()[unit],'avg_agreement')))
    units = sorting.get_unit_property(sorting.get_unit_ids()[unit], 'sorter_unit_ids')
    cols = plt.cm.Accent(np.arange(len(units))/len(units))
    for j,u in enumerate(dict(sorted(units.items())).items()):
        s = get_sorting(u).get_units_spike_train((u[1],))[0]
        s = s[s<20*get_sorting(u).get_sampling_frequency()]
        ax.plot(s/get_sorting(u).get_sampling_frequency(), np.ones(len(s))*j, '|', color=cols[j], label=u[0])
    ax.set_frame_on(False)
    ax.set_xticks(())
    ax.set_yticks(())
    ax.plot((0,1),(-1,-1),'k')
    ax.annotate('1s',(0.5,-1.75), ha='center')
    ax.set_ylim((-2,len(units)+1))
    
    fig = plt.figure(figsize=(16, 2))
    units = sorting.get_unit_property(sorting.get_unit_ids()[unit], 'sorter_unit_ids')
    print(units)
    print('Agreement: {}'.format(sorting.get_unit_property(sorting.get_unit_ids()[unit],'avg_agreement')))
    cols = plt.cm.Accent(np.arange(len(units))/len(units))
    for j,u in enumerate(dict(sorted(units.items())).items()):
        ax = plt.subplot(1, len(sorter_list), j+1)
        w = sw.plot_unit_templates(recording, get_sorting(u), unit_ids=(u[1],), max_spikes_per_unit=10, 
                               channel_locs=True, radius=75, show_all_channels=False, color=[cols[j]], 
                               lw=1.5, ax=ax, plot_channels=False, set_title=False, axis_equal=True) 
        # was 100 spikes in original plot
        ax.set_title(u[0])
{'SC': 563, 'TDC': 3, 'HDS': 44005, 'HS': 5, 'KS': 116, 'IC': 88}
Agreement: 0.9661710270732794
{'HS': 49, 'TDC': 15, 'KS': 189, 'HDS': 18005, 'IC': 137, 'SC': 272}
Agreement: 0.6920718057618236

Figure 4 - comparsion between ensembe sortings and curated data

In [13]:
# perform a comparison with curated sortings (KS2)

curated1 = se.NwbSortingExtractor('sub-mouse412804_ses-20200824T155542.nwb', sampling_frequency=30000)
curated2 = se.NwbSortingExtractor('sub-mouse412804_ses-20200824T155543.nwb', sampling_frequency=30000)

comparison_curated = sc.compare_two_sorters(curated1, curated2)
comparison_curated_ks = sc.compare_multiple_sorters((curated1, curated2, sortings[sorter_list.index('kilosort2')]))

# consensus sortings (units where at least 2 sorters agree)
sorting = mcmp.get_agreement_sorting(minimum_agreement_count=2)
consensus_sortings = []
units_dict = defaultdict(list)
units = [sorting.get_unit_property(u,'sorter_unit_ids') for u in sorting.get_unit_ids()]
for au in units:
    for u in au.items():
        units_dict[u[0]].append(u[1])
for i,s in enumerate(sorter_names_short):
    consensus_sortings.append(se.SubSortingExtractor(sortings[i], unit_ids=units_dict[s]))

# orphan units (units found by only one sorter)
sorting = mcmp.get_agreement_sorting(minimum_agreement_count=1, minimum_agreement_count_only=True)
unmatched_sortings = []
units_dict = defaultdict(list)
units = [sorting.get_unit_property(u,'sorter_unit_ids') for u in sorting.get_unit_ids()]
for au in units:
    for u in au.items():
        units_dict[u[0]].append(u[1])
for i,s in enumerate(sorter_names_short):
    unmatched_sortings.append(se.SubSortingExtractor(sortings[i], unit_ids=units_dict[s]))
    
consensus_curated_comparisons = []
for s in consensus_sortings:
    consensus_curated_comparisons.append(sc.compare_two_sorters(s, curated1))
    consensus_curated_comparisons.append(sc.compare_two_sorters(s, curated2))    
    
unmatched_curated_comparisons = []
for s in unmatched_sortings:
    unmatched_curated_comparisons.append(sc.compare_two_sorters(s, curated1))
    unmatched_curated_comparisons.append(sc.compare_two_sorters(s, curated2))

all_curated_comparisons = []
for s in sortings:
    all_curated_comparisons.append(sc.compare_two_sorters(s, curated1))
    all_curated_comparisons.append(sc.compare_two_sorters(s, curated2))    \
    
# count various types of units

count_mapped = lambda x : np.sum([u!=-1 for u in x.get_mapped_unit_ids()])
count_not_mapped = lambda x : np.sum([u==-1 for u in x.get_mapped_unit_ids()])
count_units = lambda x : len(x.get_unit_ids())

n_consensus_curated_mapped = np.array([count_mapped(c.get_mapped_sorting1()) for c in consensus_curated_comparisons]).reshape((len(sorter_list),2))
n_consensus_curated_unmapped = np.array([count_not_mapped(c.get_mapped_sorting1()) for c in consensus_curated_comparisons]).reshape((len(sorter_list),2))
n_unmatched_curated_mapped = np.array([count_mapped(c.get_mapped_sorting1()) for c in unmatched_curated_comparisons]).reshape((len(sorter_list),2))
n_all_curated_mapped = np.array([count_mapped(c.get_mapped_sorting1()) for c in all_curated_comparisons]).reshape((len(sorter_list),2))
n_all_curated_unmapped = np.array([count_not_mapped(c.get_mapped_sorting1()) for c in all_curated_comparisons]).reshape((len(sorter_list),2))
n_curated_all_unmapped = np.array([count_not_mapped(c.get_mapped_sorting2()) for c in all_curated_comparisons]).reshape((len(sorter_list),2))
n_all = np.array([count_units(s) for s in sortings])
n_consensus = np.array([count_units(s) for s in consensus_sortings])
n_unmatched = np.array([count_units(s) for s in unmatched_sortings])
n_curated1 = len(curated1.get_unit_ids())
n_curated2 = len(curated2.get_unit_ids())
In [14]:
# overlap between two manually curated data and the Kilosort2 sorting they were derived from

i = {}
for k in ['{0:03b}'.format(v) for v in range(1,2**3)]:
    i[k] = 0
i['111'] = len(comparison_curated_ks.get_agreement_sorting(minimum_agreement_count=3).get_unit_ids())
s = comparison_curated_ks.get_agreement_sorting(minimum_agreement_count=2, minimum_agreement_count_only=True)
units = [s.get_unit_property(u,'sorter_unit_ids').keys() for u in s.get_unit_ids()]
for u in units:
    if 'sorting1' in u and 'sorting2' in u:
        i['110'] += 1
    if 'sorting1' in u and 'sorting3' in u:
        i['101'] += 1
    if 'sorting2' in u and 'sorting3' in u:
        i['011'] += 1   
s = comparison_curated_ks.get_agreement_sorting(minimum_agreement_count=1, minimum_agreement_count_only=True)
units = [s.get_unit_property(u,'sorter_unit_ids').keys() for u in s.get_unit_ids()]
for u in units:
    if 'sorting1' in u:
        i['100'] += 1
    if 'sorting2' in u:
        i['010'] += 1
    if 'sorting3' in u:
        i['001'] += 1   
colors = plt.cm.RdYlBu(np.linspace(0,1,3))
venn3(subsets = i,set_labels=('Curated 1', 'Curated 2', 'Kilosort2'), 
      set_colors=colors, alpha=0.6, normalize_to=100)
Out[14]:
<matplotlib_venn._common.VennDiagram at 0x7fadfef520b8>
In [15]:
# overlaps betweem ensemble sortings (per sorter) and manually curated sortings

def plot_mcmp_results(data, labels, ax, ylim=None, yticks=None, legend=False):
    angles = (np.linspace(0, 2*np.pi, len(sorter_list), endpoint=False)).tolist()
    angles += angles[:1]
    for i,v in enumerate(data):
        v = v.tolist() + v[:1].tolist()
        ax.bar(np.array(angles)+i*2*np.pi/len(sorter_list)/len(data)/2-2*np.pi/len(sorter_list)/len(data)/4, 
               v, label=labels[i], 
               alpha=0.8, width=np.pi/len(sorter_list)/2)
        
    ax.set_thetagrids(np.degrees(angles), sorter_names_short)
    if legend:
        ax.legend(bbox_to_anchor=(1.0, 1), loc=2, borderaxespad=0., frameon=False, fontsize=8, markerscale=0.25)
    ax.set_theta_offset(np.pi / 2)
    ax.set_theta_direction(-1)
    if ylim is not None:
        ax.set_ylim(ylim)
    if yticks is not None:
        ax.set_yticks(yticks)

plt.figure(figsize=(14,3))
sns.set_palette(sns.color_palette("Set1"))
ax = plt.subplot(131, projection='polar')
plot_mcmp_results((n_all_curated_mapped[:,0]/n_all*100, 
                   n_all_curated_mapped[:,1]/n_all*100), 
                  ('Curated 1','Curated 2'), ax, yticks=np.arange(20,101,20))
ax.set_title('Percent all units\nwith match in curated sets',pad=20);
plt.ylim((0,100))

ax = plt.subplot(132, projection='polar')
plot_mcmp_results((n_consensus_curated_mapped[:,0]/n_consensus*100, 
                   n_consensus_curated_mapped[:,1]/n_consensus*100), 
                  ('Curated 1','Curated 2'), ax, yticks=np.arange(20,101,20))
ax.set_title('Percent consensus units\nwith match in curated sets',pad=20);
plt.ylim((0,100))

ax = plt.subplot(133, projection='polar')
plot_mcmp_results((n_unmatched_curated_mapped[:,0]/n_unmatched*100, 
                   n_unmatched_curated_mapped[:,1]/n_unmatched*100), 
                  ('Curated 1','Curated 2'), ax, ylim=(0,30), yticks=np.arange(10,21,10), legend=True)
ax.set_title('Percent non-consensus units\nwith match in curated sets',pad=20);
In [16]:
# numbers for figure above

df = pd.DataFrame(np.vstack((n_all_curated_mapped[:,0]/n_all*100, n_all_curated_mapped[:,1]/n_all*100,
                           n_all_curated_mapped[:,0], n_all_curated_mapped[:,1])).T,
                  columns = ('C1 %', 'C2 %', 'C1', 'C2'), index=sorter_names_short)
print('Percent all units with match in curated sets')
print(df)

df = pd.DataFrame(np.vstack((n_consensus_curated_mapped[:,0]/n_consensus*100, n_consensus_curated_mapped[:,1]/n_consensus*100,
                           n_consensus_curated_mapped[:,0],n_consensus_curated_mapped[:,1])).T,
                  columns = ('C1 %', 'C2 %', 'C1', 'C2'), index=sorter_names_short)
print('\nPercent consensus units with match in curated sets')
print(df)

df = pd.DataFrame(np.vstack((n_unmatched_curated_mapped[:,0]/n_unmatched*100,
                             n_unmatched_curated_mapped[:,1]/n_unmatched*100,
                           n_unmatched_curated_mapped[:,0],n_unmatched_curated_mapped[:,1])).T,
                  columns = ('C1 %', 'C2 %', 'C1', 'C2'), index=sorter_names_short)
print('\nPercent non-consensus units with match in curated sets')
print(df)
Percent all units with match in curated sets
          C1 %       C2 %     C1     C2
HS   25.714286  26.190476   54.0   55.0
KS   50.669643  56.250000  227.0  252.0
IC   59.829060  61.111111  140.0  143.0
TDC  42.245989  45.454545   79.0   85.0
SC   27.547771  27.547771  173.0  173.0
HDS  30.283912  29.968454   96.0   95.0

Percent consensus units with match in curated sets
          C1 %       C2 %     C1     C2
HS   76.056338  76.056338   54.0   54.0
KS   84.545455  89.545455  186.0  197.0
IC   82.634731  83.832335  138.0  140.0
TDC  75.238095  80.000000   79.0   84.0
SC   72.844828  73.275862  169.0  170.0
HDS  70.676692  70.676692   94.0   94.0

Percent non-consensus units with match in curated sets
          C1 %       C2 %    C1    C2
HS    0.000000   0.719424   0.0   1.0
KS   18.584071  24.336283  42.0  55.0
IC    3.030303   4.545455   2.0   3.0
TDC   0.000000   1.219512   0.0   1.0
SC    1.010101   0.757576   4.0   3.0
HDS   1.630435   1.086957   3.0   2.0
In [ ]:
 

Marques-Smith neuropixel 384ch paired recording

Sorter comparison with paired (neuropixel - patch) recordings

Author : Samuel Garcia

André Marques-Smith make an open dataset with simultaneous patch-clamp and neuropixel probe extracellular recordings from the same cortical neuron in anaesthetized rats.

This is very very usefull to test spike sorting engine.

The original contain 42 recordings.

Here we select only a subset of 6 files. I keep only when the SNR in the extra cellular trace is big enough to be detected. One file (c24) was remove because the juxta cellular itself is ambiguous.

The patch recording will be the "Groud Truth". And the neuropixel with 384ch will computed by 5 sorters to compare results.

Please have a look to the paper:

https://www.biorxiv.org/content/10.1101/370080v2

The repo the explain everything

https://github.com/kampff-lab/sc.io/tree/master/Paired%20Recordings

Data set availaible here :

http://crcns.org/data-sets/methods/spe-1

or here

https://drive.google.com/drive/folders/13GCOuWN4QMW6vQmlNIolUrxPy-4Wv1BC

Note :

  • I will note use the spiek indexes provide by André because for some files small errors on double peak detection can occurs.
  • This results is also on spike forest here but the spikeforest keep only 32 channels to reduce the computation. Here the computation is done on 384 channels. Lets see if we have the same results.
In [7]:
# import everything
import os, getpass

kilosort2_path = '/home/samuel/Documents/Spikeinterface/Kilosort2'
os.environ["KILOSORT2_PATH"] = kilosort2_path

kilosort_path = '/home/samuel/Documents/Spikeinterface/KiloSort/'
os.environ["KILOSORT_PATH"] = kilosort_path

ironclust_path = '/home/samuel/Documents/Spikeinterface/ironclust'
os.environ["IRONCLUST_PATH"] = ironclust_path

from pathlib import Path
import scipy.signal
import scipy.io

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path


import spikeinterface as si
import spikeinterface.extractors as se
import spikeinterface.sorters as ss
import spikeinterface.widgets as sw

from spikeinterface.comparison import GroundTruthStudy
In [2]:
si.print_spikeinterface_version()
ss.print_sorter_versions()
spikeinterface: 0.9.9
  * spikeextractor: 0.8.4
  * spiketoolkit: 0.6.3
  * spikesorters: 0.3.3
  * spikecomparison: 0.2.6
  * spikewidgets: 0.4.3

herdingspikes: 0.3.7+git.45665a2b6438
ironclust: 5.9.8
kilosort: git-cd040da1963d
kilosort2: git-e243c934339e
spykingcircus: 0.9.7
tridesclous: 1.6.1.dev

path and list

In [4]:
p = '/media/samuel/dataspikesorting/DataSpikeSortingHD2/andre_paired_neuropixel/'
p = Path(p)

recordings_folder = p / 'recordings'
study_folder = p / 'study_paired_neuropixel'


rec_names = [
    'c14',
    'c26', 
    'c28', 
    'c37',
    'c45', 
    'c46'
]

function to detect peak from patch recording

File provide by André contains small errors in peak detection. Here we computed then again.

In [6]:
def detect_peak_on_patch_sig(patch_sig, sample_rate):
    # filter because some traces have drift
    sos = scipy.signal.iirfilter(5, 200./sample_rate*2, analog=False, btype = 'highpass', ftype = 'butter', output = 'sos')
    patch_sig_f = scipy.signal.sosfiltfilt(sos, patch_sig, axis=0)
    
    med = np.median(patch_sig_f)
    mad = np.median(np.abs(patch_sig_f-med))*1.4826
    thresh = med - 12 * mad
    
    # 1 ms aounrd peak
    d = int(sample_rate * 0.001)
    spike_indexes, prop = scipy.signal.find_peaks(-patch_sig_f, height=-thresh, distance=d)

    #~ fig, ax = plt.subplots()
    #~ ax.plot(patch_sig_f)
    #~ ax.axhline(thresh)
    #~ ax.plot(spike_indexes, patch_sig_f[spike_indexes], ls='None', marker='o')
    #~ plt.show()
    
    return spike_indexes
In [ ]:
## create the study
In [ ]:
# the file chanMap.mat contain the geometry
d = scipy.io.loadmat(str(p / 'chanMap.mat'))
locations = np.zeros((384, 2))
locations[:, 0] = d['xcoords'][:, 0]
locations[:, 1] = d['ycoords'][:, 0]

#~ fig, ax = plt.subplots()
#~ ax.scatter(locations[:, 0], locations[:, 1])
#~ plt.show()


gt_dict = {}
for rec_name in rec_names:
    print(rec_name)
    sample_rate = 30000.

    # neuropixel sigs
    raw_bin_filename = recordings_folder /  rec_name / (rec_name + '_npx_raw.bin')
    mea_sigs = np.memmap(raw_bin_filename, dtype='int16', mode='r').reshape(-1, 384)

    # patch recoring
    filename = recordings_folder /  rec_name / (rec_name + '_patch_ch1.bin')
    #~ patch_sig = np.memmap(str(filename), dtype='float64', mode='r')
    patch_sig = np.fromfile(str(filename), dtype='float64')
    
    # spike index inn the patch clock refrence
    sr = 50023. #  this is not theexact freq but it do not matter here
    gt_spike_indexes_patch = detect_peak_on_patch_sig(patch_sig, sr)

    # this is time factor strech between the 2 recordings (neuropixel and patch)
    time_factor = mea_sigs.shape[0] / patch_sig.shape[0]
    print('time_factor', time_factor)
    
    # spike index in the neuropixel clock refrence
    gt_spike_indexes = (gt_spike_indexes_patch * time_factor).astype('int64')


    # recording
    rec = se.BinDatRecordingExtractor(raw_bin_filename, sample_rate, 384, 'int16', offset=0, time_axis=0)
    rec.set_channel_locations(locations)

    # gt sorting
    sorting_gt = se.NumpySortingExtractor()
    sorting_gt.set_times_labels(gt_spike_indexes, np.zeros(gt_spike_indexes.size, dtype='int64'))
    sorting_gt.set_sampling_frequency(sample_rate)

    gt_dict[rec_name] = (rec, sorting_gt)


study = GroundTruthStudy.create(study_folder, gt_dict)

Get signal to noise ratio for all units

In [ ]:
study = GroundTruthStudy(study_folder)
snr = study.concat_all_snr()
snr
In [ ]:
fig, ax = plt.subplots()
ax.hist(snr['snr'].values, bins=np.arange(0, 40, 5))
ax.set_xlabel('GT units SNR')

Run all sorters

In [9]:
sorter_list = ['herdingspikes', 'ironclust', 'kilosort2',  
                'spykingcircus', 'tridesclous']
In [ ]:
study = GroundTruthStudy(study_folder)

study.run_sorters(sorter_list, mode='keep', verbose=False)

Run comparison with ground truth and retreive result tables

In [ ]:
study = GroundTruthStudy(study_folder)
In [ ]:
# this copy sorting is necessary to copy results from sorter
# into a centralize folder with all results
study.copy_sortings()
In [ ]:
# this run all comparison to GT
# exhaustive_gt=False because it is a pair recording so only one GT units
study.run_comparisons(exhaustive_gt=False, match_score=0.1, overmerged_score=0.2)
In [ ]:
# this retrieve results
comparisons = study.comparisons
dataframes = study.aggregate_dataframes()

Run times

In [ ]:
dataframes['run_times'].set_index(['rec_name', 'sorter_name']).unstack('sorter_name')
In [ ]:
sns.set_palette(sns.color_palette("Set1"))

#fig, ax = plt.subplots()
g = sns.catplot(data=dataframes['run_times'], x='sorter_name', y='run_time',
                hue="rec_name",
                order=sorter_list,
                kind='bar',
               legend=False)
g.fig.set_size_inches(12,5)

#ax.set_yscale('log')
#ax.set_ylabel('Run time (s)');
#ax.set_xlabel(None);
fig.axes[0].set_xticklabels(sorter_list, rotation=40, ha='right');

 accuracy/precision/recall scores per sorters

In [ ]:
sns.set_palette(sns.color_palette("Set1"))

df = pd.melt(dataframes['perf_by_units'], id_vars=['rec_name', 'sorter_name'],
            var_name='metric', value_name='score', 
            value_vars=('accuracy','precision', 'recall'))
display(df)

# df.assign(Generation=df.metric.map({'metric':'Metric','accuracy': 'Accuracy', 'precision': 'Precision', 'recall': 'Recall'}))
g = sns.catplot(data=df, x='sorter_name', y='score', hue='metric', kind='swarm', dodge=True, # ax=ax, 
                order=sorter_list, legend_out=True, s=4)
g.fig.set_size_inches(12,5)
In [ ]:
sns.set_palette(sns.color_palette("deep"))
g = sns.relplot(data=dataframes['perf_by_units'], x='precision', y='recall', col='sorter_name', 
                col_wrap=3, col_order=sorter_list, s=60)

 Accuracy vns SNR

In [ ]:
df = dataframes['perf_by_units']
df
In [ ]:
df = dataframes['perf_by_units']

# add snr to the by-unit table
df['snr'] = None
for rec_name, gt_id in snr.index:
    df['snr'].loc[(df['gt_unit_id']==gt_id) & (df['rec_name']==rec_name)] = snr.at[(rec_name, gt_id), 'snr']

sns.set_palette(sns.color_palette("deep"))
g = sns.relplot(data=df, x='snr', y='accuracy', col='sorter_name',
        col_wrap=3, col_order=sorter_list, s=80)

Does the recording's duration affect the quality of spike sorting?

Does the recording's duration affect the quality of spike sorting?

This notebook investigates if and how the duration of the recording affects spike sorting.

Obviously, each sorter engine needs a minimum number of events to detect a "cluster", and therefore a unit. If a neuron doesn't fire enough during a recording it won't be detected. The number of event per units depends on the recording duration and the each individual firing rates.

In order to test this phenomenon, we use the same dataset (with the same neurons and firing rates), but we vary the duration of the recording.

The simulated recording is generated with MEArec using a Neuronexus-32 probe. This specific dataset seems relatively easy to sort. The "SYNTH_MEAREC_NEURONEXUS" dataset in SpikeForest (which uses the same probe), in fact, shows quite good results for all sorters. The original duration is 600s (10 min).

Here we have generated a new but similar recording with a duration of 1800s. Then we have shortened it to 60s, 300s, 600s and 1800s (original). The recording can be downloaded from Zenodo: https://doi.org/10.5281/zenodo.4058272

The dataset name is: recordings_10cells_Neuronexus-32_1800.0_10.0uV_2020-02-28.h5. It contains 10 neurons recorded on a Neuronexus-32 probe. The duration is 1800s and the noise level is 10uV.

Let's see if spike sorters are robust to fewer events and if are able to deal with long durations or they end up finding too many events.

Author: Samuel Garcia, CRNL, Lyon

Requirements

For this need you will need the following Python packages:

  • numpy
  • pandas
  • matplotlib
  • seaborn
  • spikeinterface

To run the MATLAB-based sorters, you would also need a MATLAB license. For other sorters, please refer to the documentation on how to install sorters.

Installation and imports

In [1]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path

import spikeinterface as si
import spikeinterface.extractors as se
import spikeinterface.sorters as ss
import spikeinterface.widgets as sw

from spikeinterface.comparison import GroundTruthStudy
/home/samuel/.virtualenvs/py36/lib/python3.6/site-packages/statsmodels/tools/_testing.py:19: FutureWarning: pandas.util.testing is deprecated. Use the functions in the public API at pandas.testing instead.
  import pandas.util.testing as tm
In [ ]:
# clone and install MATLAB sorters

# kilosort2
!git clone https://github.com/MouseLand/Kilosort2.git
kilosort2_path = './Kilosort2'
ss.Kilosort2Sorter.set_kilosort2_path(kilosort2_path)

# kilosort
!git clone https://github.com/cortex-lab/KiloSort.git
kilosort_path = './KiloSort'
ss.KilosortSorter.set_kilosort_path(kilosort_path)

# ironclust
!git clone https://github.com/flatironinstitute/ironclust.git
ironclust_path = './ironclust'
ss.IronclustSorter.set_ironclust_path(ironclust_path)
In [2]:
%matplotlib inline

# some matplotlib hack to prettify figure
SMALL_SIZE = 12
MEDIUM_SIZE = 14
BIGGER_SIZE = 16

plt.rc('font', size=SMALL_SIZE)          # controls default text sizes
plt.rc('axes', titlesize=SMALL_SIZE)     # fontsize of the axes title
plt.rc('axes', labelsize=MEDIUM_SIZE)    # fontsize of the x and y labels
plt.rc('xtick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('ytick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('legend', fontsize=SMALL_SIZE)    # legend fontsize
plt.rc('figure', figsize=(10.0, 8.0))  # figsize
plt.rc('figure', titlesize=BIGGER_SIZE)  # fontsize of the figure title

def clear_axes(ax):
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)    

Check spikeinterface version and sorter version

In [3]:
si.print_spikeinterface_version()
ss.print_sorter_versions()
spikeinterface: 0.9.1
  * spikeextractor: 0.7.2
  * spiketoolkit: 0.5.2
  * spikesorters: 0.2.4
  * spikecomparison: 0.2.3
  * spikewidgets: 0.3.3

herdingspikes: 0.3.7+git.45665a2b6438
ironclust: 5.9.4
kilosort: git-cd040da1963d
kilosort2: git-67a42a87b866
klusta: 3.0.16
mountainsort4: unknown
spykingcircus: 0.9.2
tridesclous: 1.5.0

Setup global path

In [4]:
# Change this path to point to where you downloaded the dataset
p = Path('/home/samuel/Documents/DataSpikeSorting/mearec/')
study_folder = p / 'study_mearec_neuronexus_several_durations/'

Setup ground truth study

In [5]:
mearec_filename = p / 'recordings_10cells_Neuronexus-32_1800.0_10.0uV_2020-02-28.h5'
In [ ]:
rec  = se.MEArecRecordingExtractor(mearec_filename, locs_2d=True)
gt_sorting = se.MEArecSortingExtractor(mearec_filename)

fs = rec.get_sampling_frequency()

gt_dict = {}
durations = [60, 300, 600, 1800]
for duration in durations:
    sub_rec = se.SubRecordingExtractor(rec, start_frame=0, end_frame=int(duration*fs))
    sub_sorting = se.SubSortingExtractor(gt_sorting, start_frame=0, end_frame=int(duration*fs))
    gt_dict[f'rec{duration}'] = (sub_rec, sub_sorting)
study = GroundTruthStudy.create(study_folder, gt_dict)

Run all sorters

In [7]:
sorter_list = ['herdingspikes', 'ironclust', 'kilosort2',  'kilosort',
               'mountainsort4', 'spykingcircus', 'tridesclous'] 
In [ ]:
study = GroundTruthStudy(study_folder)
sorter_params = {} 
study.run_sorters(sorter_list, sorter_params=sorter_params, mode='keep', verbose=True)

Get signal to noise ratio for all units

Units are the same in each recording so the snr is the same lets take from the longest one

In [42]:
study = GroundTruthStudy(study_folder)
snr = study.get_units_snr(rec_name='rec1800')
snr
Out[42]:
snr rec_name
gt_unit_id
0 22.270592 rec1800
1 9.222510 rec1800
2 5.939772 rec1800
3 14.853892 rec1800
4 7.504935 rec1800
5 9.803070 rec1800
6 16.042390 rec1800
7 11.883403 rec1800
8 7.490971 rec1800
9 5.482461 rec1800
In [43]:
fig, ax = plt.subplots()
ax.hist(snr['snr'].values, bins=10)
ax.set_xlabel('GT units SNR')
Out[43]:
Text(0.5, 0, 'GT units SNR')

Run comparison with ground truth and retreive result tables

In [102]:
# this copy sorting is necessary to copy results from sorter
# into a centralize folder with all results
study.copy_sortings()

# this run all comparison sto GT
study.run_comparisons(exhaustive_gt=True, match_score=0.1, overmerged_score=0.2)
In [103]:
# this retrieve results
comparisons = study.comparisons
dataframes = study.aggregate_dataframes()

Run times

In [104]:
run_times = dataframes['run_times']
run_times
Out[104]:
rec_name sorter_name run_time
0 rec600 tridesclous 118.908889
1 rec300 kilosort2 44.356094
2 rec600 kilosort 119.256738
3 rec1800 mountainsort4 795.946418
4 rec60 herdingspikes 5.698961
5 rec60 spykingcircus 31.770160
6 rec60 kilosort 28.761853
7 rec60 tridesclous 19.454172
8 rec1800 herdingspikes 124.915830
9 rec60 mountainsort4 89.919816
10 rec600 spykingcircus 154.454786
11 rec300 spykingcircus 82.981423
12 rec600 kilosort2 64.552614
13 rec60 kilosort2 24.819330
14 rec600 mountainsort4 327.498347
15 rec300 mountainsort4 197.165915
16 rec300 herdingspikes 21.194948
17 rec1800 ironclust 98.909468
18 rec300 kilosort 67.272793
19 rec300 ironclust 27.962674
20 rec1800 kilosort2 144.794843
21 rec600 herdingspikes 44.548482
22 rec600 ironclust 40.769642
23 rec1800 kilosort 333.402659
24 rec300 tridesclous 66.892407
25 rec1800 tridesclous 383.363512
26 rec60 ironclust 19.337697
27 rec1800 spykingcircus 471.506204
In [105]:
# insert durations
run_times['duration']  = run_times['rec_name'].apply(lambda s: float(s.replace('rec', '')))
In [106]:
g = sns.catplot(data=run_times, x='duration', y='run_time', hue='sorter_name', kind='bar')

Accuracy vs duration

In [107]:
perf = dataframes['perf_by_units']
perf
Out[107]:
rec_name sorter_name gt_unit_id accuracy recall precision false_discovery_rate miss_rate
0 rec600 spykingcircus 0 0.980565 0.980565 1 0 0.0194346
1 rec600 spykingcircus 1 0.993337 0.993337 1 0 0.00666263
2 rec600 spykingcircus 2 0.832362 0.832824 0.999334 0.000666445 0.167176
3 rec600 spykingcircus 3 0.995868 0.995868 1 0 0.00413223
4 rec600 spykingcircus 4 0.9933 0.9933 1 0 0.00669975
... ... ... ... ... ... ... ... ...
275 rec1800 kilosort 5 0.998887 0.999026 0.999861 0.000139315 0.000974388
276 rec1800 kilosort 6 0.999045 0.999331 0.999713 0.000286834 0.000669024
277 rec1800 kilosort 7 0.99865 0.999285 0.999364 0.000635627 0.000715023
278 rec1800 kilosort 8 0.998279 0.998552 0.999726 0.000274187 0.00144757
279 rec1800 kilosort 9 0.995271 0.995552 0.999716 0.000283597 0.00444805

280 rows × 8 columns

In [108]:
# insert durations
perf['duration']  = perf['rec_name'].apply(lambda s: float(s.replace('rec', '')))
In [109]:
g = sns.catplot(data=perf, x='duration', y='accuracy', hue='sorter_name', kind='bar')

Count good, bad, false positive units vs duration

In [110]:
count_units = dataframes['count_units']
count_units
Out[110]:
rec_name sorter_name num_gt num_sorter num_well_detected num_redundant num_overmerged num_false_positive num_bad
0 rec600 spykingcircus 10 27 10 0 0 17 17
1 rec300 tridesclous 10 9 8 0 0 0 0
2 rec1800 mountainsort4 10 199 2 6 0 183 189
3 rec300 mountainsort4 10 104 2 9 0 85 94
4 rec60 kilosort2 10 13 9 0 1 3 3
5 rec60 kilosort 10 5 3 0 1 0 0
6 rec1800 herdingspikes 10 71 9 0 0 61 61
7 rec1800 tridesclous 10 11 9 0 0 1 1
8 rec600 herdingspikes 10 23 9 0 0 13 13
9 rec60 mountainsort4 10 60 4 4 0 46 50
10 rec300 spykingcircus 10 22 10 0 0 12 12
11 rec600 ironclust 10 13 9 0 0 3 3
12 rec600 mountainsort4 10 133 2 7 0 116 123
13 rec600 tridesclous 10 12 9 0 0 2 2
14 rec60 spykingcircus 10 20 9 0 0 10 10
15 rec300 kilosort2 10 16 10 0 0 6 6
16 rec600 kilosort 10 11 10 0 0 1 1
17 rec300 ironclust 10 12 8 0 0 2 2
18 rec1800 spykingcircus 10 37 9 1 0 26 27
19 rec1800 kilosort2 10 14 10 0 0 4 4
20 rec1800 ironclust 10 13 7 0 1 4 4
21 rec300 herdingspikes 10 15 9 0 0 5 5
22 rec300 kilosort 10 11 10 0 0 1 1
23 rec60 herdingspikes 10 10 8 0 0 0 0
24 rec60 tridesclous 10 10 9 0 0 0 0
25 rec600 kilosort2 10 16 10 0 0 6 6
26 rec60 ironclust 10 12 8 0 0 2 2
27 rec1800 kilosort 10 11 10 0 0 1 1
In [111]:
# insert durations
count_units['duration']  = count_units['rec_name'].apply(lambda s: float(s.replace('rec', '')))

num_well_detected vs duration

the more the better

In [112]:
g = sns.catplot(data=count_units, x='duration', y='num_well_detected', hue='sorter_name', kind='bar')

num_false_positive vs duration

the less the better

In [113]:
g = sns.catplot(data=count_units, x='duration', y='num_false_positive', hue='sorter_name', kind='bar')
In [114]:
# same as previous but with other limits
g = sns.catplot(data=count_units, x='duration', y='num_false_positive', hue='sorter_name', kind='bar')
g.fig.axes[0].set_ylim(0, 10)
Out[114]:
(0, 10)

num_redundant vs duration

the less the better

In [115]:
g = sns.catplot(data=count_units, x='duration', y='num_redundant', hue='sorter_name', kind='bar')

Conlusion

For this simple simulated dataset we have observed that:

  • Focusing on the average accuracy, all sorters have similar performance for long or short recordings. The only exception is Kilosort: it has a clear drop in performence for the shortest duration (60s).

  • Very surprinsingly, some sorters (e.g. tridesclous, ironclust) have better performence at 60s than 300s. This could be specific to this dataset and have to be instigate more.

  • Looking at the number of "num_false_positive" and "num_well_detected" the situation is the following:

    • kilosort is not affected by the duration
    • herdingspikes (one of the most affected): the longer the duration, the more "num_false_positive"
    • ironclust seems to have a slight increase in "num_false_positive" for longer duration
    • kilosort2 has random oscillations of "num_false_positive" across durations
    • tridesclous has a few more "num_false_positive" for long durations
    • moutainsort is heavily affected by the duration
    • spykingcircus is affected by long durations as more "num_false_positive" units are found

Spampinato mice retina mea252ch pair recording - part2

Spampinato mice retina mea252ch pair recording - 2

Part 2) Ground-truth comparison

This set of notebooks the dataset is from paired juxtacellular/extracellular recordings from mice retina in vitro. The MEA has 252 channels.

The official publication of this open dataset can be found at the following address: https://zenodo.org/record/1205233#.W9mq1HWLTIF

These datasets were used by Pierre Yger et al in the following "spyking circus" paper: https://elifesciences.org/articles/34518

After inspecting the juxta-cellular data, we found that some recordings don't have a good enough quality to be considered as "ground truth". To be "ground truth", a unit is required to be stable in the detection, peak signal-to-noise ratio (SNR) and amplitude.

At the end of our quality assessment ("spampinato-mice-retina-mea252ch-pair-recording-part1"), some files are removed for this main study.

Author: Samuel Garcia, CRNL, Lyon

Requirements

For this need you will need the following Python packages:

  • numpy
  • pandas
  • matplotlib
  • seaborn
  • spikeinterface

To run the MATLAB-based sorters, you would also need a MATLAB license. For other sorters, please refer to the documentation on how to install sorters.

Installation and imports

  1. to create a folder basedir.
  2. download all files from the zenodo link
  3. move them in a subfolder basedir/original_files (20160415_patch2.tar.gz, ...)
  4. and then execute cells by cell this notebook
In [16]:
# import everything
import os, getpass
import re

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path


import spikeinterface as si
import spikeinterface.extractors as se
import spikeinterface.sorters as ss
import spikeinterface.widgets as sw

from spikeinterface.comparison import GroundTruthStudy
In [ ]:
# clone and install MATLAB sorters

# kilosort2
!git clone https://github.com/MouseLand/Kilosort2.git
kilosort2_path = './Kilosort2'
ss.Kilosort2Sorter.set_kilosort2_path(kilosort2_path)

# kilosort
!git clone https://github.com/cortex-lab/KiloSort.git
kilosort_path = './KiloSort'
ss.KilosortSorter.set_kilosort_path(kilosort_path)

# ironclust
!git clone https://github.com/flatironinstitute/ironclust.git
ironclust_path = './ironclust'
ss.IronclustSorter.set_ironclust_path(ironclust_path)
In [17]:
si.print_spikeinterface_version()
ss.print_sorter_versions()
spikeinterface: 0.9.9
  * spikeextractor: 0.8.4
  * spiketoolkit: 0.6.3
  * spikesorters: 0.3.3
  * spikecomparison: 0.2.6
  * spikewidgets: 0.4.3

herdingspikes: 0.3.7+git.45665a2b6438
ironclust: 5.9.8
kilosort: git-cd040da1963d
kilosort2: git-e243c934339e
spykingcircus: 0.9.7
tridesclous: 1.6.1.dev
In [18]:
# my working path
basedir = '/media/samuel/dataspikesorting/DataSpikeSortingHD2/Pierre/zenodo/'

# input file
recording_folder = basedir + 'original_files/'

# ground truth information
ground_truth_folder = basedir + 'ground_truth/'

# where output will be
study_folder = basedir + 'study_gt252/'

# sorter list
sorter_list = ['tridesclous']

# selected recordings (8/19)
rec_names = [
    '20160415_patch2',
    '20170803_patch1',
    '20160426_patch3', 
    '20170725_patch1',
    '20170621_patch1',
    '20160426_patch2', 
    '20170728_patch2',
    '20170713_patch1',
]
In [19]:
%matplotlib inline

# some matplotlib hack to prettify figure
SMALL_SIZE = 12
MEDIUM_SIZE = 14
BIGGER_SIZE = 16

plt.rc('font', size=SMALL_SIZE)          # controls default text sizes
plt.rc('axes', titlesize=SMALL_SIZE)     # fontsize of the axes title
plt.rc('axes', labelsize=MEDIUM_SIZE)    # fontsize of the x and y labels
plt.rc('xtick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('ytick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('legend', fontsize=SMALL_SIZE)    # legend fontsize
plt.rc('figure', titlesize=BIGGER_SIZE)  # fontsize of the figure title

def clear_axes(ax):
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)    

 Step 1 : Clean original dataset

Setup study

In this step:

  • we create a dict of (recordings+sorting)
  • and call the function setup_comparison_study(study_folder, gt_dict)

Internally, spikeinterface copies all recordings and ground-truth sorting into an organised folder.

Important note:

  • the file have 256 channels but only 252 are useful. The PRB file contains all channels needed so we need to explicitly use grouping_property='group' to be sure to only take in account the channels in the unique group.
  • This step has to run only once and takes some times because of copy.
In [ ]:
gt_dict = {}
for rec_name in rec_names:

    # find raw file
    dirname = recording_folder + rec_name + '/'
    for f in os.listdir(dirname):
        if f.endswith('.raw') and not f.endswith('juxta.raw'):
            mea_filename = dirname + f

    # raw files have an internal offset that depend on the channel count
    # a simple built header can be parsed to get it
    with open(mea_filename.replace('.raw', '.txt'), mode='r') as f:
        offset = int(re.findall(r'padding = (\d+)', f.read())[0])

    # recording
    rec = se.BinDatRecordingExtractor(mea_filename, 20000., 256, 'uint16', offset=offset, time_axis=0)

    # this reduce channel count to 252
    rec = se.load_probe_file(rec, basedir + 'mea_256.prb')

    # gt sorting
    gt_indexes = np.fromfile(ground_truth_folder + rec_name + '/juxta_peak_indexes.raw', dtype='int64')
    sorting_gt = se.NumpySortingExtractor()
    sorting_gt.set_times_labels(gt_indexes, np.zeros(gt_indexes.size, dtype='int64'))
    sorting_gt.set_sampling_frequency(20000.0)

    gt_dict[rec_name] = (rec, sorting_gt)


study = GroundTruthStudy.create(study_folder, gt_dict)

Get signal to noise ratio for all units

In [20]:
study = GroundTruthStudy(study_folder)
snr = study.concat_all_snr()
snr
Out[20]:
snr
rec_name gt_unit_id
20160415_patch2 0 7.647741
20170803_patch1 0 13.230232
20160426_patch3 0 7.270211
20170725_patch1 0 28.532387
20170621_patch1 0 15.133425
20160426_patch2 0 4.661795
20170728_patch2 0 23.652615
20170713_patch1 0 14.934497
In [21]:
fig, ax = plt.subplots()
ax.hist(snr['snr'].values, bins=np.arange(0, 40, 5))
ax.set_xlabel('GT units SNR')
Out[21]:
Text(0.5, 0, 'GT units SNR')

Run all sorters

In [22]:
sorter_list = ['herdingspikes', 'ironclust', 'kilosort2',  
               'spykingcircus', 'tridesclous']
In [ ]:
study = GroundTruthStudy(study_folder)

study.run_sorters(sorter_list, mode='keep', verbose=False)

Run comparison with ground truth and retreive result tables

In [23]:
# this copy sorting is necessary to copy results from sorter
# into a centralize folder with all results
study.copy_sortings()
In [24]:
# this run all comparison to GT
# exhaustive_gt=False because it is a pair recording so only one GT units
study.run_comparisons(exhaustive_gt=False, match_score=0.1, overmerged_score=0.2)
In [25]:
# this retrieve results
comparisons = study.comparisons
dataframes = study.aggregate_dataframes()
In [26]:
## Run times
In [27]:
dataframes['run_times'].set_index(['rec_name', 'sorter_name']).unstack('sorter_name')
Out[27]:
run_time
sorter_name herdingspikes ironclust kilosort2 spykingcircus tridesclous
rec_name
20160415_patch2 142.888554 108.045776 229.796090 745.985181 391.462280
20160426_patch2 85.068679 78.040375 50.278513 315.495001 109.498236
20160426_patch3 72.955193 65.033961 46.270309 265.166210 96.272756
20170621_patch1 162.408083 118.569352 190.309174 603.793946 343.447162
20170713_patch1 169.336252 133.652825 246.374550 799.341534 464.072922
20170725_patch1 152.863248 101.050823 201.603548 710.756077 390.276237
20170728_patch2 150.883882 99.282168 150.312809 487.173840 309.182291
20170803_patch1 148.814701 114.053511 216.422934 753.743665 433.630640
In [28]:
sns.set_palette(sns.color_palette("Set1"))

#fig, ax = plt.subplots()
g = sns.catplot(data=dataframes['run_times'], x='sorter_name', y='run_time',
                hue="rec_name",
                order=sorter_list,
                kind='bar',
               legend=False)
g.fig.set_size_inches(12,5)

#ax.set_yscale('log')
#ax.set_ylabel('Run time (s)');
#ax.set_xlabel(None);
fig.axes[0].set_xticklabels(sorter_list, rotation=40, ha='right');

 accuracy/precision/recall scores per sorters

In [ ]:
 
In [29]:
#fig, ax = plt.subplots(figsize=(12,4))
#fig.tight_layout()

sns.set_palette(sns.color_palette("Set1"))

df = pd.melt(dataframes['perf_by_units'], id_vars=['rec_name', 'sorter_name'],
            var_name='metric', value_name='score', 
            value_vars=('accuracy','precision', 'recall'))
display(df)

# df.assign(Generation=df.metric.map({'metric':'Metric','accuracy': 'Accuracy', 'precision': 'Precision', 'recall': 'Recall'}))
g = sns.catplot(data=df, x='sorter_name', y='score', hue='metric', kind='swarm', dodge=True, # ax=ax, 
                order=sorter_list, legend_out=True, s=4)
g.fig.set_size_inches(12,5)

#plt.close(p.fig)

#ax.legend(['Accuracy', 'Precision', 'Recall'], labelspacing=0.2,
#          bbox_to_anchor=(1, 0.5), loc=2, borderaxespad=0., frameon=False)
#ax.set_xticklabels(sorter_list, rotation=30, ha='center')
#ax.set_xlabel(None);
#ax.set_ylabel('Score');
#clear_axes(ax)
rec_name sorter_name metric score
0 20160426_patch3 herdingspikes accuracy 0.694297
1 20170725_patch1 herdingspikes accuracy 0.919708
2 20170713_patch1 spykingcircus accuracy 0.933659
3 20160426_patch3 kilosort2 accuracy 0.703873
4 20160426_patch2 tridesclous accuracy 0.426621
... ... ... ... ...
115 20170803_patch1 spykingcircus recall 0.970153
116 20160415_patch2 spykingcircus recall 0.937962
117 20170728_patch2 herdingspikes recall 0.994103
118 20170728_patch2 kilosort2 recall 0.998736
119 20170713_patch1 herdingspikes recall 0.973463

120 rows × 4 columns

In [30]:
sns.set_palette(sns.color_palette("deep"))
g = sns.relplot(data=dataframes['perf_by_units'], x='precision', y='recall', col='sorter_name', 
                col_wrap=3, col_order=sorter_list, s=60)

 Accuracy vns SNR

In [31]:
df = dataframes['perf_by_units']
df
Out[31]:
rec_name sorter_name gt_unit_id accuracy recall precision false_discovery_rate miss_rate
0 20160426_patch3 herdingspikes 0 0.694297 0.698403 0.991604 0.00839631 0.301597
1 20170725_patch1 herdingspikes 0 0.919708 0.994737 0.924205 0.0757946 0.00526316
2 20170713_patch1 spykingcircus 0 0.933659 0.933659 1 0 0.0663413
3 20160426_patch3 kilosort2 0 0.703873 0.999409 0.704167 0.295833 0.000591366
4 20160426_patch2 tridesclous 0 0.426621 0.426621 1 0 0.573379
5 20160415_patch2 ironclust 0 0.683278 0.987194 0.689388 0.310612 0.0128059
6 20170621_patch1 kilosort2 0 0.411829 0.985783 0.414289 0.585711 0.0142171
7 20170803_patch1 kilosort2 0 0.998299 0.998822 0.999476 0.000523972 0.00117816
8 20170713_patch1 tridesclous 0 0.988047 0.995882 0.9921 0.00790033 0.00411774
9 20170728_patch2 ironclust 0 0.999789 1 0.999789 0.000210571 0
10 20170713_patch1 kilosort2 0 0.998627 0.998627 1 0 0.00137258
11 20170621_patch1 herdingspikes 0 0.897158 0.960753 0.931289 0.0687112 0.0392471
12 20160426_patch2 spykingcircus 0 0.993174 0.993174 1 0 0.00682594
13 20170725_patch1 kilosort2 0 1 1 1 0 0
14 20170621_patch1 spykingcircus 0 0.984591 0.985182 0.999391 0.000609385 0.0148178
15 20160426_patch3 ironclust 0 0.956397 0.985807 0.96975 0.0302501 0.0141928
16 20170725_patch1 tridesclous 0 0.994737 0.994737 1 0 0.00526316
17 20160415_patch2 herdingspikes 0 0.694369 0.803643 0.836245 0.163755 0.196357
18 20160426_patch2 herdingspikes 0 0 0 0 0 0
19 20160426_patch3 tridesclous 0 0.77836 0.791248 0.979502 0.0204978 0.208752
20 20160415_patch2 kilosort2 0 0.973838 0.995731 0.977921 0.0220794 0.00426864
21 20170621_patch1 tridesclous 0 0.975485 0.995995 0.979327 0.0206734 0.00400481
22 20160415_patch2 tridesclous 0 0.692271 0.818156 0.818156 0.181844 0.181844
23 20170725_patch1 ironclust 0 0.979058 0.984211 0.994681 0.00531915 0.0157895
24 20160426_patch3 spykingcircus 0 0.977568 0.979302 0.998192 0.00180832 0.0206978
25 20170803_patch1 ironclust 0 0.997775 0.998167 0.999607 0.000393288 0.0018327
26 20170803_patch1 tridesclous 0 0.989504 0.999607 0.989889 0.0101115 0.000392722
27 20160426_patch2 kilosort2 0 0.997725 0.997725 1 0 0.00227531
28 20160426_patch2 ironclust 0 0.993197 0.996587 0.996587 0.00341297 0.00341297
29 20170713_patch1 ironclust 0 0.94115 0.941437 0.999676 0.000323887 0.0585634
30 20170803_patch1 herdingspikes 0 0.939895 0.960073 0.978128 0.0218725 0.0399267
31 20170728_patch2 spykingcircus 0 0.852991 0.852991 1 0 0.147009
32 20170621_patch1 ironclust 0 0.914097 0.997197 0.916452 0.083548 0.00280336
33 20170728_patch2 tridesclous 0 1 1 1 0 0
34 20170725_patch1 spykingcircus 0 0.952632 0.952632 1 0 0.0473684
35 20170803_patch1 spykingcircus 0 0.970153 0.970153 1 0 0.0298468
36 20160415_patch2 spykingcircus 0 0.916064 0.937962 0.975148 0.0248521 0.0620376
37 20170728_patch2 herdingspikes 0 0.981085 0.994103 0.986828 0.0131716 0.00589722
38 20170728_patch2 kilosort2 0 0.998736 0.998736 1 0 0.00126369
39 20170713_patch1 herdingspikes 0 0.968295 0.973463 0.994547 0.00545341 0.0265365
In [32]:
df = dataframes['perf_by_units']

# add snr to the by-unit table
df['snr'] = None
for rec_name, gt_id in snr.index:
    df['snr'].loc[(df['gt_unit_id']==gt_id) & (df['rec_name']==rec_name)] = snr.at[(rec_name, gt_id), 'snr']

sns.set_palette(sns.color_palette("deep"))
g = sns.relplot(data=df, x='snr', y='accuracy', col='sorter_name',
        col_wrap=3, col_order=sorter_list, s=80)
/home/samuel/.virtualenvs/py37/lib/python3.7/site-packages/pandas-1.0.3-py3.7-linux-x86_64.egg/pandas/core/indexing.py:671: SettingWithCopyWarning: 
A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  self._setitem_with_indexer(indexer, value)

Spampinato mice retina mea252ch pair recording - part1

Spampinato mice retina mea252ch pair recording - 1

Part 1) Cleaning the ground-truth data

This set of notebooks the dataset is from paired juxtacellular/extracellular recordings from mice retina in vitro. The MEA has 252 channels.

The official publication of this open dataset can be found at the following address: https://zenodo.org/record/1205233#.W9mq1HWLTIF

These datasets were used by Pierre Yger et al in the following "spyking circus" paper: https://elifesciences.org/articles/34518

After inspecting the juxta-cellular data, we found that some recordings don't have a good enough quality to be considered as "ground truth". To be "ground truth", a unit is required to be stable in the detection, peak signal-to-noise ratio (SNR) and amplitude.

At the end of our quality assessment, some files are removed for the main study shown in "spampinato-mice-retina-mea252ch-pair-recording-part2".

Quality assessment details

First, we have to run the script detect_ground_truth_spike_on_juxta.py.

This script:

  • unzips the downloaded data
  • runs a juxta cellular detection
  • generates figure to manually check juxtacellular quality
  • computes the peak SNR on the max channel of the MEA.

Before running the script, we need:

  • to create a folder basedir
  • to create a subfolder basedir/original_files that contain all zip downloded (20160415_patch2.tar.gz, ...)

Then we can run the script detect_ground_truth_spike_on_juxta.py

After we can:

  • inscpect in each folder explanatory figures.

Author: Samuel Garcia, CRNL, Lyon

Criterium to keep or remove a file

Having a very reliable ground truth is crucial, as all the following spike sorting performance metrics are designed on the hypothesis the ground truth is indeed ground truth.

In the following script we choose a high threshold value for peak detection: thresh = med + 8*mad, where:

  • med is the median of the signal (the baseline),
  • mad is the median absolut deviation (a robust std estimation),
  • 8 is a quite high relative threshold that ensures the absence of false positive.

Two main criteria were used to keep a recording:

  • the distribution of the peak values of the juxtacelullar action potentials must have a Gaussian distribution:
    • a truncated Gaussian suggests that false negative (misses) corrupt the "ground truth",
    • a multi-modal distribution suggests either that an amplitude drift occured or that two (or more) cells were present.

List of accepted recording (8)

'20160415_patch2',
'20170803_patch1',
'20160426_patch3', 
'20170725_patch1',
'20170621_patch1',
'20160426_patch2', 
'20170728_patch2',
'20170713_patch1',

List of rejected recording (11)

'20170706_patch2'
'20170629_patch2'
'20170622_patch2'
'20170726_patch1'
'20170706_patch1'
'20170706_patch3'
'20170627_patch1'
'20170630_patch1'
'20170629_patch3'
'20170623_patch1'
'20170622_patch1'

(Some reader may think that we are too strict, but we prefer to be strict to ensure safe final results. Feel free to modify this list as you prefer using your own criteria.)

In [2]:
import matplotlib

import os, shutil
import zipfile, tarfile
import re

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt


# path
basedir = '/media/samuel/dataspikesorting/DataSpikeSortingHD2/Pierre/zenodo/'
recording_folder = basedir + 'original_files/'
ground_truth_folder = basedir + 'ground_truth/'

%matplotlib notebook

 Step 1 : Re-detect properly juxtacellular peaks

In [ ]:
# this tridesclous utils are imported only for juxta detection to keep this script simple
from tridesclous.peakdetector import  detect_peaks_in_chunk
from tridesclous.tools import median_mad
from tridesclous.waveformtools import extract_chunks
In [ ]:
rec_names = ['20170629_patch3', '20170728_patch2', '20170630_patch1', '20160426_patch2', '20170621_patch1',
             '20170627_patch1', '20170706_patch3', '20170706_patch1', '20170726_patch1', '20170725_patch1',
             '20160426_patch3', '20170622_patch1', '20170623_patch1', '20170622_patch2', '20170629_patch2',
             '20170713_patch1', '20160415_patch2', '20170706_patch2', '20170803_patch1']

Unzip all

In [ ]:
# this unzip all files into recording_folder
for rec_name in rec_names:
    filename = recording_folder + rec_name + '.tar.gz'

    if os.path.exists(recording_folder+rec_name) and os.path.isdir(recording_folder+rec_name):
        continue
    print('unzip', rec_name)
    t = tarfile.open(filename, mode='r|gz')
    t.extractall(recording_folder+rec_name)

 Detect ground-truth spikes on juxta

In [ ]:
if not os.path.exists(ground_truth_folder):
    os.mkdir(ground_truth_folder)

gt_info = pd.DataFrame(index=rec_names)

for rec_name in rec_names:
    print('detect_juxta: ', rec_name)

    # get juxta signal
    dirname = recording_folder + rec_name + '/'
    for f in os.listdir(dirname):
        if  f.endswith('juxta.raw'):
            juxta_filename = dirname + f
    juxta_sig = np.memmap(juxta_filename, dtype='float32')

    # get mea signals
    for f in os.listdir(dirname):
        if f.endswith('.raw') and not f.endswith('juxta.raw'):
            mea_filename = dirname + f
    with open(mea_filename.replace('.raw', '.txt'), mode='r') as f:
        offset = int(re.findall('padding = (\d+)', f.read())[0])
    mea_sigs = np.memmap(mea_filename, dtype='uint16', offset=offset).reshape(-1, 256)
    print(1)
    # select only the 252 mea channel (see PRB file)
    mea_sigs = mea_sigs[:, list(range(126)) + list(range(128,254))]
    print(2)
    gt_folder = ground_truth_folder + rec_name + '/'
    os.mkdir(gt_folder)

    # detect spikes
    med, mad = median_mad(juxta_sig)
    print(3)
    thresh = med + 8*mad
    gt_indexes = detect_peaks_in_chunk(juxta_sig[:, None], k=10,thresh=thresh, peak_sign='-')
    gt_indexes = gt_indexes.astype('int64')
    gt_indexes.tofile(gt_folder+'juxta_peak_indexes.raw')
    print(4)
    # save some figures to for visual cheking
    sr = 20000.
    times = np.arange(juxta_sig.size) / sr

    fig, ax = plt.subplots()
    ax.plot(times, juxta_sig)
    ax.plot(times[gt_indexes], juxta_sig[gt_indexes], ls='None', color='r', marker='o')
    ax.set_xlim(0, 10)
    ax.axhline(-thresh, color='k', ls='--')
    ax.set_title('juxta detection - ' + rec_name)
    fig.savefig(gt_folder+'juxta detection.png')

    fig, ax = plt.subplots()
    count, bins = np.histogram(juxta_sig[gt_indexes], bins=np.arange(np.min(juxta_sig[gt_indexes]), 0,  0.5))
    ax.plot(bins[:-1], count)
    ax.axvline(-thresh, color='k', ls='--')
    ax.set_title('juxta peak amplitude - ' + rec_name)
    fig.savefig(gt_folder+'juxta peak amplitude.png')

    print(5)

    # extract waveforms with only 150 peaks to minimize RAM
    n_left, n_right = -45, 60
    some_gt_indexes = np.random.choice(gt_indexes, size=150)
    waveforms = extract_chunks(mea_sigs, some_gt_indexes+n_left, n_right-n_left)
    wf_median, wf_mad = median_mad(waveforms, axis=0)

    print(6)
    # get on wich channel the max is and the value
    max_on_channel = np.argmin(np.min(wf_median, axis=0), axis=0)

    # get the MAD (robust STD) on the mea signal
    # this estimate the SNR
    mea_median, mea_mad = median_mad(mea_sigs[:, max_on_channel] , axis=0)
    baseline = mea_median
    print(7)
    peak_value = np.min(wf_median[:, max_on_channel])
    peak_value = peak_value- baseline
    peak_snr = np.abs(peak_value/mea_mad)

    # evrything in Dataframe
    gt_info.at[rec_name, 'nb_spike'] = gt_indexes.size
    gt_info.at[rec_name, 'max_on_channel'] = max_on_channel
    gt_info.at[rec_name, 'peak_value'] = peak_value
    gt_info.at[rec_name, 'peak_snr'] = peak_snr
    gt_info.at[rec_name, 'noise_mad'] = mea_mad


    fig, ax = plt.subplots()
    ax.plot(wf_median.T.flatten())
    fig.savefig(gt_folder+'GT waveforms flatten.png')

    fig, ax = plt.subplots()
    ax.plot(wf_median)
    ax.axvline(-n_left)
    fig.savefig(gt_folder+'GT waveforms.png')
    print(8)


gt_info.to_excel(ground_truth_folder+'gt_info.xlsx')

 Step2 : Check juxtacellular quality

In [3]:
# 2 simple functions

def get_juxta_filename(rec_name):
    # find the juxta file
    dirname = recording_folder + rec_name + '/'
    for f in os.listdir(dirname):
        if  f.endswith('juxta.raw'):
            juxta_filename = dirname + f
            return juxta_filename

def plot_juxta_amplitude(rec_name):
    juxta_filename = get_juxta_filename(rec_name)
    juxta_sig = np.memmap(juxta_filename, dtype='float32')
    
    med = np.median(juxta_sig)
    mad = np.median(np.abs(juxta_sig-med))*1.4826
    thresh = med + 8*mad
    
    
    gt_indexes = ground_truth_folder + 'juxta_peak_indexes.raw'
    gt_indexes = np.fromfile(ground_truth_folder + rec_name + '/juxta_peak_indexes.raw', dtype='int64')
    gt_amplitudes = juxta_sig[gt_indexes]
    
    fig, axs = plt.subplots(nrows=2)
    count, bins = np.histogram(gt_amplitudes, bins=np.arange(np.min(juxta_sig[gt_indexes]), 0,  0.5))
    ax = axs[0]
    ax.plot(bins[:-1], count)
    ax.axvline(-thresh, color='r', ls='--')
    ax.axvline(med, color='k', ls='-')
    for i in range(1,6):
        ax.axvspan(med - i * mad, med + i * mad, color='k', alpha=0.05)
    
    fig.suptitle('juxta peak amplitude - ' + rec_name)
    
    
    ax = axs[1]
    ax.plot(gt_indexes, gt_amplitudes, ls='None', marker='o')
    ax.axhline(-thresh, color='r', ls='--')
    for i in range(1,6):
        ax.axhspan(med - i * mad, med + i * mad, color='k', alpha=0.05)
    

Why some recordings are are not kept?

In the following figures:

  • the black vertical line is the baseline (median) of juxta-cellular trace,
  • the grey areas represent 1, 2, 3, 4, 5 MAD (robust STD),
  • the red line is the detection threshold.

Figure for 20170706_patch2

For this cell too few events are detected.

In [7]:
plot_juxta_amplitude('20170706_patch2')

Figure for 20170629_patch2

Here the peak amplitude distribution crosses the detection threhold. Missed events are obvious in the middle part of recording.

In [8]:
plot_juxta_amplitude('20170629_patch2')

Figure for 20170622_patch2

Here the peak amplitude distribution crosses the detection threhold and too few events got detected.

In [9]:
plot_juxta_amplitude('20170622_patch2')

Figure for 20170726_patch1

Here again the peak amplitude distribution crosses the detection threhold. Some spikes are clearly missed at the begnning.

In [11]:
plot_juxta_amplitude('20170726_patch1')

Figure for 20170706_patch1

Obvious missing spikes.

In [12]:
plot_juxta_amplitude('20170706_patch1')

Figure for 20170706_patch3

Obvious missing spikes.

In [13]:
plot_juxta_amplitude('20170706_patch3')

Figure for 20170627_patch1

Obvious missing spikes

In [14]:
plot_juxta_amplitude('20170627_patch1')

Figure for 20170630_patch1

Suspicion of missing spikes at the beggining and at the end of recording.

In [15]:
plot_juxta_amplitude('20170630_patch1')

Figure for 20170629_patch3

Obvious missing spikes.

In [16]:
plot_juxta_amplitude('20170629_patch3')

Figure for 20170623_patch1 : NO

Here the amplitude distribution right tail is too close to the detection threhold and there is a suspicion of missed spikes.

In [18]:
plot_juxta_amplitude('20170623_patch1')

List of clean ground truth

Figure for 20170713_patch1 : OK, but boundary

We see here two clear peaks in the distribution suggesting that there could be an electrode movement.

In [19]:
plot_juxta_amplitude('20170713_patch1')

Figure for 20160415_patch2 : OK

A ground truth unit we can trust!!

In [20]:
plot_juxta_amplitude('20160415_patch2')

Figure for 20170803_patch1 : OK

In [21]:
plot_juxta_amplitude('20170803_patch1')

Figure for 20160426_patch3 : OK

A ground truth we can trust!!

In [23]:
plot_juxta_amplitude('20160426_patch3')

Figure for 20170725_patch1 : OK but bundary

A ground truth we can trust, but there is a suspicious change in amplitude in the middle of recording.

In [25]:
plot_juxta_amplitude('20170725_patch1')

Figure for 20170621_patch1 : OK

In [26]:
plot_juxta_amplitude('20170621_patch1')

Figure for 20160426_patch2 : OK

In [27]:
plot_juxta_amplitude('20160426_patch2')

Figure for 20170728_patch2 : OK

Ok but some movement at the end...

In [28]:
plot_juxta_amplitude('20170728_patch2')

Conclusion

  • 11 out of 19 files have been removed for further ground-truth analysis.
  • 8 out of 19 files are kept for ground-truth analysis.

For paired recording ground truth, the ground truth itself has to be carefully verified.

The original ground spike index provided on 19 files by the authors are not trustable for a fair spike sorting comparison.

Example of parameters optimization

Example of parameters optimization with SpikeInterface

This notebook is an example of parameters optimization with SpikeInterface. See also this notebook on ground-truth comparison that analyzes the same dataset.

The dataset is simulated using MEArec and it can be downloaded from Zenodo: https://doi.org/10.5281/zenodo.4058272

The dataset name is: recordings_50cells_SqMEA-10-15_600.0_10.0uV_21-01-2020_18-12.h5. It contains 50 neurons recorded on a 10x10 MEA with 15um pitch. The duration is 600s and the noise level is 10uV.

Many sorters a parameters of adjacency radius that help the sorter to group spike that have a peak on several channel.

  • ironlcust uses adjacency_radius=50 by default
  • spykingcircus uses adjacency_radius=200 by default
  • herdingspikes uses probe_neighbor_radius=90 by default

Since each channel is spaced by 15um in this dataset, the radius parameters can affect the way each sorter detects and sorts spikes that are recorded on several channels.

Author: Samuel Garcia, CRNL, Lyon

Requirements

For this need you will need the following Python packages:

  • numpy
  • pandas
  • matplotlib
  • seaborn
  • spikeinterface

To run the MATLAB-based sorters, you would also need a MATLAB license. For other sorters, please refer to the documentation on how to install sorters.

Installation and imports

In [2]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path

import spikeinterface as si
import spikeinterface.extractors as se
import spikeinterface.sorters as ss
import spikeinterface.widgets as sw
import spikeinterface.comparison as sc
In [ ]:
# clone and install MATLAB sorters

# ironclust
!git clone https://github.com/flatironinstitute/ironclust.git
ironclust_path = './ironclust'
ss.IronclustSorter.set_ironclust_path(ironclust_path)

Check spikeinterface version and sorter version

In order to be a bit reproducible lets checks for versions of each sorter and spikeinterface subpackage.

In [3]:
si.print_spikeinterface_version()
ss.print_sorter_version()
spikeinterface: 0.9.1
  * spikeextractor: 0.7.1
  * spiketoolkit: 0.5.1
  * spikesorters: 0.2.2
  * spikecomparison: 0.2.2
  * spikewidgets: 0.3.2

herdingspikes: 0.3.3+git.f5232ac3520d
klusta: 3.0.16
mountainsort4: unknown
spykingcircus: 0.9.0
tridesclous: 1.4.2

Setup global path

In [12]:
# Change this path to point to where you downloaded the dataset
p = Path('/home/samuel/Documents/DataSpikeSorting/mearec/')
study_folder = p / 'study_mearec_SqMEA1015um/'
param_folder = p / 'param_search'

mearec_filename = p / 'recordings_50cells_SqMEA-10-15_600.0_10.0uV_21-01-2020_18-12.h5'

Run sorters with differents parameters

Lets try several radius : 20, 50, 100, 150, 200

results are saved to files

In [13]:
sorter_names = ['ironclust', 'spykingcircus', 'herdingspikes']
radius_list = [20, 50, 100, 150, 200]
In [ ]:
rec  = se.MEArecRecordingExtractor(mearec_filename)

param_names = {
    'ironclust': 'adjacency_radius',
    'spykingcircus': 'adjacency_radius',
    'herdingspikes': 'probe_neighbor_radius',
}

for sorter_name in sorter_names:
    for radius in radius_list:
        params = {param_names[sorter_name] : radius}
        sorting = ss.run_sorter(sorter_name, rec,
                                    output_folder=param_folder / f'{sorter_name}_{radius}',
                                    delete_output_folder=True,
                                    **params)
        se.NpzSortingExtractor.write_sorting(sorting, param_folder / f'{sorter_name}_{radius}.npz')

Retrieve results and compare to ground truth

In [14]:
study = sc.GroundTruthStudy(study_folder)
snr = study.get_units_snr()
In [16]:
gt_sorting = se.MEArecSortingExtractor(mearec_filename)

fig, axs = plt.subplots(nrows=len(sorter_names), ncols=len(radius_list), figsize=(20,10))
for r, sorter_name in enumerate(sorter_names):
    for c, radius in enumerate(radius_list):
        sorting = se.NpzSortingExtractor(param_folder / f'{sorter_name}_{radius}.npz')
        comp = sc.compare_sorter_to_ground_truth(gt_sorting, sorting)
        perfs = comp.get_performance(method='by_unit')
        axs[r, c].scatter(snr['snr'], perfs['accuracy'], s=10)


        axs[0, c].set_title(f'radius {radius}')

    axs[r, 0].set_ylabel(sorter_name)

Conclusion

In this preliminary example, we found that:

  • spykingcircus is hardly affected by the adjacency radius parameter
  • herdingspikes is only affected (worsened results) for very small radius (20um)
  • ironclust tends to prefer higher radiuses as the best performance is achieved with 150-200um

Example of ground truth comparison study

Example of ground-truth comparison with SpikeInterface

This notebook shows how SpikeInterface is used to perform a ground truth comparisong for multiple spike sorters.

The dataset is simulated using MEArec and it can be downloaded from Zenodo: https://doi.org/10.5281/zenodo.4058272

The dataset name is: recordings_50cells_SqMEA-10-15_600.0_10.0uV_21-01-2020_18-12.h5. It contains 50 neurons recorded on a 10x10 MEA with 15um pitch. The duration is 600s and the noise level is 10uV.

Author: Samuel Garcia, CRNL, Lyon

Requirements

For this need you will need the following Python packages:

  • numpy
  • pandas
  • matplotlib
  • seaborn
  • spikeinterface

To run the MATLAB-based sorters, you would also need a MATLAB license. For other sorters, please refer to the documentation on how to install sorters.

In [1]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path


import spikeinterface as si
import spikeinterface.extractors as se
import spikeinterface.sorters as ss
import spikeinterface.widgets as sw

from spikeinterface.comparison import GroundTruthStudy
/home/samuel/.virtualenvs/py36/lib/python3.6/site-packages/statsmodels/tools/_testing.py:19: FutureWarning: pandas.util.testing is deprecated. Use the functions in the public API at pandas.testing instead.
  import pandas.util.testing as tm
In [ ]:
# clone and install MATLAB sorters

# kilosort2
!git clone https://github.com/MouseLand/Kilosort2.git
kilosort2_path = './Kilosort2'
ss.Kilosort2Sorter.set_kilosort2_path(kilosort2_path)

# kilosort
!git clone https://github.com/cortex-lab/KiloSort.git
kilosort_path = './KiloSort'
ss.KilosortSorter.set_kilosort_path(kilosort_path)

# ironclust
!git clone https://github.com/flatironinstitute/ironclust.git
ironclust_path = './ironclust'
ss.IronclustSorter.set_ironclust_path(ironclust_path)
In [35]:
%matplotlib inline

# some matplotlib hack to prettify figure
SMALL_SIZE = 12
MEDIUM_SIZE = 14
BIGGER_SIZE = 16

plt.rc('font', size=SMALL_SIZE)          # controls default text sizes
plt.rc('axes', titlesize=SMALL_SIZE)     # fontsize of the axes title
plt.rc('axes', labelsize=MEDIUM_SIZE)    # fontsize of the x and y labels
plt.rc('xtick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('ytick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('legend', fontsize=SMALL_SIZE)    # legend fontsize
plt.rc('figure', figsize=(10.0, 8.0))  # figsize
plt.rc('figure', titlesize=BIGGER_SIZE)  # fontsize of the figure title

def clear_axes(ax):
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)    

Check spikeinterface version and sorter version

In order to be a bit reproducible lets checks for versions of each sorter and spikeinterface subpackage.

In [5]:
si.print_spikeinterface_version()
ss.print_sorter_versions()
spikeinterface: 0.9.1
  * spikeextractor: 0.7.1
  * spiketoolkit: 0.5.1
  * spikesorters: 0.2.4.dev0
  * spikecomparison: 0.2.2
  * spikewidgets: 0.3.2

herdingspikes: 0.3.3+git.f5232ac3520d
ironclust: 5.7.3
kilosort: git-cd040da1963d
kilosort2: git-67a42a87b866
klusta: 3.0.16
mountainsort4: unknown
spykingcircus: 0.9.2
tridesclous: 1.5.0

Setup global path

In [6]:
# Change this path to point to where you downloaded the dataset
p = Path('/home/samuel/Documents/DataSpikeSorting/mearec/')
study_folder = p / 'study_mearec_SqMEA1015um/'

Setup ground truth study

In [10]:
# lets load from mearec file the recording (traces) and the true sorting (spiketrains and units)
mearec_filename = p / 'recordings_50cells_SqMEA-10-15_600.0_10.0uV_21-01-2020_18-12.h5'
rec0  = se.MEArecRecordingExtractor(mearec_filename)
gt_sorting0 = se.MEArecSortingExtractor(mearec_filename)

# a study can have several pari of recording/GT sorting
# here we have only one
gt_dict = {'rec0' : (rec0, gt_sorting0) }

study = GroundTruthStudy.create(study_folder, gt_dict)

Run all sorters

In [13]:
sorter_list = ['herdingspikes', 'ironclust', 'kilosort2',  'kilosort',
                'spykingcircus', 'tridesclous'] 
In [ ]:
study = GroundTruthStudy(study_folder)

# here we will take default params for each sorter engine
# but we could change it
sorter_params = {} 

study.run_sorters(sorter_list, sorter_params=sorter_params, mode='keep', verbose=True)

Get signal to noise ratio for all units

In [7]:
study = GroundTruthStudy(study_folder)
snr = study.get_units_snr()
snr.head(10)
Out[7]:
snr rec_name
gt_unit_id
0 5.837606 rec0
1 6.735090 rec0
2 7.249846 rec0
3 21.647741 rec0
4 10.273477 rec0
5 15.612622 rec0
6 13.773396 rec0
7 9.518503 rec0
8 12.626004 rec0
9 6.680142 rec0
In [36]:
fig, ax = plt.subplots()
ax.hist(snr['snr'].values, bins=20)
ax.set_xlabel('GT units SNR')
Out[36]:
Text(0.5, 0, 'GT units SNR')

Run comparison with ground truth and retreive result tables

In [9]:
# this copy sorting is necessary to copy results from sorter
# into a centralize folder with all results
study.copy_sortings()

# this run all comparison sto GT
study.run_comparisons(exhaustive_gt=True, match_score=0.1, overmerged_score=0.2)
In [10]:
# this retrieve results
comparisons = study.comparisons
dataframes = study.aggregate_dataframes()

Run times

In [11]:
dataframes['run_times']
Out[11]:
rec_name sorter_name run_time
0 rec0 kilosort 359.392333
1 rec0 spykingcircus 742.992190
2 rec0 herdingspikes 148.370694
3 rec0 ironclust 128.119042
4 rec0 kilosort2 148.978660
5 rec0 tridesclous 616.314811
In [37]:
fig, ax = plt.subplots()
sns.barplot(data=dataframes['run_times'], x='sorter_name', y='run_time', ax=ax, order=sorter_list)

sns.set_palette(sns.color_palette("Set1"))
ax.set_ylabel('Run time (s)');
ax.set_xlabel(None);
ax.set_xticklabels(sorter_list, rotation=40, ha='right');

confusion matrix

In [38]:
for (rec_name, sorter_name), comp in comparisons.items():
    fig, ax = plt.subplots()
    sw.plot_agreement_matrix(comp, ax=ax)
    fig.suptitle(rec_name+'   '+ sorter_name)

 accuracy/precision/recall scores per sorters

In [44]:
sns.set_palette(sns.color_palette("Set1"))

df = pd.melt(dataframes['perf_by_units'], id_vars='sorter_name', var_name='metric', value_name='score', 
        value_vars=('accuracy','precision', 'recall'))

g = sns.catplot(data=df, x='sorter_name', y='score', hue='metric', kind='swarm', dodge=True,
                order=sorter_list, legend_out=True, s=4)
g.fig.set_size_inches(15,8)
ax = g.axes[0, 0]
#ax.legend(['Accuracy', 'Precision', 'Recall'], labelspacing=0.2,
#          bbox_to_anchor=(1, 0.5), loc=2, borderaxespad=0., frameon=False)
ax.set_xticklabels(sorter_list, rotation=30, ha='center')
ax.set_xlabel(None)
ax.set_ylabel('Score')
clear_axes(ax)

A wise man (a reviewer of our paper to be honest) suggest us that the swarnplot is not so infortative to benchmark sorters. This is totally true, we can check check the balance between false postive spikes and and false negative spikes (precission vs recall). Here this is another possible way to present the results.

In [21]:
sns.set_palette(sns.color_palette("deep"))
g = sns.relplot(data=dataframes['perf_by_units'], x='precision', y='recall', col='sorter_name', 
                col_wrap=3, col_order=sorter_list, s=60)

count units

  • well detected
  • false positive
  • redundant
  • overmerged
In [34]:
sns.set_palette(sns.color_palette("Set2"))

df = pd.melt(dataframes['count_units'], id_vars='sorter_name', var_name='metric', value_name='score', 
        value_vars=('num_well_detected', 'num_false_positive', 'num_redundant', 'num_overmerged'))
g = sns.catplot(x='sorter_name', y='score', hue='metric', data=df,
                height=6, kind="bar", order=sorter_list)

ax = g.axes[0,0]
ax.set_xticklabels(sorter_list, rotation=30, ha='right')

ax.get_legend()
ax.legend(bbox_to_anchor=(0.9, 0.95), borderaxespad=0., frameon=False, labelspacing=0.2)
for t, l in zip(ax.legend_.texts,("Well detected", "False positive", "Redundant", "Overmerged")):
    t.set_text(l)


ax.set_xlabel(None)
ax.set_ylabel('Number of units')
clear_axes(ax)

 Accuracy vns SNR

In [30]:
df = dataframes['perf_by_units']


# add snr to the by-unit table
df['snr'] = None
for gt_id in snr.index:
    df['snr'].loc[df['gt_unit_id']==gt_id] = snr.at[gt_id, 'snr']
df



sns.set_palette(sns.color_palette("deep"))
g = sns.relplot(data=dataframes['perf_by_units'], x='snr', y='accuracy', col='sorter_name', 
                col_wrap=3, col_order=sorter_list, s=60)
for i,ax in enumerate(g.axes):
    ax.set_title(sorter_list[i])
    ax.set_xlabel('')

    g.axes[0].set_xlabel('SNR');

g.axes[0].set_ylabel('Accuracy');
/home/samuel/.virtualenvs/py36/lib/python3.6/site-packages/pandas/core/indexing.py:670: SettingWithCopyWarning: 
A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  self._setitem_with_indexer(indexer, value)