add _limit_values_sign in ScaledAdam

This commit is contained in:
yaozengwei 2023-03-18 21:47:08 +08:00
parent 6196b4a407
commit cbfd459df7
2 changed files with 73 additions and 1 deletions

View File

@ -16,6 +16,7 @@
import contextlib import contextlib
import logging import logging
import math
import random import random
from collections import defaultdict from collections import defaultdict
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
@ -163,6 +164,11 @@ class ScaledAdam(BatchedOptimizer):
of the parameter tensor. This is provided to save a little time of the parameter tensor. This is provided to save a little time
in the update. in the update.
clipping_update_period: if clipping_scale is specified, this is the period clipping_update_period: if clipping_scale is specified, this is the period
p_limit_values: The probability (e.g., 0.1) to modify the update sign so as to prevent
absolute-values of any weight tensor from being over a certain percentile of
the distribution of that parameter tensor's absolute values.
percentile_limit: The percentile (e.g., 0.9) over which the parameter absolute values would be
limited.
""" """
def __init__( def __init__(
@ -180,6 +186,8 @@ class ScaledAdam(BatchedOptimizer):
clipping_update_period=100, clipping_update_period=100,
parameters_names=None, parameters_names=None,
show_dominant_parameters=True, show_dominant_parameters=True,
p_limit_values=0.0,
percentile_limit=0.9,
): ):
assert parameters_names is not None, ( assert parameters_names is not None, (
@ -198,6 +206,8 @@ class ScaledAdam(BatchedOptimizer):
scalar_max=scalar_max, scalar_max=scalar_max,
size_update_period=size_update_period, size_update_period=size_update_period,
clipping_update_period=clipping_update_period, clipping_update_period=clipping_update_period,
p_limit_values=p_limit_values,
percentile_limit=percentile_limit,
) )
super(ScaledAdam, self).__init__(params, defaults) super(ScaledAdam, self).__init__(params, defaults)
@ -296,6 +306,9 @@ class ScaledAdam(BatchedOptimizer):
size_update_period, *param_rms.shape, **kwargs size_update_period, *param_rms.shape, **kwargs
) )
if group["p_limit_values"] > 0:
state["stored_percentiles"] = torch.ones_like(param_rms)
# exp_avg_sq is the weighted sum of scaled gradients. as in Adam. # exp_avg_sq is the weighted sum of scaled gradients. as in Adam.
state["exp_avg_sq"] = torch.zeros_like(p, memory_format=torch.preserve_format) state["exp_avg_sq"] = torch.zeros_like(p, memory_format=torch.preserve_format)
@ -603,7 +616,12 @@ class ScaledAdam(BatchedOptimizer):
alpha = -lr * (1 - beta1) * state["param_rms"].clamp(min=param_min_rms) alpha = -lr * (1 - beta1) * state["param_rms"].clamp(min=param_min_rms)
delta = state["delta"] delta = state["delta"]
delta.add_(grad * alpha)
if random.random() >= group["p_limit_values"]:
delta.add_(grad * alpha)
else:
delta.add_((grad * alpha) * self._limit_values_sign(group, p, grad, state))
p.add_(delta) p.add_(delta)
def _step_scalar(self, group: dict, p: Tensor, state: dict): def _step_scalar(self, group: dict, p: Tensor, state: dict):
@ -630,6 +648,50 @@ class ScaledAdam(BatchedOptimizer):
p.clamp_(min=-scalar_max, max=scalar_max) p.clamp_(min=-scalar_max, max=scalar_max)
p.add_(delta) p.add_(delta)
def _limit_values_sign(self, group: dict, p: Tensor, grad: Tensor, state: dict):
"""Decide whether to modify the sign of the update.
Args:
group: A dict which will be used to look up configuration values
p: The parameter to be updated
grad: The grad of p
state: The state-dict corresponding to parameter p
Returns: A tensor with same shape as p, filled with 1 or -1.
"""
lr = group["lr"]
p_limit_values = group["p_limit_values"] # e.g., 0.1
percentile_limit = group["percentile_limit"] # e.g., 0.9
# it has a shape like (batch_size, 1, 1, 1, 1)
stored_percentiles = state["stored_percentiles"]
p_abs = p.abs()
dtype = p.dtype
batch_size = p.shape[0]
numel = p.numel() / batch_size
k = math.ceil(numel * (1 - percentile_limit))
percentiles = p_abs.view(batch_size, -1).topk(k=k, dim=-1)[0][:, -1] # (batch,)
# If True, stored_percentiles should be increased
percentiles_exceed = percentiles.view(stored_percentiles.shape) > stored_percentiles
# Update store_percentiles
update_sign = (percentiles_exceed.to(dtype) - 0.5).sign()
stored_percentiles.mul_(1 + update_sign * lr / p_limit_values).clamp_(min=1.0e-20)
p_exceed = p_abs > stored_percentiles
# if random.random() < 0.1:
# # print(stored_percentiles)
# # print(percentiles_exceed)
# print(p_exceed.sum(dim=list(range(1, p.ndim))) / numel)
# Decide whether to change grad sign
limit_sign = (~percentiles_exceed * p_exceed) * ((p.sign() * grad.sign()) < 0)
limit_sign = (limit_sign.to(dtype) - 0.5).sign()
return -1 * limit_sign
class LRScheduler(object): class LRScheduler(object):
""" """

View File

@ -374,6 +374,15 @@ def get_parser():
help="Whether to use half precision training.", help="Whether to use half precision training.",
) )
parser.add_argument(
"--p-limit-values",
type=float,
default=0.0,
help="""The probability (e.g., 0.1) to modify the update sign so as to prevent
absolute-values of any weight tensor from being over a certain percentile of
the distribution of that parameter tensor's absolute values""",
)
add_model_arguments(parser) add_model_arguments(parser)
return parser return parser
@ -1016,6 +1025,7 @@ def run(rank, world_size, args):
lr=params.base_lr, lr=params.base_lr,
clipping_scale=2.0, clipping_scale=2.0,
parameters_names=parameters_names, parameters_names=parameters_names,
p_limit_values=params.p_limit_values,
) )
scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)