mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-07 16:14:17 +00:00
Exclude loss calculation from mix precision training
This commit is contained in:
parent
2ff81b2838
commit
29d186118f
@ -128,9 +128,10 @@ class Transducer(nn.Module):
|
|||||||
boundary[:, 2] = y_lens
|
boundary[:, 2] = y_lens
|
||||||
boundary[:, 3] = x_lens
|
boundary[:, 3] = x_lens
|
||||||
|
|
||||||
|
with torch.autocast(device_type=x.device.type, enabled=False):
|
||||||
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
|
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
|
||||||
lm=decoder_out,
|
lm=decoder_out.float(),
|
||||||
am=encoder_out,
|
am=encoder_out.float(),
|
||||||
symbols=y_padded,
|
symbols=y_padded,
|
||||||
termination_symbol=blank_id,
|
termination_symbol=blank_id,
|
||||||
lm_only_scale=lm_scale,
|
lm_only_scale=lm_scale,
|
||||||
@ -157,8 +158,9 @@ class Transducer(nn.Module):
|
|||||||
# logits : [B, T, prune_range, C]
|
# logits : [B, T, prune_range, C]
|
||||||
logits = self.joiner(am_pruned, lm_pruned)
|
logits = self.joiner(am_pruned, lm_pruned)
|
||||||
|
|
||||||
|
with torch.autocast(device_type=x.device.type, enabled=False):
|
||||||
pruned_loss = k2.rnnt_loss_pruned(
|
pruned_loss = k2.rnnt_loss_pruned(
|
||||||
logits=logits,
|
logits=logits.float(),
|
||||||
symbols=y_padded,
|
symbols=y_padded,
|
||||||
ranges=ranges,
|
ranges=ranges,
|
||||||
termination_symbol=blank_id,
|
termination_symbol=blank_id,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user