Increase final conv_skip_rate from 0.0 to 0.01

This commit is contained in:
Daniel Povey 2022-12-31 15:10:52 +08:00
parent 577c3ad390
commit c533c30442

View File

@ -403,8 +403,9 @@ class ZipformerEncoderLayer(nn.Module):
# to work correctly.
layer_skip_rate: FloatLike = ScheduledFloat((0.0, 0.5), (4000.0, 0.05), default=0),
attention_skip_rate: FloatLike = ScheduledFloat((0.0, 0.2), (4000.0, 0.05), (16000, 0.0), default=0),
conv_skip_rate: FloatLike = ScheduledFloat((0.0, 0.2), (4000.0, 0.05), (16000, 0.0), default=0),
conv_skip_rate: FloatLike = ScheduledFloat((0.0, 0.2), (4000.0, 0.05), (16000, 0.01), default=0),
const_attention_rate: FloatLike = ScheduledFloat((0.0, 0.25), (4000.0, 0.025), default=0),
ff2_skip_rate: FloatLike = 0.01,
bypass_min: FloatLike = ScheduledFloat((0.0, 0.75), (20000.0, 0.2), default=0),
bypass_max: FloatLike = 1.0,
) -> None:
@ -418,6 +419,9 @@ class ZipformerEncoderLayer(nn.Module):
# an additional skip probability that applies to ConvModule to stop it from
# contributing too much early on.
self.conv_skip_rate = copy.deepcopy(conv_skip_rate)
# ff2_skip_rate is to prevent the ff2 module from having output that's too big
# compared to its residual.
self.ff2_skip_rate = copy.deepcopy(ff2_skip_rate)
# min and max for self.bypass_scale, applied with probability 0.5 to avoid grads
# ever becoming zero.
@ -608,7 +612,8 @@ class ZipformerEncoderLayer(nn.Module):
if torch.jit.is_scripting() or random.random() >= float(self.conv_skip_rate):
src = src + self.conv_module(src, src_key_padding_mask=src_key_padding_mask)
src = src + self.balancer_ff2(self.feed_forward2(src))
if torch.jit.is_scripting() or random.random() >= float(self.ff2_skip_rate):
src = src + self.balancer_ff2(self.feed_forward2(src))
src = self.balancer1(src)
src = self.norm(src)