Add memory cutoff on ActivationBalancer and Whiten

This commit is contained in:
Daniel Povey 2022-12-17 16:20:15 +08:00
parent 96daf7a00f
commit 29df07ba2c

View File

@ -230,6 +230,41 @@ def random_cast_to_half(x: Tensor,
return torch.where(is_too_small, random_val, x).to(torch.float16) return torch.where(is_too_small, random_val, x).to(torch.float16)
class CutoffEstimator:
"""
Estimates cutoffs of an arbitrary numerical quantity such that a specified
proportion of items will be above the cutoff on average.
p is the proportion of items that should be above the cutoff.
"""
def __init__(self, p: float):
self.p = p
# total count of items
self.count = 0
# total count of items that were above the cutoff
self.count_above = 0
# initial cutoff value
self.cutoff = 0
def __call__(self, x: float) -> bool:
"""
Returns true if x is above the cutoff.
"""
ans = (x > self.cutoff)
self.count += 1
if ans:
self.count_above += 1
cur_p = self.count_above / self.count
print(f"cur_p = {cur_p}, cutoff = {self.cutoff}")
delta_p = cur_p - self.p
if (delta_p > 0) == ans:
q = abs(delta_p)
self.cutoff = x * q + self.cutoff * (1-q)
return ans
class CachingEvalFunction(torch.autograd.Function): class CachingEvalFunction(torch.autograd.Function):
# @custom_fwd and @custom_bwd related to automatic mixed precision (amp) an ensure # @custom_fwd and @custom_bwd related to automatic mixed precision (amp) an ensure
# that the backward path runs with the same autocast context as the forward pass. # that the backward path runs with the same autocast context as the forward pass.
@ -605,6 +640,9 @@ class ActivationBalancer(torch.nn.Module):
if prob is None: if prob is None:
prob = ScheduledFloat((0.0, 0.4), (8000.0, 0.1), default=0.4) prob = ScheduledFloat((0.0, 0.4), (8000.0, 0.1), default=0.4)
self.prob = prob self.prob = prob
# 10% of the time we will return and do nothing because memory usage
# is too high.
self.mem_cutoff = CutoffEstimator(0.1)
# actually self.num_channels is no longer needed except for an assertion. # actually self.num_channels is no longer needed except for an assertion.
self.num_channels = num_channels self.num_channels = num_channels
@ -618,11 +656,9 @@ class ActivationBalancer(torch.nn.Module):
self.scale_gain_factor = scale_gain_factor self.scale_gain_factor = scale_gain_factor
def forward(self, x: Tensor) -> Tensor: def forward(self, x: Tensor) -> Tensor:
if torch.jit.is_scripting() or not x.requires_grad: if (torch.jit.is_scripting() or not x.requires_grad or
(x.is_cuda and self.mem_cutoff(torch.cuda.memory_allocated()))):
return _no_op(x) return _no_op(x)
prob = float(self.prob) prob = float(self.prob)
@ -776,7 +812,7 @@ class Whiten(nn.Module):
num_groups: int, num_groups: int,
whitening_limit: FloatLike, whitening_limit: FloatLike,
prob: Union[float, Tuple[float,float]], prob: Union[float, Tuple[float,float]],
grad_scale: float): grad_scale: FloatLike):
""" """
Args: Args:
num_groups: the number of groups to divide the channel dim into before num_groups: the number of groups to divide the channel dim into before
@ -801,6 +837,12 @@ class Whiten(nn.Module):
assert grad_scale >= 0 assert grad_scale >= 0
self.num_groups = num_groups self.num_groups = num_groups
self.whitening_limit = whitening_limit self.whitening_limit = whitening_limit
self.grad_scale = grad_scale
# 10% of the time we will return and do nothing because memory usage
# is too high.
self.mem_cutoff = CutoffEstimator(0.1)
if isinstance(prob, float): if isinstance(prob, float):
assert 0 < prob <= 1 assert 0 < prob <= 1
self.prob = prob self.prob = prob
@ -809,7 +851,6 @@ class Whiten(nn.Module):
assert 0 < self.min_prob < self.max_prob <= 1 assert 0 < self.min_prob < self.max_prob <= 1
self.prob = self.max_prob self.prob = self.max_prob
self.name = None # will be set in training loop self.name = None # will be set in training loop
self.grad_scale = grad_scale
def forward(self, def forward(self,
x: Tensor) -> Tensor: x: Tensor) -> Tensor:
@ -829,7 +870,9 @@ class Whiten(nn.Module):
you use the returned value, or the graph will be freed you use the returned value, or the graph will be freed
and nothing will happen in backprop. and nothing will happen in backprop.
""" """
if not x.requires_grad or random.random() > self.prob or self.grad_scale == 0: grad_scale = float(self.grad_scale)
if (not x.requires_grad or random.random() > self.prob or grad_scale == 0
or (x.is_cuda and self.mem_cutoff(torch.cuda.memory_allocated()))):
return _no_op(x) return _no_op(x)
else: else:
whitening_limit = float(self.whitening_limit) whitening_limit = float(self.whitening_limit)
@ -845,7 +888,7 @@ class Whiten(nn.Module):
return WhiteningPenaltyFunction.apply(x, return WhiteningPenaltyFunction.apply(x,
self.num_groups, self.num_groups,
whitening_limit, whitening_limit,
self.grad_scale, grad_scale,
self.name) self.name)