This commit is contained in:
Daniel Povey 2022-03-25 20:35:11 +08:00
parent 4b650e9f01
commit d2ed3dfc90

View File

@ -506,8 +506,8 @@ def compute_loss(
# overwhelming the simple_loss and causing it to diverge,
# in case it had not fully learned the alignment yet.
pruned_loss_scale = (0.0 if warmup < 1.0 else
(0.1 if warmup > 1.0 and warmup < 2.0) else
1.0)
(0.1 if warmup > 1.0 and warmup < 2.0 else
1.0))
loss = (params.simple_loss_scale * simple_loss +
pruned_loss_scale * pruned_loss)