diff --git a/egs/libriheavy/LM/zipformer1/subformer.py b/egs/libriheavy/LM/zipformer1/subformer.py index ae589813d..d74a65763 100644 --- a/egs/libriheavy/LM/zipformer1/subformer.py +++ b/egs/libriheavy/LM/zipformer1/subformer.py @@ -353,8 +353,7 @@ class SubformerEncoderLayer(nn.Module): causal: bool = False, memory_dim: int = -1, attention_skip_rate: FloatLike = ScheduledFloat((0.0, 0.2), (4000.0, 0.05), (16000, 0.0), default=0), - conv_skip_rate: FloatLike = ScheduledFloat((0.0, 0.2), (4000.0, 0.05), (16000, 0.0), default=0), - const_attention_rate: FloatLike = ScheduledFloat((0.0, 0.25), (4000.0, 0.025), default=0), + const_attention_rate: FloatLike = ScheduledFloat((0.0, 0.25), (4000.0, 0.0), default=0), ff2_skip_rate: FloatLike = ScheduledFloat((0.0, 0.1), (4000.0, 0.01), (50000.0, 0.0)), ff3_skip_rate: FloatLike = ScheduledFloat((0.0, 0.1), (4000.0, 0.01), (50000.0, 0.0)), bypass_skip_rate: FloatLike = ScheduledFloat((0.0, 0.5), (4000.0, 0.02), default=0), @@ -363,17 +362,15 @@ class SubformerEncoderLayer(nn.Module): self.embed_dim = embed_dim # self.bypass implements layer skipping as well as bypass; see its default values. - self.bypass = BypassModule(embed_dim, skip_rate=bypass_skip_rate, - straight_through_rate=0.025) + self.bypass = BypassModule(embed_dim, skip_rate=bypass_skip_rate) + # bypass_mid is bypass used in the middle of the layer. - self.bypass_mid = BypassModule(embed_dim, straight_through_rate=0.025) + self.bypass_mid = BypassModule(embed_dim) # skip probability for dynamic modules (meaning: anything but feedforward). self.attention_skip_rate = copy.deepcopy(attention_skip_rate) - # an additional skip probability that applies to ConvModule to stop it from - # contributing too much early on. - self.conv_skip_rate = copy.deepcopy(conv_skip_rate) + # ff2_skip_rate is to prevent the ff2 module from having output that's too big # compared to its residual.