Merge branch 'master' into wenetspeech

This commit is contained in:
PingFeng Luo 2021-12-31 18:28:22 +08:00
commit 503275e649
3 changed files with 45 additions and 13 deletions

View File

@ -107,6 +107,17 @@ The best CER we currently have is:
We provide a Colab notebook to run a pre-trained conformer CTC model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1WnG17io5HEZ0Gn_cnh_VzK5QYOoiiklC?usp=sharing)
#### Transducer Stateless Model
The best CER we currently have is:
| | test |
|-----|------|
| CER | 5.7 |
We provide a Colab notebook to run a pre-trained TransducerStateless model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/14XaT2MhnBkK-3_RqqWq3K90Xlbin-GZC#scrollTo=I30mgIz31SUF)
#### TDNN LSTM CTC Model
The CER for this model is:

View File

@ -40,6 +40,7 @@ from icefall.utils import (
setup_logger,
store_transcripts,
write_error_stats,
str2bool,
)
@ -108,6 +109,16 @@ def get_parser():
default=3,
help="Maximum number of symbols per frame",
)
parser.add_argument(
"--export",
type=str2bool,
default=False,
help="""When enabled, the averaged model is saved to
conformer_ctc/exp/pretrained.pt. Note: only model.state_dict() is saved.
pretrained.pt contains a dict {"model": model.state_dict()},
which can be loaded by `icefall.checkpoint.load_checkpoint()`.
""",
)
return parser
@ -417,6 +428,13 @@ def main():
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
if params.export:
logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt")
torch.save(
{"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt"
)
return
model.to(device)
model.eval()
model.device = device

View File

@ -45,9 +45,9 @@ import argparse
import logging
import math
from typing import List
from pathlib import Path
import kaldifeat
import sentencepiece as spm
import torch
import torchaudio
from beam_search import beam_search, greedy_search
@ -59,6 +59,8 @@ from torch.nn.utils.rnn import pad_sequence
from icefall.env import get_env_info
from icefall.utils import AttributeDict
from icefall.lexicon import Lexicon
from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
def get_parser():
@ -76,9 +78,9 @@ def get_parser():
)
parser.add_argument(
"--bpe-model",
"--lang-dir",
type=str,
help="""Path to bpe.model.
help="""Path to lang.
Used only when method is ctc-decoding.
""",
)
@ -220,18 +222,10 @@ def read_sound_files(
def main():
parser = get_parser()
args = parser.parse_args()
args.lang_dir = Path(args.lang_dir)
params = get_params()
params.update(vars(args))
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size()
logging.info(f"{params}")
device = torch.device("cpu")
@ -240,6 +234,15 @@ def main():
logging.info(f"device: {device}")
lexicon = Lexicon(params.lang_dir)
graph_compiler = CharCtcTrainingGraphCompiler(
lexicon=lexicon,
device=device,
)
params.blank_id = graph_compiler.texts_to_ids("<blk>")[0][0]
params.vocab_size = max(lexicon.tokens) + 1
logging.info("Creating model")
model = get_transducer_model(params)
@ -303,7 +306,7 @@ def main():
else:
raise ValueError(f"Unsupported method: {params.method}")
hyps.append(sp.decode(hyp).split())
hyps.append([lexicon.token_table[i] for i in hyp])
s = "\n"
for filename, hyp in zip(params.sound_files, hyps):