From 4fe91ce67ce1bf112bb914b9cc6ab4e27fb61ce1 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 14 Jan 2023 17:19:34 +0800 Subject: [PATCH 1/3] Double hidden_channels in NonlinAttention from embed_dim//4 to embed_dim//2. --- egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index 1595d0544..2da05f5b1 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -489,7 +489,7 @@ class ZipformerEncoderLayer(nn.Module): dropout) self.nonlin_attention = NonlinAttention(embed_dim, - hidden_channels=embed_dim // 4) + hidden_channels=embed_dim // 2) self.small_conv_module = SmallConvolutionModule(embed_dim) From eeadc3b0ccac77cab01d96541fe4688436b430cf Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 14 Jan 2023 20:41:30 +0800 Subject: [PATCH 2/3] Add a multiplication to NonlinAttentionModule --- .../pruned_transducer_stateless7/zipformer.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index 1595d0544..789fe9489 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -1602,7 +1602,7 @@ class NonlinAttention(nn.Module): self.hidden_channels = hidden_channels - self.in_proj = nn.Linear(channels, hidden_channels * 2, bias=True) + self.in_proj = nn.Linear(channels, hidden_channels * 3, bias=True) # balancer that goes before the sigmoid. Have quite a large min_abs value, at 2.0, # because we noticed that well-trained instances of this module have abs-value before the sigmoid @@ -1617,7 +1617,10 @@ class NonlinAttention(nn.Module): ) self.tanh = nn.Tanh() - self.activation = Identity() # for diagnostics. + self.identity1 = Identity() # for diagnostics. + self.identity2 = Identity() # for diagnostics. + self.identity3 = Identity() # for diagnostics. + self.out_proj = ScaledLinear(hidden_channels, channels, bias=True, initial_scale=0.05) @@ -1652,16 +1655,17 @@ attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len) (seq_len, batch_size, _) = x.shape hidden_channels = self.hidden_channels - s = x[..., hidden_channels:] - x = x[..., :hidden_channels] + s, x, y = x.chunk(3, dim=-1) + + # s will go through tanh. s = self.balancer(s) s = self.tanh(s) s = s.unsqueeze(-1).reshape(seq_len, batch_size, hidden_channels) x = self.whiten1(x) - x = self.activation(x) # diagnostics only, it's the identity. x = x * s + x = self.identity1(x) # diagnostics only, it's the identity. (seq_len, batch_size, embed_dim) = x.shape num_heads = attn_weights.shape[0] @@ -1673,6 +1677,11 @@ attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len) # now x: (num_heads, batch_size, seq_len, head_dim) x = x.permute(2, 1, 0, 3).reshape(seq_len, batch_size, -1) + + y = self.identity2(y) + x = x * y + x = self.identity3(x) + x = self.out_proj(x) x = self.whiten2(x) return x From 048b6b6259a715c4b8225d493fdcd8df88e42b1f Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 15 Jan 2023 00:21:01 +0800 Subject: [PATCH 3/3] Make scale in NonlinAttention have glu nonlinearity. --- .../ASR/pruned_transducer_stateless7/zipformer.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index 789fe9489..91551f185 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -1602,7 +1602,7 @@ class NonlinAttention(nn.Module): self.hidden_channels = hidden_channels - self.in_proj = nn.Linear(channels, hidden_channels * 3, bias=True) + self.in_proj = nn.Linear(channels, hidden_channels * 4, bias=True) # balancer that goes before the sigmoid. Have quite a large min_abs value, at 2.0, # because we noticed that well-trained instances of this module have abs-value before the sigmoid @@ -1655,7 +1655,9 @@ attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len) (seq_len, batch_size, _) = x.shape hidden_channels = self.hidden_channels - s, x, y = x.chunk(3, dim=-1) + x, y = x.chunk(2, dim=-1) + + s, x = x.chunk(2, dim=-1) # s will go through tanh. @@ -1677,7 +1679,7 @@ attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len) # now x: (num_heads, batch_size, seq_len, head_dim) x = x.permute(2, 1, 0, 3).reshape(seq_len, batch_size, -1) - + y = torch.nn.functional.glu(y, dim=-1) y = self.identity2(y) x = x * y x = self.identity3(x)