EI network - Backprop (Dale’s Law, fixed inhibition)#

Here we will show the analysis for a single example seed of a network described in Galloni et al. 2025.

import EIANN.EIANN as eiann
from EIANN import utils as ut
eiann.plot.update_plot_defaults()
root_dir = ut.get_project_root()
%load_ext autoreload
%autoreload 2

Hide code cell output

1. Load MNIST data#

train_dataloader, val_dataloader, test_dataloader, data_generator = ut.get_MNIST_dataloaders()

2. Load optimized pre-trained EIANN model:#

EI network trained with Backprop#

network_name = "20231129_EIANN_2_hidden_mnist_bpDale_relu_SGD_config_G_complete_optimized"   
network_seed = 66049
data_seed = 257

If you want to train this network configuration from scratch, you can build a new network object directly from the .yaml configuration file and then train it:

# Create network object
config_file_path = f"../network_config/mnist/{network_name}.yaml"
network = ut.build_EIANN_from_config(config_file_path, network_seed=network_seed)

# Train network
data_generator.manual_seed(data_seed)
network.train(train_dataloader, val_dataloader, 
              epochs = 1,
              samples_per_epoch = 20_000, 
              val_interval = (0, -1, 100), 
              store_history = True,
              store_history_interval = (0, -1, 100), 
              store_dynamics = False, 
              store_params = True,
              status_bar = True)

# Optional: Save network object to pickle file
saved_network_path = root_dir + f"/EIANN/data/mnist/{network_name}_{network_seed}_{data_seed}.pkl"
ut.save_network(spiral_net, path=saved_network_path)

In this case, since we have already trained the network, we will simply load the saved network object that is stored in a pickle (.pkl) file:

saved_network_path = root_dir + f"/EIANN/data/saved_network_pickles/mnist/{network_name}_{network_seed}_{data_seed}.pkl"
network = ut.load_network(saved_network_path)
network.name = network_name
network.seed = f"{network_seed}_{data_seed}"
Loading network from '/Users/ag1880/github-repos/Milstein-Lab/EIANN/EIANN/data/mnist/20231129_EIANN_2_hidden_mnist_bpDale_relu_SGD_config_G_complete_optimized_66049_257.pkl'
Network successfully loaded from '/Users/ag1880/github-repos/Milstein-Lab/EIANN/EIANN/data/mnist/20231129_EIANN_2_hidden_mnist_bpDale_relu_SGD_config_G_complete_optimized_66049_257.pkl'

3. Training results#

eiann.plot.plot_loss_history(network, ylim=(-0.01, 0.2))
eiann.plot.plot_accuracy_history(network)
eiann.plot.plot_error_history(network)
../_images/6ef8b1e1a2c8925ea554fd7cb8c0a9a0178d498630cc0e32bd7662d0d50c12fa.png ../_images/cba4a3561d4413b1708f94ecb9dd975704f5e04fb0aab1e2bc10ec4fc8606465.png ../_images/f77a612a2369c4df86faad35d4304bcac55e639ada95dfa06b79899b283cae2c.png
eiann.plot.plot_loss_landscape(test_dataloader, network, num_points=20, extension=0.9, vmax_scale=1.2, scale='log')
../_images/aa3feb775db11ebb77c82ae486e9a3c770760bb95b4857db0bd4da2bd466d96c.png

4. Analyze population activities#

eiann.plot.plot_batch_accuracy(network, test_dataloader, population='E')
Batch accuracy = 77.0199966430664%
../_images/b47692fe60f011ad1bf9963b1543d0c70497affa8d0a79da97edb7af6e0b9e2d.png ../_images/26f6ffd7ffcc12d1d39614c0040c11300b73fbdf451dc0c37c0e5594515e1a17.png ../_images/5de006ddf3e6622208280165716d970ae9b87cfc77eda1398ed5bff4f69fe472.png
pop_dynamics_dict = ut.compute_test_activity_dynamics(network, test_dataloader, plot=True, normalize=True) # We will evaluate the network dynamics by presenting the test dataset and recording the neuron activity.
print(pop_dynamics_dict['H1E'].shape) # Since we are store the dynamics here, each population should have activity of shape (timesteps, data samples, neurons)
torch.Size([15, 10000, 500])
../_images/8177c277291365ca2d77e1698b34ed8abd7ce4b9df859a66859c894dbaea2630.png

5. Analyze learned representations#

