diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py
index b86948067..79acd7ca3 100644
--- a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py
@@ -25,73 +25,9 @@ from torch.optim import Optimizer
import logging
-class BatchedOptimizer(Optimizer):
- """
- This class adds to class Optimizer the capability to optimize parameters in batches:
- it will stack the parameters and their grads for you so the optimizer can work
- on tensors with an extra leading dimension. This is intended for speed with GPUs,
- as it reduces the number of kernels launched in the optimizer.
-
- Args:
- params:
- """
- def __init__(self, params, defaults):
- super(BatchedOptimizer, self).__init__(params, defaults)
- def batched_params(self, param_group):
- """
- This function returns an iterator over tuples (p, state), where
- p is a `fake` parameter that is stacked (over axis 0) from real parameters
- that share the same shape, and its gradient is also stacked;
- `state` is the state corresponding to this fake parameter
- (it will be physically located in the "state" for one of the real
- parameters, the last one that has any particular shape and dtype).
-
- The idea is, instead of doing:
-
- for p in group["params"]:
- state = self.state[p]
- ...
-
- you can do:
-
- for p, state in self.batched_params(group["params"]):
- ...
-
-
- Args:
- group: a parameter group, which is a list of parameters; should be
- one of self.groups.
- """
- batches = defaultdict(list) # `batches` maps from tuple (dtype_as_str,*shape) to list of nn.Parameter
-
- for p in param_group:
- assert p.grad is not None and "All parameters must have a grad when you call optimizer.step()"
- assert not p.grad.is_sparse and "Sparse gradients not supported."
- key = (str(p.dtype), *p.shape)
- batches[key].append(p)
-
- for p in param_group:
- key = (str(p.dtype), *p.shape)
- batch = batches[key]
- if p is batch[0]: # if this is the 1st param in the batch...
- # we arbitrarily store the state in the
- # state corresponding to the 1st parameter in the
- # group. class Optimizer will take care of saving/loading state.
- state = self.state[p]
- p_stacked = torch.stack(batch)
- grad = torch.stack([p.grad for p in batch])
- p_stacked.grad = grad
- yield p_stacked, state # <-- calling code will do the actual optimization here!
- # Now un-stack the parameter changes
- for i,p in enumerate(batch):
- p.copy_(p_stacked[i])
-
-
-
-
-class ScaledAdam(BatchedOptimizer):
+class ScaledAdam(Optimizer):
"""
Implements 'Scaled Adam', a variant of Adam where we scale each parameter's update
proportional to the norm of that parameter; and also learn the scale of the parameter,
@@ -170,7 +106,9 @@ class ScaledAdam(BatchedOptimizer):
batch = True
for group in self.param_groups:
- for p, state in self.batched_params(group["params"]):
+ for p in group["params"]:
+ state = self.state[p]
+
# Perform optimization step
grad = p.grad
if grad.is_sparse:
@@ -191,14 +129,11 @@ class ScaledAdam(BatchedOptimizer):
p: Tensor,
state: dict):
"""
- Initializes state dict for parameter 'p'. Assumes that dim 0 of tensor p
- is actually the batch dimension, corresponding to batched-together
- parameters of a given shape.
+ Initializes state dict for parameter 'p'.
Args:
group: Dict to look up configuration values.
- p: The parameter that we are initializing the state for, actually
- will be stacked-together "real" parameters
+ p: The parameter that we are initializing the state for
state: Dict from string to whatever state we are initializing
"""
eps = group["eps"]
@@ -218,20 +153,17 @@ class ScaledAdam(BatchedOptimizer):
p, memory_format=torch.preserve_format
)
- batch_size = p.shape[0]
- numel = p.numel() // batch_size
+ numel = p.numel()
if numel > 1:
# "param_rms" just periodically records the scalar root-mean-square value of
# the parameter tensor.
- # it has a shape like (batch_size, 1, 1, 1, 1)
- param_rms = (p**2).mean(dim=list(range(1, p.ndim)),
- keepdim=True).sqrt() + eps
+ param_rms = (p**2).mean().sqrt() + eps
state["param_rms"] = param_rms
state["scale_exp_avg_sq"] = torch.zeros_like(param_rms)
- state["scale_grads"] = torch.zeros(size_update_period, *param_rms.shape,
- **kwargs)
+ state["scale_grads"] = torch.zeros(size_update_period, **kwargs)
+
# exp_avg_sq is the weighted sum of scaled gradients. as in Adam.
state["exp_avg_sq"] = torch.zeros_like(
@@ -244,12 +176,10 @@ class ScaledAdam(BatchedOptimizer):
p: Tensor,
state: dict):
"""
- Do the step for one parameter, which is actually going to be a batch of
- `real` parameters, with dim 0 as the batch dim.
+ Do the step for parameter p.
Args:
group: dict to look up configuration values
- p: parameter to update (actually multiple parameters stacked together
- as a batch)
+ p: parameter to update
state: state-dict for p, to look up the optimizer state
"""
lr = group["lr"]
@@ -262,17 +192,14 @@ class ScaledAdam(BatchedOptimizer):
delta = state["delta"]
delta.mul_(beta1)
- batch_size = p.shape[0]
- numel = p.numel() // batch_size
+ numel = p.numel()
if numel > 1:
# Update the size/scale of p, and set param_rms
scale_grads = state["scale_grads"]
- scale_grads[step % size_update_period] = (p * grad).sum(
- dim=list(range(1, p.ndim)), keepdim=True)
+ scale_grads[step % size_update_period] = (p * grad).sum()
if step % size_update_period == size_update_period - 1:
- param_rms = state["param_rms"] # shape: (batch_size, 1, 1, ..)
- param_rms.copy_((p**2).mean(dim=list(range(1, p.ndim)),
- keepdim=True).sqrt().clamp_(min=eps))
+ param_rms = state["param_rms"]
+ param_rms.copy_((p**2).mean().sqrt().clamp_(min=eps))
if step > 0:
# self._size_update() learns the overall scale on the
# parameter, by shrinking or expanding it.
@@ -302,7 +229,7 @@ class ScaledAdam(BatchedOptimizer):
Args:
group: dict to look up configuration values
- scale_grads: a tensor of shape (size_update_period, batch_size, 1, 1,...) containing
+ scale_grads: a tensor of shape (size_update_period) containing
grads w.r.t. the scales.
p: The parameter to update
state: The state-dict of p
@@ -315,16 +242,15 @@ class ScaledAdam(BatchedOptimizer):
param_max_rms = group["param_max_rms"]
step = state["step"]
- batch_size = p.shape[0]
size_update_period = scale_grads.shape[0]
# correct beta2 for the size update period: we will have
# faster decay at this level.
beta2_corr = beta2 ** size_update_period
- scale_exp_avg_sq = state["scale_exp_avg_sq"] # shape: (batch_size, 1, 1, ..)
+ scale_exp_avg_sq = state["scale_exp_avg_sq"] # scalar
scale_exp_avg_sq.mul_(beta2_corr).add_(
- (scale_grads ** 2).mean(dim=0), # mean over dim `size_update_period`
- alpha=1-beta2_corr) # shape is (batch_size, 1, 1, ...)
+ (scale_grads ** 2).mean(), # mean over `size_update_period` elements
+ alpha=1-beta2_corr)
# The 1st time we reach here is when size_step == 1.
size_step = (step + 1) // size_update_period
@@ -333,9 +259,9 @@ class ScaledAdam(BatchedOptimizer):
# at the start of training.
eps = 1.0e-10 # this value is not critical.
- denom = scale_exp_avg_sq.sqrt() + eps # shape: (batch_size, 1, 1, ..)
+ denom = scale_exp_avg_sq.sqrt() + eps
- scale_step = -size_lr * (bias_correction2 ** 0.5) * scale_grads.sum(dim=0) / denom
+ scale_step = -size_lr * (bias_correction2 ** 0.5) * scale_grads.sum() / denom
is_too_small = (param_rms < param_min_rms)
is_too_large = (param_rms > param_max_rms)
@@ -353,17 +279,13 @@ class ScaledAdam(BatchedOptimizer):
state: dict):
"""
This function does the core update of self.step(), in the case where the tensor
- has more than 1 elements. It multiplies the moving-average gradient tensor by a
- state["lr_{dim}"] on each non-trivial axis/dimension, does a length normalization,
- and takes a step in that direction.
+ has more than 1 elements.
Args:
group: A dict which will be used to look up configuration values
p: The parameter to be updated
grad: The grad of p
state: The state-dict corresponding to parameter p
- batch: if true, the first dimension of p is a batch dim, i.e. a batch of
- parameter matrices.
This function modifies p.
"""
@@ -401,7 +323,7 @@ class ScaledAdam(BatchedOptimizer):
state: dict):
"""
A simplified form of the core update for scalar tensors, where we cannot get a good
- estimate of the parameter rms. p will not be a scalar, because it has a batch dim.
+ estimate of the parameter rms.
"""
beta1, beta2 = group["betas"]
scalar_max = group["scalar_max"]
@@ -409,7 +331,7 @@ class ScaledAdam(BatchedOptimizer):
lr = group["lr"]
grad = p.grad
- exp_avg_sq = state["exp_avg_sq"] # shape: (batch_size,)
+ exp_avg_sq = state["exp_avg_sq"] # scalar
exp_avg_sq.mul_(beta2).addcmul_(grad, grad,
value=1-beta2)