update export.py and pretrained_ctc.py

This commit is contained in:
yaozengwei 2024-05-26 17:46:20 +08:00
parent 84dfb5765b
commit acdc333971
4 changed files with 63 additions and 14 deletions

View File

@ -73,6 +73,29 @@ Usage:
--nbest-scale 1.0 \ --nbest-scale 1.0 \
--lm-dir data/lm \ --lm-dir data/lm \
--decoding-method whole-lattice-rescoring --decoding-method whole-lattice-rescoring
(6) attention-decoder-rescoring-no-ngram
./zipformer/ctc_decode.py \
--epoch 30 \
--avg 15 \
--exp-dir ./zipformer/exp \
--use-ctc 1 \
--use-attention-decoder 1 \
--max-duration 100 \
--decoding-method attention-decoder-rescoring-no-ngram
(7) attention-decoder-rescoring-with-ngram
./zipformer/ctc_decode.py \
--epoch 30 \
--avg 15 \
--exp-dir ./zipformer/exp \
--use-ctc 1 \
--use-attention-decoder 1 \
--max-duration 100 \
--hlg-scale 0.6 \
--nbest-scale 1.0 \
--lm-dir data/lm \
--decoding-method attention-decoder-rescoring-with-ngram
""" """
@ -101,10 +124,10 @@ from icefall.decode import (
nbest_decoding, nbest_decoding,
nbest_oracle, nbest_oracle,
one_best_decoding, one_best_decoding,
rescore_with_n_best_list,
rescore_with_whole_lattice,
rescore_with_attention_decoder_no_ngram, rescore_with_attention_decoder_no_ngram,
rescore_with_attention_decoder_with_ngram, rescore_with_attention_decoder_with_ngram,
rescore_with_n_best_list,
rescore_with_whole_lattice,
) )
from icefall.lexicon import Lexicon from icefall.lexicon import Lexicon
from icefall.utils import ( from icefall.utils import (
@ -214,6 +237,10 @@ def get_parser():
- (6) nbest-oracle. Its WER is the lower bound of any n-best - (6) nbest-oracle. Its WER is the lower bound of any n-best
rescoring method can achieve. Useful for debugging n-best rescoring method can achieve. Useful for debugging n-best
rescoring method. rescoring method.
- (7) attention-decoder-rescoring-no-ngram. Extract n paths from the decoding
lattice, rescore them with the attention decoder.
- (8) attention-decoder-rescoring-with-ngram. Extract n paths from the LM
rescored lattice, rescore them with the attention decoder.
""", """,
) )

View File

@ -404,6 +404,7 @@ def main():
token_table = k2.SymbolTable.from_file(params.tokens) token_table = k2.SymbolTable.from_file(params.tokens)
params.blank_id = token_table["<blk>"] params.blank_id = token_table["<blk>"]
params.sos_id = params.eos_id = token_table["<sos/eos>"]
params.vocab_size = num_tokens(token_table) + 1 params.vocab_size = num_tokens(token_table) + 1
logging.info(params) logging.info(params)
@ -466,8 +467,6 @@ def main():
device=device, device=device,
) )
) )
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else: else:
assert params.avg > 0, params.avg assert params.avg > 0, params.avg
start = params.epoch - params.avg start = params.epoch - params.avg

View File

