EI network - Backprop (Dale’s Law, fixed inhibition)#

Here we will show the analysis for a single example seed of a network described in Galloni et al. 2025.

import EIANN as eiann
from EIANN import utils as ut
eiann.update_plot_defaults()
root_dir = ut.get_project_root()

Load MNIST data#

train_dataloader, val_dataloader, test_dataloader, data_generator = ut.get_MNIST_dataloaders()

Load optimized pre-trained EIANN model: EI network trained with Backprop (fixed Soma-I weights)#

network_name = "20231129_EIANN_2_hidden_mnist_bpDale_relu_SGD_config_G_complete_optimized"   
network_seed = 66049
data_seed = 257

If you want to train this network configuration from scratch, you can build a new network object directly from the .yaml configuration file and then train it:

# Create network object
config_file_path = f"{root_dir}/EIANN/network_config/mnist/{network_name}.yaml"
network = ut.build_EIANN_from_config(config_file_path, network_seed=network_seed)

# Train network
data_generator.manual_seed(data_seed)
network.train(train_dataloader, val_dataloader, 
              epochs = 1,
              samples_per_epoch = 20_000, 
              val_interval = (0, -1, 100), 
              store_history = True,
              store_history_interval = (0, -1, 100), 
              store_dynamics = False, 
              store_params = True,
              status_bar = True)

# Optional: Save network object to pickle file
saved_network_path = f"{root_dir}/EIANN/data/mnist/{network_name}_{network_seed}_{data_seed}.pkl"
ut.save_network(spiral_net, path=saved_network_path)

In this case, since we have already trained the network, we will simply load the saved network object that is stored in a pickle (.pkl) file:

saved_network_path = f"{root_dir}/EIANN/data/saved_network_pickles/mnist/{network_name}_{network_seed}_{data_seed}.pkl"
network = ut.load_network(saved_network_path)
network.name = network_name
network.seed = f"{network_seed}_{data_seed}"
Loading network from '/Users/ag1880/github-repos/Milstein-Lab/EIANN/EIANN/data/saved_network_pickles/mnist/20231129_EIANN_2_hidden_mnist_bpDale_relu_SGD_config_G_complete_optimized_66049_257.pkl'
Network successfully loaded from '/Users/ag1880/github-repos/Milstein-Lab/EIANN/EIANN/data/saved_network_pickles/mnist/20231129_EIANN_2_hidden_mnist_bpDale_relu_SGD_config_G_complete_optimized_66049_257.pkl'

3. Training results#

eiann.plot_loss_history(network, ylim=(-0.01, 0.2))
eiann.plot_accuracy_history(network)
eiann.plot_error_history(network)
../_images/851ba417a2a71d46db72b194a2b066dc27879c771892d6fc3108faa88acac84a.png ../_images/fb6b283e1807f952405edcba07828330a578c1ed0d76815bccc1d913dd7bc53b.png ../_images/facae12f40a91a549f1b79ab493cd334809eae5afdf345487b12d3a114cc5485.png
eiann.plot_loss_landscape(test_dataloader, network, num_points=20, extension=0.9, vmax_scale=1.2, scale='log')
../_images/b8e9d8c60abbce04e6151a6f1aea6debdf8b1f54df426a9b5e56ee73d55ec07d.png

4. Analyze population activities#

eiann.plot_batch_accuracy(network, test_dataloader, population='E')
Batch accuracy = 95.36000061035156%
../_images/a5ace26c8c7950b25ba9db887dbd51d302c814bedc0d97337210d4da9a0ca9b0.png ../_images/6c4abf53c96d5b36e7ce2d760cb5af2bad05491a56d9ab1f3169291f17cdbead.png ../_images/7d29e84e0d478cb5061ecb27791603574eb64dfd07bdcc24fed98f8540d20ea2.png
pop_dynamics_dict = ut.compute_test_activity_dynamics(network, test_dataloader, plot=True, normalize=True) # We will evaluate the network dynamics by presenting the test dataset and recording the neuron activity.
print(pop_dynamics_dict['H1E'].shape) # Since we are store the dynamics here, each population should have activity of shape (timesteps, data samples, neurons)
torch.Size([15, 10000, 500])
../_images/9cd4c90aafa820a6a2c39489329dfbbb5c9c8e7329a65b38cf1a4e91d9a089a5.png

5. Analyze learned representations#

pop_activity_dict, pattern_labels, unit_labels_dict = ut.compute_test_activity(network, test_dataloader, class_average=False, sort=True)
pattern_similarity_matrix_dict, neuron_similarity_matrix_dict = ut.compute_representational_similarity_matrix(pop_activity_dict, pattern_labels, unit_labels_dict, population='E', plot=True)
../_images/e8308b2cda6be0957bd4d86db7e6fb0c18972c38f705dc1fdc98e88ccee00c01.png ../_images/23b9809c5027ecd1a2acac3ad608b552a207c4a28bc464a97727b10a7aff15e9.png ../_images/013afdd98dfa4a360568d804201283044c514b76bb0216ea0ec7d5ad56f0157b.png

5.1 Hidden Layer 1#

receptive_fields_H1 = ut.compute_maxact_receptive_fields(network.H1.E, test_dataloader=test_dataloader)
eiann.plot_receptive_fields(receptive_fields_H1, sort=True)
Optimizing receptive field images H1E...
../_images/ddd7db9a0abb9a81afa04dc9fe94e48e2f4f6ee9a90bbe7eef1d2e59bfa1dd6a.png
metrics_dict = ut.compute_representation_metrics(population=network.H1.E, dataloader=test_dataloader, receptive_fields=receptive_fields_H1, plot=True)
../_images/bef6014bddc423ea7e50c499d7ba0bb23c1941d7dd6dabf5fa78e67615152a33.png

5.2 Hidden Layer 2#

receptive_fields_H2 = ut.compute_maxact_receptive_fields(network.H2.E, test_dataloader=test_dataloader)
eiann.plot_receptive_fields(receptive_fields_H2, sort=True)
Optimizing receptive field images H2E...
../_images/e9307b62f83916f6d1b0dd858779d66c2a13e14326c3b3526163dd23fc598558.png
metrics_dict = ut.compute_representation_metrics(population=network.H1.E, dataloader=test_dataloader, receptive_fields=receptive_fields_H2, plot=True)
../_images/3389b231a654b2efa083e7055798df3d010a61222d0d974c4dc75ac71f35e2ec.png

5.3 Output layer#

receptive_fields_output = ut.compute_maxact_receptive_fields(network.Output.E, test_dataloader=test_dataloader)
eiann.plot_receptive_fields(receptive_fields_output, sort=False)
Optimizing receptive field images OutputE...
../_images/5219ff0d78a91ebdb3caf82d17b6b87312852a28e940dbbcc302f7b2544746dc.png