From 6de0a849ce71d9ff876e60034ffb5ac66433d6e0 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Sat, 12 Mar 2022 00:55:04 +0800 Subject: [PATCH] Support modified transducer. --- .../model.py | 16 ++++++++++++++++ .../train.py | 12 ++++++++++++ 2 files changed, 28 insertions(+) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless_multi_datasets/model.py b/egs/librispeech/ASR/pruned_transducer_stateless_multi_datasets/model.py index e816d5233..ae6049e70 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless_multi_datasets/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless_multi_datasets/model.py @@ -15,6 +15,7 @@ # limitations under the License. +import random from typing import Optional import k2 @@ -119,6 +120,7 @@ class Transducer(nn.Module): x_lens: torch.Tensor, y: k2.RaggedTensor, libri: bool = True, + modified_transducer_prob: float = 0.0, prune_range: int = 5, am_scale: float = 0.0, lm_scale: float = 0.0, @@ -136,6 +138,8 @@ class Transducer(nn.Module): libri: True to use the decoder and joiner for the LibriSpeech dataset. False to use the decoder and joiner for the GigaSpeech dataset. + modified_transducer_prob: + The probability to use modified transducer loss. prune_range: The prune range for rnnt loss, it means how many symbols(context) we are considering for each frame to compute the loss. @@ -163,6 +167,16 @@ class Transducer(nn.Module): encoder_out, x_lens = self.encoder(x, x_lens) assert torch.all(x_lens > 0) + assert 0 <= modified_transducer_prob <= 1 + + if modified_transducer_prob == 0: + modified = False + elif random.random() < modified_transducer_prob: + # random.random() returns a float in the range [0, 1) + modified = True + else: + modified = False + # Now for the decoder, i.e., the prediction network row_splits = y.shape.row_splits(1) y_lens = row_splits[1:] - row_splits[:-1] @@ -213,6 +227,7 @@ class Transducer(nn.Module): lm_only_scale=lm_scale, am_only_scale=am_scale, boundary=boundary, + modified=modified, reduction="sum", return_grad=True, ) @@ -243,6 +258,7 @@ class Transducer(nn.Module): ranges=ranges, termination_symbol=blank_id, boundary=boundary, + modified=modified, reduction="sum", ) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless_multi_datasets/train.py b/egs/librispeech/ASR/pruned_transducer_stateless_multi_datasets/train.py index c03291113..2a1c6e475 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless_multi_datasets/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless_multi_datasets/train.py @@ -180,6 +180,17 @@ def get_parser(): "2 means tri-gram", ) + parser.add_argument( + "--modified-transducer-prob", + type=float, + default=0.25, + help="""The probability to use modified transducer loss. + In modified transduer, it limits the maximum number of symbols + per frame to 1. See also the option --max-sym-per-frame in + pruned_transducer_stateless_multi_datasets/decode.py + """, + ) + parser.add_argument( "--prune-range", type=int, @@ -498,6 +509,7 @@ def compute_loss( x_lens=feature_lens, y=y, libri=libri, + modified_transducer_prob=params.modified_transducer_prob, prune_range=params.prune_range, am_scale=params.am_scale, lm_scale=params.lm_scale,