add condition of p_abs > 10

This commit is contained in:
yaozengwei 2023-03-23 20:06:03 +08:00
parent 8c91b9d0cd
commit 5e58d2de75

View File

@ -688,17 +688,18 @@ class ScaledAdam(BatchedOptimizer):
# For these parameters that exceed stored_percentile,
# flip the update sign if they would get larger absolute values
limit_sign = 1 - 2 * p_exceed * ((p.sign() * grad.sign()) < 0)
limit_sign = 1 - 2 * p_exceed * ((p.sign() * grad.sign()) < 0) * (p_abs > 10)
# TODO: will remove the log bellow
if random.random() < 0.1:
logging.info(f"p.shape: {p.shape}")
logging.info(f"stored_percentiles: {stored_percentiles.view(-1)}")
logging.info(f"proportion_exceed: {proportion_exceed.view(-1)}")
logging.info(f"p_abs_max: {p_abs.view(batch_size, -1).max(dim=-1)[0]}")
mask = limit_sign == -1
proportion_sign_change = mask.sum(dim=list(range(1, p.ndim))) / numel
logging.info(f"proportion_sign_change: {proportion_sign_change}")
if mask.any():
logging.info(f"p.shape: {p.shape}")
logging.info(f"stored_percentiles: {stored_percentiles.view(-1)}")
logging.info(f"proportion_exceed: {proportion_exceed.view(-1)}")
logging.info(f"p_abs_max: {p_abs.view(batch_size, -1).max(dim=-1)[0]}")
proportion_sign_change = mask.sum(dim=list(range(1, p.ndim))) / numel
logging.info(f"proportion_sign_change: {proportion_sign_change}")
return limit_sign