Collision paper generate recordings

Generation of the recordings

In this notebook, we will generate all the recordings with MEArec that will be necessary to populate the study and compare the sorters. First, we need to create a function that will, given a dictionary of parameter, generate a single recording. The recording parameters can be defined as follows

In [4]:
import os
import sys
import shutil
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt

import MEArec as mr
import spikeinterface.full as si
In [8]:
sys.path.append('../utils/')

from corr_spike_trains import CorrelatedSpikeGenerator
In [3]:
generation_params = {
    'probe' : 'Neuronexus-32', #layout of the probe used
    'duration' : 30*60, #total duration of the recording
    'n_cell' : 20, # number of cells that will be injected
    'fs' : 30000., # sampling rate
    'lag_time' : 0.002,  # half refractory period in ms
    'make_plots' : True,
    'generate_recording' : True,
    'noise_level' : 5,
    'templates_seed' : 42,
    'noise_seed' : 42,
    'global_path' : os.path.abspath('../'),
    'study_number' : 0,
    'save_plots' : True,
    'method' : 'brette', # 'poisson' | 'brette'
    'corr_level' : 0,
    'rate_level' : 5, #Hz
    'nb_recordings' : 5
}

With these parameters, we will create 20 neurons, and correlation levels will be generated via the mixture process of [Brette et al, 2009]. The function to generate a single recording is defined as follows. It assumes that you have, in your folder, a file named ../data/templates/templates_{probe}_100.h5 with all the pre-generated templates that will be used by MEArec

