mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Introduce schedules for whitening.
This commit is contained in:
parent
a6657e6b40
commit
ee61ec63b3
@ -518,6 +518,75 @@ def _diag(x: Tensor): # like .diag(), but works for tensors with 3 dims.
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class ScheduledFloat(torch.nn.Module):
|
||||||
|
"""
|
||||||
|
This object is a torch.nn.Module only because we want it to show up in [top_level module].modules();
|
||||||
|
it does not have a working forward() function. You are supposed to cast it to float, as
|
||||||
|
in, float(parent_module.whatever), and use it as something like a dropout prob.
|
||||||
|
|
||||||
|
It is a floating point value whose value changes depending on the batch count of the
|
||||||
|
training loop. It is a piecewise linear function where you specifiy the (x,y) pairs
|
||||||
|
in sorted order on x; x corresponds to the batch index. For batch-index values before the
|
||||||
|
first x or after the last x, we just use the first or last y value.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
self.dropout = ScheduledFloat((0.0, 0.2), (4000.0, 0.0), default=0.0)
|
||||||
|
|
||||||
|
`default` is used when self.batch_count is not set or in training or mode or in
|
||||||
|
torch.jit scripting mode.
|
||||||
|
"""
|
||||||
|
def __init__(self,
|
||||||
|
*args,
|
||||||
|
default: float = 0.0):
|
||||||
|
super().__init__()
|
||||||
|
# self.batch_count and self.name will be written to in the training loop.
|
||||||
|
self.batch_count = None
|
||||||
|
self.name = None
|
||||||
|
self.default = default
|
||||||
|
assert len(args) >= 1
|
||||||
|
for (x,y) in args:
|
||||||
|
assert x >= 0
|
||||||
|
for i in range(len(args) - 1):
|
||||||
|
assert args[i + 1] > args[i], args
|
||||||
|
self.schedule = args
|
||||||
|
|
||||||
|
def extra_repr(self) -> str:
|
||||||
|
return 'batch_count={}, schedule={}'.format(self.batch_count,
|
||||||
|
self.schedule)
|
||||||
|
|
||||||
|
def __float__(self):
|
||||||
|
print_prob = 0.0002
|
||||||
|
def maybe_print(ans):
|
||||||
|
if random.random() < print_prob:
|
||||||
|
logging.info(f"ScheduledFloat: name={self.name}, batch_count={self.batch_count}, ans={ans}")
|
||||||
|
batch_count = self.batch_count
|
||||||
|
if batch_count is None or not self.training or torch.jit.is_scripting():
|
||||||
|
return float(self.default)
|
||||||
|
if batch_count <= self.schedule[0][0]:
|
||||||
|
ans = self.schedule[0][1]
|
||||||
|
maybe_print(ans)
|
||||||
|
return float(ans)
|
||||||
|
elif batch_count >= self.schedule[-1][0]:
|
||||||
|
ans = self.schedule[-1][1]
|
||||||
|
maybe_print(ans)
|
||||||
|
return float(ans)
|
||||||
|
else:
|
||||||
|
cur_x, cur_y = self.schedule[0]
|
||||||
|
for i in range(1, len(self.schedule)):
|
||||||
|
next_x, next_y = self.schedule[i]
|
||||||
|
if batch_count >= cur_x and batch_count <= next_x:
|
||||||
|
ans = cur_y + (next_y - cur_y) * (batch_count - cur_x) / (next_x - cur_x)
|
||||||
|
maybe_print(ans)
|
||||||
|
return float(ans)
|
||||||
|
cur_x, cur_y = next_x, next_y
|
||||||
|
assert False
|
||||||
|
|
||||||
|
|
||||||
|
FloatLike = Union[float, ScheduledFloat]
|
||||||
|
|
||||||
|
|
||||||
def _whitening_metric(x: Tensor,
|
def _whitening_metric(x: Tensor,
|
||||||
num_groups: int):
|
num_groups: int):
|
||||||
"""
|
"""
|
||||||
@ -593,12 +662,11 @@ class WhiteningPenaltyFunction(torch.autograd.Function):
|
|||||||
return x_grad + penalty_grad.to(x_grad.dtype), None, None, None, None
|
return x_grad + penalty_grad.to(x_grad.dtype), None, None, None, None
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class Whiten(nn.Module):
|
class Whiten(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
num_groups: int,
|
num_groups: int,
|
||||||
whitening_limit: float,
|
whitening_limit: FloatLike,
|
||||||
prob: Union[float, Tuple[float,float]],
|
prob: Union[float, Tuple[float,float]],
|
||||||
grad_scale: float):
|
grad_scale: float):
|
||||||
"""
|
"""
|
||||||
@ -621,7 +689,7 @@ class Whiten(nn.Module):
|
|||||||
"""
|
"""
|
||||||
super(Whiten, self).__init__()
|
super(Whiten, self).__init__()
|
||||||
assert num_groups >= 1
|
assert num_groups >= 1
|
||||||
assert whitening_limit >= 1
|
assert float(whitening_limit) >= 1
|
||||||
assert grad_scale >= 0
|
assert grad_scale >= 0
|
||||||
self.num_groups = num_groups
|
self.num_groups = num_groups
|
||||||
self.whitening_limit = whitening_limit
|
self.whitening_limit = whitening_limit
|
||||||
@ -656,10 +724,11 @@ class Whiten(nn.Module):
|
|||||||
if not x.requires_grad or random.random() > self.prob or self.grad_scale == 0:
|
if not x.requires_grad or random.random() > self.prob or self.grad_scale == 0:
|
||||||
return _no_op(x)
|
return _no_op(x)
|
||||||
else:
|
else:
|
||||||
|
whitening_limit = float(self.whitening_limit)
|
||||||
if hasattr(self, 'min_prob') and random.random() < 0.25:
|
if hasattr(self, 'min_prob') and random.random() < 0.25:
|
||||||
# occasionally switch between min_prob and max_prob, based on whether
|
# occasionally switch between min_prob and max_prob, based on whether
|
||||||
# we are above or below the threshold.
|
# we are above or below the threshold.
|
||||||
if _whitening_metric(x.to(torch.float32), self.num_groups) > self.whitening_limit:
|
if _whitening_metric(x.to(torch.float32), self.num_groups) > whitening_limit:
|
||||||
# there would be a change to the grad.
|
# there would be a change to the grad.
|
||||||
self.prob = self.max_prob
|
self.prob = self.max_prob
|
||||||
else:
|
else:
|
||||||
@ -667,7 +736,7 @@ class Whiten(nn.Module):
|
|||||||
|
|
||||||
return WhiteningPenaltyFunction.apply(x,
|
return WhiteningPenaltyFunction.apply(x,
|
||||||
self.num_groups,
|
self.num_groups,
|
||||||
self.whitening_limit,
|
whitening_limit,
|
||||||
self.grad_scale,
|
self.grad_scale,
|
||||||
self.name)
|
self.name)
|
||||||
|
|
||||||
@ -1003,72 +1072,6 @@ class TanSwish(torch.nn.Module):
|
|||||||
return TanSwishFunction.apply(x)
|
return TanSwishFunction.apply(x)
|
||||||
|
|
||||||
|
|
||||||
class ScheduledFloat(torch.nn.Module):
|
|
||||||
"""
|
|
||||||
This object is a torch.nn.Module only because we want it to show up in [top_level module].modules();
|
|
||||||
it does not have a working forward() function. You are supposed to cast it to float, as
|
|
||||||
in, float(parent_module.whatever), and use it as something like a dropout prob.
|
|
||||||
|
|
||||||
It is a floating point value whose value changes depending on the batch count of the
|
|
||||||
training loop. It is a piecewise linear function where you specifiy the (x,y) pairs
|
|
||||||
in sorted order on x; x corresponds to the batch index. For batch-index values before the
|
|
||||||
first x or after the last x, we just use the first or last y value.
|
|
||||||
|
|
||||||
Example:
|
|
||||||
self.dropout = ScheduledFloat((0.0, 0.2), (4000.0, 0.0), default=0.0)
|
|
||||||
|
|
||||||
`default` is used when self.batch_count is not set or in training or mode or in
|
|
||||||
torch.jit scripting mode.
|
|
||||||
"""
|
|
||||||
def __init__(self,
|
|
||||||
*args,
|
|
||||||
default: float = 0.0):
|
|
||||||
super().__init__()
|
|
||||||
# self.batch_count and self.name will be written to in the training loop.
|
|
||||||
self.batch_count = None
|
|
||||||
self.name = None
|
|
||||||
self.default = default
|
|
||||||
assert len(args) >= 1
|
|
||||||
for (x,y) in args:
|
|
||||||
assert x >= 0
|
|
||||||
for i in range(len(args) - 1):
|
|
||||||
assert args[i + 1] > args[i], args
|
|
||||||
self.schedule = args
|
|
||||||
|
|
||||||
def extra_repr(self) -> str:
|
|
||||||
return 'batch_count={}, schedule={}'.format(self.batch_count,
|
|
||||||
self.schedule)
|
|
||||||
|
|
||||||
def __float__(self):
|
|
||||||
print_prob = 0.0002
|
|
||||||
def maybe_print(ans):
|
|
||||||
if random.random() < print_prob:
|
|
||||||
logging.info(f"ScheduledFloat: name={self.name}, batch_count={self.batch_count}, ans={ans}")
|
|
||||||
batch_count = self.batch_count
|
|
||||||
if batch_count is None or not self.training or torch.jit.is_scripting():
|
|
||||||
return float(self.default)
|
|
||||||
if batch_count <= self.schedule[0][0]:
|
|
||||||
ans = self.schedule[0][1]
|
|
||||||
maybe_print(ans)
|
|
||||||
return float(ans)
|
|
||||||
elif batch_count >= self.schedule[-1][0]:
|
|
||||||
ans = self.schedule[-1][1]
|
|
||||||
maybe_print(ans)
|
|
||||||
return float(ans)
|
|
||||||
else:
|
|
||||||
cur_x, cur_y = self.schedule[0]
|
|
||||||
for i in range(1, len(self.schedule)):
|
|
||||||
next_x, next_y = self.schedule[i]
|
|
||||||
if batch_count >= cur_x and batch_count <= next_x:
|
|
||||||
ans = cur_y + (next_y - cur_y) * (batch_count - cur_x) / (next_x - cur_x)
|
|
||||||
maybe_print(ans)
|
|
||||||
return float(ans)
|
|
||||||
cur_x, cur_y = next_x, next_y
|
|
||||||
assert False
|
|
||||||
|
|
||||||
FloatLike = Union[float, ScheduledFloat]
|
|
||||||
|
|
||||||
|
|
||||||
def _test_max_eig():
|
def _test_max_eig():
|
||||||
for proportion in [0.1, 0.5, 10.0]:
|
for proportion in [0.1, 0.5, 10.0]:
|
||||||
logging.info(f"proportion = {proportion}")
|
logging.info(f"proportion = {proportion}")
|
||||||
|
|||||||
@ -336,6 +336,11 @@ class Zipformer(EncoderInterface):
|
|||||||
return x, lengths
|
return x, lengths
|
||||||
|
|
||||||
|
|
||||||
|
def _whitening_schedule(x: float) -> ScheduledFloat:
|
||||||
|
return ScheduledFloat((0.0, x),
|
||||||
|
(12000.0, 2.0 * x),
|
||||||
|
default=x)
|
||||||
|
|
||||||
class ZipformerEncoderLayer(nn.Module):
|
class ZipformerEncoderLayer(nn.Module):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@ -424,7 +429,7 @@ class ZipformerEncoderLayer(nn.Module):
|
|||||||
max_abs=6.0,
|
max_abs=6.0,
|
||||||
)
|
)
|
||||||
self.whiten = Whiten(num_groups=1,
|
self.whiten = Whiten(num_groups=1,
|
||||||
whitening_limit=5.0,
|
whitening_limit=_whitening_schedule(4.0),
|
||||||
prob=(0.025, 0.25),
|
prob=(0.025, 0.25),
|
||||||
grad_scale=0.01)
|
grad_scale=0.01)
|
||||||
|
|
||||||
@ -1048,9 +1053,8 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
|||||||
self.in_proj = ScaledLinear(embed_dim, in_proj_dim, bias=True,
|
self.in_proj = ScaledLinear(embed_dim, in_proj_dim, bias=True,
|
||||||
initial_scale=query_head_dim**-0.25)
|
initial_scale=query_head_dim**-0.25)
|
||||||
|
|
||||||
# .. TODO: tune this limit? whitening_limit.
|
|
||||||
self.whiten_keys = Whiten(num_groups=num_heads,
|
self.whiten_keys = Whiten(num_groups=num_heads,
|
||||||
whitening_limit=2.0,
|
whitening_limit=_whitening_schedule(2.0),
|
||||||
prob=(0.025, 0.25),
|
prob=(0.025, 0.25),
|
||||||
grad_scale=0.025)
|
grad_scale=0.025)
|
||||||
|
|
||||||
@ -1227,7 +1231,7 @@ class SelfAttention(nn.Module):
|
|||||||
initial_scale=0.05)
|
initial_scale=0.05)
|
||||||
|
|
||||||
self.whiten = Whiten(num_groups=1,
|
self.whiten = Whiten(num_groups=1,
|
||||||
whitening_limit=15.0,
|
whitening_limit=_whitening_schedule(7.5),
|
||||||
prob=(0.025, 0.25),
|
prob=(0.025, 0.25),
|
||||||
grad_scale=0.01)
|
grad_scale=0.01)
|
||||||
|
|
||||||
@ -1331,7 +1335,7 @@ class AttentionSqueeze(nn.Module):
|
|||||||
bias=False, initial_scale=0.05)
|
bias=False, initial_scale=0.05)
|
||||||
|
|
||||||
self.out_whiten = Whiten(num_groups=1,
|
self.out_whiten = Whiten(num_groups=1,
|
||||||
whitening_limit=15.0,
|
whitening_limit=_whitening_schedule(7.5),
|
||||||
prob=(0.01, 0.1),
|
prob=(0.01, 0.1),
|
||||||
grad_scale=0.01)
|
grad_scale=0.01)
|
||||||
|
|
||||||
@ -1388,7 +1392,7 @@ class FeedforwardModule(nn.Module):
|
|||||||
self.out_proj = ScaledLinear(feedforward_dim, embed_dim,
|
self.out_proj = ScaledLinear(feedforward_dim, embed_dim,
|
||||||
initial_scale=0.01)
|
initial_scale=0.01)
|
||||||
self.out_whiten = Whiten(num_groups=1,
|
self.out_whiten = Whiten(num_groups=1,
|
||||||
whitening_limit=15.0,
|
whitening_limit=_whitening_schedule(7.5),
|
||||||
prob=(0.025, 0.25),
|
prob=(0.025, 0.25),
|
||||||
grad_scale=0.01)
|
grad_scale=0.01)
|
||||||
|
|
||||||
@ -1433,7 +1437,7 @@ class NonlinAttentionModule(nn.Module):
|
|||||||
initial_scale=0.05)
|
initial_scale=0.05)
|
||||||
|
|
||||||
self.whiten = Whiten(num_groups=1,
|
self.whiten = Whiten(num_groups=1,
|
||||||
whitening_limit=15.0,
|
whitening_limit=_whitening_schedule(7.5),
|
||||||
prob=(0.025, 0.25),
|
prob=(0.025, 0.25),
|
||||||
grad_scale=0.01)
|
grad_scale=0.01)
|
||||||
|
|
||||||
@ -1555,7 +1559,7 @@ class ConvolutionModule(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.out_whiten = Whiten(num_groups=1,
|
self.out_whiten = Whiten(num_groups=1,
|
||||||
whitening_limit=15.0,
|
whitening_limit=_whitening_schedule(7.5),
|
||||||
prob=(0.01, 0.1),
|
prob=(0.01, 0.1),
|
||||||
grad_scale=0.01)
|
grad_scale=0.01)
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user