From 7f6fe02db9ee0cdc47ab1472c46c87735d99fd0a Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 14 Jul 2022 06:06:02 +0800 Subject: [PATCH] Fix formula for smoothing (was applying more smoothing than intended, and in the opposite sense to intended), also revert max_rms from 2.0 to 4.0 --- .../ASR/pruned_transducer_stateless7/optim.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py index 009755859..75b975711 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py @@ -953,13 +953,20 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of batch_size = rms.shape[0] size = rms.numel() // batch_size rms = rms ** param_pow - smooth = (smooth0 + - (smooth1 - smooth0) * size / (size + rank)) + + + # want expr to be of the form: smooth = alpha * size / (beta*rank + size) + # from rank==0, we get smooth0 = alpha * size/size, so alpha = smooth0. + # from setting rank==size, we get smooth1 = alpha * size / (beta*size * size) = alpha/(1+beta), + # so smooth1 == smooth0 / (1+beta), so (1+beta) = smooth0/smooth1, so beta=smooth0/smooth1 - 1 + smooth = smooth0 * size / ((smooth0/smooth1 - 1) * rank + size) + mean = _mean(rms, exclude_dims=[0], keepdim=True) rms += eps + smooth * mean new_mean = (eps + (smooth + 1) * mean) # mean of modified rms. ans = rms / new_mean + if True: # Apply max_rms max_rms = 4.0