In [5]:
def generate_single_recording(params=generation_params):

    paths = {}
    paths['basedir'] = params['global_path']
    paths['data'] = None

    if paths['data'] == None:
        paths['data'] = os.path.join(paths['basedir'], 'data')

    paths['templates'] =  os.path.join(paths['data'], 'templates')
    paths['recordings'] = os.path.join(paths['data'], 'recordings') 

    for i in paths.values():
        if not os.path.exists(i):
            os.makedirs(i)

    probe = params['probe']
    n_cell = params['n_cell']
    noise_level = params['noise_level']
    study_number = params['study_number']
    corr_level = params['corr_level']
    rate_level = params['rate_level']

    template_filename = os.path.join(paths['templates'], f'templates_{probe}_100.h5')
    recording_filename = os.path.join(paths['recordings'], f'rec{study_number}_{n_cell}cells_{noise_level}noise_{corr_level}corr_{rate_level}rate_{probe}.h5')
    plot_filename = os.path.join(paths['recordings'], f'rec{study_number}_{n_cell}cells_{noise_level}noise_{corr_level}corr_{rate_level}rate_{probe}.pdf')

    spikerate = params['rate_level']
    n_spike_alone = int(spikerate * params['duration'])

    print('Total target rate:', params['rate_level'], "Hz")
    print('Basal rate:', spikerate, "Hz")


    # collision lag range
    lag_sample = int(params['lag_time'] * params['fs'])

    refactory_period = 2 * params['lag_time']

    spiketimes = []

    if params['method'] == 'poisson':
        print('Spike trains generated as independent poisson sources')
        
        for i in range(params['n_cell']):
            
            #~ n = n_spike_alone + n_collision_by_pair * (params['n_cell'] - i - 1)
            n = n_spike_alone
            #~ times = np.random.rand(n_spike_alone) * params['duration']
            times = np.random.rand(n) * params['duration']
            
            times = np.sort(times)
            spiketimes.append(times)

    elif params['method'] == 'brette':
        print('Spike trains generated as compound mixtures')
        C = np.ones((params['n_cell'], params['n_cell']))
        C = params['corr_level'] * np.maximum(C, C.T)
        #np.fill_diagonal(C, 0*np.ones(params['n_cell']))

        rates = rates = params['rate_level']*np.ones(params['n_cell'])

        cor_spk = CorrelatedSpikeGenerator(C, rates, params['n_cell'])
        cor_spk.find_mixture(iter=1e4)
        res = cor_spk.mixture_process(tauc=refactory_period/2, t=params['duration'])
        
        # make neo spiketrains
        for i in range(params['n_cell']):
            #~ print(spiketimes[i])
            mask = res[:, 0] == i
            times = res[mask, 1]
            times = np.sort(times)
            mask = (times > 0) * (times < params['duration'])
            times = times[mask]
            spiketimes.append(times)


    # remove refactory period
    for i in range(params['n_cell']):
        times = spiketimes[i]
        ind, = np.nonzero(np.diff(times) < refactory_period)
        ind += 1
        times = np.delete(times, ind)
        assert np.sum(np.diff(times) < refactory_period) ==0
        spiketimes[i] = times

    # make neo spiketrains
    spiketrains = []
    for i in range(params['n_cell']):
        mask = np.where(spiketimes[i] > 0)
        spiketimes[i] = spiketimes[i][mask] 
        spiketrain = neo.SpikeTrain(spiketimes[i], units='s', t_start=0*pq.s, t_stop=params['duration']*pq.s)
        spiketrain.annotate(cell_type='E')
        spiketrains.append(spiketrain)

    # check with sanity plot here
    if params['make_plots']:
        
        # count number of spike per units
        fig, axs = plt.subplots(2, 2)
        count = [st.size for st in spiketrains]
        ax = axs[0, 0]
        simpleaxis(ax)
        pairs = []
        collision_count_by_pair = []
        collision_count_by_units = np.zeros(n_cell)
        for i in range(n_cell):
            for j in range(i+1, n_cell):
                times1 = spiketrains[i].rescale('s').magnitude
                times2 = spiketrains[j].rescale('s').magnitude
                matching_event = make_matching_events((times1*params['fs']).astype('int64'), (times2*params['fs']).astype('int64'), lag_sample)
                pairs.append(f'{i}-{j}')
                collision_count_by_pair.append(matching_event.size)
                collision_count_by_units[i] += matching_event.size
                collision_count_by_units[j] += matching_event.size
        ax.plot(np.arange(len(collision_count_by_pair)), collision_count_by_pair)
        ax.set_xticks(np.arange(len(collision_count_by_pair)))
        ax.set_xticklabels(pairs)
        ax.set_ylim(0, max(collision_count_by_pair) * 1.1)
        ax.set_ylabel('# Collisions')
        ax.set_xlabel('Pairs')

        # count number of spike per units
        count_total = np.array([st.size for st in spiketrains])
        count_not_collision = count_total - collision_count_by_units

        ax = axs[1, 0]
        simpleaxis(ax)
        ax.bar(np.arange(n_cell).astype(np.int)+1, count_not_collision, color='g')
        ax.bar(np.arange(n_cell).astype(np.int)+1, collision_count_by_units, bottom =count_not_collision, color='r')
        ax.set_ylabel('# spikes')
        ax.set_xlabel('Cell id')
        ax.legend(('Not colliding', 'Colliding'), loc='best')

        # cross corrlogram
        ax = axs[0, 1]
        simpleaxis(ax)
        counts = []
        for i in range(n_cell):
            for j in range(i+1, n_cell):
                times1 = spiketrains[i].rescale('s').magnitude
                times2 = spiketrains[j].rescale('s').magnitude
                matching_event = make_matching_events((times1*params['fs']).astype('int64'), (times2*params['fs']).astype('int64'), lag_sample)
                
                #~ ax = axs[i, j]
                all_lag = matching_event['delta_frame']  / params['fs']
                count, bins  = np.histogram(all_lag, bins=np.arange(-params['lag_time'], params['lag_time'], params['lag_time']/20))
                #~ ax.bar(bins[:-1], count, bins[1] - bins[0])
                ax.plot(1000*bins[:-1], count, bins[1] - bins[0], c='0.5')
                counts += [count]
        counts = np.array(counts)
        counts = np.mean(counts, 0)
        ax.plot(1000*bins[:-1], counts, bins[1] - bins[0], c='r')
        ax.set_xlabel('Lags [ms]')
        ax.set_ylabel('# Collisions')

        ax = axs[1, 1]
        simpleaxis(ax)
        ratios = []
        for i in range(n_cell):
            nb_spikes = len(spiketrains[i])
            nb_collisions = 0
            times1 = spiketrains[i].rescale('s').magnitude
            for j in list(range(0, i)) + list(range(i+1, n_cell)):
                times2 = spiketrains[j].rescale('s').magnitude
                matching_event = make_matching_events((times1*params['fs']).astype('int64'), (times2*params['fs']).astype('int64'), lag_sample)
                nb_collisions += matching_event.size

            if nb_collisions > 0:
                ratios += [nb_spikes / nb_collisions]
            else:
                ratios += [0]

        ax.bar([0], [np.mean(ratios)], yerr=[np.std(ratios)])
        ax.set_ylabel('# spikes / # collisions')
        plt.tight_layout()

        if params['save_plots']:
            plt.savefig(plot_filename)
        else:
            plt.show()
        plt.close()

    if params['generate_recording']:
        spgen = mr.SpikeTrainGenerator(spiketrains=spiketrains)
        rec_params = mr.get_default_recordings_params()
        rec_params['recordings']['fs'] = params['fs']
        rec_params['recordings']['sync_rate'] = None
        rec_params['recordings']['sync_jitter'] = 5
        rec_params['recordings']['noise_level'] = params['noise_level']
        rec_params['recordings']['filter'] = False
        rec_params['spiketrains']['duration'] = params['duration']
        rec_params['spiketrains']['n_exc'] = params['n_cell']
        rec_params['spiketrains']['n_inh'] = 0
        rec_params['recordings']['chunk_duration'] = 10.
        rec_params['templates']['n_overlap_pairs'] = None
        rec_params['templates']['min_dist'] = 0
        rec_params['seeds']['templates'] = params['templates_seed']
        rec_params['seeds']['noise'] = params['noise_seed']
        recgen = mr.gen_recordings(params=rec_params, spgen=spgen, templates=template_filename, verbose=True)
        mr.save_recording_generator(recgen, filename=recording_filename)

