def plot_E_pop_representations(model_list, model_dict_all, figure_name, plot_accuracy=True, plot_confusion=True):
fig = plt.figure(figsize=(6.5, 5))
axes = gs.GridSpec(nrows=3, ncols=len(model_list), figure=fig,
left=0.06,right=0.92, top=0.93, bottom=0.3,
wspace=0.32, hspace=0.35)
axes_metrics = gs.GridSpec(nrows=1, ncols=3, figure=fig,
left=0.06,right=0.98, top=0.22, bottom=0.05, wspace=0.45, hspace=0.6)
if plot_accuracy:
ax_accuracy = fig.add_subplot(axes_metrics[0])
ax_selectivity = fig.add_subplot(axes_metrics[1])
if plot_confusion:
ax_confusion = fig.add_subplot(axes_metrics[2])
root_dir = ut.get_project_root()
model_dict_all["vanBP"]["display_name"] = "Backprop\n(Feedforward ANN)"
model_dict_all["bpDale_learned"]["display_name"] = "Backprop\n(EIANN)"
model_dict_all["HebbWN_topsup"]["display_name"] = "Hebb + Weight Norm."
model_dict_all["BCM_topsup"]["display_name"] = "BCM"
model_dict_all["BTSP_WT_hebbdend"]["label"] = "BTSP"
for col, model_key in enumerate(model_list):
model_dict = model_dict_all[model_key]
network_name = model_dict['config'].split('.')[0]
hdf5_path = root_dir + f"/EIANN/data/model_hdf5_plot_data/plot_data_{network_name}.h5"
with h5py.File(hdf5_path, 'r') as f:
data_dict = f[network_name]
# print(f"Generating plots for {model_dict['label']}")
#########################################################
# Example heatmaps for E populations
#########################################################
example_seed = model_dict['seeds'][0]
for row, population in enumerate(['OutputE','H2E','H1E']):
# Activity plots: batch accuracy of each population to the test dataset
ax = fig.add_subplot(axes[row, col])
average_pop_activity_dict = data_dict[example_seed]['average_pop_activity_dict']
num_units = average_pop_activity_dict[population].shape[1]
eiann.plot_batch_accuracy_from_data(average_pop_activity_dict, population=population, sort=True, ax=ax, cbar=False)
ax.set_yticks([0,num_units-1])
ax.set_yticklabels([1,num_units])
if col == 0:
ax.set_ylabel(f'{population} unit', labelpad=-7 if row==0 else -10)
else:
ax.set_ylabel('')
if col==len(model_list)-1:
pos = ax.get_position()
cbar_ax = fig.add_axes([pos.x1 + 0.01, pos.y0, 0.008, pos.height])
cbar = matplotlib.colorbar.ColorbarBase(cbar_ax, cmap='viridis', orientation='vertical')
cbar.set_label('Normalized class-\naveraged activity', labelpad=14, rotation=270)
cbar.set_ticks([0, 1])
if row == 0:
ax.set_title(model_dict["display_name"])
ax.set_xlabel(ax.get_xlabel(), labelpad=0)
#################
# Plot metrics
#################
if plot_accuracy:
plot_accuracy_all_seeds(data_dict, model_dict, ax=ax_accuracy, legend=True)
if plot_confusion:
plot_confusion_all_seeds(data_dict, model_dict, ax=ax_confusion, population='H1E', type='line')
populations_to_plot = ['H1E', 'H2E']
plot_metric_all_seeds(data_dict, model_dict, populations_to_plot=populations_to_plot, ax=ax_selectivity, metric_name='selectivity', plot_type='violin', side='both')
ax_selectivity.set_ylabel(f"Selectivity\n(hidden layers)", labelpad=-2)
fig.savefig(f"{root_dir}/EIANN/figures/{figure_name}.svg", dpi=300)
fig.savefig(f"{root_dir}/EIANN/figures/{figure_name}.png", dpi=300)
return fig