Merge branch 'scaled_adam_exp466' into scaled_adam_exp472.

Below is a more complete list of the changes I am making, although some of
these may be counted in the last

  numbers XXX below correspond to branches numbered scaled_adam_expXXX.
    - from 412/413 (cherry-picked): dropout for attention in attention_squeeze and nonlin_attention modules,
      but simplified this a little to use the same dropout schedule and drop them out all together
      also have all 3 submodules use separate heads.
    - from 460->461, which is in the history of 464, revert the part about balancing output out attention_squeeze module.
    - merge from 462->467, about using TanSwish not tanh.
    - merge 462->465, remove whitening in self-attention module
    - merge the part of 465->466  that was about diagnostics (name in Whiten module)
This commit is contained in:
Daniel Povey 2022-11-23 14:39:08 +08:00
commit 1d0252d420
3 changed files with 22 additions and 23 deletions

View File

@ -561,11 +561,13 @@ class WhiteningPenaltyFunction(torch.autograd.Function):
x: Tensor, x: Tensor,
num_groups: int, num_groups: int,
whitening_limit: float, whitening_limit: float,
grad_scale: float) -> Tensor: grad_scale: float,
name: Optional[str]) -> Tensor:
ctx.save_for_backward(x) ctx.save_for_backward(x)
ctx.num_groups = num_groups ctx.num_groups = num_groups
ctx.whitening_limit = whitening_limit ctx.whitening_limit = whitening_limit
ctx.grad_scale = grad_scale ctx.grad_scale = grad_scale
ctx.name = name
return x return x
@staticmethod @staticmethod
@ -580,7 +582,7 @@ class WhiteningPenaltyFunction(torch.autograd.Function):
metric = _whitening_metric(x_detached, ctx.num_groups) metric = _whitening_metric(x_detached, ctx.num_groups)
if random.random() < 0.005 or __name__ == "__main__": if random.random() < 0.005 or __name__ == "__main__":
logging.info(f"Whitening: num_groups={ctx.num_groups}, num_channels={x_orig.shape[-1]}, " logging.info(f"Whitening: name={ctx.name}, num_groups={ctx.num_groups}, num_channels={x_orig.shape[-1]}, "
f"metric={metric.item():.2f} vs. limit={ctx.whitening_limit}") f"metric={metric.item():.2f} vs. limit={ctx.whitening_limit}")
(metric - ctx.whitening_limit).relu().backward() (metric - ctx.whitening_limit).relu().backward()
@ -588,7 +590,7 @@ class WhiteningPenaltyFunction(torch.autograd.Function):
scale = ctx.grad_scale * (x_grad.to(torch.float32).norm() / scale = ctx.grad_scale * (x_grad.to(torch.float32).norm() /
(penalty_grad.norm() + 1.0e-20)) (penalty_grad.norm() + 1.0e-20))
penalty_grad = penalty_grad * scale penalty_grad = penalty_grad * scale
return x_grad + penalty_grad.to(x_grad.dtype), None, None, None return x_grad + penalty_grad.to(x_grad.dtype), None, None, None, None
@ -630,7 +632,7 @@ class Whiten(nn.Module):
(self.min_prob, self.max_prob) = prob (self.min_prob, self.max_prob) = prob
assert 0 < self.min_prob < self.max_prob <= 1 assert 0 < self.min_prob < self.max_prob <= 1
self.prob = self.max_prob self.prob = self.max_prob
self.name = None # will be set in training loop
self.grad_scale = grad_scale self.grad_scale = grad_scale
def forward(self, def forward(self,
@ -666,7 +668,8 @@ class Whiten(nn.Module):
return WhiteningPenaltyFunction.apply(x, return WhiteningPenaltyFunction.apply(x,
self.num_groups, self.num_groups,
self.whitening_limit, self.whitening_limit,
self.grad_scale) self.grad_scale,
self.name)
class WithLoss(torch.autograd.Function): class WithLoss(torch.autograd.Function):

View File

@ -98,8 +98,8 @@ def set_batch_count(
for name, module in model.named_modules(): for name, module in model.named_modules():
if hasattr(module, 'batch_count'): if hasattr(module, 'batch_count'):
module.batch_count = batch_count module.batch_count = batch_count
if hasattr(module, 'name'): if hasattr(module, 'name'):
module.name = name module.name = name
def add_model_arguments(parser: argparse.ArgumentParser): def add_model_arguments(parser: argparse.ArgumentParser):

View File

@ -367,7 +367,7 @@ class ZipformerEncoderLayer(nn.Module):
# to work correctly. # to work correctly.
layer_skip_rate: FloatLike = ScheduledFloat((0.0, 0.5), (4000.0, 0.05), default=0), layer_skip_rate: FloatLike = ScheduledFloat((0.0, 0.5), (4000.0, 0.05), default=0),
dynamic_skip_rate: FloatLike = ScheduledFloat((0.0, 0.2), (4000.0, 0.0), default=0), dynamic_skip_rate: FloatLike = ScheduledFloat((0.0, 0.2), (4000.0, 0.0), default=0),
squeeze_const_attention_rate: FloatLike = ScheduledFloat((0.0, 0.5), (4000.0, 0.05), default=0), const_attention_rate: FloatLike = ScheduledFloat((0.0, 0.25), (4000.0, 0.025), default=0),
bypass_min: FloatLike = ScheduledFloat((0.0, 0.75), (20000.0, 0.25), default=0), bypass_min: FloatLike = ScheduledFloat((0.0, 0.75), (20000.0, 0.25), default=0),
bypass_max: FloatLike = 1.0, bypass_max: FloatLike = 1.0,
) -> None: ) -> None:
@ -382,7 +382,7 @@ class ZipformerEncoderLayer(nn.Module):
# ever becoming zero. # ever becoming zero.
self.bypass_min = copy.deepcopy(bypass_min) self.bypass_min = copy.deepcopy(bypass_min)
self.bypass_max = copy.deepcopy(bypass_max) self.bypass_max = copy.deepcopy(bypass_max)
self.squeeze_const_attention_rate = copy.deepcopy(squeeze_const_attention_rate) self.const_attention_rate = copy.deepcopy(const_attention_rate)
self.self_attn_weights = RelPositionMultiheadAttentionWeights( self.self_attn_weights = RelPositionMultiheadAttentionWeights(
embed_dim, pos_dim=pos_dim, num_heads=num_heads, embed_dim, pos_dim=pos_dim, num_heads=num_heads,
@ -480,27 +480,23 @@ class ZipformerEncoderLayer(nn.Module):
key_padding_mask=src_key_padding_mask, key_padding_mask=src_key_padding_mask,
) )
first_attn_weights = attn_weights[0:3]
squeeze_weights = attn_weights[1:2] if random.random() < float(self.const_attention_rate):
if random.random() < float(self.squeeze_const_attention_rate): # Make attention weights constant. The intention is to
# this form of dropout makes the attention-weights used for the # encourage these modules to do something similar to an
# squeeze-excite modules constant wherever they are not masked. The intention # averaging-over-time operation.
# is to encourage these modules to do something similar to an averaging-over-time first_attn_weights = (first_attn_weights > 0.0).to(first_attn_weights.dtype)
# operation. first_attn_weights = first_attn_weights * (1.0 / first_attn_weights.sum(dim=-1, keepdim=True))
squeeze_weights = (squeeze_weights > 0.0).to(squeeze_weights.dtype)
# make sure they sum to 1 over the last axis.
squeeze_weights = squeeze_weights * (1.0 / squeeze_weights.sum(dim=-1, keepdim=True))
if torch.jit.is_scripting() or use_self_attn: if torch.jit.is_scripting() or use_self_attn:
src = src + self.nonlin_attention_module(src, src = src + self.nonlin_attention_module(src,
attn_weights[0:1]) first_attn_weights[0:1])
src = src + self.feed_forward1(src) src = src + self.feed_forward1(src)
# pooling module # pooling module
if torch.jit.is_scripting() or use_self_attn: if torch.jit.is_scripting() or use_self_attn:
src = src + self.attention_squeeze1(src, squeeze_weights) src = src + self.attention_squeeze1(src, first_attn_weights[1:2])
if torch.jit.is_scripting() or use_self_attn: if torch.jit.is_scripting() or use_self_attn:
src = src + self.self_attn( src = src + self.self_attn(
@ -513,7 +509,7 @@ class ZipformerEncoderLayer(nn.Module):
# pooling module # pooling module
if torch.jit.is_scripting() or use_self_attn: if torch.jit.is_scripting() or use_self_attn:
src = src + self.attention_squeeze2(src, squeeze_weights) src = src + self.attention_squeeze2(src, first_attn_weights[2:3])
src = self.norm_final(self.balancer(src)) src = self.norm_final(self.balancer(src))