EI network - Dendritic Target Propagation (BTSP)#

  • BTSP synaptic weight update

  • Top-down weight symmetry

Here we will show the analysis for a single example seed of a network described in Galloni et al. 2025.

import EIANN.EIANN as eiann
from EIANN import utils as ut
eiann.plot.update_plot_defaults()
root_dir = ut.get_project_root()
%load_ext autoreload
%autoreload 2

Hide code cell output

1. Load MNIST data#

train_dataloader, val_dataloader, test_dataloader, data_generator = ut.get_MNIST_dataloaders()

2. Load optimized pre-trained EIANN model:#

EI network trained with Backprop#

network_name = "20241212_EIANN_2_hidden_mnist_BTSP_config_5L_complete_optimized"   
network_seed = 66049
data_seed = 257

If you want to train this network configuration from scratch, you can build a new network object directly from the .yaml configuration file and then train it:

# Create network object
config_file_path = f"../network_config/mnist/{network_name}.yaml"
network = ut.build_EIANN_from_config(config_file_path, network_seed=network_seed)

# Train network
data_generator.manual_seed(data_seed)
network.train(train_dataloader, val_dataloader, 
              epochs = 1,
              samples_per_epoch = 20_000, 
              val_interval = (0, -1, 100), 
              store_history = True,
              store_history_interval = (0, -1, 100), 
              store_dynamics = False, 
              store_params = True,
              status_bar = True)

# Optional: Save network object to pickle file
saved_network_path = root_dir + f"/EIANN/data/mnist/{network_name}_{network_seed}_{data_seed}.pkl"
ut.save_network(spiral_net, path=saved_network_path)

In this case, since we have already trained the network, we will simply load the saved network object that is stored in a pickle (.pkl) file:

saved_network_path = root_dir + f"/EIANN/data/saved_network_pickles/mnist/{network_name}_{network_seed}_{data_seed}.pkl"
network = ut.load_network(saved_network_path)
network.name = network_name
network.seed = f"{network_seed}_{data_seed}"
Loading network from '/Users/ag1880/github-repos/Milstein-Lab/EIANN/EIANN/data/mnist/20241212_EIANN_2_hidden_mnist_BTSP_config_5L_complete_optimized_66049_257.pkl'
Network successfully loaded from '/Users/ag1880/github-repos/Milstein-Lab/EIANN/EIANN/data/mnist/20241212_EIANN_2_hidden_mnist_BTSP_config_5L_complete_optimized_66049_257.pkl'

3. Training results#

eiann.plot.plot_loss_history(network, ylim=(-0.01, 0.2))
eiann.plot.plot_accuracy_history(network)
eiann.plot.plot_error_history(network)
../_images/384c402eb65aa52d6cfd849fcfdb6e7cc198872b78fafca83c95c0a7239b0606.png ../_images/fc3e0e408850f1777658ba01cb8f460ac9e3c995bc93f393d3a1373ff76b0306.png ../_images/d5adea7b8ddf8a93c40b828e502c7e99d64b53a24826af38a382aebc9d0343bb.png
eiann.plot.plot_loss_landscape(test_dataloader, network, num_points=20, extension=1, vmax_scale=1.2, scale='log')
../_images/3e4a7cb772348b5c58972abecafe1a42f89e6199f047185220e7b44ead5cddf8.png

4. Analyze population activities#

eiann.plot.plot_batch_accuracy(network, test_dataloader, population='E')
Batch accuracy = 93.45999908447266%
../_images/79f05e23273995c02a3bb459e1d0f10a97d24cdb63194f20be30d8441bd14d12.png ../_images/86164c7f8e43b3c2c32b5cf53c653e4b6b9868781a5f09a74d8f0d54eea26f6c.png ../_images/490a37d1736d7e6325ad597d08d4d77c6a3fffb179a71d9df477e0e0f385e7a5.png
pop_dynamics_dict = ut.compute_test_activity_dynamics(network, test_dataloader, plot=True, normalize=True) # We will evaluate the network dynamics by presenting the test dataset and recording the neuron activity.
print(pop_dynamics_dict['H1E'].shape) # Since we are store the dynamics here, each population should have activity of shape (timesteps, data samples, neurons)
torch.Size([15, 10000, 500])
../_images/5272138b83e516847da60bd9d595a1c9bd8dee475ec8ed4064532ec867f04e18.png

5. Analyze learned representations#

