diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py index 5ee4bab98..837e4297f 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py @@ -22,6 +22,7 @@ from typing import Optional, Tuple import torch import torch.nn as nn from torch import Tensor +from torch.nn import Embedding as ScaledEmbedding def _ntuple(n): @@ -154,10 +155,8 @@ class BasicNorm(torch.nn.Module): class ScaledLinear(nn.Linear): """ - A modified version of nn.Linear where the parameters are scaled before - use, via: - weight = self.weight * self.weight_scale.exp() - bias = self.bias * self.bias_scale.exp() + A modified version of nn.Linear that gives an easy way to set the + default initial parameter scale. Args: Accepts the standard args and kwargs that nn.Linear accepts @@ -168,59 +167,26 @@ class ScaledLinear(nn.Linear): (affects the initialization of weight_scale and bias_scale). Another option, if you want to do something like this, is to re-initialize the parameters. - initial_speed: this affects how fast the parameter will - learn near the start of training; you can set it to a - value less than one if you suspect that a module - is contributing to instability near the start of training. - Nnote: regardless of the use of this option, it's best to - use schedulers like Noam that have a warm-up period. - Alternatively you can set it to more than 1 if you want it to - initially train faster. Must be greater than 0. """ def __init__( self, *args, initial_scale: float = 1.0, - initial_speed: float = 1.0, **kwargs ): super(ScaledLinear, self).__init__(*args, **kwargs) - initial_scale = torch.tensor(initial_scale).log() - self.weight_scale = nn.Parameter(initial_scale.clone().detach()) - if self.bias is not None: - self.bias_scale = nn.Parameter(initial_scale.clone().detach()) - else: - self.register_parameter("bias_scale", None) - - self._reset_parameters( - initial_speed - ) # Overrides the reset_parameters in nn.Linear - - def _reset_parameters(self, initial_speed: float): - std = 0.1 / initial_speed - a = (3 ** 0.5) * std - nn.init.uniform_(self.weight, -a, a) - if self.bias is not None: - nn.init.constant_(self.bias, 0.0) - fan_in = self.weight.shape[1] * self.weight[0][0].numel() - scale = fan_in ** -0.5 # 1/sqrt(fan_in) with torch.no_grad(): - self.weight_scale += torch.tensor(scale / std).log() + self.weight[:] *= initial_scale + if self.bias is not None: + self.bias[:] *= initial_scale - def get_weight(self): - return self.weight * self.weight_scale.exp() + def get_weight(self): # not needed any more but kept for back compatibility + return self.weight def get_bias(self): - if self.bias is None or self.bias_scale is None: - return None + return self.bias - return self.bias * self.bias_scale.exp() - - def forward(self, input: Tensor) -> Tensor: - return torch.nn.functional.linear( - input, self.get_weight(), self.get_bias() - ) class ScaledConv1d(nn.Conv1d): @@ -229,66 +195,20 @@ class ScaledConv1d(nn.Conv1d): self, *args, initial_scale: float = 1.0, - initial_speed: float = 1.0, **kwargs ): super(ScaledConv1d, self).__init__(*args, **kwargs) - initial_scale = torch.tensor(initial_scale).log() - self.weight_scale = nn.Parameter(initial_scale.clone().detach()) - if self.bias is not None: - self.bias_scale = nn.Parameter(initial_scale.clone().detach()) - else: - self.register_parameter("bias_scale", None) - self._reset_parameters( - initial_speed - ) # Overrides the reset_parameters in base class - - def _reset_parameters(self, initial_speed: float): - std = 0.1 / initial_speed - a = (3 ** 0.5) * std - nn.init.uniform_(self.weight, -a, a) - if self.bias is not None: - nn.init.constant_(self.bias, 0.0) - fan_in = self.weight.shape[1] * self.weight[0][0].numel() - scale = fan_in ** -0.5 # 1/sqrt(fan_in) with torch.no_grad(): - self.weight_scale += torch.tensor(scale / std).log() + self.weight[:] *= initial_scale + if self.bias is not None: + self.bias[:] *= initial_scale - def get_weight(self): - return self.weight * self.weight_scale.exp() - def get_bias(self): - bias = self.bias - bias_scale = self.bias_scale - if bias is None or bias_scale is None: - return None - return bias * bias_scale.exp() + def get_weight(self): # TODO: delete + return self.weight - def forward(self, input: Tensor) -> Tensor: - F = torch.nn.functional - if self.padding_mode != "zeros": - return F.conv1d( - F.pad( - input, - self._reversed_padding_repeated_twice, - mode=self.padding_mode, - ), - self.get_weight(), - self.get_bias(), - self.stride, - (0,), - self.dilation, - self.groups, - ) - return F.conv1d( - input, - self.get_weight(), - self.get_bias(), - self.stride, - self.padding, - self.dilation, - self.groups, - ) + def get_bias(self): # TODO: delete + return self.bias class ScaledConv2d(nn.Conv2d): @@ -297,70 +217,20 @@ class ScaledConv2d(nn.Conv2d): self, *args, initial_scale: float = 1.0, - initial_speed: float = 1.0, **kwargs ): super(ScaledConv2d, self).__init__(*args, **kwargs) - initial_scale = torch.tensor(initial_scale).log() - self.weight_scale = nn.Parameter(initial_scale.clone().detach()) - if self.bias is not None: - self.bias_scale = nn.Parameter(initial_scale.clone().detach()) - else: - self.register_parameter("bias_scale", None) - self._reset_parameters( - initial_speed - ) # Overrides the reset_parameters in base class - - def _reset_parameters(self, initial_speed: float): - std = 0.1 / initial_speed - a = (3 ** 0.5) * std - nn.init.uniform_(self.weight, -a, a) - if self.bias is not None: - nn.init.constant_(self.bias, 0.0) - fan_in = self.weight.shape[1] * self.weight[0][0].numel() - scale = fan_in ** -0.5 # 1/sqrt(fan_in) with torch.no_grad(): - self.weight_scale += torch.tensor(scale / std).log() + self.weight[:] *= initial_scale + if self.bias is not None: + self.bias[:] *= initial_scale def get_weight(self): - return self.weight * self.weight_scale.exp() + return self.weight def get_bias(self): - # see https://github.com/pytorch/pytorch/issues/24135 - bias = self.bias - bias_scale = self.bias_scale - if bias is None or bias_scale is None: - return None - return bias * bias_scale.exp() + return self.bias - def _conv_forward(self, input, weight): - F = torch.nn.functional - if self.padding_mode != "zeros": - return F.conv2d( - F.pad( - input, - self._reversed_padding_repeated_twice, - mode=self.padding_mode, - ), - weight, - self.get_bias(), - self.stride, - (0, 0), - self.dilation, - self.groups, - ) - return F.conv2d( - input, - weight, - self.get_bias(), - self.stride, - self.padding, - self.dilation, - self.groups, - ) - - def forward(self, input: Tensor) -> Tensor: - return self._conv_forward(input, self.get_weight()) class ActivationBalancer(torch.nn.Module): @@ -464,179 +334,6 @@ class DoubleSwish(torch.nn.Module): return DoubleSwishFunction.apply(x) -class ScaledEmbedding(nn.Module): - r"""This is a modified version of nn.Embedding that introduces a learnable scale - on the parameters. Note: due to how we initialize it, it's best used with - schedulers like Noam that have a warmup period. - - It is a simple lookup table that stores embeddings of a fixed dictionary and size. - - This module is often used to store word embeddings and retrieve them using indices. - The input to the module is a list of indices, and the output is the corresponding - word embeddings. - - Args: - num_embeddings (int): size of the dictionary of embeddings - embedding_dim (int): the size of each embedding vector - padding_idx (int, optional): If given, pads the output with the embedding vector at :attr:`padding_idx` - (initialized to zeros) whenever it encounters the index. - max_norm (float, optional): If given, each embedding vector with norm larger than :attr:`max_norm` - is renormalized to have norm :attr:`max_norm`. - norm_type (float, optional): The p of the p-norm to compute for the :attr:`max_norm` option. Default ``2``. - scale_grad_by_freq (boolean, optional): If given, this will scale gradients by the inverse of frequency of - the words in the mini-batch. Default ``False``. - sparse (bool, optional): If ``True``, gradient w.r.t. :attr:`weight` matrix will be a sparse tensor. - See Notes for more details regarding sparse gradients. - - initial_speed (float, optional): This affects how fast the parameter will - learn near the start of training; you can set it to a value less than - one if you suspect that a module is contributing to instability near - the start of training. Nnote: regardless of the use of this option, - it's best to use schedulers like Noam that have a warm-up period. - Alternatively you can set it to more than 1 if you want it to - initially train faster. Must be greater than 0. - - - Attributes: - weight (Tensor): the learnable weights of the module of shape (num_embeddings, embedding_dim) - initialized from :math:`\mathcal{N}(0, 1)` - - Shape: - - Input: :math:`(*)`, LongTensor of arbitrary shape containing the indices to extract - - Output: :math:`(*, H)`, where `*` is the input shape and :math:`H=\text{embedding\_dim}` - - .. note:: - Keep in mind that only a limited number of optimizers support - sparse gradients: currently it's :class:`optim.SGD` (`CUDA` and `CPU`), - :class:`optim.SparseAdam` (`CUDA` and `CPU`) and :class:`optim.Adagrad` (`CPU`) - - .. note:: - With :attr:`padding_idx` set, the embedding vector at - :attr:`padding_idx` is initialized to all zeros. However, note that this - vector can be modified afterwards, e.g., using a customized - initialization method, and thus changing the vector used to pad the - output. The gradient for this vector from :class:`~torch.nn.Embedding` - is always zero. - - Examples:: - - >>> # an Embedding module containing 10 tensors of size 3 - >>> embedding = nn.Embedding(10, 3) - >>> # a batch of 2 samples of 4 indices each - >>> input = torch.LongTensor([[1,2,4,5],[4,3,2,9]]) - >>> embedding(input) - tensor([[[-0.0251, -1.6902, 0.7172], - [-0.6431, 0.0748, 0.6969], - [ 1.4970, 1.3448, -0.9685], - [-0.3677, -2.7265, -0.1685]], - - [[ 1.4970, 1.3448, -0.9685], - [ 0.4362, -0.4004, 0.9400], - [-0.6431, 0.0748, 0.6969], - [ 0.9124, -2.3616, 1.1151]]]) - - - >>> # example with padding_idx - >>> embedding = nn.Embedding(10, 3, padding_idx=0) - >>> input = torch.LongTensor([[0,2,0,5]]) - >>> embedding(input) - tensor([[[ 0.0000, 0.0000, 0.0000], - [ 0.1535, -2.0309, 0.9315], - [ 0.0000, 0.0000, 0.0000], - [-0.1655, 0.9897, 0.0635]]]) - - """ - __constants__ = [ - "num_embeddings", - "embedding_dim", - "padding_idx", - "scale_grad_by_freq", - "sparse", - ] - - num_embeddings: int - embedding_dim: int - padding_idx: int - scale_grad_by_freq: bool - weight: Tensor - sparse: bool - - def __init__( - self, - num_embeddings: int, - embedding_dim: int, - padding_idx: Optional[int] = None, - scale_grad_by_freq: bool = False, - sparse: bool = False, - initial_speed: float = 1.0, - ) -> None: - super(ScaledEmbedding, self).__init__() - self.num_embeddings = num_embeddings - self.embedding_dim = embedding_dim - if padding_idx is not None: - if padding_idx > 0: - assert ( - padding_idx < self.num_embeddings - ), "Padding_idx must be within num_embeddings" - elif padding_idx < 0: - assert ( - padding_idx >= -self.num_embeddings - ), "Padding_idx must be within num_embeddings" - padding_idx = self.num_embeddings + padding_idx - self.padding_idx = padding_idx - self.scale_grad_by_freq = scale_grad_by_freq - - self.scale = nn.Parameter(torch.zeros(())) # see reset_parameters() - self.sparse = sparse - - self.weight = nn.Parameter(torch.Tensor(num_embeddings, embedding_dim)) - self.reset_parameters(initial_speed) - - def reset_parameters(self, initial_speed: float = 1.0) -> None: - std = 0.1 / initial_speed - nn.init.normal_(self.weight, std=std) - nn.init.constant_(self.scale, torch.tensor(1.0 / std).log()) - - if self.padding_idx is not None: - with torch.no_grad(): - self.weight[self.padding_idx].fill_(0) - - def forward(self, input: Tensor) -> Tensor: - F = torch.nn.functional - scale = self.scale.exp() - if input.numel() < self.num_embeddings: - return ( - F.embedding( - input, - self.weight, - self.padding_idx, - None, - 2.0, # None, 2.0 relate to normalization - self.scale_grad_by_freq, - self.sparse, - ) - * scale - ) - else: - return F.embedding( - input, - self.weight * scale, - self.padding_idx, - None, - 2.0, # None, 2.0 relates to normalization - self.scale_grad_by_freq, - self.sparse, - ) - - def extra_repr(self) -> str: - s = "{num_embeddings}, {embedding_dim}, scale={scale}" - if self.padding_idx is not None: - s += ", padding_idx={padding_idx}" - if self.scale_grad_by_freq is not False: - s += ", scale_grad_by_freq={scale_grad_by_freq}" - if self.sparse is not False: - s += ", sparse=True" - return s.format(**self.__dict__) def _test_activation_balancer_sign(): diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless4/optim.py index acf4e42a4..921df6c63 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/optim.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/optim.py @@ -205,9 +205,16 @@ class Abel(Optimizer): eps (float, optional): term added to the denominator to improve numerical stability (default: 1e-08). target_rms (float, optional): target root-mean-square value of - x_norm in factorization (conceptually). Actually this now just becomes - a factor in the learning rate, we may remove it at some point. - + parameters, when we normalize them (only conceptually!).. + actually this now just becomes + a factor in the learning rate for non-scalar parameters. + rms_eps (float, optional): epsilon value such that when parameter + rms value goes below this, we stop learning slower for that parameter, + i.e. we treat the rms as if it were rms_eps. + rms_max (float, optional): maximum root-mean-square value for + non-scalar parameters, such that we don't let the parameter + values exceed this; we'll shrink any parameter matrices that become + larger than this. .. _Adam\: A Method for Stochastic Optimization: https://arxiv.org/abs/1412.6980 @@ -438,6 +445,9 @@ class Cain(Optimizer): co-ordinates every so often, just in the optimizer, to separate big/small directions. + This version of Cain absorbs the learned scales of the parameter matrices into + the optimizer. + Eve is a modified version of AdamW with a special way of setting the weight-decay / shrinkage-factor, which is designed to make the rms of the parameters approach a particular target_rms (default: 0.1). This is @@ -456,12 +466,17 @@ class Cain(Optimizer): running averages of gradient and its square (default: (0.9, 0.999)) eps (float, optional): term added to the denominator to improve numerical stability (default: 1e-8) - weight_decay (float, optional): weight decay coefficient (default: 3e-4; - this value means that the weight would decay significantly after - about 3k minibatches. Is not multiplied by learning rate, but - is conditional on RMS-value of parameter being > target_rms. target_rms (float, optional): target root-mean-square value of parameters, if they fall below this we will stop applying weight decay. + rms_eps (float, optional): epsilon value such that when parameter + rms value goes below this, we stop learning slower for that parameter, + i.e. we treat the rms as if it were rms_eps. + rms_max (float, optional): maximum root-mean-square value for + non-scalar parameters, such that we don't let the parameter + values exceed this; we'll shrink any parameter matrices that become + larger than this. + + .. _Adam\: A Method for Stochastic Optimization: @@ -478,8 +493,9 @@ class Cain(Optimizer): lr=1e-3, betas=(0.9, 0.98), eps=1e-8, - weight_decay=1e-3, target_rms=0.1, + rms_eps=1.0e-05, + rms_max=10.0, ): if not 0.0 <= lr: @@ -494,18 +510,20 @@ class Cain(Optimizer): raise ValueError( "Invalid beta parameter at index 1: {}".format(betas[1]) ) - if not 0 <= weight_decay <= 0.1: - raise ValueError( - "Invalid weight_decay value: {}".format(weight_decay) - ) if not 0 < target_rms <= 10.0: raise ValueError("Invalid target_rms value: {}".format(target_rms)) + if not 0.1 < rms_max <= 1000.0: + raise ValueError("Invalid rms_max value: {}".format(rms_max)) + if not 0.0 < rms_eps < 0.1: + raise ValueError("Invalid rms_eps value: {}".format(rms_eps)) + defaults = dict( lr=lr, betas=betas, eps=eps, - weight_decay=weight_decay, target_rms=target_rms, + rms_eps=rms_eps, + rms_max=rms_max, ) super(Cain, self).__init__(params, defaults) @@ -526,6 +544,12 @@ class Cain(Optimizer): loss = closure() for group in self.param_groups: + beta1, beta2 = group["betas"] + lr = group["lr"] + target_rms = group["target_rms"] + rms_eps = group["rms_eps"] + rms_max = group["rms_max"] + for p in group["params"]: if p.grad is None: continue @@ -550,14 +574,61 @@ class Cain(Optimizer): state["exp_avg_sq"] = torch.zeros_like( p, memory_format=torch.preserve_format ) + if p.numel() > 1: + # we keep track of the RMS value of the parameters to + # help set the learning rate... this emulates Eve. + state["param_rms"] = (p**2).mean().sqrt() + # we learn the scale on the parameters in a way + # that also emulates Eve. + state["scale_exp_avg_sq"] = torch.zeros((), device=p.device, + dtype=p.dtype) + state["scale_exp_avg"] = torch.zeros((), device=p.device, + dtype=p.dtype) + + + step = state["step"] + + if p.numel() > 1: + # This part learns the scale on the parameters. + # scale_deriv is the derivative w.r.t. a log-space scale on the parameters, i.e. + # if we imagine: p = underlying_param * scale.exp(), this is the derivative + # w.r.t. scale (it does not actually matter what value `scale` currently has). + scale_deriv = (p * grad).sum() + scale_exp_avg_sq = state["scale_exp_avg_sq"] + scale_exp_avg = state["scale_exp_avg"] + scale_exp_avg_sq.mul_(beta2).addcmul_(scale_deriv, scale_deriv, + value=1 - beta2) + + scale_exp_avg.mul_(beta1).add_(scale_deriv, alpha=(1-beta1)) + scale_bias_correction2 = 1 - beta2 ** step + + scale_denom = (scale_exp_avg_sq.sqrt()).add_(group["eps"]) + + scale_alpha = -lr*(scale_bias_correction2 ** 0.5) + + # scale_delta is going to be the change in log-scale, i.e. if + # p = underlying_param * scale.exp(), + # delta is the change in `scale`. + scale_delta = scale_alpha * (scale_exp_avg / scale_denom) + + # the following is equivalent to: + # if torch.all(state["param_rms"] > rms_max): + # delta = -1.0e-03 + # which means: if the parameter root-mean-square value goes above the + # specified limit, start to decay the parameters (and ignore the computed + # delta). + scale_delta = torch.where(state["param_rms"] > rms_max, + torch.tensor(-1.0e-03, device=scale_delta.device, + dtype=scale_delta.dtype), + scale_delta) + + p.mul_(1 + scale_delta) exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] - beta1, beta2 = group["betas"] - # forget bias_correction1. We use normal momentum. exp_avg really # just stores the moving-average gradient step. - bias_correction2 = 1 - beta2 ** self._step_for_bias_correction(state["step"]) + bias_correction2 = 1 - beta2 ** self._step_for_bias_correction(step) grad = self._change_coordinates(grad, state, forward=True) @@ -566,23 +637,22 @@ class Cain(Optimizer): denom = (exp_avg_sq.sqrt()).add_(group["eps"]) - step_size = group["lr"] # forget bias_correction1 - target_rms = group["target_rms"] - weight_decay = group["weight_decay"] - this_delta = grad / denom this_delta = self._change_coordinates(this_delta, state, forward=False) - # treat/repurpose exp_avg as moving-average of `this_delta` - exp_avg.mul_(beta1).add_(this_delta, alpha=-step_size*(1-beta1)*(bias_correction2 ** 0.5)) + alpha = -lr*(1-beta1)*(bias_correction2 ** 0.5) if p.numel() > 1: - # avoid applying this weight-decay on "scaling factors" - # (which are scalar). - is_above_target_rms = p.norm() > ( - target_rms * (p.numel() ** 0.5) - ) - p.mul_(1 - (weight_decay * is_above_target_rms)) + if step % 10 == 0: + state["param_rms"].fill_((p**2).mean().sqrt()) + # imagine param was normalized to rms=target_rms and we stored the + # scale as a separate scalar. This is how fast we'd be learning. + alpha = (alpha / target_rms) * state["param_rms"].clamp(min=rms_eps) + + # treat/repurpose exp_avg as moving-average of `this_delta` + # don't use alpha=alpha as I don't think + exp_avg.mul_(beta1).add_(this_delta, alpha=alpha) + p.add_(exp_avg) state["step"] += 1 @@ -1059,7 +1129,7 @@ def _test_eve_cain(): input_magnitudes = (1.0 * torch.randn(E, dtype=dtype, device=device)).exp() output_magnitudes = (1.0 * torch.randn(E, dtype=dtype, device=device)).exp() - for iter in [0,1]: + for iter in [1,0]: fix_random_seed(42) Linear = torch.nn.Linear if iter == 0 else ScaledLinear m = torch.nn.Sequential(Linear(E, 200), @@ -1071,7 +1141,7 @@ def _test_eve_cain(): if iter == 0: optim = Eve(m.parameters(), lr=0.003) else: optim = Cain(m.parameters(), lr=0.003) - scheduler = Eden(optim, lr_batches=300, lr_epochs=20, verbose=False) + scheduler = Eden(optim, lr_batches=200, lr_epochs=10, verbose=False) start = timeit.default_timer() for epoch in range(150): @@ -1121,5 +1191,7 @@ def _test_eve_cain(): if __name__ == "__main__": + torch.set_num_threads(1) + torch.set_num_interop_threads(1) _test_eve_cain() #_test_eden()