From 11f68afa1fea01e6411f6a7b05d82b259710146a Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 21 Dec 2022 18:39:16 +0800 Subject: [PATCH] Revert "Remove memory-cutoff from ActivationBalancer." This reverts commit 5afe0e78556e2e76750cae64008c9dd5c1931c5c. --- egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index 20d06329d..04a2822ee 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -871,6 +871,9 @@ class ActivationBalancer(torch.nn.Module): if prob is None: prob = ScheduledFloat((0.0, 0.5), (8000.0, 0.125), default=0.4) self.prob = prob + # 5% of the time we will return and do nothing because memory usage is + # too high. + self.mem_cutoff = CutoffEstimator(0.05) # actually self.num_channels is no longer needed except for an assertion. self.num_channels = num_channels @@ -885,7 +888,8 @@ class ActivationBalancer(torch.nn.Module): def forward(self, x: Tensor) -> Tensor: - if torch.jit.is_scripting() or not x.requires_grad: + if (torch.jit.is_scripting() or not x.requires_grad or + (x.is_cuda and self.mem_cutoff(torch.cuda.memory_allocated()))): return _no_op(x) prob = float(self.prob)