diff --git a/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/.train_adapter.py.swp b/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/.train_adapter.py.swp index 40c595ef9..695b64d9f 100644 Binary files a/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/.train_adapter.py.swp and b/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/.train_adapter.py.swp differ diff --git a/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/train_adapter.py b/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/train_adapter.py index 44bc45337..ef48c720e 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/train_adapter.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/train_adapter.py @@ -913,12 +913,21 @@ def compute_loss( if decode: model.eval() with torch.no_grad(): - hypos = model.module.decode( - x=feature, - x_lens=feature_lens, - y=y, - sp=sp - ) + try: + hypos = model.module.decode( + x=feature, + x_lens=feature_lens, + y=y, + sp=sp + ) + except: + hypos = model.decode( + x=feature, + x_lens=feature_lens, + y=y, + sp=sp + ) + logging.info(f'ref: {batch["supervisions"]["text"][0]}') logging.info(f'hyp: {" ".join(hypos[0])}') model.train()