Compare old vs new spikeinterface API

Compare "old" vs "new " spikeinterface API

Author : Samuel Garcia 29 March 2021

In spring 2021, the spikeinterface team plan a "big refactoring" of the spikeinterface tool suite.

Main changes are:

  • use neo as much as possible for extractors
  • handle multi segment
  • improve performance (pre and post processing)
  • add A WaveformExtractor class

Here I will benchmark 2 aspects of the "new API":

  • filter with 10 workers on a multi core machine
  • extractor waveform 1 worker vs 10 workers

The becnhmark is done a 10 min spikeglx file with 384 channels.

The sorting is done with kilosort3.

My machine is Intel(R) Xeon(R) Silver 4210 CPU @ 2.20GHz 2 CPU with 20 core each.

In [5]:
from pathlib import Path
import shutil
import time
import matplotlib.pyplot as plt

base_folder = Path('/mnt/data/sam/DataSpikeSorting/eduarda_arthur') 
data_folder = base_folder / 'raw_awake'

Filter with OLD API

Here we :

  1. open the file
  2. lazy filter
  3. cache it
  4. dump to json

The "cache" step is in fact the "compute and save" step.

In [6]:
import spikeextractors as se
import spiketoolkit as st

print('spikeextractors version', se.__version__)
print('spiketoolkit version', st.__version__)

# step 1: open
file_path = data_folder / 'raw_awake_01_g0_t0.imec0.ap.bin'
recording = se.SpikeGLXRecordingExtractor(file_path)

# step 2: lazy filter
rec_filtered = st.preprocessing.bandpass_filter(recording,  freq_min=300. freq_max=6000.)
print(rec_filtered)

save_folder = base_folder / 'raw_awake_filtered_old'
if save_folder.is_dir():
    shutil.rmtree(save_folder)
save_folder.mkdir()

save_file = save_folder / 'filetred_recording.dat'
dump_file = save_folder / 'filetred_recording.json'

# step 3: cache
t0 = time.perf_counter()
cached = se.CacheRecordingExtractor(rec_filtered, chunk_mb=50, n_jobs=10, 
    save_path=save_file)
t1 = time.perf_counter()
run_time_filter_old = t1-t0
print('Old spikeextractors cache', run_time_filter_old)

# step : dump
cached.dump_to_json(dump_file)
spikeextractors version 0.9.5
spiketoolkit version 0.7.4
<spiketoolkit.preprocessing.bandpass_filter.BandpassFilterRecording object at 0x7f648d3ee130>
Old spikeextractors cache 801.9439885600004

Filter with NEW API

Here we :

  1. open the file
  2. lazy filter
  3. save it

The "save" step is in fact the "compute and save" step.

In [7]:
 
import spikeinterface as si

import spikeinterface.extractors as se
import spikeinterface.toolkit as st
print('spikeinterface version', si.__version__)

# step 1: open
recording = se.SpikeGLXRecordingExtractor(data_folder)
print(recording)

# step 2: lazy filter
rec_filtered =st.bandpass_filter(recording,  freq_min=300., freq_max=6000.)
print(rec_filtered)


filter_path = base_folder / 'raw_awake_filtered'
if filter_path.is_dir():
    shutil.rmtree(filter_path)

# step 3 : compute and save with 10 workers
t0 = time.perf_counter()
cached = rec_filtered.save(folder=filter_path,
    format='binary', dtype='int16',
    n_jobs=10,  total_memory="50M", progress_bar=True)
t1 = time.perf_counter()
run_time_filter_new = t1 -t0
print('New spikeinterface filter + save binary', run_time_filter_new)
spikeinterface version 0.90.0
SpikeGLXRecordingExtractor: 385 channels - 1 segments - 30.0kHz
BandpassFilterRecording: 385 channels - 1 segments - 30.0kHz
write_binary_recording with n_jobs 10  chunk_size 3246
write_binary_recording: 100%|██████████| 5546/5546 [00:51<00:00, 108.39it/s]
New spikeinterface filter + save binary 54.79437772196252

Extract waveform with OLD API

Here we use get_unit_waveforms from toolkit.

We do the computation with 1 and then 10 jobs.

In [21]:
from spikeextractors.baseextractor import BaseExtractor
import spikeextractors as se
import spiketoolkit as st
print('spikeextractors version', se.__version__)
print('spiketoolkit version', st.__version__)
spikeextractors version 0.9.5
spiketoolkit version 0.7.4
In [24]:
save_folder = base_folder / 'raw_awake_filtered_old'
dump_file = save_folder / 'filetred_recording.json'
recording = BaseExtractor.load_extractor_from_json(dump_file)

sorting_KS3 = se.KiloSortSortingExtractor(base_folder / 'output_kilosort3')
waveform_folder = base_folder / 'waveforms_extractor_old_1_job'
if waveform_folder.is_dir():
    shutil.rmtree(waveform_folder)
waveform_folder.mkdir()
sorting_KS3.set_tmp_folder(waveform_folder)

t0 = time.perf_counter()
wf, indexes, channels = st.postprocessing.get_unit_waveforms(recording, sorting_KS3,
            max_spikes_per_unit=500, return_idxs=True, chunk_mb=50, n_jobs=1,
            memmap=True)
t1 = time.perf_counter()
run_time_waveform_old_1jobs = t1 - t0
print('OLD API get_unit_waveforms 1 jobs', run_time_waveform_old_1jobs)
OLD API get_unit_waveforms 1 jobs 513.5964983040467
In [30]:
save_folder = base_folder / 'raw_awake_filtered_old'
dump_file = save_folder / 'filetred_recording.json'
recording = BaseExtractor.load_extractor_from_json(dump_file)

sorting_KS3_bis = se.KiloSortSortingExtractor(base_folder / 'output_kilosort3')
waveform_folder = base_folder / 'waveforms_extractor_old_10_jobs_'
if waveform_folder.is_dir():
    shutil.rmtree(waveform_folder)
waveform_folder.mkdir()
sorting_KS3_bis.set_tmp_folder(waveform_folder)

t0 = time.perf_counter()
wf, indexes, channels = st.postprocessing.get_unit_waveforms(recording, sorting_KS3_bis,
            max_spikes_per_unit=500, return_idxs=True, chunk_mb=500, n_jobs=10,
            memmap=True, verbose=True)
t1 = time.perf_counter()
run_time_waveform_old_10jobs = t1 - t0
print('OLD API get_unit_waveforms 10 jobs', run_time_waveform_old_10jobs)
Number of chunks: 553 - Number of jobs: 10
Impossible to delete temp file: /mnt/data/sam/DataSpikeSorting/eduarda_arthur/waveforms_extractor_old_10_jobs Error [Errno 16] Device or resource busy: '.nfs0000000004ce04d3000007b8'
OLD API get_unit_waveforms 10 jobs 823.8002076600096

Extract waveform with NEW API

The spikeinterface 0.9 API introduce more flexible object WaveformExtractor to do the same (extract snipet).

Here some code example and benchmark speed.

In [39]:
import spikeinterface.extractors as se
from spikeinterface import WaveformExtractor, load_extractor
print('spikeinterface version', si.__version__)

filter_path = base_folder / 'raw_awake_filtered'
filered_recording = load_extractor(filter_path)

sorting_KS3 = se.KiloSortSortingExtractor(base_folder / 'output_kilosort3')
print(sorting_KS3)
spikeinterface version 0.90.0
KiloSortSortingExtractor: 184 units - 1 segments - 30.0kHz
In [41]:
# 1 worker
waveform_folder = base_folder / 'waveforms_extractor_1_job_new_'
if waveform_folder.is_dir():
    shutil.rmtree(waveform_folder)
we = WaveformExtractor.create(filered_recording, sorting_KS3, waveform_folder)

t0 = time.perf_counter()
we.set_params(ms_before=3., ms_after=4., max_spikes_per_unit=500)
we.run(n_jobs=1, total_memory="50M", progress_bar=True)
t1 = time.perf_counter()
run_time_waveform_new_1jobs = t1 - t0
print('New WaveformExtractor 1 jobs',run_time_waveform_new_1jobs)
100%|##########| 278/278 [01:42<00:00,  2.72it/s]
New WaveformExtractor 1 jobs 115.03656197001692
In [42]:
# 1 worker
waveform_folder = base_folder / 'waveforms_extractor_10_job_new_'
if waveform_folder.is_dir():
    shutil.rmtree(waveform_folder)
we = WaveformExtractor.create(filered_recording, sorting_KS3, waveform_folder)

t0 = time.perf_counter()
we.set_params(ms_before=3., ms_after=4., max_spikes_per_unit=500)
we.run(n_jobs=10, total_memory="500M", progress_bar=True)
t1 = time.perf_counter()
run_time_waveform_new_10jobs = t1 - t0
print('New WaveformExtractor 10 jobs', run_time_waveform_new_10jobs)
100%|██████████| 278/278 [00:31<00:00,  8.87it/s]
New WaveformExtractor 10 jobs 48.819815920025576

Conclusion

For filter with 10 workers the speedup is x14.

For waveform extactor with 1 workers the speedup is x4

For waveform extactor with 10 workers the speedup is x16

In [11]:
speedup_filter = run_time_filter_old / run_time_filter_new
print('speedup filter', speedup_filter)
speedup filter 14.635515939778026
In [43]:
speedup_waveform_1jobs = run_time_waveform_old_1jobs / run_time_waveform_new_1jobs
print('speedup waveforms 1 jobs', speedup_waveform_1jobs)

speedup_waveform_10jobs = run_time_waveform_old_10jobs / run_time_waveform_new_10jobs
print('speedup waveformd 10jobs', speedup_waveform_10jobs)
speedup waveforms 1 jobs 4.464637064152789
speedup waveformd 10jobs 16.874299751754943
In [ ]: