diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py index a4a58a0c7..0ef4c7d19 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py @@ -58,6 +58,7 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of learning the scale on the parameters (we'll keep it >= this size) param_max_rms: Maximum root-mean-square value of parameter tensor, for purposes of learning the scale on the parameters (we'll keep it <= this size) + scalar_max: Maximum absolute value for scalar parameters size_update_period: The periodicity, in steps, with which we update the size (scale) of the parameter tensor. This is provided to save a little time. lr_update_period: The periodicity, in steps, with which we update the learning-rate matrices. @@ -75,6 +76,7 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of eps=1.0e-08, param_min_rms=1.0e-05, param_max_rms=2.0, + scalar_max=2.0, size_update_period=4, lr_update_period=200, grad_cov_period=3, @@ -92,6 +94,7 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of eps=eps, param_min_rms=param_min_rms, param_max_rms=param_max_rms, + scalar_max=scalar_max, size_update_period=size_update_period, lr_update_period=lr_update_period, grad_cov_period=grad_cov_period, @@ -514,9 +517,6 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of - - - def _step(self, group: dict, p: Tensor, @@ -618,7 +618,8 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of delta = state["delta"] delta.add_(grad / denom, alpha=alpha) - + scalar_max = state["scalar_max"] + p.clamp_(min=-scalar_max, max=scalar_max) def _project(self, x: Tensor,