spikeinterface peak localization

peak localization in spikeinterface

spikeinterface include several methods for unit or peak localization :

  • 'center_of_mass' : classic and fast localization. For instance herdingspikes use this method. It is quite accurate on squared MEA but have string artifact when units are on the border of the probe. So for linear probe this method give poor result on X axis.
  • 'monopolar_triangulation' with optimizer='least_square' This method is from Julien Boussard and Erdem Varol from the Paninski lab. This has been presented at NeurIPS see also here
  • 'monopolar_triangulation' with optimizer='minimize_with_log_penality' It is an improvement from the same team on the previous method not publish yet.

Here an example how to use.

In [1]:
%load_ext autoreload
%autoreload 2
In [4]:
from pathlib import Path

import spikeinterface.full as si

import numpy as np
import matplotlib.pyplot as plt
plt.rcParams["figure.figsize"] = (20, 12)
In [5]:
# local folder
base_folder = Path('/mnt/data/sam/DataSpikeSorting/imposed_motion_nick')

dataset_folder = base_folder / 'dataset1/NP1'
preprocess_folder = base_folder / 'dataset1_NP1_preprocessed'
peak_folder = base_folder / 'dataset1_NP1_peaks'

peak_folder.mkdir(exist_ok=True)
In [6]:
# global kwargs for parallel computing
job_kwargs = dict(
    n_jobs=40,
    chunk_duration='1s',
    progress_bar=True,
)
In [11]:
# read the file
rec = si.read_spikeglx(dataset_folder)
rec
Out[11]:
SpikeGLXRecordingExtractor: 384 channels - 1 segments - 30.0kHz - 1957.191s
In [12]:
fig, ax = plt.subplots()
si.plot_probe_map(rec, ax=ax)
ax.set_ylim(-150, 200)
Out[12]:
(-150.0, 200.0)

preprocess

This take 4 min for 30min of signals

In [13]:
if not preprocess_folder.exists():
    rec_filtered = si.bandpass_filter(rec, freq_min=300., freq_max=6000.)
    rec_preprocessed = si.common_reference(rec_filtered, reference='global', operator='median')
    rec_preprocessed.save(folder=preprocess_folder, **job_kwargs)
rec_preprocessed = si.load_extractor(preprocess_folder)
write_binary_recording with n_jobs 40  chunk_size 30000
write_binary_recording: 100%|██████████| 1958/1958 [03:09<00:00, 10.34it/s]
In [14]:
# plot and check spikes
si.plot_timeseries(rec_preprocessed, time_range=(100, 110), channel_ids=rec.channel_ids[50:60])
Out[14]:
<spikeinterface.widgets.timeseries.TimeseriesWidget at 0x7f894ea61af0>

estimate noise

In [15]:
noise_levels = si.get_noise_levels(rec_preprocessed, return_scaled=False)
fig, ax = plt.subplots(figsize=(8,6))
ax.hist(noise_levels, bins=np.arange(0,10, 1))
ax.set_title('noise across channel')
Out[15]:
Text(0.5, 1.0, 'noise across channel')

detect peaks

This take 1min10s

In [16]:
from spikeinterface.sortingcomponents.peak_detection import detect_peaks
In [18]:
if not (peak_folder / 'peaks.npy').exists():
    peaks = detect_peaks(
        rec_preprocessed,
        method='locally_exclusive',
        local_radius_um=100,
        peak_sign='neg',
        detect_threshold=5,
        n_shifts=5,
        noise_levels=noise_levels,
        **job_kwargs,
    )
    np.save(peak_folder / 'peaks.npy', peaks)
peaks = np.load(peak_folder / 'peaks.npy')
print(peaks.shape)
detect peaks: 100%|██████████| 1958/1958 [01:11<00:00, 27.26it/s]
(4041179,)

localize peaks

We use 2 methods:

  • 'center_of_mass': 9 s
  • 'monopolar_triangulation' leagacy : 26min
  • 'monopolar_triangulation' log barrier : 16min
In [19]:
from spikeinterface.sortingcomponents.peak_localization import localize_peaks
In [21]:
if not (peak_folder / 'peak_locations_center_of_mass.npy').exists():
    peak_locations = localize_peaks(
        rec_preprocessed,
        peaks,
        ms_before=0.3,
        ms_after=0.6,
        method='center_of_mass',
        method_kwargs={'local_radius_um': 100.},
        **job_kwargs,
    )
    np.save(peak_folder / 'peak_locations_center_of_mass.npy', peak_locations)
    print(peak_locations.shape)
localize peaks: 100%|██████████| 1958/1958 [00:08<00:00, 218.72it/s]
(4041179,)
In [36]:
if not (peak_folder / 'peak_locations_monopolar_triangulation_legacy.npy').exists():
    peak_locations = localize_peaks(
        rec_preprocessed,
        peaks,
        ms_before=0.3,
        ms_after=0.6,
        method='monopolar_triangulation',
        method_kwargs={'local_radius_um': 100., 'max_distance_um': 1000., 'optimizer': 'least_square'},
        **job_kwargs,
    )
    np.save(peak_folder / 'peak_locations_monopolar_triangulation_legacy.npy', peak_locations)
    print(peak_locations.shape)
localize peaks: 100%|██████████| 1958/1958 [26:42<00:00,  1.22it/s] 
(4041179,)
In [23]:
if not (peak_folder / 'peak_locations_monopolar_triangulation_log_limit.npy').exists():
    peak_locations = localize_peaks(
        rec_preprocessed,
        peaks,
        ms_before=0.3,
        ms_after=0.6,
        method='monopolar_triangulation',
        method_kwargs={'local_radius_um': 100., 'max_distance_um': 1000., 'optimizer': 'minimize_with_log_penality'},
        **job_kwargs,
    )
    np.save(peak_folder / 'peak_locations_monopolar_triangulation_log_limit.npy', peak_locations)
    print(peak_locations.shape)
localize peaks: 100%|██████████| 1958/1958 [16:15<00:00,  2.01it/s]
(4041179,)
In [24]:
# load back
# peak_locations = np.load(peak_folder / 'peak_locations_center_of_mass.npy')
# peak_locations = np.load(peak_folder / 'peak_locations_monopolar_triangulation_legacy.npy')
peak_locations = np.load(peak_folder / 'peak_locations_monopolar_triangulation_log_limit.npy')
print(peak_locations)
[( 18.81849235, 1782.84538913,  78.17532357, 1696.96239445)
 ( 31.90279769, 3847.75061369, 134.79844077, 1716.03155721)
 (-23.12038001, 2632.87834759,  87.76916268, 2633.62546695) ...
 ( 40.0839554 , 1977.83852796,  26.50998809, 1092.53885299)
 (-51.40036701, 1772.34521905, 170.65660676, 2533.03617278)
 ( 54.3813594 , 1182.28971165,  87.35020554, 1303.53392431)]

plot on probe

In [38]:
for name in ('center_of_mass', 'monopolar_triangulation_legacy', 'monopolar_triangulation_log_limit'):

    peak_locations = np.load(peak_folder / f'peak_locations_{name}.npy')

    probe = rec_preprocessed.get_probe()

    fig, axs = plt.subplots(ncols=2, sharey=True, figsize=(15, 10))
    ax = axs[0]
    plot_probe(probe, ax=ax)
    ax.scatter(peak_locations['x'], peak_locations['y'], color='k', s=1, alpha=0.002)
    ax.set_xlabel('x')
    ax.set_ylabel('y')
    ax.set_title(name)
    if 'z' in peak_locations.dtype.fields:
        ax = axs[1]
        ax.scatter(peak_locations['z'], peak_locations['y'], color='k', s=1, alpha=0.002)
        ax.set_xlabel('z')
    ax.set_ylim(1500, 2500)

plot peak depth vs time

In [39]:
fig, ax = plt.subplots()
x = peaks['sample_ind'] / rec_preprocessed.get_sampling_frequency()
y = peak_locations['y']
ax.scatter(x, y, s=1, color='k', alpha=0.05)
ax.set_ylim(1300, 2500)
Out[39]:
(1300.0, 2500.0)

conclusion

spikeinterface clustering

spikeinterface clustering

The clustering step remains the central step of the spike sorting. Historically this step was separted into two distinct parts: feature reduction and clustering. In spikeinterface, we decided to regroup this two steps in the same module. This allows one to compute feature reduction on-the-fly and avoid long computations and storage of large features.

The clustering step takes the recording and detected (and optionally selected) peaks as input and returns a label for every peak.

At the moment, the implemenation is quite experimental. These methods have been implemented:

  • "position_clustering": use HDBSCAN on peak locations.
  • "sliding_hdbscan": clustering approach from tridesclous, with sliding spatial windows. PCA and HDBSCAN are run on local/sparse waveforms.
  • "position_pca_clustering": this method tries to use peak locations for a first clustering step and then perform further splits using PCA + HDBSCAN

Different methods may need different inputs (for instance some of them require need peak locations and some do not).

For this we will use a simulated with mearec dataset on 32 channel neuronexus like probe.

Here we will also use the select_peak() function to sub sample a smaller number of peak

In [1]:
# %matplotlib widget
%matplotlib inline
In [2]:
%load_ext autoreload
%autoreload 2
In [3]:
from pathlib import Path

import time

import numpy as np
import matplotlib.pyplot as plt


import spikeinterface.full as si
/home/samuel.garcia/.virtualenvs/py38/lib/python3.8/site-packages/datalad/cmd.py:375: RuntimeWarning: coroutine 'run_async_cmd' was never awaited
  new_loop = True
RuntimeWarning: Enable tracemalloc to get the object allocation traceback
/home/samuel.garcia/Documents/SpikeInterface/spikeinterface/spikeinterface/core/datasets.py:12: RuntimeWarning: coroutine 'run_async_cmd' was never awaited
  HAVE_DATALAD = False
RuntimeWarning: Enable tracemalloc to get the object allocation traceback
In [18]:
base_folder = Path('/mnt/data/sam/DataSpikeSorting/mearec_template_matching')
mearec_file = base_folder / 'recordings_collision_15cells_Neuronexus-32_1800s.h5'
rec_folder = base_folder /'Preprocessed_recording_15cells_Neuronexus-32_1800s'
peak_folder = base_folder / 'Peak_recording_15cells_Neuronexus'

clustering_path = base_folder / 'Clustering_recording_15cells_Neuronexus'

peak_folder.mkdir(exist_ok=True)
In [5]:
# global kwargs for parallel computing
job_kwargs = dict(
    n_jobs=40,
    chunk_duration='1s',
    progress_bar=True,
)

Preprocess

In [6]:
# load already cache or compute
if rec_folder.exists():
    rec_preprocessed = si.load_extractor(rec_folder)
else:
    recording, gt_sorting = si.read_mearec(mearec_file)
    recording = si.bandpass_filter(recording, dtype='float32')
    recording = si.common_reference(recording)
    rec_preprocessed = recording.save(folder=rec_folder, n_jobs=20, chunk_size=30000, progress_bar=True)

estimate noise

In [7]:
noise_levels = si.get_noise_levels(rec_preprocessed, return_scaled=False)
fig, ax = plt.subplots(figsize=(8,6))
ax.hist(noise_levels, bins=15)
ax.set_title('noise across channel')
Out[7]:
Text(0.5, 1.0, 'noise across channel')

detect peaks

In [8]:
from spikeinterface.sortingcomponents.peak_detection import detect_peaks
In [9]:
if not (peak_folder / 'peaks.npy').exists():
    peaks = detect_peaks(
        rec_preprocessed,
        method='locally_exclusive',
        local_radius_um=100,
        peak_sign='neg',
        detect_threshold=5,
        n_shifts=10,
        noise_levels=noise_levels,
        **job_kwargs,
    )
    np.save(peak_folder / 'peaks.npy', peaks)
peaks = np.load(peak_folder / 'peaks.npy')
print(peaks.shape)
(271519,)

select peaks

