diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py index c5d9ac658..51881f9a4 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py @@ -52,6 +52,10 @@ class NeutralGradient(Optimizer): decreasing with the parameter norm. param_rel_eps: An epsilon that limits how different different parameter rms estimates can be within the same tensor + param_rel_max: Limits how big we allow the parameter variance eigs to be relative to the + original mean. Setting this to a not-very-large value like 1 ensures that + we only apply the param scaling for smaller-than-average, not larger-than-average, + eigenvalues, which limits the dispersion of eigenvalues of the parameter tensors. param_max: To prevent parameter tensors getting too large, we will clip elements to -param_max..param_max. max_fullcov_size: We will only use the full-covariance (not-diagonal) update for @@ -85,11 +89,12 @@ class NeutralGradient(Optimizer): self, params, lr=3e-02, - betas=(0.9, 0.98), + betas=(0.9, 0.98, 0.95), scale_speed=0.1, grad_eps=1e-10, param_eps=1.0e-06, param_rel_eps=1.0e-04, + param_rel_max=1.0, param_max_rms=2.0, param_min_rms=1.0e-05, max_fullcov_size=1023, @@ -140,6 +145,7 @@ class NeutralGradient(Optimizer): grad_eps=grad_eps, param_eps=param_eps, param_rel_eps=param_rel_eps, + param_rel_max=param_rel_max, param_max_rms=param_max_rms, param_min_rms=param_min_rms, max_fullcov_size=max_fullcov_size, @@ -174,11 +180,12 @@ class NeutralGradient(Optimizer): for group in self.param_groups: lr = group["lr"] - beta1, beta2 = group["betas"] + beta1, beta2, beta3 = group["betas"] scale_speed = group["scale_speed"] grad_eps = group["grad_eps"] param_eps = group["param_eps"] param_rel_eps = group["param_rel_eps"] + param_rel_max = group["param_rel_max"] param_max_rms = group["param_max_rms"] param_min_rms = group["param_min_rms"] max_fullcov_size = group["max_fullcov_size"] @@ -271,8 +278,7 @@ class NeutralGradient(Optimizer): state["step_within_period"] = random.randint(0, estimate_period-stats_steps) - if param_pow != 1.0: - state["scalar_exp_avg_sq"] = torch.zeros((), **kwargs) + state["scalar_exp_avg_sq"] = torch.zeros((), **kwargs) used_scale = False for dim in range(p.ndim): @@ -285,6 +291,11 @@ class NeutralGradient(Optimizer): state[f"proj_{dim}"] = torch.ones(size, **kwargs) else: state[f"proj_{dim}"] = torch.eye(size, **kwargs) + + # we will decay the param stats via beta3; they are only + # accumulated once every `estimate_period` + state[f"param_cov_{dim}"] = torch.zeros(size, size, **kwargs) + if not used_scale: param_rms = (p**2).mean().sqrt().add_(param_eps) state[f"proj_{dim}"] *= param_rms @@ -389,6 +400,7 @@ class NeutralGradient(Optimizer): step_within_period = state["step_within_period"] if step_within_period == estimate_period: self._estimate_projections(p, state, param_eps, param_rel_eps, + param_rel_max, beta3, param_pow, grad_pow, grad_min_rand) state["step_within_period"] = 0 elif step_within_period >= estimate_period - stats_steps: @@ -437,7 +449,8 @@ class NeutralGradient(Optimizer): logging.info(f"cos_angle = {cos_angle}, shape={grad.shape}") alpha = -lr * (1-beta1) - if param_pow != 1.0 or grad_pow != 1.0: + + if True: # Renormalize scale of cur_grad scalar_exp_avg_sq = state["scalar_exp_avg_sq"] scalar_exp_avg_sq.mul_(beta2).add_((cur_grad**2).mean()/n_cached_grads, alpha=(1-beta2)) @@ -509,6 +522,8 @@ class NeutralGradient(Optimizer): state: dict, param_eps: float, param_rel_eps: float, + param_rel_max: float, + beta3: float, param_pow: float, grad_pow: float, grad_min_rand: float, @@ -523,7 +538,6 @@ class NeutralGradient(Optimizer): co-ordinates. (Thus, applying with forward==True and then forward==False is not a round trip). - Args: p: The parameter for which we are etimating the projections. Supplied because we need this to estimate parameter covariance matrices in @@ -535,6 +549,7 @@ class NeutralGradient(Optimizer): param_rel_eps: Another constraint on minimum parameter rms, relative to sqrt(overall_param_rms), where overall_param_rms is the sqrt of the parameter variance averaged over the entire tensor. + beta3: discounting parameter for param_cov, applied once every estimate_period. param_pow: param_pow==1.0 means fully normalizing for parameter scale; setting to a smaller value closer to 0.0 means partial normalization. grad_min_rand: minimum proportion of random tensor to mix with the @@ -557,10 +572,13 @@ class NeutralGradient(Optimizer): grad_cov = state[f"grad_cov_{dim}"] / count del state[f"grad_cov_{dim}"] # save memory - self._randomize_lowrank_cov(grad_cov, count, p.shape, - grad_min_rand) + #self._randomize_lowrank_cov(grad_cov, count, p.shape, + # grad_min_rand) - param_cov = self._get_param_cov(p, dim) + cur_param_cov = self._get_param_cov(p, dim) + cur_param_cov = cur_param_cov / (cur_param_cov.diag().mean() + 1.0e-20) # normalize size.. + param_cov = state[f"param_cov_{dim}"] + param_cov.mul_(beta3).add_(cur_param_cov, alpha=(1-beta3)) # We want the orthogonal matrix U that diagonalizes P, where # P is the SPD matrix such that P G^{param_pow} P^T == C^{param_pow}, @@ -570,7 +588,11 @@ class NeutralGradient(Optimizer): # since the P's are related by taking-to-a-power. P = self._estimate_proj(grad_cov, param_cov, - param_pow / grad_pow) + param_pow / grad_pow, + param_rel_eps, + param_rel_max) + + # The only thing we want from P is the basis that diagonalizes # it, i.e. if we do the symmetric svd P = U S U^T, we can # interpret the shape of U as (canonical_coordinate, @@ -585,7 +607,8 @@ class NeutralGradient(Optimizer): proj.fill_(1.0) self._estimate_param_scale(p, state, param_eps, - param_rel_eps, param_pow) + param_rel_eps, param_rel_max, + param_pow) def _estimate_param_scale(self, @@ -593,6 +616,7 @@ class NeutralGradient(Optimizer): state: dict, param_eps: float, param_rel_eps: float, + param_rel_max: float, param_pow: float) -> None: """ This is called from _estimate_projections() after suitably "diagonalizing" bases @@ -615,8 +639,10 @@ class NeutralGradient(Optimizer): params_sq_norm = params_sq - if param_pow != 1.0: - params_sq_partnorm = params_sq + + # param_diag_vars is the diagonals of the parameter variances + # on each tensor dim, with an arbitrary scalar factor. + param_diag_vars = [None] * p.ndim for _ in range(3): # Iterate 3 times, this should be enough to converge. @@ -630,35 +656,39 @@ class NeutralGradient(Optimizer): # (this is after normalizing previous dimensions' variance) this_var = params_sq_norm.mean(dim=other_dims, keepdim=True) params_sq_norm = params_sq_norm / this_var - if param_pow != 1.0: - params_sq_partnorm = params_sq_partnorm / (this_var ** param_pow) - this_var = this_var.reshape(size) - this_scale = (this_var ** (param_pow * 0.5)).reshape(size) - proj = state[f"proj_{dim}"] - if proj.ndim == 1: - proj *= this_scale + if param_diag_vars[dim] == None: + param_diag_vars[dim] = this_var.reshape(size) else: - # scale the rows; dim 0 of `proj` corresponds to the - # diagonalized space. - proj *= this_scale.unsqueeze(-1) + param_diag_vars[dim] *= this_var.reshape(size) + + for dim in range(p.ndim): + size = p.shape[dim] + if size == 1: + continue + param_diag_var = param_diag_vars[dim] + param_diag_var = self._smooth_param_diag_var(param_diag_var, + param_pow, + param_rel_eps, + param_rel_max) + param_scale = param_diag_var ** 0.5 + proj = state[f"proj_{dim}"] + + if proj.ndim == 1: + proj *= param_scale + else: + # scale the rows; dim 0 of `proj` corresponds to the + # diagonalized space. + proj *= param_scale.unsqueeze(-1) + # the scalar factors on the various dimensions' projections are a little + # arbitrary, and their product is arbitrary too. we normalize the scalar + # factor outside this function, see calar_exp_avg_sq. - if param_pow != 1.0: - # need to get the overall scale correct, as if we had param_pow == 1.0 - scale = (params_sq_partnorm.mean() ** 0.5) - for dim in range(p.ndim): - size = p.shape[dim] - if size == 1: - continue - else: - state[f"proj_{dim}"] *= scale - break # Reset the squared-gradient stats because we have changed the basis. state["exp_avg_sq"].fill_(0.0) - if param_pow != 1.0: - state["scalar_exp_avg_sq"].fill_(0.0) + state["scalar_exp_avg_sq"].fill_(0.0) @@ -726,8 +756,8 @@ class NeutralGradient(Optimizer): param_cov = torch.matmul(p.t(), p) / num_outer_products - self._randomize_lowrank_cov(param_cov, num_outer_products, - p.shape) + #self._randomize_lowrank_cov(param_cov, num_outer_products, + # p.shape) return param_cov @@ -794,8 +824,31 @@ class NeutralGradient(Optimizer): (proj.t() if forward else proj)).transpose(-1, dim) - def _estimate_proj(self, grad_cov: Tensor, param_cov: Tensor, - param_pow: float = 1.0) -> Tensor: + def _smooth_param_diag_var(self, + param_diag: Tensor, + param_pow: float, + param_rel_eps: float, + param_rel_max: float) -> Tensor: + """ + Applies the smoothing formula to the eigenvalues of the parameter covariance + tensor. + (Actually when we use this function in _get_param_cov, they won't exactly + be the eigenvalues because we diagonalize relative to the grad_cov, + but they will be parameter covariances in certain directions). + """ + # normalize so mean = 1.0 + param_diag = param_diag / (param_diag.mean() + 1.0e-20) + # use 1/(1/x + 1/y) to apply soft-max of param_diag with param_rel_max. + # use param_rel_eps here to prevent division by zero. + param_diag = 1. / (1. / (param_diag + param_rel_eps) + 1. / param_rel_max) + return param_diag ** param_pow + + def _estimate_proj(self, + grad_cov: Tensor, + param_cov: Tensor, + param_pow: float = 1.0, + param_rel_eps: float = 0.0, + param_rel_max: float = 1.0e+10) -> Tensor: """ Return a symmetric positive definite matrix P such that (P grad_cov P^T == param_cov^{param_pow}), @@ -839,8 +892,11 @@ class NeutralGradient(Optimizer): # P = Q^T (Q G Q^T)^{-0.5} Q. # because C is symmetric, C == U S U^T, we can ignore V. - # S_sqrt is S.sqrt() in the normal case where param_pow == 1.0 - S_sqrt = S ** (0.5 * param_pow) + # S_sqrt is S.sqrt() in the limit where param_pow == 1.0, + # param_rel_eps=0, param_rel_max=inf + S_smoothed = self._smooth_param_diag_var(S, param_pow, + param_rel_eps, param_rel_max) + S_sqrt = S_smoothed ** 0.5 Q = (U * S_sqrt).t() X = torch.matmul(Q, torch.matmul(G, Q.t())) @@ -856,23 +912,24 @@ class NeutralGradient(Optimizer): P = torch.matmul(Y, Y.t()) - if random.random() < 0.025: - - # TEMP: - _,s,_ = P.svd() - print(f"Min,max eig of P: {s.min()},{s.max()}") - + if random.random() < 1.0: #0.025: # TODO: remove this testing code. assert (P - P.t()).abs().mean() < 0.01 # make sure symmetric. + try: + P = 0.5 * (P + P.t()) + _,s,_ = P.svd() + print(f"Min,max eig of P: {s.min()},{s.max()}") + except: + pass # testing... note, this is only true modulo "eps" C_check = torch.matmul(torch.matmul(P, G), P) - # C_check should equal C - C_diff = C_check - C + C_smoothed = torch.matmul(Q.t(), Q) + # C_check should equal C_smoothed + C_diff = C_check - C_smoothed # Roundoff can cause significant differences, so use a fairly large # threshold of 0.001. We may increase this later or even remove the check. - if not C_diff.abs().mean() < 0.01 * C.diag().mean(): - print(f"Warning: large C_diff: {C_diff.abs().mean()}, C diag mean: {C.diag().mean()}") - + if not C_diff.abs().mean() < 0.01 * C_smoothed.diag().mean(): + print(f"Warning: large C_diff: {C_diff.abs().mean()}, C diag mean: {C_smoothed.diag().mean()}") return P @@ -899,7 +956,7 @@ class NeutralGradient(Optimizer): return - _, beta2 = group["betas"] + _, beta2, _ = group["betas"] for p in group["params"]: if p.grad is None: @@ -1602,9 +1659,11 @@ def _test_eve_cain(): if iter == 0: optim = Eve(m.parameters(), lr=0.003) elif iter == 1: optim = Cain(m.parameters(), lr=0.03) elif iter == 2: optim = NeutralGradient(m.parameters(), lr=0.04, max_fullcov_size=10, - estimate_period=500, stats_steps=100) + estimate_period=500, stats_steps=100, + lr_for_speedup=0.02) elif iter == 3: optim = NeutralGradient(m.parameters(), lr=0.04, max_fullcov_size=1000, - estimate_period=500, stats_steps=100) + estimate_period=500, stats_steps=100, + lr_for_speedup=0.02) scheduler = Eden(optim, lr_batches=200, lr_epochs=5, verbose=False) start = timeit.default_timer()