diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index 208d15735..842233ef3 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -230,6 +230,41 @@ def random_cast_to_half(x: Tensor, return torch.where(is_too_small, random_val, x).to(torch.float16) + +class CutoffEstimator: + """ + Estimates cutoffs of an arbitrary numerical quantity such that a specified + proportion of items will be above the cutoff on average. + p is the proportion of items that should be above the cutoff. + + """ + def __init__(self, p: float): + self.p = p + # total count of items + self.count = 0 + # total count of items that were above the cutoff + self.count_above = 0 + # initial cutoff value + self.cutoff = 0 + + + def __call__(self, x: float) -> bool: + """ + Returns true if x is above the cutoff. + """ + ans = (x > self.cutoff) + self.count += 1 + if ans: + self.count_above += 1 + cur_p = self.count_above / self.count + print(f"cur_p = {cur_p}, cutoff = {self.cutoff}") + delta_p = cur_p - self.p + if (delta_p > 0) == ans: + q = abs(delta_p) + self.cutoff = x * q + self.cutoff * (1-q) + return ans + + class CachingEvalFunction(torch.autograd.Function): # @custom_fwd and @custom_bwd related to automatic mixed precision (amp) an ensure # that the backward path runs with the same autocast context as the forward pass. @@ -605,6 +640,9 @@ class ActivationBalancer(torch.nn.Module): if prob is None: prob = ScheduledFloat((0.0, 0.4), (8000.0, 0.1), default=0.4) self.prob = prob + # 10% of the time we will return and do nothing because memory usage + # is too high. + self.mem_cutoff = CutoffEstimator(0.1) # actually self.num_channels is no longer needed except for an assertion. self.num_channels = num_channels @@ -618,11 +656,9 @@ class ActivationBalancer(torch.nn.Module): self.scale_gain_factor = scale_gain_factor - - - def forward(self, x: Tensor) -> Tensor: - if torch.jit.is_scripting() or not x.requires_grad: + if (torch.jit.is_scripting() or not x.requires_grad or + (x.is_cuda and self.mem_cutoff(torch.cuda.memory_allocated()))): return _no_op(x) prob = float(self.prob) @@ -776,7 +812,7 @@ class Whiten(nn.Module): num_groups: int, whitening_limit: FloatLike, prob: Union[float, Tuple[float,float]], - grad_scale: float): + grad_scale: FloatLike): """ Args: num_groups: the number of groups to divide the channel dim into before @@ -801,6 +837,12 @@ class Whiten(nn.Module): assert grad_scale >= 0 self.num_groups = num_groups self.whitening_limit = whitening_limit + self.grad_scale = grad_scale + + # 10% of the time we will return and do nothing because memory usage + # is too high. + self.mem_cutoff = CutoffEstimator(0.1) + if isinstance(prob, float): assert 0 < prob <= 1 self.prob = prob @@ -809,7 +851,6 @@ class Whiten(nn.Module): assert 0 < self.min_prob < self.max_prob <= 1 self.prob = self.max_prob self.name = None # will be set in training loop - self.grad_scale = grad_scale def forward(self, x: Tensor) -> Tensor: @@ -829,7 +870,9 @@ class Whiten(nn.Module): you use the returned value, or the graph will be freed and nothing will happen in backprop. """ - if not x.requires_grad or random.random() > self.prob or self.grad_scale == 0: + grad_scale = float(self.grad_scale) + if (not x.requires_grad or random.random() > self.prob or grad_scale == 0 + or (x.is_cuda and self.mem_cutoff(torch.cuda.memory_allocated()))): return _no_op(x) else: whitening_limit = float(self.whitening_limit) @@ -845,7 +888,7 @@ class Whiten(nn.Module): return WhiteningPenaltyFunction.apply(x, self.num_groups, whitening_limit, - self.grad_scale, + grad_scale, self.name)