mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-09 09:04:19 +00:00
Bug fixes
This commit is contained in:
parent
1669e21c0c
commit
56d6dd55ae
@ -809,14 +809,14 @@ class Decorrelate(torch.nn.Module):
|
|||||||
self.register_buffer('cov', torch.zeros(num_channels, num_channels))
|
self.register_buffer('cov', torch.zeros(num_channels, num_channels))
|
||||||
# step_buf is a copy of step, included so it will be loaded/saved with
|
# step_buf is a copy of step, included so it will be loaded/saved with
|
||||||
# the model.
|
# the model.
|
||||||
self.register_buffer('step_buf', torch.tensor(0))
|
self.register_buffer('step_buf', torch.tensor(0.0))
|
||||||
self.step = 0
|
self.step = 0
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def load_state_dict(self, *args, **kwargs):
|
def load_state_dict(self, *args, **kwargs):
|
||||||
super(Decorrelate, self).load_state_dict(*args, **kwargs)
|
super(Decorrelate, self).load_state_dict(*args, **kwargs)
|
||||||
self.step = self.step_buf.item()
|
self.step = int(self.step_buf.item())
|
||||||
|
|
||||||
|
|
||||||
def forward(self, x: Tensor) -> Tensor:
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
@ -825,7 +825,7 @@ class Decorrelate(torch.nn.Module):
|
|||||||
else:
|
else:
|
||||||
apply_prob = self.apply_prob_decay / (self.step + self.apply_prob_decay)
|
apply_prob = self.apply_prob_decay / (self.step + self.apply_prob_decay)
|
||||||
self.step += 1
|
self.step += 1
|
||||||
self.step_buf.fill_(self.step)
|
self.step_buf.fill_(float(self.step))
|
||||||
if random.random() > apply_prob:
|
if random.random() > apply_prob:
|
||||||
return x
|
return x
|
||||||
with torch.cuda.amp.autocast(enabled=False):
|
with torch.cuda.amp.autocast(enabled=False):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user