Bug fixes

This commit is contained in:
Daniel Povey 2022-06-09 12:06:35 +08:00
parent 1669e21c0c
commit 56d6dd55ae

View File

@ -809,14 +809,14 @@ class Decorrelate(torch.nn.Module):
self.register_buffer('cov', torch.zeros(num_channels, num_channels)) self.register_buffer('cov', torch.zeros(num_channels, num_channels))
# step_buf is a copy of step, included so it will be loaded/saved with # step_buf is a copy of step, included so it will be loaded/saved with
# the model. # the model.
self.register_buffer('step_buf', torch.tensor(0)) self.register_buffer('step_buf', torch.tensor(0.0))
self.step = 0 self.step = 0
def load_state_dict(self, *args, **kwargs): def load_state_dict(self, *args, **kwargs):
super(Decorrelate, self).load_state_dict(*args, **kwargs) super(Decorrelate, self).load_state_dict(*args, **kwargs)
self.step = self.step_buf.item() self.step = int(self.step_buf.item())
def forward(self, x: Tensor) -> Tensor: def forward(self, x: Tensor) -> Tensor:
@ -825,7 +825,7 @@ class Decorrelate(torch.nn.Module):
else: else:
apply_prob = self.apply_prob_decay / (self.step + self.apply_prob_decay) apply_prob = self.apply_prob_decay / (self.step + self.apply_prob_decay)
self.step += 1 self.step += 1
self.step_buf.fill_(self.step) self.step_buf.fill_(float(self.step))
if random.random() > apply_prob: if random.random() > apply_prob:
return x return x
with torch.cuda.amp.autocast(enabled=False): with torch.cuda.amp.autocast(enabled=False):