mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Some cleanups
This commit is contained in:
parent
a5b9b7b974
commit
64f7166545
@ -889,12 +889,17 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
|||||||
# rms of eigenvalues, or spectral 2-norm.
|
# rms of eigenvalues, or spectral 2-norm.
|
||||||
return _mean(_diag(Y), exclude_dims=[0], keepdim=True).unsqueeze(-1)
|
return _mean(_diag(Y), exclude_dims=[0], keepdim=True).unsqueeze(-1)
|
||||||
|
|
||||||
def rms_eig(Y):
|
def rms_eig(Y, M):
|
||||||
# rms of eigenvalues, or spectral 2-norm.
|
# rms of eigenvalues, or spectral 2-norm, of A = M^{0.5} Y M^{0.5}.
|
||||||
return (_sum(Y*Y.transpose(2,3), exclude_dims=[0], keepdim=True) / size).sqrt()
|
# We can compute this as (tr(A A) / size).sqrt(), since A is symmetric.
|
||||||
|
# tr(A ) == tr(M^{0.5} Y M^{0.5} M^{0.5} Y M^{0.5}) = tr(M Y M Y) = tr(Z Z) if Z == M Y.
|
||||||
|
# tr(Z Z) can be computed as sum(Z*Z.t())
|
||||||
|
Z = torch.matmul(Y, M)
|
||||||
|
return (_sum(Z*Z.transpose(2,3), exclude_dims=[0], keepdim=True) / size).sqrt()
|
||||||
|
|
||||||
# make sure eigs of M^{0.5} X M^{0.5} have rms value 1. This imposes limit on the max.
|
# make sure eigs of M^{0.5} X M^{0.5} have rms value 1. This imposes a
|
||||||
X /= rms_eig(torch.matmul(X, M))
|
# limit on the max.
|
||||||
|
X /= rms_eig(X, M)
|
||||||
|
|
||||||
# eig_ceil is the maximum possible eigenvalue that X could possibly
|
# eig_ceil is the maximum possible eigenvalue that X could possibly
|
||||||
# have at this time, equal to num_blocks * block_size.
|
# have at this time, equal to num_blocks * block_size.
|
||||||
@ -928,6 +933,9 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
|||||||
X = 0.5 * (X + X.transpose(-2, -1)) # make sure exactly symmetric.
|
X = 0.5 * (X + X.transpose(-2, -1)) # make sure exactly symmetric.
|
||||||
|
|
||||||
if min_eig != 0.0:
|
if min_eig != 0.0:
|
||||||
|
# conceptually we want tr(M^{0.5} X M^{0.5}), but this is the same as tr(X M).
|
||||||
|
# we can compute this as the sum of the elements of X * M^T, because if Y=matmul(X, M),
|
||||||
|
# Y_{ii} is the sum of X_{ij} * M_{ji}.
|
||||||
mean_XM_eig = _sum(X * M.transpose(2, 3), exclude_dims=[0], keepdim=True) / size
|
mean_XM_eig = _sum(X * M.transpose(2, 3), exclude_dims=[0], keepdim=True) / size
|
||||||
XM_over_X = mean_XM_eig / mean_eig(X)
|
XM_over_X = mean_XM_eig / mean_eig(X)
|
||||||
X /= mean_XM_eig
|
X /= mean_XM_eig
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user