In [10]:
from spikeinterface.sortingcomponents.peak_selection import select_peaks
In [11]:
if not (peak_folder / 'some_peaks.npy').exists():
    some_peaks = select_peaks(peaks, method='uniform', select_per_channel=True, n_peaks=500, seed=None)
    np.save(peak_folder / 'some_peaks.npy', some_peaks)
some_peaks = np.load(peak_folder / 'some_peaks.npy')
print('some_peaks.size', some_peaks.size)
some_peaks.size 13424

localize peaks (on sub selection)

In [12]:
from spikeinterface.sortingcomponents.peak_localization import localize_peaks
In [14]:
some_peak_locations = localize_peaks(rec_preprocessed, some_peaks,
        method='monopolar_triangulation',
        method_kwargs={'local_radius_um': 100., 'max_distance_um': 1000., 'optimizer': 'minimize_with_log_penality'},
        # method_kwargs={'local_radius_um': 100., 'max_distance_um': 1000., 'optimizer': 'least_square'},
        **job_kwargs)
np.save(peak_folder / 'some_peak_locations.npy', some_peak_locations)
localize peaks: 100%|██████████| 1800/1800 [00:03<00:00, 513.78it/s]
In [15]:
some_peak_locations = np.load(peak_folder / f'some_peak_locations.npy')

probe = rec_preprocessed.get_probe()

fig, axs = plt.subplots(ncols=2, sharey=True, figsize=(15, 10))
ax = axs[0]
si.plot_probe_map(rec_preprocessed, ax=ax)
ax.scatter(some_peak_locations['x'], some_peak_locations['y'], color='k', s=1, alpha=0.2)
ax.set_xlabel('x')
ax.set_ylabel('y')
if 'z' in some_peak_locations.dtype.fields:
    ax = axs[1]
    ax.scatter(some_peak_locations['z'], some_peak_locations['y'], color='k', s=1, alpha=0.2)
    ax.set_xlabel('z')
# ax.set_ylim(1500, 2500)

clustering (on sub selection)

In [19]:
from spikeinterface.sortingcomponents.clustering import find_cluster_from_peaks
In [43]:
method_kwargs = dict(
    peak_locations=some_peak_locations,
    hdbscan_params_spatial = {"min_cluster_size" : 20,  "allow_single_cluster" : True, 'metric' : 'l2'},
    probability_thr = 0,
    apply_norm=True,
    #~ debug=True,
    debug=False,
    tmp_folder=clustering_path,
    n_components_by_channel=4,
    n_components=4,
    job_kwargs = {"n_jobs" : 2, "chunk_size" : 30000, "progress_bar" : True},
    waveform_mode="shared_memory",
    #~ waveform_mode="memmap",
)

t0 = time.perf_counter()
possible_labels, peak_labels = find_cluster_from_peaks(rec_preprocessed, some_peaks, 
        method='position_pca_clustering', method_kwargs=method_kwargs)
t1 = time.perf_counter()
print('position_pca_clustering', t1 -t0)
extract waveforms shared_memory: 100%|██████████| 1800/1800 [00:00<00:00, 4997.14it/s]
extract waveforms shared_memory: 100%|██████████| 1800/1800 [00:00<00:00, 5033.09it/s]
position_pca_clustering 12.23708628397435
In [44]:
print(possible_labels)
[ 3  4  6 10 11 14 16 18 19 20 22 28 29 32 33]
In [45]:
import distinctipy
def plot_cluster_on_probe(rec, possible_labels, peak_labels):
    possible_colors = distinctipy.get_colors(possible_labels.size)

    colors = np.zeros((peak_labels.size, 3))
    for i, k in enumerate(possible_labels):
        mask = peak_labels == k
        colors[mask, :] = possible_colors[i]
    colors[mask, :] = possible_colors[i]
    

    fig, axs = plt.subplots(ncols=2, sharey=True, figsize=(15, 10))
    ax = axs[0]
    si.plot_probe_map(rec, ax=ax)
    ax.scatter(some_peak_locations['x'], some_peak_locations['y'], s=1, c=colors, alpha=0.5)
    ax.set_xlabel('x')
    ax.set_ylabel('y')
    if 'z' in some_peak_locations.dtype.fields:
        ax = axs[1]
        ax.scatter(some_peak_locations['z'], some_peak_locations['y'], s=1, c=colors, alpha=0.5)
        ax.set_xlabel('z')
In [46]:
plot_cluster_on_probe(rec_preprocessed, possible_labels, peak_labels)

lets try another method

In [47]:
method_kwargs = dict(
)

t0 = time.perf_counter()
possible_labels, peak_labels = find_cluster_from_peaks(rec_preprocessed, some_peaks, 
        method='sliding_hdbscan', method_kwargs=method_kwargs)
t1 = time.perf_counter()
print('position_pca_clustering', t1 -t0)
position_pca_clustering 19.997920085676014
In [48]:
print(possible_labels)
[ 1  2  3  4  5  6  7  8  9 10 11 16 20]
In [49]:
plot_cluster_on_probe(rec_preprocessed, possible_labels, peak_labels)
In [ ]:
 
In [ ]:
 

spikeinterface motion estimation / correction

motion estimation in spikeinterface

In 2021,the SpikeInterface project has started to implemented sortingcomponents, a modular module for spike sorting steps.

Here is an overview or our progress integrating motion (aka drift) estimation and correction.

This notebook will be based on the open dataset from Nick Steinmetz published in 2021 "Imposed motion datasets" from Steinmetz et al. Science 2021 https://figshare.com/articles/dataset/_Imposed_motion_datasets_from_Steinmetz_et_al_Science_2021/14024495

The motion estimation is done in several modular steps:

  1. detect peaks
  2. localize peaks:
  3. estimation motion:
    • rigid or non rigid
    • "decentralize" by Erdem Varol and Julien Boussard DOI : 10.1109/ICASSP39728.2021.9414145
    • "motion cloud" by Julien Boussard (not implemented yet)

Here we will show this chain:

  • detect peak > localize peaks with "monopolar_triangulation" > estimation motion "decentralize"
In [1]:
%load_ext autoreload
%autoreload 2
In [2]:
from pathlib import Path

import spikeinterface.full as si

import numpy as np
import matplotlib.pyplot as plt
plt.rcParams["figure.figsize"] = (20, 12)

from probeinterface.plotting import plot_probe
/home/samuel.garcia/.virtualenvs/py38/lib/python3.8/site-packages/datalad/cmd.py:375: RuntimeWarning: coroutine 'run_async_cmd' was never awaited
  new_loop = True
RuntimeWarning: Enable tracemalloc to get the object allocation traceback
/home/samuel.garcia/Documents/SpikeInterface/spikeinterface/spikeinterface/core/datasets.py:12: RuntimeWarning: coroutine 'run_async_cmd' was never awaited
  HAVE_DATALAD = False
RuntimeWarning: Enable tracemalloc to get the object allocation traceback
In [3]:
# local folder
base_folder = Path('/mnt/data/sam/DataSpikeSorting/imposed_motion_nick')

dataset_folder = base_folder / 'dataset1/NP1'
preprocess_folder = base_folder / 'dataset1_NP1_preprocessed'
peak_folder = base_folder / 'dataset1_NP1_peaks'

peak_folder.mkdir(exist_ok=True)
In [4]:
# global kwargs for parallel computing
job_kwargs = dict(
    n_jobs=40,
    chunk_memory='10M',
    progress_bar=True,
)
In [5]:
# read the file
rec = si.read_spikeglx(dataset_folder)
rec
Out[5]:
SpikeGLXRecordingExtractor: 384 channels - 1 segments - 30.0kHz - 1957.191s
In [7]:
fig, ax = plt.subplots()
si.plot_probe_map(rec, ax=ax)
ax.set_ylim(-150, 200)
Out[7]:
(-150.0, 200.0)

preprocess

This take 4 min for 30min of signals

In [8]:
if not preprocess_folder.exists():
    rec_filtered = si.bandpass_filter(rec, freq_min=300., freq_max=6000.)
    rec_preprocessed = si.common_reference(rec_filtered, reference='global', operator='median')
    rec_preprocessed.save(folder=preprocess_folder, **job_kwargs)
rec_preprocessed = si.load_extractor(preprocess_folder)
In [9]:
# plot and check spikes
si.plot_timeseries(rec_preprocessed, time_range=(100, 110), channel_ids=rec.channel_ids[50:60])
Out[9]:
<spikeinterface.widgets.timeseries.TimeseriesWidget at 0x7f72adda0520>

estimate noise

In [13]:
noise_levels = si.get_noise_levels(rec_preprocessed, return_scaled=False)
fig, ax = plt.subplots(figsize=(8,6))
ax.hist(noise_levels, bins=10)
ax.set_title('noise across channel')
Out[13]:
Text(0.5, 1.0, 'noise across channel')

detect peaks

This take 1min30s

In [11]:
from spikeinterface.sortingcomponents.peak_detection import detect_peaks
In [14]:
if not (peak_folder / 'peaks.npy').exists():
    peaks = detect_peaks(
        rec_preprocessed,
        method='locally_exclusive',
        local_radius_um=100,
        peak_sign='neg',
        detect_threshold=5,
        n_shifts=5,
        noise_levels=noise_levels,
        **job_kwargs,
    )
    np.save(peak_folder / 'peaks.npy', peaks)
peaks = np.load(peak_folder / 'peaks.npy')
print(peaks.shape)
(4041179,)

localize peaks

Here we chosse 'monopolar_triangulation' with log barrier

In [18]:
from spikeinterface.sortingcomponents.peak_localization import localize_peaks
In [16]:
if not (peak_folder / 'peak_locations_monopolar_triangulation_log_limit.npy').exists():
    peak_locations = localize_peaks(
        rec_preprocessed,
        peaks,
        ms_before=0.3,
        ms_after=0.6,
        method='monopolar_triangulation',
        method_kwargs={'local_radius_um': 100., 'max_distance_um': 1000., 'optimizer': 'minimize_with_log_penality'},
        **job_kwargs,
    )
    np.save(peak_folder / 'peak_locations_monopolar_triangulation_log_limit.npy', peak_locations)
    print(peak_locations.shape)
peak_locations = np.load(peak_folder / 'peak_locations_monopolar_triangulation_log_limit.npy')
In [17]:
print(peak_locations)
[( 18.81849235, 1782.84538913,  78.17532357, 1696.96239445)
 ( 31.90279769, 3847.75061369, 134.79844077, 1716.03155721)
 (-23.12038001, 2632.87834759,  87.76916268, 2633.62546695) ...
 ( 40.0839554 , 1977.83852796,  26.50998809, 1092.53885299)
 (-51.40036701, 1772.34521905, 170.65660676, 2533.03617278)
 ( 54.3813594 , 1182.28971165,  87.35020554, 1303.53392431)]

plot on probe

In [22]:
fig, axs = plt.subplots(ncols=2, sharey=True, figsize=(15, 10))
ax = axs[0]
si.plot_probe_map(rec_preprocessed, ax=ax)
ax.scatter(peak_locations['x'], peak_locations['y'], color='k', s=1, alpha=0.002)
ax.set_xlabel('x')
ax.set_ylabel('y')
if 'z' in peak_locations.dtype.fields:
    ax = axs[1]
    ax.scatter(peak_locations['z'], peak_locations['y'], color='k', s=1, alpha=0.002)
    ax.set_xlabel('z')
    ax.set_xlim(0, 150)
ax.set_ylim(1800, 2500)
Out[22]:
(1800.0, 2500.0)

plot peak depth vs time

In [23]:
fig, ax = plt.subplots()
x = peaks['sample_ind'] / rec_preprocessed.get_sampling_frequency()
y = peak_locations['y']
ax.scatter(x, y, s=1, color='k', alpha=0.05)
ax.set_ylim(1300, 2500)
Out[23]:
(1300.0, 2500.0)

motion estimate : rigid with decentralized

In [25]:
from spikeinterface.sortingcomponents.motion_estimation import (
    estimate_motion,
    make_motion_histogram,
    compute_pairwise_displacement,
    compute_global_displacement
)
In [45]:
bin_um = 5
bin_duration_s=5.

