mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-10 10:32:17 +00:00
add _limit_values_sign in ScaledAdam
This commit is contained in:
parent
6196b4a407
commit
cbfd459df7
@ -16,6 +16,7 @@
|
|||||||
|
|
||||||
import contextlib
|
import contextlib
|
||||||
import logging
|
import logging
|
||||||
|
import math
|
||||||
import random
|
import random
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from typing import List, Optional, Tuple, Union
|
from typing import List, Optional, Tuple, Union
|
||||||
@ -163,6 +164,11 @@ class ScaledAdam(BatchedOptimizer):
|
|||||||
of the parameter tensor. This is provided to save a little time
|
of the parameter tensor. This is provided to save a little time
|
||||||
in the update.
|
in the update.
|
||||||
clipping_update_period: if clipping_scale is specified, this is the period
|
clipping_update_period: if clipping_scale is specified, this is the period
|
||||||
|
p_limit_values: The probability (e.g., 0.1) to modify the update sign so as to prevent
|
||||||
|
absolute-values of any weight tensor from being over a certain percentile of
|
||||||
|
the distribution of that parameter tensor's absolute values.
|
||||||
|
percentile_limit: The percentile (e.g., 0.9) over which the parameter absolute values would be
|
||||||
|
limited.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -180,6 +186,8 @@ class ScaledAdam(BatchedOptimizer):
|
|||||||
clipping_update_period=100,
|
clipping_update_period=100,
|
||||||
parameters_names=None,
|
parameters_names=None,
|
||||||
show_dominant_parameters=True,
|
show_dominant_parameters=True,
|
||||||
|
p_limit_values=0.0,
|
||||||
|
percentile_limit=0.9,
|
||||||
):
|
):
|
||||||
|
|
||||||
assert parameters_names is not None, (
|
assert parameters_names is not None, (
|
||||||
@ -198,6 +206,8 @@ class ScaledAdam(BatchedOptimizer):
|
|||||||
scalar_max=scalar_max,
|
scalar_max=scalar_max,
|
||||||
size_update_period=size_update_period,
|
size_update_period=size_update_period,
|
||||||
clipping_update_period=clipping_update_period,
|
clipping_update_period=clipping_update_period,
|
||||||
|
p_limit_values=p_limit_values,
|
||||||
|
percentile_limit=percentile_limit,
|
||||||
)
|
)
|
||||||
|
|
||||||
super(ScaledAdam, self).__init__(params, defaults)
|
super(ScaledAdam, self).__init__(params, defaults)
|
||||||
@ -296,6 +306,9 @@ class ScaledAdam(BatchedOptimizer):
|
|||||||
size_update_period, *param_rms.shape, **kwargs
|
size_update_period, *param_rms.shape, **kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if group["p_limit_values"] > 0:
|
||||||
|
state["stored_percentiles"] = torch.ones_like(param_rms)
|
||||||
|
|
||||||
# exp_avg_sq is the weighted sum of scaled gradients. as in Adam.
|
# exp_avg_sq is the weighted sum of scaled gradients. as in Adam.
|
||||||
state["exp_avg_sq"] = torch.zeros_like(p, memory_format=torch.preserve_format)
|
state["exp_avg_sq"] = torch.zeros_like(p, memory_format=torch.preserve_format)
|
||||||
|
|
||||||
@ -603,7 +616,12 @@ class ScaledAdam(BatchedOptimizer):
|
|||||||
alpha = -lr * (1 - beta1) * state["param_rms"].clamp(min=param_min_rms)
|
alpha = -lr * (1 - beta1) * state["param_rms"].clamp(min=param_min_rms)
|
||||||
|
|
||||||
delta = state["delta"]
|
delta = state["delta"]
|
||||||
delta.add_(grad * alpha)
|
|
||||||
|
if random.random() >= group["p_limit_values"]:
|
||||||
|
delta.add_(grad * alpha)
|
||||||
|
else:
|
||||||
|
delta.add_((grad * alpha) * self._limit_values_sign(group, p, grad, state))
|
||||||
|
|
||||||
p.add_(delta)
|
p.add_(delta)
|
||||||
|
|
||||||
def _step_scalar(self, group: dict, p: Tensor, state: dict):
|
def _step_scalar(self, group: dict, p: Tensor, state: dict):
|
||||||
@ -630,6 +648,50 @@ class ScaledAdam(BatchedOptimizer):
|
|||||||
p.clamp_(min=-scalar_max, max=scalar_max)
|
p.clamp_(min=-scalar_max, max=scalar_max)
|
||||||
p.add_(delta)
|
p.add_(delta)
|
||||||
|
|
||||||
|
def _limit_values_sign(self, group: dict, p: Tensor, grad: Tensor, state: dict):
|
||||||
|
"""Decide whether to modify the sign of the update.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
group: A dict which will be used to look up configuration values
|
||||||
|
p: The parameter to be updated
|
||||||
|
grad: The grad of p
|
||||||
|
state: The state-dict corresponding to parameter p
|
||||||
|
|
||||||
|
Returns: A tensor with same shape as p, filled with 1 or -1.
|
||||||
|
"""
|
||||||
|
lr = group["lr"]
|
||||||
|
p_limit_values = group["p_limit_values"] # e.g., 0.1
|
||||||
|
percentile_limit = group["percentile_limit"] # e.g., 0.9
|
||||||
|
# it has a shape like (batch_size, 1, 1, 1, 1)
|
||||||
|
stored_percentiles = state["stored_percentiles"]
|
||||||
|
|
||||||
|
p_abs = p.abs()
|
||||||
|
dtype = p.dtype
|
||||||
|
batch_size = p.shape[0]
|
||||||
|
|
||||||
|
numel = p.numel() / batch_size
|
||||||
|
k = math.ceil(numel * (1 - percentile_limit))
|
||||||
|
percentiles = p_abs.view(batch_size, -1).topk(k=k, dim=-1)[0][:, -1] # (batch,)
|
||||||
|
|
||||||
|
# If True, stored_percentiles should be increased
|
||||||
|
percentiles_exceed = percentiles.view(stored_percentiles.shape) > stored_percentiles
|
||||||
|
|
||||||
|
# Update store_percentiles
|
||||||
|
update_sign = (percentiles_exceed.to(dtype) - 0.5).sign()
|
||||||
|
stored_percentiles.mul_(1 + update_sign * lr / p_limit_values).clamp_(min=1.0e-20)
|
||||||
|
|
||||||
|
p_exceed = p_abs > stored_percentiles
|
||||||
|
# if random.random() < 0.1:
|
||||||
|
# # print(stored_percentiles)
|
||||||
|
# # print(percentiles_exceed)
|
||||||
|
# print(p_exceed.sum(dim=list(range(1, p.ndim))) / numel)
|
||||||
|
|
||||||
|
# Decide whether to change grad sign
|
||||||
|
limit_sign = (~percentiles_exceed * p_exceed) * ((p.sign() * grad.sign()) < 0)
|
||||||
|
limit_sign = (limit_sign.to(dtype) - 0.5).sign()
|
||||||
|
|
||||||
|
return -1 * limit_sign
|
||||||
|
|
||||||
|
|
||||||
class LRScheduler(object):
|
class LRScheduler(object):
|
||||||
"""
|
"""
|
||||||
|
@ -374,6 +374,15 @@ def get_parser():
|
|||||||
help="Whether to use half precision training.",
|
help="Whether to use half precision training.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--p-limit-values",
|
||||||
|
type=float,
|
||||||
|
default=0.0,
|
||||||
|
help="""The probability (e.g., 0.1) to modify the update sign so as to prevent
|
||||||
|
absolute-values of any weight tensor from being over a certain percentile of
|
||||||
|
the distribution of that parameter tensor's absolute values""",
|
||||||
|
)
|
||||||
|
|
||||||
add_model_arguments(parser)
|
add_model_arguments(parser)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
@ -1016,6 +1025,7 @@ def run(rank, world_size, args):
|
|||||||
lr=params.base_lr,
|
lr=params.base_lr,
|
||||||
clipping_scale=2.0,
|
clipping_scale=2.0,
|
||||||
parameters_names=parameters_names,
|
parameters_names=parameters_names,
|
||||||
|
p_limit_values=params.p_limit_values,
|
||||||
)
|
)
|
||||||
|
|
||||||
scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)
|
scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user