mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Fixes to comments
This commit is contained in:
parent
33ffd17515
commit
966ac36cde
@ -638,20 +638,18 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
|||||||
M = cur_p.transpose(dim, -1)
|
M = cur_p.transpose(dim, -1)
|
||||||
# if p were of shape (batch_size, x, size, y, z),
|
# if p were of shape (batch_size, x, size, y, z),
|
||||||
# after the next line M would be of shape
|
# after the next line M would be of shape
|
||||||
# (batch_size, x, y, z, num_blocks, block_size)
|
# (batch_size, x, z, y, num_blocks, block_size)
|
||||||
M = M.reshape(batch_size, *M.shape[1:-1],
|
M = M.reshape(*M.shape[:-1], num_blocks, block_size)
|
||||||
num_blocks, block_size)
|
M = _move_dim(M, -2, 1) # (batch_size, num_blocks, x, z, y, block_size)
|
||||||
M = _move_dim(M, -2, 1) # (batch_size, num_blocks, x, y, z, block_size)
|
|
||||||
|
|
||||||
while Q.ndim < M.ndim:
|
while Q.ndim < M.ndim:
|
||||||
Q = Q.unsqueeze(2)
|
Q = Q.unsqueeze(2)
|
||||||
# Now Q is of shape (batch_size, num_blocks, 1, 1, block_size, block_size)
|
# Now Q is of shape (batch_size, num_blocks, 1, 1, block_size, block_size)
|
||||||
# [batch_index, block_index, diagonalized_coordinate, canonical_coordinate],
|
# with indexes [batch_index, block_index, 1, 1, diagonalized_coordinate, canonical_coordinate],
|
||||||
# so we need to transpose Q as we convert M to the diagonalized co-ordinate.
|
# so we need to transpose Q as we convert M to the diagonalized co-ordinate.
|
||||||
#M = torch.matmul(M, Q.transpose(2, 3)) # (batch_size, num_blocks, x, y, z, block_size)
|
M = torch.matmul(M, Q.transpose(-2, -1)) # (batch_size, num_blocks, x, z, y, block_size)
|
||||||
M = torch.matmul(M, Q.transpose(-2, -1)) # (batch_size, num_blocks, x, y, z, block_size)
|
M = _move_dim(M, 1, -2) # (batch_size, x, z, y, num_blocks, block_size)
|
||||||
M = _move_dim(M, 1, -2) # (batch_size, x, y, z, num_blocks, block_size)
|
M = M.reshape(*M.shape[:-2], size) # # (batch_size, x, z, y, size)
|
||||||
M = M.reshape(*M.shape[:-2], size) # # (batch_size, x, y, z, size)
|
|
||||||
cur_p = M.transpose(dim, -1) # (batch_size, x, size, y, z)
|
cur_p = M.transpose(dim, -1) # (batch_size, x, size, y, z)
|
||||||
|
|
||||||
# cur_param_var is a diagonal parameter variance over dimension `dim`,
|
# cur_param_var is a diagonal parameter variance over dimension `dim`,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user