motion_histogram, temporal_bins, spatial_bins = make_motion_histogram(
    rec_preprocessed,
    peaks,
    peak_locations=peak_locations, 
    bin_um=bin_um,
    bin_duration_s=bin_duration_s,
    direction='y',
    weight_with_amplitude=False,
)
print(motion_histogram.shape, temporal_bins.size, spatial_bins.size)
(392, 784) 393 785
In [32]:
fig, ax = plt.subplots()
extent = (temporal_bins[0], temporal_bins[-1], spatial_bins[0], spatial_bins[-1])
im = ax.imshow(
    motion_histogram.T,
    interpolation='nearest',
    origin='lower',
    aspect='auto',
    extent=extent,
)
im.set_clim(0, 30)
ax.set_ylim(1300, 2500)
ax.set_xlabel('time[s]')
ax.set_ylabel('depth[um]')
Out[32]:
Text(0, 0.5, 'depth[um]')

pariwise displacement from the motion histogram

In [39]:
pairwise_displacement, pairwise_displacement_weight = compute_pairwise_displacement(motion_histogram, bin_um, method='conv2d', )
np.save(peak_folder / 'pairwise_displacement_conv2d.npy', pairwise_displacement)
In [40]:
fig, ax = plt.subplots()
extent = (temporal_bins[0], temporal_bins[-1], temporal_bins[0], temporal_bins[-1])
# extent = None
im = ax.imshow(
    pairwise_displacement,
    interpolation='nearest',
    cmap='PiYG',
    origin='lower',
    aspect='auto',
    extent=extent,
)
im.set_clim(-40, 40)
ax.set_aspect('equal')
fig.colorbar(im)
Out[40]:
<matplotlib.colorbar.Colorbar at 0x7f7238013400>

estimate motion (rigid) from the pairwise displacement

In [43]:
motion = compute_global_displacement(pairwise_displacement)

motion = compute_global_displacement(pairwise_displacement,convergence_method='gradient_descent',)
# motion = compute_global_displacement(pairwise_displacement, pairwise_displacement_weight=pairwise_displacement_weight, convergence_method='lsqr_robust',)
In [47]:
fig, ax = plt.subplots()
ax.plot(temporal_bins[:-1], motion)
Out[47]:
[<matplotlib.lines.Line2D at 0x7f7238624a60>]

motion estimation with one unique funtion

Internally estimate_motion() does:

  • make_motion_histogram()
  • compute_pairwise_displacement()
  • compute_global_displacement()
In [58]:
from spikeinterface.sortingcomponents.motion_estimation import estimate_motion
from spikeinterface.widgets import plot_pairwise_displacement, plot_displacement
In [59]:
method='decentralized_registration'
method_kwargs = dict(

     pairwise_displacement_method='conv2d',
    # convergence_method='gradient_descent',
    convergence_method='lsqr_robust',
    
)

# method='decentralized_registration'
# method_kwargs = dict(
#     pairwise_displacement_method='phase_cross_correlation',
#     convergence_method='lsqr_robust',
# )


motion, temporal_bins, spatial_bins, extra_check = estimate_motion(
    rec_preprocessed,
    peaks,
    peak_locations=peak_locations,
    direction='y',
    bin_duration_s=5.,
    bin_um=10.,
    method=method,
    method_kwargs=method_kwargs,
    non_rigid_kwargs=None,
    output_extra_check=True,
    progress_bar=True,
    verbose=False,
)
100%|██████████| 392/392 [00:09<00:00, 40.71it/s]
In [60]:
plot_pairwise_displacement(motion, temporal_bins, spatial_bins, extra_check, ncols=4)
Out[60]:
<spikeinterface.widgets.drift.PairwiseDisplacementWidget at 0x7f72427cf4c0>
In [61]:
plot_displacement(motion, temporal_bins, spatial_bins, extra_check, with_histogram=True)
Out[61]:
<spikeinterface.widgets.drift.DisplacementWidget at 0x7f72384dd8b0>
In [62]:
fig, ax = plt.subplots()
x = peaks['sample_ind'] / rec_preprocessed.get_sampling_frequency()
y = peak_locations['y']
ax.scatter(x, y, s=1, color='k', alpha=0.05)
plot_displacement(motion, temporal_bins, spatial_bins, extra_check, with_histogram=False, ax=ax)
Out[62]:
<spikeinterface.widgets.drift.DisplacementWidget at 0x7f72381deca0>
In [ ]:
 

motion estimation non rigid

In [64]:
# method='decentralized_registration'
# method_kwargs = dict()
#     pairwise_displacement_method='conv2d',
#     convergence_method='gradient_descent',
# )

method='decentralized_registration'
method_kwargs = dict(
    pairwise_displacement_method='conv2d',
    convergence_method='lsqr_robust',
)


# method='decentralized_registration'
# method_kwargs = dict(
#     pairwise_displacement_method='phase_cross_correlation',
#     convergence_method='lsqr_robust',
# )


motion, temporal_bins, spatial_bins, extra_check = estimate_motion(
    rec_preprocessed,
    peaks,
    peak_locations=peak_locations,
    direction='y',
    bin_duration_s=5.,
    bin_um=5.,
    method=method,
    method_kwargs=method_kwargs,
    non_rigid_kwargs=dict(bin_step_um=200, signam=3),
    output_extra_check=True,
    progress_bar=True,
    verbose=False,
)
100%|██████████| 392/392 [00:18<00:00, 21.18it/s]
100%|██████████| 392/392 [00:20<00:00, 19.59it/s]
100%|██████████| 392/392 [00:18<00:00, 21.08it/s]
100%|██████████| 392/392 [00:18<00:00, 21.04it/s]
100%|██████████| 392/392 [00:20<00:00, 19.49it/s]
100%|██████████| 392/392 [00:18<00:00, 21.12it/s]
100%|██████████| 392/392 [00:20<00:00, 19.47it/s]
100%|██████████| 392/392 [00:18<00:00, 21.08it/s]
100%|██████████| 392/392 [00:20<00:00, 19.57it/s]
100%|██████████| 392/392 [00:18<00:00, 21.12it/s]
100%|██████████| 392/392 [00:20<00:00, 19.59it/s]
100%|██████████| 392/392 [00:18<00:00, 21.12it/s]
100%|██████████| 392/392 [00:20<00:00, 19.59it/s]
100%|██████████| 392/392 [00:18<00:00, 21.07it/s]
100%|██████████| 392/392 [00:20<00:00, 19.54it/s]
100%|██████████| 392/392 [00:18<00:00, 21.04it/s]
100%|██████████| 392/392 [00:19<00:00, 19.60it/s]
100%|██████████| 392/392 [00:18<00:00, 21.09it/s]
100%|██████████| 392/392 [00:19<00:00, 19.61it/s]
100%|██████████| 392/392 [00:18<00:00, 21.25it/s]
In [65]:
fig, ax = plt.subplots()
for win in extra_check['non_rigid_windows']:
    ax.plot(win, extra_check['spatial_hist_bins'][:-1])
In [66]:
plot_pairwise_displacement(motion, temporal_bins, spatial_bins, extra_check, ncols=4)
Out[66]:
<spikeinterface.widgets.drift.PairwiseDisplacementWidget at 0x7f722db68dc0>
In [67]:
plot_displacement(motion, temporal_bins, spatial_bins, extra_check, with_histogram=True)
Out[67]:
<spikeinterface.widgets.drift.DisplacementWidget at 0x7f722b6d6d60>
In [69]:
fig, ax = plt.subplots()
x = peaks['sample_ind'] / rec_preprocessed.get_sampling_frequency()
y = peak_locations['y']
ax.scatter(x, y, s=1, color='k', alpha=0.05)
plot_displacement(motion, temporal_bins, spatial_bins, extra_check, with_histogram=False, ax=ax)
ax.set_ylim(0, 2000)
Out[69]:
(0.0, 2000.0)
In [70]:
fig, ax = plt.subplots()
ax.plot(temporal_bins, motion)
Out[70]:
[<matplotlib.lines.Line2D at 0x7f722b466790>,
 <matplotlib.lines.Line2D at 0x7f722b4667c0>,
 <matplotlib.lines.Line2D at 0x7f722b4668e0>,
 <matplotlib.lines.Line2D at 0x7f722b466a00>,
 <matplotlib.lines.Line2D at 0x7f722b4577f0>,
 <matplotlib.lines.Line2D at 0x7f722b457820>,
 <matplotlib.lines.Line2D at 0x7f722b466c40>,
 <matplotlib.lines.Line2D at 0x7f722b466d60>,
 <matplotlib.lines.Line2D at 0x7f722b466e80>,
 <matplotlib.lines.Line2D at 0x7f722b466fa0>,
 <matplotlib.lines.Line2D at 0x7f722b4412e0>,
 <matplotlib.lines.Line2D at 0x7f722b3f0100>,
 <matplotlib.lines.Line2D at 0x7f722b3f0310>,
 <matplotlib.lines.Line2D at 0x7f722b3f0430>,
 <matplotlib.lines.Line2D at 0x7f722b3f0550>,
 <matplotlib.lines.Line2D at 0x7f722b3f0670>,
 <matplotlib.lines.Line2D at 0x7f722b3f0790>,
 <matplotlib.lines.Line2D at 0x7f722b3f08b0>,
 <matplotlib.lines.Line2D at 0x7f722b3f09d0>,
 <matplotlib.lines.Line2D at 0x7f722b3f0af0>]
In [71]:
fig, ax = plt.subplots()
im = ax.imshow(motion.T,
    interpolation='nearest',
    cmap='PiYG',
    origin='lower',
    aspect='auto',
    # extent=extent,
)
im.set_clim(-40, 40)
ax.set_aspect('equal')
fig.colorbar(im)
Out[71]:
<matplotlib.colorbar.Colorbar at 0x7f722b336a30>
In [ ]:
 
In [ ]:
 
In [ ]:
 
In [ ]:
 

spikeinterface peak detection

peak detection in spikeinterface

Author : Samuel Garcia

spikeinterface implements several method for peak detection.

peak detection can be used:

  1. as a first step for spike sorting chain
  2. as a first step for estimating motion (aka drift)

Here we will illustrate how this work and also how in conjonction of the preprocessing module we can compute this detection with or without caching the preprocessed traces on the disk.

This example will be run on neuropixel 1 and neuropixel 2 recorded by Nick Steinmetz here.

In [1]:
%load_ext autoreload
%autoreload 2
In [2]:
# %matplotlib widget
%matplotlib inline
In [3]:
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt

import spikeinterface.full as si
/home/samuel.garcia/.virtualenvs/py38/lib/python3.8/site-packages/datalad/cmd.py:375: RuntimeWarning: coroutine 'run_async_cmd' was never awaited
  new_loop = True
RuntimeWarning: Enable tracemalloc to get the object allocation traceback
/home/samuel.garcia/Documents/SpikeInterface/spikeinterface/spikeinterface/core/datasets.py:12: RuntimeWarning: coroutine 'run_async_cmd' was never awaited
  HAVE_DATALAD = False
RuntimeWarning: Enable tracemalloc to get the object allocation traceback

open dataset

In [4]:
# local folder
base_folder = Path('/mnt/data/sam/DataSpikeSorting/imposed_motion_nick')

dataset_folder = base_folder / 'dataset1/NP2'
preprocess_folder_bin = base_folder / 'dataset1_NP2_preprocessed_binary'
preprocess_folder_zarr = base_folder / 'dataset1_NP2_preprocessed_zarr'
In [5]:
# read the file
rec = si.read_spikeglx(dataset_folder)
rec
Out[5]:
SpikeGLXRecordingExtractor: 384 channels - 1 segments - 30.0kHz - 1956.954s
In [6]:
fig, ax = plt.subplots(figsize=(7, 20))
si.plot_probe_map(rec, with_channel_ids=True, ax=ax)
ax.set_ylim(-150, 200)
Out[6]:
(-150.0, 200.0)

preprocess

Here we will apply filetring + CMR

