Figure 2: Multilayer learning: Backprop vs Hebb

Figure 2: Multilayer learning: Backprop vs Hebb#

import matplotlib
import matplotlib.pyplot as plt
import matplotlib.gridspec as gs
import h5py

import EIANN as eiann
import EIANN.utils as ut
from EIANN.generate_figures import *
eiann.update_plot_defaults()

Hide code cell source

def plot_E_pop_representations(model_list, model_dict_all, figure_name, plot_accuracy=True, plot_confusion=True):
    fig = plt.figure(figsize=(6.5, 5))
    axes = gs.GridSpec(nrows=3, ncols=len(model_list), figure=fig,                    
                        left=0.06,right=0.92, top=0.93, bottom=0.3,
                        wspace=0.32, hspace=0.35)
    
    axes_metrics = gs.GridSpec(nrows=1, ncols=3, figure=fig,
                            left=0.06,right=0.98, top=0.22, bottom=0.05, wspace=0.45, hspace=0.6)
    if plot_accuracy:
        ax_accuracy = fig.add_subplot(axes_metrics[0])  
    ax_selectivity = fig.add_subplot(axes_metrics[1])
    if plot_confusion:
        ax_confusion = fig.add_subplot(axes_metrics[2])

    root_dir = ut.get_project_root()

    model_dict_all["vanBP"]["display_name"] = "Backprop\n(Feedforward ANN)"
    model_dict_all["bpDale_learned"]["display_name"] = "Backprop\n(EIANN)"
    model_dict_all["HebbWN_topsup"]["display_name"] = "Hebb + Weight Norm."
    model_dict_all["BCM_topsup"]["display_name"] = "BCM"
    model_dict_all["BTSP_WT_hebbdend"]["label"] = "BTSP"

    for col, model_key in enumerate(model_list):
        model_dict = model_dict_all[model_key]
        network_name = model_dict['config'].split('.')[0]
        hdf5_path = root_dir + f"/EIANN/data/model_hdf5_plot_data/plot_data_{network_name}.h5"
        with h5py.File(hdf5_path, 'r') as f:
            data_dict = f[network_name]
            # print(f"Generating plots for {model_dict['label']}")

            #########################################################
            # Example heatmaps for E populations
            #########################################################
            example_seed = model_dict['seeds'][0]
            for row, population in enumerate(['OutputE','H2E','H1E']):
                # Activity plots: batch accuracy of each population to the test dataset
                ax = fig.add_subplot(axes[row, col])
                average_pop_activity_dict = data_dict[example_seed]['average_pop_activity_dict']
                num_units = average_pop_activity_dict[population].shape[1]
                eiann.plot_batch_accuracy_from_data(average_pop_activity_dict, population=population, sort=True, ax=ax, cbar=False)    
                ax.set_yticks([0,num_units-1])
                ax.set_yticklabels([1,num_units])
                if col == 0:
                    ax.set_ylabel(f'{population} unit', labelpad=-7 if row==0 else -10)
                else:
                    ax.set_ylabel('')
                    if col==len(model_list)-1:
                        pos = ax.get_position()                
                        cbar_ax = fig.add_axes([pos.x1 + 0.01, pos.y0, 0.008, pos.height])                
                        
                        cbar = matplotlib.colorbar.ColorbarBase(cbar_ax, cmap='viridis', orientation='vertical')
                        cbar.set_label('Normalized class-\naveraged activity', labelpad=14, rotation=270)
                        cbar.set_ticks([0, 1])
                if row == 0:
                    ax.set_title(model_dict["display_name"])
                ax.set_xlabel(ax.get_xlabel(), labelpad=0)

            #################
            # Plot metrics
            #################
            if plot_accuracy:   
                plot_accuracy_all_seeds(data_dict, model_dict, ax=ax_accuracy, legend=True)
            if plot_confusion:
                plot_confusion_all_seeds(data_dict, model_dict, ax=ax_confusion, population='H1E', type='line')

            populations_to_plot = ['H1E', 'H2E']
            plot_metric_all_seeds(data_dict, model_dict, populations_to_plot=populations_to_plot, ax=ax_selectivity, metric_name='selectivity', plot_type='violin', side='both')
            ax_selectivity.set_ylabel(f"Selectivity\n(hidden layers)", labelpad=-2)

    fig.savefig(f"{root_dir}/EIANN/figures/{figure_name}.svg", dpi=300)
    fig.savefig(f"{root_dir}/EIANN/figures/{figure_name}.png", dpi=300)
    return fig
figure_name = "Fig2_vanBP_bpDale_hebb"
model_list = ["vanBP", "bpDale_learned", "HebbWN_topsup", "BCM_topsup"]
model_dict_all = load_model_dict()
generate_hdf5_all_seeds(model_list, model_dict_all, variables_to_save=["average_pop_activity_dict", "metrics_dict", "test_accuracy_history"], recompute=None)

fig = plot_E_pop_representations(model_list, model_dict_all, figure_name)
../_images/96cf710421b02b8f613e7dd5d13928a9e2f01870032a794899c4c6c1cc0cad4e.png