EI network - Hebbian (learned inh.)#

Here we will show the analysis for a single example seed of a network described in Galloni et al. 2025.

import EIANN as eiann
from EIANN import utils as ut
eiann.update_plot_defaults()
root_dir = ut.get_project_root()

Load MNIST data#

train_dataloader, val_dataloader, test_dataloader, data_generator = ut.get_MNIST_dataloaders()

Load optimized pre-trained EIANN model#

network_name = "20241105_EIANN_2_hidden_mnist_Top_Layer_Supervised_Hebb_WeightNorm_config_7_complete_optimized"   
network_seed = 66049
data_seed = 257

If you want to train this network configuration from scratch, you can build a new network object directly from the .yaml configuration file and then train it:

# Create network object
config_file_path = f"{root_dir}/EIANN/network_config/mnist/{network_name}.yaml"
network = ut.build_EIANN_from_config(config_file_path, network_seed=network_seed)

# Train network
data_generator.manual_seed(data_seed)
network.train(train_dataloader, val_dataloader, 
              epochs = 1,
              samples_per_epoch = 20_000, 
              val_interval = (0, -1, 100), 
              store_history = True,
              store_history_interval = (0, -1, 100), 
              store_dynamics = False, 
              store_params = True,
              status_bar = True)

# Optional: Save network object to pickle file
saved_network_path = f"{root_dir}/EIANN/data/mnist/{network_name}_{network_seed}_{data_seed}.pkl"
ut.save_network(spiral_net, path=saved_network_path)

In this case, since we have already trained the network, we will simply load the saved network object that is stored in a pickle (.pkl) file:

saved_network_path = f"{root_dir}/EIANN/data/saved_network_pickles/mnist/{network_name}_{network_seed}_{data_seed}.pkl"
network = ut.load_network(saved_network_path)
network.name = network_name
network.seed = f"{network_seed}_{data_seed}"
Loading network from '/Users/ag1880/github-repos/Milstein-Lab/EIANN/EIANN/data/saved_network_pickles/mnist/20241105_EIANN_2_hidden_mnist_Top_Layer_Supervised_Hebb_WeightNorm_config_7_complete_optimized_66049_257.pkl'
Network successfully loaded from '/Users/ag1880/github-repos/Milstein-Lab/EIANN/EIANN/data/saved_network_pickles/mnist/20241105_EIANN_2_hidden_mnist_Top_Layer_Supervised_Hebb_WeightNorm_config_7_complete_optimized_66049_257.pkl'

3. Training results#

eiann.plot_loss_history(network, ylim=(-0.01, 0.2))
eiann.plot_accuracy_history(network)
eiann.plot_error_history(network)
../_images/8ecf8169d7ef35214cff203eee98e77a14e45330930648ed39eab724f2d3b817.png ../_images/c100c2ad2e9aa8f2ba97fdbe991dcd6e9ac6c5e92ce0634d288e4b354d8b902b.png ../_images/c1641f30caca0f35e911718e63cbbacf00f1c609cab0eac573136f56a2eaf533.png
eiann.plot_loss_landscape(test_dataloader, network, num_points=30, extension=0.8, vmax_scale=1.2, scale='log')
../_images/ba6873b330fb3c96ac92acf91818e79934dfecb5d754a69135fc7fa9f3e331c3.png

4. Analyze population activities#

eiann.plot_batch_accuracy(network, test_dataloader, population='E')
Batch accuracy = 73.91000366210938%
../_images/06d90e04cf4df267994dae3e3f00f6271b9e0a9fa4a67dc6988f4577af3e7368.png ../_images/f18d032b9e403077922a5b270b2911ef31181cef4b93c75471d8515407e38fdb.png ../_images/24ecd57bd2cd0795726eb5589228650fe075210b9351a85d85f82a4d989264a7.png
pop_dynamics_dict = ut.compute_test_activity_dynamics(network, test_dataloader, plot=True, normalize=True) # We will evaluate the network dynamics by presenting the test dataset and recording the neuron activity.
print(pop_dynamics_dict['H1E'].shape) # Since we are store the dynamics here, each population should have activity of shape (timesteps, data samples, neurons)
torch.Size([15, 10000, 500])
../_images/549c936f4940b89cf86bd35427658ed2f145ed32fd6702ff6b2df139c86be489.png

5. Analyze learned representations#

pop_activity_dict, pattern_labels, unit_labels_dict = ut.compute_test_activity(network, test_dataloader, class_average=False, sort=True)
pattern_similarity_matrix_dict, neuron_similarity_matrix_dict = ut.compute_representational_similarity_matrix(pop_activity_dict, pattern_labels, unit_labels_dict, population='E', plot=True)
../_images/71d42cc8006abb671a3e4e04350f7f15550a261c611504e481cec6d3325bb422.png ../_images/025a5c9c22718f1317c95c2149f4858a2fa93192b1c639f0aee1c8a9056e5c64.png ../_images/dc47d745593c30abf856fea18f6470ac080506d8a93b1f235cc8fe8f5f6b2e7d.png

5.3 Output layer#

