From 67fcae95a89951ee03028f6d3318e703c7072da0 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 5 Apr 2023 14:44:16 +0800 Subject: [PATCH] Refactor bypass, and add bypass in the middle of the layer. --- .../pruned_transducer_stateless7/zipformer.py | 90 ++++++++++++++----- 1 file changed, 69 insertions(+), 21 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index a549da5a6..0328b20ed 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -207,7 +207,7 @@ class Zipformer2(EncoderInterface): dropout=dropout, warmup_begin=warmup_batches * (i + 1) / (num_encoders + 1), warmup_end=warmup_batches * (i + 2) / (num_encoders + 1), - final_layerdrop_rate=0.035 * (downsampling_factor[i] ** 0.5), + final_layerdrop_rate=0.02 * (downsampling_factor[i] ** 0.5), ) if downsampling_factor[i] != 1: @@ -519,23 +519,21 @@ class Zipformer2EncoderLayer(nn.Module): dropout: FloatLike = 0.1, cnn_module_kernel: int = 31, causal: bool = False, - # layer_skip_rate will be overwritten to change warmup begin and end times. - # treating batch_index == 0.0 specially is just to get scan_pessimistic_batches_for_oom() - # to work correctly. - layer_skip_rate: FloatLike = ScheduledFloat((0.0, 0.5), (4000.0, 0.05), default=0), attention_skip_rate: FloatLike = ScheduledFloat((0.0, 0.2), (4000.0, 0.05), (16000, 0.0), default=0), conv_skip_rate: FloatLike = ScheduledFloat((0.0, 0.2), (4000.0, 0.05), (16000, 0.0), default=0), const_attention_rate: FloatLike = ScheduledFloat((0.0, 0.25), (4000.0, 0.025), default=0), ff2_skip_rate: FloatLike = ScheduledFloat((0.0, 0.1), (4000.0, 0.01), (50000.0, 0.0)), ff3_skip_rate: FloatLike = ScheduledFloat((0.0, 0.1), (4000.0, 0.01), (50000.0, 0.0)), - bypass_min: FloatLike = ScheduledFloat((0.0, 0.75), (20000.0, 0.2), default=0), - bypass_max: FloatLike = 1.0, ) -> None: super(Zipformer2EncoderLayer, self).__init__() self.embed_dim = embed_dim - # probability of skipping the entire layer. - self.layer_skip_rate = copy.deepcopy(layer_skip_rate) + # self.bypass implements layer skipping as well as bypass; see its default values. + self.bypass = BypassModule(embed_dim) + # bypass_mid is bypass used in the middle of the layer. + self.bypass_mid = BypassModule(embed_dim) + + # skip probability for dynamic modules (meaning: anything but feedforward). self.attention_skip_rate = copy.deepcopy(attention_skip_rate) # an additional skip probability that applies to ConvModule to stop it from @@ -547,10 +545,6 @@ class Zipformer2EncoderLayer(nn.Module): self.ff2_skip_rate = copy.deepcopy(ff2_skip_rate) self.ff3_skip_rate = copy.deepcopy(ff3_skip_rate) - # min and max for self.bypass_scale, applied with probability 0.5 to avoid grads - # ever becoming zero. - self.bypass_min = copy.deepcopy(bypass_min) - self.bypass_max = copy.deepcopy(bypass_max) self.const_attention_rate = copy.deepcopy(const_attention_rate) self.self_attn_weights = RelPositionMultiheadAttentionWeights( @@ -751,6 +745,9 @@ class Zipformer2EncoderLayer(nn.Module): src = src + self.sequence_dropout(self.balancer_ff2(self.feed_forward2(src)), float(self.ff2_skip_rate)) + # bypass in the middle of the layer. + src = self.bypass_mid(src_orig, src) + self_attn = self.self_attn2( src, attn_weights) @@ -769,10 +766,7 @@ class Zipformer2EncoderLayer(nn.Module): src = self.balancer1(src) src = self.norm(src) - bypass_scale = self.get_bypass_scale(src.shape[1]) - # the next line equivalent to: src = src * bypass_scale + src_orig * - # (1.0 - bypass_scale), but more memory efficient for backprop. - src = src_orig + (src - src_orig) * bypass_scale + src = self.bypass(src_orig, src) src = self.balancer2(src) src = self.whiten(src) @@ -819,10 +813,12 @@ class Zipformer2Encoder(nn.Module): cur_begin = warmup_begin # interpreted as a training batch index for i in range(num_layers): cur_end = cur_begin + delta - # treating batch_index=0.0 specially is just to get scan_pessimistic_batches_for_oom() - self.layers[i].layer_skip_rate = ScheduledFloat((cur_begin, initial_layerdrop_rate), - (cur_end, final_layerdrop_rate), - default=0.0) + self.layers[i].bypass.skip_rate = ScheduledFloat((cur_begin, initial_layerdrop_rate), + (cur_end, final_layerdrop_rate), + default=0.0) + self.layers[i].bypass_mid.skip_rate = ScheduledFloat((cur_begin, initial_layerdrop_rate), + (cur_end, final_layerdrop_rate), + default=0.0) cur_begin = cur_end def forward( @@ -869,6 +865,58 @@ class Zipformer2Encoder(nn.Module): return output +class BypassModule(nn.Module): + """ + An nn.Module that implements a learnable bypass scale, and also randomized per-sequence + layer-skipping. The bypass is limited during early stages of training to be close to + "straight-through", i.e. to not do the bypass operation much initially, in order to + force all the modules to learn something. + """ + def __init__( + self, + embed_dim: int, + skip_rate: FloatLike = ScheduledFloat((0.0, 0.5), (4000.0, 0.02), default=0), + scale_min: FloatLike = ScheduledFloat((0.0, 0.9), (20000.0, 0.2), default=0), + scale_max: FloatLike = 1.0): + super().__init__() + self.bypass_scale = nn.Parameter(torch.full((embed_dim,), 0.5)) + self.skip_rate = copy.deepcopy(skip_rate) + self.scale_min = copy.deepcopy(scale_min) + self.scale_max = copy.deepcopy(scale_max) + + + def _get_bypass_scale(self, batch_size: int): + # returns bypass-scale of shape (num_channels,), + # or (batch_size, num_channels,). This is actually the + # scale on the non-residual term, so 0 correponds to bypassing + # this module. + if torch.jit.is_scripting() or not self.training: + return self.bypass_scale + else: + ans = limit_param_value(self.bypass_scale, + min=float(self.scale_min), + max=float(self.scale_max)) + skip_rate = float(self.skip_rate) + if skip_rate != 0.0: + mask = torch.rand((batch_size, 1), device=ans.device) > skip_rate + ans = ans * mask + # now ans is of shape (batch_size, num_channels), and is zero for sequences + # on which we have randomly chosen to do layer-skipping. + return ans + + def forward(self, + src_orig: Tensor, + src: Tensor): + """ + Args: src_orig and src are both of shape (seq_len, batch_size, num_channels) + Returns: something with the same shape as src and src_orig + """ + bypass_scale = self._get_bypass_scale(src.shape[1]) + return src_orig + (src - src_orig) * bypass_scale + + + + class DownsampledZipformer2Encoder(nn.Module): r""" DownsampledZipformer2Encoder is a zipformer encoder evaluated at a reduced frame rate,