Change to warmup schedule.

This commit is contained in:
Daniel Povey 2022-10-25 12:27:00 +08:00
parent 36cb279318
commit 1e8984174b

View File

@ -666,10 +666,14 @@ def compute_loss(
s if batch_idx_train >= warm_step
else 1.0 - (batch_idx_train / warm_step) * (1.0 - s)
)
pruned_loss_scale = (
1.0 if batch_idx_train >= warm_step
else 0.1 + 0.9 * (batch_idx_train / warm_step)
)
loss = (
simple_loss_scale * simple_loss
+ pruned_loss
simple_loss_scale * simple_loss +
pruned_loss_scale * pruned_loss
)
assert loss.requires_grad == is_training