From 864ff96322352ab4e71b422fd97dc2173de3f072 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 15 Dec 2022 19:27:29 +0800 Subject: [PATCH] Remove nonlin_skip_rate, introduce conv_skip_rate. --- .../ASR/pruned_transducer_stateless7/zipformer.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index 0c0b6ae21..0f45d0212 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -398,8 +398,8 @@ class ZipformerEncoderLayer(nn.Module): # treating batch_index == 0.0 specially is just to get scan_pessimistic_batches_for_oom() # to work correctly. layer_skip_rate: FloatLike = ScheduledFloat((0.0, 0.5), (4000.0, 0.05), default=0), - dynamic_skip_rate: FloatLike = ScheduledFloat((0.0, 0.2), (4000.0, 0.0), default=0), - conv_skip_rate: FloatLike = ScheduledFloat((0.0, 0.1), (20000, 0.0), default=0), + attention_skip_rate: FloatLike = ScheduledFloat((0.0, 0.2), (4000.0, 0.0), default=0), + conv_skip_rate: FloatLike = ScheduledFloat((0.0, 0.2), (16000, 0.0), default=0), const_attention_rate: FloatLike = ScheduledFloat((0.0, 0.25), (4000.0, 0.025), default=0), bypass_min: FloatLike = ScheduledFloat((0.0, 0.75), (20000.0, 0.2), default=0), bypass_max: FloatLike = 1.0, @@ -410,7 +410,7 @@ class ZipformerEncoderLayer(nn.Module): # probability of skipping the entire layer. self.layer_skip_rate = copy.deepcopy(layer_skip_rate) # skip probability for dynamic modules (meaning: anything but feedforward). - self.dynamic_skip_rate = copy.deepcopy(dynamic_skip_rate) + self.attention_skip_rate = copy.deepcopy(attention_skip_rate) # an additional skip probability that applies to ConvModule to stop it from # contributing too much early on. self.conv_skip_rate = copy.deepcopy(conv_skip_rate) @@ -507,7 +507,7 @@ class ZipformerEncoderLayer(nn.Module): src_orig = src # dropout rate for non-feedforward submodules - dynamic_skip_rate = float(self.dynamic_skip_rate) if self.training else 0.0 + attention_skip_rate = float(self.attention_skip_rate) if self.training else 0.0 # attn_weights: (num_heads, batch_size, seq_len, seq_len) if self.self_attn_weights is not None: @@ -528,7 +528,7 @@ class ZipformerEncoderLayer(nn.Module): # skip the layer return src, attn_weights - use_self_attn = (random.random() >= dynamic_skip_rate) + use_self_attn = (random.random() >= attention_skip_rate) if use_self_attn: selected_attn_weights = attn_weights[head_offset:head_offset+2] if random.random() < float(self.const_attention_rate): @@ -555,7 +555,7 @@ class ZipformerEncoderLayer(nn.Module): src = src + self.self_attn( src, attn_weights) - if torch.jit.is_scripting() or random.random() >= dynamic_skip_rate + float(self.conv_skip_rate): + if torch.jit.is_scripting() or random.random() >= float(self.conv_skip_rate): src = src + self.conv_module(src, src_key_padding_mask=src_key_padding_mask) src = src + self.feed_forward2(src)