And to demonstrate the flexibility we will on work on 3 objects:

  • the lazy object rec_preprocessed
  • the cached object in binary format rec_preprocessed_cached_binary
  • the cached object in zarr format rec_preprocessed_cached_zarr

Caching to binary take Caching to zarr take

In [7]:
# global kwargs for parallel computing
job_kwargs = dict(
    n_jobs=20,
    chunk_duration='1s',
    progress_bar=True,
)
In [8]:
# create the lazy object
rec_filtered = si.bandpass_filter(rec, freq_min=300., freq_max=6000.)
rec_preprocessed = si.common_reference(rec_filtered, reference='global', operator='median')
print(rec_preprocessed)
CommonReferenceRecording: 384 channels - 1 segments - 30.0kHz - 1956.954s
In [9]:
# if not exists yet cache to binary
if preprocess_folder_bin.exists():
    rec_preprocessed_cached_binary = si.load_extractor(preprocess_folder_bin)
else:
    # cache to binary
    rec_preprocessed_cached_binary = rec_preprocessed.save(folder=preprocess_folder_bin, format='binary', **job_kwargs)
write_binary_recording with n_jobs 20  chunk_size 30000
write_binary_recording: 100%|██████████| 1957/1957 [03:50<00:00,  8.49it/s]
In [10]:
print(rec_preprocessed_cached_binary)
BinaryRecordingExtractor: 384 channels - 1 segments - 30.0kHz - 1956.954s
  file_paths: ['/mnt/data/sam/DataSpikeSorting/imposed_motion_nick/dataset1_NP2_preprocessed_binary/traces_cached_seg0.raw']
In [11]:
if preprocess_folder_zarr.exists():
    rec_preprocessed_cached_zarr = si.load_extractor(preprocess_folder_zarr)
else:
    # cache to zarr
    rec_preprocessed_cached_zarr = rec_preprocessed.save(zarr_path=preprocess_folder_zarr,  format='zarr', **job_kwargs)
Using default zarr compressor: Blosc(cname='zstd', clevel=5, shuffle=BITSHUFFLE, blocksize=0). To use a different compressor, use the 'compressor' argument
write_zarr_recording with n_jobs 20  chunk_size 30000
write_zarr_recording: 100%|██████████| 1957/1957 [03:36<00:00,  9.04it/s]
Skipping field contact_plane_axes: only 1D and 2D arrays can be serialized
In [16]:
print(rec_preprocessed_cached_zarr)
ZarrRecordingExtractor: 384 channels - 1 segments - 30.0kHz - 1956.954s

show some traces

In [9]:
# plot and check spikes
si.plot_timeseries(rec_preprocessed, time_range=(100, 110), channel_ids=rec.channel_ids[50:60])
Out[9]:
<spikeinterface.widgets.timeseries.TimeseriesWidget at 0x7f003e94c880>

estimate noise

We need some estimation of the noise.

Very important : we must estimate the noise with return_scaled=False because detect_peaks() will work on raw data (int16 very often)

In [39]:
noise_levels = si.get_noise_levels(rec_preprocessed, return_scaled=False)
fig, ax = plt.subplots(figsize=(8,6))
ax.hist(noise_levels, bins=10)
ax.set_title('noise across channel')
Out[39]:
Text(0.5, 1.0, 'noise across channel')

detect peaks

we have 2 methods in spikeinetrface with are done with numba:

  • 'by_channel' : peaks are detected on each channel indepandantly
  • 'locally_exclusive' : if a units fire on several channel the best peak on the best channel is kept This is controlle by local_radius_um
In [34]:
from spikeinterface.sortingcomponents.peak_detection import detect_peaks
In [40]:
peaks = detect_peaks(rec_preprocessed,
                     method='locally_exclusive',
                     local_radius_um=100,
                     peak_sign='neg',
                     detect_threshold=5,
                     n_shifts=5,
                     noise_levels=noise_levels,
                    **job_kwargs)
print(peaks.shape)
detect peaks: 100%|██████████| 1957/1957 [02:09<00:00, 15.09it/s]
(2531770,)

compare compute time with cached version

When we detect peak on the lazy object. Every trace chunk is loaded processed and then peak are detected on it.

When we detect peak on cached version the trace chunk is read from the save version

In [41]:
peaks = detect_peaks(rec_preprocessed_cached_binary,
                     method='locally_exclusive',
                     local_radius_um=100,
                     peak_sign='neg',
                     detect_threshold=5,
                     n_shifts=5,
                     noise_levels=noise_levels,
                    **job_kwargs)
print(peaks.shape)
detect peaks: 100%|██████████| 1957/1957 [01:30<00:00, 21.55it/s]
(2528737,)
In [42]:
peaks = detect_peaks(rec_preprocessed_cached_zarr,
                     method='locally_exclusive',
                     local_radius_um=100,
                     peak_sign='neg',
                     detect_threshold=5,
                     n_shifts=5,
                     noise_levels=noise_levels,
                    **job_kwargs)
print(peaks.shape)
detect peaks: 100%|██████████| 1957/1957 [01:28<00:00, 22.23it/s]
(2528737,)

Conlusion

Running peak detection on lazy vs cached version is an important choice.

detect_peak() is a bit faster on cahed version (1:30) than lazy version (2:00)

But the total time of save() + detect_peak() is slower (3:30 + 1:30 = 5:00) !!!

Here writing to disk is clearly a waste on time.

So the benefit of caching totally depend:

  1. on the complexity of the preprocessing chain
  2. writting disk capability
  3. how many the preprocessed recording will be cunsumed!!!

spikeinterface template matching

spikeinterface template matching

Template matching is the final step used in many tools (kilosort, spyking-circus, yass, tridesclous, hdsort...)

In this step, from a given catalogue (aka dictionnary) of template (aka atoms), algorithms explain traces as a linear sum of template plus residual noise.

We have started to implement some template matching procedure in spikeinterface.

Here a small demo and also some benchmark to compare performance of theses algos.

For this we will use a simulated with mearec dataset on 32 channel neuronexus like probe. Then we will compute the true template using the true sorting. Theses true templates will be used for diffrents method. And then we will apply comparison to ground truth procedure to estimate only this step.

In [1]:
# %matplotlib widget
%matplotlib inline
In [2]:
%load_ext autoreload
%autoreload 2
In [3]:
from pathlib import Path

import time

import numpy as np
import matplotlib.pyplot as plt


import spikeinterface.full as si
/home/samuel.garcia/.virtualenvs/py38/lib/python3.8/site-packages/datalad/cmd.py:375: RuntimeWarning: coroutine 'run_async_cmd' was never awaited
  new_loop = True
RuntimeWarning: Enable tracemalloc to get the object allocation traceback
/home/samuel.garcia/Documents/SpikeInterface/spikeinterface/spikeinterface/core/datasets.py:12: RuntimeWarning: coroutine 'run_async_cmd' was never awaited
  HAVE_DATALAD = False
RuntimeWarning: Enable tracemalloc to get the object allocation traceback
In [4]:
base_folder = Path('/mnt/data/sam/DataSpikeSorting/mearec_template_matching')
mearec_file = base_folder / 'recordings_collision_15cells_Neuronexus-32_1800s.h5'
wf_folder = base_folder / 'Waveforms_recording_15cells_Neuronexus-32_1800s'
rec_folder = base_folder /'Preprocessed_recording_15cells_Neuronexus-32_1800s'

open and preprocess

In [5]:
# load already cache or compute
if rec_folder.exists():
    recording = si.load_extractor(rec_folder)
else:
    recording, gt_sorting = si.read_mearec(mearec_file)
    recording = si.bandpass_filter(recording, dtype='float32')
    recording = si.common_reference(recording)
    recording = recording.save(folder=rec_folder, n_jobs=20, chunk_size=30000, progress_bar=True)

construct true templates

In [6]:
_, gt_sorting = si.read_mearec(mearec_file)
recording = si.load_extractor(rec_folder)
In [7]:
we = si.extract_waveforms(recording, gt_sorting, wf_folder, load_if_exists=True,
                           ms_before=2.5, ms_after=3.5, max_spikes_per_unit=500,
                           n_jobs=20, chunk_size=30000, progress_bar=True)
print(we)
WaveformExtractor: 32 channels - 15 units - 1 segments
  before:75 after:105 n_per_units:500
In [8]:
metrics = si.compute_quality_metrics(we, metric_names=['snr'], load_if_exists=True)
metrics
Out[8]:
snr
#0 42.573563
#1 23.475538
#2 11.677200
#3 8.544864
#4 61.134110
#5 49.281887
#6 31.793837
#7 36.275745
#8 12.932632
#9 39.769770
#10 8.230338
#11 14.968547
#12 12.002127
#13 12.905783
#14 20.285872

run several method of template matching

A unique function is used for that find_spikes_from_templates()

In [9]:
from spikeinterface.sortingcomponents.template_matching import find_spikes_from_templates
In [10]:
# Some method need teh noise level (for internal detection)
noise_levels = si.get_noise_levels(recording, return_scaled=False)
noise_levels
Out[10]:
array([3.9969404, 3.9896376, 3.8046541, 3.5555122, 3.3091464, 3.257736 ,
       3.6201818, 3.9503036, 4.079712 , 4.2103205, 3.8557687, 3.9278026,
       3.8464408, 3.651188 , 3.4105062, 3.2170172, 3.3981993, 3.7377162,
       3.9932737, 4.1710896, 4.2710056, 4.296086 , 3.7716963, 3.7748668,
       3.6391177, 3.4687228, 3.3020885, 3.3594728, 3.6073673, 3.8444421,
       4.0852304, 4.234068 ], dtype=float32)
In [11]:
## this method support parallel computing
job_kwargs = dict(
    n_jobs=40,
    chunk_size=30000,
    progress_bar=True
)
In [16]:
# lets build dict for handling parameters
methods = {}
methods['naive'] =  ('naive', 
                    {'waveform_extractor' : we})
methods['tridesclous'] =  ('tridesclous',
                           {'waveform_extractor' : we,
                            'noise_levels' : noise_levels,
                            'num_closest' :3})
methods['circus'] =  ('circus',
                      {'waveform_extractor' : we,
                       'noise_levels' : noise_levels})
methods['circus-omp'] =  ('circus-omp',
                          {'waveform_extractor' : we,
                           'noise_levels' : noise_levels})


spikes_by_methods = {}
for name,  (method, method_kwargs) in methods.items():
    spikes = find_spikes_from_templates(recording, method=method, method_kwargs=method_kwargs, **job_kwargs)
    spikes_by_methods[name] = spikes
find spikes (naive): 100%|██████████| 1800/1800 [00:06<00:00, 294.34it/s]
find spikes (tridesclous): 100%|██████████| 1800/1800 [00:06<00:00, 277.50it/s]
[1] compute overlaps: 100%|██████████| 180/180 [00:00<00:00, 978.20it/s]
[2] compute amplitudes: 100%|██████████| 15/15 [00:01<00:00,  9.17it/s]
find spikes (circus): 100%|██████████| 1800/1800 [00:04<00:00, 386.06it/s]
find spikes (circus-omp): 100%|██████████| 1800/1800 [00:28<00:00, 63.54it/s]
In [17]:
## the output of every method is a numpy array with a complex dtype

spikes = spikes_by_methods['tridesclous']
print(spikes.dtype)
print(spikes.shape)
print(spikes[:5])
[('sample_ind', '<i8'), ('channel_ind', '<i8'), ('cluster_ind', '<i8'), ('amplitude', '<f8'), ('segment_ind', '<i8')]
(234977,)
[( 59,  0,  4, 1., 0) (309, 21,  8, 1., 0) (371, 13,  3, 1., 0)
 (623, 30, 14, 1., 0) (713, 31, 13, 1., 0)]

check performances method by method

For this:

  1. we transform the spikes vector into a sorting object
  2. use the compare_sorter_to_ground_truth() function to compute performances
  3. plot agreement matrix
  4. plot accuracy vs snr
  5. plot collision vs similarity

