From 8a9bbb93bc077a4ce0b63f92af4e9b4aac7ff946 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 24 Jul 2022 04:45:57 +0800 Subject: [PATCH] Cosmetic fixes --- .../ASR/pruned_transducer_stateless7/optim.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py index f0bfbc354..c2e8379be 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py @@ -1181,13 +1181,13 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of g = grad.transpose(-1, dim) # if grad were of shape (batch_size, x, size, y, z), - # and g of shape (batch_size, x, y, z, size), - # after the next line g will be of shape - # (batch_size, x, y, z, num_blocks, block_size) - g = g.reshape(batch_size, *g.shape[1:-1], - num_blocks, block_size) - g = _move_dim(g, -2, 1) # (batch_size, num_blocks, x, y, z, block_size) + # g would be of shape (batch_size, x, z, y, size); and + # after the next line, g will be of shape + # (batch_size, x, z, y, num_blocks, block_size) + g = g.reshape(*g.shape[:-1], num_blocks, block_size) + g = _move_dim(g, -2, 1) # (batch_size, num_blocks, x, z, y, block_size) g = g.reshape(batch_size, num_blocks, -1, block_size) + # now g is of shape (batch_size, num_blocks, x*z*y, block_size) # this_grad_cov: (batch_size, num_blocks, block_size, block_size) this_grad_cov = torch.matmul(g.transpose(-2, -1), g)