Try to resolve graph-freed problem

This commit is contained in:
Daniel Povey 2022-06-08 20:07:35 +08:00
parent 46ca1cd4c4
commit 8e56445c70

View File

@ -738,6 +738,7 @@ class DecorrelateFunction(torch.autograd.Function):
full_shape = x.shape
x = x.reshape(-1, num_channels)
x = x.detach()
old_cov = old_cov.detach()
x.requires_grad = True
x_grad = x_grad.reshape(-1, num_channels)
@ -804,8 +805,9 @@ class Decorrelate(torch.nn.Module):
x = x.transpose(self.channel_dim, -1)
x = x.reshape(-1, x.shape[-1])
cov = torch.matmul(x.t(), x) / x.shape[0]
self.cov.mul_(self.beta).add_(cov, alpha=(1-self.beta))
cov = torch.matmul(x.t(), x)
with torch.no_grad():
self.cov.mul_(self.beta).add_(cov, alpha=(1-self.beta))
self.step += 1
return ans # ans == x.