diff --git a/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/.train.py.swp b/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/.train.py.swp index 0fe1c626a..dbcc18b7c 100644 Binary files a/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/.train.py.swp and b/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/.train.py.swp differ diff --git a/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/train.py b/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/train.py index 6cc2856c5..36aa913c4 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/train.py @@ -835,6 +835,19 @@ def compute_loss( assert loss.requires_grad == is_training + if decode: + model.eval() + with torch.no_grad(): + hypos = model.module.decode( + x=feature, + x_lens=feature_lens, + y=y, + sp=sp + ) + logging.info(f'ref: {batch["supervisions"]["text"][0]}') + logging.info(f'hyp: {" ".join(hypos[0])}') + model.train() + with warnings.catch_warnings(): warnings.simplefilter("ignore") info["frames"] = (feature_lens // params.subsampling_factor).sum().item()