Fixes to comments

This commit is contained in:
Daniel Povey 2022-07-24 04:36:41 +08:00
parent 33ffd17515
commit 966ac36cde

View File

@ -638,20 +638,18 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
M = cur_p.transpose(dim, -1)
# if p were of shape (batch_size, x, size, y, z),
# after the next line M would be of shape
# (batch_size, x, y, z, num_blocks, block_size)
M = M.reshape(batch_size, *M.shape[1:-1],
num_blocks, block_size)
M = _move_dim(M, -2, 1) # (batch_size, num_blocks, x, y, z, block_size)
# (batch_size, x, z, y, num_blocks, block_size)
M = M.reshape(*M.shape[:-1], num_blocks, block_size)
M = _move_dim(M, -2, 1) # (batch_size, num_blocks, x, z, y, block_size)
while Q.ndim < M.ndim:
Q = Q.unsqueeze(2)
# Now Q is of shape (batch_size, num_blocks, 1, 1, block_size, block_size)
# [batch_index, block_index, diagonalized_coordinate, canonical_coordinate],
# with indexes [batch_index, block_index, 1, 1, diagonalized_coordinate, canonical_coordinate],
# so we need to transpose Q as we convert M to the diagonalized co-ordinate.
#M = torch.matmul(M, Q.transpose(2, 3)) # (batch_size, num_blocks, x, y, z, block_size)
M = torch.matmul(M, Q.transpose(-2, -1)) # (batch_size, num_blocks, x, y, z, block_size)
M = _move_dim(M, 1, -2) # (batch_size, x, y, z, num_blocks, block_size)
M = M.reshape(*M.shape[:-2], size) # # (batch_size, x, y, z, size)
M = torch.matmul(M, Q.transpose(-2, -1)) # (batch_size, num_blocks, x, z, y, block_size)
M = _move_dim(M, 1, -2) # (batch_size, x, z, y, num_blocks, block_size)
M = M.reshape(*M.shape[:-2], size) # # (batch_size, x, z, y, size)
cur_p = M.transpose(dim, -1) # (batch_size, x, size, y, z)
# cur_param_var is a diagonal parameter variance over dimension `dim`,