Note:

  • as we provide the true template list every matrix is supposed to be squared!!! The performances are can be seen on the diagonal. A perfect matching is supposed to have only ones on the diagonal.
  • The dataset here is one the dataset used in collision paper We can also make a fine benchmark on inspecting collision.
In [18]:
# load metrics for snr on true template
metrics = we.load_extension('quality_metrics').get_metrics()
In [20]:
templates = we.get_all_templates()

comparisons = {}
for name,  (method, method_kwargs) in methods.items():
    spikes = spikes_by_methods[name]

    sorting = si.NumpySorting.from_times_labels(spikes['sample_ind'], spikes['cluster_ind'], recording.get_sampling_frequency())
    print(sorting)

    comp = si.compare_sorter_to_ground_truth(gt_sorting, sorting)
    

    fig, axs = plt.subplots(ncols=2)
    si.plot_agreement_matrix(comp, ax=axs[0])
    si.plot_sorting_performance(comp, metrics, performance_name='accuracy', metric_name='snr', ax=axs[1], color='g')
    si.plot_sorting_performance(comp, metrics, performance_name='recall', metric_name='snr', ax=axs[1], color='b')
    si.plot_sorting_performance(comp, metrics, performance_name='precision', metric_name='snr', ax=axs[1], color='r')
    axs[0].set_title(name)
    axs[1].set_ylim(0.8, 1.1)
    axs[1].legend(['accuracy', 'recall', 'precision'])
    
    comp = si.CollisionGTComparison(gt_sorting, sorting)
    comparisons[name] = comp
    fig, ax = plt.subplots()
    si.plot_comparison_collision_by_similarity(comp, templates, figure=fig)
    fig.suptitle(name)

plt.show()
NumpySorting: 15 units - 1 segments - 30.0kHz
NumpySorting: 15 units - 1 segments - 30.0kHz
NumpySorting: 15 units - 1 segments - 30.0kHz
NumpySorting: 15 units - 1 segments - 30.0kHz

comparison of methods 2 by 2

In [ ]:
names = list(comparisons.keys())
n = len(names)

for r, name0 in enumerate(names):
    for c, name1 in enumerate(names):
        if r<=c:
            continue

        fig, ax = plt.subplots()
        val0 = comparisons[name0].get_performance()['accuracy']
        val1 = comparisons[name1].get_performance()['accuracy']
        ax.scatter(val0, val1)
        ax.set_xlabel(name0)
        ax.set_ylabel(name1)
        ax.plot([0,1], [0, 1], color='k')
        ax.set_title('accuracy')
        ax.set_xlim(0.6, 1)
        ax.set_ylim(0.6, 1)

conclusion

  • tridesclous and circus-omp are clear winner for performances
  • tridesclous is the fastest
  • Improvement must be done because performances are far to be perfect!!!
In [ ]:
 

spikeinterface destripe

destripe processsing in spikeinterface

Author : Samuel Garcia

Olivier Winter has develop for IBL a standard pre-processing chain in the ibllib to clean the traces before spike sorting. See this

This procesdure is called "destripe". This procedure remove artefact that are present on all channels (common noise)

The main idea is to have this:

  1. filter
  2. align sample (phase shift
  3. remove common noise
  4. apply spatial filter and bad channel interpolation

Except step 4., all other steps are available in spikeinterface

spikeinterface.toolkit.preprocessing propose some class and function to build what we call a lazy chain of processing.

Here an example with 4 files nicely given by Oliver Winter to illustarte the spikeinterface implementation of this destripe procedure.

In [42]:
# %matplotlib widget
%matplotlib inline
In [19]:
%load_ext autoreload
%autoreload 2
The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
In [20]:
from pathlib import Path

import time

import numpy as np
import matplotlib.pyplot as plt


import spikeinterface.full as si
In [21]:
base_folder = Path('/media/samuel/dataspikesorting/DataSpikeSorting/olivier_destripe/')

folder1 = base_folder / '4c04120d-523a-4795-ba8f-49dbb8d9f63a'
folder2 = base_folder / '68f06c5f-8566-4a4f-a4b1-ab8398724913'
folder3 = base_folder / '8413c5c6-b42b-4ec6-b751-881a54413628'
folder4 = base_folder / 'f74a6b9a-b8a5-4c80-9c30-7dd4cdbb48c0'
data_folders = [folder1, folder2, folder3, folder4]

Build the prprocessing chain

In spike interface we have:

  • bandpass_filter()
  • common_reference(): this remove common noise (global or local) by substraction of median (or average)
  • phase_shift(): this compensate the ADC shift across channel by applying a reverse in FFT transform.

That can be combined to get ore or less the same result than the "destripe".

Here we will compare 2 preprocessing:

  1. filter > cmr
  2. filter > phase_shift > cmr

The step 4. (kfilter) is not implemented yet but this should be done soon.

In [38]:
# lets have a function that build the chain and plot intermediate results

def preprocess_steps(rec, time_range=None, clim=(-80, 80), figsize=(15, 10)):
    
    # chain 1. : filter + cmr
    rec_filtered = si.bandpass_filter(rec, freq_min=300., freq_max=6000., dtype='float32')
    rec_cmr = si.common_reference(rec_filtered, reference='global', operator='median')
    
    # chain 2.. : filter + phase_shift + cmr
    rec_pshift = si.phase_shift(rec_filtered)
    rec_cmr2 = si.common_reference(rec_pshift, reference='global', operator='median')
    
    
    # rec
    fig, axs = plt.subplots(ncols=5, sharex=True, sharey=True, figsize=figsize)
    
    ax = axs[0]
    ax.set_title('raw')
    si.plot_timeseries(rec, ax=ax,  with_colorbar=False) # clim=clim,
    
    # filter
    
    ax = axs[1]
    ax.set_title('filtered')
    si.plot_timeseries(rec_filtered, ax=axs[1], clim=clim, with_colorbar=False)
    
    # filter + cmr
    
    # rec_preprocessed
    ax = axs[2]
    ax.set_title('filtered + cmr')
    si.plot_timeseries(rec_cmr, ax=axs[2], clim=clim, with_colorbar=False)
    
    # filter + phase_shift
    
    ax = axs[3]
    ax.set_title('filtered + phase_shift')
    si.plot_timeseries(rec_pshift, ax=ax, clim=clim, with_colorbar=False)
    
    # filtered + phase_shift + cmr
    
    ax = axs[4]
    ax.set_title('filtered + phase_shift + cmr')
    si.plot_timeseries(rec_cmr2, ax=ax, clim=clim, with_colorbar=True)

    # optionally a time range can be given
    if time_range is not None:
        ax.set_xlim(*time_range)

dataset 1

In [30]:
rec = si.read_cbin_ibl(folder1)
preprocess_steps(rec)
In [31]:
# zoom on a stripe
preprocess_steps(rec, time_range=(0.95, 0.97))

dataset 2

In [32]:
rec = si.read_cbin_ibl(folder2)
preprocess_steps(rec)
In [33]:
rec = si.read_cbin_ibl(folder2)
preprocess_steps(rec, time_range=(0.2, .3))

dataset3

In [34]:
rec = si.read_cbin_ibl(folder3)
preprocess_steps(rec)
In [35]:
rec = si.read_cbin_ibl(folder3)
preprocess_steps(rec, time_range=(0.797, .801))

dataset 4

In [39]:
rec = si.read_cbin_ibl(folder4)
preprocess_steps(rec, clim=(-50, 50))
In [41]:
preprocess_steps(rec, clim=(-50, 50), time_range=(0.368, .375))

conlusion

Here we demonstrate how to use the modular way of building a preprocessing chain directly in spikeinterface. This is particularly usefull because:

  1. the same preprocessing can be apply for diffrent sorters
  2. The preprocessing can cached in parralel using rec.save(...) in binary or zarr format
  3. Every steps can be parameterized depending the in put dataset and compute ressource available.

Collision paper spike sorting performance

Spike sorting performance against spike collisions (figure 2-3-5)

In this notebook, we describe how to generate the figures for all the sudies, i.e. for all rate and correlation levels, in a systematic manner. However, while by default the figures were saved as .pdf, here we will modify the ranges of rates and correlations to display only a single figures. Feel free to modify the scripts in order to display only a single figures

In [1]:
import numpy as np
from pathlib import Path

import matplotlib.pyplot as plt
from matplotlib import gridspec

import MEArec as mr
import spikeinterface.full as si
In [2]:
study_base_folder = Path('../data/study/')

Plot global spike sorting performance (Figure 2)

In [1]:
res = {}

rate_levels = [5]
corr_levels = [0]

for rate_level in rate_levels:
    for corr_level in corr_levels:

        fig = plt.figure(figsize=(15,5))
        gs = gridspec.GridSpec(2, 3, figure=fig)

        study_folder = study_base_folder / f'20cells_5noise_{corr_level}corr_{rate_level}rate_Neuronexus-32'
        
        study = si.GroundTruthStudy(study_folder)
        study.run_comparisons(exhaustive_gt=True)

        ax_1 = plt.subplot(gs[0, 0])
        ax_2 = plt.subplot(gs[0, 1:])
        ax_3 = plt.subplot(gs[1, 1:])
        ax_4 = plt.subplot(gs[1, 0])

        for ax in [ax_1, ax_2, ax_3, ax_4]:
            ax.spines['top'].set_visible(False)
            ax.spines['right'].set_visible(False)

        ax_2.tick_params(labelbottom=False)
        ax_2.set_xlabel('')

        si.plot_gt_study_run_times(study, ax=ax_1)
        si.plot_gt_study_unit_counts(study, ax=ax_2)
        si.plot_gt_study_performances_averages(study, ax=ax_3)
        si.plot_gt_study_performances_by_template_similarity(study, ax=ax_4)

        plt.tight_layout()

Plot collision recall as function of the lags (Figure 3)

In [2]:
for rate_level in rate_levels:
    for corr_level in corr_levels:
        study_folder = f'../data/study/20cells_5noise_{corr_level}corr_{rate_level}rate_Neuronexus-32'
        res[(rate_level, corr_level)] = si.CollisionGTStudy(study_folder)
        res[(rate_level, corr_level)].run_comparisons(exhaustive_gt=True, collision_lag=2, nbins=11)

        for rec_name in res[(rate_level, corr_level)].rec_names:
            res[(rate_level, corr_level)].compute_waveforms(rec_name)

        si.plot_study_comparison_collision_by_similarity(res[(rate_level, corr_level)], 
                                                         show_legend=False, ylim=(0.4, 1))
        plt.tight_layout()
Python 3.6.9 (default, Jan 26 2021, 15:33:00) 
[GCC 8.4.0] does not support parallel processing
Python 3.6.9 (default, Jan 26 2021, 15:33:00) 
[GCC 8.4.0] does not support parallel processing
Python 3.6.9 (default, Jan 26 2021, 15:33:00) 
[GCC 8.4.0] does not support parallel processing
Python 3.6.9 (default, Jan 26 2021, 15:33:00) 
[GCC 8.4.0] does not support parallel processing
Python 3.6.9 (default, Jan 26 2021, 15:33:00) 
[GCC 8.4.0] does not support parallel processing
Python 3.6.9 (default, Jan 26 2021, 15:33:00) 
[GCC 8.4.0] does not support parallel processing
Python 3.6.9 (default, Jan 26 2021, 15:33:00) 
[GCC 8.4.0] does not support parallel processing
Python 3.6.9 (default, Jan 26 2021, 15:33:00) 
[GCC 8.4.0] does not support parallel processing
Python 3.6.9 (default, Jan 26 2021, 15:33:00) 
[GCC 8.4.0] does not support parallel processing
Python 3.6.9 (default, Jan 26 2021, 15:33:00) 
[GCC 8.4.0] does not support parallel processing

Plot collision recall as function of the lag and/or cosine similarity (supplementary figures)

In [3]:
gs = gridspec.GridSpec(len(rate_levels), len(corr_levels))
for i, rate_level in enumerate(rate_levels):
    for j, corr_level in enumerate(corr_levels):
        
        ax = plt.subplot(gs[i, j])

        if i == 0 and j == 0:
            show_legend = True
        else:
            show_legend = False
        si.plot_study_comparison_collision_by_similarity_range(res[(rate_level, corr_level)], show_legend=show_legend, similarity_range=[0.5, 1], ax=ax, ylim=(0.3, 1))

        ax.set_title(f'Rate {rate_level} Hz, Corr {corr_level}' )
        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)

        if rate_level != rate_levels[-1]:
            ax.tick_params(labelbottom=False)
            ax.set_xlabel('')
        else:
            ax.set_xlabel('lags (ms)')

        if corr_level != corr_levels[0]:
            ax.tick_params(labelleft=False)
            ax.set_ylabel('')
        else:
            ax.set_ylabel('collision accuracy')
