mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-08 08:34:19 +00:00
Use torch.no_grad() for stats
This commit is contained in:
parent
07d3369234
commit
9133b57808
@ -221,6 +221,7 @@ class OrthogonalTransformation(nn.Module):
|
|||||||
"""
|
"""
|
||||||
x = torch.matmul(x, self.weight.t())
|
x = torch.matmul(x, self.weight.t())
|
||||||
if self.step % 10 == 0 and self.train():
|
if self.step % 10 == 0 and self.train():
|
||||||
|
with torch.no_grad():
|
||||||
# store covariance after input transform.
|
# store covariance after input transform.
|
||||||
# Update covariance stats every 10 batches (in training mode)
|
# Update covariance stats every 10 batches (in training mode)
|
||||||
f = x.reshape(-1, x.shape[-1])
|
f = x.reshape(-1, x.shape[-1])
|
||||||
|
Loading…
x
Reference in New Issue
Block a user