Exclude loss calculation from mix precision training

This commit is contained in:
pkufool 2022-04-03 07:59:17 +08:00
parent 2ff81b2838
commit 29d186118f

View File

@ -128,9 +128,10 @@ class Transducer(nn.Module):
boundary[:, 2] = y_lens
boundary[:, 3] = x_lens
with torch.autocast(device_type=x.device.type, enabled=False):
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
lm=decoder_out,
am=encoder_out,
lm=decoder_out.float(),
am=encoder_out.float(),
symbols=y_padded,
termination_symbol=blank_id,
lm_only_scale=lm_scale,
@ -157,8 +158,9 @@ class Transducer(nn.Module):
# logits : [B, T, prune_range, C]
logits = self.joiner(am_pruned, lm_pruned)
with torch.autocast(device_type=x.device.type, enabled=False):
pruned_loss = k2.rnnt_loss_pruned(
logits=logits,
logits=logits.float(),
symbols=y_padded,
ranges=ranges,
termination_symbol=blank_id,