diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py index 7df4260ec..bad0693b3 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py @@ -483,7 +483,10 @@ class LearnedGradient(Optimizer): # get the total magnitude of the change right. The 2 changes below # are the final changes, the only 2 we make in this loop that have # side effects. - delta.add_(this_delta, alpha=-meta_lr*(1-beta1)) + + # delta_scale will make it update the learning rates faster than it otherwise would. + delta_scale=0.2 + delta.add_(this_delta, alpha=-delta_scale*meta_lr*(1-beta1)) # there is no momentum on Q. Q.add_(Q_delta, alpha=-meta_lr) @@ -640,67 +643,12 @@ class LearnedGradient(Optimizer): var0 = (x**2).mean(dim=1, keepdim=True) + eps*eps var1 = ((x / var0.sqrt()) ** 2).mean(dim=0, keepdim=True) + eps*eps - # esetimate var0 one more time. + # estimate var0 one more time. var0 = ((x / var1.sqrt()) ** 2).mean(dim=1, keepdim=True) + eps*eps return var1 * var0 - def _update_lr(self, - lr_mat: Tensor, - meta_lr: float) -> None: - """ - Do one step of update for learning rate matrix 'lr_mat', which we assume has already - has its `grad` property populated with the grad of the negative loss (our special - loss that we use to train the learning rate matrix). Also delete lr_mat.grad - - Args: - lr_mat: The symmetric positive-semidefinite (PSD) learning-rate matrix. - Will have its .grad populated with the derivative of a negated - loss function (i.e. to be maximized), that we use to optimize - this learning-rate matrix. lr_mat.grad will be deleted and - lr_mat will be updated by one step. - meta_lr: The speed with which we learn lr_mat. - """ - grad = lr_mat.grad - - if random.random() < 0.1: - # A check. - inner = (grad * lr_mat).sum() / ((grad ** 2).sum() * (lr_mat ** 2).sum()).sqrt() - if inner.abs() > 0.01: - logging.info(f"Warning: inner product of lr with grad is {inner}, should be zero; dim is {lr_mat.shape[0]}") - - # OK. suppose lr_mat = M, which is symmetric positive definite. - # Decompose: M = A N A^T - # where N is numerically equal to M (but a separate variable), and A is a variable numerically - # equal to I. To make it easier to keep the result positive definite, we learn A rather than N. - # We'll use a convention where A_grad has the same orientation as A, i.e. A_grad_ij equals the grad - # of loss w.r.t A_ij. - # The grad w.r.t A equals (N M_grad)^T + (M_grad N) = 2 (M_grad N). We don't care about scalar - # factors at this point (we'll normalize them out), so ignore the factor of 2, and say; - # A_grad = M_grad N. - A_grad = torch.matmul(grad, lr_mat) - - - # Normalize the size of `A_grad` (there are arbitrary scalar factors in this loss function), - # and include meta_lr - A_grad = A_grad * (meta_lr / ((A_grad ** 2).mean().sqrt() + 1.0e-20)) - - # So the updated A is just I plus lr_mat * normalize[A_grad] - - size = A.shape[0] - A = torch.eye(size, device=A.device, dtype=A.dtype) + meta_lr * A_grad - - # the result is positive semidefinite if the initial LR_mat was positive semidefinite, - # because it is of the form X X^T where X = A lr_mat^{0.5} - torch.matmul(A, torch.matmul(lr_mat, A.t()), out=lr_mat) - - # Normalize the scale of lr_mat: we always ensure lr_mat.diag().mean() is 1.0. - # Also ensure it is perfectly symmetric, or it may drift away from symmetry due to - # roundoff. - lr_mat[:] = (lr_mat + lr_mat.t()) / (2.0 * lr_mat.diag().mean()) - - def _step(self, group: dict, p: Tensor,