diff --git a/egs/librispeech/ASR/transducer_stateless_aux_kl/train.py b/egs/librispeech/ASR/transducer_stateless_aux_kl/train.py index d56beed9e..71bf8da86 100755 --- a/egs/librispeech/ASR/transducer_stateless_aux_kl/train.py +++ b/egs/librispeech/ASR/transducer_stateless_aux_kl/train.py @@ -503,11 +503,12 @@ def compute_validation_loss( batch=batch, is_training=False, ) - assert loss.requires_grad is False + assert transduer_loss.requires_grad is False + assert aux_loss.requires_grad is False tot_loss = tot_loss + loss_info if world_size > 1: - tot_loss.reduce(loss.device) + tot_loss.reduce(transduer_loss.device) loss_value = tot_loss["tot_loss"] / tot_loss["frames"] if loss_value < params.best_valid_loss: