Compare old vs new spikeinterface API
Compare "old" vs "new " spikeinterface API¶
Author : Samuel Garcia 29 March 2021
In spring 2021, the spikeinterface team plan a "big refactoring" of the spikeinterface tool suite.
Main changes are:
- use neo as much as possible for extractors
- handle multi segment
- improve performance (pre and post processing)
- add A WaveformExtractor class
Here I will benchmark 2 aspects of the "new API":
- filter with 10 workers on a multi core machine
- extractor waveform 1 worker vs 10 workers
The becnhmark is done a 10 min spikeglx file with 384 channels.
The sorting is done with kilosort3.
My machine is Intel(R) Xeon(R) Silver 4210 CPU @ 2.20GHz 2 CPU with 20 core each.
from pathlib import Path
import shutil
import time
import matplotlib.pyplot as plt
base_folder = Path('/mnt/data/sam/DataSpikeSorting/eduarda_arthur')
data_folder = base_folder / 'raw_awake'
Filter with OLD API¶
Here we :
- open the file
- lazy filter
- cache it
- dump to json
The "cache" step is in fact the "compute and save" step.
import spikeextractors as se
import spiketoolkit as st
print('spikeextractors version', se.__version__)
print('spiketoolkit version', st.__version__)
# step 1: open
file_path = data_folder / 'raw_awake_01_g0_t0.imec0.ap.bin'
recording = se.SpikeGLXRecordingExtractor(file_path)
# step 2: lazy filter
rec_filtered = st.preprocessing.bandpass_filter(recording, freq_min=300. freq_max=6000.)
print(rec_filtered)
save_folder = base_folder / 'raw_awake_filtered_old'
if save_folder.is_dir():
shutil.rmtree(save_folder)
save_folder.mkdir()
save_file = save_folder / 'filetred_recording.dat'
dump_file = save_folder / 'filetred_recording.json'
# step 3: cache
t0 = time.perf_counter()
cached = se.CacheRecordingExtractor(rec_filtered, chunk_mb=50, n_jobs=10,
save_path=save_file)
t1 = time.perf_counter()
run_time_filter_old = t1-t0
print('Old spikeextractors cache', run_time_filter_old)
# step : dump
cached.dump_to_json(dump_file)
Filter with NEW API¶
Here we :
- open the file
- lazy filter
- save it
The "save" step is in fact the "compute and save" step.
import spikeinterface as si
import spikeinterface.extractors as se
import spikeinterface.toolkit as st
print('spikeinterface version', si.__version__)
# step 1: open
recording = se.SpikeGLXRecordingExtractor(data_folder)
print(recording)
# step 2: lazy filter
rec_filtered =st.bandpass_filter(recording, freq_min=300., freq_max=6000.)
print(rec_filtered)
filter_path = base_folder / 'raw_awake_filtered'
if filter_path.is_dir():
shutil.rmtree(filter_path)
# step 3 : compute and save with 10 workers
t0 = time.perf_counter()
cached = rec_filtered.save(folder=filter_path,
format='binary', dtype='int16',
n_jobs=10, total_memory="50M", progress_bar=True)
t1 = time.perf_counter()
run_time_filter_new = t1 -t0
print('New spikeinterface filter + save binary', run_time_filter_new)
Extract waveform with OLD API¶
Here we use get_unit_waveforms from toolkit.
We do the computation with 1 and then 10 jobs.
from spikeextractors.baseextractor import BaseExtractor
import spikeextractors as se
import spiketoolkit as st
print('spikeextractors version', se.__version__)
print('spiketoolkit version', st.__version__)
save_folder = base_folder / 'raw_awake_filtered_old'
dump_file = save_folder / 'filetred_recording.json'
recording = BaseExtractor.load_extractor_from_json(dump_file)
sorting_KS3 = se.KiloSortSortingExtractor(base_folder / 'output_kilosort3')
waveform_folder = base_folder / 'waveforms_extractor_old_1_job'
if waveform_folder.is_dir():
shutil.rmtree(waveform_folder)
waveform_folder.mkdir()
sorting_KS3.set_tmp_folder(waveform_folder)
t0 = time.perf_counter()
wf, indexes, channels = st.postprocessing.get_unit_waveforms(recording, sorting_KS3,
max_spikes_per_unit=500, return_idxs=True, chunk_mb=50, n_jobs=1,
memmap=True)
t1 = time.perf_counter()
run_time_waveform_old_1jobs = t1 - t0
print('OLD API get_unit_waveforms 1 jobs', run_time_waveform_old_1jobs)
save_folder = base_folder / 'raw_awake_filtered_old'
dump_file = save_folder / 'filetred_recording.json'
recording = BaseExtractor.load_extractor_from_json(dump_file)
sorting_KS3_bis = se.KiloSortSortingExtractor(base_folder / 'output_kilosort3')
waveform_folder = base_folder / 'waveforms_extractor_old_10_jobs_'
if waveform_folder.is_dir():
shutil.rmtree(waveform_folder)
waveform_folder.mkdir()
sorting_KS3_bis.set_tmp_folder(waveform_folder)
t0 = time.perf_counter()
wf, indexes, channels = st.postprocessing.get_unit_waveforms(recording, sorting_KS3_bis,
max_spikes_per_unit=500, return_idxs=True, chunk_mb=500, n_jobs=10,
memmap=True, verbose=True)
t1 = time.perf_counter()
run_time_waveform_old_10jobs = t1 - t0
print('OLD API get_unit_waveforms 10 jobs', run_time_waveform_old_10jobs)
Extract waveform with NEW API¶
The spikeinterface 0.9 API introduce more flexible object WaveformExtractor to do the same (extract snipet).
Here some code example and benchmark speed.
import spikeinterface.extractors as se
from spikeinterface import WaveformExtractor, load_extractor
print('spikeinterface version', si.__version__)
filter_path = base_folder / 'raw_awake_filtered'
filered_recording = load_extractor(filter_path)
sorting_KS3 = se.KiloSortSortingExtractor(base_folder / 'output_kilosort3')
print(sorting_KS3)
# 1 worker
waveform_folder = base_folder / 'waveforms_extractor_1_job_new_'
if waveform_folder.is_dir():
shutil.rmtree(waveform_folder)
we = WaveformExtractor.create(filered_recording, sorting_KS3, waveform_folder)
t0 = time.perf_counter()
we.set_params(ms_before=3., ms_after=4., max_spikes_per_unit=500)
we.run(n_jobs=1, total_memory="50M", progress_bar=True)
t1 = time.perf_counter()
run_time_waveform_new_1jobs = t1 - t0
print('New WaveformExtractor 1 jobs',run_time_waveform_new_1jobs)
# 1 worker
waveform_folder = base_folder / 'waveforms_extractor_10_job_new_'
if waveform_folder.is_dir():
shutil.rmtree(waveform_folder)
we = WaveformExtractor.create(filered_recording, sorting_KS3, waveform_folder)
t0 = time.perf_counter()
we.set_params(ms_before=3., ms_after=4., max_spikes_per_unit=500)
we.run(n_jobs=10, total_memory="500M", progress_bar=True)
t1 = time.perf_counter()
run_time_waveform_new_10jobs = t1 - t0
print('New WaveformExtractor 10 jobs', run_time_waveform_new_10jobs)
Conclusion¶
For filter with 10 workers the speedup is x14.
For waveform extactor with 1 workers the speedup is x4
For waveform extactor with 10 workers the speedup is x16
speedup_filter = run_time_filter_old / run_time_filter_new
print('speedup filter', speedup_filter)
speedup_waveform_1jobs = run_time_waveform_old_1jobs / run_time_waveform_new_1jobs
print('speedup waveforms 1 jobs', speedup_waveform_1jobs)
speedup_waveform_10jobs = run_time_waveform_old_10jobs / run_time_waveform_new_10jobs
print('speedup waveformd 10jobs', speedup_waveform_10jobs)