pop_activity_dict, pattern_labels, unit_labels_dict = ut.compute_test_activity(network, test_dataloader, class_average=False, sort=True)
pattern_similarity_matrix_dict, neuron_similarity_matrix_dict = ut.compute_representational_similarity_matrix(pop_activity_dict, pattern_labels, unit_labels_dict, population='E', plot=True)
../_images/6e1e5c1434e46223a97a63ad67e8b5a3c71abb7e8997a9b2509f838a238a2ef8.png ../_images/0217e6e2693ec1adacf97c4ac9f9b9ef888ee93b7b7df447daa25c90cda7fb09.png ../_images/59cab1b48b1b2be7713fa7db838a2a6f738027ac807fc20adadb07e607fb5a22.png
within_class_pattern_similarity_dict, between_class_pattern_similarity_dict, within_class_unit_similarity_dict, between_class_unit_similarity_dict = ut.compute_within_class_representational_similarity(network, test_dataloader, 
                                                                                                                                                                                                         population='E', plot=True)
Computing representational similarity for population: OutputE
Computing representational similarity for population: H2E
Computing representational similarity for population: H1E
../_images/33ec307d28fb0cd3cda76f23d600a77a96793c3e0c4db1e46599eee0f5e9f699.png

5.3 Output layer#

receptive_fields_output = ut.compute_maxact_receptive_fields(network.Output.E, test_dataloader=test_dataloader, export=True, export_path=root_dir+f"/EIANN/data/model_hdf5_plot_data/plot_data_{network_name}.h5")
eiann.plot.plot_receptive_fields(receptive_fields_output, sort=False)
Loading maxact_receptive_fields_OutputE from file: /Users/ag1880/github-repos/Milstein-Lab/EIANN/EIANN/data/model_hdf5_plot_data/plot_data_20241212_EIANN_2_hidden_mnist_BTSP_config_5L_complete_optimized.h5
../_images/44f791ce576226c943bbb6fbc960330ba2fcf1214a59890062a8ccdf2a04d7c0.png

5.2 Hidden Layer 2#

receptive_fields_H2 = ut.compute_maxact_receptive_fields(network.H2.E, test_dataloader=test_dataloader, export=True, export_path=root_dir+f"/EIANN/data/model_hdf5_plot_data/plot_data_{network_name}.h5")
eiann.plot.plot_receptive_fields(receptive_fields_H2, sort=True)
Loading maxact_receptive_fields_H2E from file: /Users/ag1880/github-repos/Milstein-Lab/EIANN/EIANN/data/model_hdf5_plot_data/plot_data_20241212_EIANN_2_hidden_mnist_BTSP_config_5L_complete_optimized.h5
../_images/dd7c1b38cfe2c8c0af75028934115e53f5de1fa6c5100420c6e34e9f0687f57a.png
population = 'H2E'
average_pop_activity_dict, pattern_labels, unit_labels_dict = ut.compute_test_activity(network, test_dataloader, class_average=True, sort=True)
eiann.plot_receptive_field_similarity(receptive_fields_H2, average_pop_activity_dict[population], unit_labels_dict[population])
../_images/529331d573ea865198b5dea01f9407153815e5a270a99af3b6f848e15c93ee87.png
metrics_dict = ut.compute_representation_metrics(population=network.H1.E, dataloader=test_dataloader, receptive_fields=receptive_fields_H2, plot=True)
../_images/ba5578dfab7edd2b6090d8d5be3a2b472608f4a9eec45038984e368f58141e49.png

5.1 Hidden Layer 1#

receptive_fields_H1 = ut.compute_maxact_receptive_fields(network.H1.E, test_dataloader=test_dataloader, export=True, export_path=root_dir+f"/EIANN/data/model_hdf5_plot_data/plot_data_{network_name}.h5")
eiann.plot.plot_receptive_fields(receptive_fields_H1, sort=True)
Loading maxact_receptive_fields_H1E from file: /Users/ag1880/github-repos/Milstein-Lab/EIANN/EIANN/data/model_hdf5_plot_data/plot_data_20241212_EIANN_2_hidden_mnist_BTSP_config_5L_complete_optimized.h5
../_images/fd463432315a580905e59e4939ea3c7fca1f9c9bd65604a8a136265b75c512b2.png
metrics_dict = ut.compute_representation_metrics(population=network.H1.E, dataloader=test_dataloader, receptive_fields=receptive_fields_H1, plot=True)
../_images/34d3c68e5bf2a826245a8b149b518a15a83c7c9ea910dad5dac4fbc7535c28cc.png
population = 'H1E'
average_pop_activity_dict, pattern_labels, unit_labels_dict = ut.compute_test_activity(network, test_dataloader, class_average=True, sort=True)
eiann.plot_receptive_field_similarity(receptive_fields_H1, average_pop_activity_dict[population], unit_labels_dict[population])
../_images/4e96ecee2cf48862808a240152ae7e6f4133fbb799a61d6349f8b3520c46578e.png