mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Bug fix RE rankj
This commit is contained in:
parent
dee496145d
commit
cc388675a9
@ -142,8 +142,8 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
|||||||
lr=3e-02,
|
lr=3e-02,
|
||||||
betas=(0.9, 0.98),
|
betas=(0.9, 0.98),
|
||||||
size_lr_scale=0.1,
|
size_lr_scale=0.1,
|
||||||
min_lr_factor=(0.05, 0.05, 0.05),
|
min_lr_factor=(0.01, 0.01, 0.01),
|
||||||
max_lr_factor=(100.0, 100.0, 100.0),
|
max_lr_factor=(10.0, 10.0, 10.0),
|
||||||
param_rms_smooth0=0.75,
|
param_rms_smooth0=0.75,
|
||||||
param_rms_smooth1=0.25,
|
param_rms_smooth1=0.25,
|
||||||
eps=1.0e-08,
|
eps=1.0e-08,
|
||||||
@ -924,21 +924,23 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
|||||||
P_norm = P_prime / P_prime_scale
|
P_norm = P_prime / P_prime_scale
|
||||||
# Now P is as normalized as we can make it... do smoothing baserd on 'rank',
|
# Now P is as normalized as we can make it... do smoothing baserd on 'rank',
|
||||||
# that is intended to compensate for bad estimates of P.
|
# that is intended to compensate for bad estimates of P.
|
||||||
batch_size = p_shape[0]
|
(batch_size, num_blocks, block_size, block_size) = P_norm.shape
|
||||||
size = P_norm.shape[0] # size of dim we are concerned with right now
|
# `rank_per_block` is the rank of each block of P_prime if we were to estimate it from just one
|
||||||
# `rank` is the rank of P_prime if we were to estimate it from just one
|
|
||||||
# parameter tensor. We average it over time, but actually it won't be changing
|
# parameter tensor. We average it over time, but actually it won't be changing
|
||||||
# too much, so `rank` does tell us something.
|
# too much, so `rank` does tell us something.
|
||||||
rank = p_shape.numel() // (size * batch_size)
|
size = num_blocks * block_size
|
||||||
|
rank = p_shape.numel() // (size * batch_size) # actually the rank of each block
|
||||||
smooth0 = group["param_rms_smooth0"]
|
smooth0 = group["param_rms_smooth0"]
|
||||||
smooth1 = group["param_rms_smooth1"]
|
smooth1 = group["param_rms_smooth1"]
|
||||||
# We want expr for smoothing amount to be of the form: smooth = alpha * size / (beta*rank + size)
|
# We want expr for smoothing amount to be of the form: smooth = alpha * size / (beta*rank + size)
|
||||||
|
# for "size" here, we actually want to use block_size, since we are concerned about the
|
||||||
|
# robustness of the covariance within these blocks.
|
||||||
# param_rms_smooth{0,1} represents the user-specified desired amount of smoothing
|
# param_rms_smooth{0,1} represents the user-specified desired amount of smoothing
|
||||||
# when rank==0*size and rank==1*size, respectively.
|
# when rank==0*size and rank==1*size, respectively.
|
||||||
# from rank==0*size, we get smooth0 = alpha * size/size, so alpha = smooth0.
|
# from rank==0*size, we get smooth0 = alpha * size/size, so alpha = smooth0.
|
||||||
# from setting rank==size, we get smooth1 = alpha * size / (beta*size * size) = alpha/(1+beta),
|
# from setting rank==size, we get smooth1 = alpha * size / (beta*size * size) = alpha/(1+beta),
|
||||||
# so smooth1 == smooth0 / (1+beta), so (1+beta) = smooth0/smooth1, so beta=smooth0/smooth1 - 1
|
# so smooth1 == smooth0 / (1+beta), so (1+beta) = smooth0/smooth1, so beta=smooth0/smooth1 - 1
|
||||||
smooth = smooth0 * size / ((smooth0/smooth1 - 1) * rank + size)
|
smooth = smooth0 * block_size / ((smooth0/smooth1 - 1) * rank + block_size)
|
||||||
|
|
||||||
# add rank-dependent smoothing amount to diagonal of P_prime. _diag() returns an aliased tensor.
|
# add rank-dependent smoothing amount to diagonal of P_prime. _diag() returns an aliased tensor.
|
||||||
# we don't need to multiply `smooth` by anything, because at this point, P_prime should have
|
# we don't need to multiply `smooth` by anything, because at this point, P_prime should have
|
||||||
@ -2084,7 +2086,8 @@ def _test_eve_cain():
|
|||||||
input_magnitudes = (1.0 * torch.randn(E, dtype=dtype, device=device)).exp()
|
input_magnitudes = (1.0 * torch.randn(E, dtype=dtype, device=device)).exp()
|
||||||
output_magnitudes = (1.0 * torch.randn(E, dtype=dtype, device=device)).exp()
|
output_magnitudes = (1.0 * torch.randn(E, dtype=dtype, device=device)).exp()
|
||||||
|
|
||||||
for iter in [3, 2, 1, 0]:
|
#for iter in [3, 2, 1, 0]: # will restore 1,0 later
|
||||||
|
for iter in [3, 2]:
|
||||||
fix_random_seed(42)
|
fix_random_seed(42)
|
||||||
Linear = torch.nn.Linear if iter == 0 else ScaledLinear
|
Linear = torch.nn.Linear if iter == 0 else ScaledLinear
|
||||||
# TODO: find out why this is not converging...
|
# TODO: find out why this is not converging...
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user