Feedforward ANN (Backprop)#

Here we will show the analysis for a single example seed of a network described in Galloni et al. 2025.

import EIANN.EIANN as eiann
from EIANN import utils as ut
eiann.plot.update_plot_defaults()
root_dir = ut.get_project_root()

Hide code cell output

Load MNIST data#

train_dataloader, val_dataloader, test_dataloader, data_generator = ut.get_MNIST_dataloaders()

Load optimized pre-trained model: Feedforward ANN trained with Backprop#

network_name = "20231129_EIANN_2_hidden_mnist_van_bp_relu_SGD_config_G_complete_optimized"
network_seed = 66049
data_seed = 257

If you want to train this network configuration from scratch, you can build a new network object directly from the .yaml configuration file and then train it:

# Create network object
config_file_path = f"../network_config/mnist/{network_name}.yaml"
network = ut.build_EIANN_from_config(config_file_path, network_seed=network_seed)

# Train network
data_generator.manual_seed(data_seed)
network.train(train_dataloader, val_dataloader, 
              epochs = 1,
              samples_per_epoch = 20_000, 
              val_interval = (0, -1, 100), 
              store_history = True,
              store_history_interval = (0, -1, 100), 
              store_dynamics = False, 
              store_params = True,
              status_bar = True)

# Optional: Save network object to pickle file
saved_network_path = root_dir + f"/EIANN/data/mnist/{network_name}_{network_seed}_{data_seed}.pkl"
ut.save_network(spiral_net, path=saved_network_path)

In this case, since we have already trained the network, we will simply load the saved network object that is stored in a pickle (.pkl) file:

saved_network_path = root_dir + f"/EIANN/data/saved_network_pickles/mnist/{network_name}_{network_seed}_{data_seed}.pkl"
network = ut.load_network(saved_network_path)
Loading network from '/Users/ag1880/github-repos/Milstein-Lab/EIANN/EIANN/data/saved_network_pickles/mnist/20231129_EIANN_2_hidden_mnist_van_bp_relu_SGD_config_G_complete_optimized_66049_257.pkl'
Network successfully loaded from '/Users/ag1880/github-repos/Milstein-Lab/EIANN/EIANN/data/saved_network_pickles/mnist/20231129_EIANN_2_hidden_mnist_van_bp_relu_SGD_config_G_complete_optimized_66049_257.pkl'

Training results#

eiann.plot.plot_loss_history(network, ylim=(-0.01, 0.2))
eiann.plot.plot_accuracy_history(network)
eiann.plot.plot_error_history(network)
../_images/821cb24ac5246cf5783aecc07c4fe9419a76640e53952a945511fe6920f2c536.png ../_images/d7c1a8af0a83b4e6de8ef2747818e2141daef2423e4d563cd62b0197ececbabd.png ../_images/e25e39cefe28f842bb37a1b86f1fb357214d4657d39e14063c4429f98c910a79.png
eiann.plot.plot_loss_landscape(test_dataloader, network, num_points=20, extension=0.9, vmax_scale=1.2, scale='log')
../_images/c3d382e4e12fa15d672f16d98855c359c69103cd814a2bd4829f46353c8e7e7e.png

Analyze population activities#

eiann.plot.plot_batch_accuracy(network, test_dataloader, population='all')
Batch accuracy = 96.11000061035156%
../_images/266979086ac0e23bd354916fa81bcbee7a283e3feb1dbce6cd0f5667418c1a6c.png ../_images/391621e7ec5f59d99a03917942851c77cf34c29a2e24822d21dff07dcc103b4d.png ../_images/81dcde7ff59c6d9d425e5d98892b9a3420058a89134a7ce390461375da09a746.png ../_images/a66d7464db25e9e79f7a2741260cf088a8d43376ceaad9eaf618737eb2cc7dac.png

Analyze learned representations#

pop_activity_dict, pattern_labels, unit_labels_dict = ut.compute_test_activity(network, test_dataloader, class_average=False, sort=True)
pattern_similarity_matrix_dict, neuron_similarity_matrix_dict = ut.compute_representational_similarity_matrix(pop_activity_dict, pattern_labels, unit_labels_dict, population='E', plot=True)
../_images/e36d5834d664a82447b185ec6e172ef5c80091236557ef49821bd9cdd96da8fe.png ../_images/4cd139c953c9ca21adb99850c77818f2d7c723290361d895b3257d11e226ac4b.png ../_images/ec5781e912cb86202bd77764d2f0bb11600ba9b35b2989c0868e3a3b4b6aee64.png

Hidden Layer 1#

receptive_fields_H1 = ut.compute_maxact_receptive_fields(network.H1.E, test_dataloader=test_dataloader)
eiann.plot.plot_receptive_fields(receptive_fields_H1, sort=True)
Optimizing receptive field images...
../_images/7f3a256d1f1e7859510611ce8b089a659e5eadf2f61ad31bc860c04197bf6592.png
metrics_dict = ut.compute_representation_metrics(population=network.H1.E, dataloader=test_dataloader, receptive_fields=receptive_fields_H1, plot=True)
../_images/45b4e8bcf2ba087e5c0c0111d2c6687dc6e72159d0c4c421edbeb90422b975f7.png

Hidden Layer 2#

receptive_fields_H2 = ut.compute_maxact_receptive_fields(network.H2.E, test_dataloader=test_dataloader)
eiann.plot.plot_receptive_fields(receptive_fields_H2, sort=True)
Optimizing receptive field images...
../_images/297d71bcd6878ac65bd3ab6207578b4d2a63f936f8fd84514eb9bb80de249cd3.png
metrics_dict = ut.compute_representation_metrics(population=network.H1.E, dataloader=test_dataloader, receptive_fields=receptive_fields_H2, plot=True)
../_images/ce359ab9a4b53d3f198349bb7f2b13137b191443d51397aebd47c9f9cd430528.png

Output layer#

receptive_fields_output = ut.compute_maxact_receptive_fields(network.Output.E, test_dataloader=test_dataloader)
eiann.plot.plot_receptive_fields(receptive_fields_output, sort=False)
Optimizing receptive field images...
../_images/4c83daa8490d39a421a1c3f5dfa28bdc1c08d14b346f5c887c2bd8a470c5e5e6.png