Spampinato mice retina mea252ch pair recording - part2
Spampinato mice retina mea252ch pair recording - 2¶
Part 2) Ground-truth comparison¶
This set of notebooks the dataset is from paired juxtacellular/extracellular recordings from mice retina in vitro. The MEA has 252 channels.
The official publication of this open dataset can be found at the following address: https://zenodo.org/record/1205233#.W9mq1HWLTIF
These datasets were used by Pierre Yger et al in the following "spyking circus" paper: https://elifesciences.org/articles/34518
After inspecting the juxta-cellular data, we found that some recordings don't have a good enough quality to be considered as "ground truth". To be "ground truth", a unit is required to be stable in the detection, peak signal-to-noise ratio (SNR) and amplitude.
At the end of our quality assessment ("spampinato-mice-retina-mea252ch-pair-recording-part1"), some files are removed for this main study.
Author: Samuel Garcia, CRNL, Lyon
Requirements¶
For this need you will need the following Python packages:
- numpy
- pandas
- matplotlib
- seaborn
- spikeinterface
To run the MATLAB-based sorters, you would also need a MATLAB license. For other sorters, please refer to the documentation on how to install sorters.
Installation and imports¶
- to create a folder basedir.
- download all files from the zenodo link
- move them in a subfolder basedir/original_files (20160415_patch2.tar.gz, ...)
- and then execute cells by cell this notebook
# import everything
import os, getpass
import re
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import spikeinterface as si
import spikeinterface.extractors as se
import spikeinterface.sorters as ss
import spikeinterface.widgets as sw
from spikeinterface.comparison import GroundTruthStudy
# clone and install MATLAB sorters
# kilosort2
!git clone https://github.com/MouseLand/Kilosort2.git
kilosort2_path = './Kilosort2'
ss.Kilosort2Sorter.set_kilosort2_path(kilosort2_path)
# kilosort
!git clone https://github.com/cortex-lab/KiloSort.git
kilosort_path = './KiloSort'
ss.KilosortSorter.set_kilosort_path(kilosort_path)
# ironclust
!git clone https://github.com/flatironinstitute/ironclust.git
ironclust_path = './ironclust'
ss.IronclustSorter.set_ironclust_path(ironclust_path)
si.print_spikeinterface_version()
ss.print_sorter_versions()
# my working path
basedir = '/media/samuel/dataspikesorting/DataSpikeSortingHD2/Pierre/zenodo/'
# input file
recording_folder = basedir + 'original_files/'
# ground truth information
ground_truth_folder = basedir + 'ground_truth/'
# where output will be
study_folder = basedir + 'study_gt252/'
# sorter list
sorter_list = ['tridesclous']
# selected recordings (8/19)
rec_names = [
'20160415_patch2',
'20170803_patch1',
'20160426_patch3',
'20170725_patch1',
'20170621_patch1',
'20160426_patch2',
'20170728_patch2',
'20170713_patch1',
]
%matplotlib inline
# some matplotlib hack to prettify figure
SMALL_SIZE = 12
MEDIUM_SIZE = 14
BIGGER_SIZE = 16
plt.rc('font', size=SMALL_SIZE) # controls default text sizes
plt.rc('axes', titlesize=SMALL_SIZE) # fontsize of the axes title
plt.rc('axes', labelsize=MEDIUM_SIZE) # fontsize of the x and y labels
plt.rc('xtick', labelsize=SMALL_SIZE) # fontsize of the tick labels
plt.rc('ytick', labelsize=SMALL_SIZE) # fontsize of the tick labels
plt.rc('legend', fontsize=SMALL_SIZE) # legend fontsize
plt.rc('figure', titlesize=BIGGER_SIZE) # fontsize of the figure title
def clear_axes(ax):
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
Step 1 : Clean original dataset¶
Setup study¶
In this step:
- we create a dict of (recordings+sorting)
- and call the function
setup_comparison_study(study_folder, gt_dict)
Internally, spikeinterface copies all recordings and ground-truth sorting into an organised folder.
Important note:
- the file have 256 channels but only 252 are useful. The PRB file contains all channels needed so we need to explicitly use grouping_property='group' to be sure to only take in account the channels in the unique group.
- This step has to run only once and takes some times because of copy.
gt_dict = {}
for rec_name in rec_names:
# find raw file
dirname = recording_folder + rec_name + '/'
for f in os.listdir(dirname):
if f.endswith('.raw') and not f.endswith('juxta.raw'):
mea_filename = dirname + f
# raw files have an internal offset that depend on the channel count
# a simple built header can be parsed to get it
with open(mea_filename.replace('.raw', '.txt'), mode='r') as f:
offset = int(re.findall(r'padding = (\d+)', f.read())[0])
# recording
rec = se.BinDatRecordingExtractor(mea_filename, 20000., 256, 'uint16', offset=offset, time_axis=0)
# this reduce channel count to 252
rec = se.load_probe_file(rec, basedir + 'mea_256.prb')
# gt sorting
gt_indexes = np.fromfile(ground_truth_folder + rec_name + '/juxta_peak_indexes.raw', dtype='int64')
sorting_gt = se.NumpySortingExtractor()
sorting_gt.set_times_labels(gt_indexes, np.zeros(gt_indexes.size, dtype='int64'))
sorting_gt.set_sampling_frequency(20000.0)
gt_dict[rec_name] = (rec, sorting_gt)
study = GroundTruthStudy.create(study_folder, gt_dict)
Get signal to noise ratio for all units¶
study = GroundTruthStudy(study_folder)
snr = study.concat_all_snr()
snr
fig, ax = plt.subplots()
ax.hist(snr['snr'].values, bins=np.arange(0, 40, 5))
ax.set_xlabel('GT units SNR')
Run all sorters¶
sorter_list = ['herdingspikes', 'ironclust', 'kilosort2',
'spykingcircus', 'tridesclous']
study = GroundTruthStudy(study_folder)
study.run_sorters(sorter_list, mode='keep', verbose=False)
Run comparison with ground truth and retreive result tables¶
# this copy sorting is necessary to copy results from sorter
# into a centralize folder with all results
study.copy_sortings()
# this run all comparison to GT
# exhaustive_gt=False because it is a pair recording so only one GT units
study.run_comparisons(exhaustive_gt=False, match_score=0.1, overmerged_score=0.2)
# this retrieve results
comparisons = study.comparisons
dataframes = study.aggregate_dataframes()
## Run times
dataframes['run_times'].set_index(['rec_name', 'sorter_name']).unstack('sorter_name')
sns.set_palette(sns.color_palette("Set1"))
#fig, ax = plt.subplots()
g = sns.catplot(data=dataframes['run_times'], x='sorter_name', y='run_time',
hue="rec_name",
order=sorter_list,
kind='bar',
legend=False)
g.fig.set_size_inches(12,5)
#ax.set_yscale('log')
#ax.set_ylabel('Run time (s)');
#ax.set_xlabel(None);
fig.axes[0].set_xticklabels(sorter_list, rotation=40, ha='right');
accuracy/precision/recall scores per sorters¶
#fig, ax = plt.subplots(figsize=(12,4))
#fig.tight_layout()
sns.set_palette(sns.color_palette("Set1"))
df = pd.melt(dataframes['perf_by_units'], id_vars=['rec_name', 'sorter_name'],
var_name='metric', value_name='score',
value_vars=('accuracy','precision', 'recall'))
display(df)
# df.assign(Generation=df.metric.map({'metric':'Metric','accuracy': 'Accuracy', 'precision': 'Precision', 'recall': 'Recall'}))
g = sns.catplot(data=df, x='sorter_name', y='score', hue='metric', kind='swarm', dodge=True, # ax=ax,
order=sorter_list, legend_out=True, s=4)
g.fig.set_size_inches(12,5)
#plt.close(p.fig)
#ax.legend(['Accuracy', 'Precision', 'Recall'], labelspacing=0.2,
# bbox_to_anchor=(1, 0.5), loc=2, borderaxespad=0., frameon=False)
#ax.set_xticklabels(sorter_list, rotation=30, ha='center')
#ax.set_xlabel(None);
#ax.set_ylabel('Score');
#clear_axes(ax)
sns.set_palette(sns.color_palette("deep"))
g = sns.relplot(data=dataframes['perf_by_units'], x='precision', y='recall', col='sorter_name',
col_wrap=3, col_order=sorter_list, s=60)
Accuracy vns SNR¶
df = dataframes['perf_by_units']
df
df = dataframes['perf_by_units']
# add snr to the by-unit table
df['snr'] = None
for rec_name, gt_id in snr.index:
df['snr'].loc[(df['gt_unit_id']==gt_id) & (df['rec_name']==rec_name)] = snr.at[(rec_name, gt_id), 'snr']
sns.set_palette(sns.color_palette("deep"))
g = sns.relplot(data=df, x='snr', y='accuracy', col='sorter_name',
col_wrap=3, col_order=sorter_list, s=80)