mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 18:12:19 +00:00
Bug-fix RE sign of target_rms
This commit is contained in:
parent
d1a669162c
commit
25724b5ce9
@ -137,9 +137,10 @@ class Eve(Optimizer):
|
||||
delta = exp_avg / denom
|
||||
|
||||
if p.numel() > 1:
|
||||
# avoid applying this weight-decay on "scaling factors" (which are scalar).
|
||||
is_below_target_rms = (p.norm() < (target_rms * (p.numel() ** 0.5)))
|
||||
p.mul_(1 - (weight_decay * is_below_target_rms))
|
||||
# avoid applying this weight-decay on "scaling factors"
|
||||
# (which are scalar).
|
||||
is_above_target_rms = (p.norm() > (target_rms * (p.numel() ** 0.5)))
|
||||
p.mul_(1 - (weight_decay * is_above_target_rms))
|
||||
p.addcdiv_(exp_avg, denom, value=-step_size)
|
||||
|
||||
return loss
|
||||
@ -149,5 +150,6 @@ class Eve(Optimizer):
|
||||
# if avg-change as rms(diff) / rms(params) equals 0.2, and rms(params) = 0.1,
|
||||
# then rm(diff) 0.1 * 0.2, var(diff) = (0.1 * 0.2)**2, = 0.0004. So var(diff per minibatch)
|
||||
# = 0.0004 / 4000 = 1e-07, rms(diff per minibatch) = 3.16e-04. So LR would be 3e-04.
|
||||
# Suggested lr_schedule?
|
||||
#
|
||||
# .. 6e-05 is 1/5 of that...
|
||||
|
Loading…
x
Reference in New Issue
Block a user