mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-06 23:54:17 +00:00
Exclude loss calculation from mix precision training
This commit is contained in:
parent
2ff81b2838
commit
29d186118f
@ -128,17 +128,18 @@ class Transducer(nn.Module):
|
|||||||
boundary[:, 2] = y_lens
|
boundary[:, 2] = y_lens
|
||||||
boundary[:, 3] = x_lens
|
boundary[:, 3] = x_lens
|
||||||
|
|
||||||
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
|
with torch.autocast(device_type=x.device.type, enabled=False):
|
||||||
lm=decoder_out,
|
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
|
||||||
am=encoder_out,
|
lm=decoder_out.float(),
|
||||||
symbols=y_padded,
|
am=encoder_out.float(),
|
||||||
termination_symbol=blank_id,
|
symbols=y_padded,
|
||||||
lm_only_scale=lm_scale,
|
termination_symbol=blank_id,
|
||||||
am_only_scale=am_scale,
|
lm_only_scale=lm_scale,
|
||||||
boundary=boundary,
|
am_only_scale=am_scale,
|
||||||
reduction="sum",
|
boundary=boundary,
|
||||||
return_grad=True,
|
reduction="sum",
|
||||||
)
|
return_grad=True,
|
||||||
|
)
|
||||||
|
|
||||||
# ranges : [B, T, prune_range]
|
# ranges : [B, T, prune_range]
|
||||||
ranges = k2.get_rnnt_prune_ranges(
|
ranges = k2.get_rnnt_prune_ranges(
|
||||||
@ -157,13 +158,14 @@ class Transducer(nn.Module):
|
|||||||
# logits : [B, T, prune_range, C]
|
# logits : [B, T, prune_range, C]
|
||||||
logits = self.joiner(am_pruned, lm_pruned)
|
logits = self.joiner(am_pruned, lm_pruned)
|
||||||
|
|
||||||
pruned_loss = k2.rnnt_loss_pruned(
|
with torch.autocast(device_type=x.device.type, enabled=False):
|
||||||
logits=logits,
|
pruned_loss = k2.rnnt_loss_pruned(
|
||||||
symbols=y_padded,
|
logits=logits.float(),
|
||||||
ranges=ranges,
|
symbols=y_padded,
|
||||||
termination_symbol=blank_id,
|
ranges=ranges,
|
||||||
boundary=boundary,
|
termination_symbol=blank_id,
|
||||||
reduction="sum",
|
boundary=boundary,
|
||||||
)
|
reduction="sum",
|
||||||
|
)
|
||||||
|
|
||||||
return (simple_loss, pruned_loss)
|
return (simple_loss, pruned_loss)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user