From b8db0f53f1c0da161b3da4fa77f915f77d0c69c8 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 31 Oct 2022 13:11:59 +0800 Subject: [PATCH 1/5] Change to schedule of bypass_scale min: make it larger, decrease slower. --- .../ASR/pruned_transducer_stateless7/zipformer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index 9985c9001..28a6980ea 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -371,9 +371,9 @@ class ZipformerEncoderLayer(nn.Module): # ensure we get grads if self.bypass_scale becomes out of range return self.bypass_scale # hardcode warmup period for bypass scale - warmup_period = 4000.0 - initial_clamp_min = 0.5 - final_clamp_min = 0.2 + warmup_period = 20000.0 + initial_clamp_min = 1.0 + final_clamp_min = 0.3 if self.batch_count > warmup_period: clamp_min = final_clamp_min else: From 730e6c89146352ce866a001aa36ad681a62ba8df Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 31 Oct 2022 13:47:26 +0800 Subject: [PATCH 2/5] Change schedule after initial loss not promising --- egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index 28a6980ea..0bd90729c 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -372,8 +372,8 @@ class ZipformerEncoderLayer(nn.Module): return self.bypass_scale # hardcode warmup period for bypass scale warmup_period = 20000.0 - initial_clamp_min = 1.0 - final_clamp_min = 0.3 + initial_clamp_min = 0.75 + final_clamp_min = 0.25 if self.batch_count > warmup_period: clamp_min = final_clamp_min else: From 5fda800b6d2d80ccb138b72c89178a10a4c818e7 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 31 Oct 2022 15:49:18 +0800 Subject: [PATCH 3/5] Implement pooling module, add it after initial feedforward. --- .../pruned_transducer_stateless7/zipformer.py | 43 +++++++++++++++++++ 1 file changed, 43 insertions(+) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index 0bd90729c..7e8bea503 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -330,6 +330,8 @@ class ZipformerEncoderLayer(nn.Module): d_model, attention_dim, nhead, pos_dim, dropout=0.0, ) + self.pooling = PoolingModule(d_model) + self.feed_forward1 = FeedforwardModule(d_model, feedforward_dim, dropout) @@ -410,6 +412,10 @@ class ZipformerEncoderLayer(nn.Module): # macaron style feed forward module src = src + self.feed_forward1(src) + # pooling module + src = src + self.pooling(src, + key_padding_mask=src_key_padding_mask) + # multi-headed self-attention module src_att, attn_weights = self.self_attn( src, @@ -1384,6 +1390,43 @@ class RelPositionMultiheadAttention(nn.Module): logging.info(f"attn_weights_entropy = {attn_weights_entropy}, covar={attn_covar}, in_proj_covar={in_proj_covar}, out_proj_covar={out_proj_covar}") + + +class PoolingModule(nn.Module): + """ + Averages the input over the time dimension and project with a square matrix. + """ + def __init__(self, + d_model: int): + super().__init__() + self.proj = ScaledLinear(d_model, d_model, + initial_scale=0.1, bias=False) + + def forward(self, + x: Tensor, + key_padding_mask): + """ + Args: + x: a Tensor of shape (T, N, C) + key_padding_mask: a Tensor of bool, of shape (N, T), with True in masked + positions. + Returns: + a Tensor of shape (1, N, C) + """ + if key_padding_mask is not None: + pooling_mask = key_padding_mask.logical_not().to(src.dtype) # (N, T) + pooling_mask = (pooling_mask / pooling_mask.sum(dim=1, keepdim=True)) + pooling_mask = pooling_mask.transpose(0, 1).contiguous().unsqueeze(-1) + # now pooling_mask: (T, N, 1) + else: + num_frames = x.shape[0] + pooling_mask = 1.0 / num_frames + + x = (x * pooling_mask).sum(dim=0, keepdim=True) + x = self.proj(x) + return x + + class FeedforwardModule(nn.Module): """Feedforward module in Zipformer model. """ From 3de8a5aef2b054f53d05f270c7f616d9c62356e2 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 31 Oct 2022 15:50:46 +0800 Subject: [PATCH 4/5] Bug fix --- egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index 7e8bea503..aa6e3d1ae 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -1414,7 +1414,7 @@ class PoolingModule(nn.Module): a Tensor of shape (1, N, C) """ if key_padding_mask is not None: - pooling_mask = key_padding_mask.logical_not().to(src.dtype) # (N, T) + pooling_mask = key_padding_mask.logical_not().to(x.dtype) # (N, T) pooling_mask = (pooling_mask / pooling_mask.sum(dim=1, keepdim=True)) pooling_mask = pooling_mask.transpose(0, 1).contiguous().unsqueeze(-1) # now pooling_mask: (T, N, 1) From 12f17f550e6054533d551f32b40da0b3c5c4af61 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 31 Oct 2022 16:18:52 +0800 Subject: [PATCH 5/5] Introduce dropout rate to dynamic submodules of conformer. --- .../pruned_transducer_stateless7/zipformer.py | 48 ++++++++++++++----- 1 file changed, 36 insertions(+), 12 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index aa6e3d1ae..05a6ea933 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -383,6 +383,21 @@ class ZipformerEncoderLayer(nn.Module): (self.batch_count / warmup_period) * (initial_clamp_min - final_clamp_min)) return self.bypass_scale.clamp(min=clamp_min, max=1.0) + def get_dynamic_dropout_rate(self): + # return dropout rate for the dynamic modules (self_attn, pooling, convolution); this + # starts at 0.2 and rapidly decreases to 0. Its purpose is to keep the training stable + # at the beginning, by making the network focus on the feedforward modules. + if torch.jit.is_scripting() or not self.training: + return 0.0 + warmup_period = 2000.0 + initial_dropout_rate = 0.2 + final_dropout_rate = 0.0 + if self.batch_count > warmup_period: + return final_dropout_rate + else: + return (initial_dropout_rate - + (initial_dropout_rate * final_dropout_rate) * (self.batch_count / warmup_period)) + def forward( self, src: Tensor, @@ -412,28 +427,37 @@ class ZipformerEncoderLayer(nn.Module): # macaron style feed forward module src = src + self.feed_forward1(src) + # dropout rate for submodules that interact with time. + dynamic_dropout = self.get_dynamic_dropout_rate() + # pooling module - src = src + self.pooling(src, - key_padding_mask=src_key_padding_mask) + if torch.jit.is_scripting() or random.random() > dynamic_dropout: + src = src + self.pooling(src, + key_padding_mask=src_key_padding_mask) # multi-headed self-attention module - src_att, attn_weights = self.self_attn( - src, - pos_emb=pos_emb, - attn_mask=src_mask, - key_padding_mask=src_key_padding_mask, - ) - src = src + src_att + use_self_attn = (random.random() > dynamic_dropout) + if torch.jit.is_scripting() or use_self_attn: + src_att, attn_weights = self.self_attn( + src, + pos_emb=pos_emb, + attn_mask=src_mask, + key_padding_mask=src_key_padding_mask, + ) + src = src + src_att # convolution module - src = src + self.conv_module1(src, src_key_padding_mask=src_key_padding_mask) + if torch.jit.is_scripting() or random.random() > dynamic_dropout: + src = src + self.conv_module1(src, src_key_padding_mask=src_key_padding_mask) src = src + self.feed_forward2(src) - src = src + self.self_attn.forward2(src, attn_weights) + if torch.jit.is_scripting() or use_self_attn: + src = src + self.self_attn.forward2(src, attn_weights) - src = src + self.conv_module2(src, src_key_padding_mask=src_key_padding_mask) + if torch.jit.is_scripting() or random.random() > dynamic_dropout: + src = src + self.conv_module2(src, src_key_padding_mask=src_key_padding_mask) src = src + self.feed_forward3(src)