EI network - Backprop (Dale’s Law, fixed inhibition)#
Here we will show the analysis for a single example seed of a network described in Galloni et al. 2025.
import EIANN.EIANN as eiann
from EIANN import utils as ut
eiann.plot.update_plot_defaults()
root_dir = ut.get_project_root()
%load_ext autoreload
%autoreload 2
1. Load MNIST data#
train_dataloader, val_dataloader, test_dataloader, data_generator = ut.get_MNIST_dataloaders()
2. Load optimized pre-trained EIANN model:#
EI network trained with Backprop#
network_name = "20231129_EIANN_2_hidden_mnist_bpDale_relu_SGD_config_G_complete_optimized"
network_seed = 66049
data_seed = 257
If you want to train this network configuration from scratch, you can build a new network object directly from the .yaml configuration file and then train it:
# Create network object
config_file_path = f"../network_config/mnist/{network_name}.yaml"
network = ut.build_EIANN_from_config(config_file_path, network_seed=network_seed)
# Train network
data_generator.manual_seed(data_seed)
network.train(train_dataloader, val_dataloader,
epochs = 1,
samples_per_epoch = 20_000,
val_interval = (0, -1, 100),
store_history = True,
store_history_interval = (0, -1, 100),
store_dynamics = False,
store_params = True,
status_bar = True)
# Optional: Save network object to pickle file
saved_network_path = root_dir + f"/EIANN/data/mnist/{network_name}_{network_seed}_{data_seed}.pkl"
ut.save_network(spiral_net, path=saved_network_path)
In this case, since we have already trained the network, we will simply load the saved network object that is stored in a pickle (.pkl) file:
saved_network_path = root_dir + f"/EIANN/data/saved_network_pickles/mnist/{network_name}_{network_seed}_{data_seed}.pkl"
network = ut.load_network(saved_network_path)
network.name = network_name
network.seed = f"{network_seed}_{data_seed}"
Loading network from '/Users/ag1880/github-repos/Milstein-Lab/EIANN/EIANN/data/mnist/20231129_EIANN_2_hidden_mnist_bpDale_relu_SGD_config_G_complete_optimized_66049_257.pkl'
Network successfully loaded from '/Users/ag1880/github-repos/Milstein-Lab/EIANN/EIANN/data/mnist/20231129_EIANN_2_hidden_mnist_bpDale_relu_SGD_config_G_complete_optimized_66049_257.pkl'
3. Training results#
eiann.plot.plot_loss_history(network, ylim=(-0.01, 0.2))
eiann.plot.plot_accuracy_history(network)
eiann.plot.plot_error_history(network)



eiann.plot.plot_loss_landscape(test_dataloader, network, num_points=20, extension=0.9, vmax_scale=1.2, scale='log')

4. Analyze population activities#
eiann.plot.plot_batch_accuracy(network, test_dataloader, population='E')
Batch accuracy = 77.0199966430664%



pop_dynamics_dict = ut.compute_test_activity_dynamics(network, test_dataloader, plot=True, normalize=True) # We will evaluate the network dynamics by presenting the test dataset and recording the neuron activity.
print(pop_dynamics_dict['H1E'].shape) # Since we are store the dynamics here, each population should have activity of shape (timesteps, data samples, neurons)
torch.Size([15, 10000, 500])

5. Analyze learned representations#
pop_activity_dict, pattern_labels, unit_labels_dict = ut.compute_test_activity(network, test_dataloader, class_average=False, sort=True)
pattern_similarity_matrix_dict, neuron_similarity_matrix_dict = ut.compute_representational_similarity_matrix(pop_activity_dict, pattern_labels, unit_labels_dict, population='E', plot=True)



within_class_pattern_similarity_dict, between_class_pattern_similarity_dict, within_class_unit_similarity_dict, between_class_unit_similarity_dict = ut.compute_within_class_representational_similarity(network, test_dataloader,
population='E', plot=True)
Plotting pattern similarity for population: H1E
Plotting unit similarity for population: H1E
Plotting pattern similarity for population: H2E
Plotting unit similarity for population: H2E

5.3 Output layer#
receptive_fields_output = ut.compute_maxact_receptive_fields(network.Output.E, test_dataloader=test_dataloader, export=True, export_path=root_dir+f"/EIANN/data/Figure_data_hdf5/plot_data_{network_name}.h5")
eiann.plot.plot_receptive_fields(receptive_fields_output, sort=False)
Loading maxact_receptive_fields_OutputE from file: /Users/ag1880/github-repos/Milstein-Lab/EIANN/EIANN/data/Figure_data_hdf5/plot_data_20231129_EIANN_2_hidden_mnist_bpDale_relu_SGD_config_G_complete_optimized.h5