Once this function is created, we can create an additional function that will generate several recordings, with different suffix/seeds:

In [6]:
def generate_recordings(params=generation_params):
    for i in range(params['nb_recordings']):
        generation_params['study_number'] = i
        generation_params['templates_seed'] = i
        generation_params['noise_seed'] = i
        generate_single_recording(generation_params)

And now, we have all the required tools to create our recordings. By default, they will all be saved in the folder ../recordings/

In [7]:
## Provide the different rate and correlations levels you want to generate
rate_levels = [5, 10, 15]
corr_levels = [0, 0.1, 0.2]
generation_params['nb_recordings'] = 5 #Number of recordings per conditions
In [ ]:
result = {}

for rate_level in rate_levels:
    for corr_level in corr_levels:

        generation_params['rate_level'] = rate_level
        generation_params['corr_level'] = corr_level
        generate_recordings(generation_params)

Generation of the study objects

Since the recordings have been generated, we now need to create Study objects for spikeinterface, and run the sorters on all these recordings. Be careful that by default, this can create quite a large amount of data, if you have numerous rate/correlation levels and/or number of recordings and/or sorters. First, we need to tell spikeinterface how to find the sorters

In [11]:
ironclust_path = '/media/cure/Secondary/pierre/softwares/ironclust'
kilosort1_path = '/media/cure/Secondary/pierre/softwares/Kilosort-1.0'
kilosort2_path = '/media/cure/Secondary/pierre/softwares/Kilosort-2.0'
kilosort3_path = '/media/cure/Secondary/pierre/softwares/Kilosort-3.0'
hdsort_path = '/media/cure/Secondary/pierre/softwares/HDsort'
os.environ["KILOSORT_PATH"] = kilosort1_path
os.environ["KILOSORT2_PATH"] = kilosort2_path
os.environ["KILOSORT3_PATH"] = kilosort3_path
os.environ['IRONCLUST_PATH'] = ironclust_path
os.environ['HDSORT_PATH'] = hdsort_path

And then we need to create a function that will, given a list of recordings, create a study and run all the sorters

