From 72f4a673b106fefc9c88841a3e4ff3a9d1d6fd88 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 4 Apr 2022 20:21:34 +0800 Subject: [PATCH] First draft of new approach to learning rates + init --- .../pruned_transducer_stateless2/conformer.py | 87 ------ .../ASR/pruned_transducer_stateless2/optim.py | 254 ++++++++++++++++++ .../pruned_transducer_stateless2/scaling.py | 12 +- .../ASR/pruned_transducer_stateless2/train.py | 50 +++- 4 files changed, 299 insertions(+), 104 deletions(-) create mode 100644 egs/librispeech/ASR/pruned_transducer_stateless2/optim.py diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py index 0deb960ad..4797cce08 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py @@ -1017,93 +1017,6 @@ class Conv2dSubsampling(nn.Module): return x -class Noam(object): - """ - Implements Noam optimizer. - - Proposed in - "Attention Is All You Need", https://arxiv.org/pdf/1706.03762.pdf - - Modified from - https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/optimizer.py # noqa - - Args: - params: - iterable of parameters to optimize or dicts defining parameter groups - model_size: - attention dimension of the transformer model - factor: - learning rate factor - warm_step: - warmup steps - """ - - def __init__( - self, - params, - model_size: int = 256, - factor: float = 10.0, - warm_step: int = 25000, - weight_decay=0, - ) -> None: - """Construct an Noam object.""" - self.optimizer = torch.optim.Adam( - params, lr=0, betas=(0.9, 0.98), eps=1e-9, weight_decay=weight_decay - ) - self._step = 0 - self.warmup = warm_step - self.factor = factor - self.model_size = model_size - self._rate = 0 - - @property - def param_groups(self): - """Return param_groups.""" - return self.optimizer.param_groups - - def step(self): - """Update parameters and rate.""" - self._step += 1 - rate = self.rate() - for p in self.optimizer.param_groups: - p["lr"] = rate - self._rate = rate - self.optimizer.step() - - def rate(self, step=None): - """Implement `lrate` above.""" - if step is None: - step = self._step - return ( - self.factor - * self.model_size ** (-0.5) - * self.warmup ** (-0.5 - -0.333) - * min(step ** (-0.333), step * self.warmup ** (-1.333)) - ) - - def zero_grad(self): - """Reset gradient.""" - self.optimizer.zero_grad() - - def state_dict(self): - """Return state_dict.""" - return { - "_step": self._step, - "warmup": self.warmup, - "factor": self.factor, - "model_size": self.model_size, - "_rate": self._rate, - "optimizer": self.optimizer.state_dict(), - } - - def load_state_dict(self, state_dict): - """Load state_dict.""" - for key, value in state_dict.items(): - if key == "optimizer": - self.optimizer.load_state_dict(state_dict["optimizer"]) - else: - setattr(self, key, value) - if __name__ == '__main__': feature_dim = 50 diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless2/optim.py new file mode 100644 index 000000000..edbebcceb --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/optim.py @@ -0,0 +1,254 @@ +# Copyright 2022 Xiaomi Corp. (authors: Daniel Povey) +# +# See ../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import random +from typing import List, Optional, Tuple + +import torch +from torch import Tensor +from torch.optim import Optimizer + + +class Eve(Optimizer): + r""" + Implements Eve algorithm. This is a modified version of AdamW with a special + way of setting the weight-decay / shrinkage-factor, which is designed to make the + rms of the parameters approach a particular specified value (generally 0.1). This is + for use with networks with 'scaled' versions of modules (see scaling.py), which + will be close to invariant to the absolute scale on the parameter matrix. + + The original Adam algorithm was proposed in `Adam: A Method for Stochastic Optimization`_. + The AdamW variant was proposed in `Decoupled Weight Decay Regularization`_. + Eve is unpublished so far. + + Arguments: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 1e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-8) + weight_decay (float, optional): weight decay coefficient (default: 1e-2) + amsgrad (boolean, optional): whether to use the AMSGrad variant of this + algorithm from the paper `On the Convergence of Adam and Beyond`_ + (default: False) + + .. _Adam\: A Method for Stochastic Optimization: + https://arxiv.org/abs/1412.6980 + .. _Decoupled Weight Decay Regularization: + https://arxiv.org/abs/1711.05101 + .. _On the Convergence of Adam and Beyond: + https://openreview.net/forum?id=ryQu7f-RZ + """ + + def __init__(self, params, lr=1e-3, betas=(0.9, 0.98), eps=1e-8, + target_rms=0.1): + + if not 0.0 <= lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 <= eps: + raise ValueError("Invalid epsilon value: {}".format(eps)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) + if not 0.0 <= betas[1] < 1.0: + raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) + if not 0 < target_rms <= 10.0: + raise ValueError("Invalid target_rms value: {}".format(target_rms)) + defaults = dict(lr=lr, betas=betas, eps=eps, + target_rms=target_rms) + super(Eve, self).__init__(params, defaults) + + def __setstate__(self, state): + super(Eve, self).__setstate__(state) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + for p in group['params']: + if p.grad is None: + continue + + # Perform optimization step + grad = p.grad + if grad.is_sparse: + raise RuntimeError('AdamW does not support sparse gradients') + + state = self.state[p] + + # State initialization + if len(state) == 0: + state['step'] = 0 + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format) + # Exponential moving average of squared gradient values + state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) + + exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] + + beta1, beta2 = group['betas'] + + state['step'] += 1 + bias_correction1 = 1 - beta1 ** state['step'] + bias_correction2 = 1 - beta2 ** state['step'] + + # Decay the first and second moment running average coefficient + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) + denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps']) + + step_size = group['lr'] / bias_correction1 + target_rms = group['target_rms'] + delta = exp_avg / denom + + # we'll be doing: p += delta * step_size. + # In the normal case delta_rms (the rms value of the elements of + # delta) will be very close to 1.0, but we compute it here so + # that if we don't use a particular parameter, its value won't + # shrink to zero. + # delta_var is the expected change in the variance of the parameter + # values, i.e. of E[param_elem^2], due to this step. It will + # be close to 1. + + # Let us define: + # delta_var_from_update = (delta**2).mean() * step_size * step_size + + # Suppose we are going to shrinkage with a small value epsilon (not the + # same as the eps above!), i.e. param *= (1-epsilon). Then + # if E[param_elem^2] == target_rms^2, + # E[(param_elem*(1-epsilon))^2] == target_rms^2 (1- 2epsilon + epsilon^2), + # which we can put as: + # delta_var_from_shrinkage \simeq -2 epsilon target_rms^2. + # Setting delta_var_from_shrinkage = -delta_var_from_update + # because we want them to cancel, + # delta_var_from_update = 2 epsilon target_rms^2, or: + # epsilon = delta_var_from_update / (2 * target_rms^2) + # = (delta**2).mean() * 0.5 * (step_size / target_rms)**2. + # Note: step_size is close to the learning rate. For an example, if + # lr = 1.0e-04 and target_rms == 0.1, then in the normal case where + # (delta**2).mean() == 1, we will have: + # epsilon = 1.0 * 0.5 * (1.0e-04 / 0.1) = 1.0e-06. + # Note that this is close to the "traditional" value used for weight + # decay. + + # this is the weight-decay amount... + weight_decay = (delta ** 2).mean().sqrt() * ((0.5 * (step_size / target_rms)) ** 2) + + p.mul_(1 - weight_decay) + p.add_(delta, alpha=-step_size) + + return loss + + + +class Noam(object): + """ + Implements Noam optimizer. + + Proposed in + "Attention Is All You Need", https://arxiv.org/pdf/1706.03762.pdf + + Modified from + https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/optimizer.py # noqa + + Args: + params: + iterable of parameters to optimize or dicts defining parameter groups + model_size: + attention dimension of the transformer model + factor: + learning rate factor + warm_step: + warmup steps + """ + + def __init__( + self, + params, + model_size: int = 256, + factor: float = 10.0, + warm_step: int = 25000, + weight_decay=0, + ) -> None: + """Construct an Noam object.""" + self.optimizer = torch.optim.Adam( + params, lr=0, betas=(0.9, 0.98), eps=1e-9, weight_decay=weight_decay + ) + self._step = 0 + self.warmup = warm_step + self.factor = factor + self.model_size = model_size + self._rate = 0 + + @property + def param_groups(self): + """Return param_groups.""" + return self.optimizer.param_groups + + def step(self): + """Update parameters and rate.""" + self._step += 1 + rate = self.rate() + for p in self.optimizer.param_groups: + p["lr"] = rate + self._rate = rate + self.optimizer.step() + + def rate(self, step=None): + """Implement `lrate` above.""" + if step is None: + step = self._step + return ( + self.factor + * self.model_size ** (-0.5) + * self.warmup ** (-0.5 - -0.333) + * min(step ** (-0.333), step * self.warmup ** (-1.333)) + ) + + def zero_grad(self): + """Reset gradient.""" + self.optimizer.zero_grad() + + def state_dict(self): + """Return state_dict.""" + return { + "_step": self._step, + "warmup": self.warmup, + "factor": self.factor, + "model_size": self.model_size, + "_rate": self._rate, + "optimizer": self.optimizer.state_dict(), + } + + def load_state_dict(self, state_dict): + """Load state_dict.""" + for key, value in state_dict.items(): + if key == "optimizer": + self.optimizer.load_state_dict(state_dict["optimizer"]) + else: + setattr(self, key, value) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py index 4c45205ce..33b4ad908 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py @@ -158,7 +158,10 @@ class ScaledLinear(nn.Linear): self._reset_parameters(initial_speed) # Overrides the reset_parameters in nn.Linear def _reset_parameters(self, initial_speed: float): - std = 0.01 / initial_speed + # we plan to use Eve as the optimizer, which will eventually make the stddev approach + # 0.1 as that's the target_rms we set, but we initialize with a larger stddev + # to have the same effect as a warm-up period. + std = 0.5 / initial_speed a = (3 ** 0.5) * std nn.init.uniform_(self.weight, -a, a) if self.bias is not None: @@ -196,7 +199,7 @@ class ScaledConv1d(nn.Conv1d): self._reset_parameters(initial_speed) # Overrides the reset_parameters in base class def _reset_parameters(self, initial_speed: float): - std = 0.01 / initial_speed + std = 0.5 / initial_speed a = (3 ** 0.5) * std nn.init.uniform_(self.weight, -a, a) if self.bias is not None: @@ -241,7 +244,7 @@ class ScaledConv2d(nn.Conv2d): self._reset_parameters(initial_speed) # Overrides the reset_parameters in base class def _reset_parameters(self, initial_speed: float): - std = 0.01 / initial_speed + std = 0.5 / initial_speed a = (3 ** 0.5) * std nn.init.uniform_(self.weight, -a, a) if self.bias is not None: @@ -476,9 +479,8 @@ class ScaledEmbedding(nn.Module): self.reset_parameters(initial_speed) - def reset_parameters(self, initial_speed: float = 1.0) -> None: - std = 0.01 / initial_speed + std = 0.5 / initial_speed nn.init.normal_(self.weight, std=std) nn.init.constant_(self.scale, torch.tensor(1.0/std).log()) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py index bf7f23fab..9d074fdd4 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py @@ -28,7 +28,10 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3" --exp-dir pruned_transducer_stateless2/exp \ --full-libri 1 \ --max-duration 300 \ - --lr-factor 1.5 + --initial-lr 0.002 \ + --lr-decay-steps 10000 \ + --num-lr-decays 4 + """ @@ -52,6 +55,7 @@ from lhotse.cut import Cut from lhotse.dataset.sampling.base import CutSampler from lhotse.utils import fix_random_seed from model import Transducer +from optim import Eve from torch import Tensor from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter @@ -141,17 +145,24 @@ def get_parser(): ) parser.add_argument( - "--lr-factor", + "--initial-lr", type=float, - default=5.0, - help="The lr_factor for Noam optimizer", + default=0.002, + help="The initial learning rate", ) parser.add_argument( - "--warm-step", + "--lr-decay-steps", type=float, - default=60000, - help="The number of warmup steps for the (modified) Noam optimizer", + default=5000, + help="The number of steps before we decay (halve) the learning rate", + ) + + parser.add_argument( + "--num-lr-decays", + type=float, + default=4, + help="The total number of times we decay (halve) the learning rate" ) parser.add_argument( @@ -426,6 +437,7 @@ def save_checkpoint( params: AttributeDict, model: nn.Module, optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[torch.optimal.lr_scheduler._LRScheduler] = None, sampler: Optional[CutSampler] = None, rank: int = 0, ) -> None: @@ -449,6 +461,7 @@ def save_checkpoint( model=model, params=params, optimizer=optimizer, + scheduler=scheduler, sampler=sampler, rank=rank, ) @@ -574,6 +587,7 @@ def train_one_epoch( params: AttributeDict, model: nn.Module, optimizer: torch.optim.Optimizer, + scheduler: torch.optim.lr_scheduler._LRScheduler, sp: spm.SentencePieceProcessor, train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, @@ -594,6 +608,8 @@ def train_one_epoch( The model for training. optimizer: The optimizer we are using. + scheduler: + The learning rate scheduler, we call step() every step. train_dl: Dataloader for the training dataset. valid_dl: @@ -636,6 +652,7 @@ def train_one_epoch( loss.backward() optimizer.step() optimizer.zero_grad() + lr_scheduler.step() if params.print_diagnostics and batch_idx == 5: return @@ -651,6 +668,7 @@ def train_one_epoch( model=model, params=params, optimizer=optimizer, + scheduler=scheduler, sampler=train_dl.sampler, rank=rank, ) @@ -756,17 +774,24 @@ def run(rank, world_size, args): model = DDP(model, device_ids=[rank]) model.device = device - optimizer = Noam( + optimizer = Eve( model.parameters(), - model_size=params.encoder_dim, - factor=params.lr_factor, - warm_step=params.warm_step, - ) + lr=params.initial_lr, betas=(0.9, 0.98), + eps=1e-9, target_rms=0.1) + scheduler = torch.optim.lr_scheduler.MultiStepLR( + optimizer, + [ n * params.lr_decay_steps for n in range(1, params.num_lr_decays+1) ], + gamma=0.5) + if checkpoints and "optimizer" in checkpoints: logging.info("Loading optimizer state dict") optimizer.load_state_dict(checkpoints["optimizer"]) + if checkpoints and "scheduler" in checkpoints: + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( @@ -839,6 +864,7 @@ def run(rank, world_size, args): params=params, model=model, optimizer=optimizer, + scheduler=scheduler, sp=sp, train_dl=train_dl, valid_dl=valid_dl,