mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
add transducer-stateless-aishell recipe to readme
This commit is contained in:
parent
1717b26cab
commit
f68dc1893f
11
README.md
11
README.md
@ -107,6 +107,17 @@ The best CER we currently have is:
|
||||
|
||||
We provide a Colab notebook to run a pre-trained conformer CTC model: [](https://colab.research.google.com/drive/1WnG17io5HEZ0Gn_cnh_VzK5QYOoiiklC?usp=sharing)
|
||||
|
||||
#### Transducer Stateless Model
|
||||
|
||||
The best CER we currently have is:
|
||||
|
||||
| | test |
|
||||
|-----|------|
|
||||
| CER | 5.7 |
|
||||
|
||||
|
||||
We provide a Colab notebook to run a pre-trained TransducerStateless model: [](https://colab.research.google.com/drive/14XaT2MhnBkK-3_RqqWq3K90Xlbin-GZC#scrollTo=I30mgIz31SUF)
|
||||
|
||||
#### TDNN LSTM CTC Model
|
||||
|
||||
The CER for this model is:
|
||||
|
@ -40,6 +40,7 @@ from icefall.utils import (
|
||||
setup_logger,
|
||||
store_transcripts,
|
||||
write_error_stats,
|
||||
str2bool,
|
||||
)
|
||||
|
||||
|
||||
@ -108,6 +109,16 @@ def get_parser():
|
||||
default=3,
|
||||
help="Maximum number of symbols per frame",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--export",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="""When enabled, the averaged model is saved to
|
||||
conformer_ctc/exp/pretrained.pt. Note: only model.state_dict() is saved.
|
||||
pretrained.pt contains a dict {"model": model.state_dict()},
|
||||
which can be loaded by `icefall.checkpoint.load_checkpoint()`.
|
||||
""",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
@ -417,6 +428,13 @@ def main():
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
|
||||
if params.export:
|
||||
logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt")
|
||||
torch.save(
|
||||
{"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt"
|
||||
)
|
||||
return
|
||||
|
||||
model.to(device)
|
||||
model.eval()
|
||||
model.device = device
|
||||
|
@ -45,9 +45,9 @@ import argparse
|
||||
import logging
|
||||
import math
|
||||
from typing import List
|
||||
from pathlib import Path
|
||||
|
||||
import kaldifeat
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torchaudio
|
||||
from beam_search import beam_search, greedy_search
|
||||
@ -59,6 +59,8 @@ from torch.nn.utils.rnn import pad_sequence
|
||||
|
||||
from icefall.env import get_env_info
|
||||
from icefall.utils import AttributeDict
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
|
||||
|
||||
|
||||
def get_parser():
|
||||
@ -76,9 +78,9 @@ def get_parser():
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--bpe-model",
|
||||
"--lang-dir",
|
||||
type=str,
|
||||
help="""Path to bpe.model.
|
||||
help="""Path to lang.
|
||||
Used only when method is ctc-decoding.
|
||||
""",
|
||||
)
|
||||
@ -220,18 +222,10 @@ def read_sound_files(
|
||||
def main():
|
||||
parser = get_parser()
|
||||
args = parser.parse_args()
|
||||
args.lang_dir = Path(args.lang_dir)
|
||||
|
||||
params = get_params()
|
||||
|
||||
params.update(vars(args))
|
||||
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.load(params.bpe_model)
|
||||
|
||||
# <blk> is defined in local/train_bpe_model.py
|
||||
params.blank_id = sp.piece_to_id("<blk>")
|
||||
params.vocab_size = sp.get_piece_size()
|
||||
|
||||
logging.info(f"{params}")
|
||||
|
||||
device = torch.device("cpu")
|
||||
@ -240,6 +234,15 @@ def main():
|
||||
|
||||
logging.info(f"device: {device}")
|
||||
|
||||
lexicon = Lexicon(params.lang_dir)
|
||||
graph_compiler = CharCtcTrainingGraphCompiler(
|
||||
lexicon=lexicon,
|
||||
device=device,
|
||||
)
|
||||
|
||||
params.blank_id = graph_compiler.texts_to_ids("<blk>")[0][0]
|
||||
params.vocab_size = max(lexicon.tokens) + 1
|
||||
|
||||
logging.info("Creating model")
|
||||
model = get_transducer_model(params)
|
||||
|
||||
@ -303,7 +306,7 @@ def main():
|
||||
else:
|
||||
raise ValueError(f"Unsupported method: {params.method}")
|
||||
|
||||
hyps.append(sp.decode(hyp).split())
|
||||
hyps.append([lexicon.token_table[i] for i in hyp])
|
||||
|
||||
s = "\n"
|
||||
for filename, hyp in zip(params.sound_files, hyps):
|
||||
|
Loading…
x
Reference in New Issue
Block a user