In [13]:
def generate_study(params, keep_data=True):
    paths = {}
    paths['basedir'] = params['global_path']
    paths['data'] = None

    if paths['data'] == None:
        paths['data'] = os.path.join(paths['basedir'], 'data')

    paths['templates'] =  os.path.join(paths['data'], 'templates')
    paths['recordings'] = os.path.join(paths['data'], 'recordings')
    paths['study'] = os.path.join(paths['data'], 'study')
    
    for i in paths.values():
        if not os.path.exists(i):
            os.makedirs(i)

    probe = params['probe']
    n_cell = params['n_cell']
    noise_level = params['noise_level']
    study_number = params['study_number']
    corr_level = params['corr_level']
    rate_level = params['rate_level']

    paths['mearec_filename'] = []

    study_folder = os.path.join(paths['study'], f'{n_cell}cells_{noise_level}noise_{corr_level}corr_{rate_level}rate_{probe}')
    study_folder = Path(study_folder)

    if params['reset_study'] and os.path.exists(study_folder):
        shutil.rmtree(study_folder)

    print('Availables sorters:')
    si.print_sorter_versions()

    gt_dict = {}

    if not os.path.exists(study_folder):

        for i in range(params['nb_recordings']):
            paths['mearec_filename'] += [os.path.join(paths['recordings'], f'rec{i}_{n_cell}cells_{noise_level}noise_{corr_level}corr_{rate_level}rate_{probe}.h5')]

        print('Availables recordings:')
        print(paths['mearec_filename'])

        
        for count, file in enumerate(paths['mearec_filename']):
            rec  = si.MEArecRecordingExtractor(file)
            sorting_gt = si.MEArecSortingExtractor(file)
            gt_dict['rec%d' %count] = (rec, sorting_gt)

        study = si.GroundTruthStudy.create(study_folder, gt_dict, n_jobs=-1, chunk_memory='1G', progress_bar=True)
        study.run_sorters(params['sorter_list'], verbose=False, docker_images=params['docker_images'])
        print("Study created!")
    else:
        study = si.GroundTruthStudy(study_folder)
        if params['relaunch'] == 'all':
            if_exist = 'overwrite'
        elif params['relaunch'] == 'some':
            if_exist = 'keep'

        if params['relaunch'] in ['all', 'some']:
            study.run_sorters(params['sorter_list'], verbose=False, mode_if_folder_exists=if_exist, docker_images=params['docker_images'])
            print("Study loaded!")

    study.copy_sortings()

    if not keep_data:

        for sorter in params['sorter_list']:

            for rec in ['rec%d' %i for i in range(params['nb_recordings'])]:
                sorter_path = os.path.join(study_folder, 'sorter_folders', rec, sorter)
                if os.path.exists(sorter_path):
                    for f in os.listdir(sorter_path):
                        if f != 'spikeinterface_log.json':
                            full_file = os.path.join(sorter_path, f)
                            try:
                                if os.path.isdir(full_file):
                                    shutil.rmtree(full_file)
                                else:
                                    os.remove(full_file)
                            except Exception:
                                pass
        for file in paths['mearec_filename']:
            os.remove(file)

    return study

This function will take a dictionary of inputs (the same as for generating the recordings), and looping over all the possible recordings for a given condition (probe, rate, correlation levels) it will create a study in the path ../study/, running all the sorters on the recordings. This can take a lot of time, depending on the number of recordings/sorters. Note also that by default, the original recorindgs generated by MEArec are kept, and thus duplicated in the study folder. If you want to delete the original recordings (they are not needed for further analysis) you can set keep_data=False

In [14]:
study_params = generation_params.copy()
study_params['sorter_list'] = ['yass', 'kilosort', 'kilosort2', 'kilosort3', 'spykingcircus', 'tridesclous', 'ironclust', 'herdingspikes', 'hdsort']
study_params['docker_images'] = {'yass' : 'spikeinterface/yass-base:2.0.0'} #If some sorters are installed via docker
study_params['relaunch'] = 'all' #If you want to relaunch the sorters. 
study_params['reset_study'] = False #If you want to reset the study (delete everything)
In [ ]:
all_studies = {}
for rate_level in rate_levels:
    for corr_level in corr_levels:

        study_params['rate_level'] = rate_level
        study_params['corr_level'] = corr_level
        all_studies[corr_level, rate_level] = generate_study(study_params)

And this is it! Now you should have several studies, each of them with several recordings that have be analyzed by several sorters, in a structured manner (as function of rate/correlations levels)