diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index 876d0a0c3..105a63943 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -338,8 +338,14 @@ class Zipformer(EncoderInterface): # this how we implement U-net-like skipping of some series of # stacks. The layer_skip_dropout_prob is to discourage it, especially # early in training, from completely ignoring the middle layers. - if not (self.training and random.random() < float(self.layer_skip_dropout_prob)): - x = self.skip_modules[i](outputs[self.skip_layers[i]], x) + batch_size = x.shape[0] + skip_x = self.skip_modules[i](outputs[self.skip_layers[i]], x) + if self.training: + mask = (torch.rand((batch_size, 1, 1), device=x.device) > + float(self.layer_skip_dropout_prob)) + x = torch.where(mask, skip_x, x) + else: + x = skip_x x = module(x, feature_mask=feature_masks[i], src_key_padding_mask=None if mask is None else mask[...,::ds]) @@ -463,13 +469,24 @@ class ZipformerEncoderLayer(nn.Module): def remove_attention_weights(self): self.self_attn_weights = None - def get_bypass_scale(self): + def get_bypass_scale(self, batch_size: int): + # returns bypass-scale of shape (num_channels,), + # or (batch_size, num_channels,). This is actually the + # scale on the delta src - src_orig, so 0 correponds to bypassing + # this module. if torch.jit.is_scripting() or not self.training: return self.bypass_scale else: - return limit_param_value(self.bypass_scale, - min=float(self.bypass_min), - max=float(self.bypass_max)) + ans = limit_param_value(self.bypass_scale, + min=float(self.bypass_min), + max=float(self.bypass_max)) + layer_skip_rate = float(self.layer_skip_rate) + if layer_skip_rate != 0.0: + mask = torch.rand((batch_size, 1), device=ans.device) > layer_skip_rate + ans = ans * mask + # now ans is of shape (batch_size, num_channels), and is zero for sequences + # on which we have randomly chosen to do layer-skipping. + return ans def forward( self, @@ -521,10 +538,6 @@ class ZipformerEncoderLayer(nn.Module): # attention weights from another one. head_offset = 0 if self.self_attn_weights is not None else 2 - if self.training and random.random() < float(self.layer_skip_rate): - # skip the layer - return src, attn_weights - use_self_attn = (random.random() >= attention_skip_rate) if use_self_attn: selected_attn_weights = attn_weights[head_offset:head_offset+2] @@ -564,7 +577,7 @@ class ZipformerEncoderLayer(nn.Module): delta = src - src_orig - src = src_orig + delta * self.get_bypass_scale() + src = src_orig + delta * self.get_bypass_scale(src.shape[1]) src = self.whiten(src) return src, attn_weights