From 6c5763fbb32a5ebde0f5156bb3f9bcc73c79bfdd Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 22 Nov 2022 21:38:05 +0800 Subject: [PATCH 1/2] Implement subtracted momentum [0.33,0.66], and print name in Whiten module. --- .../pruned_transducer_stateless7/scaling.py | 13 ++++--- .../ASR/pruned_transducer_stateless7/train.py | 4 +-- .../pruned_transducer_stateless7/zipformer.py | 36 +++++++++++++------ 3 files changed, 36 insertions(+), 17 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index ed3784a78..838018653 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -561,11 +561,13 @@ class WhiteningPenaltyFunction(torch.autograd.Function): x: Tensor, num_groups: int, whitening_limit: float, - grad_scale: float) -> Tensor: + grad_scale: float, + name: Optional[str]) -> Tensor: ctx.save_for_backward(x) ctx.num_groups = num_groups ctx.whitening_limit = whitening_limit ctx.grad_scale = grad_scale + ctx.name = name return x @staticmethod @@ -580,7 +582,7 @@ class WhiteningPenaltyFunction(torch.autograd.Function): metric = _whitening_metric(x_detached, ctx.num_groups) if random.random() < 0.005 or __name__ == "__main__": - logging.info(f"Whitening: num_groups={ctx.num_groups}, num_channels={x_orig.shape[-1]}, " + logging.info(f"Whitening: name={ctx.name}, num_groups={ctx.num_groups}, num_channels={x_orig.shape[-1]}, " f"metric={metric.item():.2f} vs. limit={ctx.whitening_limit}") (metric - ctx.whitening_limit).relu().backward() @@ -588,7 +590,7 @@ class WhiteningPenaltyFunction(torch.autograd.Function): scale = ctx.grad_scale * (x_grad.to(torch.float32).norm() / (penalty_grad.norm() + 1.0e-20)) penalty_grad = penalty_grad * scale - return x_grad + penalty_grad.to(x_grad.dtype), None, None, None + return x_grad + penalty_grad.to(x_grad.dtype), None, None, None, None @@ -630,7 +632,7 @@ class Whiten(nn.Module): (self.min_prob, self.max_prob) = prob assert 0 < self.min_prob < self.max_prob <= 1 self.prob = self.max_prob - + self.name = None # will be set in training loop self.grad_scale = grad_scale def forward(self, @@ -666,7 +668,8 @@ class Whiten(nn.Module): return WhiteningPenaltyFunction.apply(x, self.num_groups, self.whitening_limit, - self.grad_scale) + self.grad_scale, + self.name) class WithLoss(torch.autograd.Function): diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py index 45513ef5c..8adf65cc9 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py @@ -98,8 +98,8 @@ def set_batch_count( for name, module in model.named_modules(): if hasattr(module, 'batch_count'): module.batch_count = batch_count - if hasattr(module, 'name'): - module.name = name + if hasattr(module, 'name'): + module.name = name def add_model_arguments(parser: argparse.ArgumentParser): diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index e5d2e58ec..9247171a1 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -464,6 +464,12 @@ class ZipformerEncoderLayer(nn.Module): src_orig = src + momentum_alpha = 0.66 + # the -0.5 below is "how strong" to make the negative momentum. + momentum_rate = -0.5 * (1.0 / (1 - momentum_alpha)) + momentum = 0.0 + + # dropout rate for non-feedforward submodules dynamic_skip_rate = float(self.dynamic_skip_rate) if self.training else 0.0 # multi-headed self-attention module @@ -478,30 +484,40 @@ class ZipformerEncoderLayer(nn.Module): key_padding_mask=src_key_padding_mask, ) + def add_to_src(src, momentum, x): + src = src + x + momentum_rate * momentum + momentum = (momentum * momentum_alpha) + x + return src, momentum + + if torch.jit.is_scripting() or use_self_attn: - src = src + self.nonlin_attention_module(src, - attn_weights[0:1]) + src, momentum = add_to_src(src, momentum, + self.nonlin_attention_module(src, attn_weights[0:1])) - src = src + self.feed_forward1(src) + src, momentum = add_to_src(src, momentum, + self.feed_forward1(src)) # pooling module if torch.jit.is_scripting() or use_self_attn: - src = src + self.attention_squeeze1(src, attn_weights[1:2]) + src, momentum = add_to_src(src, momentum, + self.attention_squeeze1(src, attn_weights[1:2])) if torch.jit.is_scripting() or use_self_attn: - src = src + self.self_attn( - src, attn_weights) + src, momentum = add_to_src(src, momentum, self.self_attn( + src, attn_weights)) if torch.jit.is_scripting() or random.random() >= dynamic_skip_rate: - src = src + self.conv_module(src, src_key_padding_mask=src_key_padding_mask) + src, momentum = add_to_src(src, momentum, + self.conv_module(src, src_key_padding_mask=src_key_padding_mask)) - src = src + self.feed_forward2(src) + src, momentum = add_to_src(src, momentum, + self.feed_forward2(src)) # pooling module if torch.jit.is_scripting() or use_self_attn: - src = src + self.attention_squeeze2(src, attn_weights[2:3]) - + src, momentum = add_to_src(src, momentum, + self.attention_squeeze2(src, attn_weights[2:3])) src = self.norm_final(self.balancer(src)) From 1826648dde6e9473384f49ea1f6ca5e175aa43e9 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 22 Nov 2022 22:54:05 +0800 Subject: [PATCH 2/2] Fix formulas and constants --- .../ASR/pruned_transducer_stateless7/zipformer.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index 9247171a1..5fb643f32 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -465,8 +465,10 @@ class ZipformerEncoderLayer(nn.Module): src_orig = src momentum_alpha = 0.66 - # the -0.5 below is "how strong" to make the negative momentum. - momentum_rate = -0.5 * (1.0 / (1 - momentum_alpha)) + # the -0.333 below is "how strong" to make the negative momentum. + # the (1-momentum_alpha) cancels out the 1/(1-momentum_alpha) factor from + # adding up powers of momentum_alpha + momentum_rate = -0.333 * (1 - momentum_alpha) momentum = 0.0