-
Notifications
You must be signed in to change notification settings - Fork 9
/
utils.py
84 lines (68 loc) · 2.13 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import pandas as pd
import numpy as np
##########
# analysis
def load_exps(l_pqt):
'''Load simulation results from disk
Parameters
----------
l_pkl : list
List of parquet files with simulation results
Returns
-------
exps : df
data for all experiments 'path_res'
'''
# cycle through all experiments
dfs = []
for p in l_pqt:
# load metadata from pickle
with open(p, 'rb') as f:
df = pd.read_parquet(p)
df.loc[:, 't'] = df.loc[:, 't'].astype(float)
dfs.append(df)
df = pd.concat(dfs)
return df
def get_rate(df, t_run, n_run, flyid2name=dict()):
'''Calculate rate and standard deviation for all experiments
in df
Parameters
----------
df : pd.DataFrame
Dataframe generated with `load_exps` containing spike times
t_run : float
Trial duration in seconds
n_run : int
Number of trials
flyid2name : dict (optional)
Mapping between flywire IDs and custom names
Returns
-------
df_rate : pd.DataFrame
Dataframe with average firing rates
df_std : pd.DataFrame
Dataframe with standard deviation of firing rates
'''
rate, std, flyid, exp_name = [], [], [], []
for e, df_e in df.groupby('exp_name', sort=False):
for f, df_f in df_e.groupby('flywire_id'):
r = np.zeros(n_run)
for t, df_t in df_f.groupby('trial'):
r[int(t)] = len(df_t) / t_run
rate.append(r.mean())
std.append(r.std())
flyid.append(f)
exp_name.append(e)
d = {
'r' : rate,
'std': std,
'flyid' : flyid,
'exp_name' : exp_name,
}
df = pd.DataFrame(d)
df_rate = df.pivot_table(columns='exp_name', index='flyid', values='r')
df_std = df.pivot_table(columns='exp_name', index='flyid', values='std')
if flyid2name:
df_rate.insert(loc=0, column='name', value=df_rate.index.map(flyid2name).fillna(''))
df_std.insert(loc=0, column='name', value=df_rate.index.map(flyid2name).fillna(''))
return df_rate, df_std