Add ctc-decoding to pretrained.py

This commit is contained in:
Mingshuang Luo 2021-10-12 11:00:08 +08:00
parent d54828e73a
commit 6bc027ee8a
2 changed files with 162 additions and 62 deletions

View File

@ -528,10 +528,56 @@ displays the help information.
It supports three decoding methods: It supports three decoding methods:
- CTC decoding
- HLG decoding - HLG decoding
- HLG + n-gram LM rescoring - HLG + n-gram LM rescoring
- HLG + n-gram LM rescoring + attention decoder rescoring - HLG + n-gram LM rescoring + attention decoder rescoring
CTC decoding
^^^^^^^^^^^^
CTC decoding uses the best path of the decoding lattice as the decoding result
without any LM or lexicon.
The command to run CTC decoding is:
.. code-block:: bash
$ cd egs/librispeech/ASR
$ ./conformer_ctc/pretrained.py \
--checkpoint ./tmp/icefall_asr_librispeech_conformer_ctc/exp/pretrained.pt \
--lang-dir ./tmp/icefall_asr_librispeech_conformer_ctc/data/lang_bpe \
./tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1089-134686-0001.flac \
./tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1221-135766-0001.flac \
./tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1221-135766-0002.flac
The output is given below:
.. code-block::
2021-10-09 21:06:57,154 INFO [pretrained.py:253] device: cuda:0
2021-10-09 21:06:57,154 INFO [pretrained.py:255] Creating model
2021-10-09 21:07:04,234 INFO [pretrained.py:272] Constructing Fbank computer
2021-10-09 21:07:04,235 INFO [pretrained.py:282] Reading sound files: ['./tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1089-134686-0001.flac', './tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1221-135766-0001.flac', './tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1221-135766-0002.flac']
2021-10-09 21:07:04,248 INFO [pretrained.py:288] Decoding started
2021-10-09 21:07:05,041 INFO [pretrained.py:306] Building CTC topology
2021-10-09 21:07:05,334 INFO [lexicon.py:113] Loading pre-compiled tmp/icefall_asr_librispeech_conformer_ctc/data/lang_bpe/Linv.pt
2021-10-09 21:07:05,380 INFO [pretrained.py:315] Loading BPE model
2021-10-09 21:07:07,905 INFO [pretrained.py:330] Use CTC decoding
2021-10-09 21:07:07,918 INFO [pretrained.py:407]
./tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1089-134686-0001.flac:
AFTER EARLY NIGHTFALL THE YELLOW LAMPS WOULD LIGHT UP HERE AND THERE THE SQUALID QUARTER OF THE BROTHELS
./tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1221-135766-0001.flac:
GOD AS A DIRECT CONSEQUENCE OF THE SIN WHICH MAN THUS PUNISHED HAD GIVEN HER A LOVELY CHILD WHOSE PLACE WAS ON THAT SAME DISHONOURED
BOSOM TO CONNECT HER PARENT FOR EVER WITH THE RACE AND DESCENT OF MORTALS AND TO BE FINALLY A BLESSED SOUL IN HEAVEN
./tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1221-135766-0002.flac:
YET THESE THOUGHTS AFFECTED HESTER PRYNNE LESS WITH HOPE THAN APPREHENSION
2021-10-09 21:07:07,918 INFO [pretrained.py:409] Decoding Done
HLG decoding HLG decoding
^^^^^^^^^^^^ ^^^^^^^^^^^^

View File

