mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Introduce schedules for whitening.
This commit is contained in:
parent
a6657e6b40
commit
ee61ec63b3
@ -518,6 +518,75 @@ def _diag(x: Tensor): # like .diag(), but works for tensors with 3 dims.
|
||||
return x
|
||||
|
||||
|
||||
|
||||
|
||||
class ScheduledFloat(torch.nn.Module):
|
||||
"""
|
||||
This object is a torch.nn.Module only because we want it to show up in [top_level module].modules();
|
||||
it does not have a working forward() function. You are supposed to cast it to float, as
|
||||
in, float(parent_module.whatever), and use it as something like a dropout prob.
|
||||
|
||||
It is a floating point value whose value changes depending on the batch count of the
|
||||
training loop. It is a piecewise linear function where you specifiy the (x,y) pairs
|
||||
in sorted order on x; x corresponds to the batch index. For batch-index values before the
|
||||
first x or after the last x, we just use the first or last y value.
|
||||
|
||||
Example:
|
||||
self.dropout = ScheduledFloat((0.0, 0.2), (4000.0, 0.0), default=0.0)
|
||||
|
||||
`default` is used when self.batch_count is not set or in training or mode or in
|
||||
torch.jit scripting mode.
|
||||
"""
|
||||
def __init__(self,
|
||||
*args,
|
||||
default: float = 0.0):
|
||||
super().__init__()
|
||||
# self.batch_count and self.name will be written to in the training loop.
|
||||
self.batch_count = None
|
||||
self.name = None
|
||||
self.default = default
|
||||
assert len(args) >= 1
|
||||
for (x,y) in args:
|
||||
assert x >= 0
|
||||
for i in range(len(args) - 1):
|
||||
assert args[i + 1] > args[i], args
|
||||
self.schedule = args
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
return 'batch_count={}, schedule={}'.format(self.batch_count,
|
||||
self.schedule)
|
||||
|
||||
def __float__(self):
|
||||
print_prob = 0.0002
|
||||
def maybe_print(ans):
|
||||
if random.random() < print_prob:
|
||||
logging.info(f"ScheduledFloat: name={self.name}, batch_count={self.batch_count}, ans={ans}")
|
||||
batch_count = self.batch_count
|
||||
if batch_count is None or not self.training or torch.jit.is_scripting():
|
||||
return float(self.default)
|
||||
if batch_count <= self.schedule[0][0]:
|
||||
ans = self.schedule[0][1]
|
||||
maybe_print(ans)
|
||||
return float(ans)
|
||||
elif batch_count >= self.schedule[-1][0]:
|
||||
ans = self.schedule[-1][1]
|
||||
maybe_print(ans)
|
||||
return float(ans)
|
||||
else:
|
||||
cur_x, cur_y = self.schedule[0]
|
||||
for i in range(1, len(self.schedule)):
|
||||
next_x, next_y = self.schedule[i]
|
||||
if batch_count >= cur_x and batch_count <= next_x:
|
||||
ans = cur_y + (next_y - cur_y) * (batch_count - cur_x) / (next_x - cur_x)
|
||||
maybe_print(ans)
|
||||
return float(ans)
|
||||
cur_x, cur_y = next_x, next_y
|
||||
assert False
|
||||
|
||||
|
||||
FloatLike = Union[float, ScheduledFloat]
|
||||
|
||||
|
||||
def _whitening_metric(x: Tensor,
|
||||
num_groups: int):
|
||||
"""
|
||||
@ -593,12 +662,11 @@ class WhiteningPenaltyFunction(torch.autograd.Function):
|
||||
return x_grad + penalty_grad.to(x_grad.dtype), None, None, None, None
|
||||
|
||||
|
||||
|
||||
class Whiten(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
num_groups: int,
|
||||
whitening_limit: float,
|
||||
whitening_limit: FloatLike,
|
||||
prob: Union[float, Tuple[float,float]],
|
||||
grad_scale: float):
|
||||
"""
|
||||
@ -621,7 +689,7 @@ class Whiten(nn.Module):
|
||||
"""
|
||||
super(Whiten, self).__init__()
|
||||
assert num_groups >= 1
|
||||
assert whitening_limit >= 1
|
||||
assert float(whitening_limit) >= 1
|
||||
assert grad_scale >= 0
|
||||
self.num_groups = num_groups
|
||||
self.whitening_limit = whitening_limit
|
||||
@ -656,10 +724,11 @@ class Whiten(nn.Module):
|
||||
if not x.requires_grad or random.random() > self.prob or self.grad_scale == 0:
|
||||
return _no_op(x)
|
||||
else:
|
||||
whitening_limit = float(self.whitening_limit)
|
||||
if hasattr(self, 'min_prob') and random.random() < 0.25:
|
||||
# occasionally switch between min_prob and max_prob, based on whether
|
||||
# we are above or below the threshold.
|
||||
if _whitening_metric(x.to(torch.float32), self.num_groups) > self.whitening_limit:
|
||||
if _whitening_metric(x.to(torch.float32), self.num_groups) > whitening_limit:
|
||||
# there would be a change to the grad.
|
||||
self.prob = self.max_prob
|
||||
else:
|
||||
@ -667,7 +736,7 @@ class Whiten(nn.Module):
|
||||
|
||||
return WhiteningPenaltyFunction.apply(x,
|
||||
self.num_groups,
|
||||
self.whitening_limit,
|
||||
whitening_limit,
|
||||
self.grad_scale,
|
||||
self.name)
|
||||
|
||||
@ -1003,72 +1072,6 @@ class TanSwish(torch.nn.Module):
|
||||
return TanSwishFunction.apply(x)
|
||||
|
||||
|
||||
class ScheduledFloat(torch.nn.Module):
|
||||
"""
|
||||
This object is a torch.nn.Module only because we want it to show up in [top_level module].modules();
|
||||
it does not have a working forward() function. You are supposed to cast it to float, as
|
||||
in, float(parent_module.whatever), and use it as something like a dropout prob.
|
||||
|
||||
It is a floating point value whose value changes depending on the batch count of the
|
||||
training loop. It is a piecewise linear function where you specifiy the (x,y) pairs
|
||||
in sorted order on x; x corresponds to the batch index. For batch-index values before the
|
||||
first x or after the last x, we just use the first or last y value.
|
||||
|
||||
Example:
|
||||
self.dropout = ScheduledFloat((0.0, 0.2), (4000.0, 0.0), default=0.0)
|
||||
|
||||
`default` is used when self.batch_count is not set or in training or mode or in
|
||||
torch.jit scripting mode.
|
||||
"""
|
||||
def __init__(self,
|
||||
*args,
|
||||
default: float = 0.0):
|
||||
super().__init__()
|
||||
# self.batch_count and self.name will be written to in the training loop.
|
||||
self.batch_count = None
|
||||
self.name = None
|
||||
self.default = default
|
||||
assert len(args) >= 1
|
||||
for (x,y) in args:
|
||||
assert x >= 0
|
||||
for i in range(len(args) - 1):
|
||||
assert args[i + 1] > args[i], args
|
||||
self.schedule = args
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
return 'batch_count={}, schedule={}'.format(self.batch_count,
|
||||
self.schedule)
|
||||
|
||||
def __float__(self):
|
||||
print_prob = 0.0002
|
||||
def maybe_print(ans):
|
||||
if random.random() < print_prob:
|
||||
logging.info(f"ScheduledFloat: name={self.name}, batch_count={self.batch_count}, ans={ans}")
|
||||
batch_count = self.batch_count
|
||||
if batch_count is None or not self.training or torch.jit.is_scripting():
|
||||
return float(self.default)
|
||||
if batch_count <= self.schedule[0][0]:
|
||||
ans = self.schedule[0][1]
|
||||
maybe_print(ans)
|
||||
return float(ans)
|
||||
elif batch_count >= self.schedule[-1][0]:
|
||||
ans = self.schedule[-1][1]
|
||||
maybe_print(ans)
|
||||
return float(ans)
|
||||
else:
|
||||
cur_x, cur_y = self.schedule[0]
|
||||
for i in range(1, len(self.schedule)):
|
||||
next_x, next_y = self.schedule[i]
|
||||
if batch_count >= cur_x and batch_count <= next_x:
|
||||
ans = cur_y + (next_y - cur_y) * (batch_count - cur_x) / (next_x - cur_x)
|
||||
maybe_print(ans)
|
||||
return float(ans)
|
||||
cur_x, cur_y = next_x, next_y
|
||||
assert False
|
||||
|
||||
FloatLike = Union[float, ScheduledFloat]
|
||||
|
||||
|
||||
def _test_max_eig():
|
||||
for proportion in [0.1, 0.5, 10.0]:
|
||||
logging.info(f"proportion = {proportion}")
|
||||
|
||||
@ -336,6 +336,11 @@ class Zipformer(EncoderInterface):
|
||||
return x, lengths
|
||||
|
||||
|
||||
def _whitening_schedule(x: float) -> ScheduledFloat:
|
||||
return ScheduledFloat((0.0, x),
|
||||
(12000.0, 2.0 * x),
|
||||
default=x)
|
||||
|
||||
class ZipformerEncoderLayer(nn.Module):
|
||||
"""
|
||||
Args:
|
||||
@ -424,7 +429,7 @@ class ZipformerEncoderLayer(nn.Module):
|
||||
max_abs=6.0,
|
||||
)
|
||||
self.whiten = Whiten(num_groups=1,
|
||||
whitening_limit=5.0,
|
||||
whitening_limit=_whitening_schedule(4.0),
|
||||
prob=(0.025, 0.25),
|
||||
grad_scale=0.01)
|
||||
|
||||
@ -1048,9 +1053,8 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
||||
self.in_proj = ScaledLinear(embed_dim, in_proj_dim, bias=True,
|
||||
initial_scale=query_head_dim**-0.25)
|
||||
|
||||
# .. TODO: tune this limit? whitening_limit.
|
||||
self.whiten_keys = Whiten(num_groups=num_heads,
|
||||
whitening_limit=2.0,
|
||||
whitening_limit=_whitening_schedule(2.0),
|
||||
prob=(0.025, 0.25),
|
||||
grad_scale=0.025)
|
||||
|
||||
@ -1227,7 +1231,7 @@ class SelfAttention(nn.Module):
|
||||
initial_scale=0.05)
|
||||
|
||||
self.whiten = Whiten(num_groups=1,
|
||||
whitening_limit=15.0,
|
||||
whitening_limit=_whitening_schedule(7.5),
|
||||
prob=(0.025, 0.25),
|
||||
grad_scale=0.01)
|
||||
|
||||
@ -1331,7 +1335,7 @@ class AttentionSqueeze(nn.Module):
|
||||
bias=False, initial_scale=0.05)
|
||||
|
||||
self.out_whiten = Whiten(num_groups=1,
|
||||
whitening_limit=15.0,
|
||||
whitening_limit=_whitening_schedule(7.5),
|
||||
prob=(0.01, 0.1),
|
||||
grad_scale=0.01)
|
||||
|
||||
@ -1388,7 +1392,7 @@ class FeedforwardModule(nn.Module):
|
||||
self.out_proj = ScaledLinear(feedforward_dim, embed_dim,
|
||||
initial_scale=0.01)
|
||||
self.out_whiten = Whiten(num_groups=1,
|
||||
whitening_limit=15.0,
|
||||
whitening_limit=_whitening_schedule(7.5),
|
||||
prob=(0.025, 0.25),
|
||||
grad_scale=0.01)
|
||||
|
||||
@ -1433,7 +1437,7 @@ class NonlinAttentionModule(nn.Module):
|
||||
initial_scale=0.05)
|
||||
|
||||
self.whiten = Whiten(num_groups=1,
|
||||
whitening_limit=15.0,
|
||||
whitening_limit=_whitening_schedule(7.5),
|
||||
prob=(0.025, 0.25),
|
||||
grad_scale=0.01)
|
||||
|
||||
@ -1555,7 +1559,7 @@ class ConvolutionModule(nn.Module):
|
||||
)
|
||||
|
||||
self.out_whiten = Whiten(num_groups=1,
|
||||
whitening_limit=15.0,
|
||||
whitening_limit=_whitening_schedule(7.5),
|
||||
prob=(0.01, 0.1),
|
||||
grad_scale=0.01)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user