diff --git a/README.md b/README.md index ff93e8fad..cddc1de12 100644 --- a/README.md +++ b/README.md @@ -107,6 +107,17 @@ The best CER we currently have is: We provide a Colab notebook to run a pre-trained conformer CTC model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1WnG17io5HEZ0Gn_cnh_VzK5QYOoiiklC?usp=sharing) +#### Transducer Stateless Model + +The best CER we currently have is: + +| | test | +|-----|------| +| CER | 5.7 | + + +We provide a Colab notebook to run a pre-trained TransducerStateless model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/14XaT2MhnBkK-3_RqqWq3K90Xlbin-GZC#scrollTo=I30mgIz31SUF) + #### TDNN LSTM CTC Model The CER for this model is: diff --git a/egs/aishell/ASR/transducer_stateless/decode.py b/egs/aishell/ASR/transducer_stateless/decode.py index 22640131c..a915b971f 100755 --- a/egs/aishell/ASR/transducer_stateless/decode.py +++ b/egs/aishell/ASR/transducer_stateless/decode.py @@ -40,6 +40,7 @@ from icefall.utils import ( setup_logger, store_transcripts, write_error_stats, + str2bool, ) @@ -108,6 +109,16 @@ def get_parser(): default=3, help="Maximum number of symbols per frame", ) + parser.add_argument( + "--export", + type=str2bool, + default=False, + help="""When enabled, the averaged model is saved to + conformer_ctc/exp/pretrained.pt. Note: only model.state_dict() is saved. + pretrained.pt contains a dict {"model": model.state_dict()}, + which can be loaded by `icefall.checkpoint.load_checkpoint()`. + """, + ) return parser @@ -417,6 +428,13 @@ def main(): model.to(device) model.load_state_dict(average_checkpoints(filenames, device=device)) + if params.export: + logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt") + torch.save( + {"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt" + ) + return + model.to(device) model.eval() model.device = device diff --git a/egs/aishell/ASR/transducer_stateless/pretrained.py b/egs/aishell/ASR/transducer_stateless/pretrained.py index e5dba8f0e..65ac5f3ff 100755 --- a/egs/aishell/ASR/transducer_stateless/pretrained.py +++ b/egs/aishell/ASR/transducer_stateless/pretrained.py @@ -45,9 +45,9 @@ import argparse import logging import math from typing import List +from pathlib import Path import kaldifeat -import sentencepiece as spm import torch import torchaudio from beam_search import beam_search, greedy_search @@ -59,6 +59,8 @@ from torch.nn.utils.rnn import pad_sequence from icefall.env import get_env_info from icefall.utils import AttributeDict +from icefall.lexicon import Lexicon +from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler def get_parser(): @@ -76,9 +78,9 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--lang-dir", type=str, - help="""Path to bpe.model. + help="""Path to lang. Used only when method is ctc-decoding. """, ) @@ -220,18 +222,10 @@ def read_sound_files( def main(): parser = get_parser() args = parser.parse_args() + args.lang_dir = Path(args.lang_dir) params = get_params() - params.update(vars(args)) - - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) - - # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() - logging.info(f"{params}") device = torch.device("cpu") @@ -240,6 +234,15 @@ def main(): logging.info(f"device: {device}") + lexicon = Lexicon(params.lang_dir) + graph_compiler = CharCtcTrainingGraphCompiler( + lexicon=lexicon, + device=device, + ) + + params.blank_id = graph_compiler.texts_to_ids("")[0][0] + params.vocab_size = max(lexicon.tokens) + 1 + logging.info("Creating model") model = get_transducer_model(params) @@ -303,7 +306,7 @@ def main(): else: raise ValueError(f"Unsupported method: {params.method}") - hyps.append(sp.decode(hyp).split()) + hyps.append([lexicon.token_table[i] for i in hyp]) s = "\n" for filename, hyp in zip(params.sound_files, hyps):