mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-08 08:34:19 +00:00
Bug fixes
This commit is contained in:
parent
1669e21c0c
commit
56d6dd55ae
@ -809,14 +809,14 @@ class Decorrelate(torch.nn.Module):
|
||||
self.register_buffer('cov', torch.zeros(num_channels, num_channels))
|
||||
# step_buf is a copy of step, included so it will be loaded/saved with
|
||||
# the model.
|
||||
self.register_buffer('step_buf', torch.tensor(0))
|
||||
self.register_buffer('step_buf', torch.tensor(0.0))
|
||||
self.step = 0
|
||||
|
||||
|
||||
|
||||
def load_state_dict(self, *args, **kwargs):
|
||||
super(Decorrelate, self).load_state_dict(*args, **kwargs)
|
||||
self.step = self.step_buf.item()
|
||||
self.step = int(self.step_buf.item())
|
||||
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
@ -825,7 +825,7 @@ class Decorrelate(torch.nn.Module):
|
||||
else:
|
||||
apply_prob = self.apply_prob_decay / (self.step + self.apply_prob_decay)
|
||||
self.step += 1
|
||||
self.step_buf.fill_(self.step)
|
||||
self.step_buf.fill_(float(self.step))
|
||||
if random.random() > apply_prob:
|
||||
return x
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
|
Loading…
x
Reference in New Issue
Block a user