Scale down pruned loss in warmup mode

This commit is contained in:
Daniel Povey 2022-03-17 16:09:35 +08:00
parent 13db33ffa2
commit acc0eda5b0

View File

@ -450,7 +450,9 @@ def compute_loss(
lm_scale=params.lm_scale,
warmup_mode=warmup_mode,
)
loss = params.simple_loss_scale * simple_loss + pruned_loss
loss = params.simple_loss_scale * simple_loss
if not warmup_mode:
loss = loss + pruned_loss * (0.1 if warmup_mode else 1.0)
assert loss.requires_grad == is_training