From 29d186118f6063ea87d62605abe0a96de354dc5f Mon Sep 17 00:00:00 2001 From: pkufool Date: Sun, 3 Apr 2022 07:59:17 +0800 Subject: [PATCH] Exclude loss calculation from mix precision training --- .../ASR/pruned_transducer_stateless/model.py | 40 ++++++++++--------- 1 file changed, 21 insertions(+), 19 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/model.py b/egs/librispeech/ASR/pruned_transducer_stateless/model.py index 2f019bcdb..810e7ace5 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/model.py @@ -128,17 +128,18 @@ class Transducer(nn.Module): boundary[:, 2] = y_lens boundary[:, 3] = x_lens - simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( - lm=decoder_out, - am=encoder_out, - symbols=y_padded, - termination_symbol=blank_id, - lm_only_scale=lm_scale, - am_only_scale=am_scale, - boundary=boundary, - reduction="sum", - return_grad=True, - ) + with torch.autocast(device_type=x.device.type, enabled=False): + simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( + lm=decoder_out.float(), + am=encoder_out.float(), + symbols=y_padded, + termination_symbol=blank_id, + lm_only_scale=lm_scale, + am_only_scale=am_scale, + boundary=boundary, + reduction="sum", + return_grad=True, + ) # ranges : [B, T, prune_range] ranges = k2.get_rnnt_prune_ranges( @@ -157,13 +158,14 @@ class Transducer(nn.Module): # logits : [B, T, prune_range, C] logits = self.joiner(am_pruned, lm_pruned) - pruned_loss = k2.rnnt_loss_pruned( - logits=logits, - symbols=y_padded, - ranges=ranges, - termination_symbol=blank_id, - boundary=boundary, - reduction="sum", - ) + with torch.autocast(device_type=x.device.type, enabled=False): + pruned_loss = k2.rnnt_loss_pruned( + logits=logits.float(), + symbols=y_padded, + ranges=ranges, + termination_symbol=blank_id, + boundary=boundary, + reduction="sum", + ) return (simple_loss, pruned_loss)