From 92e524ea7fac0db1af1ad9dcf9f9cdc38e23a5ab Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Thu, 20 Jan 2022 11:27:29 +0800 Subject: [PATCH] Use modified transducer loss in training. --- .../ASR/transducer_stateless/model.py | 16 ++++++++++++++++ .../ASR/transducer_stateless/train.py | 18 +++++++++++++++++- 2 files changed, 33 insertions(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/transducer_stateless/model.py b/egs/librispeech/ASR/transducer_stateless/model.py index 98a6f0f37..ca688bd5a 100644 --- a/egs/librispeech/ASR/transducer_stateless/model.py +++ b/egs/librispeech/ASR/transducer_stateless/model.py @@ -14,6 +14,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import random + import k2 import torch import torch.nn as nn @@ -62,6 +64,7 @@ class Transducer(nn.Module): x: torch.Tensor, x_lens: torch.Tensor, y: k2.RaggedTensor, + modified_transducer_prob: float = 0.0, ) -> torch.Tensor: """ Args: @@ -73,6 +76,8 @@ class Transducer(nn.Module): y: A ragged tensor with 2 axes [utt][label]. It contains labels of each utterance. + modified_transducer_prob: + The probability to use modified transducer loss. Returns: Return the transducer loss. """ @@ -113,6 +118,16 @@ class Transducer(nn.Module): # reference stage import optimized_transducer + assert 0 <= modified_transducer_prob <= 1 + + if modified_transducer_prob == 0: + one_sym_per_frame = False + elif random.random() < modified_transducer_prob: + # random.random() returns a float in the range [0, 1) + one_sym_per_frame = True + else: + one_sym_per_frame = False + loss = optimized_transducer.transducer_loss( logits=logits, targets=y_padded, @@ -120,6 +135,7 @@ class Transducer(nn.Module): target_lengths=y_lens, blank=blank_id, reduction="sum", + one_sym_per_frame=one_sym_per_frame, ) return loss diff --git a/egs/librispeech/ASR/transducer_stateless/train.py b/egs/librispeech/ASR/transducer_stateless/train.py index 694ebf1d5..1accda09a 100755 --- a/egs/librispeech/ASR/transducer_stateless/train.py +++ b/egs/librispeech/ASR/transducer_stateless/train.py @@ -138,6 +138,17 @@ def get_parser(): "2 means tri-gram", ) + parser.add_argument( + "--modified-transducer-prob", + type=float, + default=0.25, + help="""The probability to use modified transducer loss. + In modified transduer, it limits the maximum number of symbols + per frame to 1. See also the option --max-sym-per-frame in + transducer_stateless/decode.py + """, + ) + return parser @@ -383,7 +394,12 @@ def compute_loss( y = k2.RaggedTensor(y).to(device) with torch.set_grad_enabled(is_training): - loss = model(x=feature, x_lens=feature_lens, y=y) + loss = model( + x=feature, + x_lens=feature_lens, + y=y, + modified_transducer_prob=params.modified_transducer_prob, + ) assert loss.requires_grad == is_training