Implement subtracted momentum [0.33,0.66], and print name in Whiten module.

This commit is contained in:
Daniel Povey 2022-11-22 21:38:05 +08:00
parent 1a2632d0a2
commit 6c5763fbb3
3 changed files with 36 additions and 17 deletions

View File

@ -561,11 +561,13 @@ class WhiteningPenaltyFunction(torch.autograd.Function):
x: Tensor,
num_groups: int,
whitening_limit: float,
grad_scale: float) -> Tensor:
grad_scale: float,
name: Optional[str]) -> Tensor:
ctx.save_for_backward(x)
ctx.num_groups = num_groups
ctx.whitening_limit = whitening_limit
ctx.grad_scale = grad_scale
ctx.name = name
return x
@staticmethod
@ -580,7 +582,7 @@ class WhiteningPenaltyFunction(torch.autograd.Function):
metric = _whitening_metric(x_detached, ctx.num_groups)
if random.random() < 0.005 or __name__ == "__main__":
logging.info(f"Whitening: num_groups={ctx.num_groups}, num_channels={x_orig.shape[-1]}, "
logging.info(f"Whitening: name={ctx.name}, num_groups={ctx.num_groups}, num_channels={x_orig.shape[-1]}, "
f"metric={metric.item():.2f} vs. limit={ctx.whitening_limit}")
(metric - ctx.whitening_limit).relu().backward()
@ -588,7 +590,7 @@ class WhiteningPenaltyFunction(torch.autograd.Function):
scale = ctx.grad_scale * (x_grad.to(torch.float32).norm() /
(penalty_grad.norm() + 1.0e-20))
penalty_grad = penalty_grad * scale
return x_grad + penalty_grad.to(x_grad.dtype), None, None, None
return x_grad + penalty_grad.to(x_grad.dtype), None, None, None, None
@ -630,7 +632,7 @@ class Whiten(nn.Module):
(self.min_prob, self.max_prob) = prob
assert 0 < self.min_prob < self.max_prob <= 1
self.prob = self.max_prob
self.name = None # will be set in training loop
self.grad_scale = grad_scale
def forward(self,
@ -666,7 +668,8 @@ class Whiten(nn.Module):
return WhiteningPenaltyFunction.apply(x,
self.num_groups,
self.whitening_limit,
self.grad_scale)
self.grad_scale,
self.name)
class WithLoss(torch.autograd.Function):

View File

@ -464,6 +464,12 @@ class ZipformerEncoderLayer(nn.Module):
src_orig = src
momentum_alpha = 0.66
# the -0.5 below is "how strong" to make the negative momentum.
momentum_rate = -0.5 * (1.0 / (1 - momentum_alpha))
momentum = 0.0
# dropout rate for non-feedforward submodules
dynamic_skip_rate = float(self.dynamic_skip_rate) if self.training else 0.0
# multi-headed self-attention module
@ -478,30 +484,40 @@ class ZipformerEncoderLayer(nn.Module):
key_padding_mask=src_key_padding_mask,
)
def add_to_src(src, momentum, x):
src = src + x + momentum_rate * momentum
momentum = (momentum * momentum_alpha) + x
return src, momentum
if torch.jit.is_scripting() or use_self_attn:
src = src + self.nonlin_attention_module(src,
attn_weights[0:1])
src, momentum = add_to_src(src, momentum,
self.nonlin_attention_module(src, attn_weights[0:1]))
src = src + self.feed_forward1(src)
src, momentum = add_to_src(src, momentum,
self.feed_forward1(src))
# pooling module
if torch.jit.is_scripting() or use_self_attn:
src = src + self.attention_squeeze1(src, attn_weights[1:2])
src, momentum = add_to_src(src, momentum,
self.attention_squeeze1(src, attn_weights[1:2]))
if torch.jit.is_scripting() or use_self_attn:
src = src + self.self_attn(
src, attn_weights)
src, momentum = add_to_src(src, momentum, self.self_attn(
src, attn_weights))
if torch.jit.is_scripting() or random.random() >= dynamic_skip_rate:
src = src + self.conv_module(src, src_key_padding_mask=src_key_padding_mask)
src, momentum = add_to_src(src, momentum,
self.conv_module(src, src_key_padding_mask=src_key_padding_mask))
src = src + self.feed_forward2(src)
src, momentum = add_to_src(src, momentum,
self.feed_forward2(src))
# pooling module
if torch.jit.is_scripting() or use_self_attn:
src = src + self.attention_squeeze2(src, attn_weights[2:3])
src, momentum = add_to_src(src, momentum,
self.attention_squeeze2(src, attn_weights[2:3]))
src = self.norm_final(self.balancer(src))