From 5e58d2de75cd619cf2b4eb7e1c1e3a51f49d377f Mon Sep 17 00:00:00 2001 From: yaozengwei Date: Thu, 23 Mar 2023 20:06:03 +0800 Subject: [PATCH] add condition of p_abs > 10 --- .../ASR/pruned_transducer_stateless7/optim.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py index 4c9010124..f2fa6e85a 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py @@ -688,17 +688,18 @@ class ScaledAdam(BatchedOptimizer): # For these parameters that exceed stored_percentile, # flip the update sign if they would get larger absolute values - limit_sign = 1 - 2 * p_exceed * ((p.sign() * grad.sign()) < 0) + limit_sign = 1 - 2 * p_exceed * ((p.sign() * grad.sign()) < 0) * (p_abs > 10) # TODO: will remove the log bellow if random.random() < 0.1: - logging.info(f"p.shape: {p.shape}") - logging.info(f"stored_percentiles: {stored_percentiles.view(-1)}") - logging.info(f"proportion_exceed: {proportion_exceed.view(-1)}") - logging.info(f"p_abs_max: {p_abs.view(batch_size, -1).max(dim=-1)[0]}") mask = limit_sign == -1 - proportion_sign_change = mask.sum(dim=list(range(1, p.ndim))) / numel - logging.info(f"proportion_sign_change: {proportion_sign_change}") + if mask.any(): + logging.info(f"p.shape: {p.shape}") + logging.info(f"stored_percentiles: {stored_percentiles.view(-1)}") + logging.info(f"proportion_exceed: {proportion_exceed.view(-1)}") + logging.info(f"p_abs_max: {p_abs.view(batch_size, -1).max(dim=-1)[0]}") + proportion_sign_change = mask.sum(dim=list(range(1, p.ndim))) / numel + logging.info(f"proportion_sign_change: {proportion_sign_change}") return limit_sign