diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py index 009755859..75b975711 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py @@ -953,13 +953,20 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of batch_size = rms.shape[0] size = rms.numel() // batch_size rms = rms ** param_pow - smooth = (smooth0 + - (smooth1 - smooth0) * size / (size + rank)) + + + # want expr to be of the form: smooth = alpha * size / (beta*rank + size) + # from rank==0, we get smooth0 = alpha * size/size, so alpha = smooth0. + # from setting rank==size, we get smooth1 = alpha * size / (beta*size * size) = alpha/(1+beta), + # so smooth1 == smooth0 / (1+beta), so (1+beta) = smooth0/smooth1, so beta=smooth0/smooth1 - 1 + smooth = smooth0 * size / ((smooth0/smooth1 - 1) * rank + size) + mean = _mean(rms, exclude_dims=[0], keepdim=True) rms += eps + smooth * mean new_mean = (eps + (smooth + 1) * mean) # mean of modified rms. ans = rms / new_mean + if True: # Apply max_rms max_rms = 4.0