This commit is contained in:
Daniel Povey 2022-03-25 20:35:11 +08:00
parent 4b650e9f01
commit d2ed3dfc90

View File

@ -506,8 +506,8 @@ def compute_loss(
# overwhelming the simple_loss and causing it to diverge, # overwhelming the simple_loss and causing it to diverge,
# in case it had not fully learned the alignment yet. # in case it had not fully learned the alignment yet.
pruned_loss_scale = (0.0 if warmup < 1.0 else pruned_loss_scale = (0.0 if warmup < 1.0 else
(0.1 if warmup > 1.0 and warmup < 2.0) else (0.1 if warmup > 1.0 and warmup < 2.0 else
1.0) 1.0))
loss = (params.simple_loss_scale * simple_loss + loss = (params.simple_loss_scale * simple_loss +
pruned_loss_scale * pruned_loss) pruned_loss_scale * pruned_loss)