@ -81,6 +81,15 @@ Usage of this script:
--sample-rate 16000 \ --sample-rate 16000 \
/path/to/foo.wav \ /path/to/foo.wav \
/path/to/bar.wav /path/to/bar.wav
(5) attention-decoder-rescoring-no-ngram
./zipformer/pretrained_ctc.py \
--checkpoint ./zipformer/exp/pretrained.pt \
--tokens data/lang_bpe_500/tokens.txt \
--method attention-decoder-rescoring-no-ngram \
--sample-rate 16000 \
/path/to/foo.wav \
/path/to/bar.wav
""" """
import argparse import argparse
@ -100,6 +109,7 @@ from train import add_model_arguments, get_model, get_params
from icefall.decode import ( from icefall.decode import (
get_lattice, get_lattice,
one_best_decoding, one_best_decoding,
rescore_with_attention_decoder_no_ngram,
rescore_with_n_best_list, rescore_with_n_best_list,
rescore_with_whole_lattice, rescore_with_whole_lattice,
) )
@ -172,6 +182,8 @@ def get_parser():
decoding lattice and then use 1best to decode the decoding lattice and then use 1best to decode the
rescored lattice. rescored lattice.
We call it HLG decoding + whole-lattice n-gram LM rescoring. We call it HLG decoding + whole-lattice n-gram LM rescoring.
(4) attention-decoder-rescoring-no-ngram. Extract n paths from the decoding
lattice, rescore them with the attention decoder.
""", """,
) )
@ -276,6 +288,7 @@ def main():
token_table = k2.SymbolTable.from_file(params.tokens) token_table = k2.SymbolTable.from_file(params.tokens)
params.vocab_size = num_tokens(token_table) + 1 # +1 for blank params.vocab_size = num_tokens(token_table) + 1 # +1 for blank
params.blank_id = token_table["<blk>"] params.blank_id = token_table["<blk>"]
params.sos_id = params.eos_id = token_table["<sos/eos>"]
assert params.blank_id == 0 assert params.blank_id == 0
logging.info(f"{params}") logging.info(f"{params}")
@ -333,16 +346,13 @@ def main():
dtype=torch.int32, dtype=torch.int32,
) )
if params.method == "ctc-decoding": if params.method in ["ctc-decoding", "attention-decoder-rescoring-no-ngram"]:
logging.info("Use CTC decoding")
max_token_id = params.vocab_size - 1 max_token_id = params.vocab_size - 1
H = k2.ctc_topo( H = k2.ctc_topo(
max_token=max_token_id, max_token=max_token_id,
modified=False, modified=False,
device=device, device=device,
) )
lattice = get_lattice( lattice = get_lattice(
nnet_output=ctc_output, nnet_output=ctc_output,
decoding_graph=H, decoding_graph=H,
@ -354,9 +364,23 @@ def main():
subsampling_factor=params.subsampling_factor, subsampling_factor=params.subsampling_factor,
) )
best_path = one_best_decoding( if params.method == "ctc-decoding":
lattice=lattice, use_double_scores=params.use_double_scores logging.info("Use CTC decoding")
) best_path = one_best_decoding(
lattice=lattice, use_double_scores=params.use_double_scores
)
else:
logging.info("Use attention decoder rescoring without ngram")
best_path_dict = rescore_with_attention_decoder_no_ngram(
lattice=lattice,
num_paths=params.num_paths,
attention_decoder=model.attention_decoder,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
nbest_scale=params.nbest_scale,
)
best_path = next(iter(best_path_dict.values()))
token_ids = get_texts(best_path) token_ids = get_texts(best_path)
hyps = [[token_table[i] for i in ids] for ids in token_ids] hyps = [[token_table[i] for i in ids] for ids in token_ids]
elif params.method in [ elif params.method in [
@ -430,7 +454,7 @@ def main():
raise ValueError(f"Unsupported decoding method: {params.method}") raise ValueError(f"Unsupported decoding method: {params.method}")
s = "\n" s = "\n"
if params.method == "ctc-decoding": if params.method in ["ctc-decoding", "attention-decoder-rescoring-no-ngram"]:
for filename, hyp in zip(params.sound_files, hyps): for filename, hyp in zip(params.sound_files, hyps):
words = "".join(hyp) words = "".join(hyp)
words = words.replace("", " ").strip() words = words.replace("", " ").strip()

View File

@ -1199,8 +1199,7 @@ def run(rank, world_size, args):
# <blk> is defined in local/train_bpe_model.py # <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>") params.blank_id = sp.piece_to_id("<blk>")
params.eos_id = sp.piece_to_id("<sos/eos>") params.sos_id = params.eos_id = sp.piece_to_id("<sos/eos>")
params.sos_id = sp.piece_to_id("<sos/eos>")
params.vocab_size = sp.get_piece_size() params.vocab_size = sp.get_piece_size()
if not params.use_transducer: if not params.use_transducer: