Change memory cutoff for ActivationBalancer; remove it for Whiten

This commit is contained in:
Daniel Povey 2022-12-21 11:24:47 +08:00
parent 4d61d39d36
commit c097c13720

View File

@ -871,9 +871,9 @@ class ActivationBalancer(torch.nn.Module):
if prob is None: if prob is None:
prob = ScheduledFloat((0.0, 0.5), (8000.0, 0.125), default=0.4) prob = ScheduledFloat((0.0, 0.5), (8000.0, 0.125), default=0.4)
self.prob = prob self.prob = prob
# 20% of the time we will return and do nothing because memory usage is # 5% of the time we will return and do nothing because memory usage is
# too high. # too high.
self.mem_cutoff = CutoffEstimator(0.2) self.mem_cutoff = CutoffEstimator(0.05)
# actually self.num_channels is no longer needed except for an assertion. # actually self.num_channels is no longer needed except for an assertion.
self.num_channels = num_channels self.num_channels = num_channels
@ -1070,10 +1070,6 @@ class Whiten(nn.Module):
self.whitening_limit = whitening_limit self.whitening_limit = whitening_limit
self.grad_scale = grad_scale self.grad_scale = grad_scale
# 20% of the time we will return and do nothing because memory usage
# is too high.
self.mem_cutoff = CutoffEstimator(0.2)
if isinstance(prob, float): if isinstance(prob, float):
assert 0 < prob <= 1 assert 0 < prob <= 1
self.prob = prob self.prob = prob
@ -1102,8 +1098,7 @@ class Whiten(nn.Module):
and nothing will happen in backprop. and nothing will happen in backprop.
""" """
grad_scale = float(self.grad_scale) grad_scale = float(self.grad_scale)
if (not x.requires_grad or random.random() > self.prob or grad_scale == 0 if not x.requires_grad or random.random() > self.prob or grad_scale == 0:
or (x.is_cuda and self.mem_cutoff(torch.cuda.memory_allocated()))):
return _no_op(x) return _no_op(x)
else: else:
whitening_limit = float(self.whitening_limit) whitening_limit = float(self.whitening_limit)