Training EIANN networks#

Note

🚧 Work in Progress: This documentation site is currently under construction. Content may change frequently.

Once an network has been constructed using EIANN, training just requires loading the dataset and specifying a few additional parameters. For this example, we will train a conventional feedforward neural network to classify MNIST handwritten digits.

import torch
import torchvision
import torchvision.transforms as T
import numpy as np
import matplotlib.pyplot as plt

import EIANN.EIANN as eiann
from EIANN import utils as ut
eiann.plot.update_plot_defaults()

Hide code cell output

1. Loading the dataset#

We will start by loading the dataset into a pytorch “DataLoader” object. To facilitate both analysis and reproducibility, the training function of EIANN has been designed to expect data samples that have a unique index for each sample, in addition to the features and labels. In other words, each sample is a tuple of the form (index, input data, output target). Before creating the DataLoader object, we will repackage the MNIST dataset into this format:

# Load dataset
tensor_flatten = T.Compose([T.ToTensor(), T.Lambda(torch.flatten)])
root_dir = ut.get_project_root()
MNIST_train_dataset = torchvision.datasets.MNIST(root=f"{root_dir}/EIANN/data/datasets/", train=True, download=True, transform=tensor_flatten)
MNIST_test_dataset = torchvision.datasets.MNIST(root=f"{root_dir}/EIANN/data/datasets/", train=False, download=True, transform=tensor_flatten)

# Add index to train & test data
MNIST_train = []
for idx,(data,target) in enumerate(MNIST_train_dataset):
    target = torch.eye(len(MNIST_train_dataset.classes))[target]
    MNIST_train.append((idx, data, target))
    
MNIST_test = []
for idx,(data,target) in enumerate(MNIST_test_dataset):
    target = torch.eye(len(MNIST_test_dataset.classes))[target]
    MNIST_test.append((idx, data, target))
    
# Put data in DataLoader
data_generator = torch.Generator()
train_dataloader = torch.utils.data.DataLoader(MNIST_train[0:50000], batch_size=1, shuffle=True, generator=data_generator)
val_dataloader = torch.utils.data.DataLoader(MNIST_train[-10000:], batch_size=10000, shuffle=False)
test_dataloader = torch.utils.data.DataLoader(MNIST_test, batch_size=10000, shuffle=False)

2. Build and train a simple feedforward neural network#

Following the programmatic approach described in the previous page, we will first set up the architecture of a simple ANN.

network_config = eiann.NetworkBuilder()

# Define layers and populations
network_config.layer('Input').population('E', size=784)
network_config.layer('H1').population('E', size=500, activation='relu', bias=True)
network_config.layer('H2').population('E', size=500, activation='relu', bias=True)
network_config.layer('Output').population('E', size=10, activation='softmax', bias=True)

# Create connections between populations
network_config.connect(source='Input.E', target='H1.E')
network_config.connect(source='H1.E', target='H2.E')
network_config.connect(source='H2.E', target='Output.E')

# Set global learning rule
network_config.set_learning_rule('Backprop')

# Set training parameters
network_config.training(optimizer='Adam', learning_rate=0.0001)

network_config.print_architecture()

# Build the network
network_seed = 42 # Random seed for network initialization (for reproducibility)
network = network_config.build(seed=network_seed)
==================================================
Network Architecture:
Input.E (784) -> H1.E (500): Backprop
H1.E (500) -> H2.E (500): Backprop
H2.E (500) -> Output.E (10): Backprop
==================================================

Network(
  (criterion): MSELoss()
  (module_dict): ModuleDict(
    (H1E_InputE): Projection(in_features=784, out_features=500, bias=False)
    (H2E_H1E): Projection(in_features=500, out_features=500, bias=False)
    (OutputE_H2E): Projection(in_features=500, out_features=10, bias=False)
  )
  (parameter_dict): ParameterDict(
      (H1E_bias): Parameter containing: [torch.FloatTensor of size 500]
      (H2E_bias): Parameter containing: [torch.FloatTensor of size 500]
      (OutputE_bias): Parameter containing: [torch.FloatTensor of size 10]
  )
)

