Collision paper generate recordings
Generation of the recordings¶
In this notebook, we will generate all the recordings with MEArec that will be necessary to populate the study and compare the sorters. First, we need to create a function that will, given a dictionary of parameter, generate a single recording. The recording parameters can be defined as follows
import os
import sys
import shutil
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt
import MEArec as mr
import spikeinterface.full as si
sys.path.append('../utils/')
from corr_spike_trains import CorrelatedSpikeGenerator
generation_params = {
'probe' : 'Neuronexus-32', #layout of the probe used
'duration' : 30*60, #total duration of the recording
'n_cell' : 20, # number of cells that will be injected
'fs' : 30000., # sampling rate
'lag_time' : 0.002, # half refractory period in ms
'make_plots' : True,
'generate_recording' : True,
'noise_level' : 5,
'templates_seed' : 42,
'noise_seed' : 42,
'global_path' : os.path.abspath('../'),
'study_number' : 0,
'save_plots' : True,
'method' : 'brette', # 'poisson' | 'brette'
'corr_level' : 0,
'rate_level' : 5, #Hz
'nb_recordings' : 5
}
With these parameters, we will create 20 neurons, and correlation levels will be generated via the mixture process of [Brette et al, 2009]. The function to generate a single recording is defined as follows. It assumes that you have, in your folder, a file named ../data/templates/templates_{probe}_100.h5
with all the pre-generated templates that will be used by MEArec
def generate_single_recording(params=generation_params):
paths = {}
paths['basedir'] = params['global_path']
paths['data'] = None
if paths['data'] == None:
paths['data'] = os.path.join(paths['basedir'], 'data')
paths['templates'] = os.path.join(paths['data'], 'templates')
paths['recordings'] = os.path.join(paths['data'], 'recordings')
for i in paths.values():
if not os.path.exists(i):
os.makedirs(i)
probe = params['probe']
n_cell = params['n_cell']
noise_level = params['noise_level']
study_number = params['study_number']
corr_level = params['corr_level']
rate_level = params['rate_level']
template_filename = os.path.join(paths['templates'], f'templates_{probe}_100.h5')
recording_filename = os.path.join(paths['recordings'], f'rec{study_number}_{n_cell}cells_{noise_level}noise_{corr_level}corr_{rate_level}rate_{probe}.h5')
plot_filename = os.path.join(paths['recordings'], f'rec{study_number}_{n_cell}cells_{noise_level}noise_{corr_level}corr_{rate_level}rate_{probe}.pdf')
spikerate = params['rate_level']
n_spike_alone = int(spikerate * params['duration'])
print('Total target rate:', params['rate_level'], "Hz")
print('Basal rate:', spikerate, "Hz")
# collision lag range
lag_sample = int(params['lag_time'] * params['fs'])
refactory_period = 2 * params['lag_time']
spiketimes = []
if params['method'] == 'poisson':
print('Spike trains generated as independent poisson sources')
for i in range(params['n_cell']):
#~ n = n_spike_alone + n_collision_by_pair * (params['n_cell'] - i - 1)
n = n_spike_alone
#~ times = np.random.rand(n_spike_alone) * params['duration']
times = np.random.rand(n) * params['duration']
times = np.sort(times)
spiketimes.append(times)
elif params['method'] == 'brette':
print('Spike trains generated as compound mixtures')
C = np.ones((params['n_cell'], params['n_cell']))
C = params['corr_level'] * np.maximum(C, C.T)
#np.fill_diagonal(C, 0*np.ones(params['n_cell']))
rates = rates = params['rate_level']*np.ones(params['n_cell'])
cor_spk = CorrelatedSpikeGenerator(C, rates, params['n_cell'])
cor_spk.find_mixture(iter=1e4)
res = cor_spk.mixture_process(tauc=refactory_period/2, t=params['duration'])
# make neo spiketrains
for i in range(params['n_cell']):
#~ print(spiketimes[i])
mask = res[:, 0] == i
times = res[mask, 1]
times = np.sort(times)
mask = (times > 0) * (times < params['duration'])
times = times[mask]
spiketimes.append(times)
# remove refactory period
for i in range(params['n_cell']):
times = spiketimes[i]
ind, = np.nonzero(np.diff(times) < refactory_period)
ind += 1
times = np.delete(times, ind)
assert np.sum(np.diff(times) < refactory_period) ==0
spiketimes[i] = times
# make neo spiketrains
spiketrains = []
for i in range(params['n_cell']):
mask = np.where(spiketimes[i] > 0)
spiketimes[i] = spiketimes[i][mask]
spiketrain = neo.SpikeTrain(spiketimes[i], units='s', t_start=0*pq.s, t_stop=params['duration']*pq.s)
spiketrain.annotate(cell_type='E')
spiketrains.append(spiketrain)
# check with sanity plot here
if params['make_plots']:
# count number of spike per units
fig, axs = plt.subplots(2, 2)
count = [st.size for st in spiketrains]
ax = axs[0, 0]
simpleaxis(ax)
pairs = []
collision_count_by_pair = []
collision_count_by_units = np.zeros(n_cell)
for i in range(n_cell):
for j in range(i+1, n_cell):
times1 = spiketrains[i].rescale('s').magnitude
times2 = spiketrains[j].rescale('s').magnitude
matching_event = make_matching_events((times1*params['fs']).astype('int64'), (times2*params['fs']).astype('int64'), lag_sample)
pairs.append(f'{i}-{j}')
collision_count_by_pair.append(matching_event.size)
collision_count_by_units[i] += matching_event.size
collision_count_by_units[j] += matching_event.size
ax.plot(np.arange(len(collision_count_by_pair)), collision_count_by_pair)
ax.set_xticks(np.arange(len(collision_count_by_pair)))
ax.set_xticklabels(pairs)
ax.set_ylim(0, max(collision_count_by_pair) * 1.1)
ax.set_ylabel('# Collisions')
ax.set_xlabel('Pairs')
# count number of spike per units
count_total = np.array([st.size for st in spiketrains])
count_not_collision = count_total - collision_count_by_units
ax = axs[1, 0]
simpleaxis(ax)
ax.bar(np.arange(n_cell).astype(np.int)+1, count_not_collision, color='g')
ax.bar(np.arange(n_cell).astype(np.int)+1, collision_count_by_units, bottom =count_not_collision, color='r')
ax.set_ylabel('# spikes')
ax.set_xlabel('Cell id')
ax.legend(('Not colliding', 'Colliding'), loc='best')
# cross corrlogram
ax = axs[0, 1]
simpleaxis(ax)
counts = []
for i in range(n_cell):
for j in range(i+1, n_cell):
times1 = spiketrains[i].rescale('s').magnitude
times2 = spiketrains[j].rescale('s').magnitude
matching_event = make_matching_events((times1*params['fs']).astype('int64'), (times2*params['fs']).astype('int64'), lag_sample)
#~ ax = axs[i, j]
all_lag = matching_event['delta_frame'] / params['fs']
count, bins = np.histogram(all_lag, bins=np.arange(-params['lag_time'], params['lag_time'], params['lag_time']/20))
#~ ax.bar(bins[:-1], count, bins[1] - bins[0])
ax.plot(1000*bins[:-1], count, bins[1] - bins[0], c='0.5')
counts += [count]
counts = np.array(counts)
counts = np.mean(counts, 0)
ax.plot(1000*bins[:-1], counts, bins[1] - bins[0], c='r')
ax.set_xlabel('Lags [ms]')
ax.set_ylabel('# Collisions')
ax = axs[1, 1]
simpleaxis(ax)
ratios = []
for i in range(n_cell):
nb_spikes = len(spiketrains[i])
nb_collisions = 0
times1 = spiketrains[i].rescale('s').magnitude
for j in list(range(0, i)) + list(range(i+1, n_cell)):
times2 = spiketrains[j].rescale('s').magnitude
matching_event = make_matching_events((times1*params['fs']).astype('int64'), (times2*params['fs']).astype('int64'), lag_sample)
nb_collisions += matching_event.size
if nb_collisions > 0:
ratios += [nb_spikes / nb_collisions]
else:
ratios += [0]
ax.bar([0], [np.mean(ratios)], yerr=[np.std(ratios)])
ax.set_ylabel('# spikes / # collisions')
plt.tight_layout()
if params['save_plots']:
plt.savefig(plot_filename)
else:
plt.show()
plt.close()
if params['generate_recording']:
spgen = mr.SpikeTrainGenerator(spiketrains=spiketrains)
rec_params = mr.get_default_recordings_params()
rec_params['recordings']['fs'] = params['fs']
rec_params['recordings']['sync_rate'] = None
rec_params['recordings']['sync_jitter'] = 5
rec_params['recordings']['noise_level'] = params['noise_level']
rec_params['recordings']['filter'] = False
rec_params['spiketrains']['duration'] = params['duration']
rec_params['spiketrains']['n_exc'] = params['n_cell']
rec_params['spiketrains']['n_inh'] = 0
rec_params['recordings']['chunk_duration'] = 10.
rec_params['templates']['n_overlap_pairs'] = None
rec_params['templates']['min_dist'] = 0
rec_params['seeds']['templates'] = params['templates_seed']
rec_params['seeds']['noise'] = params['noise_seed']
recgen = mr.gen_recordings(params=rec_params, spgen=spgen, templates=template_filename, verbose=True)
mr.save_recording_generator(recgen, filename=recording_filename)
Once this function is created, we can create an additional function that will generate several recordings, with different suffix/seeds:
def generate_recordings(params=generation_params):
for i in range(params['nb_recordings']):
generation_params['study_number'] = i
generation_params['templates_seed'] = i
generation_params['noise_seed'] = i
generate_single_recording(generation_params)
And now, we have all the required tools to create our recordings. By default, they will all be saved in the folder ../recordings/
## Provide the different rate and correlations levels you want to generate
rate_levels = [5, 10, 15]
corr_levels = [0, 0.1, 0.2]
generation_params['nb_recordings'] = 5 #Number of recordings per conditions
result = {}
for rate_level in rate_levels:
for corr_level in corr_levels:
generation_params['rate_level'] = rate_level
generation_params['corr_level'] = corr_level
generate_recordings(generation_params)
Generation of the study objects¶
Since the recordings have been generated, we now need to create Study objects for spikeinterface, and run the sorters on all these recordings. Be careful that by default, this can create quite a large amount of data, if you have numerous rate/correlation levels and/or number of recordings and/or sorters. First, we need to tell spikeinterface how to find the sorters
ironclust_path = '/media/cure/Secondary/pierre/softwares/ironclust'
kilosort1_path = '/media/cure/Secondary/pierre/softwares/Kilosort-1.0'
kilosort2_path = '/media/cure/Secondary/pierre/softwares/Kilosort-2.0'
kilosort3_path = '/media/cure/Secondary/pierre/softwares/Kilosort-3.0'
hdsort_path = '/media/cure/Secondary/pierre/softwares/HDsort'
os.environ["KILOSORT_PATH"] = kilosort1_path
os.environ["KILOSORT2_PATH"] = kilosort2_path
os.environ["KILOSORT3_PATH"] = kilosort3_path
os.environ['IRONCLUST_PATH'] = ironclust_path
os.environ['HDSORT_PATH'] = hdsort_path
And then we need to create a function that will, given a list of recordings, create a study and run all the sorters
def generate_study(params, keep_data=True):
paths = {}
paths['basedir'] = params['global_path']
paths['data'] = None
if paths['data'] == None:
paths['data'] = os.path.join(paths['basedir'], 'data')
paths['templates'] = os.path.join(paths['data'], 'templates')
paths['recordings'] = os.path.join(paths['data'], 'recordings')
paths['study'] = os.path.join(paths['data'], 'study')
for i in paths.values():
if not os.path.exists(i):
os.makedirs(i)
probe = params['probe']
n_cell = params['n_cell']
noise_level = params['noise_level']
study_number = params['study_number']
corr_level = params['corr_level']
rate_level = params['rate_level']
paths['mearec_filename'] = []
study_folder = os.path.join(paths['study'], f'{n_cell}cells_{noise_level}noise_{corr_level}corr_{rate_level}rate_{probe}')
study_folder = Path(study_folder)
if params['reset_study'] and os.path.exists(study_folder):
shutil.rmtree(study_folder)
print('Availables sorters:')
si.print_sorter_versions()
gt_dict = {}
if not os.path.exists(study_folder):
for i in range(params['nb_recordings']):
paths['mearec_filename'] += [os.path.join(paths['recordings'], f'rec{i}_{n_cell}cells_{noise_level}noise_{corr_level}corr_{rate_level}rate_{probe}.h5')]
print('Availables recordings:')
print(paths['mearec_filename'])
for count, file in enumerate(paths['mearec_filename']):
rec = si.MEArecRecordingExtractor(file)
sorting_gt = si.MEArecSortingExtractor(file)
gt_dict['rec%d' %count] = (rec, sorting_gt)
study = si.GroundTruthStudy.create(study_folder, gt_dict, n_jobs=-1, chunk_memory='1G', progress_bar=True)
study.run_sorters(params['sorter_list'], verbose=False, docker_images=params['docker_images'])
print("Study created!")
else:
study = si.GroundTruthStudy(study_folder)
if params['relaunch'] == 'all':
if_exist = 'overwrite'
elif params['relaunch'] == 'some':
if_exist = 'keep'
if params['relaunch'] in ['all', 'some']:
study.run_sorters(params['sorter_list'], verbose=False, mode_if_folder_exists=if_exist, docker_images=params['docker_images'])
print("Study loaded!")
study.copy_sortings()
if not keep_data:
for sorter in params['sorter_list']:
for rec in ['rec%d' %i for i in range(params['nb_recordings'])]:
sorter_path = os.path.join(study_folder, 'sorter_folders', rec, sorter)
if os.path.exists(sorter_path):
for f in os.listdir(sorter_path):
if f != 'spikeinterface_log.json':
full_file = os.path.join(sorter_path, f)
try:
if os.path.isdir(full_file):
shutil.rmtree(full_file)
else:
os.remove(full_file)
except Exception:
pass
for file in paths['mearec_filename']:
os.remove(file)
return study
This function will take a dictionary of inputs (the same as for generating the recordings), and looping over all the possible recordings for a given condition (probe, rate, correlation levels) it will create a study in the path ../study/, running all the sorters on the recordings. This can take a lot of time, depending on the number of recordings/sorters. Note also that by default, the original recorindgs generated by MEArec are kept, and thus duplicated in the study folder. If you want to delete the original recordings (they are not needed for further analysis) you can set keep_data=False
study_params = generation_params.copy()
study_params['sorter_list'] = ['yass', 'kilosort', 'kilosort2', 'kilosort3', 'spykingcircus', 'tridesclous', 'ironclust', 'herdingspikes', 'hdsort']
study_params['docker_images'] = {'yass' : 'spikeinterface/yass-base:2.0.0'} #If some sorters are installed via docker
study_params['relaunch'] = 'all' #If you want to relaunch the sorters.
study_params['reset_study'] = False #If you want to reset the study (delete everything)
all_studies = {}
for rate_level in rate_levels:
for corr_level in corr_levels:
study_params['rate_level'] = rate_level
study_params['corr_level'] = corr_level
all_studies[corr_level, rate_level] = generate_study(study_params)
And this is it! Now you should have several studies, each of them with several recordings that have be analyzed by several sorters, in a structured manner (as function of rate/correlations levels)