EI network - BCM#

Here we will show the analysis for a single example seed of a network described in Galloni et al. 2025.

import EIANN as eiann
from EIANN import utils as ut
eiann.update_plot_defaults()
root_dir = ut.get_project_root()

Load MNIST data#

train_dataloader, val_dataloader, test_dataloader, data_generator = ut.get_MNIST_dataloaders()

Load optimized pre-trained EIANN model#

network_name = "20240723_EIANN_2_hidden_mnist_Supervised_BCM_config_4_complete_optimized" # Supervised_BCM_WT_hebbdend
network_seed = 66049
data_seed = 257

If you want to train this network configuration from scratch, you can build a new network object directly from the .yaml configuration file and then train it:

# Create network object
config_file_path = f"{root_dir}/EIANN/network_config/mnist/{network_name}.yaml"
network = ut.build_EIANN_from_config(config_file_path, network_seed=network_seed)

# Train network
data_generator.manual_seed(data_seed)
network.train(train_dataloader, val_dataloader, 
              epochs = 1,
              samples_per_epoch = 20_000, 
              val_interval = (0, -1, 100), 
              store_history = True,
              store_history_interval = (0, -1, 100), 
              store_dynamics = False, 
              store_params = True,
              status_bar = True)

# Optional: Save network object to pickle file
saved_network_path = f"{root_dir}/EIANN/data/mnist/{network_name}_{network_seed}_{data_seed}.pkl"
ut.save_network(spiral_net, path=saved_network_path)

In this case, since we have already trained the network, we will simply load the saved network object that is stored in a pickle (.pkl) file:

saved_network_path = f"{root_dir}/EIANN/data/saved_network_pickles/mnist/{network_name}_{network_seed}_{data_seed}.pkl"
network = ut.load_network(saved_network_path)
network.name = network_name
network.seed = f"{network_seed}_{data_seed}"
Loading network from '/Users/ag1880/github-repos/Milstein-Lab/EIANN/EIANN/data/saved_network_pickles/mnist/20240723_EIANN_2_hidden_mnist_Supervised_BCM_config_4_complete_optimized_66049_257.pkl'
Network successfully loaded from '/Users/ag1880/github-repos/Milstein-Lab/EIANN/EIANN/data/saved_network_pickles/mnist/20240723_EIANN_2_hidden_mnist_Supervised_BCM_config_4_complete_optimized_66049_257.pkl'

3. Training results#

eiann.plot_loss_history(network, ylim=(0.05, 0.11))
eiann.plot_accuracy_history(network)
eiann.plot_error_history(network)
../_images/efc9cd27bdde2e7d35248781004b667aa58cdd032ab1520dc0a80cbebac4cccb.png ../_images/88b3cc9b3587439aeec509e148ff8ca6f14b2883dc8f1a6abb67207d8e8612ec.png ../_images/f5364ea5077a291eae60fe0e65bec0d9fc8a04ec7a7fe6bfe312f8bb6ab0ebcf.png
eiann.plot_loss_landscape(test_dataloader, network, num_points=30, extension=0.8, vmax_scale=1.1, scale='linear')
../_images/43e6a48f3f3cb2a12de5e505fadc97e9cdc71bb0418cff86ecf88d1cd82b595b.png

4. Analyze population activities#

eiann.plot_batch_accuracy(network, test_dataloader, population='E')
Batch accuracy = 79.69000244140625%
../_images/a1374b9794c7e6c1e947ea853b71d8218dcbf5b7a7a2fdd79f65a1a3e985d8a4.png ../_images/d7ec602427f2afbb65fd9a724970d1d5f19352ada22c16445455f6e94942dbba.png ../_images/a7e530cd180fe06ef7bcc5f87c9a48a72bcd7a3fa91ff36cb618b988cef6cbd9.png
pop_dynamics_dict = ut.compute_test_activity_dynamics(network, test_dataloader, plot=True, normalize=True) # We will evaluate the network dynamics by presenting the test dataset and recording the neuron activity.
print(pop_dynamics_dict['H1E'].shape) # Since we are store the dynamics here, each population should have activity of shape (timesteps, data samples, neurons)
torch.Size([15, 10000, 500])
../_images/8110e50bed4bc7ccbdc9667879696585cfd0561438cb258ee006c773eb52ff1a.png

5. Analyze learned representations#

pop_activity_dict, pattern_labels, unit_labels_dict = ut.compute_test_activity(network, test_dataloader, class_average=False, sort=True)
pattern_similarity_matrix_dict, neuron_similarity_matrix_dict = ut.compute_representational_similarity_matrix(pop_activity_dict, pattern_labels, unit_labels_dict, population='E', plot=True)
../_images/1bb15244c00dab6eb6c8c27efffcf591162b1cb8f554e64aabd60efc5dcfe567.png ../_images/88a4d6bb4587de8d7c641c6675c53de482aa372cade6b10de180542d3baf83be.png ../_images/49b4e4902ed26a517166974771e5d91379e2613e46fc13a24179898783f6da9b.png

5.3 Output layer#

receptive_fields_output = ut.compute_maxact_receptive_fields(network.Output.E, test_dataloader=test_dataloader, export=True, export_path=root_dir+f"/EIANN/data/model_hdf5_plot_data/plot_data_{network_name}.h5")
eiann.plot_receptive_fields(receptive_fields_output, sort=False)
Loading maxact_receptive_fields_OutputE from file: /Users/ag1880/github-repos/Milstein-Lab/EIANN/EIANN/data/model_hdf5_plot_data/plot_data_20240723_EIANN_2_hidden_mnist_Supervised_BCM_config_4_complete_optimized.h5
../_images/33f59b5ce00a163acccaa13ae9491cf74760de240fd423583955a4a7829404e3.png

5.2 Hidden Layer 2#

receptive_fields_H2 = ut.compute_maxact_receptive_fields(network.H2.E, test_dataloader=test_dataloader, export=True, export_path=root_dir+f"/EIANN/data/model_hdf5_plot_data/plot_data_{network_name}.h5")
eiann.plot_receptive_fields(receptive_fields_H2, sort=True)
Loading maxact_receptive_fields_H2E from file: /Users/ag1880/github-repos/Milstein-Lab/EIANN/EIANN/data/model_hdf5_plot_data/plot_data_20240723_EIANN_2_hidden_mnist_Supervised_BCM_config_4_complete_optimized.h5
../_images/7a5271fc815709fc40bb1e95e1daa462125f4f0c04eac9658a1ec146484f48cc.png
population = 'H2E'
average_pop_activity_dict, pattern_labels, unit_labels_dict = ut.compute_test_activity(network, test_dataloader, class_average=True, sort=True)
eiann.plot_receptive_field_similarity(receptive_fields_H2, average_pop_activity_dict[population], unit_labels_dict[population])
../_images/799fb10df7c4d47ef8805db821c99ecba6679ea10b9568c8f8d84b8d6cfb45d2.png
metrics_dict = ut.compute_representation_metrics(population=network.H1.E, dataloader=test_dataloader, receptive_fields=receptive_fields_H2, plot=True)
../_images/a484fe5c51658ffc2e966cd53c20e44ca957f941553f30eaecb1aeea3f418162.png

5.1 Hidden Layer 1#

receptive_fields_H1 = ut.compute_maxact_receptive_fields(network.H1.E, test_dataloader=test_dataloader, export=True, export_path=root_dir+f"/EIANN/data/model_hdf5_plot_data/plot_data_{network_name}.h5")
eiann.plot.plot_receptive_fields(receptive_fields_H1, sort=True)
Loading maxact_receptive_fields_H1E from file: /Users/ag1880/github-repos/Milstein-Lab/EIANN/EIANN/data/model_hdf5_plot_data/plot_data_20240723_EIANN_2_hidden_mnist_Supervised_BCM_config_4_complete_optimized.h5
../_images/8e0100b431819b07ced829ca9127423f3875c8efcb90a1fcbb6d19d6892bfac8.png
metrics_dict = ut.compute_representation_metrics(population=network.H1.E, dataloader=test_dataloader, receptive_fields=receptive_fields_H1, plot=True)
../_images/2c4f4d076bf78614fb1172f24ed5f888b846358dd5ed8df1d2bea20c9018a1f0.png
population = 'H1E'
average_pop_activity_dict, pattern_labels, unit_labels_dict = ut.compute_test_activity(network, test_dataloader, class_average=True, sort=True)
eiann.plot_receptive_field_similarity(receptive_fields_H1, average_pop_activity_dict[population], unit_labels_dict[population])
../_images/b44a682da2d24930e584d8ed59a551defbdecbdef0f77c422b592e9a14de3445.png