Feedforward ANN (Backprop)#

Here we will show the analysis for a single example seed of a network described in Galloni et al. 2025.

import EIANN as eiann
from EIANN import utils as ut
eiann.update_plot_defaults()
root_dir = ut.get_project_root()

Load MNIST data#

train_dataloader, val_dataloader, test_dataloader, data_generator = ut.get_MNIST_dataloaders()

Load optimized pre-trained model: Feedforward ANN trained with Backprop#

network_name = "20231129_EIANN_2_hidden_mnist_van_bp_relu_SGD_config_G_complete_optimized"
network_seed = 66049
data_seed = 257

If you want to train this network configuration from scratch, you can build a new network object directly from the .yaml configuration file and then train it:

# Create network object
config_file_path = f"{root_dir}/EIANN/network_config/mnist/{network_name}.yaml"
network = ut.build_EIANN_from_config(config_file_path, network_seed=network_seed)

# Train network
data_generator.manual_seed(data_seed)
network.train(train_dataloader, val_dataloader, 
              epochs = 1,
              samples_per_epoch = 20_000, 
              val_interval = (0, -1, 100), 
              store_history = True,
              store_history_interval = (0, -1, 100), 
              store_dynamics = False, 
              store_params = True,
              status_bar = True)

# Optional: Save network object to pickle file
saved_network_path = f"{root_dir}/EIANN/data/mnist/{network_name}_{network_seed}_{data_seed}.pkl"
ut.save_network(spiral_net, path=saved_network_path)

In this case, since we have already trained the network, we will simply load the saved network object that is stored in a pickle (.pkl) file:

saved_network_path = f"{root_dir}/EIANN/data/saved_network_pickles/mnist/{network_name}_{network_seed}_{data_seed}.pkl"
network = ut.load_network(saved_network_path)
Loading network from '/Users/ag1880/github-repos/Milstein-Lab/EIANN/EIANN/data/saved_network_pickles/mnist/20231129_EIANN_2_hidden_mnist_van_bp_relu_SGD_config_G_complete_optimized_66049_257.pkl'
Network successfully loaded from '/Users/ag1880/github-repos/Milstein-Lab/EIANN/EIANN/data/saved_network_pickles/mnist/20231129_EIANN_2_hidden_mnist_van_bp_relu_SGD_config_G_complete_optimized_66049_257.pkl'

Training results#

eiann.plot_loss_history(network, ylim=(-0.01, 0.2))
eiann.plot_accuracy_history(network)
eiann.plot_error_history(network)
../_images/9f4c8dd366a6df1e0fdc3763065aae660fced8a44381aa2e55e60dbfc6c05246.png ../_images/82353f27fdf5e9904731d3ffd31998746e8b8b08e55362f1f93f0db1c75dbc03.png ../_images/e2948cfe0096cf52c1ad31620d6071cfa70e84082ae87d74663ebb9f20eebe1c.png
eiann.plot_loss_landscape(test_dataloader, network, num_points=20, extension=0.9, vmax_scale=1.2, scale='log')
../_images/df71bd66c3db3cb3ff94d7138d6b0f6a991680ee19c9b2f341ed8e38139c36b4.png

Analyze population activities#

eiann.plot_batch_accuracy(network, test_dataloader, population='all')
Batch accuracy = 96.11000061035156%
../_images/c4b04c7414ff600b71201bbc64f7ab98e5c3e5a3873e1d4bb931f0232779b0fc.png ../_images/f85fa40fa0607db69f9ebae3abd961b71063785d0e29e7e2717bc5a27dcdb188.png ../_images/9cf138e92dccd37a44258da4824ec93b41b567d364af2d290b2a18b8906d7c4b.png ../_images/0517207f52f02bd14e5ec6f59ee8546f12d07b528a91be99a866eb6708658596.png

Analyze learned representations#

pop_activity_dict, pattern_labels, unit_labels_dict = ut.compute_test_activity(network, test_dataloader, class_average=False, sort=True)
pattern_similarity_matrix_dict, neuron_similarity_matrix_dict = ut.compute_representational_similarity_matrix(pop_activity_dict, pattern_labels, unit_labels_dict, population='E', plot=True)
../_images/86bb0970caeff0dcd502a6e3d7055514b36e949de5b95627ba3ba9f605ca7070.png ../_images/25a093e5f9d0f6af2f5b769f539e1ae1a4424804f613e538c95118c2532e5dfa.png ../_images/b353da8dbbef5d2812493fe8149079c88a9a67f4472e006bc3d6db574861c7d6.png

Hidden Layer 1#

receptive_fields_H1 = ut.compute_maxact_receptive_fields(network.H1.E, test_dataloader=test_dataloader)
eiann.plot_receptive_fields(receptive_fields_H1, sort=True)
Optimizing receptive field images H1E...
../_images/c256006d46d91912ea7cdff2c7796afcc403f5bfc396ddbe45c6cbc26efffc8d.png
metrics_dict = ut.compute_representation_metrics(population=network.H1.E, dataloader=test_dataloader, receptive_fields=receptive_fields_H1, plot=True)
../_images/aaeafb2bf8c3512e311c985868a91793abc4f464808636b57bbfe699c3c0d329.png

Hidden Layer 2#

receptive_fields_H2 = ut.compute_maxact_receptive_fields(network.H2.E, test_dataloader=test_dataloader)
eiann.plot_receptive_fields(receptive_fields_H2, sort=True)
Optimizing receptive field images H2E...
../_images/535feeafd34d27610fcaff2d497b6f52487829f47119c8e1d564322a7041853c.png
metrics_dict = ut.compute_representation_metrics(population=network.H1.E, dataloader=test_dataloader, receptive_fields=receptive_fields_H2, plot=True)
../_images/fccbcecf56f3746d3bdc2589ce42106e56d3f80a932bc584895f6fff63eb4e51.png

Output layer#

receptive_fields_output = ut.compute_maxact_receptive_fields(network.Output.E, test_dataloader=test_dataloader)
eiann.plot_receptive_fields(receptive_fields_output, sort=False)
Optimizing receptive field images OutputE...
../_images/bec084a9b77f6a2e8ece36cd2ae9a6c7dc33821680da579375215507e09e6791.png