Merge 1aa6fc0122981d3a248b08edd2bf4d71f5e27bd0 into abd9437e6d5419a497707748eb935e50976c3b7b

This commit is contained in:
Zengwei Yao 2025-06-27 11:33:00 +00:00 committed by GitHub
commit a84bd6ae1f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 133 additions and 38 deletions

View File

@ -16,6 +16,7 @@
import contextlib import contextlib
import logging import logging
import math
import random import random
from collections import defaultdict from collections import defaultdict
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
@ -163,6 +164,9 @@ class ScaledAdam(BatchedOptimizer):
of the parameter tensor. This is provided to save a little time of the parameter tensor. This is provided to save a little time
in the update. in the update.
clipping_update_period: if clipping_scale is specified, this is the period clipping_update_period: if clipping_scale is specified, this is the period
percentile_limit: The parameter absolute values over 1-percentile_limit (e.g., 95%) percentile will be limited.
p_limit_values: The probability (e.g., 0.1) to modify the update sign, so as to limit the
parameter absolute values that are larger than 1-percentile_limit (e.g., 95%) percentile.
""" """
def __init__( def __init__(
@ -180,6 +184,8 @@ class ScaledAdam(BatchedOptimizer):
clipping_update_period=100, clipping_update_period=100,
parameters_names=None, parameters_names=None,
show_dominant_parameters=True, show_dominant_parameters=True,
percentile_limit=0.05,
p_limit_values=0.0,
): ):
assert parameters_names is not None, ( assert parameters_names is not None, (
"Please prepare parameters_names," "Please prepare parameters_names,"
@ -197,6 +203,8 @@ class ScaledAdam(BatchedOptimizer):
scalar_max=scalar_max, scalar_max=scalar_max,
size_update_period=size_update_period, size_update_period=size_update_period,
clipping_update_period=clipping_update_period, clipping_update_period=clipping_update_period,
percentile_limit=percentile_limit,
p_limit_values=p_limit_values,
) )
super(ScaledAdam, self).__init__(params, defaults) super(ScaledAdam, self).__init__(params, defaults)
@ -292,6 +300,9 @@ class ScaledAdam(BatchedOptimizer):
size_update_period, *param_rms.shape, **kwargs size_update_period, *param_rms.shape, **kwargs
) )
if group["p_limit_values"] > 0:
state["stored_percentiles"] = torch.ones_like(param_rms)
# exp_avg_sq is the weighted sum of scaled gradients. as in Adam. # exp_avg_sq is the weighted sum of scaled gradients. as in Adam.
state["exp_avg_sq"] = torch.zeros_like(p, memory_format=torch.preserve_format) state["exp_avg_sq"] = torch.zeros_like(p, memory_format=torch.preserve_format)
@ -598,7 +609,12 @@ class ScaledAdam(BatchedOptimizer):
alpha = -lr * (1 - beta1) * state["param_rms"].clamp(min=param_min_rms) alpha = -lr * (1 - beta1) * state["param_rms"].clamp(min=param_min_rms)
delta = state["delta"] delta = state["delta"]
if random.random() >= group["p_limit_values"]:
delta.add_(grad * alpha) delta.add_(grad * alpha)
else:
delta.add_((grad * alpha) * self._limit_values_sign(group, p, grad, state))
p.add_(delta) p.add_(delta)
def _step_scalar(self, group: dict, p: Tensor, state: dict): def _step_scalar(self, group: dict, p: Tensor, state: dict):
@ -625,6 +641,67 @@ class ScaledAdam(BatchedOptimizer):
p.clamp_(min=-scalar_max, max=scalar_max) p.clamp_(min=-scalar_max, max=scalar_max)
p.add_(delta) p.add_(delta)
def _limit_values_sign(self, group: dict, p: Tensor, grad: Tensor, state: dict):
"""Decide whether to modify the sign of the cerrent update.
For each parameter tensor in p, we store the centain percentile (e.g., 95%),
say stored_percentiles.
If more than 5% parameter absolute values are larger than stored_percentiles:
Increase stored_percentiles by (1.0 + lr / p_limit_values).
Else:
Decrease stored_percentiles by (1.0 - lr / p_limit_values).
For all parameters whose absolute values are larger than stored_percentiles,
modify the sign of the update so that it points 'inward'.
Args:
group: A dict which will be used to look up configuration values
p: The parameter to be updated
grad: The grad of p
state: The state-dict corresponding to parameter p
Returns: A tensor with same shape as p, filled with 1 or -1.
"""
lr = group["lr"]
# The probability to limit the absolute values
p_limit_values = group["p_limit_values"] # e.g., 0.1
# The parameter absolute values over 1-percentile_limit percentile will be limited.
percentile_limit = group["percentile_limit"] # e.g., 0.05
# It stores the percentiles (i.e, 1-percentile_limit) of the parameter absolute values,
# with a shape like (batch_size, 1, 1, 1, 1)
stored_percentiles = state["stored_percentiles"]
batch_size = p.shape[0]
numel = p.numel() // batch_size
p_abs = p.abs()
p_exceed = p_abs > stored_percentiles # same shape as p
# The proportion that exceeds stored_percentiles
proportion_exceed = (
p_exceed.sum(dim=list(range(1, p.ndim)), keepdim=True) / numel
)
# Update store_percentiles
update_sign = (proportion_exceed - percentile_limit).sign()
stored_percentiles.mul_(1 + update_sign * lr / p_limit_values).clamp_(
min=1.0e-20
)
# For these parameters that exceed stored_percentile,
# flip the update sign if they would get larger absolute values
limit_sign = 1 - 2 * p_exceed * ((p.sign() * grad.sign()) < 0) * (p_abs > 10)
# TODO: will remove the log bellow
if random.random() < 0.1:
mask = limit_sign == -1
if mask.any():
logging.info(f"p.shape: {p.shape}")
logging.info(f"stored_percentiles: {stored_percentiles.view(-1)}")
logging.info(f"proportion_exceed: {proportion_exceed.view(-1)}")
logging.info(f"p_abs_max: {p_abs.view(batch_size, -1).max(dim=-1)[0]}")
proportion_sign_change = mask.sum(dim=list(range(1, p.ndim))) / numel
logging.info(f"proportion_sign_change: {proportion_sign_change}")
return limit_sign
class LRScheduler(object): class LRScheduler(object):
""" """

View File

@ -376,6 +376,22 @@ def get_parser():
help="Whether to use half precision training.", help="Whether to use half precision training.",
) )
parser.add_argument(
"--p-limit-values",
type=float,
default=0.0,
help="""The probability (e.g., 0.1) to modify the update sign, so as to limit the
parameter absolute values that are larger than 1-percentile_limit (e.g., 95%) percentile.""",
)
parser.add_argument(
"--percentile-limit",
type=float,
default=0.05,
help="""The parameter absolute values over 1-percentile_limit (e.g., 95%) percentile
will be limited.""",
)
add_model_arguments(parser) add_model_arguments(parser)
return parser return parser
@ -1009,6 +1025,8 @@ def run(rank, world_size, args):
lr=params.base_lr, lr=params.base_lr,
clipping_scale=2.0, clipping_scale=2.0,
parameters_names=parameters_names, parameters_names=parameters_names,
p_limit_values=params.p_limit_values,
percentile_limit=params.percentile_limit,
) )
scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)