EI network - Backprop (Dale’s Law, learned inhibition)#

Here we will show the analysis for a single example seed of a network described in Galloni et al. 2025.

import EIANN.EIANN as eiann
from EIANN import utils as ut
eiann.plot.update_plot_defaults()
root_dir = ut.get_project_root()

Hide code cell output

1. Load MNIST data#

train_dataloader, val_dataloader, test_dataloader, data_generator = ut.get_MNIST_dataloaders()

2. Load optimized pre-trained EIANN model:#

EI network trained with Backprop#

network_name = "20240419_EIANN_2_hidden_mnist_bpDale_relu_SGD_config_F_complete_optimized"
network_seed = 66049
data_seed = 257

If you want to train this network configuration from scratch, you can build a new network object directly from the .yaml configuration file and then train it:

# Create network object
config_file_path = f"../network_config/mnist/{network_name}.yaml"
network = ut.build_EIANN_from_config(config_file_path, network_seed=network_seed)

# Train network
data_generator.manual_seed(data_seed)
network.train(train_dataloader, val_dataloader, 
              epochs = 1,
              samples_per_epoch = 20_000, 
              val_interval = (0, -1, 100), 
              store_history = True,
              store_history_interval = (0, -1, 100), 
              store_dynamics = False, 
              store_params = True,
              status_bar = True)

# Optional: Save network object to pickle file
saved_network_path = root_dir + f"/EIANN/data/mnist/{network_name}_{network_seed}_{data_seed}.pkl"
ut.save_network(spiral_net, path=saved_network_path)

In this case, since we have already trained the network, we will simply load the saved network object that is stored in a pickle (.pkl) file:

saved_network_path = root_dir + f"/EIANN/data/saved_network_pickles/mnist/{network_name}_{network_seed}_{data_seed}.pkl"
network = ut.load_network(saved_network_path)
Loading network from '/Users/ag1880/github-repos/Milstein-Lab/EIANN/EIANN/data/mnist/20240419_EIANN_2_hidden_mnist_bpDale_relu_SGD_config_F_complete_optimized_66049_257.pkl'
Network successfully loaded from '/Users/ag1880/github-repos/Milstein-Lab/EIANN/EIANN/data/mnist/20240419_EIANN_2_hidden_mnist_bpDale_relu_SGD_config_F_complete_optimized_66049_257.pkl'

3. Training results#

eiann.plot.plot_loss_history(network, ylim=(-0.01, 0.2))
eiann.plot.plot_accuracy_history(network)
eiann.plot.plot_error_history(network)
../_images/40b0ca8578bcf1102e05791c7c220a2a666911b7fa3c64e08aca6f881c37e555.png ../_images/93dbefd7af190d7219f1a45f0a4f217fe01060e90d61f7b83e1a2f5e374df422.png ../_images/b46c29c52f1ab3578bc3b8d04a8fde4ec065af440b9c2361933118c3ac5818e8.png
eiann.plot.plot_loss_landscape(test_dataloader, network, num_points=20, extension=0.9, vmax_scale=1.2, scale='log')
../_images/25bd7a0bb23b1fabc2d523c942dea4b17a1cd2eea93dfcd26f4717dce3c352e2.png

4. Analyze population activities#

eiann.plot.plot_batch_accuracy(network, test_dataloader, population='E')
Batch accuracy = 95.12000274658203%
../_images/42ac9a7d73bb2ab63b1183e3c3bfe3e31d179806c0f3ba59d5487c670292ee90.png ../_images/ba063feee929805697e7131f956797652f5dd5f9a3dd9c6fde8e2a8dce4b3546.png ../_images/8d6e2c95b68c3015f86bb035900fa679ba9e33db8e19cd18116fbeee1ececf39.png
pop_dynamics_dict = ut.compute_test_activity_dynamics(network, test_dataloader, plot=True, normalize=True) # We will evaluate the network dynamics by presenting the test dataset and recording the neuron activity.
print(pop_dynamics_dict['H1E'].shape) # Since we are store the dynamics here, each population should have activity of shape (timesteps, data samples, neurons)
torch.Size([15, 10000, 500])
../_images/7ee0c4ad7c573ac111a4bbe6518b0196be210a67684212d1f2f3a4a56ff4c44a.png

5. Analyze learned representations#

pop_activity_dict, pattern_labels, unit_labels_dict = ut.compute_test_activity(network, test_dataloader, class_average=False, sort=True)
pattern_similarity_matrix_dict, neuron_similarity_matrix_dict = ut.compute_representational_similarity_matrix(pop_activity_dict, pattern_labels, unit_labels_dict, population='E', plot=True)
../_images/985e98eb005e206757e410f672e08cdc8b17464ba4108bc4344fb47e6758c7f7.png ../_images/3e030898ccfb176f19160797af8fadb470fa7783d9379f96e84b35bd44e3be83.png ../_images/9f514988908dc7dcd3eb4d95aebfbae3e99f55c38214a29dbe680ba8ccb587f4.png
within_class_pattern_similarity_dict, between_class_pattern_similarity_dict, within_class_unit_similarity_dict, between_class_unit_similarity_dict = ut.compute_within_class_representational_similarity(network, test_dataloader, 
                                                                                                                                                                                                         population='E', plot=True)
Plotting pattern similarity for population: H1E
Plotting unit similarity for population: H1E
Plotting pattern similarity for population: H2E
Plotting unit similarity for population: H2E
../_images/fda435726b0c0a3f6be9465127b5e1de8cfee6d533c1f3db00ad861a8b3e0716.png

5.1 Hidden Layer 1#

receptive_fields_H1 = ut.compute_maxact_receptive_fields(network.H1.E, test_dataloader=test_dataloader)
eiann.plot.plot_receptive_fields(receptive_fields_H1, sort=True)
Optimizing receptive field images...
../_images/de0207d73569edf3e5a023ea5e5db667cc55962778a466915c7ca8b8ab159343.png
metrics_dict = ut.compute_representation_metrics(population=network.H1.E, dataloader=test_dataloader, receptive_fields=receptive_fields_H1, plot=True)
../_images/e3e98cf847a6d1cdf89a43d3d87e702eed664eb83257dddf8c190870344e61ba.png

5.2 Hidden Layer 2#

receptive_fields_H2 = ut.compute_maxact_receptive_fields(network.H2.E, test_dataloader=test_dataloader)
eiann.plot.plot_receptive_fields(receptive_fields_H2, sort=True)
Optimizing receptive field images...
../_images/a323bdccea32ef8d954253140d729ada75278d7a008a981a45651af2778835cc.png
metrics_dict = ut.compute_representation_metrics(population=network.H1.E, dataloader=test_dataloader, receptive_fields=receptive_fields_H2, plot=True)
../_images/6ab1d6e691e2a880fac8b1c2c8949b76509adb8b1e5e70b146eb54cbe9b64843.png

5.3 Output layer#

receptive_fields_output = ut.compute_maxact_receptive_fields(network.Output.E, test_dataloader=test_dataloader)
eiann.plot.plot_receptive_fields(receptive_fields_output, sort=False)
Optimizing receptive field images...
../_images/5fece1c0b58b2adc17dd7e72fce57ad8f8d72416a3edbfff50c278638db087fc.png