In [4]:
gs = gridspec.GridSpec(len(rate_levels), len(corr_levels))
for i, rate_level in enumerate(rate_levels):
    for j, corr_level in enumerate(corr_levels):
        
        ax = plt.subplot(gs[i, j])

        if i == 0 and j == 0:
            show_legend = True
        else:
            show_legend = False
        si.plot_study_comparison_collision_by_similarity_ranges(res[(rate_level, corr_level)], show_legend=show_legend, ax=ax)

        ax.set_title(f'Rate {rate_level} Hz, Corr {corr_level}' )
        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)

        if rate_level != rate_levels[-1]:
            ax.tick_params(labelbottom=False)
            ax.set_xlabel('')
        else:
            ax.set_xlabel('similarity')

        if corr_level != corr_levels[0]:
            ax.tick_params(labelleft=False)
            ax.set_ylabel('')
        else:
            ax.set_ylabel('collision accuracy')

Plot average collision recall over multiple conditions, as function of the lags (Figure 5)

In [9]:
rate_levels = [5,10,15]
corr_levels = [0, 0.1, 0.2]

gs = gridspec.GridSpec(1, 2)
ax = plt.subplot(gs[0, 0])
curves = {}
for i, rate_level in enumerate(rate_levels):
    for j, corr_level in enumerate(corr_levels):
        
        study_folder = f'../data/study/20cells_5noise_{corr_level}corr_{rate_level}rate_Neuronexus-32'
        res[(rate_level, corr_level)] = si.CollisionGTStudy(study_folder)
        res[(rate_level, corr_level)].run_comparisons(exhaustive_gt=True, collision_lag=2, nbins=11)
        res[(rate_level, corr_level)].precompute_scores_by_similarities()
        
        for sorter_name in res[(rate_level, corr_level)].sorter_names:
            data = res[(rate_level, corr_level)].get_mean_over_similarity_range([0.5, 1], sorter_name)
            if not sorter_name in curves:
                curves[sorter_name] = [data]
            else:
                curves[sorter_name] += [data]

lags = res[(rate_level, corr_level)].get_lags()
for sorter_name in res[(rate_level, corr_level)].sorter_names:
    curves[sorter_name] = np.array(curves[sorter_name])
    mean_sorter = curves[sorter_name].mean(0)
    std_sorter = curves[sorter_name].std(0)
    ax.plot(lags[:-1] + (lags[1]-lags[0]) / 2, mean_sorter, label=sorter_name)
    ax.fill_between(lags[:-1] + (lags[1]-lags[0]) / 2, mean_sorter-std_sorter,mean_sorter+std_sorter, alpha=0.2)

ax.legend()
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.set_xlabel('lag (ms)')
ax.set_ylabel('collision accuracy')
Out[9]:
Text(0, 0.5, 'collision accuracy')

Plotting the average collision recall over multiple conditions, as function of the similarity

In [5]:
rate_levels = [5,10,15]
corr_levels = [0, 0.1, 0.2]
res = {}
gs = gridspec.GridSpec(1, 2)
ax = plt.subplot(gs[0, 0])
curves = {}
similarity_ranges = np.linspace(-0.4, 1, 8)
for i, rate_level in enumerate(rate_levels):
    for j, corr_level in enumerate(corr_levels):
        
        study_folder = f'../data/study/20cells_5noise_{corr_level}corr_{rate_level}rate_Neuronexus-32'
        res[(rate_level, corr_level)] = si.CollisionGTStudy(study_folder)
        res[(rate_level, corr_level)].run_comparisons(exhaustive_gt=True, collision_lag=2, nbins=11)
        res[(rate_level, corr_level)].precompute_scores_by_similarities()
        
        for sorter_name in res[(rate_level, corr_level)].sorter_names:

            all_similarities = res[(rate_level, corr_level)].all_similarities[sorter_name]
            all_recall_scores = res[(rate_level, corr_level)].all_recall_scores[sorter_name]

            order = np.argsort(all_similarities)
            all_similarities = all_similarities[order]
            all_recall_scores = all_recall_scores[order, :]

            mean_recall_scores = []
            std_recall_scores = []
            for k in range(similarity_ranges.size - 1):
                cmin, cmax = similarity_ranges[k], similarity_ranges[k + 1]
                amin, amax = np.searchsorted(all_similarities, [cmin, cmax])
                value = np.mean(all_recall_scores[amin:amax])
                mean_recall_scores += [np.nan_to_num(value)]

            xaxis = np.diff(similarity_ranges)/2 + similarity_ranges[:-1]

            data = mean_recall_scores
            if not sorter_name in curves:
                curves[sorter_name] = [data]
            else:
                curves[sorter_name] += [data]

for sorter_name in res[(rate_level, corr_level)].sorter_names:
    curves[sorter_name] = np.array(curves[sorter_name])
    mean_sorter = curves[sorter_name].mean(0)
    std_sorter = curves[sorter_name].std(0)
    ax.plot(xaxis, mean_sorter, label=sorter_name)
    ax.fill_between(xaxis, mean_sorter-std_sorter,mean_sorter+std_sorter, alpha=0.2)


ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.set_xlabel('cosine similarity')
#ax.set_ylabel('collision accuracy')
#ax.set_yticks([])

plt.tight_layout()
/home/cure/.local/lib/python3.6/site-packages/numpy/core/fromnumeric.py:3373: RuntimeWarning: Mean of empty slice.
  out=out, **kwargs)
/home/cure/.local/lib/python3.6/site-packages/numpy/core/_methods.py:170: RuntimeWarning: invalid value encountered in double_scalars
  ret = ret.dtype.type(ret / rcount)
/home/cure/.local/lib/python3.6/site-packages/numpy/core/fromnumeric.py:3373: RuntimeWarning: Mean of empty slice.
  out=out, **kwargs)
/home/cure/.local/lib/python3.6/site-packages/numpy/core/_methods.py:170: RuntimeWarning: invalid value encountered in double_scalars
  ret = ret.dtype.type(ret / rcount)
/home/cure/.local/lib/python3.6/site-packages/numpy/core/fromnumeric.py:3373: RuntimeWarning: Mean of empty slice.
  out=out, **kwargs)
/home/cure/.local/lib/python3.6/site-packages/numpy/core/_methods.py:170: RuntimeWarning: invalid value encountered in double_scalars
  ret = ret.dtype.type(ret / rcount)

Collision paper simulated recordings

Simulated recordings overview (figure 1)

This notebook reproduces Figure 1 of the manuscript: "How do spike collisions affect spike sorting performance?"

To run this notebook, you first need to run the generate_recordings.ipynb notebook.

In [1]:
import shutil
import sys
from pathlib import Path

import numpy as np
import scipy.spatial

import matplotlib
import matplotlib.pyplot as plt
import matplotlib.colors as colors

import MEArec as mr
import spikeinterface.full as si


my_cmap = plt.get_cmap('winter')
cNorm  = colors.Normalize(vmin=0, vmax=1)
scalarMap = plt.cm.ScalarMappable(norm=cNorm, cmap=my_cmap)
In [ ]:
sys.path.append("../utils")

