diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py index f0bfbc354..c2e8379be 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py @@ -1181,13 +1181,13 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of g = grad.transpose(-1, dim) # if grad were of shape (batch_size, x, size, y, z), - # and g of shape (batch_size, x, y, z, size), - # after the next line g will be of shape - # (batch_size, x, y, z, num_blocks, block_size) - g = g.reshape(batch_size, *g.shape[1:-1], - num_blocks, block_size) - g = _move_dim(g, -2, 1) # (batch_size, num_blocks, x, y, z, block_size) + # g would be of shape (batch_size, x, z, y, size); and + # after the next line, g will be of shape + # (batch_size, x, z, y, num_blocks, block_size) + g = g.reshape(*g.shape[:-1], num_blocks, block_size) + g = _move_dim(g, -2, 1) # (batch_size, num_blocks, x, z, y, block_size) g = g.reshape(batch_size, num_blocks, -1, block_size) + # now g is of shape (batch_size, num_blocks, x*z*y, block_size) # this_grad_cov: (batch_size, num_blocks, block_size, block_size) this_grad_cov = torch.matmul(g.transpose(-2, -1), g)