mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Remove batching from ScaledAdam, in preparation to add gradient norm clipping
This commit is contained in:
parent
3b450c2682
commit
8f876b3f54
@ -25,73 +25,9 @@ from torch.optim import Optimizer
|
|||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
|
||||||
class BatchedOptimizer(Optimizer):
|
|
||||||
"""
|
|
||||||
This class adds to class Optimizer the capability to optimize parameters in batches:
|
|
||||||
it will stack the parameters and their grads for you so the optimizer can work
|
|
||||||
on tensors with an extra leading dimension. This is intended for speed with GPUs,
|
|
||||||
as it reduces the number of kernels launched in the optimizer.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
params:
|
|
||||||
"""
|
|
||||||
def __init__(self, params, defaults):
|
|
||||||
super(BatchedOptimizer, self).__init__(params, defaults)
|
|
||||||
|
|
||||||
|
|
||||||
def batched_params(self, param_group):
|
class ScaledAdam(Optimizer):
|
||||||
"""
|
|
||||||
This function returns an iterator over tuples (p, state), where
|
|
||||||
p is a `fake` parameter that is stacked (over axis 0) from real parameters
|
|
||||||
that share the same shape, and its gradient is also stacked;
|
|
||||||
`state` is the state corresponding to this fake parameter
|
|
||||||
(it will be physically located in the "state" for one of the real
|
|
||||||
parameters, the last one that has any particular shape and dtype).
|
|
||||||
|
|
||||||
The idea is, instead of doing:
|
|
||||||
<code>
|
|
||||||
for p in group["params"]:
|
|
||||||
state = self.state[p]
|
|
||||||
...
|
|
||||||
</code>
|
|
||||||
you can do:
|
|
||||||
<code>
|
|
||||||
for p, state in self.batched_params(group["params"]):
|
|
||||||
...
|
|
||||||
</code>
|
|
||||||
|
|
||||||
Args:
|
|
||||||
group: a parameter group, which is a list of parameters; should be
|
|
||||||
one of self.groups.
|
|
||||||
"""
|
|
||||||
batches = defaultdict(list) # `batches` maps from tuple (dtype_as_str,*shape) to list of nn.Parameter
|
|
||||||
|
|
||||||
for p in param_group:
|
|
||||||
assert p.grad is not None and "All parameters must have a grad when you call optimizer.step()"
|
|
||||||
assert not p.grad.is_sparse and "Sparse gradients not supported."
|
|
||||||
key = (str(p.dtype), *p.shape)
|
|
||||||
batches[key].append(p)
|
|
||||||
|
|
||||||
for p in param_group:
|
|
||||||
key = (str(p.dtype), *p.shape)
|
|
||||||
batch = batches[key]
|
|
||||||
if p is batch[0]: # if this is the 1st param in the batch...
|
|
||||||
# we arbitrarily store the state in the
|
|
||||||
# state corresponding to the 1st parameter in the
|
|
||||||
# group. class Optimizer will take care of saving/loading state.
|
|
||||||
state = self.state[p]
|
|
||||||
p_stacked = torch.stack(batch)
|
|
||||||
grad = torch.stack([p.grad for p in batch])
|
|
||||||
p_stacked.grad = grad
|
|
||||||
yield p_stacked, state # <-- calling code will do the actual optimization here!
|
|
||||||
# Now un-stack the parameter changes
|
|
||||||
for i,p in enumerate(batch):
|
|
||||||
p.copy_(p_stacked[i])
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class ScaledAdam(BatchedOptimizer):
|
|
||||||
"""
|
"""
|
||||||
Implements 'Scaled Adam', a variant of Adam where we scale each parameter's update
|
Implements 'Scaled Adam', a variant of Adam where we scale each parameter's update
|
||||||
proportional to the norm of that parameter; and also learn the scale of the parameter,
|
proportional to the norm of that parameter; and also learn the scale of the parameter,
|
||||||
@ -170,7 +106,9 @@ class ScaledAdam(BatchedOptimizer):
|
|||||||
|
|
||||||
batch = True
|
batch = True
|
||||||
for group in self.param_groups:
|
for group in self.param_groups:
|
||||||
for p, state in self.batched_params(group["params"]):
|
for p in group["params"]:
|
||||||
|
state = self.state[p]
|
||||||
|
|
||||||
# Perform optimization step
|
# Perform optimization step
|
||||||
grad = p.grad
|
grad = p.grad
|
||||||
if grad.is_sparse:
|
if grad.is_sparse:
|
||||||
@ -191,14 +129,11 @@ class ScaledAdam(BatchedOptimizer):
|
|||||||
p: Tensor,
|
p: Tensor,
|
||||||
state: dict):
|
state: dict):
|
||||||
"""
|
"""
|
||||||
Initializes state dict for parameter 'p'. Assumes that dim 0 of tensor p
|
Initializes state dict for parameter 'p'.
|
||||||
is actually the batch dimension, corresponding to batched-together
|
|
||||||
parameters of a given shape.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
group: Dict to look up configuration values.
|
group: Dict to look up configuration values.
|
||||||
p: The parameter that we are initializing the state for, actually
|
p: The parameter that we are initializing the state for
|
||||||
will be stacked-together "real" parameters
|
|
||||||
state: Dict from string to whatever state we are initializing
|
state: Dict from string to whatever state we are initializing
|
||||||
"""
|
"""
|
||||||
eps = group["eps"]
|
eps = group["eps"]
|
||||||
@ -218,20 +153,17 @@ class ScaledAdam(BatchedOptimizer):
|
|||||||
p, memory_format=torch.preserve_format
|
p, memory_format=torch.preserve_format
|
||||||
)
|
)
|
||||||
|
|
||||||
batch_size = p.shape[0]
|
numel = p.numel()
|
||||||
numel = p.numel() // batch_size
|
|
||||||
|
|
||||||
if numel > 1:
|
if numel > 1:
|
||||||
# "param_rms" just periodically records the scalar root-mean-square value of
|
# "param_rms" just periodically records the scalar root-mean-square value of
|
||||||
# the parameter tensor.
|
# the parameter tensor.
|
||||||
# it has a shape like (batch_size, 1, 1, 1, 1)
|
param_rms = (p**2).mean().sqrt() + eps
|
||||||
param_rms = (p**2).mean(dim=list(range(1, p.ndim)),
|
|
||||||
keepdim=True).sqrt() + eps
|
|
||||||
state["param_rms"] = param_rms
|
state["param_rms"] = param_rms
|
||||||
|
|
||||||
state["scale_exp_avg_sq"] = torch.zeros_like(param_rms)
|
state["scale_exp_avg_sq"] = torch.zeros_like(param_rms)
|
||||||
state["scale_grads"] = torch.zeros(size_update_period, *param_rms.shape,
|
state["scale_grads"] = torch.zeros(size_update_period, **kwargs)
|
||||||
**kwargs)
|
|
||||||
|
|
||||||
# exp_avg_sq is the weighted sum of scaled gradients. as in Adam.
|
# exp_avg_sq is the weighted sum of scaled gradients. as in Adam.
|
||||||
state["exp_avg_sq"] = torch.zeros_like(
|
state["exp_avg_sq"] = torch.zeros_like(
|
||||||
@ -244,12 +176,10 @@ class ScaledAdam(BatchedOptimizer):
|
|||||||
p: Tensor,
|
p: Tensor,
|
||||||
state: dict):
|
state: dict):
|
||||||
"""
|
"""
|
||||||
Do the step for one parameter, which is actually going to be a batch of
|
Do the step for parameter p.
|
||||||
`real` parameters, with dim 0 as the batch dim.
|
|
||||||
Args:
|
Args:
|
||||||
group: dict to look up configuration values
|
group: dict to look up configuration values
|
||||||
p: parameter to update (actually multiple parameters stacked together
|
p: parameter to update
|
||||||
as a batch)
|
|
||||||
state: state-dict for p, to look up the optimizer state
|
state: state-dict for p, to look up the optimizer state
|
||||||
"""
|
"""
|
||||||
lr = group["lr"]
|
lr = group["lr"]
|
||||||
@ -262,17 +192,14 @@ class ScaledAdam(BatchedOptimizer):
|
|||||||
delta = state["delta"]
|
delta = state["delta"]
|
||||||
|
|
||||||
delta.mul_(beta1)
|
delta.mul_(beta1)
|
||||||
batch_size = p.shape[0]
|
numel = p.numel()
|
||||||
numel = p.numel() // batch_size
|
|
||||||
if numel > 1:
|
if numel > 1:
|
||||||
# Update the size/scale of p, and set param_rms
|
# Update the size/scale of p, and set param_rms
|
||||||
scale_grads = state["scale_grads"]
|
scale_grads = state["scale_grads"]
|
||||||
scale_grads[step % size_update_period] = (p * grad).sum(
|
scale_grads[step % size_update_period] = (p * grad).sum()
|
||||||
dim=list(range(1, p.ndim)), keepdim=True)
|
|
||||||
if step % size_update_period == size_update_period - 1:
|
if step % size_update_period == size_update_period - 1:
|
||||||
param_rms = state["param_rms"] # shape: (batch_size, 1, 1, ..)
|
param_rms = state["param_rms"]
|
||||||
param_rms.copy_((p**2).mean(dim=list(range(1, p.ndim)),
|
param_rms.copy_((p**2).mean().sqrt().clamp_(min=eps))
|
||||||
keepdim=True).sqrt().clamp_(min=eps))
|
|
||||||
if step > 0:
|
if step > 0:
|
||||||
# self._size_update() learns the overall scale on the
|
# self._size_update() learns the overall scale on the
|
||||||
# parameter, by shrinking or expanding it.
|
# parameter, by shrinking or expanding it.
|
||||||
@ -302,7 +229,7 @@ class ScaledAdam(BatchedOptimizer):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
group: dict to look up configuration values
|
group: dict to look up configuration values
|
||||||
scale_grads: a tensor of shape (size_update_period, batch_size, 1, 1,...) containing
|
scale_grads: a tensor of shape (size_update_period) containing
|
||||||
grads w.r.t. the scales.
|
grads w.r.t. the scales.
|
||||||
p: The parameter to update
|
p: The parameter to update
|
||||||
state: The state-dict of p
|
state: The state-dict of p
|
||||||
@ -315,16 +242,15 @@ class ScaledAdam(BatchedOptimizer):
|
|||||||
param_max_rms = group["param_max_rms"]
|
param_max_rms = group["param_max_rms"]
|
||||||
step = state["step"]
|
step = state["step"]
|
||||||
|
|
||||||
batch_size = p.shape[0]
|
|
||||||
size_update_period = scale_grads.shape[0]
|
size_update_period = scale_grads.shape[0]
|
||||||
# correct beta2 for the size update period: we will have
|
# correct beta2 for the size update period: we will have
|
||||||
# faster decay at this level.
|
# faster decay at this level.
|
||||||
beta2_corr = beta2 ** size_update_period
|
beta2_corr = beta2 ** size_update_period
|
||||||
|
|
||||||
scale_exp_avg_sq = state["scale_exp_avg_sq"] # shape: (batch_size, 1, 1, ..)
|
scale_exp_avg_sq = state["scale_exp_avg_sq"] # scalar
|
||||||
scale_exp_avg_sq.mul_(beta2_corr).add_(
|
scale_exp_avg_sq.mul_(beta2_corr).add_(
|
||||||
(scale_grads ** 2).mean(dim=0), # mean over dim `size_update_period`
|
(scale_grads ** 2).mean(), # mean over `size_update_period` elements
|
||||||
alpha=1-beta2_corr) # shape is (batch_size, 1, 1, ...)
|
alpha=1-beta2_corr)
|
||||||
|
|
||||||
# The 1st time we reach here is when size_step == 1.
|
# The 1st time we reach here is when size_step == 1.
|
||||||
size_step = (step + 1) // size_update_period
|
size_step = (step + 1) // size_update_period
|
||||||
@ -333,9 +259,9 @@ class ScaledAdam(BatchedOptimizer):
|
|||||||
# at the start of training.
|
# at the start of training.
|
||||||
|
|
||||||
eps = 1.0e-10 # this value is not critical.
|
eps = 1.0e-10 # this value is not critical.
|
||||||
denom = scale_exp_avg_sq.sqrt() + eps # shape: (batch_size, 1, 1, ..)
|
denom = scale_exp_avg_sq.sqrt() + eps
|
||||||
|
|
||||||
scale_step = -size_lr * (bias_correction2 ** 0.5) * scale_grads.sum(dim=0) / denom
|
scale_step = -size_lr * (bias_correction2 ** 0.5) * scale_grads.sum() / denom
|
||||||
|
|
||||||
is_too_small = (param_rms < param_min_rms)
|
is_too_small = (param_rms < param_min_rms)
|
||||||
is_too_large = (param_rms > param_max_rms)
|
is_too_large = (param_rms > param_max_rms)
|
||||||
@ -353,17 +279,13 @@ class ScaledAdam(BatchedOptimizer):
|
|||||||
state: dict):
|
state: dict):
|
||||||
"""
|
"""
|
||||||
This function does the core update of self.step(), in the case where the tensor
|
This function does the core update of self.step(), in the case where the tensor
|
||||||
has more than 1 elements. It multiplies the moving-average gradient tensor by a
|
has more than 1 elements.
|
||||||
state["lr_{dim}"] on each non-trivial axis/dimension, does a length normalization,
|
|
||||||
and takes a step in that direction.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
group: A dict which will be used to look up configuration values
|
group: A dict which will be used to look up configuration values
|
||||||
p: The parameter to be updated
|
p: The parameter to be updated
|
||||||
grad: The grad of p
|
grad: The grad of p
|
||||||
state: The state-dict corresponding to parameter p
|
state: The state-dict corresponding to parameter p
|
||||||
batch: if true, the first dimension of p is a batch dim, i.e. a batch of
|
|
||||||
parameter matrices.
|
|
||||||
|
|
||||||
This function modifies p.
|
This function modifies p.
|
||||||
"""
|
"""
|
||||||
@ -401,7 +323,7 @@ class ScaledAdam(BatchedOptimizer):
|
|||||||
state: dict):
|
state: dict):
|
||||||
"""
|
"""
|
||||||
A simplified form of the core update for scalar tensors, where we cannot get a good
|
A simplified form of the core update for scalar tensors, where we cannot get a good
|
||||||
estimate of the parameter rms. p will not be a scalar, because it has a batch dim.
|
estimate of the parameter rms.
|
||||||
"""
|
"""
|
||||||
beta1, beta2 = group["betas"]
|
beta1, beta2 = group["betas"]
|
||||||
scalar_max = group["scalar_max"]
|
scalar_max = group["scalar_max"]
|
||||||
@ -409,7 +331,7 @@ class ScaledAdam(BatchedOptimizer):
|
|||||||
lr = group["lr"]
|
lr = group["lr"]
|
||||||
grad = p.grad
|
grad = p.grad
|
||||||
|
|
||||||
exp_avg_sq = state["exp_avg_sq"] # shape: (batch_size,)
|
exp_avg_sq = state["exp_avg_sq"] # scalar
|
||||||
exp_avg_sq.mul_(beta2).addcmul_(grad, grad,
|
exp_avg_sq.mul_(beta2).addcmul_(grad, grad,
|
||||||
value=1-beta2)
|
value=1-beta2)
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user