Figure S5: Robustness to perturbations

Figure S5: Robustness to perturbations#

import numpy as np
from scipy.stats import kurtosis
import matplotlib.pyplot as plt
import matplotlib.gridspec as gs
import h5py

import EIANN as eiann
import EIANN.utils as ut
from EIANN.generate_figures import *
eiann.update_plot_defaults()
figure_name = "FigS5_robustness_to_perturbations"
model_list = ["vanBP", "bpDale_fixed", "HebbWN_topsup", "bpLike_WT_hebbdend"]
model_dict_all = load_model_dict()
generate_hdf5_all_seeds(model_list, model_dict_all, variables_to_save = ['weights', 'noise_sensitivity', 'robustness_to_pruning'], recompute=None)

Hide code cell source

fig = plt.figure(figsize=(6.5, 3.5))
axes = gs.GridSpec(nrows=2, ncols=4, figure=fig,                    
                    left=0.07,right=0.98,
                    top=0.95, bottom=0.5,
                    wspace=0.3, hspace=0.5)

axes_perturbations = gs.GridSpec(nrows=2, ncols=3, figure=fig,                    
                    left=0.07,right=0.98,
                    top=0.35, bottom=0.08,
                    wspace=0.35, hspace=0.5, width_ratios=[0.6, 1, 1])
ax_kurtosisH1E = fig.add_subplot(axes_perturbations[0, 0])
ax_kurtosisH2E = fig.add_subplot(axes_perturbations[1, 0])
ax_noise = fig.add_subplot(axes_perturbations[:, 1]) 
ax_pruning = fig.add_subplot(axes_perturbations[:, 2])

root_dir = ut.get_project_root()

model_dict_all["vanBP"]["display_name"] = "Feedforward ANN \nBackprop"
model_dict_all["bpLike_WT_hebbdend"]["display_name"] = "EIANN \nDendritic Target Propagation"
model_dict_all["bpDale_fixed"]["display_name"] = "EIANN \nBackprop"
model_dict_all["HebbWN_topsup"]["display_name"] = "EIANN \nHebb + Weight Norm."

model_dict_all["vanBP"]["label"] = "Backprop (ANN)"
model_dict_all["bpLike_WT_hebbdend"]["label"] = "Dend Target Prop"
model_dict_all["bpDale_fixed"]["label"] = "Backprop (EIANN)"
model_dict_all["HebbWN_topsup"]["label"] = "Hebb"

