From 56d6dd55ae7f55ce7296ef180f57a8a66f3ac7a3 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 9 Jun 2022 12:06:35 +0800 Subject: [PATCH] Bug fixes --- egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py index 6662c4edd..5eb301c89 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py @@ -809,14 +809,14 @@ class Decorrelate(torch.nn.Module): self.register_buffer('cov', torch.zeros(num_channels, num_channels)) # step_buf is a copy of step, included so it will be loaded/saved with # the model. - self.register_buffer('step_buf', torch.tensor(0)) + self.register_buffer('step_buf', torch.tensor(0.0)) self.step = 0 def load_state_dict(self, *args, **kwargs): super(Decorrelate, self).load_state_dict(*args, **kwargs) - self.step = self.step_buf.item() + self.step = int(self.step_buf.item()) def forward(self, x: Tensor) -> Tensor: @@ -825,7 +825,7 @@ class Decorrelate(torch.nn.Module): else: apply_prob = self.apply_prob_decay / (self.step + self.apply_prob_decay) self.step += 1 - self.step_buf.fill_(self.step) + self.step_buf.fill_(float(self.step)) if random.random() > apply_prob: return x with torch.cuda.amp.autocast(enabled=False):