mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Revert "Remove memory-cutoff from ActivationBalancer."
This reverts commit 5afe0e78556e2e76750cae64008c9dd5c1931c5c.
This commit is contained in:
parent
829e4bd4db
commit
11f68afa1f
@ -871,6 +871,9 @@ class ActivationBalancer(torch.nn.Module):
|
||||
if prob is None:
|
||||
prob = ScheduledFloat((0.0, 0.5), (8000.0, 0.125), default=0.4)
|
||||
self.prob = prob
|
||||
# 5% of the time we will return and do nothing because memory usage is
|
||||
# too high.
|
||||
self.mem_cutoff = CutoffEstimator(0.05)
|
||||
|
||||
# actually self.num_channels is no longer needed except for an assertion.
|
||||
self.num_channels = num_channels
|
||||
@ -885,7 +888,8 @@ class ActivationBalancer(torch.nn.Module):
|
||||
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
if torch.jit.is_scripting() or not x.requires_grad:
|
||||
if (torch.jit.is_scripting() or not x.requires_grad or
|
||||
(x.is_cuda and self.mem_cutoff(torch.cuda.memory_allocated()))):
|
||||
return _no_op(x)
|
||||
|
||||
prob = float(self.prob)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user