From ddceb7963b5b608438c7ad1403501a937bd41ab4 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 23 Jul 2022 15:27:48 +0800 Subject: [PATCH] Interpolate between iterative estimate of scale, and original value. --- egs/librispeech/ASR/pruned_transducer_stateless7/optim.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py index f6539e1f1..a0d9eacc9 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py @@ -757,6 +757,10 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of (batch_size, num_blocks, block_size, block_size) = Q.shape scale = cur_scales[dim].reshape(batch_size, num_blocks, block_size, 1) + + # Geometrically interpolate scale with P_proj[dim].sqrt() + scale = (scale * P_proj[dim].reshape(batch_size, num_blocks, block_size, 1).sqrt()).sqrt() + # The following normalization step will ensure the Frobenius # norm is unchanged, from applying this scale: at least, # assuming "grad / denom" gives uncorrelated outputs so that @@ -2163,7 +2167,7 @@ def _test_eve_cain(): start = timeit.default_timer() avg_loss = 0.0 - for epoch in range(150): + for epoch in range(180): scheduler.step_epoch() #if epoch == 100 and iter in [2,3]: # optim.reset_speedup() # check it doesn't crash.