From 277cc3f9bf07d58a3f0b833ed74bca3fde49bf1a Mon Sep 17 00:00:00 2001 From: PF Luo Date: Sat, 19 Feb 2022 15:56:39 +0800 Subject: [PATCH] update aishell-1 recipe with k2.rnnt_loss (#215) * update aishell-1 recipe with k2.rnnt_loss * fix flak8 style * typo * add pretrained model link to result.md --- README.md | 2 +- egs/aishell/ASR/RESULTS.md | 7 ++++--- egs/aishell/ASR/transducer_stateless/decode.py | 8 ++------ egs/aishell/ASR/transducer_stateless/model.py | 18 +++++++++--------- egs/aishell/ASR/transducer_stateless/train.py | 3 ++- 5 files changed, 18 insertions(+), 20 deletions(-) diff --git a/README.md b/README.md index 28c9b6ce4..214e85ad0 100644 --- a/README.md +++ b/README.md @@ -113,7 +113,7 @@ The best CER we currently have is: | | test | |-----|------| -| CER | 5.7 | +| CER | 5.4 | We provide a Colab notebook to run a pre-trained TransducerStateless model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/14XaT2MhnBkK-3_RqqWq3K90Xlbin-GZC?usp=sharing) diff --git a/egs/aishell/ASR/RESULTS.md b/egs/aishell/ASR/RESULTS.md index 61f7e500e..688e0f60c 100644 --- a/egs/aishell/ASR/RESULTS.md +++ b/egs/aishell/ASR/RESULTS.md @@ -45,12 +45,13 @@ You can use the following commands to reproduce our results: ### Aishell training results (Transducer-stateless) -#### 2021-12-29 -(Pingfeng Luo) : The tensorboard log for training is available at +#### 2022-02-18 +(Pingfeng Luo) : The tensorboard log for training is available at +And pretrained model is available at ||test| |--|--| -|CER| 5.7% | +|CER| 5.4% | You can use the following commands to reproduce our results: diff --git a/egs/aishell/ASR/transducer_stateless/decode.py b/egs/aishell/ASR/transducer_stateless/decode.py index f27e4cdcf..a7b030fa5 100755 --- a/egs/aishell/ASR/transducer_stateless/decode.py +++ b/egs/aishell/ASR/transducer_stateless/decode.py @@ -31,7 +31,6 @@ from decoder import Decoder from joiner import Joiner from model import Transducer -from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler from icefall.checkpoint import average_checkpoints, load_checkpoint from icefall.env import get_env_info from icefall.lexicon import Lexicon @@ -403,12 +402,9 @@ def main(): logging.info(f"Device: {device}") lexicon = Lexicon(params.lang_dir) - graph_compiler = CharCtcTrainingGraphCompiler( - lexicon=lexicon, - device=device, - ) - params.blank_id = graph_compiler.texts_to_ids("")[0][0] + # params.blank_id = graph_compiler.texts_to_ids("")[0][0] + params.blank_id = 0 params.vocab_size = max(lexicon.tokens) + 1 logging.info(params) diff --git a/egs/aishell/ASR/transducer_stateless/model.py b/egs/aishell/ASR/transducer_stateless/model.py index 2f0f9a183..0322edeed 100644 --- a/egs/aishell/ASR/transducer_stateless/model.py +++ b/egs/aishell/ASR/transducer_stateless/model.py @@ -108,18 +108,18 @@ class Transducer(nn.Module): # Note: y does not start with SOS y_padded = y.pad(mode="constant", padding_value=0) + y_padded = y_padded.to(torch.int64) + boundary = torch.zeros( + (x.size(0), 4), dtype=torch.int64, device=x.device + ) + boundary[:, 2] = y_lens + boundary[:, 3] = x_lens + assert hasattr(torchaudio.functional, "rnnt_loss"), ( f"Current torchaudio version: {torchaudio.__version__}\n" "Please install a version >= 0.10.0" ) - loss = torchaudio.functional.rnnt_loss( - logits=logits, - targets=y_padded, - logit_lengths=x_lens, - target_lengths=y_lens, - blank=blank_id, - reduction="sum", - ) + loss = k2.rnnt_loss(logits, y_padded, blank_id, boundary) - return loss + return torch.sum(loss) diff --git a/egs/aishell/ASR/transducer_stateless/train.py b/egs/aishell/ASR/transducer_stateless/train.py index 0c180b260..b562f9dd4 100755 --- a/egs/aishell/ASR/transducer_stateless/train.py +++ b/egs/aishell/ASR/transducer_stateless/train.py @@ -558,7 +558,8 @@ def run(rank, world_size, args): oov="", ) - params.blank_id = graph_compiler.texts_to_ids("")[0][0] + # params.blank_id = graph_compiler.texts_to_ids("")[0][0] + params.blank_id = 0 params.vocab_size = max(lexicon.tokens) + 1 logging.info(params)