EI network - Dend Target Prop (BTSP)#

  • BTSP synaptic weight update

  • Top-down weight symmetry

Here we will show the analysis for a single example seed of a network described in Galloni et al. 2025.

import EIANN as eiann
from EIANN import utils as ut
eiann.update_plot_defaults()
root_dir = ut.get_project_root()

Load MNIST data#

train_dataloader, val_dataloader, test_dataloader, data_generator = ut.get_MNIST_dataloaders()

Load optimized pre-trained EIANN model#

network_name = "20241212_EIANN_2_hidden_mnist_BTSP_config_5L_complete_optimized"   
network_seed = 66049
data_seed = 257

If you want to train this network configuration from scratch, you can build a new network object directly from the .yaml configuration file and then train it:

# Create network object
config_file_path = f"{root_dir}/EIANN/network_config/mnist/{network_name}.yaml"
network = ut.build_EIANN_from_config(config_file_path, network_seed=network_seed)

# Train network
data_generator.manual_seed(data_seed)
network.train(train_dataloader, val_dataloader, 
              epochs = 1,
              samples_per_epoch = 20_000, 
              val_interval = (0, -1, 100), 
              store_history = True,
              store_history_interval = (0, -1, 100), 
              store_dynamics = False, 
              store_params = True,
              status_bar = True)

# Optional: Save network object to pickle file
saved_network_path = f"{root_dir}/EIANN/data/mnist/{network_name}_{network_seed}_{data_seed}.pkl"
ut.save_network(spiral_net, path=saved_network_path)

In this case, since we have already trained the network, we will simply load the saved network object that is stored in a pickle (.pkl) file:

saved_network_path = f"{root_dir}/EIANN/data/saved_network_pickles/mnist/{network_name}_{network_seed}_{data_seed}.pkl"
network = ut.load_network(saved_network_path)
network.name = network_name
network.seed = f"{network_seed}_{data_seed}"
Loading network from '/Users/ag1880/github-repos/Milstein-Lab/EIANN/EIANN/data/saved_network_pickles/mnist/20241212_EIANN_2_hidden_mnist_BTSP_config_5L_complete_optimized_66049_257.pkl'
Network successfully loaded from '/Users/ag1880/github-repos/Milstein-Lab/EIANN/EIANN/data/saved_network_pickles/mnist/20241212_EIANN_2_hidden_mnist_BTSP_config_5L_complete_optimized_66049_257.pkl'

3. Training results#

eiann.plot_loss_history(network, ylim=(-0.01, 0.2))
eiann.plot_accuracy_history(network)
eiann.plot_error_history(network)
../_images/3e56e46eb77af32c4aed36a4aeae0b2f2d252364d549047165b377742b06190e.png ../_images/7e4e19381997ace5236c71e4b70ab1a132097b0fb19a41784b026df7165a1aa1.png ../_images/8bda41c91a3d346bb8fe74cb6b2459e4ba89e15c02b553de04e59b2de84628b3.png
eiann.plot_loss_landscape(test_dataloader, network, num_points=20, extension=1, vmax_scale=1.2, scale='log')
../_images/2396f320ab52346adda32d9d2d2b98ee00ed614c6e8b37222a2654407924593a.png

4. Analyze population activities#

eiann.plot_batch_accuracy(network, test_dataloader, population='E')
Batch accuracy = 93.45999908447266%
../_images/af6a4a812f70d7c9e40eac09db361046b285f5847d678fd1a13d50120725dc78.png ../_images/228d81af0fbf8d21fcde70b694c1dba60abc6e53c066c72548cbc05dc04ab3ba.png ../_images/4ddac754108325cafaad0816c59a594d6d29fbca959c14521433926617e87b50.png
pop_dynamics_dict = ut.compute_test_activity_dynamics(network, test_dataloader, plot=True, normalize=True) # We will evaluate the network dynamics by presenting the test dataset and recording the neuron activity.
print(pop_dynamics_dict['H1E'].shape) # Since we are store the dynamics here, each population should have activity of shape (timesteps, data samples, neurons)
torch.Size([15, 10000, 500])
../_images/03c12735bffb5783de5d4b2c388b53231390caa2d7716dcac57e6dd27c8a69f2.png

5. Analyze learned representations#

pop_activity_dict, pattern_labels, unit_labels_dict = ut.compute_test_activity(network, test_dataloader, class_average=False, sort=True)
pattern_similarity_matrix_dict, neuron_similarity_matrix_dict = ut.compute_representational_similarity_matrix(pop_activity_dict, pattern_labels, unit_labels_dict, population='E', plot=True)
../_images/c390b522e8f4b4c2a1d76c8dc0e4eb421d857ee34b012294586285df3269dff8.png ../_images/dd919f615f7ad4dd0d1c0782650bf921795e8fe2c35dd85ba6a1736970bd5904.png ../_images/5e1d24e943afb10bc2af82470c57c70f09c3a8f32590f32b14d233a9a987e409.png

5.3 Output layer#

