diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py index 7ca649a71..906ecaddf 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py @@ -818,64 +818,53 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of power: power to take eigenvalues to X: the batch of symmetric positive definite tensors we are smoothing; of shape (batch_size, num_blocks, block_size, block_size) - """ + """ eps = 1.0e-20 - if power != 1.0: - U, S, _ = _svd(X) - def mean(Y): - return _mean(Y, exclude_dims=[0], keepdim=True) - def rms(Y): - return _mean(Y**2, exclude_dims=[0], keepdim=True).sqrt() - S = S + min_eig * mean(S) + eps - S = S / rms(S) - S = 1. / (1./S + 1./max_eig) - S = S ** power - S = S / rms(S) - return torch.matmul(U * S.unsqueeze(-2), U.transpose(2, 3)) - else: - X = X.clone() - size = X.shape[1] * X.shape[3] - def rms_eig(Y): - # rms of eigenvalues, or spectral 2-norm. - return (_sum(Y**2, exclude_dims=[0], keepdim=True) / size).sqrt() - diag = _diag(X).unsqueeze(-1) # Aliased with X - diag += (_mean(diag, exclude_dims=[0], keepdim=True) * min_eig + eps) - X /= rms_eig(X) + X = X.clone() + size = X.shape[1] * X.shape[3] + def rms_eig(Y): + # rms of eigenvalues, or spectral 2-norm. + return (_sum(Y**2, exclude_dims=[0], keepdim=True) / size).sqrt() - # eig_ceil is the maximum possible eigenvalue that X could possibly - # have at this time, given that the RMS eig is 1, equal to sqrt(num_blocks * block_size) - eig_ceil = size ** 0.5 + X /= rms_eig(X) + # eig_ceil is the maximum possible eigenvalue that X could possibly + # have at this time, given that the RMS eig is 1, equal to sqrt(num_blocks * block_size) + eig_ceil = size ** 0.5 - # the next statement wslightly adjusts the target to be the same as - # what the baseline function gives, max_eig -> 1./(1./eig_ceil + 1./max_eig) would - # give. so "max_eig" is now the target of the function for arg == - # eig_ceil. - max_eig = 1. / (1. / max_eig + 1. / eig_ceil) + # the next statement slightly adjusts the target to be the same as + # what the baseline function gives, max_eig -> 1./(1./eig_ceil + 1./max_eig) would + # give. so "max_eig" is now the target of the function for arg == + # eig_ceil. + max_eig = 1. / (1. / max_eig + 1. / eig_ceil) - while max_eig <= eig_ceil * 0.5: - # this is equivalent to the operation: l -> l - 0.5/eig_ceil*l*l - # on eigenvalues, which maps eig_ceil to 0.5*eig_ceil and is monotonically - # increasing from 0..eig_ceil. - # the transpose on the 2nd X is to try to stop small asymmetries from - # propagating. - X = X - 0.5/eig_ceil * torch.matmul(X, X.transpose(2, 3)) - eig_ceil = 0.5 * eig_ceil + while max_eig <= eig_ceil * 0.5: + # this is equivalent to the operation: l -> l - 0.5/eig_ceil*l*l + # on eigenvalues, which maps eig_ceil to 0.5*eig_ceil and is monotonically + # increasing from 0..eig_ceil. + # the transpose on the 2nd X is to try to stop small asymmetries from + # propagating. + X = X - 0.5/eig_ceil * torch.matmul(X, X.transpose(2, 3)) + eig_ceil = 0.5 * eig_ceil - # max_eig > eig_ceil * 0.5 - if max_eig < eig_ceil: - # map l -> l - coeff*l*l, if l==eig_ceil, this - # takes us to: - # eig_ceil - (eig_ceil-max_eig/(eig_ceil*eig_ceil))*eig_ceil*eig_ceil - # == max_eig - # .. and the fact that coeff <= 0.5/eig_ceil [since max_eig>eig_ceil*0.5] - # means that the function is monotonic on inputs from 0 to eig_ceil. - coeff = (eig_ceil - max_eig) / (eig_ceil*eig_ceil) - X = X - coeff * torch.matmul(X, X.transpose(2, 3)) + # max_eig > eig_ceil * 0.5 + if max_eig < eig_ceil: + # map l -> l - coeff*l*l, if l==eig_ceil, this + # takes us to: + # eig_ceil - (eig_ceil-max_eig/(eig_ceil*eig_ceil))*eig_ceil*eig_ceil + # == max_eig + # .. and the fact that coeff <= 0.5/eig_ceil [since max_eig>eig_ceil*0.5] + # means that the function is monotonic on inputs from 0 to eig_ceil. + coeff = (eig_ceil - max_eig) / (eig_ceil*eig_ceil) + X = X - coeff * torch.matmul(X, X.transpose(2, 3)) - # normalize to have rms eig == 1. - X /= rms_eig(X) - X = 0.5 * (X + X.transpose(2, 3)) # make sure exactly symmetric. - return X + # normalize to have rms eig == 1. + X = 0.5 * (X + X.transpose(2, 3)) # make sure exactly symmetric. + + diag = _diag(X).unsqueeze(-1) # Aliased with X + diag += (_mean(diag, exclude_dims=[0], keepdim=True) * min_eig + eps) + X /= rms_eig(X) + + return X def _apply_min_max_with_metric(self, @@ -910,12 +899,6 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of # make sure eigs of M^{0.5} X M^{0.5} have rms value 1. This imposes limit on the max. X /= rms_eig(torch.matmul(X, M)) - if min_eig != 0.0: - X = X * (1.0-min_eig) + min_eig * M.inverse() - X = 0.5 * (X + X.transpose(-2, -1)) # make sure exactly symmetric. - - X /= rms_eig(torch.matmul(X, M)) - # eig_ceil is the maximum possible eigenvalue that X could possibly # have at this time, equal to num_blocks * block_size. eig_ceil = size ** 0.5 @@ -945,9 +928,13 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of coeff = (eig_ceil - max_eig) / (eig_ceil*eig_ceil) X = X - coeff * torch.matmul(X, torch.matmul(M, X.transpose(2, 3))) - # Normalize again. - X /= rms_eig(X) X = 0.5 * (X + X.transpose(-2, -1)) # make sure exactly symmetric. + + if min_eig != 0.0: + X /= mean_eig(X) + X = X * (1.0-min_eig) + min_eig * M.inverse() + X = 0.5 * (X + X.transpose(-2, -1)) # make sure exactly symmetric. + return X