diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py index 25e3fdccb..7fbbe55c6 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py @@ -668,11 +668,11 @@ class RandProjDrop(torch.nn.Module): else: self._randomize_U() x = x.transpose(self.channel_dim, -1) # (..., num_channels) - - x = torch.matmul(x, self.U) + U = self.U.clone() + x = torch.matmul(x, U) x = torch.nn.functional.dropout(x, self.dropout_rate, training=True) - x = torch.matmul(x, self.U.t()) + x = torch.matmul(x, U.t()) x = x.transpose(self.channel_dim, -1) # (..., num_channels) return x