Reduce initial pruned_loss scale from 0.01 to 0.0

This commit is contained in:
Daniel Povey 2022-03-22 12:30:48 +08:00
parent b7e84d5d77
commit b82a505dfc

View File

@ -496,7 +496,7 @@ def compute_loss(
warmup_mode=warmup_mode,
)
loss = (params.simple_loss_scale * simple_loss +
(pruned_loss * 0.01 if warmup_mode else pruned_loss))
(pruned_loss * 0.0 if warmup_mode else pruned_loss))
assert loss.requires_grad == is_training