From c097c13720855f0882ae7bc27c70904a04c137e8 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 21 Dec 2022 11:24:47 +0800 Subject: [PATCH] Change memory cutoff for ActivationBalancer; remove it for Whiten --- .../ASR/pruned_transducer_stateless7/scaling.py | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index b450999ad..161a5aa4a 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -871,9 +871,9 @@ class ActivationBalancer(torch.nn.Module): if prob is None: prob = ScheduledFloat((0.0, 0.5), (8000.0, 0.125), default=0.4) self.prob = prob - # 20% of the time we will return and do nothing because memory usage is + # 5% of the time we will return and do nothing because memory usage is # too high. - self.mem_cutoff = CutoffEstimator(0.2) + self.mem_cutoff = CutoffEstimator(0.05) # actually self.num_channels is no longer needed except for an assertion. self.num_channels = num_channels @@ -1070,10 +1070,6 @@ class Whiten(nn.Module): self.whitening_limit = whitening_limit self.grad_scale = grad_scale - # 20% of the time we will return and do nothing because memory usage - # is too high. - self.mem_cutoff = CutoffEstimator(0.2) - if isinstance(prob, float): assert 0 < prob <= 1 self.prob = prob @@ -1102,8 +1098,7 @@ class Whiten(nn.Module): and nothing will happen in backprop. """ grad_scale = float(self.grad_scale) - if (not x.requires_grad or random.random() > self.prob or grad_scale == 0 - or (x.is_cuda and self.mem_cutoff(torch.cuda.memory_allocated()))): + if not x.requires_grad or random.random() > self.prob or grad_scale == 0: return _no_op(x) else: whitening_limit = float(self.whitening_limit)