Bug-fix RE sign of target_rms

This commit is contained in:
Daniel Povey 2022-04-05 13:49:35 +08:00
parent d1a669162c
commit 25724b5ce9

View File

@ -137,9 +137,10 @@ class Eve(Optimizer):
delta = exp_avg / denom
if p.numel() > 1:
# avoid applying this weight-decay on "scaling factors" (which are scalar).
is_below_target_rms = (p.norm() < (target_rms * (p.numel() ** 0.5)))
p.mul_(1 - (weight_decay * is_below_target_rms))
# avoid applying this weight-decay on "scaling factors"
# (which are scalar).
is_above_target_rms = (p.norm() > (target_rms * (p.numel() ** 0.5)))
p.mul_(1 - (weight_decay * is_above_target_rms))
p.addcdiv_(exp_avg, denom, value=-step_size)
return loss
@ -149,5 +150,6 @@ class Eve(Optimizer):
# if avg-change as rms(diff) / rms(params) equals 0.2, and rms(params) = 0.1,
# then rm(diff) 0.1 * 0.2, var(diff) = (0.1 * 0.2)**2, = 0.0004. So var(diff per minibatch)
# = 0.0004 / 4000 = 1e-07, rms(diff per minibatch) = 3.16e-04. So LR would be 3e-04.
# Suggested lr_schedule?
#
# .. 6e-05 is 1/5 of that...