mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 10:02:22 +00:00
add condition of p_abs > 10
This commit is contained in:
parent
8c91b9d0cd
commit
5e58d2de75
@ -688,17 +688,18 @@ class ScaledAdam(BatchedOptimizer):
|
||||
|
||||
# For these parameters that exceed stored_percentile,
|
||||
# flip the update sign if they would get larger absolute values
|
||||
limit_sign = 1 - 2 * p_exceed * ((p.sign() * grad.sign()) < 0)
|
||||
limit_sign = 1 - 2 * p_exceed * ((p.sign() * grad.sign()) < 0) * (p_abs > 10)
|
||||
|
||||
# TODO: will remove the log bellow
|
||||
if random.random() < 0.1:
|
||||
logging.info(f"p.shape: {p.shape}")
|
||||
logging.info(f"stored_percentiles: {stored_percentiles.view(-1)}")
|
||||
logging.info(f"proportion_exceed: {proportion_exceed.view(-1)}")
|
||||
logging.info(f"p_abs_max: {p_abs.view(batch_size, -1).max(dim=-1)[0]}")
|
||||
mask = limit_sign == -1
|
||||
proportion_sign_change = mask.sum(dim=list(range(1, p.ndim))) / numel
|
||||
logging.info(f"proportion_sign_change: {proportion_sign_change}")
|
||||
if mask.any():
|
||||
logging.info(f"p.shape: {p.shape}")
|
||||
logging.info(f"stored_percentiles: {stored_percentiles.view(-1)}")
|
||||
logging.info(f"proportion_exceed: {proportion_exceed.view(-1)}")
|
||||
logging.info(f"p_abs_max: {p_abs.view(batch_size, -1).max(dim=-1)[0]}")
|
||||
proportion_sign_change = mask.sum(dim=list(range(1, p.ndim))) / numel
|
||||
logging.info(f"proportion_sign_change: {proportion_sign_change}")
|
||||
|
||||
return limit_sign
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user