correct errors in _limit_values_sign

This commit is contained in:
yaozengwei 2023-03-22 19:31:10 +08:00
parent cbfd459df7
commit 8c91b9d0cd
2 changed files with 50 additions and 32 deletions

View File

@ -164,11 +164,9 @@ class ScaledAdam(BatchedOptimizer):
of the parameter tensor. This is provided to save a little time
in the update.
clipping_update_period: if clipping_scale is specified, this is the period
p_limit_values: The probability (e.g., 0.1) to modify the update sign so as to prevent
absolute-values of any weight tensor from being over a certain percentile of
the distribution of that parameter tensor's absolute values.
percentile_limit: The percentile (e.g., 0.9) over which the parameter absolute values would be
limited.
percentile_limit: The parameter absolute values over 1-percentile_limit (e.g., 95%) percentile will be limited.
p_limit_values: The probability (e.g., 0.1) to modify the update sign, so as to limit the
parameter absolute values that are larger than 1-percentile_limit (e.g., 95%) percentile.
"""
def __init__(
@ -186,8 +184,8 @@ class ScaledAdam(BatchedOptimizer):
clipping_update_period=100,
parameters_names=None,
show_dominant_parameters=True,
percentile_limit=0.05,
p_limit_values=0.0,
percentile_limit=0.9,
):
assert parameters_names is not None, (
@ -206,8 +204,8 @@ class ScaledAdam(BatchedOptimizer):
scalar_max=scalar_max,
size_update_period=size_update_period,
clipping_update_period=clipping_update_period,
p_limit_values=p_limit_values,
percentile_limit=percentile_limit,
p_limit_values=p_limit_values,
)
super(ScaledAdam, self).__init__(params, defaults)
@ -649,7 +647,16 @@ class ScaledAdam(BatchedOptimizer):
p.add_(delta)
def _limit_values_sign(self, group: dict, p: Tensor, grad: Tensor, state: dict):
"""Decide whether to modify the sign of the update.
"""Decide whether to modify the sign of the cerrent update.
For each parameter tensor in p, we store the centain percentile (e.g., 95%),
say stored_percentiles.
If more than 5% parameter absolute values are larger than stored_percentiles:
Increase stored_percentiles by (1.0 + lr / p_limit_values).
Else:
Decrease stored_percentiles by (1.0 - lr / p_limit_values).
For all parameters whose absolute values are larger than stored_percentiles,
modify the sign of the update so that it points 'inward'.
Args:
group: A dict which will be used to look up configuration values
@ -660,37 +667,40 @@ class ScaledAdam(BatchedOptimizer):
Returns: A tensor with same shape as p, filled with 1 or -1.
"""
lr = group["lr"]
# The probability to limit the absolute values
p_limit_values = group["p_limit_values"] # e.g., 0.1
percentile_limit = group["percentile_limit"] # e.g., 0.9
# it has a shape like (batch_size, 1, 1, 1, 1)
# The parameter absolute values over 1-percentile_limit percentile will be limited.
percentile_limit = group["percentile_limit"] # e.g., 0.05
# It stores the percentiles (i.e, 1-percentile_limit) of the parameter absolute values,
# with a shape like (batch_size, 1, 1, 1, 1)
stored_percentiles = state["stored_percentiles"]
p_abs = p.abs()
dtype = p.dtype
batch_size = p.shape[0]
numel = p.numel() // batch_size
p_abs = p.abs()
numel = p.numel() / batch_size
k = math.ceil(numel * (1 - percentile_limit))
percentiles = p_abs.view(batch_size, -1).topk(k=k, dim=-1)[0][:, -1] # (batch,)
# If True, stored_percentiles should be increased
percentiles_exceed = percentiles.view(stored_percentiles.shape) > stored_percentiles
p_exceed = p_abs > stored_percentiles # same shape as p
# The proportion that exceeds stored_percentiles
proportion_exceed = p_exceed.sum(dim=list(range(1, p.ndim)), keepdim=True) / numel
# Update store_percentiles
update_sign = (percentiles_exceed.to(dtype) - 0.5).sign()
update_sign = (proportion_exceed - percentile_limit).sign()
stored_percentiles.mul_(1 + update_sign * lr / p_limit_values).clamp_(min=1.0e-20)
p_exceed = p_abs > stored_percentiles
# if random.random() < 0.1:
# # print(stored_percentiles)
# # print(percentiles_exceed)
# print(p_exceed.sum(dim=list(range(1, p.ndim))) / numel)
# For these parameters that exceed stored_percentile,
# flip the update sign if they would get larger absolute values
limit_sign = 1 - 2 * p_exceed * ((p.sign() * grad.sign()) < 0)
# Decide whether to change grad sign
limit_sign = (~percentiles_exceed * p_exceed) * ((p.sign() * grad.sign()) < 0)
limit_sign = (limit_sign.to(dtype) - 0.5).sign()
# TODO: will remove the log bellow
if random.random() < 0.1:
logging.info(f"p.shape: {p.shape}")
logging.info(f"stored_percentiles: {stored_percentiles.view(-1)}")
logging.info(f"proportion_exceed: {proportion_exceed.view(-1)}")
logging.info(f"p_abs_max: {p_abs.view(batch_size, -1).max(dim=-1)[0]}")
mask = limit_sign == -1
proportion_sign_change = mask.sum(dim=list(range(1, p.ndim))) / numel
logging.info(f"proportion_sign_change: {proportion_sign_change}")
return -1 * limit_sign
return limit_sign
class LRScheduler(object):

View File

@ -378,9 +378,16 @@ def get_parser():
"--p-limit-values",
type=float,
default=0.0,
help="""The probability (e.g., 0.1) to modify the update sign so as to prevent
absolute-values of any weight tensor from being over a certain percentile of
the distribution of that parameter tensor's absolute values""",
help="""The probability (e.g., 0.1) to modify the update sign, so as to limit the
parameter absolute values that are larger than 1-percentile_limit (e.g., 95%) percentile.""",
)
parser.add_argument(
"--percentile-limit",
type=float,
default=0.05,
help="""The parameter absolute values over 1-percentile_limit (e.g., 95%) percentile
will be limited.""",
)
add_model_arguments(parser)
@ -1026,6 +1033,7 @@ def run(rank, world_size, args):
clipping_scale=2.0,
parameters_names=parameters_names,
p_limit_values=params.p_limit_values,
percentile_limit=params.percentile_limit,
)
scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)