EI network - Backprop (Dale’s Law, learned inhibition)#

Here we will show the analysis for a single example seed of a network described in Galloni et al. 2025.

import EIANN as eiann
from EIANN import utils as ut
eiann.update_plot_defaults()
root_dir = ut.get_project_root()

Load MNIST data#

train_dataloader, val_dataloader, test_dataloader, data_generator = ut.get_MNIST_dataloaders()

Load optimized pre-trained EIANN model: EI network trained with Backprop#

network_name = "20240419_EIANN_2_hidden_mnist_bpDale_relu_SGD_config_F_complete_optimized"
network_seed = 66049
data_seed = 257

If you want to train this network configuration from scratch, you can build a new network object directly from the .yaml configuration file and then train it:

# Create network object
config_file_path = f"{root_dir}/EIANN/network_config/mnist/{network_name}.yaml"
network = ut.build_EIANN_from_config(config_file_path, network_seed=network_seed)

# Train network
data_generator.manual_seed(data_seed)
network.train(train_dataloader, val_dataloader, 
              epochs = 1,
              samples_per_epoch = 20_000, 
              val_interval = (0, -1, 100), 
              store_history = True,
              store_history_interval = (0, -1, 100), 
              store_dynamics = False, 
              store_params = True,
              status_bar = True)

# Optional: Save network object to pickle file
saved_network_path = f"{root_dir}/EIANN/data/mnist/{network_name}_{network_seed}_{data_seed}.pkl"
ut.save_network(spiral_net, path=saved_network_path)

In this case, since we have already trained the network, we will simply load the saved network object that is stored in a pickle (.pkl) file:

saved_network_path = f"{root_dir}/EIANN/data/saved_network_pickles/mnist/{network_name}_{network_seed}_{data_seed}.pkl"
network = ut.load_network(saved_network_path)
Loading network from '/Users/ag1880/github-repos/Milstein-Lab/EIANN/EIANN/data/saved_network_pickles/mnist/20240419_EIANN_2_hidden_mnist_bpDale_relu_SGD_config_F_complete_optimized_66049_257.pkl'
Network successfully loaded from '/Users/ag1880/github-repos/Milstein-Lab/EIANN/EIANN/data/saved_network_pickles/mnist/20240419_EIANN_2_hidden_mnist_bpDale_relu_SGD_config_F_complete_optimized_66049_257.pkl'

3. Training results#

eiann.plot_loss_history(network, ylim=(-0.01, 0.2))
eiann.plot_accuracy_history(network)
eiann.plot_error_history(network)
../_images/f6fa8f6478de1d5f2a520a460e3c3d5af812cabc4098f851d9b760f4f4e58406.png ../_images/3e229ed4e10d025f66843d46bf992108235d8983d3ae9304252692531ea9aa8f.png ../_images/4a7e6fa0dbb1bfb8a4554b9da93c15b81baad4b2aa06a549868bc53c82aac4a3.png
eiann.plot_loss_landscape(test_dataloader, network, num_points=20, extension=0.9, vmax_scale=1.2, scale='log')
../_images/f7432c565050993bda96b280b644869d2e7edc30dd39d8959f26f43b89e56f2e.png

4. Analyze population activities#

eiann.plot_batch_accuracy(network, test_dataloader, population='E')
Batch accuracy = 95.12000274658203%
../_images/b28127f5f7b8270f4b4e7b8e0d5dca717250b41425b1cc5ecba1887c0f72a61a.png ../_images/bc58b59746652d34147157c7e8829627b2f74a08f1879c6c5880a94eeba6c00b.png ../_images/edcc14d5dab947f5cc7f3def5c73aea052eecda189fa11d04b1aaa09ab9ce262.png
pop_dynamics_dict = ut.compute_test_activity_dynamics(network, test_dataloader, plot=True, normalize=True) # We will evaluate the network dynamics by presenting the test dataset and recording the neuron activity.
print(pop_dynamics_dict['H1E'].shape) # Since we are store the dynamics here, each population should have activity of shape (timesteps, data samples, neurons)
torch.Size([15, 10000, 500])
../_images/ea247538055b2904fd51c627a960adb23dea054b9cea28fbe265dce2307df834.png

5. Analyze learned representations#

pop_activity_dict, pattern_labels, unit_labels_dict = ut.compute_test_activity(network, test_dataloader, class_average=False, sort=True)
pattern_similarity_matrix_dict, neuron_similarity_matrix_dict = ut.compute_representational_similarity_matrix(pop_activity_dict, pattern_labels, unit_labels_dict, population='E', plot=True)
../_images/e378f3ae9de2d62e751d198f4e61c15a7889f72d445cdee581bf66c2ddbaa308.png ../_images/4760d41c99f0e787d9d680f9239759cbe112aee56a4b2074d184376a0e8b5699.png ../_images/afe8e527afaa7362b85a02a58ac2c00b06f6a8581cd0aba6a824a59fcbc13a05.png

5.1 Hidden Layer 1#

receptive_fields_H1 = ut.compute_maxact_receptive_fields(network.H1.E, test_dataloader=test_dataloader)
eiann.plot_receptive_fields(receptive_fields_H1, sort=True)
Optimizing receptive field images H1E...
../_images/6ef2808790d04da5ed9496096b38a848f0d14333ef9625b7b405bf9b21b7a115.png
metrics_dict = ut.compute_representation_metrics(population=network.H1.E, dataloader=test_dataloader, receptive_fields=receptive_fields_H1, plot=True)
../_images/0cb15826d1d3247a15ee704b144e2a2e1a1d5169dbcecbb072ab77b574593af1.png

5.2 Hidden Layer 2#

receptive_fields_H2 = ut.compute_maxact_receptive_fields(network.H2.E, test_dataloader=test_dataloader)
eiann.plot_receptive_fields(receptive_fields_H2, sort=True)
Optimizing receptive field images H2E...
../_images/b239031c0ab450044de16f8143d3a3185e26e072edcf8c84ffe350dcd02e0970.png
metrics_dict = ut.compute_representation_metrics(population=network.H1.E, dataloader=test_dataloader, receptive_fields=receptive_fields_H2, plot=True)
../_images/56440766cbba9072aff8ddfad6bbd293a92af51851461233c8eb1241c31af3ab.png

5.3 Output layer#

receptive_fields_output = ut.compute_maxact_receptive_fields(network.Output.E, test_dataloader=test_dataloader)
eiann.plot_receptive_fields(receptive_fields_output, sort=False)
Optimizing receptive field images OutputE...
../_images/3f817ec4bf2fd32d10d7b9e879756f592d14afd0fb62abf7dcda1bdbb2ff42a4.png