mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-19 05:54:20 +00:00
this version performs way worse but has bugs fixed, can optimize from here.
This commit is contained in:
parent
dd10eb140f
commit
dee496145d
@ -143,7 +143,7 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
||||
betas=(0.9, 0.98),
|
||||
size_lr_scale=0.1,
|
||||
min_lr_factor=(0.05, 0.05, 0.05),
|
||||
max_lr_factor=(10.0, 10.0, 10.0),
|
||||
max_lr_factor=(100.0, 100.0, 100.0),
|
||||
param_rms_smooth0=0.75,
|
||||
param_rms_smooth1=0.25,
|
||||
eps=1.0e-08,
|
||||
@ -617,7 +617,8 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
||||
# Now Q is of shape (batch_size, num_blocks, 1, 1, block_size, block_size)
|
||||
# [batch_index, block_index, diagonalized_coordinate, canonical_coordinate],
|
||||
# so we need to transpose Q as we convert M to the diagonalized co-ordinate.
|
||||
M = torch.matmul(M, Q.transpose(2, 3)) # (batch_size, num_blocks, x, y, z, block_size)
|
||||
#M = torch.matmul(M, Q.transpose(2, 3)) # (batch_size, num_blocks, x, y, z, block_size)
|
||||
M = torch.matmul(M, Q.transpose(-2, -1)) # (batch_size, num_blocks, x, y, z, block_size)
|
||||
M = _move_dim(M, 1, -2) # (batch_size, x, y, z, num_blocks, block_size)
|
||||
M = M.reshape(*M.shape[:-2], size) # # (batch_size, x, y, z, size)
|
||||
cur_p = M.transpose(dim, -1) # (batch_size, x, size, y, z)
|
||||
@ -638,8 +639,17 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
||||
# spectrum"). We scale p so that it matches the accumulated stats,
|
||||
# the idea is to ensure it doesn't have any too-small eigenvalues
|
||||
# (where the stats permit).
|
||||
|
||||
scale = (S / cur_param_var.clamp(min=eps)).sqrt()
|
||||
|
||||
if True:
|
||||
S_tmp = S.reshape(batch_size, size)
|
||||
cur_tmp = cur_param_var.reshape(batch_size, size)
|
||||
scale_tmp = scale.reshape(batch_size, size)
|
||||
skip = 10 if size > 40 else 1
|
||||
logging.info(f"cur_param_var={cur_tmp[0][::skip]}, S={S_tmp[0][::skip]}, scale={scale_tmp[0][::skip]}")
|
||||
|
||||
|
||||
if random.random() < 0.01:
|
||||
skip = 10 if size < 20 else 1
|
||||
logging.info(f"shape={p.shape}, dim={dim}, scale={scale[0].flatten()[::skip]}, cur_param_var={cur_param_var[0].flatten()[::skip]}, S={S[0].flatten()[::skip]}")
|
||||
@ -752,7 +762,7 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
||||
be solving (eqn:2), and then computing:
|
||||
Z = U Z' U^T.
|
||||
|
||||
A solution to (eqn:1) is as follows. We are going to be using a Cholesky-based solution in
|
||||
A solution to (eqn:2) is as follows. We are going to be using a Cholesky-based solution in
|
||||
favor of one that requires SVD or eigenvalue decomposition, because it is much faster (we first
|
||||
have to be careful that the input is not close to singular, though).
|
||||
|
||||
@ -801,14 +811,21 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
||||
|
||||
P_prime = self._smooth_param_cov(group, p_shape, P_prime, G_prime)
|
||||
|
||||
C = P_prime.cholesky() # P_prime = torch.matmul(C, C.transpose(2, 3))
|
||||
# C will satisfy: P_prime == torch.matmul(C, C.transpose(2, 3))
|
||||
# C is of shape (batch_size, num_blocks, block_size, block_size).
|
||||
#def _fake_cholesky(X):
|
||||
# U_, S_, _ = _svd(X)
|
||||
# return U_ * S_.sqrt().unsqueeze(-2)
|
||||
#C = _fake_cholesky(P_prime)
|
||||
C = P_prime.cholesky()
|
||||
|
||||
# CGC = (C^T G' C) which would be torch.matmul(C.transpose(2, 3), torch.matmul(G, C))
|
||||
|
||||
# CGC = (C^T G' C), it would be torch.matmul(C.transpose(2, 3), torch.matmul(G, C))
|
||||
# if G were formatted as a diagonal matrix, but G is just the diagonal.
|
||||
CGC = torch.matmul(C.transpose(2, 3) * G_prime.unsqueeze(-2),
|
||||
C)
|
||||
CGC = torch.matmul(C.transpose(2, 3) * G_prime.unsqueeze(-2), C)
|
||||
|
||||
U, S, _ = _svd(CGC) # Need SVD to compute CGC^{-0.5}
|
||||
# next we compute (eqn:3). The thing in the parenthesis is, GCC^{-0.5},
|
||||
# next we compute (eqn:3). The thing in the parenthesis is, CGC^{-0.5},
|
||||
# can be written as U S^{-0.5} U^T, so the whole thing is
|
||||
# (C U) S^{-0.5} (C U)^T
|
||||
CU = torch.matmul(C, U)
|
||||
@ -817,12 +834,26 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
||||
CU.transpose(2, 3))
|
||||
|
||||
if True:
|
||||
def _check_similar(x, y, name):
|
||||
ratio = (y-x).abs().sum() / x.abs().sum()
|
||||
if ratio > 0.0001:
|
||||
logging.warn(f"Check {name} failed, ratio={ratio.item()}")
|
||||
|
||||
def _check_symmetric(x, x_name):
|
||||
diff = x - x.transpose(-2, -1)
|
||||
ratio = diff.abs().sum() / x.abs().sum()
|
||||
if ratio > 0.0001:
|
||||
logging.warn(f"{x_name} is not symmetric: ratio={ratio.item()}")
|
||||
|
||||
_check_similar(P_prime, torch.matmul(C, C.transpose(2, 3)), "cholesky_check")
|
||||
_check_symmetric(Z_prime, "Z_prime")
|
||||
_check_symmetric(P_prime, "P_prime")
|
||||
# A check.
|
||||
# Z_prime is supposed to satisfy: Z_prime G_prime Z_prime = P_prime (eqn:2)
|
||||
P_prime_check = torch.matmul(Z_prime * G_prime.unsqueeze(-2), Z_prime)
|
||||
diff_ratio = (P_prime - P_prime_check).abs().sum() / P_prime.abs().sum()
|
||||
if diff_ratio > 0.01:
|
||||
logging.warn(f"Z_prime does not satisfy its definition, diff_ratio = {diff_ratio}")
|
||||
diff_ratio = ((P_prime - P_prime_check) ** 2).sum().sqrt() // (P_prime ** 2).sum().sqrt()
|
||||
if diff_ratio > 0.001:
|
||||
logging.warn(f"Z_prime does not satisfy its definition, diff_ratio = {diff_ratio}, size={size}")
|
||||
Z = torch.matmul(U_g, torch.matmul(Z_prime, U_g.transpose(2, 3)))
|
||||
# OK, Z is the SPD transform that maps G to P, as in Z G Z = P.
|
||||
# We just need the basis that diagonalizes this.
|
||||
@ -844,7 +875,11 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
||||
U_prod = torch.matmul(U_z.transpose(2, 3), U_g)
|
||||
# this_P_proj shape: (batch_size, num_blocks, block_size)
|
||||
this_P_proj = _diag(torch.matmul(U_prod, torch.matmul(P_prime, U_prod.transpose(2, 3))))
|
||||
P_proj[dim] = this_P_proj.reshape(batch_size, size)
|
||||
P_proj[dim] = this_P_proj.clone().reshape(batch_size, size)
|
||||
if True:
|
||||
skip = 10 if P_proj[dim].shape[-1] > 40 else 1
|
||||
logging.info(f"Eigs of P_proj are: {P_proj[dim][0,::skip]}")
|
||||
|
||||
return P_proj
|
||||
|
||||
|
||||
@ -890,7 +925,7 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
||||
# Now P is as normalized as we can make it... do smoothing baserd on 'rank',
|
||||
# that is intended to compensate for bad estimates of P.
|
||||
batch_size = p_shape[0]
|
||||
size = P_prime.shape[0] # size of dim we are concerned with right now
|
||||
size = P_norm.shape[0] # size of dim we are concerned with right now
|
||||
# `rank` is the rank of P_prime if we were to estimate it from just one
|
||||
# parameter tensor. We average it over time, but actually it won't be changing
|
||||
# too much, so `rank` does tell us something.
|
||||
@ -908,7 +943,7 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
||||
# add rank-dependent smoothing amount to diagonal of P_prime. _diag() returns an aliased tensor.
|
||||
# we don't need to multiply `smooth` by anything, because at this point, P_prime should have
|
||||
# diagonal elements close to 1.
|
||||
_diag(P_prime).add_(smooth)
|
||||
_diag(P_norm).add_(smooth)
|
||||
|
||||
P_norm = self._smooth_cov(P_norm,
|
||||
group["min_lr_factor"][0],
|
||||
@ -927,7 +962,7 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
||||
G_prime_rms = G_prime.sqrt()
|
||||
G_prime_scale = G_prime_rms.unsqueeze(-1) * G_prime_rms.unsqueeze(-2)
|
||||
# P_gnorm is a version of P_prime that is scaled relative to G, i.e.
|
||||
# scaled in such a way that would make G the unit matrix.
|
||||
# scaled in a way that would make G the unit matrix.
|
||||
P_gnorm = P_prime / G_prime_scale
|
||||
# Apply another round of smoothing "relative to G"
|
||||
P_gnorm = self._smooth_cov(P_gnorm,
|
||||
@ -982,6 +1017,7 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
||||
S = S / _mean(S, exclude_dims=[0], keepdim=True)
|
||||
return torch.matmul(U * S.unsqueeze(-2), U.transpose(2, 3))
|
||||
else:
|
||||
X = X.clone() # may be
|
||||
diag = _diag(X) # Aliased with X
|
||||
mean_eig = _mean(diag, exclude_dims=[0], keepdim=True)
|
||||
eps = 1.0e-10 # prevent division by zero
|
||||
@ -1335,11 +1371,11 @@ def _diag(x: Tensor):
|
||||
elif x.ndim == 4:
|
||||
(B, C, M, M2) = x.shape
|
||||
assert M == M2
|
||||
ans = x.as_strided(size=(B, C, M), stride=(stride[0], stride[1], stride[2] + stride[3])).contiguous()
|
||||
ans = x.as_strided(size=(B, C, M), stride=(stride[0], stride[1], stride[2] + stride[3]))
|
||||
elif x.ndim == 2:
|
||||
(M, M2) = x.shape
|
||||
assert M == M2
|
||||
ans = x.as_strided(size=(M,), stride=(stride[0] + stride[1],)).contiguous()
|
||||
ans = x.as_strided(size=(M,), stride=(stride[0] + stride[1],))
|
||||
return ans
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user