Now we just need to specify a few training parameters and call the train method of EIANN

# Training parameters
# --------------------------
epochs = 1
train_steps = 20_000

# Train the network
# --------------------------
data_seed = 123 # Random seed for reproducibility. Ensures that the data is sampled in the same order each time.
data_generator.manual_seed(data_seed)
network.train(train_dataloader, val_dataloader, 
              epochs=epochs,
              samples_per_epoch=train_steps, 
              val_interval=(0, -1, 100), # Validation interval: (start, end, interval); Determines when to measure validation loss.
              status_bar=True)

# Plot training results
# --------------------------
eiann.plot.plot_loss_history(network)
eiann.plot.plot_accuracy_history(network)
eiann.plot.plot_error_history(network)
../_images/7b09ed5a6d94f1f026e9d3e0740033fc734918a63e72cf68ebe34c9993ad9bec.png ../_images/900a2e100b465f78f302882cb0fe7a435f47c0b1072b538b02b756f31a3262d0.png ../_images/4fade751d093f828b644a2efcb94b80977c13d6e48c733d4e0318a26c4f5abc3.png

We can save the trained network to a pickle file for later analysis or checkpoint training:

network_name = "example_feedforward_ANN"
ut.save_network(network, path= f"{root_dir}/EIANN/saved_networks/mnist/{network_name}.pkl", overwrite=True)
WARNING: File '/Users/ag1880/github-repos/Milstein-Lab/EIANN/EIANN/saved_networks/mnist/example_feedforward_ANN.pkl' already exists. Overwriting...
Saved network to '/Users/ag1880/github-repos/Milstein-Lab/EIANN/EIANN/saved_networks/mnist/example_feedforward_ANN.pkl'

At a later time, we can then load our trained model and either continue training it or use it to analyze the learned representations and internal structure of the network.

# Load network object from pickle file
network = ut.load_network(path= f"{root_dir}/EIANN/saved_networks/mnist/{network_name}.pkl")
Loading network from '/Users/ag1880/github-repos/Milstein-Lab/EIANN/EIANN/saved_networks/mnist/example_feedforward_ANN.pkl'
Network successfully loaded from '/Users/ag1880/github-repos/Milstein-Lab/EIANN/EIANN/saved_networks/mnist/example_feedforward_ANN.pkl'

3. Build and train a recurrent network with separate E and I cells#

We can also do the same thing for biologically constrained (Dale’s law) networks with more complex recurrent dynamics between their E and I populations.

network_config = eiann.NetworkBuilder()

# Define layers and populations
network_config.layer('Input').population('E', 784)
network_config.layer('H1').population('E', 500, 'relu').population('SomaI', 50, 'relu', tau=2)#, bias=True)
network_config.layer('H2').population('E', 500, 'relu').population('SomaI', 50, 'relu', tau=2)#, bias=True)
network_config.layer('Output').population('E', 10, 'relu').population('SomaI', 10, 'relu', tau=2)#, bias=True)

# Create connections between populations
network_config.connect(source='Input.E', target='H1.E').type('Exc')
network_config.connect(source='H1.E', target='H1.SomaI').type('Exc',init_scale=0.2).direction('F') 
network_config.connect(source='H1.SomaI', target='H1.E').type('Inh').direction('R')

network_config.connect(source='H1.E', target='H2.E').type('Exc')
network_config.connect(source='H2.E', target='H2.SomaI').type('Exc',init_scale=0.1).direction('F')
network_config.connect(source='H2.SomaI', target='H2.E').type('Inh').direction('R')

network_config.connect(source='H2.E', target='Output.E').type('Exc') 
network_config.connect(source='Output.E', target='Output.SomaI').type('Exc', init_scale=0.08).direction('F')
network_config.connect(source='Output.SomaI', target='Output.E').type('Inh').direction('R')

# Set global learning rule
network_config.set_learning_rule('Backprop')

