Exclude loss calculation from mix precision training

This commit is contained in:
pkufool 2022-04-03 07:59:17 +08:00
parent 2ff81b2838
commit 29d186118f

View File

@ -128,17 +128,18 @@ class Transducer(nn.Module):
boundary[:, 2] = y_lens
boundary[:, 3] = x_lens
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
lm=decoder_out,
am=encoder_out,
symbols=y_padded,
termination_symbol=blank_id,
lm_only_scale=lm_scale,
am_only_scale=am_scale,
boundary=boundary,
reduction="sum",
return_grad=True,
)
with torch.autocast(device_type=x.device.type, enabled=False):
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
lm=decoder_out.float(),
am=encoder_out.float(),
symbols=y_padded,
termination_symbol=blank_id,
lm_only_scale=lm_scale,
am_only_scale=am_scale,
boundary=boundary,
reduction="sum",
return_grad=True,
)
# ranges : [B, T, prune_range]
ranges = k2.get_rnnt_prune_ranges(
@ -157,13 +158,14 @@ class Transducer(nn.Module):
# logits : [B, T, prune_range, C]
logits = self.joiner(am_pruned, lm_pruned)
pruned_loss = k2.rnnt_loss_pruned(
logits=logits,
symbols=y_padded,
ranges=ranges,
termination_symbol=blank_id,
boundary=boundary,
reduction="sum",
)
with torch.autocast(device_type=x.device.type, enabled=False):
pruned_loss = k2.rnnt_loss_pruned(
logits=logits.float(),
symbols=y_padded,
ranges=ranges,
termination_symbol=blank_id,
boundary=boundary,
reduction="sum",
)
return (simple_loss, pruned_loss)