correct errors in _limit_values_sign

This commit is contained in:
yaozengwei 2023-03-22 19:31:10 +08:00
parent cbfd459df7
commit 8c91b9d0cd
2 changed files with 50 additions and 32 deletions

View File

@ -164,11 +164,9 @@ class ScaledAdam(BatchedOptimizer):
of the parameter tensor. This is provided to save a little time of the parameter tensor. This is provided to save a little time
in the update. in the update.
clipping_update_period: if clipping_scale is specified, this is the period clipping_update_period: if clipping_scale is specified, this is the period
p_limit_values: The probability (e.g., 0.1) to modify the update sign so as to prevent percentile_limit: The parameter absolute values over 1-percentile_limit (e.g., 95%) percentile will be limited.
absolute-values of any weight tensor from being over a certain percentile of p_limit_values: The probability (e.g., 0.1) to modify the update sign, so as to limit the
the distribution of that parameter tensor's absolute values. parameter absolute values that are larger than 1-percentile_limit (e.g., 95%) percentile.
percentile_limit: The percentile (e.g., 0.9) over which the parameter absolute values would be
limited.
""" """
def __init__( def __init__(
@ -186,8 +184,8 @@ class ScaledAdam(BatchedOptimizer):
clipping_update_period=100, clipping_update_period=100,
parameters_names=None, parameters_names=None,
show_dominant_parameters=True, show_dominant_parameters=True,
percentile_limit=0.05,
p_limit_values=0.0, p_limit_values=0.0,
percentile_limit=0.9,
): ):
assert parameters_names is not None, ( assert parameters_names is not None, (
@ -206,8 +204,8 @@ class ScaledAdam(BatchedOptimizer):
scalar_max=scalar_max, scalar_max=scalar_max,
size_update_period=size_update_period, size_update_period=size_update_period,
clipping_update_period=clipping_update_period, clipping_update_period=clipping_update_period,
p_limit_values=p_limit_values,
percentile_limit=percentile_limit, percentile_limit=percentile_limit,
p_limit_values=p_limit_values,
) )
super(ScaledAdam, self).__init__(params, defaults) super(ScaledAdam, self).__init__(params, defaults)
@ -649,7 +647,16 @@ class ScaledAdam(BatchedOptimizer):
p.add_(delta) p.add_(delta)
def _limit_values_sign(self, group: dict, p: Tensor, grad: Tensor, state: dict): def _limit_values_sign(self, group: dict, p: Tensor, grad: Tensor, state: dict):
"""Decide whether to modify the sign of the update. """Decide whether to modify the sign of the cerrent update.
For each parameter tensor in p, we store the centain percentile (e.g., 95%),
say stored_percentiles.
If more than 5% parameter absolute values are larger than stored_percentiles:
Increase stored_percentiles by (1.0 + lr / p_limit_values).
Else:
Decrease stored_percentiles by (1.0 - lr / p_limit_values).
For all parameters whose absolute values are larger than stored_percentiles,
modify the sign of the update so that it points 'inward'.
Args: Args:
group: A dict which will be used to look up configuration values group: A dict which will be used to look up configuration values
@ -660,37 +667,40 @@ class ScaledAdam(BatchedOptimizer):
Returns: A tensor with same shape as p, filled with 1 or -1. Returns: A tensor with same shape as p, filled with 1 or -1.
""" """
lr = group["lr"] lr = group["lr"]
# The probability to limit the absolute values
p_limit_values = group["p_limit_values"] # e.g., 0.1 p_limit_values = group["p_limit_values"] # e.g., 0.1
percentile_limit = group["percentile_limit"] # e.g., 0.9 # The parameter absolute values over 1-percentile_limit percentile will be limited.
# it has a shape like (batch_size, 1, 1, 1, 1) percentile_limit = group["percentile_limit"] # e.g., 0.05
# It stores the percentiles (i.e, 1-percentile_limit) of the parameter absolute values,
# with a shape like (batch_size, 1, 1, 1, 1)
stored_percentiles = state["stored_percentiles"] stored_percentiles = state["stored_percentiles"]
p_abs = p.abs()
dtype = p.dtype
batch_size = p.shape[0] batch_size = p.shape[0]
numel = p.numel() // batch_size
p_abs = p.abs()
numel = p.numel() / batch_size p_exceed = p_abs > stored_percentiles # same shape as p
k = math.ceil(numel * (1 - percentile_limit)) # The proportion that exceeds stored_percentiles
percentiles = p_abs.view(batch_size, -1).topk(k=k, dim=-1)[0][:, -1] # (batch,) proportion_exceed = p_exceed.sum(dim=list(range(1, p.ndim)), keepdim=True) / numel
# If True, stored_percentiles should be increased
percentiles_exceed = percentiles.view(stored_percentiles.shape) > stored_percentiles
# Update store_percentiles # Update store_percentiles
update_sign = (percentiles_exceed.to(dtype) - 0.5).sign() update_sign = (proportion_exceed - percentile_limit).sign()
stored_percentiles.mul_(1 + update_sign * lr / p_limit_values).clamp_(min=1.0e-20) stored_percentiles.mul_(1 + update_sign * lr / p_limit_values).clamp_(min=1.0e-20)
p_exceed = p_abs > stored_percentiles # For these parameters that exceed stored_percentile,
# if random.random() < 0.1: # flip the update sign if they would get larger absolute values
# # print(stored_percentiles) limit_sign = 1 - 2 * p_exceed * ((p.sign() * grad.sign()) < 0)
# # print(percentiles_exceed)
# print(p_exceed.sum(dim=list(range(1, p.ndim))) / numel)
# Decide whether to change grad sign # TODO: will remove the log bellow
limit_sign = (~percentiles_exceed * p_exceed) * ((p.sign() * grad.sign()) < 0) if random.random() < 0.1:
limit_sign = (limit_sign.to(dtype) - 0.5).sign() logging.info(f"p.shape: {p.shape}")
logging.info(f"stored_percentiles: {stored_percentiles.view(-1)}")
logging.info(f"proportion_exceed: {proportion_exceed.view(-1)}")
logging.info(f"p_abs_max: {p_abs.view(batch_size, -1).max(dim=-1)[0]}")
mask = limit_sign == -1
proportion_sign_change = mask.sum(dim=list(range(1, p.ndim))) / numel
logging.info(f"proportion_sign_change: {proportion_sign_change}")
return -1 * limit_sign return limit_sign
class LRScheduler(object): class LRScheduler(object):

View File

@ -378,9 +378,16 @@ def get_parser():
"--p-limit-values", "--p-limit-values",
type=float, type=float,
default=0.0, default=0.0,
help="""The probability (e.g., 0.1) to modify the update sign so as to prevent help="""The probability (e.g., 0.1) to modify the update sign, so as to limit the
absolute-values of any weight tensor from being over a certain percentile of parameter absolute values that are larger than 1-percentile_limit (e.g., 95%) percentile.""",
the distribution of that parameter tensor's absolute values""", )
parser.add_argument(
"--percentile-limit",
type=float,
default=0.05,
help="""The parameter absolute values over 1-percentile_limit (e.g., 95%) percentile
will be limited.""",
) )
add_model_arguments(parser) add_model_arguments(parser)
@ -1026,6 +1033,7 @@ def run(rank, world_size, args):
clipping_scale=2.0, clipping_scale=2.0,
parameters_names=parameters_names, parameters_names=parameters_names,
p_limit_values=params.p_limit_values, p_limit_values=params.p_limit_values,
percentile_limit=params.percentile_limit,
) )
scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)