Exclude loss calculation from mix precision training

This commit is contained in:
pkufool 2022-04-03 07:59:17 +08:00
parent 2ff81b2838
commit 29d186118f

View File

@ -128,17 +128,18 @@ class Transducer(nn.Module):
boundary[:, 2] = y_lens boundary[:, 2] = y_lens
boundary[:, 3] = x_lens boundary[:, 3] = x_lens
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( with torch.autocast(device_type=x.device.type, enabled=False):
lm=decoder_out, simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
am=encoder_out, lm=decoder_out.float(),
symbols=y_padded, am=encoder_out.float(),
termination_symbol=blank_id, symbols=y_padded,
lm_only_scale=lm_scale, termination_symbol=blank_id,
am_only_scale=am_scale, lm_only_scale=lm_scale,
boundary=boundary, am_only_scale=am_scale,
reduction="sum", boundary=boundary,
return_grad=True, reduction="sum",
) return_grad=True,
)
# ranges : [B, T, prune_range] # ranges : [B, T, prune_range]
ranges = k2.get_rnnt_prune_ranges( ranges = k2.get_rnnt_prune_ranges(
@ -157,13 +158,14 @@ class Transducer(nn.Module):
# logits : [B, T, prune_range, C] # logits : [B, T, prune_range, C]
logits = self.joiner(am_pruned, lm_pruned) logits = self.joiner(am_pruned, lm_pruned)
pruned_loss = k2.rnnt_loss_pruned( with torch.autocast(device_type=x.device.type, enabled=False):
logits=logits, pruned_loss = k2.rnnt_loss_pruned(
symbols=y_padded, logits=logits.float(),
ranges=ranges, symbols=y_padded,
termination_symbol=blank_id, ranges=ranges,
boundary=boundary, termination_symbol=blank_id,
reduction="sum", boundary=boundary,
) reduction="sum",
)
return (simple_loss, pruned_loss) return (simple_loss, pruned_loss)