receptive_fields_output = ut.compute_maxact_receptive_fields(network.Output.E, test_dataloader=test_dataloader, export=True, export_path=root_dir+f"/EIANN/data/model_hdf5_plot_data/plot_data_{network_name}.h5")
eiann.plot_receptive_fields(receptive_fields_output, sort=False)
Loading maxact_receptive_fields_OutputE from file: /Users/ag1880/github-repos/Milstein-Lab/EIANN/EIANN/data/model_hdf5_plot_data/plot_data_20241105_EIANN_2_hidden_mnist_Top_Layer_Supervised_Hebb_WeightNorm_config_7_complete_optimized.h5
../_images/f127e7a9e9c2280bc5c0ed336894440d138245bc02c8e2298cb41764f275b344.png

5.2 Hidden Layer 2#

receptive_fields_H2 = ut.compute_maxact_receptive_fields(network.H2.E, test_dataloader=test_dataloader, export=True, export_path=root_dir+f"/EIANN/data/model_hdf5_plot_data/plot_data_{network_name}.h5")
eiann.plot_receptive_fields(receptive_fields_H2, sort=True)
Loading maxact_receptive_fields_H2E from file: /Users/ag1880/github-repos/Milstein-Lab/EIANN/EIANN/data/model_hdf5_plot_data/plot_data_20241105_EIANN_2_hidden_mnist_Top_Layer_Supervised_Hebb_WeightNorm_config_7_complete_optimized.h5
Data key maxact_receptive_fields_H2E not found in seed 66049_257 of network 20241105_EIANN_2_hidden_mnist_Top_Layer_Supervised_Hebb_WeightNorm_config_7_complete_optimized in file: /Users/ag1880/github-repos/Milstein-Lab/EIANN/EIANN/data/model_hdf5_plot_data/plot_data_20241105_EIANN_2_hidden_mnist_Top_Layer_Supervised_Hebb_WeightNorm_config_7_complete_optimized.h5
Optimizing receptive field images H2E...
maxact_receptive_fields_H2E saved to file: /Users/ag1880/github-repos/Milstein-Lab/EIANN/EIANN/data/model_hdf5_plot_data/plot_data_20241105_EIANN_2_hidden_mnist_Top_Layer_Supervised_Hebb_WeightNorm_config_7_complete_optimized.h5
../_images/cc2cface7c27015c6d03c5149545e8fcf45c404e45342370510cc6dff4f7a6cc.png
population = 'H2E'
average_pop_activity_dict, pattern_labels, unit_labels_dict = ut.compute_test_activity(network, test_dataloader, class_average=True, sort=True)
eiann.plot_receptive_field_similarity(receptive_fields_H2, average_pop_activity_dict[population], unit_labels_dict[population])
../_images/80212cd6663fe07af41726d3920e3066fe695558fb375591334ded78b083fbc1.png
metrics_dict = ut.compute_representation_metrics(population=network.H1.E, dataloader=test_dataloader, receptive_fields=receptive_fields_H2, plot=True)
../_images/19c3f7ae125ee39b59c743204db24ae1bec52f2cb35b91518e05e2fefea89f39.png

5.1 Hidden Layer 1#

receptive_fields_H1 = ut.compute_maxact_receptive_fields(network.H1.E, test_dataloader=test_dataloader, export=True, export_path=root_dir+f"/EIANN/data/model_hdf5_plot_data/plot_data_{network_name}.h5")
eiann.plot_receptive_fields(receptive_fields_H1, sort=True)
Loading maxact_receptive_fields_H1E from file: /Users/ag1880/github-repos/Milstein-Lab/EIANN/EIANN/data/model_hdf5_plot_data/plot_data_20241105_EIANN_2_hidden_mnist_Top_Layer_Supervised_Hebb_WeightNorm_config_7_complete_optimized.h5
Data key maxact_receptive_fields_H1E not found in seed 66049_257 of network 20241105_EIANN_2_hidden_mnist_Top_Layer_Supervised_Hebb_WeightNorm_config_7_complete_optimized in file: /Users/ag1880/github-repos/Milstein-Lab/EIANN/EIANN/data/model_hdf5_plot_data/plot_data_20241105_EIANN_2_hidden_mnist_Top_Layer_Supervised_Hebb_WeightNorm_config_7_complete_optimized.h5
Optimizing receptive field images H1E...
maxact_receptive_fields_H1E saved to file: /Users/ag1880/github-repos/Milstein-Lab/EIANN/EIANN/data/model_hdf5_plot_data/plot_data_20241105_EIANN_2_hidden_mnist_Top_Layer_Supervised_Hebb_WeightNorm_config_7_complete_optimized.h5
../_images/e3e45224efc3d952f34dbbcc411cdbfaeab90e5ed1ec9db04be6b18db2478f75.png
metrics_dict = ut.compute_representation_metrics(population=network.H1.E, dataloader=test_dataloader, receptive_fields=receptive_fields_H1, plot=True)
../_images/e3853c164f4a70ee60ec5ba123fbf9dfc421c4117533350e52d8365470a24995.png
population = 'H1E'
average_pop_activity_dict, pattern_labels, unit_labels_dict = ut.compute_test_activity(network, test_dataloader, class_average=True, sort=True)
eiann.plot_receptive_field_similarity(receptive_fields_H1, average_pop_activity_dict[population], unit_labels_dict[population])
../_images/1a5d6fc58c1bd8aba33e6371c490c79d16d26c5718b7a9591e8031d093ac8b62.png