add transducer-stateless-aishell recipe to readme

This commit is contained in:
PingFeng Luo 2021-12-31 18:24:26 +08:00
parent 1717b26cab
commit f68dc1893f
3 changed files with 45 additions and 13 deletions

View File

@ -107,6 +107,17 @@ The best CER we currently have is:
We provide a Colab notebook to run a pre-trained conformer CTC model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1WnG17io5HEZ0Gn_cnh_VzK5QYOoiiklC?usp=sharing) We provide a Colab notebook to run a pre-trained conformer CTC model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1WnG17io5HEZ0Gn_cnh_VzK5QYOoiiklC?usp=sharing)
#### Transducer Stateless Model
The best CER we currently have is:
| | test |
|-----|------|
| CER | 5.7 |
We provide a Colab notebook to run a pre-trained TransducerStateless model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/14XaT2MhnBkK-3_RqqWq3K90Xlbin-GZC#scrollTo=I30mgIz31SUF)
#### TDNN LSTM CTC Model #### TDNN LSTM CTC Model
The CER for this model is: The CER for this model is:

View File

@ -40,6 +40,7 @@ from icefall.utils import (
setup_logger, setup_logger,
store_transcripts, store_transcripts,
write_error_stats, write_error_stats,
str2bool,
) )
@ -108,6 +109,16 @@ def get_parser():
default=3, default=3,
help="Maximum number of symbols per frame", help="Maximum number of symbols per frame",
) )
parser.add_argument(
"--export",
type=str2bool,
default=False,
help="""When enabled, the averaged model is saved to
conformer_ctc/exp/pretrained.pt. Note: only model.state_dict() is saved.
pretrained.pt contains a dict {"model": model.state_dict()},
which can be loaded by `icefall.checkpoint.load_checkpoint()`.
""",
)
return parser return parser
@ -417,6 +428,13 @@ def main():
model.to(device) model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device)) model.load_state_dict(average_checkpoints(filenames, device=device))
if params.export:
logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt")
torch.save(
{"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt"
)
return
model.to(device) model.to(device)
model.eval() model.eval()
model.device = device model.device = device

View File

@ -45,9 +45,9 @@ import argparse
import logging import logging
import math import math
from typing import List from typing import List
from pathlib import Path
import kaldifeat import kaldifeat
import sentencepiece as spm
import torch import torch
import torchaudio import torchaudio
from beam_search import beam_search, greedy_search from beam_search import beam_search, greedy_search
@ -59,6 +59,8 @@ from torch.nn.utils.rnn import pad_sequence
from icefall.env import get_env_info from icefall.env import get_env_info
from icefall.utils import AttributeDict from icefall.utils import AttributeDict
from icefall.lexicon import Lexicon
from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
def get_parser(): def get_parser():
@ -76,9 +78,9 @@ def get_parser():
) )
parser.add_argument( parser.add_argument(
"--bpe-model", "--lang-dir",
type=str, type=str,
help="""Path to bpe.model. help="""Path to lang.
Used only when method is ctc-decoding. Used only when method is ctc-decoding.
""", """,
) )
@ -220,18 +222,10 @@ def read_sound_files(
def main(): def main():
parser = get_parser() parser = get_parser()
args = parser.parse_args() args = parser.parse_args()
args.lang_dir = Path(args.lang_dir)
params = get_params() params = get_params()
params.update(vars(args)) params.update(vars(args))
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size()
logging.info(f"{params}") logging.info(f"{params}")
device = torch.device("cpu") device = torch.device("cpu")
@ -240,6 +234,15 @@ def main():
logging.info(f"device: {device}") logging.info(f"device: {device}")
lexicon = Lexicon(params.lang_dir)
graph_compiler = CharCtcTrainingGraphCompiler(
lexicon=lexicon,
device=device,
)
params.blank_id = graph_compiler.texts_to_ids("<blk>")[0][0]
params.vocab_size = max(lexicon.tokens) + 1
logging.info("Creating model") logging.info("Creating model")
model = get_transducer_model(params) model = get_transducer_model(params)
@ -303,7 +306,7 @@ def main():
else: else:
raise ValueError(f"Unsupported method: {params.method}") raise ValueError(f"Unsupported method: {params.method}")
hyps.append(sp.decode(hyp).split()) hyps.append([lexicon.token_table[i] for i in hyp])
s = "\n" s = "\n"
for filename, hyp in zip(params.sound_files, hyps): for filename, hyp in zip(params.sound_files, hyps):