mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-19 05:54:20 +00:00
Fix bug in smooth_cov, for power==1.0
This commit is contained in:
parent
cc388675a9
commit
b47433b77a
@ -144,6 +144,8 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
||||
size_lr_scale=0.1,
|
||||
min_lr_factor=(0.01, 0.01, 0.01),
|
||||
max_lr_factor=(10.0, 10.0, 10.0),
|
||||
#param_pow=(0.99999, 0.99999, 0.99999),
|
||||
param_pow=(1.0, 1.0, 1.0),
|
||||
param_rms_smooth0=0.75,
|
||||
param_rms_smooth1=0.25,
|
||||
eps=1.0e-08,
|
||||
@ -163,6 +165,7 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
||||
size_lr_scale=size_lr_scale,
|
||||
min_lr_factor=min_lr_factor,
|
||||
max_lr_factor=max_lr_factor,
|
||||
param_pow=param_pow,
|
||||
param_rms_smooth0=param_rms_smooth0,
|
||||
param_rms_smooth1=param_rms_smooth1,
|
||||
betas=betas,
|
||||
@ -678,7 +681,7 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
||||
# cur_scales for the other dims.
|
||||
cur_scales = [None] * ndim
|
||||
|
||||
debug = (random.random() < 0.001)
|
||||
debug = (random.random() < 0.1)
|
||||
for i in range(4): # for 4 iterations (this is quite arbitrary)
|
||||
for dim in range(1, ndim):
|
||||
size = p.shape[dim]
|
||||
@ -949,7 +952,8 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
||||
|
||||
P_norm = self._smooth_cov(P_norm,
|
||||
group["min_lr_factor"][0],
|
||||
group["max_lr_factor"][0])
|
||||
group["max_lr_factor"][0],
|
||||
group["param_pow"][0])
|
||||
# Remove the diagonal preconditioning on P_norm, giving us stage-1-smoothed
|
||||
# version of P_prime.
|
||||
P_prime = P_norm * P_prime_scale
|
||||
@ -969,14 +973,16 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
||||
# Apply another round of smoothing "relative to G"
|
||||
P_gnorm = self._smooth_cov(P_gnorm,
|
||||
group["min_lr_factor"][1],
|
||||
group["max_lr_factor"][1])
|
||||
group["max_lr_factor"][1],
|
||||
group["param_pow"][1])
|
||||
# Undo the scaling relative to G, so we have stage-2-smoothed version of P_prime.
|
||||
P_prime = P_gnorm * G_prime_scale
|
||||
|
||||
# Apply a 3rd round of smoothing
|
||||
P_prime = self._smooth_cov(P_prime,
|
||||
group["min_lr_factor"][2],
|
||||
group["max_lr_factor"][2])
|
||||
group["max_lr_factor"][2],
|
||||
group["param_pow"][2])
|
||||
return P_prime
|
||||
|
||||
def _smooth_cov(self,
|
||||
@ -1007,9 +1013,9 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
||||
X: the batch of symmetric positive definite tensors we are smoothing;
|
||||
of shape (batch_size, num_blocks, block_size, block_size)
|
||||
"""
|
||||
eps = 1.0e-10
|
||||
if power != 1.0:
|
||||
U, S, _ = _svd(X)
|
||||
eps = 1.0e-10
|
||||
S_mean = _mean(S, exclude_dims=[0], keepdim=True)
|
||||
S = S + min_eig * S_mean + eps
|
||||
S_mean = S_mean * (1 + min_eig) + eps
|
||||
@ -1019,17 +1025,18 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
||||
S = S / _mean(S, exclude_dims=[0], keepdim=True)
|
||||
return torch.matmul(U * S.unsqueeze(-2), U.transpose(2, 3))
|
||||
else:
|
||||
X = X.clone() # may be
|
||||
X = X.clone()
|
||||
diag = _diag(X) # Aliased with X
|
||||
mean_eig = _mean(diag, exclude_dims=[0], keepdim=True)
|
||||
eps = 1.0e-10 # prevent division by zero
|
||||
diag += (mean_eig * min_eig + eps)
|
||||
cur_diag_mean = mean_eig * (1 + min_eig) + eps
|
||||
# The following 2 statements will be equivalent to:
|
||||
# L /= L.mean()
|
||||
# L = 1 / (1/L + 1/max_eig) # soft-min between L and max_eig
|
||||
# if L is the eigenvalues
|
||||
X = (X.inverse() + 1/(max_eig * cur_diag_mean.unsqueeze(-1))).inverse()
|
||||
X_inv = X.inverse()
|
||||
_diag(X_inv).add_(1. / (max_eig * cur_diag_mean))
|
||||
X = X_inv.inverse()
|
||||
X /= _mean(_diag(X), exclude_dims=[0], keepdim=True).unsqueeze(-1)
|
||||
return X
|
||||
|
||||
@ -2107,6 +2114,17 @@ def _test_eve_cain():
|
||||
elif iter == 3: optim = PrAdam(m.parameters(), lr=0.03, max_block_size=100)
|
||||
scheduler = Eden(optim, lr_batches=200, lr_epochs=5, verbose=False)
|
||||
|
||||
#TEMP
|
||||
if iter == 3:
|
||||
a = torch.randn(5, 10, 5, 5)
|
||||
a = torch.matmul(a, a.transpose(2, 3))
|
||||
b1 = optim._smooth_cov(a, 0.1, 4.0, 0.999999999)
|
||||
b2 = optim._smooth_cov(a, 0.1, 4.0, 1.0)
|
||||
diff = (b1 - b2)
|
||||
ratio = (diff**2).sqrt().mean() / (b1**2).sqrt().mean()
|
||||
logging.info(f"ratio = {ratio}")
|
||||
assert ratio < 0.01
|
||||
|
||||
start = timeit.default_timer()
|
||||
avg_loss = 0.0
|
||||
for epoch in range(150):
|
||||
|
Loading…
x
Reference in New Issue
Block a user