This commit is contained in:
Daniel Povey 2022-08-02 14:31:37 +08:00
parent e9f4ada1c0
commit e44ab25e99

View File

@ -677,7 +677,7 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
# value) to be 1, and since the singular values (==eigenvalues) of P
# are the squares of the singular values of C, we want them to have
# mean of 1.
P /= _mean(_diag(P), exclude_dims=[0]).unsqueeze(-1)
P /= _mean(_diag(P), exclude_dims=[0], keepdim=True).unsqueeze(-1)
# C will satisfy: P == torch.matmul(C, C.transpose(2, 3))
# C is of shape (batch_size, num_blocks, block_size, block_size).