Bug fixes

This commit is contained in:
Daniel Povey 2022-06-09 12:06:35 +08:00
parent 1669e21c0c
commit 56d6dd55ae

View File

@ -809,14 +809,14 @@ class Decorrelate(torch.nn.Module):
self.register_buffer('cov', torch.zeros(num_channels, num_channels))
# step_buf is a copy of step, included so it will be loaded/saved with
# the model.
self.register_buffer('step_buf', torch.tensor(0))
self.register_buffer('step_buf', torch.tensor(0.0))
self.step = 0
def load_state_dict(self, *args, **kwargs):
super(Decorrelate, self).load_state_dict(*args, **kwargs)
self.step = self.step_buf.item()
self.step = int(self.step_buf.item())
def forward(self, x: Tensor) -> Tensor:
@ -825,7 +825,7 @@ class Decorrelate(torch.nn.Module):
else:
apply_prob = self.apply_prob_decay / (self.step + self.apply_prob_decay)
self.step += 1
self.step_buf.fill_(self.step)
self.step_buf.fill_(float(self.step))
if random.random() > apply_prob:
return x
with torch.cuda.amp.autocast(enabled=False):