Cosmetic fixes

This commit is contained in:
Daniel Povey 2022-07-24 04:45:57 +08:00
parent 966ac36cde
commit 8a9bbb93bc

View File

@ -1181,13 +1181,13 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
g = grad.transpose(-1, dim)
# if grad were of shape (batch_size, x, size, y, z),
# and g of shape (batch_size, x, y, z, size),
# after the next line g will be of shape
# (batch_size, x, y, z, num_blocks, block_size)
g = g.reshape(batch_size, *g.shape[1:-1],
num_blocks, block_size)
g = _move_dim(g, -2, 1) # (batch_size, num_blocks, x, y, z, block_size)
# g would be of shape (batch_size, x, z, y, size); and
# after the next line, g will be of shape
# (batch_size, x, z, y, num_blocks, block_size)
g = g.reshape(*g.shape[:-1], num_blocks, block_size)
g = _move_dim(g, -2, 1) # (batch_size, num_blocks, x, z, y, block_size)
g = g.reshape(batch_size, num_blocks, -1, block_size)
# now g is of shape (batch_size, num_blocks, x*z*y, block_size)
# this_grad_cov: (batch_size, num_blocks, block_size, block_size)
this_grad_cov = torch.matmul(g.transpose(-2, -1), g)