Introduce dropout schedule for NonlinAttentionModule

This commit is contained in:
Daniel Povey 2022-12-01 15:19:51 +08:00
parent dcf6fced40
commit 4621e924ba

View File

@ -388,6 +388,7 @@ class ZipformerEncoderLayer(nn.Module):
# to work correctly. # to work correctly.
layer_skip_rate: FloatLike = ScheduledFloat((0.0, 0.5), (4000.0, 0.05), default=0), layer_skip_rate: FloatLike = ScheduledFloat((0.0, 0.5), (4000.0, 0.05), default=0),
dynamic_skip_rate: FloatLike = ScheduledFloat((0.0, 0.2), (4000.0, 0.0), default=0), dynamic_skip_rate: FloatLike = ScheduledFloat((0.0, 0.2), (4000.0, 0.0), default=0),
nonlin_skip_rate: FloatLike = ScheduledFloat((0.0, 0.1), (20000, 0.0), default=0),
const_attention_rate: FloatLike = ScheduledFloat((0.0, 0.25), (4000.0, 0.025), default=0), const_attention_rate: FloatLike = ScheduledFloat((0.0, 0.25), (4000.0, 0.025), default=0),
bypass_min: FloatLike = ScheduledFloat((0.0, 0.75), (20000.0, 0.2), default=0), bypass_min: FloatLike = ScheduledFloat((0.0, 0.75), (20000.0, 0.2), default=0),
bypass_max: FloatLike = 1.0, bypass_max: FloatLike = 1.0,
@ -397,8 +398,12 @@ class ZipformerEncoderLayer(nn.Module):
# probability of skipping the entire layer. # probability of skipping the entire layer.
self.layer_skip_rate = copy.deepcopy(layer_skip_rate) self.layer_skip_rate = copy.deepcopy(layer_skip_rate)
# skip probability for dynamic modules (meaning: anything but feedforward) # skip probability for dynamic modules (meaning: anything but feedforward).
self.dynamic_skip_rate = copy.deepcopy(dynamic_skip_rate) self.dynamic_skip_rate = copy.deepcopy(dynamic_skip_rate)
# an additional skip probability that applies to NoninAttentionModule to stop it from
# contributing too much early on.
self.nonlin_skip_rate = copy.deepcopy(nonlin_skip_rate)
# min and max for self.bypass_scale, applied with probability 0.5 to avoid grads # min and max for self.bypass_scale, applied with probability 0.5 to avoid grads
# ever becoming zero. # ever becoming zero.
self.bypass_min = copy.deepcopy(bypass_min) self.bypass_min = copy.deepcopy(bypass_min)
@ -521,7 +526,7 @@ class ZipformerEncoderLayer(nn.Module):
first_attn_weights = first_attn_weights * (1.0 / first_attn_weights.sum(dim=-1, keepdim=True)) first_attn_weights = first_attn_weights * (1.0 / first_attn_weights.sum(dim=-1, keepdim=True))
first_attn_weights = first_attn_weights.expand(3, -1, -1, -1) first_attn_weights = first_attn_weights.expand(3, -1, -1, -1)
if torch.jit.is_scripting() or use_self_attn: if torch.jit.is_scripting() or (use_self_attn and random.random() >= float(self.nonlin_skip_rate)):
src = src + self.nonlin_attention_module(src, src = src + self.nonlin_attention_module(src,
first_attn_weights[0:1]) first_attn_weights[0:1])