From ad388890d94eaeab105374d88d0ff707735806d1 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 10 Feb 2023 16:34:01 +0800 Subject: [PATCH] Make most forms of sequence dropout be separate per sequence. --- .../pruned_transducer_stateless7/zipformer.py | 52 ++++++++++++------- 1 file changed, 34 insertions(+), 18 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index 089b918e0..7057cc292 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -623,6 +623,26 @@ class ZipformerEncoderLayer(nn.Module): # on which we have randomly chosen to do layer-skipping. return ans + def get_sequence_dropout_mask(self, x: Tensor, dropout_rate: float) -> Optional[Tensor]: + if dropout_rate == 0.0 or not self.training or torch.jit.is_scripting(): + return None + batch_size = x.shape[1] + mask = (torch.rand(batch_size, 1, device=x.device) > dropout_rate).to(x.dtype) + return mask + + + def sequence_dropout(self, x: Tensor, dropout_rate: float) -> Tensor: + """ + Apply sequence-level dropout to x. + x shape: (seq_len, batch_size, embed_dim) + """ + dropout_mask = self.get_sequence_dropout_mask(x, dropout_rate) + if dropout_mask is None: + return x + else: + return x * dropout_mask + + def forward( self, src: Tensor, @@ -670,8 +690,9 @@ class ZipformerEncoderLayer(nn.Module): # attention weights from another one. head_offset = 0 if self.self_attn_weights is not None else 2 - use_self_attn = (random.random() >= attention_skip_rate) - if use_self_attn: + self_attn_dropout_mask = self.get_sequence_dropout_mask(src, attention_skip_rate) + + if True: selected_attn_weights = attn_weights[head_offset:head_offset+2] if random.random() < float(self.const_attention_rate): # Make attention weights constant. The intention is to @@ -683,28 +704,23 @@ class ZipformerEncoderLayer(nn.Module): selected_attn_weights = selected_attn_weights * (1.0 / selected_attn_weights.sum(dim=-1, keepdim=True)) selected_attn_weights = selected_attn_weights.expand(2, -1, -1, -1) - if torch.jit.is_scripting() or use_self_attn: - src = src + self.balancer_na(self.nonlin_attention(src, - selected_attn_weights[0:1])) + na = self.balancer_na(self.nonlin_attention(src, + selected_attn_weights[0:1])) + src = src + (na if self_attn_dropout_mask is None else na * self_attn_dropout_mask) src = src + self.feed_forward1(src) - ## pooling module - #if torch.jit.is_scripting() or use_self_attn: - # src = src + self.balancer_as( - # self.attention_squeeze(src, selected_attn_weights[1:2])) + self_attn = self.self_attn( + src, attn_weights) + src = src + (self_attn if self_attn_dropout_mask is None else self_attn * self_attn_dropout_mask) - if torch.jit.is_scripting() or use_self_attn: - src = src + self.self_attn( - src, attn_weights) + src = src + self.sequence_dropout(self.conv_module(src, chunk_size=chunk_size, + src_key_padding_mask=src_key_padding_mask), + float(self.conv_skip_rate)) - if torch.jit.is_scripting() or random.random() >= float(self.conv_skip_rate): - src = src + self.conv_module(src, chunk_size=chunk_size, - src_key_padding_mask=src_key_padding_mask) - - if torch.jit.is_scripting() or random.random() >= float(self.ff2_skip_rate): - src = src + self.balancer_ff2(self.feed_forward2(src)) + src = src + self.sequence_dropout(self.balancer_ff2(self.feed_forward2(src)), + float(self.ff2_skip_rate)) src = self.balancer1(src) src = self.norm(src)