@ -1,5 +1,6 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) # Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang,
# Mingshuang Luo)
# #
# See ../../../../LICENSE for clarification regarding multiple authors # See ../../../../LICENSE for clarification regarding multiple authors
# #
@ -19,6 +20,7 @@
import argparse import argparse
import logging import logging
import math import math
import sentencepiece as spm
from typing import List from typing import List
import k2 import k2
@ -28,6 +30,7 @@ import torchaudio
from conformer import Conformer from conformer import Conformer
from torch.nn.utils.rnn import pad_sequence from torch.nn.utils.rnn import pad_sequence
from icefall.lexicon import Lexicon
from icefall.decode import ( from icefall.decode import (
get_lattice, get_lattice,
one_best_decoding, one_best_decoding,
@ -54,12 +57,17 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--words-file", "--words-file",
type=str, type=str,
required=True, default="./tmp/icefall_asr_librispeech_conformer_ctc/ \
data/lang_bpe/words.txt",
help="Path to words.txt", help="Path to words.txt",
) )
parser.add_argument( parser.add_argument(
"--HLG", type=str, required=True, help="Path to HLG.pt." "--HLG",
type=str,
default="./tmp/icefall_asr_librispeech_conformer_ctc/ \
data/lang_bpe/HLG.pt",
help="Path to HLG.pt.",
) )
parser.add_argument( parser.add_argument(
@ -68,6 +76,10 @@ def get_parser():
default="1best", default="1best",
help="""Decoding method. help="""Decoding method.
Possible values are: Possible values are:
(0) ctc-decoding - Use CTC decoding. It uses a sentence
piece model, i.e., lang_dir/bpe.model, to convert
word pieces to words. It needs neither a lexicon
nor an n-gram LM.
(1) 1best - Use the best path as decoding output. Only (1) 1best - Use the best path as decoding output. Only
the transformer encoder output is used for decoding. the transformer encoder output is used for decoding.
We call it HLG decoding. We call it HLG decoding.
@ -157,6 +169,14 @@ def get_parser():
""", """,
) )
parser.add_argument(
"--lang-dir",
type=str,
default="./tmp/icefall_asr_librispeech_conformer_ctc/ \
data/lang_bpe",
help="Path to lang bpe dir.",
)
parser.add_argument( parser.add_argument(
"sound_files", "sound_files",
type=str, type=str,
@ -249,23 +269,6 @@ def main():
model.to(device) model.to(device)
model.eval() model.eval()
logging.info(f"Loading HLG from {params.HLG}")
HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu"))
HLG = HLG.to(device)
if not hasattr(HLG, "lm_scores"):
# For whole-lattice-rescoring and attention-decoder
HLG.lm_scores = HLG.scores.clone()
if params.method in ["whole-lattice-rescoring", "attention-decoder"]:
logging.info(f"Loading G from {params.G}")
G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu"))
# Add epsilon self-loops to G as we will compose
# it with the whole lattice later
G = G.to(device)
G = k2.add_epsilon_self_loops(G)
G = k2.arc_sort(G)
G.lm_scores = G.scores.clone()
logging.info("Constructing Fbank computer") logging.info("Constructing Fbank computer")
opts = kaldifeat.FbankOptions() opts = kaldifeat.FbankOptions()
opts.device = device opts.device = device
@ -299,52 +302,103 @@ def main():
dtype=torch.int32, dtype=torch.int32,
) )
lattice = get_lattice( if params.method == "ctc-decoding":
nnet_output=nnet_output, logging.info("Building CTC topology")
decoding_graph=HLG, lexicon = Lexicon(params.lang_dir)
supervision_segments=supervision_segments, max_token_id = max(lexicon.tokens)
search_beam=params.search_beam, H = k2.ctc_topo(
output_beam=params.output_beam, max_token=max_token_id,
min_active_states=params.min_active_states, modified=False,
max_active_states=params.max_active_states, device=device,
subsampling_factor=params.subsampling_factor, )
)
if params.method == "1best": logging.info("Loading BPE model")
logging.info("Use HLG decoding") bpe_model = spm.SentencePieceProcessor()
bpe_model.load(str(params.lang_dir + "/bpe.model"))
lattice = get_lattice(
nnet_output=nnet_output,
decoding_graph=H,
supervision_segments=supervision_segments,
search_beam=params.search_beam,
output_beam=params.output_beam,
min_active_states=params.min_active_states,
max_active_states=params.max_active_states,
subsampling_factor=params.subsampling_factor,
)
logging.info("Use CTC decoding")
best_path = one_best_decoding( best_path = one_best_decoding(
lattice=lattice, use_double_scores=params.use_double_scores lattice=lattice, use_double_scores=params.use_double_scores
) )
elif params.method == "whole-lattice-rescoring": token_ids = get_texts(best_path)
logging.info("Use HLG decoding + LM rescoring") hyps = bpe_model.decode(token_ids)
best_path_dict = rescore_with_whole_lattice( hyps = [s.split() for s in hyps]
lattice=lattice,
G_with_epsilon_loops=G,
lm_scale_list=[params.ngram_lm_scale],
)
best_path = next(iter(best_path_dict.values()))
elif params.method == "attention-decoder":
logging.info("Use HLG + LM rescoring + attention decoder rescoring")
rescored_lattice = rescore_with_whole_lattice(
lattice=lattice, G_with_epsilon_loops=G, lm_scale_list=None
)
best_path_dict = rescore_with_attention_decoder(
lattice=rescored_lattice,
num_paths=params.num_paths,
model=model,
memory=memory,
memory_key_padding_mask=memory_key_padding_mask,
sos_id=params.sos_id,
eos_id=params.eos_id,
nbest_scale=params.nbest_scale,
ngram_lm_scale=params.ngram_lm_scale,
attention_scale=params.attention_decoder_scale,
)
best_path = next(iter(best_path_dict.values()))
hyps = get_texts(best_path) else:
word_sym_table = k2.SymbolTable.from_file(params.words_file) logging.info(f"Loading HLG from {params.HLG}")
hyps = [[word_sym_table[i] for i in ids] for ids in hyps] HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu"))
HLG = HLG.to(device)
if not hasattr(HLG, "lm_scores"):
# For whole-lattice-rescoring and attention-decoder
HLG.lm_scores = HLG.scores.clone()
if params.method in ["whole-lattice-rescoring", "attention-decoder"]:
logging.info(f"Loading G from {params.G}")
G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu"))
# Add epsilon self-loops to G as we will compose
# it with the whole lattice later
G = G.to(device)
G = k2.add_epsilon_self_loops(G)
G = k2.arc_sort(G)
G.lm_scores = G.scores.clone()
lattice = get_lattice(
nnet_output=nnet_output,
decoding_graph=HLG,
supervision_segments=supervision_segments,
search_beam=params.search_beam,
output_beam=params.output_beam,
min_active_states=params.min_active_states,
max_active_states=params.max_active_states,
subsampling_factor=params.subsampling_factor,
)
if params.method == "1best":
logging.info("Use HLG decoding")
best_path = one_best_decoding(
lattice=lattice, use_double_scores=params.use_double_scores
)
elif params.method == "whole-lattice-rescoring":
logging.info("Use HLG decoding + LM rescoring")
best_path_dict = rescore_with_whole_lattice(
lattice=lattice,
G_with_epsilon_loops=G,
lm_scale_list=[params.ngram_lm_scale],
)
best_path = next(iter(best_path_dict.values()))
elif params.method == "attention-decoder":
logging.info("Use HLG + LM rescoring + attention decoder rescoring")
rescored_lattice = rescore_with_whole_lattice(
lattice=lattice, G_with_epsilon_loops=G, lm_scale_list=None
)
best_path_dict = rescore_with_attention_decoder(
lattice=rescored_lattice,
num_paths=params.num_paths,
model=model,
memory=memory,
memory_key_padding_mask=memory_key_padding_mask,
sos_id=params.sos_id,
eos_id=params.eos_id,
nbest_scale=params.nbest_scale,
ngram_lm_scale=params.ngram_lm_scale,
attention_scale=params.attention_decoder_scale,
)
best_path = next(iter(best_path_dict.values()))
hyps = get_texts(best_path)
word_sym_table = k2.SymbolTable.from_file(params.words_file)
hyps = [[word_sym_table[i] for i in ids] for ids in hyps]
s = "\n" s = "\n"
for filename, hyp in zip(params.sound_files, hyps): for filename, hyp in zip(params.sound_files, hyps):