From dee496145de2559187993e4c9192de136c99e599 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 23 Jul 2022 08:11:20 +0800 Subject: [PATCH] this version performs way worse but has bugs fixed, can optimize from here. --- .../ASR/pruned_transducer_stateless7/optim.py | 70 ++++++++++++++----- 1 file changed, 53 insertions(+), 17 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py index 902a7d21c..87956ed95 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py @@ -143,7 +143,7 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of betas=(0.9, 0.98), size_lr_scale=0.1, min_lr_factor=(0.05, 0.05, 0.05), - max_lr_factor=(10.0, 10.0, 10.0), + max_lr_factor=(100.0, 100.0, 100.0), param_rms_smooth0=0.75, param_rms_smooth1=0.25, eps=1.0e-08, @@ -617,7 +617,8 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of # Now Q is of shape (batch_size, num_blocks, 1, 1, block_size, block_size) # [batch_index, block_index, diagonalized_coordinate, canonical_coordinate], # so we need to transpose Q as we convert M to the diagonalized co-ordinate. - M = torch.matmul(M, Q.transpose(2, 3)) # (batch_size, num_blocks, x, y, z, block_size) + #M = torch.matmul(M, Q.transpose(2, 3)) # (batch_size, num_blocks, x, y, z, block_size) + M = torch.matmul(M, Q.transpose(-2, -1)) # (batch_size, num_blocks, x, y, z, block_size) M = _move_dim(M, 1, -2) # (batch_size, x, y, z, num_blocks, block_size) M = M.reshape(*M.shape[:-2], size) # # (batch_size, x, y, z, size) cur_p = M.transpose(dim, -1) # (batch_size, x, size, y, z) @@ -638,8 +639,17 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of # spectrum"). We scale p so that it matches the accumulated stats, # the idea is to ensure it doesn't have any too-small eigenvalues # (where the stats permit). + scale = (S / cur_param_var.clamp(min=eps)).sqrt() + if True: + S_tmp = S.reshape(batch_size, size) + cur_tmp = cur_param_var.reshape(batch_size, size) + scale_tmp = scale.reshape(batch_size, size) + skip = 10 if size > 40 else 1 + logging.info(f"cur_param_var={cur_tmp[0][::skip]}, S={S_tmp[0][::skip]}, scale={scale_tmp[0][::skip]}") + + if random.random() < 0.01: skip = 10 if size < 20 else 1 logging.info(f"shape={p.shape}, dim={dim}, scale={scale[0].flatten()[::skip]}, cur_param_var={cur_param_var[0].flatten()[::skip]}, S={S[0].flatten()[::skip]}") @@ -752,7 +762,7 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of be solving (eqn:2), and then computing: Z = U Z' U^T. - A solution to (eqn:1) is as follows. We are going to be using a Cholesky-based solution in + A solution to (eqn:2) is as follows. We are going to be using a Cholesky-based solution in favor of one that requires SVD or eigenvalue decomposition, because it is much faster (we first have to be careful that the input is not close to singular, though). @@ -801,14 +811,21 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of P_prime = self._smooth_param_cov(group, p_shape, P_prime, G_prime) - C = P_prime.cholesky() # P_prime = torch.matmul(C, C.transpose(2, 3)) + # C will satisfy: P_prime == torch.matmul(C, C.transpose(2, 3)) + # C is of shape (batch_size, num_blocks, block_size, block_size). + #def _fake_cholesky(X): + # U_, S_, _ = _svd(X) + # return U_ * S_.sqrt().unsqueeze(-2) + #C = _fake_cholesky(P_prime) + C = P_prime.cholesky() - # CGC = (C^T G' C) which would be torch.matmul(C.transpose(2, 3), torch.matmul(G, C)) + + # CGC = (C^T G' C), it would be torch.matmul(C.transpose(2, 3), torch.matmul(G, C)) # if G were formatted as a diagonal matrix, but G is just the diagonal. - CGC = torch.matmul(C.transpose(2, 3) * G_prime.unsqueeze(-2), - C) + CGC = torch.matmul(C.transpose(2, 3) * G_prime.unsqueeze(-2), C) + U, S, _ = _svd(CGC) # Need SVD to compute CGC^{-0.5} - # next we compute (eqn:3). The thing in the parenthesis is, GCC^{-0.5}, + # next we compute (eqn:3). The thing in the parenthesis is, CGC^{-0.5}, # can be written as U S^{-0.5} U^T, so the whole thing is # (C U) S^{-0.5} (C U)^T CU = torch.matmul(C, U) @@ -817,12 +834,26 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of CU.transpose(2, 3)) if True: + def _check_similar(x, y, name): + ratio = (y-x).abs().sum() / x.abs().sum() + if ratio > 0.0001: + logging.warn(f"Check {name} failed, ratio={ratio.item()}") + + def _check_symmetric(x, x_name): + diff = x - x.transpose(-2, -1) + ratio = diff.abs().sum() / x.abs().sum() + if ratio > 0.0001: + logging.warn(f"{x_name} is not symmetric: ratio={ratio.item()}") + + _check_similar(P_prime, torch.matmul(C, C.transpose(2, 3)), "cholesky_check") + _check_symmetric(Z_prime, "Z_prime") + _check_symmetric(P_prime, "P_prime") # A check. # Z_prime is supposed to satisfy: Z_prime G_prime Z_prime = P_prime (eqn:2) P_prime_check = torch.matmul(Z_prime * G_prime.unsqueeze(-2), Z_prime) - diff_ratio = (P_prime - P_prime_check).abs().sum() / P_prime.abs().sum() - if diff_ratio > 0.01: - logging.warn(f"Z_prime does not satisfy its definition, diff_ratio = {diff_ratio}") + diff_ratio = ((P_prime - P_prime_check) ** 2).sum().sqrt() // (P_prime ** 2).sum().sqrt() + if diff_ratio > 0.001: + logging.warn(f"Z_prime does not satisfy its definition, diff_ratio = {diff_ratio}, size={size}") Z = torch.matmul(U_g, torch.matmul(Z_prime, U_g.transpose(2, 3))) # OK, Z is the SPD transform that maps G to P, as in Z G Z = P. # We just need the basis that diagonalizes this. @@ -844,7 +875,11 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of U_prod = torch.matmul(U_z.transpose(2, 3), U_g) # this_P_proj shape: (batch_size, num_blocks, block_size) this_P_proj = _diag(torch.matmul(U_prod, torch.matmul(P_prime, U_prod.transpose(2, 3)))) - P_proj[dim] = this_P_proj.reshape(batch_size, size) + P_proj[dim] = this_P_proj.clone().reshape(batch_size, size) + if True: + skip = 10 if P_proj[dim].shape[-1] > 40 else 1 + logging.info(f"Eigs of P_proj are: {P_proj[dim][0,::skip]}") + return P_proj @@ -890,7 +925,7 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of # Now P is as normalized as we can make it... do smoothing baserd on 'rank', # that is intended to compensate for bad estimates of P. batch_size = p_shape[0] - size = P_prime.shape[0] # size of dim we are concerned with right now + size = P_norm.shape[0] # size of dim we are concerned with right now # `rank` is the rank of P_prime if we were to estimate it from just one # parameter tensor. We average it over time, but actually it won't be changing # too much, so `rank` does tell us something. @@ -908,7 +943,7 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of # add rank-dependent smoothing amount to diagonal of P_prime. _diag() returns an aliased tensor. # we don't need to multiply `smooth` by anything, because at this point, P_prime should have # diagonal elements close to 1. - _diag(P_prime).add_(smooth) + _diag(P_norm).add_(smooth) P_norm = self._smooth_cov(P_norm, group["min_lr_factor"][0], @@ -927,7 +962,7 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of G_prime_rms = G_prime.sqrt() G_prime_scale = G_prime_rms.unsqueeze(-1) * G_prime_rms.unsqueeze(-2) # P_gnorm is a version of P_prime that is scaled relative to G, i.e. - # scaled in such a way that would make G the unit matrix. + # scaled in a way that would make G the unit matrix. P_gnorm = P_prime / G_prime_scale # Apply another round of smoothing "relative to G" P_gnorm = self._smooth_cov(P_gnorm, @@ -982,6 +1017,7 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of S = S / _mean(S, exclude_dims=[0], keepdim=True) return torch.matmul(U * S.unsqueeze(-2), U.transpose(2, 3)) else: + X = X.clone() # may be diag = _diag(X) # Aliased with X mean_eig = _mean(diag, exclude_dims=[0], keepdim=True) eps = 1.0e-10 # prevent division by zero @@ -1335,11 +1371,11 @@ def _diag(x: Tensor): elif x.ndim == 4: (B, C, M, M2) = x.shape assert M == M2 - ans = x.as_strided(size=(B, C, M), stride=(stride[0], stride[1], stride[2] + stride[3])).contiguous() + ans = x.as_strided(size=(B, C, M), stride=(stride[0], stride[1], stride[2] + stride[3])) elif x.ndim == 2: (M, M2) = x.shape assert M == M2 - ans = x.as_strided(size=(M,), stride=(stride[0] + stride[1],)).contiguous() + ans = x.as_strided(size=(M,), stride=(stride[0] + stride[1],)) return ans