From ffeef4ede46f30e0bb9e4ccd0dc486aee6417615 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 9 Jul 2022 13:36:48 +0800 Subject: [PATCH] Remove rank-1 dims, meaning where size==numel(), from processing. --- .../ASR/pruned_transducer_stateless7/optim.py | 411 +----------------- 1 file changed, 7 insertions(+), 404 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py index bbc575ff4..8cda942a7 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py @@ -178,8 +178,7 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of for dim in range(p.ndim): size = p.shape[dim] - if size == 1: - # if size == p.numel(), is_one_axis == True above + if size == 1 or size == p.numel(): continue @@ -320,7 +319,7 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of (step + lr_update_period)) for dim in range(p.ndim): size = p.shape[dim] - if size == 1: + if size == 1 or size == p.numel(): continue # M is the parameter matrix, of shape (-1, size) where the -1 covers @@ -363,7 +362,7 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of for dim in range(ndim): size = p.shape[dim] - if size == 1: + if size == 1 or size == p.numel(): continue param_cov = state[f"param_cov_{dim}"] U, S, V = _svd(param_cov) @@ -472,7 +471,7 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of # grad_cov_{dim}, which are for periodically updating the # learning-rate matrices. size = grad.shape[dim] - if size == 1: + if size == 1 or size == p.numel(): continue if step % grad_cov_period == 0 or step < grad_cov_period: grad_cov = state[f"grad_cov_{dim}"] @@ -559,8 +558,10 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of co-ordinates); if False, the reverse. They differ by a transpose, not an inverse, and do not make a round trip. """ + numel = x.numel() for dim in range(x.ndim): - if x.shape[dim] == 1: + size = x.shape[dim] + if size == 1 or size == numel: continue Q = state[f"Q_{dim}"] if forward: @@ -575,236 +576,6 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of x = x.transpose(-1, dim) return x - def _store_grad_stats(self, - grad: Tensor, - state: dict, - beta2: float) -> None: - """ - Accumulate some stats for each dimension of the tensor, about the covariance of gradients. - The stats are just summed, not decayed, as the scalar factor will be normalized out later. - """ - ndim = grad.ndim - kwargs = {'device': grad.device, 'dtype': grad.dtype} - for dim in range(ndim): - size = grad.shape[dim] - if size == 1: - continue - grad_cov = state[f"grad_cov_{dim}"] - g = grad.transpose(dim, -1) - g = g.reshape(-1, g.shape[-1]) - # We don't care about the scalar factor, we will normalize when we - # use it. - grad_cov.mul_(beta2).add_(torch.matmul(g.t(), g)) - - def _estimate_projections(self, - p: Tensor, - state: dict, - param_eps: float, - param_rel_eps: float, - param_rel_max: float, - param_reverse_cutoff: float, - beta3: float, - param_pow: float, - grad_pow: float, - grad_min_rand: float, - ) -> None: - """ - Estimate the per-dimension projections proj_0, proj_1 and so on. - These projections, when applied with forward==True by _change_coordinates(), - first rotate into a space - where the optimal learning-rate-scale is diagonalized (see self._estimate_proj); - then scale by the parameter scale. When applied with forward == False - they first scale by the parameter scale and then rotate back into canonical - co-ordinates. (Thus, applying with forward==True and then forward==False - is not a round trip). - - Args: - p: The parameter for which we are etimating the projections. Supplied - because we need this to estimate parameter covariance matrices in - each of its tensor directions. - state: The state dict for this parameter tensor - param_eps: Small epsilon representing a minimum parameter magnitude; - once the estimated parameter rms for some element of p gets below - this value, we'll treat it as having this value. - param_rel_eps: Another constraint on minimum parameter rms, relative to - sqrt(overall_param_rms), where overall_param_rms is the sqrt of the - parameter variance averaged over the entire tensor. - beta3: discounting parameter for param_cov, applied once every estimate_period. - param_pow: param_pow==1.0 means fully normalizing for parameter scale; - setting to a smaller value closer to 0.0 means partial normalization. - grad_min_rand: minimum proportion of random tensor to mix with the - gradient covariance matrix. - """ - kwargs = {'device':p.device, 'dtype':p.dtype} - - ndim = p.ndim - - for dim in range(ndim): - size = p.shape[dim] - if size == 1: - continue - proj = state[f"proj_{dim}"] - - if proj.ndim == 2: - # This dimension gets the full-covariance treatment. - count = state[f"grad_cov_count_{dim}"] - assert count != 0 - grad_cov = state[f"grad_cov_{dim}"] / count - del state[f"grad_cov_{dim}"] # save memory - - #self._randomize_lowrank_cov(grad_cov, count, p.shape, - # grad_min_rand) - - cur_param_cov = self._get_param_cov(p, dim) - cur_param_cov = cur_param_cov / (cur_param_cov.diag().mean() + 1.0e-20) # normalize size.. - param_cov = state[f"param_cov_{dim}"] - param_cov.mul_(beta3).add_(cur_param_cov, alpha=(1-beta3)) - - # We want the orthogonal matrix U that diagonalizes P, where - # P is the SPD matrix such that P G^{param_pow} P^T == C^{param_pow}, - # where G == grad_cov and C == param_cov. - # .. this is the same as the U that diagonalizes a different P - # such that P G P^T == C^{param_pow/(2-param_pow)}, - # since the P's are related by taking-to-a-power. - - num_samples = p.numel() // size - reverse_cutoff = (param_reverse_cutoff if num_samples > size//4 else 1.0e+10) - P = self._estimate_proj(grad_cov, - param_cov, - param_pow / grad_pow, - param_rel_eps, - param_rel_max, - reverse_cutoff) - - - # The only thing we want from P is the basis that diagonalizes - # it, i.e. if we do the symmetric svd P = U S U^T, we can - # interpret the shape of U as (canonical_coordinate, - # diagonalized_coordinate) - U, S, _ = P.svd() - # `proj` will be of shape (diagonalized_coordinate, canonical_coordinate) - # Later we will modify `proj` to incorporate the parameter scale. - proj[:] = U.t() - else: - # This dimension will be treated diagonally. For now just set `proj` to - # all ones, later we will modify it in _estimate_params_scale() - proj.fill_(1.0) - - self._estimate_param_scale(p, state, param_eps, - param_pow, param_rel_eps, param_rel_max, - param_reverse_cutoff) - - - - def _multiply_by_lr(self, - grad: Tensor, - state: dict, - forward: bool) -> Tensor: - """ - Multiply the grad by a positive semi-definite matrix for each dimension, - interpreted as a non-scalar learning-rate factor. - """ - cur_grad = grad - for dim in range(grad.ndim): - size = grad.shape[dim] - if size == 1: - continue - M = state[f"lr_{dim}"] - if M.ndim == 1: - # We are treating this dimension diagonally because it is too big. - M = M.reshape(size, *([1] * (ndim-1-dim))) - cur_grad = cur_grad * M - else: - cur_grad = self._multiply_on_dim(cur_grad, M, dim) - return cur_grad - - - def _multiply_by_grad_cov(self, - grad: Tensor, - state: dict, - forward: bool) -> Tensor: - """ - Multiply the grad by a positive semi-definite matrix for each dimension, - interpreted as a non-scalar learning-rate factor. - """ - cur_grad = grad - for dim in range(grad.ndim): - size = grad.shape[dim] - if size == 1: - continue - M = state[f"grad_cov_{dim}"] - if M.ndim == 1: - # We are treating this dimension diagonally because it is too big. - M = M.reshape(size, *([1] * (ndim-1-dim))) - # Dividing by (M.mean() + 1.0e-20) is just arbitrary normalization, - # to avoid getting very large or very small outputs. - cur_grad = cur_grad * (M / (M.mean() + 1.0e-20)) - else: - # Dividing by (M.diag().mean() + 1.0e-20) is just arbitrary normalization, - # to avoid getting very large or very small outputs. We only need the result - # up to an arbitrary scalar factor. - cur_grad = self._multiply_on_dim(cur_grad, M, dim) / (M.diag().mean() + 1.0e-20) - return cur_grad - - - - def _change_coordinates(self, - grad: Tensor, - state: dict, - forward: bool) -> Tensor: - """ - Multiply the grad by a matrix for each dimension, interpreted as a non-orthogonal - basis change. - If forward == True, `grad` is interpreted as: gradient -> gradient in new basis; - if forward == False, it is interpreted as: - parameter-change in new basis -> parameter-change in canonical basis. - These differ by a transpose (this is not the same as inversion as the matrix is not - orthogonal). Therefore the co-ordinate change applied both forward and backward - does not represent a "round trip". - - We know at this point that `grad` has more than one non-trivial dimension/axis - (i.e. >1 dimension with size != 1). - - Args: - p: the Parameter that the grad is for; may be needed if we re-estimating - the learning-rate matrix - grad: the gradient to precondition (will not be modified by this function) - state: the state-dict where will look up the projections. - Returns: - The co-ordinate-changed grad or parameter change - """ - cur_grad = grad - ndim = grad.ndim - step = state["step"] - for dim in range(ndim): - size = grad.shape[dim] - if size == 1: - continue - proj = state[f"proj_{dim}"] - - if proj.ndim == 1: - # We are treating this dimension diagonally because it is too big. - proj = proj.reshape(size, *([1] * (ndim-1-dim))) - cur_grad = cur_grad * proj - else: - cur_grad = self._multiply_on_dim(cur_grad, proj, dim, forward) - return cur_grad - - - - def _multiply_on_dim(self, x: Tensor, M: Tensor, dim: int) -> Tensor: - """ - Matrix-multiplies x by symmetric matrix `M` on x's dimension numbered `dim` - Args: - x: Tensor to multiply, of arbitrary shape - M: symmetric matrix to multiply x by, of shape - (x.shape[dim], x.shape[dim]). - """ - # Note: can use either M or M.t() here because M is symmetric - return torch.matmul(x.transpose(-1, dim), M.t()) - - - def _smooth_param_rms(self, group: dict, rms: Tensor, @@ -836,176 +607,8 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of return rms / new_mean - def _estimate_proj(self, - grad_cov: Tensor, - param_cov: Tensor, - param_pow: float, - param_rel_eps: float, - param_rel_max: float, - param_reverse_cutoff: float) -> Tensor: - """ - Return a symmetric positive definite matrix P such that - (P grad_cov P^T == param_cov^{param_pow}), - i.e. a descent direction that makes the covariance of steps look - like the covariance of parameter. This will be normalizes so - that the mean of its diagonal is 1.0 (since we'll have a separate - such projection per dim of the tensor, and we don't want to count - the scalar factor many times). - - Args: - grad_cov: Covariance of gradients, a symmetric-positive-definite - matrix [except for roundoff!] of shape (size, size) - param_cov: Covariance of parameters, a symmetric-positive-definite - matrix [except for roundoff!] of shape (size, size) - param_pow: Power that we take "param_cov" to, between 0 and 1. - Normally it will be 1, and not all the comments mention it. - Returns: - A matrix P such that (alpha P grad_cov P^T == param_cov^param_pow) - """ - # symmetrize grad_cov and param_cov. The projection P only cares about the - # relative sizes, so doubling both of them does not matter. - G = grad_cov + grad_cov.t() - C = param_cov + param_cov.t() - U, S, V = C.svd() - - # Using G == grad_cov and C == param_cov, and not writing transpose because - # all these quantities are symmetric, we can use: - # - # P = C^{0.5} (C^{0.5} G C^{0.5})^{-0.5} C^{0.5} - # - # You can verify fairly easily that P G P == C: multiply on left and right - # by C^{-0.5} on each side and you can see that both sides equal I. - # - # We can reduce the amount of roundoff by, after decomposing C - # with SVD as (U S U^T), writing C^{0.5} = U Q = Q^T U^T, with Q = S^0.5 U^T. - # - # Then we can write P as follows: - # P = Q^T U^T (U Q G Q^T U^T)^{-0.5} U Q, - # and we easily show that for orthogonal U, (U X U^T)^{-0.5} = U X^{-0.5} U^T, so we - # can do this and cancel U^T U = U U^T = I, giving: - # P = Q^T (Q G Q^T)^{-0.5} Q. - - # because C is symmetric, C == U S U^T, we can ignore V. - # S_sqrt is S.sqrt() in the limit where param_pow == 1.0, - # param_rel_eps=0, param_rel_max=inf - - S_smoothed = self._smooth_param_diag_var(S, param_pow, - param_rel_eps, param_rel_max, - param_reverse_cutoff) - S_sqrt = S_smoothed ** 0.5 - Q = (U * S_sqrt).t() - - X = torch.matmul(Q, torch.matmul(G, Q.t())) - - U2, S2, _ = X.svd() - # Because X^{-0.5} = U2 S2^{-0.5} U2^T, we have - # P = Q^T U2 S2^{-0.5} U2^T Q - # = Y Y^T, where - # Y = Q^T U2 S2^{-0.25} - - S2_pow = (S2 + 1.0e-35) ** -0.25 - Y = torch.matmul(Q.t(), U2) * S2_pow - - P = torch.matmul(Y, Y.t()) - - if random.random() < 1.0: #0.025: - # TODO: remove this testing code. - assert (P - P.t()).abs().mean() < 0.01 # make sure symmetric. - try: - P = 0.5 * (P + P.t()) - _,s,_ = P.svd() - logging.info(f"Min,max eig of P: {s.min()},{s.max()}") - except: - pass - # testing... note, this is only true modulo "eps" - C_check = torch.matmul(torch.matmul(P, G), P) - C_smoothed = torch.matmul(Q.t(), Q) - # C_check should equal C_smoothed - C_diff = C_check - C_smoothed - # Roundoff can cause significant differences, so use a fairly large - # threshold of 0.001. We may increase this later or even remove the check. - if not C_diff.abs().mean() < 0.01 * C_smoothed.diag().mean(): - logging.info(f"Warning: large C_diff: {C_diff.abs().mean()}, C diag mean: {C_smoothed.diag().mean()}") - - return P - - def reset_speedup(self): - """ - Reset the step_period so that we'll do a step each time, until the next - time - """ - for group in self.param_groups: - group["speedup_step"] = -group["speedup_first_step"] - for p in group["params"]: - state = self.state[p] - state["step_period"] = 1 - - def _recalibrate_speedup(self, group: dict): - param_prods = [] - speedup = group["lr_for_speedup"] / group["lr"] - if speedup > 10.0: - speedup = 10.0 - if speedup <= 1.0: - for p in group["params"]: - state = self.state[p] - state["step_period"] = 1 - return - _, beta2, _ = group["betas"] - - for p in group["params"]: - if p.grad is None: - continue - # We ignore the part of the parameter change that comes from the scale - # (see 'scale_deriv' in the code) - state = self.state[p] - step = state["step"] - - # the sqrt() is just a heuristic, it seems to give periods that work better. - param_prods.append(state["grad_times_step"].sqrt()) - - - param_prods = torch.stack(param_prods).to('cpu') - # TEMP - param_prods = param_prods.sqrt() - avg_update_freq = 1.0 / speedup - param_freqs = param_prods * (avg_update_freq / param_prods.mean()) - - # Don't delay more than 1 / (3*speedup) steps for any parameter, to - # ensure that no parameter waits super-long for an update. Note: with - # beta=0.1, this would mean an average delay of about 30*speedup steps - # before the gradient makes it to the parameter. - min_freq = 1.0 / (3 * speedup) - - # get the mean after applying limits of [min_freq..1] for the frequencies of - # update - param_freqs_mean = param_freqs.clamp(min=min_freq, max=1.0).mean() - # renormalize after applying min and max limits - param_freqs = param_freqs * (avg_update_freq / param_freqs_mean) - # ...and apply the limits again. - param_freqs = param_freqs.clamp(min=min_freq, max=1.0) - param_periods = (1.0 / param_freqs).to(torch.int32) # .to(torch.int32) rounds down - - actual_speedup = 1.0 / (1.0 / param_periods).mean() - - param_prods = param_prods.tolist() - param_freqs = param_freqs.tolist() - param_periods = param_periods.tolist() - - logging.info(f"PrAdam._recalibrate_speedup: speedup = {speedup:.2g}, actual_speedup = {actual_speedup:.2g}") - print_info = random.random() < 0.01 - i = 0 - for p in group["params"]: - if p.grad is None: - continue - state = self.state[p] - state["step_period"] = param_periods[i] - if print_info: - logging.info(f"Shape = {p.shape}, prod = {param_prods[i]}, freq = {param_freqs[i]}, period = {param_periods[i]}.") - i += 1 - - # Do nothing more, as yet. def _svd(x: Tensor):