mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Add memory cutoff on ActivationBalancer and Whiten
This commit is contained in:
parent
96daf7a00f
commit
29df07ba2c
@ -230,6 +230,41 @@ def random_cast_to_half(x: Tensor,
|
|||||||
return torch.where(is_too_small, random_val, x).to(torch.float16)
|
return torch.where(is_too_small, random_val, x).to(torch.float16)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class CutoffEstimator:
|
||||||
|
"""
|
||||||
|
Estimates cutoffs of an arbitrary numerical quantity such that a specified
|
||||||
|
proportion of items will be above the cutoff on average.
|
||||||
|
p is the proportion of items that should be above the cutoff.
|
||||||
|
|
||||||
|
"""
|
||||||
|
def __init__(self, p: float):
|
||||||
|
self.p = p
|
||||||
|
# total count of items
|
||||||
|
self.count = 0
|
||||||
|
# total count of items that were above the cutoff
|
||||||
|
self.count_above = 0
|
||||||
|
# initial cutoff value
|
||||||
|
self.cutoff = 0
|
||||||
|
|
||||||
|
|
||||||
|
def __call__(self, x: float) -> bool:
|
||||||
|
"""
|
||||||
|
Returns true if x is above the cutoff.
|
||||||
|
"""
|
||||||
|
ans = (x > self.cutoff)
|
||||||
|
self.count += 1
|
||||||
|
if ans:
|
||||||
|
self.count_above += 1
|
||||||
|
cur_p = self.count_above / self.count
|
||||||
|
print(f"cur_p = {cur_p}, cutoff = {self.cutoff}")
|
||||||
|
delta_p = cur_p - self.p
|
||||||
|
if (delta_p > 0) == ans:
|
||||||
|
q = abs(delta_p)
|
||||||
|
self.cutoff = x * q + self.cutoff * (1-q)
|
||||||
|
return ans
|
||||||
|
|
||||||
|
|
||||||
class CachingEvalFunction(torch.autograd.Function):
|
class CachingEvalFunction(torch.autograd.Function):
|
||||||
# @custom_fwd and @custom_bwd related to automatic mixed precision (amp) an ensure
|
# @custom_fwd and @custom_bwd related to automatic mixed precision (amp) an ensure
|
||||||
# that the backward path runs with the same autocast context as the forward pass.
|
# that the backward path runs with the same autocast context as the forward pass.
|
||||||
@ -605,6 +640,9 @@ class ActivationBalancer(torch.nn.Module):
|
|||||||
if prob is None:
|
if prob is None:
|
||||||
prob = ScheduledFloat((0.0, 0.4), (8000.0, 0.1), default=0.4)
|
prob = ScheduledFloat((0.0, 0.4), (8000.0, 0.1), default=0.4)
|
||||||
self.prob = prob
|
self.prob = prob
|
||||||
|
# 10% of the time we will return and do nothing because memory usage
|
||||||
|
# is too high.
|
||||||
|
self.mem_cutoff = CutoffEstimator(0.1)
|
||||||
|
|
||||||
# actually self.num_channels is no longer needed except for an assertion.
|
# actually self.num_channels is no longer needed except for an assertion.
|
||||||
self.num_channels = num_channels
|
self.num_channels = num_channels
|
||||||
@ -618,11 +656,9 @@ class ActivationBalancer(torch.nn.Module):
|
|||||||
self.scale_gain_factor = scale_gain_factor
|
self.scale_gain_factor = scale_gain_factor
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def forward(self, x: Tensor) -> Tensor:
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
if torch.jit.is_scripting() or not x.requires_grad:
|
if (torch.jit.is_scripting() or not x.requires_grad or
|
||||||
|
(x.is_cuda and self.mem_cutoff(torch.cuda.memory_allocated()))):
|
||||||
return _no_op(x)
|
return _no_op(x)
|
||||||
|
|
||||||
prob = float(self.prob)
|
prob = float(self.prob)
|
||||||
@ -776,7 +812,7 @@ class Whiten(nn.Module):
|
|||||||
num_groups: int,
|
num_groups: int,
|
||||||
whitening_limit: FloatLike,
|
whitening_limit: FloatLike,
|
||||||
prob: Union[float, Tuple[float,float]],
|
prob: Union[float, Tuple[float,float]],
|
||||||
grad_scale: float):
|
grad_scale: FloatLike):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
num_groups: the number of groups to divide the channel dim into before
|
num_groups: the number of groups to divide the channel dim into before
|
||||||
@ -801,6 +837,12 @@ class Whiten(nn.Module):
|
|||||||
assert grad_scale >= 0
|
assert grad_scale >= 0
|
||||||
self.num_groups = num_groups
|
self.num_groups = num_groups
|
||||||
self.whitening_limit = whitening_limit
|
self.whitening_limit = whitening_limit
|
||||||
|
self.grad_scale = grad_scale
|
||||||
|
|
||||||
|
# 10% of the time we will return and do nothing because memory usage
|
||||||
|
# is too high.
|
||||||
|
self.mem_cutoff = CutoffEstimator(0.1)
|
||||||
|
|
||||||
if isinstance(prob, float):
|
if isinstance(prob, float):
|
||||||
assert 0 < prob <= 1
|
assert 0 < prob <= 1
|
||||||
self.prob = prob
|
self.prob = prob
|
||||||
@ -809,7 +851,6 @@ class Whiten(nn.Module):
|
|||||||
assert 0 < self.min_prob < self.max_prob <= 1
|
assert 0 < self.min_prob < self.max_prob <= 1
|
||||||
self.prob = self.max_prob
|
self.prob = self.max_prob
|
||||||
self.name = None # will be set in training loop
|
self.name = None # will be set in training loop
|
||||||
self.grad_scale = grad_scale
|
|
||||||
|
|
||||||
def forward(self,
|
def forward(self,
|
||||||
x: Tensor) -> Tensor:
|
x: Tensor) -> Tensor:
|
||||||
@ -829,7 +870,9 @@ class Whiten(nn.Module):
|
|||||||
you use the returned value, or the graph will be freed
|
you use the returned value, or the graph will be freed
|
||||||
and nothing will happen in backprop.
|
and nothing will happen in backprop.
|
||||||
"""
|
"""
|
||||||
if not x.requires_grad or random.random() > self.prob or self.grad_scale == 0:
|
grad_scale = float(self.grad_scale)
|
||||||
|
if (not x.requires_grad or random.random() > self.prob or grad_scale == 0
|
||||||
|
or (x.is_cuda and self.mem_cutoff(torch.cuda.memory_allocated()))):
|
||||||
return _no_op(x)
|
return _no_op(x)
|
||||||
else:
|
else:
|
||||||
whitening_limit = float(self.whitening_limit)
|
whitening_limit = float(self.whitening_limit)
|
||||||
@ -845,7 +888,7 @@ class Whiten(nn.Module):
|
|||||||
return WhiteningPenaltyFunction.apply(x,
|
return WhiteningPenaltyFunction.apply(x,
|
||||||
self.num_groups,
|
self.num_groups,
|
||||||
whitening_limit,
|
whitening_limit,
|
||||||
self.grad_scale,
|
grad_scale,
|
||||||
self.name)
|
self.name)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user