Bug fix to reproduce past results with max_block_size unset.

This commit is contained in:
Daniel Povey 2022-07-11 17:03:32 -07:00
parent 075a2e27d8
commit 4f0e219523

View File

@ -567,14 +567,14 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
M = cur_p.transpose(dim, -1)
# if p were of shape (batch_size, x, size, y, z),
# after the next line M will be of shape
# after the next line M would be of shape
# (batch_size, x, y, z, num_blocks, block_size)
M = M.reshape(batch_size, *M.shape[1:-1],
num_blocks, block_size)
M = _move_dim(M, -2, 1) # (batch_size, num_blocks, x, y, z, block_size)
while U.ndim < M.ndim:
U = U.unsqueeze(1)
U = U.unsqueeze(2)
# Now U is of shape (batch_size, num_blocks, 1, 1, block_size, block_size)
M = torch.matmul(M, U) # (batch_size, num_blocks, x, y, z, block_size)
@ -598,9 +598,10 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
# the idea is to ensure it doesn't have any too-small eigenvalues
# (where the stats permit).
scale = (S.clamp(min=eps) / cur_param_var.clamp(min=eps)).sqrt()
logging.info(f"shape={p.shape}, dim={dim}, scale={scale[0].flatten()[::10]}")
# scale shape: (batch_size, 1, size, 1, 1) if dim==2
logging.info(f"shape={p.shape}, dim={dim}, scale={scale[0].flatten()[::10]}, cur_param_var={cur_param_var[0].flatten()[::10]}, S={S[0].flatten()[::10]}")
# scale shape: (batch_size, 1, size, 1, 1)
cur_p *= scale
# OK, at this point we have a matrix cur_p that is (somewhat)