Bug fix to warmup_scale

This commit is contained in:
Daniel Povey 2022-03-31 17:30:51 +08:00
parent 49bc761ba1
commit 8caa18e2fe

View File

@ -221,7 +221,7 @@ class ConformerEncoderLayer(nn.Module):
warmup_scale = min(0.1 + warmup, 1.0)
# alpha = 1.0 means fully use this encoder layer, 0.0 would mean completely
# bypass it.
alpha = 0.1 if torch.rand(()).item() <= 0.9 else warmup_scale
alpha = warmup_scale if torch.rand(()).item() <= 0.9 else 0.1
# macaron style feed forward module
src = src + self.dropout(self.feed_forward_macaron(src))