from generation_utils import generation_params
from study_utils import generate_study
In [ ]:
recordings_folder = Path('../data/recordings/'')
In [2]:
# define some parameters

nb_traces = 10 # for panel I
window_ms = 20 #for CC plots
bin_ms = 0.2 # for CC plots
n_cell = 20 #
lag_time = generation_params['lag_time']*1000
corr_level = 0 # to select the appropriate recording if several (run generation first)
rate_level = 5 # to select the appropriate recording if several (run generation first)
In [8]:
# We use the plotting.py script to ease the creation of figures with several panels. 
figA, axA = plt.subplots()

# We load the file
rec_file = recordings_folder / f'rec0_20cells_5noise_{corr_level}corr_{rate_level}rate_Neuronexus-32.h5'

mearec_object = mr.load_recordings(rec_file)
rec = si.MEArecRecordingExtractor(rec_file)
sorting_gt = si.MEArecSortingExtractor(rec_file)

waveforms_path = Path('.') / 'tmp'
waveforms_path.mkdir(exist_ok=True)

waveforms = si.extract_waveforms(rec, sorting_gt, waveforms_path, ms_before=3, ms_after=3)

original_templates = waveforms.get_all_templates()
snrs = np.array([i for i in si.compute_snrs(waveforms).values()])
rates = np.array([i for i in si.compute_firing_rate(waveforms).values()])


## Plotting the probe layout and the cell positions
si.plot_unit_localization(waveforms, ax=axA)
axA.set_ylabel('y (um)')
axA.set_xlabel('x (um)')
In [9]:
figB, axB = plt.subplots(ncols=3, figsize=(12, 7))

colors = {'#0' : 'k', '#16' : 'r'}

similarities = si.compute_template_similarity(waveforms)

## Plotting example of pair with selected similarity
si.plot_unit_templates(waveforms, plot_templates=True, axes=[axB[0]], unit_ids=['#0'], unit_colors=colors)
si.plot_unit_templates(waveforms, plot_templates=True, axes=[axB[0]], unit_ids=['#16'], unit_colors=colors)
axB[0].set_title('(#0, #16) similarity %02g' %similarities[0, 16])   

colors = {'#0' : 'k', '#10' : 'r'}
si.plot_unit_templates(waveforms, plot_templates=True, axes=[axB[1]], unit_ids=['#0'], unit_colors=colors)
si.plot_unit_templates(waveforms, plot_templates=True, axes=[axB[1]], unit_ids=['#10'], unit_colors=colors)    
axB[1].set_title('(#0, #10) similarity %02g' %similarities[0, 10])

colors = {'#0' : 'k', '#1' : 'r'}
si.plot_unit_templates(waveforms, plot_templates=True, axes=[axB[2]], unit_ids=['#0'], unit_colors=colors)
si.plot_unit_templates(waveforms, plot_templates=True, axes=[axB[2]], unit_ids=['#1'], unit_colors=colors)    
axB[2].set_title('(#0, #1) similarity %02g' %similarities[0, 1])
figB.tight_layout()
In [11]:
figC, axC = pltsubplotsfigure()

## Plotting the similarity matrix
im = axC.imshow(similarities, cmap='viridis',
                aspect='auto',
                origin='lower',
                interpolation='none',
                extent=(-0.5, n_cell-1+0.5, -0.5, n_cell-1+0.5))
axC.set_xlabel('# cell')
axC.set_ylabel('# cell')
plt.colorbar(im, ax=axC, label='cosine similarity')
Out[11]:
<matplotlib.colorbar.Colorbar at 0x7f6c27a39208>
In [12]:
figDE, axDE = plt.subplots(nrows=2)

centers = np.array([v for v in si.compute_unit_centers_of_mass(waveforms).values()])
real_centers = mearec_object.template_locations[:]

distances = scipy.spatial.distance_matrix(centers, centers)
real_distances =  scipy.spatial.distance_matrix(real_centers, real_centers)

# Plotting the distribution of similarities as function of distance (either real or estimated)
axDE[0].plot(distances.flatten(), similarities.flatten(), '.', label='Center of Mass')
axDE[0].plot(real_distances.flatten(), similarities.flatten(), '.', label='Real position')
axDE[0].legend()
axDE[0].set_xlabel('distances (um)')
axDE[0].set_ylabel('cosine similarity')

x, y = np.histogram(similarities.flatten(), 10)
axD[1].bar(y[1:], x/float(x.sum()), width=y[1]-y[0])
axD[1].set_xlabel('cosine similarity')
axD[1].set_ylabel('probability')
Out[12]:
Text(0, 0.5, 'probability')
In [14]:
## For the CC, you should uncomment the following line, but the figure was assembled
w = si.plot_crosscorrelograms(sorting_gt, ['#%s' %i for i in range(0,3)], 
                              bin_ms=bin_ms, window_ms=window_ms, symmetrize=True)
figF = w.figure
Out[14]:
<spikeinterface.widgets.correlograms.CrossCorrelogramsWidget at 0x7f6bae84ae80>
In [16]:
figGH, axGH = plt.subplots(nrows=2)

ccs, lags = si.compute_correlograms(sorting_gt, bin_ms=bin_ms, window_ms=window_ms, symmetrize=True)
ccs = ccs.reshape(n_cell**2, ccs.shape[2])
mask = np.ones(n_cell**2).astype(np.bool)
mask[np.arange(0, n_cell**2, n_cell) + np.arange(n_cell)] = False
mean_cc = np.mean(ccs[mask], 0)
std_cc = np.std(ccs[mask], 0)

## Plotting the average CC
xaxis = (lags[:-1] - lags[:-1].mean())
axGH[0].plot(xaxis, mean_cc, lw=2, c='r')
axGH[0].fill_between(xaxis, mean_cc-std_cc,mean_cc+std_cc, color='0.5', alpha=0.5)
axGH[0].set_xlabel('time (ms)')
axGH[0].set_ylabel('cross correlation')
ymin, ymax = axGH[0].get_ylim()
axGH[0].plot([-lag_time,-lag_time],[ymin,ymax],'k--')
axGH[0].plot([lag_time,lag_time],[ymin,ymax],'k--')

mask = np.zeros(n_cell**2).astype(np.bool)
mask[np.arange(0, n_cell**2, n_cell) + np.arange(n_cell)] = True
mean_cc = np.mean(ccs[mask], 0)
std_cc = np.std(ccs[mask], 0)

xaxis = (lags[:-1] - lags[:-1].mean())
axGH[1].plot(xaxis, mean_cc, lw=2, c='r')
axGH[1].fill_between(xaxis, mean_cc-std_cc,mean_cc+std_cc, color='0.5', alpha=0.5)
axGH[1].set_ylabel('auto correlation')
ymin, ymax = axGH[1].get_ylim()
axGH[1].plot([-lag_time,-lag_time],[ymin,ymax],'k--')
axGH[1].plot([lag_time,lag_time],[ymin,ymax],'k--')
Out[16]:
[<matplotlib.lines.Line2D at 0x7f6c702a1358>]
In [18]:
## Plotting timeseries
w = si.plot_timeseries(rec, time_range=(5,5.1), channel_ids=['%s' %i for i in range(1,nb_traces)], color='k')
figI = w.figure
Out[18]:
<spikeinterface.widgets.timeseries.TimeseriesWidget at 0x7f6badf61f60>

Collision paper generate recordings

Generation of the recordings

In this notebook, we will generate all the recordings with MEArec that will be necessary to populate the study and compare the sorters. First, we need to create a function that will, given a dictionary of parameter, generate a single recording. The recording parameters can be defined as follows

In [4]:
import os
import sys
import shutil
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt

import MEArec as mr
import spikeinterface.full as si
In [8]:
sys.path.append('../utils/')

from corr_spike_trains import CorrelatedSpikeGenerator
In [3]:
generation_params = {
    'probe' : 'Neuronexus-32', #layout of the probe used
    'duration' : 30*60, #total duration of the recording
    'n_cell' : 20, # number of cells that will be injected
    'fs' : 30000., # sampling rate
    'lag_time' : 0.002,  # half refractory period in ms
    'make_plots' : True,
    'generate_recording' : True,
    'noise_level' : 5,
    'templates_seed' : 42,
    'noise_seed' : 42,
    'global_path' : os.path.abspath('../'),
    'study_number' : 0,
    'save_plots' : True,
    'method' : 'brette', # 'poisson' | 'brette'
    'corr_level' : 0,
    'rate_level' : 5, #Hz
    'nb_recordings' : 5
}

With these parameters, we will create 20 neurons, and correlation levels will be generated via the mixture process of [Brette et al, 2009]. The function to generate a single recording is defined as follows. It assumes that you have, in your folder, a file named ../data/templates/templates_{probe}_100.h5 with all the pre-generated templates that will be used by MEArec

In [5]:
def generate_single_recording(params=generation_params):

    paths = {}
    paths['basedir'] = params['global_path']
    paths['data'] = None

    if paths['data'] == None:
        paths['data'] = os.path.join(paths['basedir'], 'data')

    paths['templates'] =  os.path.join(paths['data'], 'templates')
    paths['recordings'] = os.path.join(paths['data'], 'recordings') 

    for i in paths.values():
        if not os.path.exists(i):
            os.makedirs(i)

    probe = params['probe']
    n_cell = params['n_cell']
    noise_level = params['noise_level']
    study_number = params['study_number']
    corr_level = params['corr_level']
    rate_level = params['rate_level']

    template_filename = os.path.join(paths['templates'], f'templates_{probe}_100.h5')
    recording_filename = os.path.join(paths['recordings'], f'rec{study_number}_{n_cell}cells_{noise_level}noise_{corr_level}corr_{rate_level}rate_{probe}.h5')
    plot_filename = os.path.join(paths['recordings'], f'rec{study_number}_{n_cell}cells_{noise_level}noise_{corr_level}corr_{rate_level}rate_{probe}.pdf')

    spikerate = params['rate_level']
    n_spike_alone = int(spikerate * params['duration'])

    print('Total target rate:', params['rate_level'], "Hz")
    print('Basal rate:', spikerate, "Hz")


    # collision lag range
    lag_sample = int(params['lag_time'] * params['fs'])

    refactory_period = 2 * params['lag_time']

    spiketimes = []

    if params['method'] == 'poisson':
        print('Spike trains generated as independent poisson sources')
        
        for i in range(params['n_cell']):
            
            #~ n = n_spike_alone + n_collision_by_pair * (params['n_cell'] - i - 1)
            n = n_spike_alone
            #~ times = np.random.rand(n_spike_alone) * params['duration']
            times = np.random.rand(n) * params['duration']
            
            times = np.sort(times)
            spiketimes.append(times)

    elif params['method'] == 'brette':
        print('Spike trains generated as compound mixtures')
        C = np.ones((params['n_cell'], params['n_cell']))
        C = params['corr_level'] * np.maximum(C, C.T)
        #np.fill_diagonal(C, 0*np.ones(params['n_cell']))

        rates = rates = params['rate_level']*np.ones(params['n_cell'])

        cor_spk = CorrelatedSpikeGenerator(C, rates, params['n_cell'])
        cor_spk.find_mixture(iter=1e4)
        res = cor_spk.mixture_process(tauc=refactory_period/2, t=params['duration'])
        
        # make neo spiketrains
        for i in range(params['n_cell']):
            #~ print(spiketimes[i])
            mask = res[:, 0] == i
            times = res[mask, 1]
            times = np.sort(times)
            mask = (times > 0) * (times < params['duration'])
            times = times[mask]
            spiketimes.append(times)


    # remove refactory period
    for i in range(params['n_cell']):
        times = spiketimes[i]
        ind, = np.nonzero(np.diff(times) < refactory_period)
        ind += 1
        times = np.delete(times, ind)
        assert np.sum(np.diff(times) < refactory_period) ==0
        spiketimes[i] = times

    # make neo spiketrains
    spiketrains = []
    for i in range(params['n_cell']):
        mask = np.where(spiketimes[i] > 0)
        spiketimes[i] = spiketimes[i][mask] 
        spiketrain = neo.SpikeTrain(spiketimes[i], units='s', t_start=0*pq.s, t_stop=params['duration']*pq.s)
        spiketrain.annotate(cell_type='E')
        spiketrains.append(spiketrain)

    # check with sanity plot here
    if params['make_plots']:
        
        # count number of spike per units
        fig, axs = plt.subplots(2, 2)
        count = [st.size for st in spiketrains]
        ax = axs[0, 0]
        simpleaxis(ax)
        pairs = []
        collision_count_by_pair = []
        collision_count_by_units = np.zeros(n_cell)
        for i in range(n_cell):
            for j in range(i+1, n_cell):
                times1 = spiketrains[i].rescale('s').magnitude
                times2 = spiketrains[j].rescale('s').magnitude
                matching_event = make_matching_events((times1*params['fs']).astype('int64'), (times2*params['fs']).astype('int64'), lag_sample)
                pairs.append(f'{i}-{j}')
                collision_count_by_pair.append(matching_event.size)
                collision_count_by_units[i] += matching_event.size
                collision_count_by_units[j] += matching_event.size
        ax.plot(np.arange(len(collision_count_by_pair)), collision_count_by_pair)
        ax.set_xticks(np.arange(len(collision_count_by_pair)))
        ax.set_xticklabels(pairs)
        ax.set_ylim(0, max(collision_count_by_pair) * 1.1)
        ax.set_ylabel('# Collisions')
        ax.set_xlabel('Pairs')

        # count number of spike per units
        count_total = np.array([st.size for st in spiketrains])
        count_not_collision = count_total - collision_count_by_units

        ax = axs[1, 0]
        simpleaxis(ax)
        ax.bar(np.arange(n_cell).astype(np.int)+1, count_not_collision, color='g')
        ax.bar(np.arange(n_cell).astype(np.int)+1, collision_count_by_units, bottom =count_not_collision, color='r')
        ax.set_ylabel('# spikes')
        ax.set_xlabel('Cell id')
        ax.legend(('Not colliding', 'Colliding'), loc='best')

        # cross corrlogram
        ax = axs[0, 1]
        simpleaxis(ax)
        counts = []
        for i in range(n_cell):
            for j in range(i+1, n_cell):
                times1 = spiketrains[i].rescale('s').magnitude
                times2 = spiketrains[j].rescale('s').magnitude
                matching_event = make_matching_events((times1*params['fs']).astype('int64'), (times2*params['fs']).astype('int64'), lag_sample)
                
                #~ ax = axs[i, j]
                all_lag = matching_event['delta_frame']  / params['fs']
                count, bins  = np.histogram(all_lag, bins=np.arange(-params['lag_time'], params['lag_time'], params['lag_time']/20))
                #~ ax.bar(bins[:-1], count, bins[1] - bins[0])
                ax.plot(1000*bins[:-1], count, bins[1] - bins[0], c='0.5')
                counts += [count]
        counts = np.array(counts)
        counts = np.mean(counts, 0)
        ax.plot(1000*bins[:-1], counts, bins[1] - bins[0], c='r')
        ax.set_xlabel('Lags [ms]')
        ax.set_ylabel('# Collisions')

        ax = axs[1, 1]
        simpleaxis(ax)
        ratios = []
        for i in range(n_cell):
            nb_spikes = len(spiketrains[i])
            nb_collisions = 0
            times1 = spiketrains[i].rescale('s').magnitude
            for j in list(range(0, i)) + list(range(i+1, n_cell)):
                times2 = spiketrains[j].rescale('s').magnitude
                matching_event = make_matching_events((times1*params['fs']).astype('int64'), (times2*params['fs']).astype('int64'), lag_sample)
                nb_collisions += matching_event.size

            if nb_collisions > 0:
                ratios += [nb_spikes / nb_collisions]
            else:
                ratios += [0]

        ax.bar([0], [np.mean(ratios)], yerr=[np.std(ratios)])
        ax.set_ylabel('# spikes / # collisions')
        plt.tight_layout()

        if params['save_plots']:
            plt.savefig(plot_filename)
        else:
            plt.show()
        plt.close()

    if params['generate_recording']:
        spgen = mr.SpikeTrainGenerator(spiketrains=spiketrains)
        rec_params = mr.get_default_recordings_params()
        rec_params['recordings']['fs'] = params['fs']
        rec_params['recordings']['sync_rate'] = None
        rec_params['recordings']['sync_jitter'] = 5
        rec_params['recordings']['noise_level'] = params['noise_level']
        rec_params['recordings']['filter'] = False
        rec_params['spiketrains']['duration'] = params['duration']
        rec_params['spiketrains']['n_exc'] = params['n_cell']
        rec_params['spiketrains']['n_inh'] = 0
        rec_params['recordings']['chunk_duration'] = 10.
        rec_params['templates']['n_overlap_pairs'] = None
        rec_params['templates']['min_dist'] = 0
        rec_params['seeds']['templates'] = params['templates_seed']
        rec_params['seeds']['noise'] = params['noise_seed']
        recgen = mr.gen_recordings(params=rec_params, spgen=spgen, templates=template_filename, verbose=True)
        mr.save_recording_generator(recgen, filename=recording_filename)

Once this function is created, we can create an additional function that will generate several recordings, with different suffix/seeds:

In [6]:
def generate_recordings(params=generation_params):
    for i in range(params['nb_recordings']):
        generation_params['study_number'] = i
        generation_params['templates_seed'] = i
        generation_params['noise_seed'] = i
        generate_single_recording(generation_params)

And now, we have all the required tools to create our recordings. By default, they will all be saved in the folder ../recordings/

In [7]:
## Provide the different rate and correlations levels you want to generate
rate_levels = [5, 10, 15]
corr_levels = [0, 0.1, 0.2]
generation_params['nb_recordings'] = 5 #Number of recordings per conditions
In [ ]:
result = {}

for rate_level in rate_levels:
    for corr_level in corr_levels:

        generation_params['rate_level'] = rate_level
        generation_params['corr_level'] = corr_level
        generate_recordings(generation_params)

Generation of the study objects

Since the recordings have been generated, we now need to create Study objects for spikeinterface, and run the sorters on all these recordings. Be careful that by default, this can create quite a large amount of data, if you have numerous rate/correlation levels and/or number of recordings and/or sorters. First, we need to tell spikeinterface how to find the sorters

In [11]:
ironclust_path = '/media/cure/Secondary/pierre/softwares/ironclust'
kilosort1_path = '/media/cure/Secondary/pierre/softwares/Kilosort-1.0'
kilosort2_path = '/media/cure/Secondary/pierre/softwares/Kilosort-2.0'
kilosort3_path = '/media/cure/Secondary/pierre/softwares/Kilosort-3.0'
hdsort_path = '/media/cure/Secondary/pierre/softwares/HDsort'
os.environ["KILOSORT_PATH"] = kilosort1_path
os.environ["KILOSORT2_PATH"] = kilosort2_path
os.environ["KILOSORT3_PATH"] = kilosort3_path
os.environ['IRONCLUST_PATH'] = ironclust_path
os.environ['HDSORT_PATH'] = hdsort_path

And then we need to create a function that will, given a list of recordings, create a study and run all the sorters

In [13]:
def generate_study(params, keep_data=True):
    paths = {}
    paths['basedir'] = params['global_path']
    paths['data'] = None

    if paths['data'] == None:
        paths['data'] = os.path.join(paths['basedir'], 'data')

    paths['templates'] =  os.path.join(paths['data'], 'templates')
    paths['recordings'] = os.path.join(paths['data'], 'recordings')
    paths['study'] = os.path.join(paths['data'], 'study')
    
    for i in paths.values():
        if not os.path.exists(i):
            os.makedirs(i)

    probe = params['probe']
    n_cell = params['n_cell']
    noise_level = params['noise_level']
    study_number = params['study_number']
    corr_level = params['corr_level']
    rate_level = params['rate_level']

    paths['mearec_filename'] = []

    study_folder = os.path.join(paths['study'], f'{n_cell}cells_{noise_level}noise_{corr_level}corr_{rate_level}rate_{probe}')
    study_folder = Path(study_folder)

    if params['reset_study'] and os.path.exists(study_folder):
        shutil.rmtree(study_folder)

    print('Availables sorters:')
    si.print_sorter_versions()

    gt_dict = {}

    if not os.path.exists(study_folder):

        for i in range(params['nb_recordings']):
            paths['mearec_filename'] += [os.path.join(paths['recordings'], f'rec{i}_{n_cell}cells_{noise_level}noise_{corr_level}corr_{rate_level}rate_{probe}.h5')]

        print('Availables recordings:')
        print(paths['mearec_filename'])

        
        for count, file in enumerate(paths['mearec_filename']):
            rec  = si.MEArecRecordingExtractor(file)
            sorting_gt = si.MEArecSortingExtractor(file)
            gt_dict['rec%d' %count] = (rec, sorting_gt)

        study = si.GroundTruthStudy.create(study_folder, gt_dict, n_jobs=-1, chunk_memory='1G', progress_bar=True)
        study.run_sorters(params['sorter_list'], verbose=False, docker_images=params['docker_images'])
        print("Study created!")
    else:
        study = si.GroundTruthStudy(study_folder)
        if params['relaunch'] == 'all':
            if_exist = 'overwrite'
        elif params['relaunch'] == 'some':
            if_exist = 'keep'

        if params['relaunch'] in ['all', 'some']:
            study.run_sorters(params['sorter_list'], verbose=False, mode_if_folder_exists=if_exist, docker_images=params['docker_images'])
            print("Study loaded!")

    study.copy_sortings()

    if not keep_data:

        for sorter in params['sorter_list']:

            for rec in ['rec%d' %i for i in range(params['nb_recordings'])]:
                sorter_path = os.path.join(study_folder, 'sorter_folders', rec, sorter)
                if os.path.exists(sorter_path):
                    for f in os.listdir(sorter_path):
                        if f != 'spikeinterface_log.json':
                            full_file = os.path.join(sorter_path, f)
                            try:
                                if os.path.isdir(full_file):
                                    shutil.rmtree(full_file)
                                else:
                                    os.remove(full_file)
                            except Exception:
                                pass
        for file in paths['mearec_filename']:
            os.remove(file)

    return study

This function will take a dictionary of inputs (the same as for generating the recordings), and looping over all the possible recordings for a given condition (probe, rate, correlation levels) it will create a study in the path ../study/, running all the sorters on the recordings. This can take a lot of time, depending on the number of recordings/sorters. Note also that by default, the original recorindgs generated by MEArec are kept, and thus duplicated in the study folder. If you want to delete the original recordings (they are not needed for further analysis) you can set keep_data=False

In [14]:
study_params = generation_params.copy()
study_params['sorter_list'] = ['yass', 'kilosort', 'kilosort2', 'kilosort3', 'spykingcircus', 'tridesclous', 'ironclust', 'herdingspikes', 'hdsort']
study_params['docker_images'] = {'yass' : 'spikeinterface/yass-base:2.0.0'} #If some sorters are installed via docker
study_params['relaunch'] = 'all' #If you want to relaunch the sorters. 
study_params['reset_study'] = False #If you want to reset the study (delete everything)
In [ ]:
all_studies = {}
for rate_level in rate_levels:
    for corr_level in corr_levels:

        study_params['rate_level'] = rate_level
        study_params['corr_level'] = corr_level
        all_studies[corr_level, rate_level] = generate_study(study_params)

And this is it! Now you should have several studies, each of them with several recordings that have be analyzed by several sorters, in a structured manner (as function of rate/correlations levels)

probeinterface paper figures

Figure for probeinterface paper

Here a notebook to reproduce figures for paper

ProbeInterface: a unified framework for probe handling in extracellular electrophysiology

In [2]:
from probeinterface import plotting, io, Probe, ProbeGroup, get_probe
from probeinterface.plotting import plot_probe_group

import numpy as np
import matplotlib.pyplot as plt

%matplotlib inline
In [3]:
# create contact positions
positions = np.zeros((32, 2))
positions[:, 0] = [0] * 8 + [50] * 8 + [200] * 8 + [250] * 8
positions[:, 1] = list(range(0, 400, 50)) * 4
# create an empty probe object with coordinates in um
probe0 = Probe(ndim=2, si_units='um')
# set contacts
probe0.set_contacts(positions=positions, shapes='circle',shape_params={'radius': 10})
# create probe shape (optional)
polygon = [(-20, 480), (-20, -30), (20, -110), (70, -30), (70, 450),
           (180, 450), (180, -30), (220, -110), (270, -30), (270, 480)]
probe0.set_planar_contour(polygon)
In [4]:
# duplicate the probe and move it horizontally
probe1 = probe0.copy()
# move probe by 600 um in x direction
probe1.move([600, 0])

# Create a probegroup
probegroup = ProbeGroup()
probegroup.add_probe(probe0)
probegroup.add_probe(probe1)
In [5]:
fig2, ax2 = plt.subplots(figsize=(10,7))
plot_probe_group(probegroup, ax=ax2)
In [6]:
fig2.savefig("fig2.pdf")
In [7]:
probe0 = get_probe('cambridgeneurotech', 'ASSY-156-P-1')
probe1 = get_probe('neuronexus', 'A1x32-Poly3-10mm-50-177')
probe1.move([1000, -100])

probegroup = ProbeGroup()
probegroup.add_probe(probe0)
probegroup.add_probe(probe1)

fig3, ax3 = plt.subplots(figsize=(10,7))
plot_probe_group(probegroup, ax=ax3)
In [8]:
fig3.savefig("fig3.pdf")
In [9]:
manufacturer = 'cambridgeneurotech'
probe_name = 'ASSY-156-P-1'

probe = get_probe(manufacturer, probe_name)
print(probe)
cambridgeneurotech - ASSY-156-P-1 - 64ch - 4shanks
In [10]:
probe.wiring_to_device('ASSY-156>RHD2164')

fig4, ax4 = plt.subplots(figsize=(12,7))
plotting.plot_probe(probe, with_device_index=True, with_contact_id=True, title=False, ax=ax4)
ax4.set_xlim(-100, 400)
ax4.set_ylim(-150, 100)
Out[10]:
(-150.0, 100.0)
In [11]:
fig4.savefig("fig4.pdf")
In [12]:
probe.device_channel_indices
Out[12]:
array([47, 46, 45, 44, 43, 42, 41, 40, 39, 38, 37, 36, 35, 34, 33, 32, 31,
       30, 29, 28, 27, 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 16, 15, 14,
       13, 12, 11, 10,  9,  8,  7,  6,  5,  4,  3,  2,  1,  0, 63, 62, 61,
       60, 59, 58, 57, 56, 55, 54, 53, 52, 51, 50, 49, 48])
In [13]:
probe.to_dataframe(complete=True)
Out[13]:
x y contact_shapes width height shank_ids contact_ids device_channel_indices si_units plane_axis_x_0 plane_axis_x_1 plane_axis_y_0 plane_axis_y_1
0 522.5 137.5 rect 11.0 15.0 2 1 47 um 1.0 0.0 0.0 1.0
1 500.0 50.0 rect 11.0 15.0 2 2 46 um 1.0 0.0 0.0 1.0
2 522.5 187.5 rect 11.0 15.0 2 3 45 um 1.0 0.0 0.0 1.0
3 500.0 125.0 rect 11.0 15.0 2 4 44 um 1.0 0.0 0.0 1.0
4 772.5 112.5 rect 11.0 15.0 3 5 43 um 1.0 0.0 0.0 1.0
... ... ... ... ... ... ... ... ... ... ... ... ... ...
59 772.5 37.5 rect 11.0 15.0 3 60 52 um 1.0 0.0 0.0 1.0
60 750.0 150.0 rect 11.0 15.0 3 61 51 um 1.0 0.0 0.0 1.0
61 750.0 50.0 rect 11.0 15.0 3 62 50 um 1.0 0.0 0.0 1.0
62 750.0 125.0 rect 11.0 15.0 3 63 49 um 1.0 0.0 0.0 1.0
63 772.5 12.5 rect 11.0 15.0 3 64 48 um 1.0 0.0 0.0 1.0

64 rows × 13 columns