pop_activity_dict, pattern_labels, unit_labels_dict = ut.compute_test_activity(network, test_dataloader, class_average=False, sort=True)
pattern_similarity_matrix_dict, neuron_similarity_matrix_dict = ut.compute_representational_similarity_matrix(pop_activity_dict, pattern_labels, unit_labels_dict, population='E', plot=True)
../_images/ce1dad5484fee082222b6fced565ee0b86c81c52b297a58adc79b133406abd4e.png ../_images/1c65c8993344546c0b5ff72e3bf6882c03f5cb5d9cd1adfa6857e51671d05caa.png ../_images/3db6588008ea65b44f26ce1c2208a2eb7119f2afaa1ba0b21a6a3980e5140c84.png
within_class_pattern_similarity_dict, between_class_pattern_similarity_dict, within_class_unit_similarity_dict, between_class_unit_similarity_dict = ut.compute_within_class_representational_similarity(network, test_dataloader, 
                                                                                                                                                                                                         population='E', plot=True)
Plotting pattern similarity for population: H1E
Plotting unit similarity for population: H1E
Plotting pattern similarity for population: H2E
Plotting unit similarity for population: H2E
../_images/360f5950263a11ed3a89768425507f73755829644dbed0b8f2b2a5f893c635cb.png

5.3 Output layer#

receptive_fields_output = ut.compute_maxact_receptive_fields(network.Output.E, test_dataloader=test_dataloader, export=True, export_path=root_dir+f"/EIANN/data/Figure_data_hdf5/plot_data_{network_name}.h5")
eiann.plot.plot_receptive_fields(receptive_fields_output, sort=False)
Loading maxact_receptive_fields_OutputE from file: /Users/ag1880/github-repos/Milstein-Lab/EIANN/EIANN/data/Figure_data_hdf5/plot_data_20231129_EIANN_2_hidden_mnist_bpDale_relu_SGD_config_G_complete_optimized.h5
../_images/2167bb836e4d5f6cd93f1b45a7e92b1158a20564907316d4e7e1ee4f9087143c.png

5.2 Hidden Layer 2#

receptive_fields_H2 = ut.compute_maxact_receptive_fields(network.H2.E, test_dataloader=test_dataloader, export=True, export_path=root_dir+f"/EIANN/data/Figure_data_hdf5/plot_data_{network_name}.h5")
eiann.plot.plot_receptive_fields(receptive_fields_H2, sort=True)
Loading maxact_receptive_fields_H2E from file: /Users/ag1880/github-repos/Milstein-Lab/EIANN/EIANN/data/Figure_data_hdf5/plot_data_20231129_EIANN_2_hidden_mnist_bpDale_relu_SGD_config_G_complete_optimized.h5
../_images/a1b02d746cc4e571edc1a6abe060113fd152fa4db12a52f87783821624b96062.png
population = 'H2E'
average_pop_activity_dict, pattern_labels, unit_labels_dict = ut.compute_test_activity(network, test_dataloader, class_average=True, sort=True)
eiann.plot_receptive_field_similarity(receptive_fields_H2, average_pop_activity_dict[population], unit_labels_dict[population])
../_images/14634f6f73dcd640dd1b7ebc2e10dbb991ef690974f87b1d61ad8fcc269b02cd.png
metrics_dict = ut.compute_representation_metrics(population=network.H1.E, dataloader=test_dataloader, receptive_fields=receptive_fields_H2, plot=True)
../_images/e10e14cd52558ab1eb8eb8574dbe2e22274377b57926a5cb445962e21a8ecce8.png

5.1 Hidden Layer 1#

receptive_fields_H1 = ut.compute_maxact_receptive_fields(network.H1.E, test_dataloader=test_dataloader, export=True, export_path=root_dir+f"/EIANN/data/Figure_data_hdf5/plot_data_{network_name}.h5")
eiann.plot.plot_receptive_fields(receptive_fields_H1, sort=True)
Loading maxact_receptive_fields_H1E from file: /Users/ag1880/github-repos/Milstein-Lab/EIANN/EIANN/data/Figure_data_hdf5/plot_data_20231129_EIANN_2_hidden_mnist_bpDale_relu_SGD_config_G_complete_optimized.h5
maxact_receptive_fields_H1E loaded from file: /Users/ag1880/github-repos/Milstein-Lab/EIANN/EIANN/data/Figure_data_hdf5/plot_data_20231129_EIANN_2_hidden_mnist_bpDale_relu_SGD_config_G_complete_optimized.h5
../_images/2dc63fe2ca07d6b1c19cf05d06c21afd5472feb70aa96df2d2ce38e9707ab76b.png
metrics_dict = ut.compute_representation_metrics(population=network.H1.E, dataloader=test_dataloader, receptive_fields=receptive_fields_H1, plot=True)
../_images/0bc1030392d76c4af3b0199c8585ef06a91f75214926c084e9ec8c68ae733eb6.png
population = 'H1E'
average_pop_activity_dict, pattern_labels, unit_labels_dict = ut.compute_test_activity(network, test_dataloader, class_average=True, sort=True)
eiann.plot_receptive_field_similarity(receptive_fields_H1, average_pop_activity_dict[population], unit_labels_dict[population])
../_images/8aa464eb2bcdb20bc52b97726441f1420ab04136350a1986556973bf2371c69c.png