mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
add transducer-stateless-aishell recipe to readme
This commit is contained in:
parent
1717b26cab
commit
f68dc1893f
11
README.md
11
README.md
@ -107,6 +107,17 @@ The best CER we currently have is:
|
|||||||
|
|
||||||
We provide a Colab notebook to run a pre-trained conformer CTC model: [](https://colab.research.google.com/drive/1WnG17io5HEZ0Gn_cnh_VzK5QYOoiiklC?usp=sharing)
|
We provide a Colab notebook to run a pre-trained conformer CTC model: [](https://colab.research.google.com/drive/1WnG17io5HEZ0Gn_cnh_VzK5QYOoiiklC?usp=sharing)
|
||||||
|
|
||||||
|
#### Transducer Stateless Model
|
||||||
|
|
||||||
|
The best CER we currently have is:
|
||||||
|
|
||||||
|
| | test |
|
||||||
|
|-----|------|
|
||||||
|
| CER | 5.7 |
|
||||||
|
|
||||||
|
|
||||||
|
We provide a Colab notebook to run a pre-trained TransducerStateless model: [](https://colab.research.google.com/drive/14XaT2MhnBkK-3_RqqWq3K90Xlbin-GZC#scrollTo=I30mgIz31SUF)
|
||||||
|
|
||||||
#### TDNN LSTM CTC Model
|
#### TDNN LSTM CTC Model
|
||||||
|
|
||||||
The CER for this model is:
|
The CER for this model is:
|
||||||
|
@ -40,6 +40,7 @@ from icefall.utils import (
|
|||||||
setup_logger,
|
setup_logger,
|
||||||
store_transcripts,
|
store_transcripts,
|
||||||
write_error_stats,
|
write_error_stats,
|
||||||
|
str2bool,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -108,6 +109,16 @@ def get_parser():
|
|||||||
default=3,
|
default=3,
|
||||||
help="Maximum number of symbols per frame",
|
help="Maximum number of symbols per frame",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--export",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="""When enabled, the averaged model is saved to
|
||||||
|
conformer_ctc/exp/pretrained.pt. Note: only model.state_dict() is saved.
|
||||||
|
pretrained.pt contains a dict {"model": model.state_dict()},
|
||||||
|
which can be loaded by `icefall.checkpoint.load_checkpoint()`.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
@ -417,6 +428,13 @@ def main():
|
|||||||
model.to(device)
|
model.to(device)
|
||||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||||
|
|
||||||
|
if params.export:
|
||||||
|
logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt")
|
||||||
|
torch.save(
|
||||||
|
{"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
model.to(device)
|
model.to(device)
|
||||||
model.eval()
|
model.eval()
|
||||||
model.device = device
|
model.device = device
|
||||||
|
@ -45,9 +45,9 @@ import argparse
|
|||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
from typing import List
|
from typing import List
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
import kaldifeat
|
import kaldifeat
|
||||||
import sentencepiece as spm
|
|
||||||
import torch
|
import torch
|
||||||
import torchaudio
|
import torchaudio
|
||||||
from beam_search import beam_search, greedy_search
|
from beam_search import beam_search, greedy_search
|
||||||
@ -59,6 +59,8 @@ from torch.nn.utils.rnn import pad_sequence
|
|||||||
|
|
||||||
from icefall.env import get_env_info
|
from icefall.env import get_env_info
|
||||||
from icefall.utils import AttributeDict
|
from icefall.utils import AttributeDict
|
||||||
|
from icefall.lexicon import Lexicon
|
||||||
|
from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
@ -76,9 +78,9 @@ def get_parser():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--bpe-model",
|
"--lang-dir",
|
||||||
type=str,
|
type=str,
|
||||||
help="""Path to bpe.model.
|
help="""Path to lang.
|
||||||
Used only when method is ctc-decoding.
|
Used only when method is ctc-decoding.
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
@ -220,18 +222,10 @@ def read_sound_files(
|
|||||||
def main():
|
def main():
|
||||||
parser = get_parser()
|
parser = get_parser()
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
args.lang_dir = Path(args.lang_dir)
|
||||||
|
|
||||||
params = get_params()
|
params = get_params()
|
||||||
|
|
||||||
params.update(vars(args))
|
params.update(vars(args))
|
||||||
|
|
||||||
sp = spm.SentencePieceProcessor()
|
|
||||||
sp.load(params.bpe_model)
|
|
||||||
|
|
||||||
# <blk> is defined in local/train_bpe_model.py
|
|
||||||
params.blank_id = sp.piece_to_id("<blk>")
|
|
||||||
params.vocab_size = sp.get_piece_size()
|
|
||||||
|
|
||||||
logging.info(f"{params}")
|
logging.info(f"{params}")
|
||||||
|
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
@ -240,6 +234,15 @@ def main():
|
|||||||
|
|
||||||
logging.info(f"device: {device}")
|
logging.info(f"device: {device}")
|
||||||
|
|
||||||
|
lexicon = Lexicon(params.lang_dir)
|
||||||
|
graph_compiler = CharCtcTrainingGraphCompiler(
|
||||||
|
lexicon=lexicon,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
|
||||||
|
params.blank_id = graph_compiler.texts_to_ids("<blk>")[0][0]
|
||||||
|
params.vocab_size = max(lexicon.tokens) + 1
|
||||||
|
|
||||||
logging.info("Creating model")
|
logging.info("Creating model")
|
||||||
model = get_transducer_model(params)
|
model = get_transducer_model(params)
|
||||||
|
|
||||||
@ -303,7 +306,7 @@ def main():
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported method: {params.method}")
|
raise ValueError(f"Unsupported method: {params.method}")
|
||||||
|
|
||||||
hyps.append(sp.decode(hyp).split())
|
hyps.append([lexicon.token_table[i] for i in hyp])
|
||||||
|
|
||||||
s = "\n"
|
s = "\n"
|
||||||
for filename, hyp in zip(params.sound_files, hyps):
|
for filename, hyp in zip(params.sound_files, hyps):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user