diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index f0a52f605..0c0b6ae21 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -399,7 +399,7 @@ class ZipformerEncoderLayer(nn.Module): # to work correctly. layer_skip_rate: FloatLike = ScheduledFloat((0.0, 0.5), (4000.0, 0.05), default=0), dynamic_skip_rate: FloatLike = ScheduledFloat((0.0, 0.2), (4000.0, 0.0), default=0), - nonlin_skip_rate: FloatLike = ScheduledFloat((0.0, 0.1), (20000, 0.0), default=0), + conv_skip_rate: FloatLike = ScheduledFloat((0.0, 0.1), (20000, 0.0), default=0), const_attention_rate: FloatLike = ScheduledFloat((0.0, 0.25), (4000.0, 0.025), default=0), bypass_min: FloatLike = ScheduledFloat((0.0, 0.75), (20000.0, 0.2), default=0), bypass_max: FloatLike = 1.0, @@ -411,9 +411,9 @@ class ZipformerEncoderLayer(nn.Module): self.layer_skip_rate = copy.deepcopy(layer_skip_rate) # skip probability for dynamic modules (meaning: anything but feedforward). self.dynamic_skip_rate = copy.deepcopy(dynamic_skip_rate) - # an additional skip probability that applies to NoninAttentionModule to stop it from + # an additional skip probability that applies to ConvModule to stop it from # contributing too much early on. - self.nonlin_skip_rate = copy.deepcopy(nonlin_skip_rate) + self.conv_skip_rate = copy.deepcopy(conv_skip_rate) # min and max for self.bypass_scale, applied with probability 0.5 to avoid grads # ever becoming zero. @@ -541,7 +541,7 @@ class ZipformerEncoderLayer(nn.Module): selected_attn_weights = selected_attn_weights * (1.0 / selected_attn_weights.sum(dim=-1, keepdim=True)) selected_attn_weights = selected_attn_weights.expand(2, -1, -1, -1) - if torch.jit.is_scripting() or (use_self_attn and random.random() >= float(self.nonlin_skip_rate)): + if torch.jit.is_scripting() or use_self_attn: src = src + self.nonlin_attention_module(src, selected_attn_weights[0:1]) @@ -555,7 +555,7 @@ class ZipformerEncoderLayer(nn.Module): src = src + self.self_attn( src, attn_weights) - if torch.jit.is_scripting() or random.random() >= dynamic_skip_rate: + if torch.jit.is_scripting() or random.random() >= dynamic_skip_rate + float(self.conv_skip_rate): src = src + self.conv_module(src, src_key_padding_mask=src_key_padding_mask) src = src + self.feed_forward2(src)