mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
Add ctc-decoding to pretrained.py
This commit is contained in:
parent
d54828e73a
commit
6bc027ee8a
@ -528,10 +528,56 @@ displays the help information.
|
||||
|
||||
It supports three decoding methods:
|
||||
|
||||
- CTC decoding
|
||||
- HLG decoding
|
||||
- HLG + n-gram LM rescoring
|
||||
- HLG + n-gram LM rescoring + attention decoder rescoring
|
||||
|
||||
CTC decoding
|
||||
^^^^^^^^^^^^
|
||||
|
||||
CTC decoding uses the best path of the decoding lattice as the decoding result
|
||||
without any LM or lexicon.
|
||||
|
||||
The command to run CTC decoding is:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
$ cd egs/librispeech/ASR
|
||||
$ ./conformer_ctc/pretrained.py \
|
||||
--checkpoint ./tmp/icefall_asr_librispeech_conformer_ctc/exp/pretrained.pt \
|
||||
--lang-dir ./tmp/icefall_asr_librispeech_conformer_ctc/data/lang_bpe \
|
||||
./tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1089-134686-0001.flac \
|
||||
./tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1221-135766-0001.flac \
|
||||
./tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1221-135766-0002.flac
|
||||
|
||||
The output is given below:
|
||||
|
||||
.. code-block::
|
||||
|
||||
2021-10-09 21:06:57,154 INFO [pretrained.py:253] device: cuda:0
|
||||
2021-10-09 21:06:57,154 INFO [pretrained.py:255] Creating model
|
||||
2021-10-09 21:07:04,234 INFO [pretrained.py:272] Constructing Fbank computer
|
||||
2021-10-09 21:07:04,235 INFO [pretrained.py:282] Reading sound files: ['./tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1089-134686-0001.flac', './tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1221-135766-0001.flac', './tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1221-135766-0002.flac']
|
||||
2021-10-09 21:07:04,248 INFO [pretrained.py:288] Decoding started
|
||||
2021-10-09 21:07:05,041 INFO [pretrained.py:306] Building CTC topology
|
||||
2021-10-09 21:07:05,334 INFO [lexicon.py:113] Loading pre-compiled tmp/icefall_asr_librispeech_conformer_ctc/data/lang_bpe/Linv.pt
|
||||
2021-10-09 21:07:05,380 INFO [pretrained.py:315] Loading BPE model
|
||||
2021-10-09 21:07:07,905 INFO [pretrained.py:330] Use CTC decoding
|
||||
2021-10-09 21:07:07,918 INFO [pretrained.py:407]
|
||||
./tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1089-134686-0001.flac:
|
||||
AFTER EARLY NIGHTFALL THE YELLOW LAMPS WOULD LIGHT UP HERE AND THERE THE SQUALID QUARTER OF THE BROTHELS
|
||||
|
||||
./tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1221-135766-0001.flac:
|
||||
GOD AS A DIRECT CONSEQUENCE OF THE SIN WHICH MAN THUS PUNISHED HAD GIVEN HER A LOVELY CHILD WHOSE PLACE WAS ON THAT SAME DISHONOURED
|
||||
BOSOM TO CONNECT HER PARENT FOR EVER WITH THE RACE AND DESCENT OF MORTALS AND TO BE FINALLY A BLESSED SOUL IN HEAVEN
|
||||
|
||||
./tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1221-135766-0002.flac:
|
||||
YET THESE THOUGHTS AFFECTED HESTER PRYNNE LESS WITH HOPE THAN APPREHENSION
|
||||
|
||||
|
||||
2021-10-09 21:07:07,918 INFO [pretrained.py:409] Decoding Done
|
||||
|
||||
HLG decoding
|
||||
^^^^^^^^^^^^
|
||||
|
||||
|
@ -1,5 +1,6 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang,
|
||||
# Mingshuang Luo)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
@ -19,6 +20,7 @@
|
||||
import argparse
|
||||
import logging
|
||||
import math
|
||||
import sentencepiece as spm
|
||||
from typing import List
|
||||
|
||||
import k2
|
||||
@ -28,6 +30,7 @@ import torchaudio
|
||||
from conformer import Conformer
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.decode import (
|
||||
get_lattice,
|
||||
one_best_decoding,
|
||||
@ -54,12 +57,17 @@ def get_parser():
|
||||
parser.add_argument(
|
||||
"--words-file",
|
||||
type=str,
|
||||
required=True,
|
||||
default="./tmp/icefall_asr_librispeech_conformer_ctc/ \
|
||||
data/lang_bpe/words.txt",
|
||||
help="Path to words.txt",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--HLG", type=str, required=True, help="Path to HLG.pt."
|
||||
"--HLG",
|
||||
type=str,
|
||||
default="./tmp/icefall_asr_librispeech_conformer_ctc/ \
|
||||
data/lang_bpe/HLG.pt",
|
||||
help="Path to HLG.pt.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@ -68,6 +76,10 @@ def get_parser():
|
||||
default="1best",
|
||||
help="""Decoding method.
|
||||
Possible values are:
|
||||
(0) ctc-decoding - Use CTC decoding. It uses a sentence
|
||||
piece model, i.e., lang_dir/bpe.model, to convert
|
||||
word pieces to words. It needs neither a lexicon
|
||||
nor an n-gram LM.
|
||||
(1) 1best - Use the best path as decoding output. Only
|
||||
the transformer encoder output is used for decoding.
|
||||
We call it HLG decoding.
|
||||
@ -157,6 +169,14 @@ def get_parser():
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lang-dir",
|
||||
type=str,
|
||||
default="./tmp/icefall_asr_librispeech_conformer_ctc/ \
|
||||
data/lang_bpe",
|
||||
help="Path to lang bpe dir.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"sound_files",
|
||||
type=str,
|
||||
@ -249,23 +269,6 @@ def main():
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
||||
logging.info(f"Loading HLG from {params.HLG}")
|
||||
HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu"))
|
||||
HLG = HLG.to(device)
|
||||
if not hasattr(HLG, "lm_scores"):
|
||||
# For whole-lattice-rescoring and attention-decoder
|
||||
HLG.lm_scores = HLG.scores.clone()
|
||||
|
||||
if params.method in ["whole-lattice-rescoring", "attention-decoder"]:
|
||||
logging.info(f"Loading G from {params.G}")
|
||||
G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu"))
|
||||
# Add epsilon self-loops to G as we will compose
|
||||
# it with the whole lattice later
|
||||
G = G.to(device)
|
||||
G = k2.add_epsilon_self_loops(G)
|
||||
G = k2.arc_sort(G)
|
||||
G.lm_scores = G.scores.clone()
|
||||
|
||||
logging.info("Constructing Fbank computer")
|
||||
opts = kaldifeat.FbankOptions()
|
||||
opts.device = device
|
||||
@ -299,52 +302,103 @@ def main():
|
||||
dtype=torch.int32,
|
||||
)
|
||||
|
||||
lattice = get_lattice(
|
||||
nnet_output=nnet_output,
|
||||
decoding_graph=HLG,
|
||||
supervision_segments=supervision_segments,
|
||||
search_beam=params.search_beam,
|
||||
output_beam=params.output_beam,
|
||||
min_active_states=params.min_active_states,
|
||||
max_active_states=params.max_active_states,
|
||||
subsampling_factor=params.subsampling_factor,
|
||||
)
|
||||
if params.method == "ctc-decoding":
|
||||
logging.info("Building CTC topology")
|
||||
lexicon = Lexicon(params.lang_dir)
|
||||
max_token_id = max(lexicon.tokens)
|
||||
H = k2.ctc_topo(
|
||||
max_token=max_token_id,
|
||||
modified=False,
|
||||
device=device,
|
||||
)
|
||||
|
||||
if params.method == "1best":
|
||||
logging.info("Use HLG decoding")
|
||||
logging.info("Loading BPE model")
|
||||
bpe_model = spm.SentencePieceProcessor()
|
||||
bpe_model.load(str(params.lang_dir + "/bpe.model"))
|
||||
|
||||
lattice = get_lattice(
|
||||
nnet_output=nnet_output,
|
||||
decoding_graph=H,
|
||||
supervision_segments=supervision_segments,
|
||||
search_beam=params.search_beam,
|
||||
output_beam=params.output_beam,
|
||||
min_active_states=params.min_active_states,
|
||||
max_active_states=params.max_active_states,
|
||||
subsampling_factor=params.subsampling_factor,
|
||||
)
|
||||
|
||||
logging.info("Use CTC decoding")
|
||||
best_path = one_best_decoding(
|
||||
lattice=lattice, use_double_scores=params.use_double_scores
|
||||
)
|
||||
elif params.method == "whole-lattice-rescoring":
|
||||
logging.info("Use HLG decoding + LM rescoring")
|
||||
best_path_dict = rescore_with_whole_lattice(
|
||||
lattice=lattice,
|
||||
G_with_epsilon_loops=G,
|
||||
lm_scale_list=[params.ngram_lm_scale],
|
||||
)
|
||||
best_path = next(iter(best_path_dict.values()))
|
||||
elif params.method == "attention-decoder":
|
||||
logging.info("Use HLG + LM rescoring + attention decoder rescoring")
|
||||
rescored_lattice = rescore_with_whole_lattice(
|
||||
lattice=lattice, G_with_epsilon_loops=G, lm_scale_list=None
|
||||
)
|
||||
best_path_dict = rescore_with_attention_decoder(
|
||||
lattice=rescored_lattice,
|
||||
num_paths=params.num_paths,
|
||||
model=model,
|
||||
memory=memory,
|
||||
memory_key_padding_mask=memory_key_padding_mask,
|
||||
sos_id=params.sos_id,
|
||||
eos_id=params.eos_id,
|
||||
nbest_scale=params.nbest_scale,
|
||||
ngram_lm_scale=params.ngram_lm_scale,
|
||||
attention_scale=params.attention_decoder_scale,
|
||||
)
|
||||
best_path = next(iter(best_path_dict.values()))
|
||||
token_ids = get_texts(best_path)
|
||||
hyps = bpe_model.decode(token_ids)
|
||||
hyps = [s.split() for s in hyps]
|
||||
|
||||
hyps = get_texts(best_path)
|
||||
word_sym_table = k2.SymbolTable.from_file(params.words_file)
|
||||
hyps = [[word_sym_table[i] for i in ids] for ids in hyps]
|
||||
else:
|
||||
logging.info(f"Loading HLG from {params.HLG}")
|
||||
HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu"))
|
||||
HLG = HLG.to(device)
|
||||
if not hasattr(HLG, "lm_scores"):
|
||||
# For whole-lattice-rescoring and attention-decoder
|
||||
HLG.lm_scores = HLG.scores.clone()
|
||||
|
||||
if params.method in ["whole-lattice-rescoring", "attention-decoder"]:
|
||||
logging.info(f"Loading G from {params.G}")
|
||||
G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu"))
|
||||
# Add epsilon self-loops to G as we will compose
|
||||
# it with the whole lattice later
|
||||
G = G.to(device)
|
||||
G = k2.add_epsilon_self_loops(G)
|
||||
G = k2.arc_sort(G)
|
||||
G.lm_scores = G.scores.clone()
|
||||
|
||||
lattice = get_lattice(
|
||||
nnet_output=nnet_output,
|
||||
decoding_graph=HLG,
|
||||
supervision_segments=supervision_segments,
|
||||
search_beam=params.search_beam,
|
||||
output_beam=params.output_beam,
|
||||
min_active_states=params.min_active_states,
|
||||
max_active_states=params.max_active_states,
|
||||
subsampling_factor=params.subsampling_factor,
|
||||
)
|
||||
|
||||
if params.method == "1best":
|
||||
logging.info("Use HLG decoding")
|
||||
best_path = one_best_decoding(
|
||||
lattice=lattice, use_double_scores=params.use_double_scores
|
||||
)
|
||||
elif params.method == "whole-lattice-rescoring":
|
||||
logging.info("Use HLG decoding + LM rescoring")
|
||||
best_path_dict = rescore_with_whole_lattice(
|
||||
lattice=lattice,
|
||||
G_with_epsilon_loops=G,
|
||||
lm_scale_list=[params.ngram_lm_scale],
|
||||
)
|
||||
best_path = next(iter(best_path_dict.values()))
|
||||
elif params.method == "attention-decoder":
|
||||
logging.info("Use HLG + LM rescoring + attention decoder rescoring")
|
||||
rescored_lattice = rescore_with_whole_lattice(
|
||||
lattice=lattice, G_with_epsilon_loops=G, lm_scale_list=None
|
||||
)
|
||||
best_path_dict = rescore_with_attention_decoder(
|
||||
lattice=rescored_lattice,
|
||||
num_paths=params.num_paths,
|
||||
model=model,
|
||||
memory=memory,
|
||||
memory_key_padding_mask=memory_key_padding_mask,
|
||||
sos_id=params.sos_id,
|
||||
eos_id=params.eos_id,
|
||||
nbest_scale=params.nbest_scale,
|
||||
ngram_lm_scale=params.ngram_lm_scale,
|
||||
attention_scale=params.attention_decoder_scale,
|
||||
)
|
||||
best_path = next(iter(best_path_dict.values()))
|
||||
|
||||
hyps = get_texts(best_path)
|
||||
word_sym_table = k2.SymbolTable.from_file(params.words_file)
|
||||
hyps = [[word_sym_table[i] for i in ids] for ids in hyps]
|
||||
|
||||
s = "\n"
|
||||
for filename, hyp in zip(params.sound_files, hyps):
|
||||
|
Loading…
x
Reference in New Issue
Block a user