From f7c99ed1d1a8f8e4fb6d3f91c8666fa721062d99 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 18 Nov 2022 12:06:12 +0800 Subject: [PATCH 1/9] Introduce random shift with stddev=1.0 into pos_emb --- .../pruned_transducer_stateless7/zipformer.py | 30 ++++++++++++------- 1 file changed, 20 insertions(+), 10 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index efcb25754..c6d44bbe1 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -890,27 +890,32 @@ class CompactRelPositionalEncoding(torch.nn.Module): Args: embed_dim: Embedding dimension. dropout_rate: Dropout rate. + random_shift: standard deviation of random distance by which we shift each time, if + training. max_len: Maximum input length: just a heuristic for initialization. """ def __init__( self, embed_dim: int, - dropout_rate: float, + dropout_rate: FloatLike = 0.0, + random_shift: FloatLike = 1.0, max_len: int = 1000 ) -> None: """Construct a CompactRelPositionalEncoding object.""" super(CompactRelPositionalEncoding, self).__init__() self.embed_dim = embed_dim assert embed_dim % 2 == 0 - self.dropout = torch.nn.Dropout(dropout_rate) + self.dropout_rate = dropout_rate + self.random_shift = random_shift self.pe = None - self.extend_pe(torch.tensor(0.0).expand(max_len)) + self.extend_pe(torch.tensor(0.0).expand(max_len), 0) - def extend_pe(self, x: Tensor) -> None: + def extend_pe(self, x: Tensor, shift: int) -> None: """Reset the positional encodings.""" + T = x.size(0) + abs(shift) if self.pe is not None: # self.pe contains both positive and negative parts # the length of self.pe is 2 * input_len - 1 - if self.pe.size(0) >= x.size(0) * 2 - 1: + if self.pe.size(0) >= T * 2 - 1: # Note: TorchScript doesn't implement operator== for torch.Device if self.pe.dtype != x.dtype or str(self.pe.device) != str( x.device @@ -918,7 +923,6 @@ class CompactRelPositionalEncoding(torch.nn.Module): self.pe = self.pe.to(dtype=x.dtype, device=x.device) return - T = x.size(0) # if T == 4, x would contain [ -3, -2, 1, 0, 1, 2, 3 ] x = torch.arange(-(T-1), T, device=x.device).to(torch.float32).unsqueeze(1) @@ -973,16 +977,22 @@ class CompactRelPositionalEncoding(torch.nn.Module): torch.Tensor: Encoded tensor (batch, 2*time-1, `*`). """ - self.extend_pe(x) + + shift = 0 if self.training else int(round(random.normalvariate(0, 1) * float(self.random_shift))) + self.extend_pe(x, shift) + pos_emb = self.pe[ self.pe.size(0) // 2 - - x.size(0) + - x.size(0) + shift + 1 : self.pe.size(0) // 2 # noqa E203 - + x.size(0), + + x.size(0) + shift, : ] pos_emb = pos_emb.unsqueeze(0) - return self.dropout(pos_emb) + pos_emb = torch.nn.functional.dropout(pos_emb, + p=float(self.dropout_rate), + training=self.training) + return pos_emb From 8a095c1cd1de2c541f855f47d0498e37ffb35493 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 18 Nov 2022 12:46:40 +0800 Subject: [PATCH 2/9] Add SmallConvModule; decrease feedforward dims to keep about same num params. --- .../ASR/pruned_transducer_stateless7/train.py | 2 +- .../pruned_transducer_stateless7/zipformer.py | 80 +++++++++++++++++++ 2 files changed, 81 insertions(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py index 45513ef5c..8d1c65362 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py @@ -122,7 +122,7 @@ def add_model_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--feedforward-dim", type=str, - default="1536,1536,2048,1536,1536,1536", + default="1280,1280,1536,1280,1280,1280", help="Feedforward dimension of the zipformer encoder layers, per stack, comma separated.", ) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index c6d44bbe1..994150b2e 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -405,6 +405,8 @@ class ZipformerEncoderLayer(nn.Module): self.nonlin_attention_module = NonlinAttentionModule(embed_dim) + self.small_conv_module = SmallConvolutionModule(embed_dim) + self.conv_module = ConvolutionModule(embed_dim, cnn_module_kernel) @@ -483,6 +485,10 @@ class ZipformerEncoderLayer(nn.Module): src = src + self.nonlin_attention_module(src, attn_weights[0:1]) + + if torch.jit.is_scripting() or random.random() >= dynamic_skip_rate: + src = src + self.small_conv_module(src, src_key_padding_mask=src_key_padding_mask) + src = src + self.feed_forward1(src) # pooling module @@ -1569,6 +1575,80 @@ class ConvolutionModule(nn.Module): return x.permute(2, 0, 1) +class SmallConvolutionModule(nn.Module): + """Part of Zipformer model: a small version of the Convolution module that uses a small kernel. + + Args: + channels (int): The number of channels of conv layers. + kernel_size (int): Kernerl size of conv layers. + bias (bool): Whether to use bias in conv layers (default=True). + + """ + + def __init__( + self, channels: int, hidden_dim: int = 256, + ) -> None: + super().__init__() + + self.conv1 = nn.Conv1d( + channels, + hidden_dim, + kernel_size=3, + stride=1, + padding=1, + bias=True, + ) + + self.deriv_balancer = ActivationBalancer( + hidden_dim, channel_dim=1, + min_positive=0.05, max_positive=1.0, + max_abs=20.0, + ) + + self.activation = DoubleSwish() + + self.conv2 = ScaledConv1d( + hidden_dim, + channels, + kernel_size=1, + stride=1, + padding=0, + bias=True, + initial_scale=0.05, + ) + + def forward(self, + x: Tensor, + src_key_padding_mask: Optional[Tensor] = None, + ) -> Tensor: + """Compute convolution module. + + Args: + x: Input tensor (#time, batch, channels). + src_key_padding_mask: the mask for the src keys per batch (optional): + (batch, #time), contains bool in masked positions. + + Returns: + Tensor: Output tensor (#time, batch, channels). + + """ + # exchange the temporal dimension and the feature dimension + x = x.permute(1, 2, 0) # (#batch, channels, time). + + x = self.conv1(x) # (batch, hidden_dim, time) + + x = self.deriv_balancer(x) + x = self.activation(x) + + if src_key_padding_mask is not None: + x.masked_fill_(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0) + + x = self.conv2(x) + + return x.permute(2, 0, 1) + + + class Conv2dSubsampling(nn.Module): """Convolutional 2D subsampling (to 1/2 length). From 0601dd72fd1d9f019c0c148e91d13018ea69aeba Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 18 Nov 2022 14:53:03 +0800 Subject: [PATCH 3/9] Bug-fix RE random shift --- egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index c6d44bbe1..f5bbe9d48 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -978,7 +978,7 @@ class CompactRelPositionalEncoding(torch.nn.Module): """ - shift = 0 if self.training else int(round(random.normalvariate(0, 1) * float(self.random_shift))) + shift = int(round(random.normalvariate(0, 1) * float(self.random_shift))) if self.training else 0 self.extend_pe(x, shift) pos_emb = self.pe[ From d23fda7c5f1b2eaf93ba82a25a8ef0d7428e51ca Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 19 Nov 2022 13:36:16 +0800 Subject: [PATCH 4/9] Multiply length_factor by 2.0. --- .../ASR/pruned_transducer_stateless7/zipformer.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index efcb25754..52d171a14 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -937,7 +937,10 @@ class CompactRelPositionalEncoding(torch.nn.Module): # length_factor is chosen so that the FFT can exactly separate points # close to the origin (T == 0). So this part of the formulation is not really # heuristic. - length_factor = self.embed_dim / (2.0 * math.pi) # todo: test this. + length_factor = self.embed_dim / (2.0 * math.pi) + # multiplying length_factor by this heuristic constant should reduce the resolution near to the + # origin, i.e. reduce its ability to separate points near zero. + length_factor *= 2.0 # note for machine implementations: if atan is not available, we can use: # x.sign() * ((1 / (x.abs() + 1)) - 1) * (-math.pi/2) From 8b3303594cf554b16ad2398da43fe51c2012bcac Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 20 Nov 2022 13:07:20 +0800 Subject: [PATCH 5/9] Revert 419->420 change, regarding random shift in pos embedding --- .../pruned_transducer_stateless7/zipformer.py | 30 +++++++------------ 1 file changed, 10 insertions(+), 20 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index e502c991e..eff4f65c2 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -896,32 +896,27 @@ class CompactRelPositionalEncoding(torch.nn.Module): Args: embed_dim: Embedding dimension. dropout_rate: Dropout rate. - random_shift: standard deviation of random distance by which we shift each time, if - training. max_len: Maximum input length: just a heuristic for initialization. """ def __init__( self, embed_dim: int, - dropout_rate: FloatLike = 0.0, - random_shift: FloatLike = 1.0, + dropout_rate: float, max_len: int = 1000 ) -> None: """Construct a CompactRelPositionalEncoding object.""" super(CompactRelPositionalEncoding, self).__init__() self.embed_dim = embed_dim assert embed_dim % 2 == 0 - self.dropout_rate = dropout_rate - self.random_shift = random_shift + self.dropout = torch.nn.Dropout(dropout_rate) self.pe = None - self.extend_pe(torch.tensor(0.0).expand(max_len), 0) + self.extend_pe(torch.tensor(0.0).expand(max_len)) - def extend_pe(self, x: Tensor, shift: int) -> None: + def extend_pe(self, x: Tensor) -> None: """Reset the positional encodings.""" - T = x.size(0) + abs(shift) if self.pe is not None: # self.pe contains both positive and negative parts # the length of self.pe is 2 * input_len - 1 - if self.pe.size(0) >= T * 2 - 1: + if self.pe.size(0) >= x.size(0) * 2 - 1: # Note: TorchScript doesn't implement operator== for torch.Device if self.pe.dtype != x.dtype or str(self.pe.device) != str( x.device @@ -929,6 +924,7 @@ class CompactRelPositionalEncoding(torch.nn.Module): self.pe = self.pe.to(dtype=x.dtype, device=x.device) return + T = x.size(0) # if T == 4, x would contain [ -3, -2, 1, 0, 1, 2, 3 ] x = torch.arange(-(T-1), T, device=x.device).to(torch.float32).unsqueeze(1) @@ -983,22 +979,16 @@ class CompactRelPositionalEncoding(torch.nn.Module): torch.Tensor: Encoded tensor (batch, 2*time-1, `*`). """ - - shift = int(round(random.normalvariate(0, 1) * float(self.random_shift))) if self.training else 0 - self.extend_pe(x, shift) - + self.extend_pe(x) pos_emb = self.pe[ self.pe.size(0) // 2 - - x.size(0) + shift + - x.size(0) + 1 : self.pe.size(0) // 2 # noqa E203 - + x.size(0) + shift, + + x.size(0), : ] pos_emb = pos_emb.unsqueeze(0) - pos_emb = torch.nn.functional.dropout(pos_emb, - p=float(self.dropout_rate), - training=self.training) - return pos_emb + return self.dropout(pos_emb) From 31b2a735b8344abe6f73d952fed60f30b3115686 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 20 Nov 2022 13:17:39 +0800 Subject: [PATCH 6/9] Move feedforward1 to the beginning, separating it from small_conv_module. --- egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index f22a8a39a..ff9468167 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -472,6 +472,8 @@ class ZipformerEncoderLayer(nn.Module): # multi-headed self-attention module use_self_attn = (random.random() >= dynamic_skip_rate) + src = src + self.feed_forward1(src) + if torch.jit.is_scripting() or use_self_attn: # attn_weights: (num_heads, batch_size, seq_len, seq_len) attn_weights = self.self_attn_weights( @@ -489,8 +491,6 @@ class ZipformerEncoderLayer(nn.Module): if torch.jit.is_scripting() or random.random() >= dynamic_skip_rate: src = src + self.small_conv_module(src, src_key_padding_mask=src_key_padding_mask) - src = src + self.feed_forward1(src) - # pooling module if torch.jit.is_scripting() or use_self_attn: src = src + self.attention_squeeze1(src, attn_weights[1:2]) From a52ec3da28baa688d68b4abf42364a25f962bee1 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 20 Nov 2022 14:24:41 +0800 Subject: [PATCH 7/9] Change feedforward dims: increase 1536->1792 for largest ff dim and move it one step later. --- egs/librispeech/ASR/pruned_transducer_stateless7/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py index 8d1c65362..d364560f3 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py @@ -122,7 +122,7 @@ def add_model_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--feedforward-dim", type=str, - default="1280,1280,1536,1280,1280,1280", + default="1280,1280,1280,1792,1280,1280", help="Feedforward dimension of the zipformer encoder layers, per stack, comma separated.", ) From cdfbbdded28d3d683fafef3e73e8a6f03f4d1e7e Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 20 Nov 2022 16:34:51 +0800 Subject: [PATCH 8/9] Refactoring, and change length_factor from 2.0 to 1.5. --- .../pruned_transducer_stateless7/zipformer.py | 34 +++++++++---------- 1 file changed, 16 insertions(+), 18 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index ff9468167..5b510a71d 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -544,7 +544,8 @@ class ZipformerEncoder(nn.Module): final_layerdrop_rate: float = 0.05, ) -> None: super().__init__() - self.encoder_pos = CompactRelPositionalEncoding(pos_dim, dropout_rate=0.15) + self.encoder_pos = CompactRelPositionalEncoding(pos_dim, dropout_rate=0.15, + length_factor=1.5) self.layers = nn.ModuleList( [copy.deepcopy(encoder_layer) for i in range(num_layers)] @@ -897,11 +898,14 @@ class CompactRelPositionalEncoding(torch.nn.Module): embed_dim: Embedding dimension. dropout_rate: Dropout rate. max_len: Maximum input length: just a heuristic for initialization. + length_factor: a heuristic scale (should be >= 1.0) which, if larger, gives + less weight to small differences of offset near the origin. """ def __init__( self, embed_dim: int, dropout_rate: float, - max_len: int = 1000 + max_len: int = 1000, + length_factor: float = 1.0, ) -> None: """Construct a CompactRelPositionalEncoding object.""" super(CompactRelPositionalEncoding, self).__init__() @@ -909,8 +913,12 @@ class CompactRelPositionalEncoding(torch.nn.Module): assert embed_dim % 2 == 0 self.dropout = torch.nn.Dropout(dropout_rate) self.pe = None + assert length_factor >= 1.0 + self.length_factor = length_factor self.extend_pe(torch.tensor(0.0).expand(max_len)) + + def extend_pe(self, x: Tensor) -> None: """Reset the positional encodings.""" if self.pe is not None: @@ -940,18 +948,16 @@ class CompactRelPositionalEncoding(torch.nn.Module): # is important. x_compressed = compression_length * x.sign() * ((x.abs() + compression_length).log() - math.log(compression_length)) - # length_factor is chosen so that the FFT can exactly separate points - # close to the origin (T == 0). So this part of the formulation is not really - # heuristic. - length_factor = self.embed_dim / (2.0 * math.pi) - # multiplying length_factor by this heuristic constant should reduce the resolution near to the - # origin, i.e. reduce its ability to separate points near zero. - length_factor *= 2.0 + # if self.length_factor == 1.0, then length_scale is chosen so that the + # FFT can exactly separate points close to the origin (T == 0). So this + # part of the formulation is not really heuristic. + # But empirically, for ASR at least, length_factor > 1.0 seems to work better. + length_scale = self.length_factor * self.embed_dim / (2.0 * math.pi) # note for machine implementations: if atan is not available, we can use: # x.sign() * ((1 / (x.abs() + 1)) - 1) * (-math.pi/2) # check on wolframalpha.com: plot(sign(x) * (1 / ( abs(x) + 1) - 1 ) * -pi/2 , atan(x)) - x_atan = (x_compressed / length_factor).atan() # results between -pi and pi + x_atan = (x_compressed / length_scale).atan() # results between -pi and pi cosines = (x_atan * freqs).cos() sines = (x_atan * freqs).sin() @@ -961,14 +967,6 @@ class CompactRelPositionalEncoding(torch.nn.Module): pe[:, 1::2] = sines pe[:, -1] = 1.0 # for bias. - # if we have the length_factor correct, the cosines around 0 offset (T in the array) - # should be oscillating in sign like -1, 1, -1; and the sines should all be close to - # zero. - #r = 2 - #print("cosines = ", cosines[T-r:T+r,-5:]) - #print("sines = ", sines[T-r:T+r,-5:]) - - self.pe = pe.to(dtype=x.dtype) From a10a0bce7d51df2e53d97c4e5b31b3974a3d01eb Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 20 Nov 2022 16:36:18 +0800 Subject: [PATCH 9/9] Increase length_factor from 1.5 to 3.0. --- egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index 5b510a71d..47a5efea1 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -545,7 +545,7 @@ class ZipformerEncoder(nn.Module): ) -> None: super().__init__() self.encoder_pos = CompactRelPositionalEncoding(pos_dim, dropout_rate=0.15, - length_factor=1.5) + length_factor=3.0) self.layers = nn.ModuleList( [copy.deepcopy(encoder_layer) for i in range(num_layers)]