receptive_fields_output = ut.compute_maxact_receptive_fields(network.Output.E, test_dataloader=test_dataloader, export=True, export_path=root_dir+f"/EIANN/data/model_hdf5_plot_data/plot_data_{network_name}.h5")
eiann.plot_receptive_fields(receptive_fields_output, sort=False)
Loading maxact_receptive_fields_OutputE from file: /Users/ag1880/github-repos/Milstein-Lab/EIANN/EIANN/data/model_hdf5_plot_data/plot_data_20241212_EIANN_2_hidden_mnist_BTSP_config_5L_complete_optimized.h5
Data key maxact_receptive_fields_OutputE not found in seed 66049_257 of network 20241212_EIANN_2_hidden_mnist_BTSP_config_5L_complete_optimized in file: /Users/ag1880/github-repos/Milstein-Lab/EIANN/EIANN/data/model_hdf5_plot_data/plot_data_20241212_EIANN_2_hidden_mnist_BTSP_config_5L_complete_optimized.h5
Optimizing receptive field images OutputE...
maxact_receptive_fields_OutputE saved to file: /Users/ag1880/github-repos/Milstein-Lab/EIANN/EIANN/data/model_hdf5_plot_data/plot_data_20241212_EIANN_2_hidden_mnist_BTSP_config_5L_complete_optimized.h5
../_images/bd4bd242c38bb178a7c076b9c063103f8f30360e6bad5491a38c84d53a290c11.png

5.2 Hidden Layer 2#

receptive_fields_H2 = ut.compute_maxact_receptive_fields(network.H2.E, test_dataloader=test_dataloader, export=True, export_path=root_dir+f"/EIANN/data/model_hdf5_plot_data/plot_data_{network_name}.h5")
eiann.plot_receptive_fields(receptive_fields_H2, sort=True)
Loading maxact_receptive_fields_H2E from file: /Users/ag1880/github-repos/Milstein-Lab/EIANN/EIANN/data/model_hdf5_plot_data/plot_data_20241212_EIANN_2_hidden_mnist_BTSP_config_5L_complete_optimized.h5
Data key maxact_receptive_fields_H2E not found in seed 66049_257 of network 20241212_EIANN_2_hidden_mnist_BTSP_config_5L_complete_optimized in file: /Users/ag1880/github-repos/Milstein-Lab/EIANN/EIANN/data/model_hdf5_plot_data/plot_data_20241212_EIANN_2_hidden_mnist_BTSP_config_5L_complete_optimized.h5
Optimizing receptive field images H2E...
maxact_receptive_fields_H2E saved to file: /Users/ag1880/github-repos/Milstein-Lab/EIANN/EIANN/data/model_hdf5_plot_data/plot_data_20241212_EIANN_2_hidden_mnist_BTSP_config_5L_complete_optimized.h5
../_images/91aa6e9dfe7bc9a780d60b73f255bea4362e9e06de033d50434cf832f298a031.png
population = 'H2E'
average_pop_activity_dict, pattern_labels, unit_labels_dict = ut.compute_test_activity(network, test_dataloader, class_average=True, sort=True)
eiann.plot_receptive_field_similarity(receptive_fields_H2, average_pop_activity_dict[population], unit_labels_dict[population])
../_images/1d15845fb7bb2c7f3d60adffd41bdd35f654edb5eb2ef43fed328b9b3876c51a.png
metrics_dict = ut.compute_representation_metrics(population=network.H1.E, dataloader=test_dataloader, receptive_fields=receptive_fields_H2, plot=True)
../_images/61c703f34cb25f6c081f4f48c88f5a8740d0e70b14b952da6238333bba24ba46.png

5.1 Hidden Layer 1#

receptive_fields_H1 = ut.compute_maxact_receptive_fields(network.H1.E, test_dataloader=test_dataloader, export=True, export_path=root_dir+f"/EIANN/data/model_hdf5_plot_data/plot_data_{network_name}.h5")
eiann.plot_receptive_fields(receptive_fields_H1, sort=True)
Loading maxact_receptive_fields_H1E from file: /Users/ag1880/github-repos/Milstein-Lab/EIANN/EIANN/data/model_hdf5_plot_data/plot_data_20241212_EIANN_2_hidden_mnist_BTSP_config_5L_complete_optimized.h5
Data key maxact_receptive_fields_H1E not found in seed 66049_257 of network 20241212_EIANN_2_hidden_mnist_BTSP_config_5L_complete_optimized in file: /Users/ag1880/github-repos/Milstein-Lab/EIANN/EIANN/data/model_hdf5_plot_data/plot_data_20241212_EIANN_2_hidden_mnist_BTSP_config_5L_complete_optimized.h5
Optimizing receptive field images H1E...
maxact_receptive_fields_H1E saved to file: /Users/ag1880/github-repos/Milstein-Lab/EIANN/EIANN/data/model_hdf5_plot_data/plot_data_20241212_EIANN_2_hidden_mnist_BTSP_config_5L_complete_optimized.h5
../_images/18f502a89c0c24938be79e8106322d3ac0f50ef0e1330280d15341d31fd08bca.png
metrics_dict = ut.compute_representation_metrics(population=network.H1.E, dataloader=test_dataloader, receptive_fields=receptive_fields_H1, plot=True)
../_images/28f36ad8a67300baca4cc137bdc5dca5ff20ffb1c86e8bf1d64a39d7590c0b53.png
population = 'H1E'
average_pop_activity_dict, pattern_labels, unit_labels_dict = ut.compute_test_activity(network, test_dataloader, class_average=True, sort=True)
eiann.plot_receptive_field_similarity(receptive_fields_H1, average_pop_activity_dict[population], unit_labels_dict[population])
../_images/9e9b45ba8c34aa651c48661b861e23b26be4a50a48b214bae70ec3cd44715dbb.png