From fd6416e6c1ed3ae08f8ca2a41bf78fa6baa04e1e Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Thu, 14 Apr 2022 11:43:35 +0800 Subject: [PATCH] Update train.py to use torchaudio's RNN-T loss. --- egs/librispeech/ASR/transducer_stateless2/model.py | 3 --- egs/librispeech/ASR/transducer_stateless2/train.py | 12 ------------ 2 files changed, 15 deletions(-) diff --git a/egs/librispeech/ASR/transducer_stateless2/model.py b/egs/librispeech/ASR/transducer_stateless2/model.py index 78a96b983..9208d0654 100644 --- a/egs/librispeech/ASR/transducer_stateless2/model.py +++ b/egs/librispeech/ASR/transducer_stateless2/model.py @@ -70,7 +70,6 @@ class Transducer(nn.Module): x: torch.Tensor, x_lens: torch.Tensor, y: k2.RaggedTensor, - modified_transducer_prob: float = 0.0, ) -> torch.Tensor: """ Args: @@ -82,8 +81,6 @@ class Transducer(nn.Module): y: A ragged tensor with 2 axes [utt][label]. It contains labels of each utterance. - modified_transducer_prob: - The probability to use modified transducer loss. Returns: Return the transducer loss. """ diff --git a/egs/librispeech/ASR/transducer_stateless2/train.py b/egs/librispeech/ASR/transducer_stateless2/train.py index d6827c17c..81acf9706 100755 --- a/egs/librispeech/ASR/transducer_stateless2/train.py +++ b/egs/librispeech/ASR/transducer_stateless2/train.py @@ -140,17 +140,6 @@ def get_parser(): "2 means tri-gram", ) - parser.add_argument( - "--modified-transducer-prob", - type=float, - default=0.25, - help="""The probability to use modified transducer loss. - In modified transduer, it limits the maximum number of symbols - per frame to 1. See also the option --max-sym-per-frame in - transducer_stateless/decode.py - """, - ) - parser.add_argument( "--seed", type=int, @@ -414,7 +403,6 @@ def compute_loss( x=feature, x_lens=feature_lens, y=y, - modified_transducer_prob=params.modified_transducer_prob, ) assert loss.requires_grad == is_training