From d20e927e6ad332d9629adb108f74f97c093ea530 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Thu, 14 Apr 2022 11:41:51 +0800 Subject: [PATCH] Update model.py to use torchaudio's RNN-T loss. --- .../ASR/transducer_stateless2/model.py | 31 +++++++------------ 1 file changed, 11 insertions(+), 20 deletions(-) diff --git a/egs/librispeech/ASR/transducer_stateless2/model.py b/egs/librispeech/ASR/transducer_stateless2/model.py index 8281e1fb5..78a96b983 100644 --- a/egs/librispeech/ASR/transducer_stateless2/model.py +++ b/egs/librispeech/ASR/transducer_stateless2/model.py @@ -13,12 +13,18 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +""" +Note we use `rnnt_loss` from torchaudio, which exists only in +torchaudio >= v0.10.0. It also means you have to use torch >= v1.10.0 +""" import random import k2 import torch import torch.nn as nn +import torchaudio +import torchaudio.functional from encoder_interface import EncoderInterface from icefall.utils import add_sos @@ -102,42 +108,27 @@ class Transducer(nn.Module): decoder_out = self.decoder(sos_y_padded) - # +1 here since a blank is prepended to each utterance. logits = self.joiner( encoder_out=encoder_out, decoder_out=decoder_out, - encoder_out_len=x_lens, - decoder_out_len=y_lens + 1, ) # rnnt_loss requires 0 padded targets # Note: y does not start with SOS y_padded = y.pad(mode="constant", padding_value=0) - # We don't put this `import` at the beginning of the file - # as it is required only in the training, not during the - # reference stage - import optimized_transducer + assert hasattr(torchaudio.functional, "rnnt_loss"), ( + f"Current torchaudio version: {torchaudio.__version__}\n" + "Please install a version >= 0.10.0" + ) - assert 0 <= modified_transducer_prob <= 1 - - if modified_transducer_prob == 0: - one_sym_per_frame = False - elif random.random() < modified_transducer_prob: - # random.random() returns a float in the range [0, 1) - one_sym_per_frame = True - else: - one_sym_per_frame = False - - loss = optimized_transducer.transducer_loss( + loss = torchaudio.functional.rnnt_loss( logits=logits, targets=y_padded, logit_lengths=x_lens, target_lengths=y_lens, blank=blank_id, reduction="sum", - one_sym_per_frame=one_sym_per_frame, - from_log_softmax=False, ) return loss