diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py index c06c6693b..fcdec8c1d 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py @@ -21,7 +21,6 @@ import torch import random from torch import Tensor from torch.optim import Optimizer -from icefall import diagnostics # only for testing code import logging class PrAdam(Optimizer): @@ -120,15 +119,8 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of loss = closure() for group in self.param_groups: - lr = group["lr"] - size_lr = lr * group["size_lr_scale"] - beta1, beta2 = group["betas"] - scalar_max = group["scalar_max"] eps = group["eps"] size_update_period = group["size_update_period"] - param_min_rms = group["param_min_rms"] - param_max_rms = group["param_max_rms"] - lr_update_period = group["lr_update_period"] for p in group["params"]: if p.grad is None: @@ -207,43 +199,59 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of # instead of just using a temporary and smoothing the scalar factor. state[f"grad_cov_{dim}"] = torch.zeros(size, size, **kwargs) - - step = state["step"] - delta = state["delta"] - delta.mul_(beta1) - numel = p.numel() - if numel > 1: - # Update the size/scale of p, and set param_rms - scale_grads = state["scale_grads"] - scale_grads[step % size_update_period] = (p * grad).sum() - if step % size_update_period == size_update_period - 1: - # this learns the overall scale on the parameter, by shrinking or - # expanding it. - param_rms = state["param_rms"] - param_rms.copy_((p ** 2).mean().sqrt().clamp_(min=eps)) - if step > 0: - self._size_update(p, state, - scale_grads, param_rms, - beta1, beta2, step, size_lr, - param_min_rms, param_max_rms) - - - if numel == 1: - # For parameters with very few elements we just use a form - # of Adam with a scale factor to reflect the overall - # parameter rms. Updates delta. - self._step_scalar(scalar_max, beta1, beta2, eps, lr, p, grad, state) - else: - if step % lr_update_period == 0 and step > 0: - self._accum_param_covs(group, p, state) - self._update_lrs(group, p, state) - self._zero_exp_avg_sq(state) - self._step(group, p, grad, state) - p.add_(delta) - state["step"] = step + 1 + self._step_one_param(group, p, state) return loss + def _step_one_param(self, + group: dict, + p: Tensor, + state: dict): + lr = group["lr"] + size_lr = lr * group["size_lr_scale"] + beta1, beta2 = group["betas"] + scalar_max = group["scalar_max"] + eps = group["eps"] + size_update_period = group["size_update_period"] + param_min_rms = group["param_min_rms"] + param_max_rms = group["param_max_rms"] + lr_update_period = group["lr_update_period"] + + grad = p.grad + step = state["step"] + delta = state["delta"] + delta.mul_(beta1) + numel = p.numel() + if numel > 1: + # Update the size/scale of p, and set param_rms + scale_grads = state["scale_grads"] + scale_grads[step % size_update_period] = (p * grad).sum() + if step % size_update_period == size_update_period - 1: + # this learns the overall scale on the parameter, by shrinking or + # expanding it. + param_rms = state["param_rms"] + param_rms.copy_((p ** 2).mean().sqrt().clamp_(min=eps)) + if step > 0: + self._size_update(p, state, + scale_grads, param_rms, + beta1, beta2, step, size_lr, + param_min_rms, param_max_rms) + + if numel == 1: + # For parameters with very few elements we just use a form + # of Adam with a scale factor to reflect the overall + # parameter rms. Updates delta. + self._step_scalar(scalar_max, beta1, beta2, eps, lr, p, grad, state) + else: + if step % lr_update_period == 0 and step > 0: + self._accum_param_covs(group, p, state) + self._update_lrs(group, p, state) + self._zero_exp_avg_sq(state) + self._step(group, p, grad, state) + p.add_(delta) + state["step"] = step + 1 + + def _size_update(self, p: Tensor, state: dict, @@ -1343,7 +1351,8 @@ def _test_eve_cain(): B = 4 T = 2 logging.info("in test_eve_cain") - device = torch.device('cuda') + #device = torch.device('cuda') + device = torch.device('cpu') dtype = torch.float32 fix_random_seed(42) @@ -1376,11 +1385,11 @@ def _test_eve_cain(): #if epoch == 100 and iter in [2,3]: # optim.reset_speedup() # check it doesn't crash. - if epoch == 130: - opts = diagnostics.TensorDiagnosticOptions( - 2 ** 22 - ) # allow 4 megabytes per sub-module - diagnostic = diagnostics.attach_diagnostics(m, opts) + #if epoch == 130: + # opts = diagnostics.TensorDiagnosticOptions( + # 2 ** 22 + # ) # allow 4 megabytes per sub-module + # diagnostic = diagnostics.attach_diagnostics(m, opts) for n, (x,y) in enumerate(train_pairs):