mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Bug fix to reproduce past results with max_block_size unset.
This commit is contained in:
parent
075a2e27d8
commit
4f0e219523
@ -567,14 +567,14 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
||||
|
||||
M = cur_p.transpose(dim, -1)
|
||||
# if p were of shape (batch_size, x, size, y, z),
|
||||
# after the next line M will be of shape
|
||||
# after the next line M would be of shape
|
||||
# (batch_size, x, y, z, num_blocks, block_size)
|
||||
M = M.reshape(batch_size, *M.shape[1:-1],
|
||||
num_blocks, block_size)
|
||||
M = _move_dim(M, -2, 1) # (batch_size, num_blocks, x, y, z, block_size)
|
||||
|
||||
while U.ndim < M.ndim:
|
||||
U = U.unsqueeze(1)
|
||||
U = U.unsqueeze(2)
|
||||
# Now U is of shape (batch_size, num_blocks, 1, 1, block_size, block_size)
|
||||
|
||||
M = torch.matmul(M, U) # (batch_size, num_blocks, x, y, z, block_size)
|
||||
@ -598,9 +598,10 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
||||
# the idea is to ensure it doesn't have any too-small eigenvalues
|
||||
# (where the stats permit).
|
||||
scale = (S.clamp(min=eps) / cur_param_var.clamp(min=eps)).sqrt()
|
||||
logging.info(f"shape={p.shape}, dim={dim}, scale={scale[0].flatten()[::10]}")
|
||||
|
||||
# scale shape: (batch_size, 1, size, 1, 1) if dim==2
|
||||
logging.info(f"shape={p.shape}, dim={dim}, scale={scale[0].flatten()[::10]}, cur_param_var={cur_param_var[0].flatten()[::10]}, S={S[0].flatten()[::10]}")
|
||||
|
||||
# scale shape: (batch_size, 1, size, 1, 1)
|
||||
cur_p *= scale
|
||||
|
||||
# OK, at this point we have a matrix cur_p that is (somewhat)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user