for col, model_key in enumerate(model_list):
    model_dict = model_dict_all[model_key]
    network_name = model_dict['config'].split('.')[0]
    hdf5_path = root_dir + f"/EIANN/data/model_hdf5_plot_data/plot_data_{network_name}.h5"
    with h5py.File(hdf5_path, 'r') as f:
        data_dict = f[network_name]
        # print(f"Generating plots for {model_dict['label']}")

        #########################
        # Weight distributions
        #########################
        H1E_weights = []
        H2E_weights = []
        for seed in model_dict['seeds']:
            H1E_weights.extend(data_dict[seed]['weights']['final_weights']['H1E_InputE'][:].flatten())
            H2E_weights.extend(data_dict[seed]['weights']['final_weights']['H2E_H1E'][:].flatten())

        ax = fig.add_subplot(axes[0, col])
        ax.hist(H1E_weights, bins=70, alpha=0.5, color=model_dict["color"], density=False)
        ax.set_xlabel("H1E Weights", labelpad=0)
        if col == 0:
            ax.set_ylabel("Count")
        ax.set_title(model_dict['display_name'])
        ax.set_yscale('log')

        ax = fig.add_subplot(axes[1, col])
        ax.hist(H2E_weights, bins=60, alpha=0.5, color=model_dict["color"], density=False)        
        ax.set_xlabel("H2E Weights", labelpad=0)
        if col == 0:
            ax.set_ylabel("Count")
        ax.set_yscale('log')

        projection = 'H1E_InputE'
        plot_kurtosis_all_seeds(data_dict, model_dict, projection, ax=ax_kurtosisH1E)
        projection = 'H2E_H1E'
        plot_kurtosis_all_seeds(data_dict, model_dict, projection, ax=ax_kurtosisH2E)

        ###############################
        # Robustness to input noise
        ###############################
        accuracy_list = []
        for seed in model_dict['seeds']:
            noise_stds, noise_accuracy = data_dict[seed]['noise_sensitivity']
            accuracy_list.append(noise_accuracy)
        accuracy_array = np.array(accuracy_list)
        mean_accuracy = np.mean(accuracy_array, axis=0)
        mean_accuracy_norm = mean_accuracy / np.max(mean_accuracy)
        std_accuracy_norm = np.std(accuracy_array, axis=0) / np.max(mean_accuracy)
        ax_noise.plot(noise_stds, mean_accuracy_norm, label=model_dict['label'], linewidth=1, color=model_dict["color"], alpha=0.7)
        ax_noise.fill_between(noise_stds, mean_accuracy_norm - std_accuracy_norm, mean_accuracy_norm + std_accuracy_norm, alpha=0.2, color=model_dict["color"], linewidth=0.) 
        ax_noise.set_xlim([0, 1])
        ax_noise.set_ylim([0, 1])
        ax_noise.set_yticks([0, 0.5, 1], minor=False)
        ax_noise.set_yticks(np.linspace(0, 1, 5), minor=True)
        ax_noise.legend(frameon=False, fontsize=6, loc='lower left', bbox_to_anchor=(0., -0.05))
        ax_noise.grid(True, axis='y', color='gray', linewidth=0.5, alpha=0.3, which='both')
        ax_noise.set_xlabel("Input noise std")
        ax_noise.set_ylabel("Normalized accuracy")
        ax_noise.set_title("Noise sensitivity", fontsize=8)

        ##################################
        # Robustness to synaptic pruning
        ##################################
        accuracy_list = []
        for seed in model_dict['seeds']:
            fraction_to_prune, accuracy = data_dict[seed]['robustness_to_pruning']
            accuracy_list.append(accuracy)
        accuracy_array = np.array(accuracy_list)
        mean_accuracy = np.mean(accuracy_array, axis=0)
        mean_accuracy_norm = mean_accuracy / np.max(mean_accuracy)
        std_accuracy_norm = np.std(accuracy_array, axis=0) / np.max(mean_accuracy)
        ax_pruning.plot(fraction_to_prune, mean_accuracy_norm, label=model_key, linewidth=1, color=model_dict["color"], alpha=0.7)
        ax_pruning.fill_between(fraction_to_prune, mean_accuracy_norm - std_accuracy_norm, mean_accuracy_norm + std_accuracy_norm, alpha=0.2, color=model_dict["color"], linewidth=0.) 
        ax_pruning.set_xlim([0, 1])
        ax_pruning.set_ylim([0, 1])
        ax_pruning.set_yticks([0, 0.5, 1], minor=False)
        ax_pruning.set_yticks(np.linspace(0, 1, 5), minor=True)
        ax_pruning.grid(True, axis='y', color='gray', linewidth=0.5, alpha=0.3, which='both')
        ax_pruning.set_xlabel("Proportion pruned")
        ax_pruning.set_ylabel("Normalized accuracy")
        ax_pruning.set_title("Robustness to pruning", fontsize=8)

ax_kurtosisH1E.set_ylabel("Kurtosis")
ax_kurtosisH2E.set_ylabel("Kurtosis", labelpad=1)
ax_kurtosisH1E.set_title("H1E", x=0.1, y=0.55)
ax_kurtosisH2E.set_title("H2E", x=0.1, y=0.55)
ax_kurtosisH1E.set_xticklabels([])

plt.show()

fig.savefig(f"{root_dir}/EIANN/figures/{figure_name}.svg", dpi=300)
fig.savefig(f"{root_dir}/EIANN/figures/{figure_name}.png", dpi=300)
../_images/22198ef951025ad874a959662d99b42a8ec84728d9e44d26ede3cea4337f4f07.png