fig = plt.figure(figsize=(6.5, 1.6))
axes = gs.GridSpec(nrows=1, ncols=3, figure=fig,
left=0.07,right=0.98,
top=0.9, bottom=0.25,
wspace=0.4, hspace=0.4)
ax = {'OutputE': fig.add_subplot(axes[0]),
'H2E': fig.add_subplot(axes[1]),
'H1E': fig.add_subplot(axes[2])}
root_dir = ut.get_project_root()
for model_key in model_list:
model_dict = model_dict_all[model_key]
network_name = model_dict['config'].split('.')[0]
hdf5_path = root_dir + f"/EIANN/data/model_hdf5_plot_data/plot_data_{network_name}.h5"
with h5py.File(hdf5_path, 'r') as f:
data_dict = f[network_name]
# print(f"Generating plots for {model_dict['label']}")
population = 'H2DendI'
example_seed = model_dict['seeds'][0] # example seed to plot
for pop in ['OutputE', 'H2E', 'H1E']:
plot_confusion_all_seeds(data_dict, model_dict, ax=ax[pop], population=pop, type='line')
ax[pop].set_title(pop, fontsize=8, y=0.85)
ax['OutputE'].legend(loc='upper left', bbox_to_anchor=(0, -0.2), ncol=5)
fig.savefig(f"{root_dir}/EIANN/figures/{figure_name}.svg", dpi=300)
fig.savefig(f"{root_dir}/EIANN/figures/{figure_name}.png", dpi=300)