diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py index 65cfeedc3..f0bfbc354 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py @@ -638,20 +638,18 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of M = cur_p.transpose(dim, -1) # if p were of shape (batch_size, x, size, y, z), # after the next line M would be of shape - # (batch_size, x, y, z, num_blocks, block_size) - M = M.reshape(batch_size, *M.shape[1:-1], - num_blocks, block_size) - M = _move_dim(M, -2, 1) # (batch_size, num_blocks, x, y, z, block_size) + # (batch_size, x, z, y, num_blocks, block_size) + M = M.reshape(*M.shape[:-1], num_blocks, block_size) + M = _move_dim(M, -2, 1) # (batch_size, num_blocks, x, z, y, block_size) while Q.ndim < M.ndim: Q = Q.unsqueeze(2) # Now Q is of shape (batch_size, num_blocks, 1, 1, block_size, block_size) - # [batch_index, block_index, diagonalized_coordinate, canonical_coordinate], + # with indexes [batch_index, block_index, 1, 1, diagonalized_coordinate, canonical_coordinate], # so we need to transpose Q as we convert M to the diagonalized co-ordinate. - #M = torch.matmul(M, Q.transpose(2, 3)) # (batch_size, num_blocks, x, y, z, block_size) - M = torch.matmul(M, Q.transpose(-2, -1)) # (batch_size, num_blocks, x, y, z, block_size) - M = _move_dim(M, 1, -2) # (batch_size, x, y, z, num_blocks, block_size) - M = M.reshape(*M.shape[:-2], size) # # (batch_size, x, y, z, size) + M = torch.matmul(M, Q.transpose(-2, -1)) # (batch_size, num_blocks, x, z, y, block_size) + M = _move_dim(M, 1, -2) # (batch_size, x, z, y, num_blocks, block_size) + M = M.reshape(*M.shape[:-2], size) # # (batch_size, x, z, y, size) cur_p = M.transpose(dim, -1) # (batch_size, x, size, y, z) # cur_param_var is a diagonal parameter variance over dimension `dim`,