Commit 7535b6d1 authored by Ines Filipa Fernandes Ramos's avatar Ines Filipa Fernandes Ramos
Browse files

notebook names updated

parent d708d573
This diff is collapsed.
%% Cell type:markdown id: tags:
# Demo Notebook on how to load the transfer core and train a model
%% Cell type:code id: tags:
``` python
%matplotlib inline
%load_ext autoreload
%autoreload 2
```
%% Cell type:code id: tags:
``` python
import torch
import numpy as np
import matplotlib.pyplot as plt
from collections import OrderedDict
import neuralpredictors as neur
from neuralpredictors.data.datasets import StaticImageSet, FileTreeDataset
```
%% Cell type:markdown id: tags:
# Build the dataloaders
%% Cell type:markdown id: tags:
The dataloaders object is a dictionary of 3 dictionaries: train, validation and test. Each of them contains the respective data from all datasets combined that were specified in paths. Here we only provide one dataset. While the responses are normalized, we exclude the input images from normalization. The following config was used in the paper (all arguments not in the config have the default value of the function).
%% Cell type:code id: tags:
``` python
from lurz2020.datasets.mouse_loaders import static_loaders
paths = [r'D:\inception_loop\original_code\lurz2020\static20457-5-9-preproc0']
dataset_config = {'paths': paths,
'batch_size': 64,
'seed': 1,
'cuda': True,
'normalize': False,
'include_eye_position': True,
'exclude': "images"}
dataloaders = static_loaders(**dataset_config)
dat = FileTreeDataset(r'D:\inception_loop\original_code\lurz2020\static20457-5-9-preproc0', "images", "responses")
```
%% Cell type:markdown id: tags:
### Look at the data
%% Cell type:code id: tags:
``` python
tier = 'train'
dataset_name = '20457-5-9-0'
images, responses, eye = [], [], []
for data_names in dataloaders[tier][dataset_name]:
images.append(data_names[0].squeeze().cpu().data.numpy())
responses.append(data_names[1].squeeze().cpu().data.numpy())
images = np.vstack(images)
responses = np.vstack(responses)
print('The \"{}\" set of dataset \"{}\" contains the responses of {} neurons to {} images'.format(tier, dataset_name, responses.shape[1], responses.shape[0]))
```
%%%% Output: stream
The "train" set of dataset "20457-5-9-0" contains the responses of 5335 neurons to 4472 images
%% Cell type:code id: tags:
``` python
# show some example images and the neural responses
n_images = 5
max_response = responses[:n_images].max()
for i in range(n_images):
fig, axs = plt.subplots(1, 2, figsize=(15,4))
axs[0].imshow(images[i])
axs[1].plot(responses[i])
axs[1].set_xlabel('neurons')
axs[1].set_ylabel('responses')
axs[1].set_ylim([0, max_response])
plt.show()
```
%%%% Output: display_data
![]()
%%%% Output: display_data
![]()
%%%% Output: display_data
![]()
%%%% Output: display_data
![]()
%%%% Output: display_data
![]()
%% Cell type:markdown id: tags:
# Build the model
%% Cell type:markdown id: tags:
If you want to load the transfer core later on, the arguments in the model config that concern the architecture of the model can not be changed. The following config was used in the paper (all arguments not in the config have the default value of the function).
%% Cell type:code id: tags:
``` python
from lurz2020.models.models import se2d_fullgaussian2d
model_config = {'init_mu_range': 0.55,
'init_sigma': 0.4,
'input_kern': 15,
'hidden_kern': 13,
'gamma_input': 1.0,
'grid_mean_predictor': {'type': 'cortex',
'input_dimensions': 2,
'hidden_layers': 0,
'hidden_features': 0,
'final_tanh': False},
'gamma_readout': 2.439}
model = se2d_fullgaussian2d(**model_config, dataloaders=dataloaders, seed=1)
```
%% Cell type:markdown id: tags:
## Load the weights of the transfer core
%% Cell type:markdown id: tags:
This will load the weights of the transfer core onto the model that you built above. The argument `strict=False` ensures that only matching keys are loaded. The readout keys are thus discarded.
%% Cell type:code id: tags:
``` python
transfer_model = torch.load('D://inception_loop/original_code/Lurz_2020_code/notebooks/models/transfer_model.pth.tar')
model.load_state_dict(transfer_model, strict=False)
```
%%%% Output: execute_result
_IncompatibleKeys(missing_keys=['shifter.mlp.0.weight', 'shifter.mlp.0.bias', 'shifter.mlp.2.weight', 'shifter.mlp.2.bias', 'shifter.mlp.4.weight', 'shifter.mlp.4.bias', 'readout.20457-5-9-0.sigma', 'readout.20457-5-9-0._features', 'readout.20457-5-9-0.bias', 'readout.20457-5-9-0.source_grid', 'readout.20457-5-9-0.mu_transform.0.weight', 'readout.20457-5-9-0.mu_transform.0.bias'], unexpected_keys=['readout.22564-2-12-0.sigma', 'readout.22564-2-12-0._features', 'readout.22564-2-12-0.bias', 'readout.22564-2-12-0.source_grid', 'readout.22564-2-12-0.mu_transform.0.weight', 'readout.22564-2-12-0.mu_transform.0.bias', 'readout.22564-2-13-0.sigma', 'readout.22564-2-13-0._features', 'readout.22564-2-13-0.bias', 'readout.22564-2-13-0.source_grid', 'readout.22564-2-13-0.mu_transform.0.weight', 'readout.22564-2-13-0.mu_transform.0.bias', 'readout.22564-3-8-0.sigma', 'readout.22564-3-8-0._features', 'readout.22564-3-8-0.bias', 'readout.22564-3-8-0.source_grid', 'readout.22564-3-8-0.mu_transform.0.weight', 'readout.22564-3-8-0.mu_transform.0.bias', 'readout.22564-3-12-0.sigma', 'readout.22564-3-12-0._features', 'readout.22564-3-12-0.bias', 'readout.22564-3-12-0.source_grid', 'readout.22564-3-12-0.mu_transform.0.weight', 'readout.22564-3-12-0.mu_transform.0.bias', 'readout.22846-2-19-0.sigma', 'readout.22846-2-19-0._features', 'readout.22846-2-19-0.bias', 'readout.22846-2-19-0.source_grid', 'readout.22846-2-19-0.mu_transform.0.weight', 'readout.22846-2-19-0.mu_transform.0.bias', 'readout.22846-2-21-0.sigma', 'readout.22846-2-21-0._features', 'readout.22846-2-21-0.bias', 'readout.22846-2-21-0.source_grid', 'readout.22846-2-21-0.mu_transform.0.weight', 'readout.22846-2-21-0.mu_transform.0.bias', 'readout.22846-10-16-0.sigma', 'readout.22846-10-16-0._features', 'readout.22846-10-16-0.bias', 'readout.22846-10-16-0.source_grid', 'readout.22846-10-16-0.mu_transform.0.weight', 'readout.22846-10-16-0.mu_transform.0.bias', 'readout.23343-5-17-0.sigma', 'readout.23343-5-17-0._features', 'readout.23343-5-17-0.bias', 'readout.23343-5-17-0.source_grid', 'readout.23343-5-17-0.mu_transform.0.weight', 'readout.23343-5-17-0.mu_transform.0.bias', 'readout.23555-4-20-0.sigma', 'readout.23555-4-20-0._features', 'readout.23555-4-20-0.bias', 'readout.23555-4-20-0.source_grid', 'readout.23555-4-20-0.mu_transform.0.weight', 'readout.23555-4-20-0.mu_transform.0.bias', 'readout.23555-5-12-0.sigma', 'readout.23555-5-12-0._features', 'readout.23555-5-12-0.bias', 'readout.23555-5-12-0.source_grid', 'readout.23555-5-12-0.mu_transform.0.weight', 'readout.23555-5-12-0.mu_transform.0.bias', 'readout.23656-14-22-0.sigma', 'readout.23656-14-22-0._features', 'readout.23656-14-22-0.bias', 'readout.23656-14-22-0.source_grid', 'readout.23656-14-22-0.mu_transform.0.weight', 'readout.23656-14-22-0.mu_transform.0.bias'])
%% Cell type:markdown id: tags:
# Build the trainer
%% Cell type:code id: tags:
``` python
from lurz2020.training.trainers import standard_trainer as trainer
# If you want to allow fine tuning of the core, set detach_core to False
detach_core=True
if detach_core:
print('Core is fixed and will not be fine-tuned')
else:
print('Core will be fine-tuned')
trainer_config = {'track_training': True,
'detach_core': detach_core}
```
%%%% Output: stream
Core is fixed and will not be fine-tuned
%% Cell type:markdown id: tags:
# Run training
%% Cell type:code id: tags:
``` python
model_state_before = model.state_dict()
```
%% Cell type:code id: tags:
``` python
score, output, model_state = trainer(model=model, dataloaders=dataloaders, seed=40, **trainer_config)
```
%%%% Output: stream
=======================================
correlation -0.0008181966
poisson_loss 2718038.5
%%%% Output: stream
Epoch 1: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 70/70 [00:05<00:00, 13.85it/s]
%%%% Output: stream
[001|00/05] ---> 0.19714295864105225
=======================================
correlation 0.19714296
poisson_loss 1220324.9
%%%% Output: stream
Epoch 2: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 70/70 [00:04<00:00, 14.14it/s]
%%%% Output: stream
[002|00/05] ---> 0.23619388043880463
=======================================
correlation 0.23619388
poisson_loss 919846.5
%%%% Output: stream
Epoch 3: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 70/70 [00:04<00:00, 14.36it/s]
%%%% Output: stream
[003|00/05] ---> 0.26307186484336853
=======================================
correlation 0.26307186
poisson_loss 727433.5
%%%% Output: stream
Epoch 4: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 70/70 [00:04<00:00, 14.30it/s]
%%%% Output: stream
[004|00/05] ---> 0.2799747884273529
=======================================
correlation 0.2799748
poisson_loss 608246.9
%%%% Output: stream
Epoch 5: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 70/70 [00:04<00:00, 14.56it/s]
%%%% Output: stream
[005|00/05] ---> 0.2896345555782318
=======================================
correlation 0.28963456
poisson_loss 535506.56
%%%% Output: stream
Epoch 6: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 70/70 [00:04<00:00, 14.44it/s]
%%%% Output: stream
[006|00/05] ---> 0.2954740822315216
=======================================
correlation 0.29547408
poisson_loss 492124.38
%%%% Output: stream
Epoch 7: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 70/70 [00:04<00:00, 14.49it/s]
%%%% Output: stream
[007|00/05] ---> 0.2983520030975342
=======================================
correlation 0.298352
poisson_loss 468442.75
%%%% Output: stream
Epoch 8: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 70/70 [00:04<00:00, 14.75it/s]
%%%% Output: stream
[008|00/05] ---> 0.2986760437488556
=======================================
correlation 0.29867604
poisson_loss 460164.12
%%%% Output: stream
Epoch 9: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 70/70 [00:04<00:00, 14.32it/s]
%%%% Output: stream
[009|00/05] ---> 0.2992229461669922
=======================================
correlation 0.29922295
poisson_loss 449152.44
%%%% Output: stream
Epoch 10: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 70/70 [00:04<00:00, 14.55it/s]
%%%% Output: stream
[010|00/05] ---> 0.30142906308174133
=======================================
correlation 0.30142906
poisson_loss 434541.6
%%%% Output: stream
Epoch 11: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 70/70 [00:04<00:00, 14.32it/s]
%%%% Output: stream
[011|00/05] ---> 0.30265089869499207
=======================================
correlation 0.3026509
poisson_loss 422120.84
%%%% Output: stream
Epoch 12: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 70/70 [00:04<00:00, 14.58it/s]
%%%% Output: stream
[012|00/05] ---> 0.30281803011894226
=======================================
correlation 0.30281803
poisson_loss 416183.1
%%%% Output: stream
Epoch 13: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 70/70 [00:04<00:00, 14.37it/s]
%%%% Output: stream
[013|00/05] ---> 0.3038332462310791
=======================================
correlation 0.30383325
poisson_loss 412887.44
%%%% Output: stream
Epoch 14: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 70/70 [00:04<00:00, 14.37it/s]
%%%% Output: stream
[014|01/05] -/-> 0.30254703760147095
=======================================
correlation 0.30254704
poisson_loss 414130.44
%%%% Output: stream
Epoch 15: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 70/70 [00:04<00:00, 14.33it/s]
%%%% Output: stream
[015|02/05] -/-> 0.30169227719306946
=======================================
correlation 0.30169228
poisson_loss 418484.7
%%%% Output: stream
Epoch 16: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 70/70 [00:04<00:00, 14.32it/s]
%%%% Output: stream
[016|03/05] -/-> 0.303792268037796
=======================================
correlation 0.30379227
poisson_loss 406196.94
%%%% Output: stream
Epoch 17: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 70/70 [00:04<00:00, 14.50it/s]
%%%% Output: stream
[017|04/05] -/-> 0.30363255739212036
=======================================
correlation 0.30363256
poisson_loss 403754.62
%%%% Output: stream
Epoch 18: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 70/70 [00:04<00:00, 14.50it/s]
%%%% Output: stream
[018|05/05] -/-> 0.30293262004852295
Restoring best model after lr decay! 0.302933 ---> 0.303833
=======================================
correlation 0.30383325
poisson_loss 412887.44
%%%% Output: stream
Epoch 19: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 70/70 [00:04<00:00, 14.45it/s]
%%%% Output: stream
Epoch 19: reducing learning rate of group 0 to 1.5000e-03.
[019|01/05] -/-> 0.30327093601226807
=======================================
correlation 0.30327094
poisson_loss 419500.56
%%%% Output: stream
Epoch 20: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 70/70 [00:04<00:00, 14.07it/s]
%%%% Output: stream
[020|02/05] -/-> 0.30352452397346497
=======================================
correlation 0.30352452
poisson_loss 408720.97
%%%% Output: stream
Epoch 21: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 70/70 [00:04<00:00, 14.37it/s]
%%%% Output: stream
[021|02/05] ---> 0.3041764199733734
=======================================
correlation 0.30417642
poisson_loss 403521.94
%%%% Output: stream
Epoch 22: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 70/70 [00:04<00:00, 14.49it/s]
%%%% Output: stream
[022|01/05] -/-> 0.3041190505027771
=======================================
correlation 0.30411905
poisson_loss 410190.44
%%%% Output: stream
Epoch 23: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 70/70 [00:04<00:00, 14.50it/s]
%%%% Output: stream
[023|02/05] -/-> 0.30390989780426025
=======================================
correlation 0.3039099
poisson_loss 403184.97
%%%% Output: stream
Epoch 24: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 70/70 [00:04<00:00, 14.38it/s]
%%%% Output: stream
[024|02/05] ---> 0.30429205298423767
=======================================
correlation 0.30429205
poisson_loss 407375.0
%%%% Output: stream
Epoch 25: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 70/70 [00:04<00:00, 14.44it/s]
%%%% Output: stream
[025|00/05] ---> 0.30436471104621887
=======================================
correlation 0.3043647
poisson_loss 402028.6
%%%% Output: stream
Epoch 26: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 70/70 [00:04<00:00, 14.14it/s]
%%%% Output: stream
[026|00/05] ---> 0.3048033118247986
=======================================
correlation 0.3048033
poisson_loss 402894.66
%%%% Output: stream
Epoch 27: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 70/70 [00:04<00:00, 14.61it/s]
%%%% Output: stream
[027|01/05] -/-> 0.30474141240119934
=======================================
correlation 0.3047414
poisson_loss 404821.2
%%%% Output: stream
Epoch 28: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 70/70 [00:04<00:00, 14.43it/s]
%%%% Output: stream
[028|02/05] -/-> 0.3047993779182434
=======================================
correlation 0.30479938
poisson_loss 399178.6
%%%% Output: stream
Epoch 29: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 70/70 [00:04<00:00, 14.53it/s]
%%%% Output: stream
[029|03/05] -/-> 0.30338940024375916
=======================================
correlation 0.3033894
poisson_loss 407109.62
%%%% Output: stream
Epoch 30: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 70/70 [00:04<00:00, 14.42it/s]
%%%% Output: stream
[030|04/05] -/-> 0.30467188358306885
=======================================
correlation 0.30467188
poisson_loss 406258.94
%%%% Output: stream
Epoch 31: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 70/70 [00:04<00:00, 14.43it/s]