Fix memory issue in ActivationBalancer

This commit is contained in:
Daniel Povey 2022-12-09 18:11:27 +08:00
parent 2ef0228db0
commit 5c0957d950

View File

@ -690,7 +690,7 @@ class ActivationBalancer(torch.nn.Module):
assert x.shape[self.channel_dim] == self.num_channels
sign_gain_factor = 0.5
if float(self.min_positive) != 0.0 or float(self.max_positive) != 1.0:
sign_factor = _compute_sign_factor(x, self.channel_dim,
sign_factor = _compute_sign_factor(x.detach(), self.channel_dim,
float(self.min_positive),
float(self.max_positive),
gain_factor=float(self.sign_gain_factor) / prob,
@ -699,7 +699,7 @@ class ActivationBalancer(torch.nn.Module):
sign_factor = None
scale_factor, mean = _compute_scale_factor(x, self.channel_dim,
scale_factor, mean = _compute_scale_factor(x.detach(), self.channel_dim,
min_abs=float(self.min_abs),
max_abs=float(self.max_abs),
gain_factor=float(self.scale_gain_factor) / prob,