Commit 0bfaccea authored by Ines Filipa Fernandes Ramos's avatar Ines Filipa Fernandes Ramos
Browse files

commit

parent a6425d64
......@@ -343,7 +343,7 @@ def contrast_tuning(model, img, bias, scale, min_contrast=0.01, n=1000, linear=T
return cont, vals, lim_contrast
def MEI_multi_seed(dataset_name, dat, dataloaders, models, n_seeds, MEIParameter, TargetUnit, track):
def MEI_multi_seed(dataset_name, dat, dataloaders, models, n_seeds, MEIParameter, TargetUnit, track=False):
"""
n_seeds : int # number of distinct seeded models used
mei : longblob # most exciting images
......
This diff is collapsed.
This diff is collapsed.
%% Cell type:markdown id: tags:
# Simple RGCs simulation
%% Cell type:code id: tags:
``` python
import numpy as np, array
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from scipy import signal
import matplotlib.image as mpimg
import matplotlib.cm as cm
import math
from mpl_toolkits.mplot3d import axes3d
import torch
from collections import OrderedDict
import neuralpredictors as neur
from neuralpredictors.data.datasets import StaticImageSet, FileTreeDataset
from numpy import save
```
%% Cell type:code id: tags:
``` python
#Save notebook session
#dill.dump_session('Natural_images_dataset_session_09122020.db')
```
%% Cell type:code id: tags:
``` python
#Restore notebook session
#dill.load_session('Natural_images_dataset_session_09122020.db')
```
%% Cell type:markdown id: tags:
##### Simple RGCs simulation:
%% Cell type:code id: tags:
``` python
_default_2Dgaussian_p = (1,1,1,0,0,0,0)
def gaussian_2D(xz, sigma_x, sigma_z, amp, theta, x0, z0, y0):
"""Two dimensional Gaussian function
params:
- xz: meshgrid of x and z coordinates at which to evaluate the points
- sigma_x: width of the gaussian
- sigma_z: height of the gaussian
- amp: amplitude of the gaussian
- theta: angle of the gaussian (in radian)
- x0: shift in x of the gaussian
- z0: shift in z of the gaussian
- y0: shift in y of the gaussian
"""
(x,z) = xz
x0, z0 = float(x0), float(z0)
a = (np.cos(theta)**2)/(2*sigma_x**2) + (np.sin(theta)**2)/(2*sigma_z**2)
b = -(np.sin(2*theta)) /(4*sigma_x**2) + (np.sin(2*theta)) /(4*sigma_z**2)
c = (np.sin(theta)**2)/(2*sigma_x**2) + (np.cos(theta)**2)/(2*sigma_z**2)
g = amp * np.exp( -(a*((x-x0)**2) + 2*b*(x-x0)*(z-z0) + c*((z-z0)**2))) + y0
return g.ravel()
def mexicanHat(xz, sigma_x_1, sigma_z_1, amp_1, theta_1, x0_1, z0_1,
sigma_x_2, sigma_z_2, amp_2, theta_2, x0_2, z0_2, y0):
"""Sum of two 2D Gaussian function. For the params, see `gaussian_2D`.
However, both share the y0 parameter."""
return (gaussian_2D(xz, sigma_x_1, sigma_z_1, amp_1, theta_1, x0_1, z0_1, 0)
+ gaussian_2D(xz, sigma_x_2, sigma_z_2, amp_2, theta_2, x0_2, z0_2, 0) + y0)
def ELU(r):
if r>0:
return r+1
else:
return np.exp(r) + 1
def RF(vis_field_width, vis_field_height, x_rf_center, z_rf_center, polarity, plot=False):
x,y = np.meshgrid(np.linspace(0,vis_field_width,vis_field_width),np.linspace(0,vis_field_height,vis_field_height))
if polarity==1:
sigma_x_1, sigma_z_1, amp_1, theta_1, x0_1, z0_1 = 2, 2, 1, 0, x_rf_center, z_rf_center
sigma_x_2, sigma_z_2, amp_2, theta_2, x0_2, z0_2, y0 = 3, 3, -0.5, 0, x_rf_center, z_rf_center, 0
else:
sigma_x_1, sigma_z_1, amp_1, theta_1, x0_1, z0_1 = 2, 2, -1, 0, x_rf_center, z_rf_center
sigma_x_2, sigma_z_2, amp_2, theta_2, x0_2, z0_2, y0 = 3, 3, 0.5, 0, x_rf_center, z_rf_center, 0
z = mexicanHat((x,y), sigma_x_1, sigma_z_1, amp_1, theta_1, x0_1, z0_1,
sigma_x_2, sigma_z_2, amp_2, theta_2, x0_2, z0_2, y0).reshape(vis_field_height,vis_field_width)
if plot==True:
fig = plt.figure(figsize=(5,4))
ax = fig.add_subplot(111, projection='3d')
ax.plot_wireframe(x, y, z, rstride=3, cstride=3, label=f"x_rf_center={x_rf_center} z_rf_center={z_rf_center} \n amp_center={amp_1} amp_surround={amp_2}")
_ = ax.legend()
return z
def RGC_response(rf, image, plot=False, seed=None):
Img_barHat = image * rf
if plot==True:
fig, ax = plt.subplots(3, figsize=(7,7))
ax[0].imshow(image, cmap='gray')
ax[0].set_title("Image")
ax[1].imshow(rf, vmin=-1, vmax=1, cmap="gray")
ax[1].set_title("RGC RF")
ax[2].imshow(Img_barHat, cmap=cm.Greys_r)
ax[2].set_title("RGC Response")
plt.tight_layout()
if seed is not None:
np.random.seed(seed)
g = ELU(sum(Img_barHat.ravel()))
spikes = np.random.poisson(lam=g, size=None)
return spikes
```
%% Cell type:code id: tags:
``` python
#Generate the receptive field of one RGC
rf = RF(64, 36, 50, 30, -1, plot=True)
```
%% Output
%% Cell type:code id: tags:
``` python
#Generate the response of one RGC
RGC_response(rf=rf, image=ds_imgs, plot=True)
```
%% Output
155
%% Cell type:markdown id: tags:
##### RGCs response generation to natural images dataset:
%% Cell type:code id: tags:
``` python
#Retrieve image sets from evaluation data set of lurz2020 #5993 images randomly selected as train, validation or test
paths = 'D://inception_loop/RGC_sim/data/static27012021/data/images'
paths2 = 'D://inception_loop/RGC_sim/data/static27012021/data/images2'
images = []
for n in range(5993):
x = np.load(paths+'/'+str(n)+'.npy')
x_padded = np.pad(x[0], pad_width=20, mode='constant',
constant_values=0)
np.save(paths2+'/'+str(n)+'.npy', [x_padded])
images.append(x_padded)
#images = np.vstack(images)
```
%% Cell type:code id: tags:
``` python
image = np.load('D://inception_loop/original_code/lurz2020/static20457-5-9-preproc0/data/images/0.npy')
```
%% Cell type:code id: tags:
``` python
#Generate receptive fields of several RGCs #2304 RGCs - haf ON/half OFF
rf_ON = []
rf_ON_center_coord = []
rf_OFF = []
rf_OFF_center_coord = []
i = 0
image = images[0]
image = image[0]
for width_center in range(image.shape[1]):
for height_center in range(image.shape[0]):
if (i % 2) == 0:
rf = RF(image.shape[1], image.shape[0], width_center, height_center, 1, plot=False)
rf_ON.append(rf)
rf_ON_center_coord = np.array([width_center, height_center])
save('D://inception_loop/RGC_sim_data/data/static27012021/RFs/center_coord_'+str(i)+'.npy', rf_ON_center_coord)
else:
rf = RF(image.shape[1], image.shape[0], width_center, height_center, -1, plot=False)
rf_OFF.append(rf)
rf_OFF_center_coord = np.array([width_center, height_center])
save('D://inception_loop/RGC_sim_data/data/static27012021/RFs/center_coord_'+str(i)+'.npy', rf_OFF_center_coord)
i+=1
```
%% Cell type:code id: tags:
``` python
#Generate responses of simulated RGCs to the image set from evaluation data set of lurz2020
import time
start_time = time.time()
i=0
for image in images:
responses = []
for rfon, rfoff in zip(rf_ON, rf_OFF):
rgc_on_response = RGC_response(rf=rfon, image=image, plot=False)
responses.append(rgc_on_response)
rgc_off_response = RGC_response(rf=rfoff, image=image, plot=False)
responses.append(rgc_off_response)
# save numpy array as npy file
save('./Lurz_2020_code/notebooks/data/RGC_sim/static27012021/data/responses/'+str(i)+'.npy', responses)
if (i % 1000) == 0:
print(i)
i+=1
print("--- %s seconds ---" % (time.time() - start_time))
```
%% Output
0
1000
2000
3000
4000
5000
--- 6485.6336970329285 seconds ---
%% Cell type:code id: tags:
``` python
#Generate data - pupil_center npy files
#Array with list of two values- coordinates of pupil center - [759.87785056, 472.71767702]
for j in range(len(images)):
pupil_center = np.array([0.0,0.0])
# save numpy array as npy file
save('./Lurz_2020_code/notebooks/data/RGC_sim/static27012021/data/pupil_center/'+str(j)+'.npy', pupil_center)
```
%% Cell type:code id: tags:
``` python
#Generate data - behavior npy files -> use include_behavior= False in static_loaders
#Array with list of three values - pupil dilation, temporal derivative and absolute running speed - [99.2678426 , 6.66429682, 0. ]
```
%% Cell type:code id: tags:
``` python
#Generate metadata - neurons - unit_ids npy array
unit_ids = np.array(range(1, len(responses)+1))
#Generate metadata - neurons - animal_ids npy array
animal_ids = np.repeat(1, len(responses))
#Generate metadata - neurons - area npy array
area = ['retina']*len(responses)
#Generate metadata - neurons - layer npy array
layer = ['RGC']*len(responses)
#Generate metadata - neurons - scan_idx npy array
scan_idx = np.repeat(14, len(responses))
#Generate metadata - neurons - sessions npy array
sessions = np.repeat(6, len(responses))
# save numpy arrays as npy arrays
save('./Lurz_2020_code/notebooks/data/RGC_sim/static27012021/meta/neurons/unit_ids.npy', unit_ids)
save('./Lurz_2020_code/notebooks/data/RGC_sim/static27012021/meta/neurons/animal_ids.npy', animal_ids)
save('./Lurz_2020_code/notebooks/data/RGC_sim/static27012021/meta/neurons/area.npy', area)
save('./Lurz_2020_code/notebooks/data/RGC_sim/static27012021/meta/neurons/layer.npy', layer)
save('./Lurz_2020_code/notebooks/data/RGC_sim/static27012021/meta/neurons/scan_idx.npy', scan_idx)
save('./Lurz_2020_code/notebooks/data/RGC_sim/static27012021/meta/neurons/sessions.npy', sessions)
```
%% Cell type:code id: tags:
``` python
#Generate metadata - trials - animal_id npy array
animal_id = np.repeat(1, len(responses))
#Generate metadata - trials - scan_idx npy array
scan_idx = np.repeat(14, len(responses))
#Generate metadata - trials - session npy array
session = np.repeat(6, len(responses))
# save numpy arrays as npy arrays
save('./Lurz_2020_code/notebooks/data/RGC_sim/static27012021/meta/trials/animal_id.npy', animal_id)
save('./Lurz_2020_code/notebooks/data/RGC_sim/static27012021/meta/trials/scan_idx.npy', scan_idx)
save('./Lurz_2020_code/notebooks/data/RGC_sim/static27012021/meta/trials/session.npy', session)
```
%% Cell type:code id: tags:
``` python
#Generate metadata - statistics - pupil_center - all
pupil_center_max_all = np.array([0.0,0.0])
pupil_center_mean_all = np.array([0.0,0.0])
pupil_center_median_all = np.array([0.0,0.0])
pupil_center_min_all = np.array([0.0,0.0])
pupil_center_std_all = np.array([0.0,0.0])
#Generate metadata - statistics - pupil_center - stimulus_frame
pupil_center_max_sf = np.array([0.0,0.0])
pupil_center_mean_sf = np.array([0.0,0.0])
pupil_center_median_sf = np.array([0.0,0.0])
pupil_center_min_sf = np.array([0.0,0.0])
pupil_center_std_sf = np.array([0.0,0.0])
# save numpy arrays as npy arrays - all
save('./Lurz_2020_code/notebooks/data/RGC_sim/static27012021/meta/statistics/pupil_center/all/max.npy', pupil_center_max_all)
save('./Lurz_2020_code/notebooks/data/RGC_sim/static27012021/meta/statistics/pupil_center/all/mean.npy', pupil_center_mean_all)
save('./Lurz_2020_code/notebooks/data/RGC_sim/static27012021/meta/statistics/pupil_center/all/median.npy', pupil_center_median_all)
save('./Lurz_2020_code/notebooks/data/RGC_sim/static27012021/meta/statistics/pupil_center/all/min.npy', pupil_center_min_all)
save('./Lurz_2020_code/notebooks/data/RGC_sim/static27012021/meta/statistics/pupil_center/all/std.npy', pupil_center_std_all)
# save numpy arrays as npy arrays - stimulus_frame
save('./Lurz_2020_code/notebooks/data/RGC_sim/static27012021/meta/statistics/pupil_center/stimulus_frame/max.npy', pupil_center_max_sf)
save('./Lurz_2020_code/notebooks/data/RGC_sim/static27012021/meta/statistics/pupil_center/stimulus_frame/mean.npy', pupil_center_mean_sf)
save('./Lurz_2020_code/notebooks/data/RGC_sim/static27012021/meta/statistics/pupil_center/stimulus_frame/median.npy', pupil_center_median_sf)
save('./Lurz_2020_code/notebooks/data/RGC_sim/static27012021/meta/statistics/pupil_center/stimulus_frame/min.npy', pupil_center_min_sf)
save('./Lurz_2020_code/notebooks/data/RGC_sim/static27012021/meta/statistics/pupil_center/stimulus_frame/std.npy', pupil_center_std_sf)
```
%% Cell type:code id: tags:
``` python
#Generate responses of simulated RGCs to the image set from evaluation data set of lurz2020
import time
start_time = time.time()
responses_all = []
for k in range(len(images)):
responses_all.append(np.load('./Lurz_2020_code/notebooks/data/RGC_sim/static27012021/data/responses/'+str(k)+'.npy'))
if (k % 1000) == 0:
print(k)
print("--- %s seconds ---" % (time.time() - start_time))
```
%% Output
0
1000
2000
3000
4000
5000
--- 47.54899740219116 seconds ---
%% Cell type:code id: tags:
``` python
#Generate metadata - statistics - responses - all
responses_max_all = np.max(responses_all, axis=0)
responses_mean_all = np.mean(responses_all, axis=0)
responses_median_all = np.median(responses_all, axis=0)
responses_min_all = np.min(responses_all, axis=0)
responses_std_all = np.std(responses_all, axis=0)
#Generate metadata - statistics - responses - stimulus_frame
# save numpy arrays as npy arrays - all
save('./Lurz_2020_code/notebooks/data/RGC_sim/static27012021/meta/statistics/responses/all/max.npy', responses_max_all)
save('./Lurz_2020_code/notebooks/data/RGC_sim/static27012021/meta/statistics/responses/all/mean.npy', responses_mean_all)
save('./Lurz_2020_code/notebooks/data/RGC_sim/static27012021/meta/statistics/responses/all/median.npy', responses_median_all)
save('./Lurz_2020_code/notebooks/data/RGC_sim/static27012021/meta/statistics/responses/all/min.npy', responses_min_all)
save('./Lurz_2020_code/notebooks/data/RGC_sim/static27012021/meta/statistics/responses/all/std.npy', responses_std_all)
# save numpy arrays as npy arrays - stimulus_frame
save('./Lurz_2020_code/notebooks/data/RGC_sim/static27012021/meta/statistics/responses/stimulus_frame/max.npy', responses_max_all)
save('./Lurz_2020_code/notebooks/data/RGC_sim/static27012021/meta/statistics/responses/stimulus_frame/mean.npy', responses_mean_all)
save('./Lurz_2020_code/notebooks/data/RGC_sim/static27012021/meta/statistics/responses/stimulus_frame/median.npy', responses_median_all)
save('./Lurz_2020_code/notebooks/data/RGC_sim/static27012021/meta/statistics/responses/stimulus_frame/min.npy', responses_min_all)
save('./Lurz_2020_code/notebooks/data/RGC_sim/static27012021/meta/statistics/responses/stimulus_frame/std.npy', responses_std_all)
```
%% Cell type:markdown id: tags:
----------------------------------------------------------------
%% Cell type:markdown id: tags:
##### RGCs response generation to random checkerboard stimuli:
%% Cell type:code id: tags:
``` python
#Generate random checkerboard stim dataset #10 993 images selected as train or validation
from numpy import random
paths = 'D://inception_loop/RGC_sim/data/static20022021/data/images'
checkerboard_set = []
for n in range(10993):
checkerboard_image = [random.choice([0.0, 255.0], size=(36, 64))]
checkerboard_set.append(checkerboard_image)
np.save(paths+'/'+str(n)+'.npy', checkerboard_image)
```
%% Cell type:code id: tags:
``` python
#Generate random checkerboard stim dataset #10 993 images selected as train or validation
repeated_set = []
idx = 10993
for n in range(10):
checkerboard_image = [random.choice([0.0, 255.0], size=(36, 64))]
for i in range(100):
checkerboard_set.append(checkerboard_image)
np.save(paths+'/'+str(idx)+'.npy', checkerboard_image)
idx+=1
```
%% Cell type:code id: tags:
``` python
#Add padding for comparison reasons
paths = 'D://inception_loop/RGC_sim/data/static20022021/data/images'
images = []
for n in range(11993):
x = np.load(paths+'/'+str(n)+'.npy')
x_padded = np.pad(x[0], pad_width=20, mode='constant',
constant_values=0)
np.save(paths+'/'+str(n)+'.npy', [x_padded])
images.append(x_padded)
```
%% Cell type:code id: tags:
``` python
#Generate receptive fields of several RGCs #2304 RGCs - haf ON/half OFF to checkerboard set
rf_ON = []
rf_OFF = []
i = 0
image = checkerboard_set[0][0]
for width_center in range(image.shape[1]):
for height_center in range(image.shape[0]):
if (i % 2) == 0:
rf = RF(image.shape[1], image.shape[0], width_center, height_center, 1, plot=False)
rf_ON.append(rf)
else:
rf = RF(image.shape[1], image.shape[0], width_center, height_center, -1, plot=False)
rf_OFF.append(rf)
i+=1
```
%% Cell type:code id: tags:
``` python
#Generate responses of simulated RGCs to the image set from evaluation data set of lurz2020
import time
start_time = time.time()
i=0
for image in checkerboard_set:
responses = []
for rfon, rfoff in zip(rf_ON, rf_OFF):
rgc_on_response = RGC_response(rf=rfon, image=image, plot=False)
responses.append(rgc_on_response)
rgc_off_response = RGC_response(rf=rfoff, image=image, plot=False)
responses.append(rgc_off_response)
# save numpy array as npy file
save('D://inception_loop/RGC_sim/data/static20022021/data/responses/'+str(i)+'.npy', responses)
if (i % 1000) == 0:
print(i)
i+=1
print("--- %s seconds ---" % (time.time() - start_time))
```
%% Output
0
1000
2000
3000
4000
5000
6000
7000
8000
9000
10000
11000
--- 12647.280105352402 seconds ---
%% Cell type:code id: tags:
``` python
#Generate data - pupil_center npy files
#Array with list of two values- coordinates of pupil center - [759.87785056, 472.71767702]
for j in range(len(checkerboard_set)):
pupil_center = np.array([0.0,0.0])
# save numpy array as npy file
save('D://inception_loop/RGC_sim/data/static20022021/data/pupil_center/'+str(j)+'.npy', pupil_center)
```
%% Cell type:code id: tags:
``` python
#Generate metadata - neurons - unit_ids npy array
unit_ids = np.array(range(1, len(responses)+1))
#Generate metadata - neurons - animal_ids npy array
animal_ids = np.repeat(1, len(responses))
#Generate metadata - neurons - area npy array
area = ['retina']*len(responses)
#Generate metadata - neurons - layer npy array
layer = ['RGC']*len(responses)
#Generate metadata - neurons - scan_idx npy array
scan_idx = np.repeat(14, len(responses))
#Generate metadata - neurons - sessions npy array
sessions = np.repeat(6, len(responses))
# save numpy arrays as npy arrays
path = 'D://inception_loop/RGC_sim/data/static20022021/'
save(path+'/meta/neurons/unit_ids.npy', unit_ids)
save(path+'/meta/neurons/animal_ids.npy', animal_ids)
save(path+'/meta/neurons/area.npy', area)
save(path+'/meta/neurons/layer.npy', layer)
save(path+'/meta/neurons/scan_idx.npy', scan_idx)
save(path+'/meta/neurons/sessions.npy', sessions)
```
%% Cell type:code id: tags:
``` python
#Generate metadata - trials - animal_id npy array
animal_id = np.repeat(1, len(responses))
#Generate metadata - trials - condition_hash npy array
condition_hash = np.repeat(" ", len(responses))
#Generate metadata - trials - frame_image_class npy array
frame_image_class = np.repeat("imagenet", len(responses))
#Generate metadata - trials - frame_image_id npy array
frame_image_id = np.arange(0, len(responses), 1)
#Generate metadata - trials - frame_last_flip npy array
frame_last_flip = np.random.randint(11000, 30000, size=(len(responses)))
#Generate metadata - trials - frame_pre_blank_period npy array
frame_pre_blank_period = np.random.uniform(0.3, 0.5, size=(len(responses)))
#Generate metadata - trials - frame_presentation_time npy array
frame_presentation_time = np.repeat(0.5, len(responses))
#Generate metadata - trials - frame_trial_ts npy array
frame_trial_ts = np.repeat("Timestamp('2021-02-23 17:53:43')", len(responses))
#Generate metadata - trials - scan_idx npy array
scan_idx = np.repeat(14, len(responses))
#Generate metadata - trials - tiers npy array
tiers = []
tiers[0:10000] = np.repeat("train", 10000)
tiers[10000:10993] = np.repeat("validation", 993)
tiers[10993:11993] = np.repeat("test", 1000)
np.asarray(tiers)
#Generate metadata - trials - session npy array
session = np.repeat(6, len(responses))
#Generate metadata - trials - trial_idx npy array
trial_idx = np.repeat(0, len(responses))
# save numpy arrays as npy arrays
path = 'D://inception_loop/RGC_sim/data/static20022021/'
save(path+'/meta/trials/animal_id.npy', animal_id)
save(path+'/meta/trials/condition_hash.npy', condition_hash)
save(path+'/meta/trials/frame_image_class.npy', frame_image_class)
save(path+'/meta/trials/frame_image_id.npy', frame_image_id)
save(path+'/meta/trials/frame_last_flip.npy', frame_last_flip)
save(path+'/meta/trials/frame_pre_blank_period.npy', frame_pre_blank_period)
save(path+'/meta/trials/frame_presentation_time.npy', frame_presentation_time)
save(path+'/meta/trials/frame_trial_ts.npy', frame_trial_ts)
save(path+'/meta/trials/scan_idx.npy', scan_idx)
save(path+'/meta/trials/tiers.npy', tiers)
save(path+'/meta/trials/session.npy', session)
save(path+'/meta/trials/trial_idx.npy', trial_idx)
```
%% Cell type:code id: tags:
``` python
#Generate metadata - statistics - pupil_center - all
pupil_center_max_all = np.array([0.0,0.0])
pupil_center_mean_all = np.array([0.0,0.0])
pupil_center_median_all = np.array([0.0,0.0])
pupil_center_min_all = np.array([0.0,0.0])
pupil_center_std_all = np.array([0.0,0.0])
#Generate metadata - statistics - pupil_center - stimulus_frame
pupil_center_max_sf = np.array([0.0,0.0])
pupil_center_mean_sf = np.array([0.0,0.0])
pupil_center_median_sf = np.array([0.0,0.0])
pupil_center_min_sf = np.array([0.0,0.0])
pupil_center_std_sf = np.array([0.0,0.0])
# save numpy arrays as npy arrays - all
path = 'D://inception_loop/RGC_sim/data/static20022021/'
save(path+'/meta/statistics/pupil_center/all/mean.npy', pupil_center_mean_all)
save(path+'/meta/statistics/pupil_center/all/median.npy', pupil_center_median_all)
save(path+'/meta/statistics/pupil_center/all/min.npy', pupil_center_min_all)
save(path+'/meta/statistics/pupil_center/all/std.npy', pupil_center_std_all)
# save numpy arrays as npy arrays - stimulus_frame
save(path+'/meta/statistics/pupil_center/stimulus_frame/max.npy', pupil_center_max_sf)
save(path+'/meta/statistics/pupil_center/stimulus_frame/mean.npy', pupil_center_mean_sf)
save(path+'/meta/statistics/pupil_center/stimulus_frame/median.npy', pupil_center_median_sf)
save(path+'/meta/statistics/pupil_center/stimulus_frame/min.npy', pupil_center_min_sf)
save(path+'/meta/statistics/pupil_center/stimulus_frame/std.npy', pupil_center_std_sf)
```
%% Cell type:code id: tags:
``` python
#Generate metadata - statistics - responses - all and stimulus_frame
responses_max_all = np.max(responses, axis=0)
responses_mean_all = np.mean(responses, axis=0)
responses_median_all = np.median(responses, axis=0)
responses_min_all = np.min(responses, axis=0)
responses_std_all = np.std(responses, axis=0)
# save numpy arrays as npy arrays - all
path = 'D://inception_loop/RGC_sim/data/static20022021/'
save(path+'/meta/statistics/responses/all/max.npy', responses_max_all)
save(path+'/meta/statistics/responses/all/mean.npy', responses_mean_all)
save(path+'/meta/statistics/responses/all/median.npy', responses_median_all)
save(path+'/meta/statistics/responses/all/min.npy', responses_min_all)
save(path+'/meta/statistics/responses/all/std.npy', responses_std_all)
# save numpy arrays as npy arrays - stimulus_frame
save(path+'/meta/statistics/responses/stimulus_frame/max.npy', responses_max_all)
save(path+'/meta/statistics/responses/stimulus_frame/mean.npy', responses_mean_all)
save(path+'/meta/statistics/responses/stimulus_frame/median.npy', responses_median_all)
save(path+'/meta/statistics/responses/stimulus_frame/min.npy', responses_min_all)
save(path+'/meta/statistics/responses/stimulus_frame/std.npy', responses_std_all)
```
%% Cell type:code id: tags:
``` python
images = np.vstack(checkerboard_set)
#Generate metadata - statistics - responses - all
images_max_all = np.max(images)
images_mean_all = np.mean(images)
images_median_all = np.median(images)
images_min_all = np.min(images)
images_std_all = np.std(images)
#Generate metadata - statistics - responses - stimulus_frame
# save numpy arrays as npy arrays - all
path = 'D://inception_loop/RGC_sim/data/static20022021/'
save(path+'/meta/statistics/images/all/mean.npy', images_mean_all)
save(path+'/meta/statistics/images/all/median.npy', images_median_all)
save(path+'/meta/statistics/images/all/min.npy', images_min_all)
save(path+'/meta/statistics/images/all/std.npy', images_std_all)
# save numpy arrays as npy arrays - stimulus_frame
save(path+'/meta/statistics/images/stimulus_frame/max.npy', images_max_all)
save(path+'/meta/statistics/images/stimulus_frame/mean.npy', images_mean_all)
save(path+'/meta/statistics/images/stimulus_frame/median.npy', images_median_all)
save(path+'/meta/statistics/images/stimulus_frame/min.npy', images_min_all)
save(path+'/meta/statistics/images/stimulus_frame/std.npy', images_std_all)
```
%% Cell type:markdown id: tags:
------------------------------------------
......
This diff is collapsed.
%% Cell type:markdown id: tags:
 
# Demo Notebook with simulated RGCs data
 
%% Cell type:code id: tags:
 
``` python
%matplotlib inline
%load_ext autoreload
%autoreload 2
%load_ext memory_profiler
```
 
%% Cell type:code id: tags:
 
``` python
import torch
import numpy as np
import matplotlib.pyplot as plt
from collections import OrderedDict
import neuralpredictors as neur
from neuralpredictors.data.datasets import StaticImageSet, FileTreeDataset
import MEI
import matplotlib as mpl
```
 
%% Output
 
The autoreload extension is already loaded. To reload it, use:
%reload_ext autoreload
 
%% Cell type:markdown id: tags:
 
# Build the dataloaders
 
%% Cell type:markdown id: tags:
 
The dataloaders object is a dictionary of 3 dictionaries: train, validation and test. Each of them contains the respective data from all datasets combined that were specified in paths. Here we only provide one dataset. While the responses are normalized, we exclude the input images from normalization. The following config was used in the paper (all arguments not in the config have the default value of the function).
 
%% Cell type:code id: tags:
 
``` python
#Use dataloaders with generated RGC data
from lurz2020.datasets.mouse_loaders import static_loaders
 
#paths = ['D://inception_loop/RGC_sim/data/static27012021']
paths = ['C://Users/Asus/Desktop/Intership EMBL/Python docs/Inception_loop/Lurz_2020_code/notebooks/data/RGC_sim/static27012021']
 
dataset_config = {'paths': paths,
'batch_size': 64,
'seed': 1,
'cuda': True,
'normalize': True,
'exclude': "images"}
 
dataloaders_RGCs = static_loaders(**dataset_config)
dat = FileTreeDataset('C://Users/Asus/Desktop/Intership EMBL/Python docs/Inception_loop/Lurz_2020_code/notebooks/data/RGC_sim/static27012021', "images", "responses")
```
 
%% Cell type:markdown id: tags:
 
### Look at the data
 
%% Cell type:code id: tags:
 
``` python
tier = 'train'
dataset_name = '27012021'
 
images, responses = [], []
for x, y in dataloaders_RGCs[tier][dataset_name]:
images.append(x.squeeze().cpu().data.numpy())
responses.append(y.squeeze().cpu().data.numpy())
 
images = np.vstack(images)
responses = np.vstack(responses)
 
print('The \"{}\" set of dataset \"{}\" contains the responses of {} RGC neurons to {} images'.format(tier, dataset_name, responses.shape[1], responses.shape[0]))
```
 
%% Output
 
The "train" set of dataset "27012021" contains the responses of 2304 RGC neurons to 4472 images
 
%% Cell type:code id: tags:
 
``` python
# show some example images and the neural responses
n_images = 5
max_response = responses[:n_images].max()
 
for i in range(n_images):
fig, axs = plt.subplots(1, 2, figsize=(15,4))
axs[0].imshow(images[i])
axs[1].plot(responses[i])
axs[1].set_xlabel('neurons')
axs[1].set_ylabel('responses')
axs[1].set_ylim([0, max_response])
plt.show()
```
 
%% Output
 
 
 
 
 
 
%% Cell type:markdown id: tags:
 
# Build the model, transfer core, train and evaluate performance - 4 instances
 
%% Cell type:markdown id: tags:
 
Get 4 instances of the model for MEI generation:
 
%% Cell type:code id: tags:
 
``` python
%%time
%%memit
from lurz2020.models.models import se2d_fullgaussian2d
from lurz2020.training.trainers import standard_trainer as trainer
from lurz2020.utility.measures import get_correlations, get_fraction_oracles
 
#Generate 4 instances of the same model with different seeds, for MEI generation
n_seeds = 4
 
models = []
train_correlation_models = []
validation_correlation_models = []
test_correlation_models = []
fraction_oracle = []
 
#Model config
model_config = {'init_mu_range': 0.55,
'init_sigma': 0.4,
'input_kern': 15,
'hidden_kern': 13,
'gamma_input': 1.0,
'grid_mean_predictor': None,
'gamma_readout': 2.439}
 
 
#Change trainer config to not track and print the training progress
trainer_config = {'track_training': False,
'verbose': None,
'detach_core': True}
 
for i in range(n_seeds):
 
model = se2d_fullgaussian2d(**model_config, dataloaders=dataloaders_RGCs, seed=i)
#Load the weights of the transfer core
transfer_model = torch.load('D://inception_loop/original_code/Lurz_2020_code/notebooks/models/transfer_model.pth.tar')
model.load_state_dict(transfer_model, strict=False)
#Run training
score, output, model_state = trainer(model=model, dataloaders=dataloaders_RGCs, seed=1, **trainer_config)
#Get performance of model
train_correlation_models.append(get_correlations(model, dataloaders_RGCs["train"], device='cuda', as_dict=False, per_neuron=False))
validation_correlation_models.append(get_correlations(model, dataloaders_RGCs["validation"], device='cuda', as_dict=False, per_neuron=False))
test_correlation_models.append(get_correlations(model, dataloaders_RGCs["test"], device='cuda', as_dict=False, per_neuron=False))
 
oracle_dataloader = static_loaders(**dataset_config, return_test_sampler=True, tier='test')
fraction_oracle.append(get_fraction_oracles(model=model, dataloaders=oracle_dataloader, device='cuda')[0])
 
print('-----------------------------------------')
print(f'Model instance #{i}')
print('Correlation (train set): {0:.3f}'.format(train_correlation_models[i]))
print('Correlation (validation set): {0:.3f}'.format(validation_correlation_models[i]))
print('Correlation (test set): {0:.3f}'.format(test_correlation_models[i]))
print('-----------------------------------------')
print('Fraction oracle (test set): {0:.3f}'.format(fraction_oracle[i]))
 
models.append(model)
#Save model state for loading later
torch.save(model_state, 'D://inception_loop/RGC_sim/models/model_padded'+str(i)+'.pth')
```
 
%% Output
 
Epoch 1: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 70/70 [00:09<00:00, 7.58it/s]
 
[001|00/05] ---> 0.18584898114204407
 
Epoch 2: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 70/70 [00:09<00:00, 7.63it/s]
 
[002|00/05] ---> 0.2385803908109665
 
Epoch 3: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 70/70 [00:09<00:00, 7.59it/s]
 
[003|00/05] ---> 0.2923159599304199
 
Epoch 4: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 70/70 [00:09<00:00, 7.68it/s]
 
[004|00/05] ---> 0.3427446782588959
 
Epoch 5: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 70/70 [00:09<00:00, 7.40it/s]
 
[005|00/05] ---> 0.39037665724754333
 
Epoch 6: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 70/70 [00:09<00:00, 7.50it/s]
 
[006|00/05] ---> 0.43327847123146057
 
Epoch 7: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 70/70 [00:09<00:00, 7.62it/s]
 
[007|00/05] ---> 0.468386709690094
 
Epoch 8: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 70/70 [00:09<00:00, 7.62it/s]
 
[008|00/05] ---> 0.5005220770835876
 
Epoch 9: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 70/70 [00:09<00:00, 7.69it/s]
 
[009|00/05] ---> 0.5247463583946228
 
Epoch 10: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 70/70 [00:09<00:00, 7.64it/s]
 
[010|00/05] ---> 0.5488583445549011
 
Epoch 11: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 70/70 [00:09<00:00, 7.63it/s]
 
[011|00/05] ---> 0.5667374730110168
 
Epoch 12: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 70/70 [00:09<00:00, 7.71it/s]
 
[012|00/05] ---> 0.5839897394180298
 
Epoch 13: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 70/70 [00:09<00:00, 7.62it/s]
 
[013|00/05] ---> 0.5982062220573425
 
Epoch 14: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 70/70 [00:09<00:00, 7.70it/s]
 
[014|00/05] ---> 0.6118215322494507
 
Epoch 15: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 70/70 [00:09<00:00, 7.70it/s]
 
[015|00/05] ---> 0.6228028535842896
 
Epoch 16: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 70/70 [00:09<00:00, 7.66it/s]
 
[016|00/05] ---> 0.6349698901176453
 
Epoch 17: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 70/70 [00:09<00:00, 7.65it/s]
 
[017|00/05] ---> 0.643383800983429
 
Epoch 18: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 70/70 [00:09<00:00, 7.69it/s]
 
[018|00/05] ---> 0.6502649188041687
 
Epoch 19: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 70/70 [00:09<00:00, 7.64it/s]
 
[019|00/05] ---> 0.6568547487258911
 
Epoch 20: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 70/70 [00:09<00:00, 7.68it/s]
 
[020|00/05] ---> 0.6671704053878784
 
Epoch 21: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 70/70 [00:09<00:00, 7.66it/s]
 
[021|00/05] ---> 0.6736009120941162
 
Epoch 22: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 70/70 [00:09<00:00, 7.49it/s]
 
[022|00/05] ---> 0.6785272359848022
 
Epoch 23: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 70/70 [00:09<00:00, 7.70it/s]
 
[023|00/05] ---> 0.6855196952819824
 
Epoch 24: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 70/70 [00:09<00:00, 7.34it/s]
 
[024|00/05] ---> 0.691754937171936
 
Epoch 25: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 70/70 [00:09<00:00, 7.75it/s]
 
[025|00/05] ---> 0.6927850246429443
 
Epoch 26: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 70/70 [00:09<00:00, 7.65it/s]
 
[026|00/05] ---> 0.6983302235603333
 
Epoch 27: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 70/70 [00:08<00:00, 7.87it/s]
 
[027|00/05] ---> 0.7038701176643372
 
Epoch 28: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 70/70 [00:08<00:00, 7.89it/s]
 
[028|00/05] ---> 0.7102596163749695
 
Epoch 29: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 70/70 [00:08<00:00, 7.86it/s]
 
[029|00/05] ---> 0.7148491144180298
 
Epoch 30: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 70/70 [00:08<00:00, 7.92it/s]
 
[030|00/05] ---> 0.7153313755989075
 
Epoch 31: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 70/70 [00:08<00:00, 7.84it/s]
 
[031|00/05] ---> 0.7180318832397461
 
Epoch 32: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 70/70 [00:08<00:00, 7.86it/s]
 
[032|00/05] ---> 0.7236976623535156
 
Epoch 33: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 70/70 [00:08<00:00, 7.88it/s]
 
[033|00/05] ---> 0.7282322645187378
 
Epoch 34: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 70/70 [00:08<00:00, 7.99it/s]
 
[034|00/05] ---> 0.7299407124519348
 
Epoch 35: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 70/70 [00:08<00:00, 7.89it/s]
 
[035|00/05] ---> 0.7318837642669678
 
Epoch 36: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 70/70 [00:09<00:00, 7.31it/s]
 
[036|00/05] ---> 0.7329056262969971
 
Epoch 37: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 70/70 [00:09<00:00, 7.70it/s]
 
[037|00/05] ---> 0.7344937920570374
 
Epoch 38: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 70/70 [00:09<00:00, 7.66it/s]
 
[038|00/05] ---> 0.7394781708717346
 
Epoch 39: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 70/70 [00:09<00:00, 7.68it/s]
 
[039|00/05] ---> 0.7395546436309814
 
Epoch 40: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 70/70 [00:08<00:00, 7.96it/s]
 
[040|00/05] ---> 0.7422395348548889
 
Epoch 41: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 70/70 [00:08<00:00, 7.78it/s]
 
[041|00/05] ---> 0.7456504106521606
 
Epoch 42: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 70/70 [00:09<00:00, 7.62it/s]
 
[042|01/05] -/-> 0.7443419694900513
 
Epoch 43: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 70/70 [00:09<00:00, 7.76it/s]
 
[043|01/05] ---> 0.7480041980743408
 
Epoch 44: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 70/70 [00:08<00:00, 7.82it/s]
 
[044|00/05] ---> 0.7486326694488525
 
Epoch 45: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 70/70 [00:08<00:00, 7.82it/s]
 
[045|00/05] ---> 0.7494108080863953
 
Epoch 46: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 70/70 [00:08<00:00, 7.95it/s]
 
[046|00/05] ---> 0.7513750195503235
 
Epoch 47: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 70/70 [00:08<00:00, 7.96it/s]
 
[047|01/05] -/-> 0.7504597902297974
 
Epoch 48: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 70/70 [00:08<00:00, 7.96it/s]
 
[048|01/05] ---> 0.7521153688430786
 
Epoch 49: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 70/70 [00:08<00:00, 7.94it/s]
 
[049|00/05] ---> 0.756562352180481
 
Epoch 50: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 70/70 [00:08<00:00, 7.93it/s]
 
[050|01/05] -/-> 0.7559959888458252
 
Epoch 51: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 70/70 [00:08<00:00, 7.87it/s]
 
[051|01/05] ---> 0.7571963667869568
 
Epoch 52: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 70/70 [00:08<00:00, 7.95it/s]
 
[052|00/05] ---> 0.7600085139274597
 
Epoch 53: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 70/70 [00:08<00:00, 7.87it/s]
 
[053|01/05] -/-> 0.7594912052154541
 
Epoch 54: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 70/70 [00:08<00:00, 7.92it/s]
 
[054|01/05] ---> 0.7604233622550964
 
Epoch 55: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 70/70 [00:08<00:00, 7.91it/s]
 
[055|00/05] ---> 0.7634674906730652
 
Epoch 56: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 70/70 [00:08<00:00, 7.88it/s]
 
[056|01/05] -/-> 0.7628517150878906
 
Epoch 57: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 70/70 [00:08<00:00, 7.92it/s]
 
[057|02/05] -/-> 0.7623629570007324
 
Epoch 58: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 70/70 [00:08<00:00, 7.86it/s]
 
[058|02/05] ---> 0.7653490900993347
 
Epoch 59: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 70/70 [00:08<00:00, 7.89it/s]
 
[059|00/05] ---> 0.7658202052116394
 
Epoch 60: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 70/70 [00:09<00:00, 7.72it/s]
 
[060|00/05] ---> 0.7668657302856445
 
Epoch 61: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 70/70 [00:08<00:00, 7.88it/s]
 
[061|01/05] -/-> 0.7661678791046143
 
Epoch 62: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 70/70 [00:08<00:00, 7.92it/s]
 
[062|02/05] -/-> 0.7667859792709351
 
Epoch 63: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 70/70 [00:08<00:00, 7.92it/s]
 
[063|02/05] ---> 0.7697504162788391
 
Epoch 64: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 70/70 [00:08<00:00, 7.86it/s]
 
[064|00/05] ---> 0.7713988423347473
 
Epoch 65: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 70/70 [00:08<00:00, 7.92it/s]
 
[065|00/05] ---> 0.771432101726532
 
Epoch 66: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 70/70 [00:08<00:00, 7.79it/s]
 
[066|01/05] -/-> 0.7688093185424805
 
Epoch 67: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 70/70 [00:08<00:00, 7.89it/s]
 
[067|01/05] ---> 0.7738945484161377
 
Epoch 68: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 70/70 [00:08<00:00, 7.89it/s]
 
[068|01/05] -/-> 0.7717046737670898
 
Epoch 69: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 70/70 [00:09<00:00, 7.75it/s]
 
[069|02/05] -/-> 0.772650957107544
 
Epoch 70: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 70/70 [00:08<00:00, 7.93it/s]
 
[070|02/05] ---> 0.774773895740509
 
Epoch 71: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 70/70 [00:08<00:00, 7.79it/s]
 
[071|01/05] -/-> 0.7742199897766113
 
Epoch 72: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 70/70 [00:08<00:00, 7.79it/s]
 
[072|02/05] -/-> 0.7746914625167847
 
Epoch 73: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 70/70 [00:08<00:00, 7.89it/s]
 
[073|03/05] -/-> 0.7741342782974243
 
Epoch 74: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 70/70 [00:08<00:00, 7.93it/s]
 
[074|03/05] ---> 0.7771814465522766
 
Epoch 75: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 70/70 [00:08<00:00, 7.87it/s]
 
[075|00/05] ---> 0.7776859998703003
 
Epoch 76: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 70/70 [00:08<00:00, 7.87it/s]
 
[076|01/05] -/-> 0.773730993270874
 
Epoch 77: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 70/70 [00:08<00:00, 7.92it/s]
 
[077|02/05] -/-> 0.7738113403320312
 
Epoch 78: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 70/70 [00:08<00:00, 7.85it/s]
 
[078|03/05] -/-> 0.775076150894165
 
Epoch 79: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 70/70 [00:08<00:00, 7.88it/s]
 
[079|04/05] -/-> 0.7754709720611572
 
Epoch 80: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 70/70 [00:09<00:00, 7.72it/s]
 
[080|04/05] ---> 0.7781058549880981
 
Epoch 81: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 70/70 [00:08<00:00, 7.84it/s]
 
[081|01/05] -/-> 0.7777507305145264
 
Epoch 82: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 70/70 [00:08<00:00, 7.87it/s]
 
[082|01/05] ---> 0.7801316976547241
 
Epoch 83: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 70/70 [00:09<00:00, 7.65it/s]
 
[083|01/05] -/-> 0.7792916893959045
 
Epoch 84: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 70/70 [00:08<00:00, 7.91it/s]
 
[084|02/05] -/-> 0.7798691987991333
 
Epoch 85: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 70/70 [00:09<00:00, 7.71it/s]
 
[085|02/05] ---> 0.7806061506271362
 
Epoch 86: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 70/70 [00:08<00:00, 7.91it/s]
 
[086|00/05] ---> 0.7816150188446045
 
Epoch 87: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 70/70 [00:08<00:00, 7.82it/s]
 
[087|01/05] -/-> 0.7792311906814575
 
Epoch 88: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 70/70 [00:08<00:00, 7.80it/s]
 
[088|01/05] ---> 0.7820054292678833
 
Epoch 89: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 70/70 [00:08<00:00, 7.93it/s]
 
[089|01/05] -/-> 0.7817076444625854
 
Epoch 90: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 70/70 [00:08<00:00, 7.86it/s]
 
[090|01/05] ---> 0.7827510833740234
 
Epoch 91: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 70/70 [00:08<00:00, 7.84it/s]
 
[091|01/05] -/-> 0.7805787324905396
 
Epoch 92: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 70/70 [00:08<00:00, 7.78it/s]
 
[092|02/05] -/-> 0.7811468839645386
 
Epoch 93: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 70/70 [00:08<00:00, 7.81it/s]
 
[093|02/05] ---> 0.7829802632331848
 
Epoch 94: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 70/70 [00:08<00:00, 7.93it/s]
 
[094|01/05] -/-> 0.7803215980529785
 
Epoch 95: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 70/70 [00:09<00:00, 7.74it/s]
 
[095|02/05] -/-> 0.7814676761627197
 
Epoch 96: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 70/70 [00:08<00:00, 7.92it/s]
 
[096|03/05] -/-> 0.7824930548667908
 
Epoch 97: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 70/70 [00:08<00:00, 7.80it/s]
 
[097|04/05] -/-> 0.781731903553009
 
Epoch 98: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 70/70 [00:08<00:00, 7.91it/s]
 
[098|04/05] ---> 0.7839739322662354
 
Epoch 99: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 70/70 [00:08<00:00, 7.84it/s]
 
[099|01/05] -/-> 0.7836860418319702
 
Epoch 100: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 70/70 [00:08<00:00, 7.85it/s]
 
[100|02/05] -/-> 0.7818054556846619
 
Epoch 101: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 70/70 [00:08<00:00, 7.87it/s]
 
[101|02/05] ---> 0.7861367464065552
 
Epoch 102: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 70/70 [00:08<00:00, 7.82it/s]
 
[102|00/05] ---> 0.7872404456138611
 
Epoch 103: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 70/70 [00:08<00:00, 7.92it/s]
 
[103|01/05] -/-> 0.7854347229003906
 
Epoch 104: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 70/70 [00:08<00:00, 7.98it/s]
 
[104|02/05] -/-> 0.7844687104225159
 
Epoch 105: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 70/70 [00:09<00:00, 7.46it/s]
 
[105|03/05] -/-> 0.7845489382743835
 
Epoch 106: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 70/70 [00:09<00:00, 7.08it/s]
 
[106|04/05] -/-> 0.785740315914154
 
Epoch 107: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 70/70 [00:08<00:00, 7.78it/s]
 
[107|05/05] -/-> 0.7863855361938477
 
Epoch 108: 3%|██▉ | 2/70 [00:00<00:05, 11.46it/s]
 
Restoring best model after lr decay! 0.786386 ---> 0.787240
 
Epoch 108: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 70/70 [00:08<00:00, 7.80it/s]
 
[108|01/05] -/-> 0.7853833436965942
 
Epoch 109: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 70/70 [00:09<00:00, 7.26it/s]
 
[109|01/05] ---> 0.791121780872345
 
Epoch 110: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 70/70 [00:09<00:00, 7.67it/s]
 
[110|01/05] -/-> 0.7908748984336853
 
Epoch 111: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 70/70 [00:09<00:00, 7.49it/s]
 
[111|02/05] -/-> 0.7898320555686951
 
Epoch 112: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 70/70 [00:09<00:00, 7.02it/s]
 
[112|02/05] ---> 0.7914217710494995
 
Epoch 113: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 70/70 [00:08<00:00, 7.81it/s]
 
[113|01/05] -/-> 0.7900956273078918
 
Epoch 114: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 70/70 [00:09<00:00, 7.57it/s]
 
[114|02/05] -/-> 0.7912313342094421
 
Epoch 115: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 70/70 [00:09<00:00, 7.76it/s]
 
[115|02/05] ---> 0.7931042909622192
 
Epoch 116: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 70/70 [00:09<00:00, 7.78it/s]
 
[116|01/05] -/-> 0.7930849194526672
 
Epoch 117: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 70/70 [00:09<00:00, 7.77it/s]
 
[117|02/05] -/-> 0.7927504777908325
 
Epoch 118: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 70/70 [00:09<00:00, 7.75it/s]
 
[118|03/05] -/-> 0.7894423604011536
 
Epoch 119: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 70/70 [00:08<00:00, 7.87it/s]
 
[119|04/05] -/-> 0.7912067770957947
 
Epoch 120: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 70/70 [00:08<00:00, 7.84it/s]
 
[120|05/05] -/-> 0.7906427383422852
 
Epoch 121: 3%|██▉ | 2/70 [00:00<00:06, 10.94it/s]
 
Restoring best model after lr decay! 0.790643 ---> 0.793104
 
Epoch 121: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 70/70 [00:09<00:00, 7.77it/s]
 
[121|01/05] -/-> 0.7923586368560791
 
Epoch 122: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 70/70 [00:08<00:00, 7.92it/s]
 
[122|02/05] -/-> 0.7924504280090332
 
Epoch 123: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 70/70 [00:08<00:00, 7.94it/s]
 
[123|03/05] -/-> 0.7912676930427551
 
Epoch 124: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 70/70 [00:09<00:00, 7.75it/s]
 
[124|03/05] ---> 0.7938411235809326
 
Epoch 125: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 70/70 [00:08<00:00, 7.78it/s]
 
[125|01/05] -/-> 0.7937787175178528
 
Epoch 126: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 70/70 [00:09<00:00, 7.68it/s]
 
[126|02/05] -/-> 0.7923514246940613
 
Epoch 127: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 70/70 [00:09<00:00, 7.76it/s]
 
[127|03/05] -/-> 0.7913436889648438
 
Epoch 128: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 70/70 [00:08<00:00, 7.97it/s]
 
[128|03/05] ---> 0.794790506362915
 
Epoch 129: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 70/70 [00:08<00:00, 8.00it/s]
 
[129|01/05] -/-> 0.7940186262130737
 
Epoch 130: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 70/70 [00:08<00:00, 7.93it/s]
 
[130|02/05] -/-> 0.7917271852493286
 
Epoch 131: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 70/70 [00:08<00:00, 7.94it/s]
 
[131|03/05] -/-> 0.7902320623397827
 
Epoch 132: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 70/70 [00:08<00:00, 7.88it/s]
 
[132|04/05] -/-> 0.7938737869262695
 
Epoch 133: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 70/70 [00:08<00:00, 7.97it/s]
 
[133|05/05] -/-> 0.791316032409668
Restoring best model after lr decay! 0.791316 ---> 0.794791
Restoring best model! 0.794791 ---> 0.794791
Returning only test sampler with repeats...
-----------------------------------------
Model instance #0
Correlation (train set): 0.820
Correlation (validation set): 0.795
Correlation (test set): 0.754
-----------------------------------------
Fraction oracle (test set): 0.798
 
Epoch 1: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 70/70 [00:09<00:00, 7.49it/s]
 
[001|00/05] ---> 0.18238994479179382
 
Epoch 2: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 70/70 [00:09<00:00, 7.62it/s]
 
[002|00/05] ---> 0.2311604768037796
 
Epoch 3: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 70/70 [00:09<00:00, 7.64it/s]
 
[003|00/05] ---> 0.2820581793785095
 
Epoch 4: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 70/70 [00:09<00:00, 7.73it/s]
 
[004|00/05] ---> 0.33178916573524475
 
Epoch 5: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 70/70 [00:09<00:00, 7.66it/s]
 
[005|00/05] ---> 0.3803391754627228
 
Epoch 6: 100%|███████████████████████████████████████████