From 5c0957d9501cd6d46f58a15b02364043304cef5e Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 9 Dec 2022 18:11:27 +0800 Subject: [PATCH] Fix memory issue in ActivationBalancer --- egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index 36aa5f660..e1d220f5f 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -690,7 +690,7 @@ class ActivationBalancer(torch.nn.Module): assert x.shape[self.channel_dim] == self.num_channels sign_gain_factor = 0.5 if float(self.min_positive) != 0.0 or float(self.max_positive) != 1.0: - sign_factor = _compute_sign_factor(x, self.channel_dim, + sign_factor = _compute_sign_factor(x.detach(), self.channel_dim, float(self.min_positive), float(self.max_positive), gain_factor=float(self.sign_gain_factor) / prob, @@ -699,7 +699,7 @@ class ActivationBalancer(torch.nn.Module): sign_factor = None - scale_factor, mean = _compute_scale_factor(x, self.channel_dim, + scale_factor, mean = _compute_scale_factor(x.detach(), self.channel_dim, min_abs=float(self.min_abs), max_abs=float(self.max_abs), gain_factor=float(self.scale_gain_factor) / prob,