diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py index 4c9010124..f2fa6e85a 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py @@ -688,17 +688,18 @@ class ScaledAdam(BatchedOptimizer): # For these parameters that exceed stored_percentile, # flip the update sign if they would get larger absolute values - limit_sign = 1 - 2 * p_exceed * ((p.sign() * grad.sign()) < 0) + limit_sign = 1 - 2 * p_exceed * ((p.sign() * grad.sign()) < 0) * (p_abs > 10) # TODO: will remove the log bellow if random.random() < 0.1: - logging.info(f"p.shape: {p.shape}") - logging.info(f"stored_percentiles: {stored_percentiles.view(-1)}") - logging.info(f"proportion_exceed: {proportion_exceed.view(-1)}") - logging.info(f"p_abs_max: {p_abs.view(batch_size, -1).max(dim=-1)[0]}") mask = limit_sign == -1 - proportion_sign_change = mask.sum(dim=list(range(1, p.ndim))) / numel - logging.info(f"proportion_sign_change: {proportion_sign_change}") + if mask.any(): + logging.info(f"p.shape: {p.shape}") + logging.info(f"stored_percentiles: {stored_percentiles.view(-1)}") + logging.info(f"proportion_exceed: {proportion_exceed.view(-1)}") + logging.info(f"p_abs_max: {p_abs.view(batch_size, -1).max(dim=-1)[0]}") + proportion_sign_change = mask.sum(dim=list(range(1, p.ndim))) / numel + logging.info(f"proportion_sign_change: {proportion_sign_change}") return limit_sign