Limit bypass scale to >= 0.1

This commit is contained in:
Daniel Povey 2022-10-08 21:37:21 +08:00
parent bc9fbe2579
commit d467338837

View File

@ -375,6 +375,9 @@ class ConformerEncoderLayer(nn.Module):
keep_prob = 0.5 + 0.5 * warmup_value
# the :1 means the mask is chosen per frame.
delta = delta * (torch.rand_like(delta[...,:1]) < keep_prob)
bypass_scale = self.bypass_scale
if random.random() > 0.1:
bypass_scale = bypass_scale.clamp(min=0.1)
src = src_orig + delta * self.bypass_scale
return src, attn_scores_out