mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-19 05:54:20 +00:00
this version not working great
This commit is contained in:
parent
790e8c4ba9
commit
cb67540cdc
@ -655,55 +655,51 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
||||
# is the same as param_cov and grad_cov.
|
||||
#
|
||||
# G_prime is the diagonalized grad_cov, G' above, of shape (batch_size, num_blocks, block_size).
|
||||
U_g, G_prime, _ = _svd(grad_cov)
|
||||
# Recompute G_prime via matrix multiplication, as svd does not seem to produce zeros on
|
||||
# the singular values in places where it should be exactly zero. This is to keep
|
||||
# dimensions with zero grad separate from those with nonzero grad.
|
||||
G_prime = _diag(torch.matmul(U_g.transpose(2,3), torch.matmul(grad_cov, U_g)))
|
||||
|
||||
G_prime_noeps = G_prime.clone()
|
||||
G = grad_cov.clone()
|
||||
|
||||
|
||||
# Use the form of the diagonalized gradient matrix that we get after
|
||||
# we add the Adam-type smoothing with epsilon.
|
||||
G_prime += (_mean(G_prime, exclude_dims=[0], keepdim=True) *(denom_rel_eps * denom_rel_eps) +
|
||||
(eps * eps))
|
||||
|
||||
# P_prime is P' above, which represents param_cov in the basis that diagonalizes G_prime.
|
||||
# It is not smoothed yet.
|
||||
P_prime = torch.matmul(U_g.transpose(2, 3), torch.matmul(param_cov, U_g))
|
||||
G_diag = _diag(G) # aliased with G
|
||||
G_diag += (_mean(G_diag, exclude_dims=[0], keepdim=True) *(denom_rel_eps * denom_rel_eps) +
|
||||
(eps * eps))
|
||||
|
||||
P_prime_unsmoothed = P_prime
|
||||
P_prime = self._smooth_param_cov(group, p_shape, P_prime, G_prime)
|
||||
P_unsmoothed = param_cov
|
||||
P = param_cov.clone()
|
||||
P = self._smooth_param_cov(group, p_shape, P, G)
|
||||
|
||||
# C will satisfy: P_prime == torch.matmul(C, C.transpose(2, 3))
|
||||
if True:
|
||||
_,S,_ = _svd(P)
|
||||
logging.info(f"Eigs of P are {S}")
|
||||
|
||||
|
||||
# C will satisfy: P == torch.matmul(C, C.transpose(2, 3))
|
||||
# C is of shape (batch_size, num_blocks, block_size, block_size).
|
||||
#def _fake_cholesky(X):
|
||||
# U_, S_, _ = _svd(X)
|
||||
# return U_ * S_.sqrt().unsqueeze(-2)
|
||||
#C = _fake_cholesky(P_prime)
|
||||
C = P_prime.cholesky()
|
||||
#C = _fake_cholesky(P)
|
||||
C = P.cholesky()
|
||||
|
||||
# OK, P == U C C^T U^T.
|
||||
# A matrix that takes normally distributed data to P would
|
||||
# be U_g C, because C I C^T = C C^T = P. We can actually use *any* matrix
|
||||
# be C, because C I C^T = C C^T = P. We can actually use *any* matrix
|
||||
# that takes normally distributed data to P, so we can use
|
||||
# U_g C U for any orthogonal U, since U_g C U I U^T C^T U_g^T == P.
|
||||
# C U for any orthogonal U, since C U I U^T C^T == P.
|
||||
# So there is no harm in choosing a matrix U that diagonalizes the
|
||||
# projected grad_cov. grad_cov gets projected by
|
||||
|
||||
# this projects P; its transpose projects the gradient.
|
||||
UC = torch.matmul(U_g, C)
|
||||
grad_cov_proj = torch.matmul(C.transpose(2, 3),
|
||||
torch.matmul(grad_cov, C))
|
||||
|
||||
# instead of projecting grad_cov, we can just use its diagonal, forget the #
|
||||
# U_g part of the transform, and project with C.
|
||||
grad_cov_proj = torch.matmul(C.transpose(2, 3) * G_prime_noeps.unsqueeze(-1), C)
|
||||
|
||||
# OK, grad_cov is diagonalized by U^T C^T U_g^T. So the projection that we
|
||||
# apply to the param cov is U_g C U
|
||||
# OK, grad_cov is diagonalized by U^T C^T. So the projection that we
|
||||
# apply to the param cov is C U
|
||||
U, S, _ = _svd(grad_cov_proj)
|
||||
|
||||
# proj is indexed [batch_idx,block_idx,canonical_coordinate,diagonalized_coordinate],
|
||||
# so we need to transpose to get Q_{dim}.
|
||||
proj = torch.matmul(UC, U)
|
||||
proj = torch.matmul(C, U)
|
||||
|
||||
Q[:] = proj.transpose(2, 3)
|
||||
|
||||
@ -713,20 +709,19 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
||||
def _smooth_param_cov(self,
|
||||
group: dict,
|
||||
p_shape: torch.Size,
|
||||
P_prime: Tensor,
|
||||
G_prime: Tensor) -> Tensor:
|
||||
P: Tensor,
|
||||
G: Tensor) -> Tensor:
|
||||
"""
|
||||
This function returns a modified/smoothed version of the parameter covariance
|
||||
P_prime.
|
||||
P.
|
||||
|
||||
Args:
|
||||
group: dict to look up config values
|
||||
p_shape: The shape of the parameter we are optimizing
|
||||
P_prime: a Tensor of shape (batch_size, num_blocks, block_size, block_size),
|
||||
containing the parameter covariance in a basis that diagonalizes the
|
||||
gradient covariance.
|
||||
G_prime: the diagonalized gradient covariance, of shape (batch_size, num_blocks,
|
||||
block_size)
|
||||
P: a Tensor of shape (batch_size, num_blocks, block_size, block_size),
|
||||
containing the parameter covariance
|
||||
G: the gradient covariance, of shape (batch_size, num_blocks,
|
||||
block_size, block_size)
|
||||
|
||||
|
||||
state[f"param_cov_{dim}"], which is an estimate of the covariance of the parameter
|
||||
@ -739,7 +734,7 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
||||
"""
|
||||
# do smoothing based on 'rank',
|
||||
# that is intended to compensate for bad estimates of P.
|
||||
(batch_size, num_blocks, block_size, block_size) = P_prime.shape
|
||||
(batch_size, num_blocks, block_size, block_size) = P.shape
|
||||
# `rank_per_block` is the rank of each block of P_prime if we were to estimate it from just one
|
||||
# parameter tensor. We average it over time, but actually it won't be changing
|
||||
# too much, so `rank` does tell us something.
|
||||
@ -763,48 +758,60 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
||||
# we don't need to multiply `smooth` by anything, because at this point, P_prime should have
|
||||
# diagonal elements close to 1.
|
||||
|
||||
P_prime_diag = _diag(P_prime)
|
||||
P_prime_diag_mean = _mean(P_prime_diag, exclude_dims=[0], keepdim=True)
|
||||
P_prime_diag += smooth * P_prime_diag_mean
|
||||
P = P.clone()
|
||||
P_diag = _diag(P)
|
||||
P_diag_mean = _mean(P_diag, exclude_dims=[0], keepdim=True)
|
||||
P_diag += smooth * P_diag_mean
|
||||
|
||||
G = G.clone()
|
||||
G_diag = _diag(G) # aliased
|
||||
G_diag *= 1.01 # improve its condition, for numerical reasons.
|
||||
G = self._smooth_cov(G,
|
||||
group["cov_min"][3],
|
||||
group["cov_max"][3],
|
||||
group["cov_pow"][3])
|
||||
|
||||
if True:
|
||||
# This block smooths G_prime.
|
||||
# Make sure G_prime has unit mean and no eigenvalue is super small. Note, G_prime
|
||||
# is already diagonalized, the variable G_prime is just the tensor of eigenvalues.
|
||||
G_prime_mean = _mean(G_prime, exclude_dims=[0], keepdim=True)
|
||||
G_prime_min = group["cov_min"][3]
|
||||
# make sure G_prime has no zero eigs, and is unit mean.
|
||||
G_prime = G_prime + G_prime_min * G_prime_mean + 1.0e-20
|
||||
G_prime /= _mean(G_prime, exclude_dims=[0], keepdim=True)
|
||||
# it now has unit mean..
|
||||
G_prime_max = group["cov_max"][3]
|
||||
G_prime = 1. / (1./G_prime + 1./G_prime_max) # apply max
|
||||
G_prime_pow = group["cov_pow"][3]
|
||||
if G_prime_pow != 1.0:
|
||||
G_prime = G_prime ** G_prime_pow
|
||||
# C C^T == G.
|
||||
C = G.cholesky()
|
||||
|
||||
# treat the last dim of C as being in an arbitrary space, its next-to-last dim
|
||||
# is the "canonical" one that we need to sum with the dims of P.
|
||||
P_gnorm = torch.matmul(C.transpose(2, 3),
|
||||
torch.matmul(P, C))
|
||||
|
||||
G_prime_rms = G_prime.sqrt()
|
||||
G_prime_scale = G_prime_rms.unsqueeze(-1) * G_prime_rms.unsqueeze(-2)
|
||||
# P_gnorm is a version of P_prime that is multiplied by G (actually
|
||||
# G^{0.5} P_prime G^{0.5}), so that it reflects the amount of
|
||||
# loss-function change in each dimension. We also tried smoothing
|
||||
# a version of P_prime divided by G, but it seemed not to be helpful.
|
||||
P_gnorm = P_prime * G_prime_scale
|
||||
# Apply another round of smoothing "relative to G"
|
||||
P_gnorm = self._smooth_cov(P_gnorm,
|
||||
group["cov_min"][1],
|
||||
group["cov_max"][1],
|
||||
group["cov_pow"][1])
|
||||
# Undo the scaling relative to G, so we have stage-2-smoothed version of P_prime.
|
||||
P_prime = P_gnorm / G_prime_scale
|
||||
|
||||
#torch.triangular_solve(b, A, upper=True, transpose=False, unitriangular=False, *, out=None)
|
||||
#In symbols, it solves AX = b, ... X = A^{-1} b
|
||||
#and assumes A is square upper-triangular (or lower-triangular if upper= False) and does not have zeros on the diagonal.
|
||||
#
|
||||
# We want to invert the normalization by C to recover P (if there weren't smoothing), so we want
|
||||
# P = torch.matmul(C_inv.transpose(2, 3),
|
||||
# torch.matmul(P_gnorm, C_inv))
|
||||
|
||||
# the following is equivalent to: P_temp = torch.matmul(C_inv.transpose(2, 3), P_gnorm),
|
||||
# where C_inv = C.invert()
|
||||
P_temp = torch.triangular_solve(P_gnorm, C, upper=False, transpose=True)[0]
|
||||
# .. now we want to do the same on the other axis, need transpose.
|
||||
P = torch.triangular_solve(P_temp.transpose(2, 3), C,
|
||||
upper=False, transpose=True)[0].transpose(2, 3)
|
||||
|
||||
if True: # TEMP
|
||||
C_inv = C.inverse()
|
||||
P2 = torch.matmul(C_inv.transpose(2, 3),
|
||||
torch.matmul(P_gnorm, C_inv))
|
||||
assert (P-P2).norm() <= 0.001 * P.norm()
|
||||
|
||||
|
||||
# Apply a 3rd round of smoothing in the canonical basis.
|
||||
P_prime = self._smooth_cov(P_prime,
|
||||
group["cov_min"][2],
|
||||
group["cov_max"][2],
|
||||
group["cov_pow"][2])
|
||||
return P_prime
|
||||
P = self._smooth_cov(P,
|
||||
group["cov_min"][2],
|
||||
group["cov_max"][2],
|
||||
group["cov_pow"][2])
|
||||
return P
|
||||
|
||||
def _smooth_cov(self,
|
||||
X: Tensor,
|
||||
|
Loading…
x
Reference in New Issue
Block a user