Commit b987a0c7 authored by Ines Filipa Fernandes Ramos's avatar Ines Filipa Fernandes Ramos
Browse files

commit

parents 12b54861 f38591a3
This source diff could not be displayed because it is too large. You can view the blob instead.
# inception_loop_asari_lab
Repository of code developed to implement inception loops ([Walker et al. 2019 Nature Neuro](https://www.nature.com/articles/s41593-019-0517-x)) and generate MEIs used by the Asari Lab.
%% Cell type:markdown id: tags:
# Simple RGCs simulation
%% Cell type:code id: tags:
``` python
import numpy as np, array
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from scipy import signal
import matplotlib.image as mpimg
import matplotlib.cm as cm
import math
from mpl_toolkits.mplot3d import axes3d
import torch
from collections import OrderedDict
import neuralpredictors as neur
from neuralpredictors.data.datasets import StaticImageSet, FileTreeDataset
from numpy import save
```
%% Cell type:code id: tags:
``` python
#Save notebook session
#dill.dump_session('Natural_images_dataset_session_09122020.db')
```
%% Cell type:code id: tags:
``` python
#Restore notebook session
#dill.load_session('Natural_images_dataset_session_09122020.db')
```
%% Cell type:markdown id: tags:
##### Simple RGCs simulation:
%% Cell type:code id: tags:
``` python
_default_2Dgaussian_p = (1,1,1,0,0,0,0)
def gaussian_2D(xz, sigma_x, sigma_z, amp, theta, x0, z0, y0):
"""Two dimensional Gaussian function
params:
- xz: meshgrid of x and z coordinates at which to evaluate the points
- sigma_x: width of the gaussian
- sigma_z: height of the gaussian
- amp: amplitude of the gaussian
- theta: angle of the gaussian (in radian)
- x0: shift in x of the gaussian
- z0: shift in z of the gaussian
- y0: shift in y of the gaussian
"""
(x,z) = xz
x0, z0 = float(x0), float(z0)
a = (np.cos(theta)**2)/(2*sigma_x**2) + (np.sin(theta)**2)/(2*sigma_z**2)
b = -(np.sin(2*theta)) /(4*sigma_x**2) + (np.sin(2*theta)) /(4*sigma_z**2)
c = (np.sin(theta)**2)/(2*sigma_x**2) + (np.cos(theta)**2)/(2*sigma_z**2)
g = amp * np.exp( -(a*((x-x0)**2) + 2*b*(x-x0)*(z-z0) + c*((z-z0)**2))) + y0
return g.ravel()
def mexicanHat(xz, sigma_x_1, sigma_z_1, amp_1, theta_1, x0_1, z0_1,
sigma_x_2, sigma_z_2, amp_2, theta_2, x0_2, z0_2, y0):
"""Sum of two 2D Gaussian function. For the params, see `gaussian_2D`.
However, both share the y0 parameter."""
return (gaussian_2D(xz, sigma_x_1, sigma_z_1, amp_1, theta_1, x0_1, z0_1, 0)
+ gaussian_2D(xz, sigma_x_2, sigma_z_2, amp_2, theta_2, x0_2, z0_2, 0) + y0)
def ELU(r):
if r>0:
return r+1
else:
return np.exp(r) + 1
def RF(vis_field_width, vis_field_height, x_rf_center, z_rf_center, polarity, plot=False):
def RF(vis_field_width, vis_field_height, rf_theta, x_rf_center, z_rf_center, polarity, plot=False):
x,y = np.meshgrid(np.linspace(0,vis_field_width,vis_field_width),np.linspace(0,vis_field_height,vis_field_height))
if polarity==1:
sigma_x_1, sigma_z_1, amp_1, theta_1, x0_1, z0_1 = 2, 2, 1, 0, x_rf_center, z_rf_center
sigma_x_2, sigma_z_2, amp_2, theta_2, x0_2, z0_2, y0 = 3, 3, -0.5, 0, x_rf_center, z_rf_center, 0
sigma_x_1, sigma_z_1, amp_1, theta_1, x0_1, z0_1 = 1, 2, 1, rf_theta, x_rf_center, z_rf_center
sigma_x_2, sigma_z_2, amp_2, theta_2, x0_2, z0_2, y0 = 2, 3, -0.5, rf_theta, x_rf_center, z_rf_center, 0
else:
sigma_x_1, sigma_z_1, amp_1, theta_1, x0_1, z0_1 = 2, 2, -1, 0, x_rf_center, z_rf_center
sigma_x_2, sigma_z_2, amp_2, theta_2, x0_2, z0_2, y0 = 3, 3, 0.5, 0, x_rf_center, z_rf_center, 0
sigma_x_1, sigma_z_1, amp_1, theta_1, x0_1, z0_1 = 1, 2, -1, rf_theta, x_rf_center, z_rf_center
sigma_x_2, sigma_z_2, amp_2, theta_2, x0_2, z0_2, y0 = 2, 3, 0.5, rf_theta, x_rf_center, z_rf_center, 0
z = mexicanHat((x,y), sigma_x_1, sigma_z_1, amp_1, theta_1, x0_1, z0_1,
sigma_x_2, sigma_z_2, amp_2, theta_2, x0_2, z0_2, y0).reshape(vis_field_height,vis_field_width)
if plot==True:
fig = plt.figure(figsize=(5,4))
ax = fig.add_subplot(111, projection='3d')
ax.plot_wireframe(x, y, z, rstride=3, cstride=3, label=f"x_rf_center={x_rf_center} z_rf_center={z_rf_center} \n amp_center={amp_1} amp_surround={amp_2}")
_ = ax.legend()
return z
def RGC_response(rf, image, plot=False, seed=None):
Img_barHat = image * rf
if plot==True:
fig, ax = plt.subplots(3, figsize=(7,7))
ax[0].imshow(image, cmap='gray')
ax[0].imshow(image)
ax[0].set_title("Image")
ax[1].imshow(rf, vmin=-1, vmax=1, cmap="gray")
ax[1].imshow(rf)
ax[1].set_title("RGC RF")
ax[2].imshow(Img_barHat, cmap=cm.Greys_r)
ax[2].imshow(Img_barHat)
ax[2].set_title("RGC Response")
plt.tight_layout()
if seed is not None:
np.random.seed(seed)
g = ELU(sum(Img_barHat.ravel()))
spikes = np.random.poisson(lam=g, size=None)
return spikes
```
%% Cell type:code id: tags:
``` python
#Generate the receptive field of one RGC
rf = RF(64, 36, 50, 30, -1, plot=True)
rf = RF(vis_field_width = 64, vis_field_height = 36, rf_theta = 1.2, x_rf_center = 30, z_rf_center = 30, polarity =1, plot=True)
```
%% Output