diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless2/optim.py index b0d269571..432bf8220 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/optim.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/optim.py @@ -15,11 +15,9 @@ # limitations under the License. -import random -from typing import List, Optional, Tuple, Union +from typing import List, Optional, Union import torch -from torch import Tensor from torch.optim import Optimizer @@ -59,24 +57,41 @@ class Eve(Optimizer): https://openreview.net/forum?id=ryQu7f-RZ """ - def __init__(self, params, lr=1e-3, betas=(0.9, 0.98), eps=1e-8, - weight_decay=1e-3, target_rms=0.1): + def __init__( + self, + params, + lr=1e-3, + betas=(0.9, 0.98), + eps=1e-8, + weight_decay=1e-3, + target_rms=0.1, + ): if not 0.0 <= lr: raise ValueError("Invalid learning rate: {}".format(lr)) if not 0.0 <= eps: raise ValueError("Invalid epsilon value: {}".format(eps)) if not 0.0 <= betas[0] < 1.0: - raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) + raise ValueError( + "Invalid beta parameter at index 0: {}".format(betas[0]) + ) if not 0.0 <= betas[1] < 1.0: - raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) + raise ValueError( + "Invalid beta parameter at index 1: {}".format(betas[1]) + ) if not 0 <= weight_decay <= 0.1: - raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) + raise ValueError( + "Invalid weight_decay value: {}".format(weight_decay) + ) if not 0 < target_rms <= 10.0: raise ValueError("Invalid target_rms value: {}".format(target_rms)) - defaults = dict(lr=lr, betas=betas, eps=eps, - weight_decay=weight_decay, - target_rms=target_rms) + defaults = dict( + lr=lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + target_rms=target_rms, + ) super(Eve, self).__init__(params, defaults) def __setstate__(self, state): @@ -96,83 +111,98 @@ class Eve(Optimizer): loss = closure() for group in self.param_groups: - for p in group['params']: + for p in group["params"]: if p.grad is None: continue # Perform optimization step grad = p.grad if grad.is_sparse: - raise RuntimeError('AdamW does not support sparse gradients') + raise RuntimeError( + "AdamW does not support sparse gradients" + ) state = self.state[p] # State initialization if len(state) == 0: - state['step'] = 0 + state["step"] = 0 # Exponential moving average of gradient values - state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format) + state["exp_avg"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) # Exponential moving average of squared gradient values - state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) + state["exp_avg_sq"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) - exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] + exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] - beta1, beta2 = group['betas'] + beta1, beta2 = group["betas"] - state['step'] += 1 - bias_correction1 = 1 - beta1 ** state['step'] - bias_correction2 = 1 - beta2 ** state['step'] + state["step"] += 1 + bias_correction1 = 1 - beta1 ** state["step"] + bias_correction2 = 1 - beta2 ** state["step"] # Decay the first and second moment running average coefficient exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) - denom = (exp_avg_sq.sqrt() * (bias_correction2 ** -0.5)).add_(group['eps']) + denom = (exp_avg_sq.sqrt() * (bias_correction2 ** -0.5)).add_( + group["eps"] + ) - step_size = group['lr'] / bias_correction1 - target_rms = group['target_rms'] - weight_decay = group['weight_decay'] - delta = exp_avg / denom + step_size = group["lr"] / bias_correction1 + target_rms = group["target_rms"] + weight_decay = group["weight_decay"] if p.numel() > 1: # avoid applying this weight-decay on "scaling factors" # (which are scalar). - is_above_target_rms = (p.norm() > (target_rms * (p.numel() ** 0.5))) + is_above_target_rms = p.norm() > ( + target_rms * (p.numel() ** 0.5) + ) p.mul_(1 - (weight_decay * is_above_target_rms)) p.addcdiv_(exp_avg, denom, value=-step_size) return loss + class LRScheduler(object): """ Base-class for learning rate schedulers where the learning-rate depends on both the batch and the epoch. """ + def __init__(self, optimizer: Optimizer, verbose: bool = False): # Attach optimizer if not isinstance(optimizer, Optimizer): - raise TypeError('{} is not an Optimizer'.format( - type(optimizer).__name__)) + raise TypeError( + "{} is not an Optimizer".format(type(optimizer).__name__) + ) self.optimizer = optimizer self.verbose = verbose for group in optimizer.param_groups: - group.setdefault('initial_lr', group['lr']) + group.setdefault("initial_lr", group["lr"]) - self.base_lrs = [group['initial_lr'] for group in optimizer.param_groups] + self.base_lrs = [ + group["initial_lr"] for group in optimizer.param_groups + ] self.epoch = 0 self.batch = 0 - def state_dict(self): """Returns the state of the scheduler as a :class:`dict`. It contains an entry for every variable in self.__dict__ which is not the optimizer. """ - return {'base_lrs': self.base_lrs, - 'epoch': self.epoch, - 'batch': self.batch} + return { + "base_lrs": self.base_lrs, + "epoch": self.epoch, + "batch": self.batch, + } def load_state_dict(self, state_dict): """Loads the schedulers state. @@ -184,8 +214,7 @@ class LRScheduler(object): self.__dict__.update(state_dict) def get_last_lr(self) -> List[float]: - """ Return last computed learning rate by current scheduler. Will be a list of float. - """ + """Return last computed learning rate by current scheduler. Will be a list of float.""" return self._last_lr def get_lr(self): @@ -194,7 +223,6 @@ class LRScheduler(object): # e.g. return [some_formula(self.batch, self.epoch, base_lr) for base_lr in self.base_lrs ] raise NotImplementedError - def step_batch(self, batch: Optional[int] = None) -> None: # Step the batch index, or just set it. If `batch` is specified, it # must be the batch index from the start of training, i.e. summed over @@ -217,24 +245,23 @@ class LRScheduler(object): self.epoch = self.epoch + 1 self._set_lrs() - def _set_lrs(self): values = self.get_lr() assert len(values) == len(self.optimizer.param_groups) for i, data in enumerate(zip(self.optimizer.param_groups, values)): param_group, lr = data - param_group['lr'] = lr + param_group["lr"] = lr self.print_lr(self.verbose, i, lr) - self._last_lr = [group['lr'] for group in self.optimizer.param_groups] - + self._last_lr = [group["lr"] for group in self.optimizer.param_groups] def print_lr(self, is_verbose, group, lr): - """Display the current learning rate. - """ + """Display the current learning rate.""" if is_verbose: - print(f'Epoch={self.epoch}, batch={self.batch}: adjusting learning rate' - f' of group {group} to {lr:.4e}.') + print( + f"Epoch={self.epoch}, batch={self.batch}: adjusting learning rate" + f" of group {group} to {lr:.4e}." + ) class Eden(LRScheduler): @@ -254,18 +281,27 @@ class Eden(LRScheduler): 20 to 40 epochs, but may need smaller number if dataset is huge and you will do few epochs. """ - def __init__(self, optimizer: Optimizer, - lr_batches: Union[int, float], - lr_epochs: Union[int, float], - verbose: bool = False): + + def __init__( + self, + optimizer: Optimizer, + lr_batches: Union[int, float], + lr_epochs: Union[int, float], + verbose: bool = False, + ): super(Eden, self).__init__(optimizer, verbose) self.lr_batches = lr_batches self.lr_epochs = lr_epochs def get_lr(self): - factor = (((self.batch**2 + self.lr_batches**2) / self.lr_batches**2) ** -0.25 * - (((self.epoch**2 + self.lr_epochs**2) / self.lr_epochs**2) ** -0.25)) - return [ x * factor for x in self.base_lrs ] + factor = ( + (self.batch ** 2 + self.lr_batches ** 2) / self.lr_batches ** 2 + ) ** -0.25 * ( + ((self.epoch ** 2 + self.lr_epochs ** 2) / self.lr_epochs ** 2) + ** -0.25 + ) + return [x * factor for x in self.base_lrs] + def _test_eden(): m = torch.nn.Linear(100, 100) @@ -290,5 +326,6 @@ def _test_eden(): print("last lr = ", scheduler.get_last_lr()) print("state dict = ", scheduler.state_dict()) -if __name__ == '__main__': + +if __name__ == "__main__": _test_eden()