def plot_learning_rule_diagram(axes_list):
a_pre = 1 # presynaptic activation
d = np.linspace(-1, 1, 100) # dendritic state
mathfont = 'cm'
# mathfont = 'stix'
pad = -1
# Linear Dendritic State (LDS, used in Backprop)
dW_BP = d*a_pre
ax = axes_list[0]
ax.plot(d, dW_BP, color='black', linewidth=1.)
ax.hlines(0, -1, 1, linestyle='--', color='gray', linewidth=1, alpha=0.5)
ax.vlines(0, -1, 1, linestyle='--', color='gray', linewidth=1, alpha=0.5)
ax.set_ylabel(r'$\Delta W$', math_fontfamily=mathfont, fontsize=10, rotation=0, labelpad=5, y=0.4)
ax.set_xlabel(' $\hat{d}$', math_fontfamily=mathfont, fontsize=10, labelpad=pad, loc='center')
ax.set_xticks([0])
ax.set_yticks([-1, 0, 1])
ax.set_title('Linear Dendritic State (LDS)', fontsize=8)
# BTSP
w_max = 2
temporal_discount=0.1
dep_th=0.01
dep_width=0.01
q_dep = ut.get_scaled_rectified_sigmoid(dep_th, dep_th + dep_width)
colors = ['deepskyblue', 'royalblue', 'darkblue']
ax = axes_list[1]
ax.hlines(0, 0, 3, linestyle='--', color='gray', linewidth=1, alpha=0.5)
for i,w in enumerate(torch.tensor([0.,0.5,1])):
dW_prev = (w_max-w)*a_pre*temporal_discount - w*q_dep(torch.tensor(a_pre*temporal_discount))
dW_curr = (w_max-w)*a_pre - w*q_dep(torch.tensor(a_pre))
dW_next = (w_max-w)*a_pre*temporal_discount - w*q_dep(torch.tensor(a_pre*temporal_discount))
ax.plot([0,0.9, 1.05,1.95, 2.1,3], [dW_prev,dW_prev, dW_curr,dW_curr, dW_next,dW_next], color=colors[i], linewidth=1)
ax.text(3.1, dW_next-0.05, fr'$w={w}$', fontsize=6, color=colors[i], math_fontfamily=mathfont)
ax.set_ylim(-1.5,3)
ax.set_yticks([0])
ax.set_xticks([0.5,1.5,2.5])
ax.set_xticklabels(['$x_{-1}$','$x_0$','$x_1$'], math_fontfamily=mathfont, fontsize=10)
for label in ax.get_xticklabels():
label.set_y(label.get_position()[1] + 0.05)
ax.set_title('BTSP', fontsize=8)
# Hebb Temporal Contrast
delta_a = np.linspace(-1, 1, 100)
dW_HTC = delta_a * a_pre
ax = axes_list[2]
ax.hlines(0, -1, 1, linestyle='--', color='gray', linewidth=1, alpha=0.5)
ax.vlines(0, -1, 1, linestyle='--', color='gray', linewidth=1, alpha=0.5)
ax.plot(delta_a, dW_HTC, color='black', linewidth=1.)
ax.set_xlabel('$\Delta c = \\tilde{c} - c$', math_fontfamily=mathfont, fontsize=10, labelpad=pad)
ax.set_ylabel(r'$\Delta W$', math_fontfamily=mathfont, fontsize=10, rotation=0, labelpad=5, y=0.4)
ax.set_xticks([0])
ax.set_yticks([-1, 0, 1])
ax.set_title('Temp. Contrastive Hebb', fontsize=8, x=0.4)
# BCM
a_post = np.linspace(0, 1, 100)
theta = 0.5
dW_BCM = a_pre * a_post * (a_post - theta)
ax = axes_list[3]
ax.hlines(0, -0.2, 1, linestyle='--', color='gray', linewidth=1, alpha=0.5)
ax.vlines(0, -0.2, 0.4, linestyle='--', color='gray', linewidth=1, alpha=0.5)
ax.plot(a_post, dW_BCM, color='black', linewidth=1.)
theta2 = 0.8
dW_BCM2 = a_pre * a_post * (a_post - theta2)
ax.plot(a_post, dW_BCM2, color='gray', linewidth=1.)
ax.set_ylim(-0.2, 0.3)
ax.set_xticks([0])
ax.set_yticks([0])
ax.set_xlabel('$\\tilde{a}$', math_fontfamily=mathfont, fontsize=10, labelpad=pad)
ax.vlines(theta, -0.07, 0.07, linestyle='--', color='k', linewidth=0.55, alpha=1)
ax.text(theta-0.1, 0.018, r'$\theta$', fontsize=8, ha='center', math_fontfamily=mathfont)
ax.annotate('', xy=(theta+0.02, 0.016), xytext=(theta2+0.05, 0.016), arrowprops=dict(arrowstyle='<|-', color='red', linewidth=0.8, ), ha='center')
ax.set_title('BCM', fontsize=8)
# Hebb
a_post = np.linspace(0, 1, 100)
theta = 0.5
dW_hebb = a_pre * a_post
ax = axes_list[4]
ax.hlines(0, -0.3, 1, linestyle='--', color='gray', linewidth=1, alpha=0.5)
ax.vlines(0, -0.3, 1, linestyle='--', color='gray', linewidth=1, alpha=0.5)
ax.plot(a_post, dW_hebb, color='black', linewidth=1.)
ax.set_ylim(-0.3, 1)
ax.set_xlim(-0.3, 1)
ax.set_xticks([0])
ax.set_yticks([0])
ax.set_xlabel('$\\tilde{a}$', math_fontfamily=mathfont, fontsize=10, labelpad=pad)
ax.set_title('Sup. Hebb + W Norm.', fontsize=8)