Reduce warmup scale on pruned loss form 0.1 to 0.01.

This commit is contained in:
Daniel Povey 2022-03-17 16:46:59 +08:00
parent acc0eda5b0
commit cbe6b175d1

View File

@ -452,7 +452,7 @@ def compute_loss(
) )
loss = params.simple_loss_scale * simple_loss loss = params.simple_loss_scale * simple_loss
if not warmup_mode: if not warmup_mode:
loss = loss + pruned_loss * (0.1 if warmup_mode else 1.0) loss = loss + (pruned_loss * 0.01 if warmup_mode else pruned_loss)
assert loss.requires_grad == is_training assert loss.requires_grad == is_training