mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Cosmetic fixes
This commit is contained in:
parent
966ac36cde
commit
8a9bbb93bc
@ -1181,13 +1181,13 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
||||
|
||||
g = grad.transpose(-1, dim)
|
||||
# if grad were of shape (batch_size, x, size, y, z),
|
||||
# and g of shape (batch_size, x, y, z, size),
|
||||
# after the next line g will be of shape
|
||||
# (batch_size, x, y, z, num_blocks, block_size)
|
||||
g = g.reshape(batch_size, *g.shape[1:-1],
|
||||
num_blocks, block_size)
|
||||
g = _move_dim(g, -2, 1) # (batch_size, num_blocks, x, y, z, block_size)
|
||||
# g would be of shape (batch_size, x, z, y, size); and
|
||||
# after the next line, g will be of shape
|
||||
# (batch_size, x, z, y, num_blocks, block_size)
|
||||
g = g.reshape(*g.shape[:-1], num_blocks, block_size)
|
||||
g = _move_dim(g, -2, 1) # (batch_size, num_blocks, x, z, y, block_size)
|
||||
g = g.reshape(batch_size, num_blocks, -1, block_size)
|
||||
# now g is of shape (batch_size, num_blocks, x*z*y, block_size)
|
||||
|
||||
# this_grad_cov: (batch_size, num_blocks, block_size, block_size)
|
||||
this_grad_cov = torch.matmul(g.transpose(-2, -1), g)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user