From 3ef2a1d81e869e6105049fe8fa44881411d5b642 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 20 Dec 2022 22:21:56 +0800 Subject: [PATCH] Make some of the layer-skipping logic be done per sequence. --- .../pruned_transducer_stateless7/zipformer.py | 35 +++++++++++++------ 1 file changed, 24 insertions(+), 11 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index 6d078345c..fd67256c8 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -336,8 +336,14 @@ class Zipformer(EncoderInterface): # this how we implement U-net-like skipping of some series of # stacks. The layer_skip_dropout_prob is to discourage it, especially # early in training, from completely ignoring the middle layers. - if not (self.training and random.random() < float(self.layer_skip_dropout_prob)): - x = self.skip_modules[i](outputs[self.skip_layers[i]], x) + batch_size = x.shape[0] + skip_x = self.skip_modules[i](outputs[self.skip_layers[i]], x) + if self.training: + mask = (torch.rand((batch_size, 1, 1), device=x.device) > + float(self.layer_skip_dropout_prob)) + x = torch.where(mask, skip_x, x) + else: + x = skip_x x = module(x, feature_mask=feature_masks[i], src_key_padding_mask=None if mask is None else mask[...,::ds]) @@ -461,13 +467,24 @@ class ZipformerEncoderLayer(nn.Module): def remove_attention_weights(self): self.self_attn_weights = None - def get_bypass_scale(self): + def get_bypass_scale(self, batch_size: int): + # returns bypass-scale of shape (num_channels,), + # or (batch_size, num_channels,). This is actually the + # scale on the delta src - src_orig, so 0 correponds to bypassing + # this module. if torch.jit.is_scripting() or not self.training: return self.bypass_scale else: - return limit_param_value(self.bypass_scale, - min=float(self.bypass_min), - max=float(self.bypass_max)) + ans = limit_param_value(self.bypass_scale, + min=float(self.bypass_min), + max=float(self.bypass_max)) + layer_skip_rate = float(self.layer_skip_rate) + if layer_skip_rate != 0.0: + mask = torch.rand((batch_size, 1), device=ans.device) > layer_skip_rate + ans = ans * mask + # now ans is of shape (batch_size, num_channels), and is zero for sequences + # on which we have randomly chosen to do layer-skipping. + return ans def forward( self, @@ -519,10 +536,6 @@ class ZipformerEncoderLayer(nn.Module): # attention weights from another one. head_offset = 0 if self.self_attn_weights is not None else 2 - if self.training and random.random() < float(self.layer_skip_rate): - # skip the layer - return src, attn_weights - use_self_attn = (random.random() >= attention_skip_rate) if use_self_attn: selected_attn_weights = attn_weights[head_offset:head_offset+2] @@ -560,7 +573,7 @@ class ZipformerEncoderLayer(nn.Module): delta = src - src_orig - src = src_orig + delta * self.get_bypass_scale() + src = src_orig + delta * self.get_bypass_scale(src.shape[1]) src = self.whiten(src) return src, attn_weights