mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-08 16:44:20 +00:00
Try to resolve graph-freed problem
This commit is contained in:
parent
46ca1cd4c4
commit
8e56445c70
@ -738,6 +738,7 @@ class DecorrelateFunction(torch.autograd.Function):
|
||||
full_shape = x.shape
|
||||
x = x.reshape(-1, num_channels)
|
||||
x = x.detach()
|
||||
old_cov = old_cov.detach()
|
||||
x.requires_grad = True
|
||||
x_grad = x_grad.reshape(-1, num_channels)
|
||||
|
||||
@ -804,8 +805,9 @@ class Decorrelate(torch.nn.Module):
|
||||
|
||||
x = x.transpose(self.channel_dim, -1)
|
||||
x = x.reshape(-1, x.shape[-1])
|
||||
cov = torch.matmul(x.t(), x) / x.shape[0]
|
||||
self.cov.mul_(self.beta).add_(cov, alpha=(1-self.beta))
|
||||
cov = torch.matmul(x.t(), x)
|
||||
with torch.no_grad():
|
||||
self.cov.mul_(self.beta).add_(cov, alpha=(1-self.beta))
|
||||
self.step += 1
|
||||
|
||||
return ans # ans == x.
|
||||
|
Loading…
x
Reference in New Issue
Block a user