From 25724b5ce9f786f644e662de6e2636add523ce89 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 5 Apr 2022 13:49:35 +0800 Subject: [PATCH] Bug-fix RE sign of target_rms --- egs/librispeech/ASR/pruned_transducer_stateless2/optim.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless2/optim.py index b17ebba7c..2b40dda45 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/optim.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/optim.py @@ -137,9 +137,10 @@ class Eve(Optimizer): delta = exp_avg / denom if p.numel() > 1: - # avoid applying this weight-decay on "scaling factors" (which are scalar). - is_below_target_rms = (p.norm() < (target_rms * (p.numel() ** 0.5))) - p.mul_(1 - (weight_decay * is_below_target_rms)) + # avoid applying this weight-decay on "scaling factors" + # (which are scalar). + is_above_target_rms = (p.norm() > (target_rms * (p.numel() ** 0.5))) + p.mul_(1 - (weight_decay * is_above_target_rms)) p.addcdiv_(exp_avg, denom, value=-step_size) return loss @@ -149,5 +150,6 @@ class Eve(Optimizer): # if avg-change as rms(diff) / rms(params) equals 0.2, and rms(params) = 0.1, # then rm(diff) 0.1 * 0.2, var(diff) = (0.1 * 0.2)**2, = 0.0004. So var(diff per minibatch) # = 0.0004 / 4000 = 1e-07, rms(diff per minibatch) = 3.16e-04. So LR would be 3e-04. +# Suggested lr_schedule? # # .. 6e-05 is 1/5 of that...