# Set training parameters
network_config.training(optimizer='Adam',
                        tau=3,
                        forward_steps=18,
                        backward_steps=3,
                        learning_rate=0.0001)

network_config.print_architecture()

# Build the network
network_seed = 42 # Random seed for network initialization (for reproducibility)
network = network_config.build(seed=network_seed)
==================================================
Network Architecture:
Input.E (784) -> H1.E (500) [Exc]: Backprop
H1.SomaI (50) -> H1.E (500) [Inh]: Backprop
H1.E (500) -> H1.SomaI (50) [Exc]: Backprop
H1.E (500) -> H2.E (500) [Exc]: Backprop
H2.SomaI (50) -> H2.E (500) [Inh]: Backprop
H2.E (500) -> H2.SomaI (50) [Exc]: Backprop
H2.E (500) -> Output.E (10) [Exc]: Backprop
Output.SomaI (10) -> Output.E (10) [Inh]: Backprop
Output.E (10) -> Output.SomaI (10) [Exc]: Backprop
==================================================

Network(
  (criterion): MSELoss()
  (module_dict): ModuleDict(
    (H1E_InputE): Projection(in_features=784, out_features=500, bias=False)
    (H1E_H1SomaI): Projection(in_features=50, out_features=500, bias=False)
    (H1SomaI_H1E): Projection(in_features=500, out_features=50, bias=False)
    (H2E_H1E): Projection(in_features=500, out_features=500, bias=False)
    (H2E_H2SomaI): Projection(in_features=50, out_features=500, bias=False)
    (H2SomaI_H2E): Projection(in_features=500, out_features=50, bias=False)
    (OutputE_H2E): Projection(in_features=500, out_features=10, bias=False)
    (OutputE_OutputSomaI): Projection(in_features=10, out_features=10, bias=False)
    (OutputSomaI_OutputE): Projection(in_features=10, out_features=10, bias=False)
  )
  (parameter_dict): ParameterDict(
      (H1E_bias): Parameter containing: [torch.FloatTensor of size 500]
      (H1SomaI_bias): Parameter containing: [torch.FloatTensor of size 50]
      (H2E_bias): Parameter containing: [torch.FloatTensor of size 500]
      (H2SomaI_bias): Parameter containing: [torch.FloatTensor of size 50]
      (OutputE_bias): Parameter containing: [torch.FloatTensor of size 10]
      (OutputSomaI_bias): Parameter containing: [torch.FloatTensor of size 10]
  )
)
# Training parameters
# --------------------------
epochs = 1
train_steps = 20_000

# Train the network
# --------------------------
data_seed = 123 # Random seed for reproducibility. Ensures that the data is sampled in the same order each time.
data_generator.manual_seed(data_seed)
network.train(train_dataloader, val_dataloader, 
              epochs=epochs,
              samples_per_epoch=train_steps, 
              val_interval=(100, -1, 200), # Validation interval: (start, end, interval); Determines when to measure validation loss.
              store_params=True, # This takes a lot of memory and is only needed for analyzing weight dynamics/loss landscape during training
              status_bar=True)

# Plot training results
# --------------------------
eiann.plot.plot_loss_history(network, ylim=(-0.01, 0.2))
eiann.plot.plot_accuracy_history(network)
eiann.plot.plot_error_history(network)
../_images/17aa1482b9aef2f5b5ba28ba24683fc3c8d959282d4186082443762ef29cefd2.png ../_images/66c91ca4b1280b8aa4a3ba54d5b53978617bb5fa29416b750a5d659e63049f9b.png ../_images/fc65fe165bacda8f289034b2fbf53239b6cfd9721ded5e5e9936c0d3968301b0.png
network_name = "example_EI_network"
ut.save_network(network, path= f"{root_dir}/EIANN/saved_networks/mnist/{network_name}.pkl", overwrite=True)
Saved network to '/Users/ag1880/github-repos/Milstein-Lab/EIANN/EIANN/saved_networks/mnist/example_EI_network.pkl'

In the next tutorial, we will see a more in-depth example of what we can analyze about the internal structure of a trained EIANN network.