Introduce schedules for whitening.

This commit is contained in:
Daniel Povey 2022-11-23 19:49:34 +08:00
parent a6657e6b40
commit ee61ec63b3
2 changed files with 86 additions and 79 deletions

View File

@ -518,6 +518,75 @@ def _diag(x: Tensor): # like .diag(), but works for tensors with 3 dims.
return x
class ScheduledFloat(torch.nn.Module):
"""
This object is a torch.nn.Module only because we want it to show up in [top_level module].modules();
it does not have a working forward() function. You are supposed to cast it to float, as
in, float(parent_module.whatever), and use it as something like a dropout prob.
It is a floating point value whose value changes depending on the batch count of the
training loop. It is a piecewise linear function where you specifiy the (x,y) pairs
in sorted order on x; x corresponds to the batch index. For batch-index values before the
first x or after the last x, we just use the first or last y value.
Example:
self.dropout = ScheduledFloat((0.0, 0.2), (4000.0, 0.0), default=0.0)
`default` is used when self.batch_count is not set or in training or mode or in
torch.jit scripting mode.
"""
def __init__(self,
*args,
default: float = 0.0):
super().__init__()
# self.batch_count and self.name will be written to in the training loop.
self.batch_count = None
self.name = None
self.default = default
assert len(args) >= 1
for (x,y) in args:
assert x >= 0
for i in range(len(args) - 1):
assert args[i + 1] > args[i], args
self.schedule = args
def extra_repr(self) -> str:
return 'batch_count={}, schedule={}'.format(self.batch_count,
self.schedule)
def __float__(self):
print_prob = 0.0002
def maybe_print(ans):
if random.random() < print_prob:
logging.info(f"ScheduledFloat: name={self.name}, batch_count={self.batch_count}, ans={ans}")
batch_count = self.batch_count
if batch_count is None or not self.training or torch.jit.is_scripting():
return float(self.default)
if batch_count <= self.schedule[0][0]:
ans = self.schedule[0][1]
maybe_print(ans)
return float(ans)
elif batch_count >= self.schedule[-1][0]:
ans = self.schedule[-1][1]
maybe_print(ans)
return float(ans)
else:
cur_x, cur_y = self.schedule[0]
for i in range(1, len(self.schedule)):
next_x, next_y = self.schedule[i]
if batch_count >= cur_x and batch_count <= next_x:
ans = cur_y + (next_y - cur_y) * (batch_count - cur_x) / (next_x - cur_x)
maybe_print(ans)
return float(ans)
cur_x, cur_y = next_x, next_y
assert False
FloatLike = Union[float, ScheduledFloat]
def _whitening_metric(x: Tensor,
num_groups: int):
"""
@ -593,12 +662,11 @@ class WhiteningPenaltyFunction(torch.autograd.Function):
return x_grad + penalty_grad.to(x_grad.dtype), None, None, None, None
class Whiten(nn.Module):
def __init__(
self,
num_groups: int,
whitening_limit: float,
whitening_limit: FloatLike,
prob: Union[float, Tuple[float,float]],
grad_scale: float):
"""
@ -621,7 +689,7 @@ class Whiten(nn.Module):
"""
super(Whiten, self).__init__()
assert num_groups >= 1
assert whitening_limit >= 1
assert float(whitening_limit) >= 1
assert grad_scale >= 0
self.num_groups = num_groups
self.whitening_limit = whitening_limit
@ -656,10 +724,11 @@ class Whiten(nn.Module):
if not x.requires_grad or random.random() > self.prob or self.grad_scale == 0:
return _no_op(x)
else:
whitening_limit = float(self.whitening_limit)
if hasattr(self, 'min_prob') and random.random() < 0.25:
# occasionally switch between min_prob and max_prob, based on whether
# we are above or below the threshold.
if _whitening_metric(x.to(torch.float32), self.num_groups) > self.whitening_limit:
if _whitening_metric(x.to(torch.float32), self.num_groups) > whitening_limit:
# there would be a change to the grad.
self.prob = self.max_prob
else:
@ -667,7 +736,7 @@ class Whiten(nn.Module):
return WhiteningPenaltyFunction.apply(x,
self.num_groups,
self.whitening_limit,
whitening_limit,
self.grad_scale,
self.name)
@ -1003,72 +1072,6 @@ class TanSwish(torch.nn.Module):
return TanSwishFunction.apply(x)
class ScheduledFloat(torch.nn.Module):
"""
This object is a torch.nn.Module only because we want it to show up in [top_level module].modules();
it does not have a working forward() function. You are supposed to cast it to float, as
in, float(parent_module.whatever), and use it as something like a dropout prob.
It is a floating point value whose value changes depending on the batch count of the
training loop. It is a piecewise linear function where you specifiy the (x,y) pairs
in sorted order on x; x corresponds to the batch index. For batch-index values before the
first x or after the last x, we just use the first or last y value.
Example:
self.dropout = ScheduledFloat((0.0, 0.2), (4000.0, 0.0), default=0.0)
`default` is used when self.batch_count is not set or in training or mode or in
torch.jit scripting mode.
"""
def __init__(self,
*args,
default: float = 0.0):
super().__init__()
# self.batch_count and self.name will be written to in the training loop.
self.batch_count = None
self.name = None
self.default = default
assert len(args) >= 1
for (x,y) in args:
assert x >= 0
for i in range(len(args) - 1):
assert args[i + 1] > args[i], args
self.schedule = args
def extra_repr(self) -> str:
return 'batch_count={}, schedule={}'.format(self.batch_count,
self.schedule)
def __float__(self):
print_prob = 0.0002
def maybe_print(ans):
if random.random() < print_prob:
logging.info(f"ScheduledFloat: name={self.name}, batch_count={self.batch_count}, ans={ans}")
batch_count = self.batch_count
if batch_count is None or not self.training or torch.jit.is_scripting():
return float(self.default)
if batch_count <= self.schedule[0][0]:
ans = self.schedule[0][1]
maybe_print(ans)
return float(ans)
elif batch_count >= self.schedule[-1][0]:
ans = self.schedule[-1][1]
maybe_print(ans)
return float(ans)
else:
cur_x, cur_y = self.schedule[0]
for i in range(1, len(self.schedule)):
next_x, next_y = self.schedule[i]
if batch_count >= cur_x and batch_count <= next_x:
ans = cur_y + (next_y - cur_y) * (batch_count - cur_x) / (next_x - cur_x)
maybe_print(ans)
return float(ans)
cur_x, cur_y = next_x, next_y
assert False
FloatLike = Union[float, ScheduledFloat]
def _test_max_eig():
for proportion in [0.1, 0.5, 10.0]:
logging.info(f"proportion = {proportion}")

View File

@ -336,6 +336,11 @@ class Zipformer(EncoderInterface):
return x, lengths
def _whitening_schedule(x: float) -> ScheduledFloat:
return ScheduledFloat((0.0, x),
(12000.0, 2.0 * x),
default=x)
class ZipformerEncoderLayer(nn.Module):
"""
Args:
@ -424,7 +429,7 @@ class ZipformerEncoderLayer(nn.Module):
max_abs=6.0,
)
self.whiten = Whiten(num_groups=1,
whitening_limit=5.0,
whitening_limit=_whitening_schedule(4.0),
prob=(0.025, 0.25),
grad_scale=0.01)
@ -1048,9 +1053,8 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
self.in_proj = ScaledLinear(embed_dim, in_proj_dim, bias=True,
initial_scale=query_head_dim**-0.25)
# .. TODO: tune this limit? whitening_limit.
self.whiten_keys = Whiten(num_groups=num_heads,
whitening_limit=2.0,
whitening_limit=_whitening_schedule(2.0),
prob=(0.025, 0.25),
grad_scale=0.025)
@ -1227,7 +1231,7 @@ class SelfAttention(nn.Module):
initial_scale=0.05)
self.whiten = Whiten(num_groups=1,
whitening_limit=15.0,
whitening_limit=_whitening_schedule(7.5),
prob=(0.025, 0.25),
grad_scale=0.01)
@ -1331,7 +1335,7 @@ class AttentionSqueeze(nn.Module):
bias=False, initial_scale=0.05)
self.out_whiten = Whiten(num_groups=1,
whitening_limit=15.0,
whitening_limit=_whitening_schedule(7.5),
prob=(0.01, 0.1),
grad_scale=0.01)
@ -1388,7 +1392,7 @@ class FeedforwardModule(nn.Module):
self.out_proj = ScaledLinear(feedforward_dim, embed_dim,
initial_scale=0.01)
self.out_whiten = Whiten(num_groups=1,
whitening_limit=15.0,
whitening_limit=_whitening_schedule(7.5),
prob=(0.025, 0.25),
grad_scale=0.01)
@ -1433,7 +1437,7 @@ class NonlinAttentionModule(nn.Module):
initial_scale=0.05)
self.whiten = Whiten(num_groups=1,
whitening_limit=15.0,
whitening_limit=_whitening_schedule(7.5),
prob=(0.025, 0.25),
grad_scale=0.01)
@ -1555,7 +1559,7 @@ class ConvolutionModule(nn.Module):
)
self.out_whiten = Whiten(num_groups=1,
whitening_limit=15.0,
whitening_limit=_whitening_schedule(7.5),
prob=(0.01, 0.1),
grad_scale=0.01)