update export.py and pretrained_ctc.py

This commit is contained in:
yaozengwei 2024-05-26 17:46:20 +08:00
parent 84dfb5765b
commit acdc333971
4 changed files with 63 additions and 14 deletions

View File

@ -73,6 +73,29 @@ Usage:
--nbest-scale 1.0 \
--lm-dir data/lm \
--decoding-method whole-lattice-rescoring
(6) attention-decoder-rescoring-no-ngram
./zipformer/ctc_decode.py \
--epoch 30 \
--avg 15 \
--exp-dir ./zipformer/exp \
--use-ctc 1 \
--use-attention-decoder 1 \
--max-duration 100 \
--decoding-method attention-decoder-rescoring-no-ngram
(7) attention-decoder-rescoring-with-ngram
./zipformer/ctc_decode.py \
--epoch 30 \
--avg 15 \
--exp-dir ./zipformer/exp \
--use-ctc 1 \
--use-attention-decoder 1 \
--max-duration 100 \
--hlg-scale 0.6 \
--nbest-scale 1.0 \
--lm-dir data/lm \
--decoding-method attention-decoder-rescoring-with-ngram
"""
@ -101,10 +124,10 @@ from icefall.decode import (
nbest_decoding,
nbest_oracle,
one_best_decoding,
rescore_with_n_best_list,
rescore_with_whole_lattice,
rescore_with_attention_decoder_no_ngram,
rescore_with_attention_decoder_with_ngram,
rescore_with_n_best_list,
rescore_with_whole_lattice,
)
from icefall.lexicon import Lexicon
from icefall.utils import (
@ -214,6 +237,10 @@ def get_parser():
- (6) nbest-oracle. Its WER is the lower bound of any n-best
rescoring method can achieve. Useful for debugging n-best
rescoring method.
- (7) attention-decoder-rescoring-no-ngram. Extract n paths from the decoding
lattice, rescore them with the attention decoder.
- (8) attention-decoder-rescoring-with-ngram. Extract n paths from the LM
rescored lattice, rescore them with the attention decoder.
""",
)

View File

@ -404,6 +404,7 @@ def main():
token_table = k2.SymbolTable.from_file(params.tokens)
params.blank_id = token_table["<blk>"]
params.sos_id = params.eos_id = token_table["<sos/eos>"]
params.vocab_size = num_tokens(token_table) + 1
logging.info(params)
@ -466,8 +467,6 @@ def main():
device=device,
)
)
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
assert params.avg > 0, params.avg
start = params.epoch - params.avg

View File

@ -81,6 +81,15 @@ Usage of this script:
--sample-rate 16000 \
/path/to/foo.wav \
/path/to/bar.wav
(5) attention-decoder-rescoring-no-ngram
./zipformer/pretrained_ctc.py \
--checkpoint ./zipformer/exp/pretrained.pt \
--tokens data/lang_bpe_500/tokens.txt \
--method attention-decoder-rescoring-no-ngram \
--sample-rate 16000 \
/path/to/foo.wav \
/path/to/bar.wav
"""
import argparse
@ -100,6 +109,7 @@ from train import add_model_arguments, get_model, get_params
from icefall.decode import (
get_lattice,
one_best_decoding,
rescore_with_attention_decoder_no_ngram,
rescore_with_n_best_list,
rescore_with_whole_lattice,
)
@ -172,6 +182,8 @@ def get_parser():
decoding lattice and then use 1best to decode the
rescored lattice.
We call it HLG decoding + whole-lattice n-gram LM rescoring.
(4) attention-decoder-rescoring-no-ngram. Extract n paths from the decoding
lattice, rescore them with the attention decoder.
""",
)
@ -276,6 +288,7 @@ def main():
token_table = k2.SymbolTable.from_file(params.tokens)
params.vocab_size = num_tokens(token_table) + 1 # +1 for blank
params.blank_id = token_table["<blk>"]
params.sos_id = params.eos_id = token_table["<sos/eos>"]
assert params.blank_id == 0
logging.info(f"{params}")
@ -333,16 +346,13 @@ def main():
dtype=torch.int32,
)
if params.method == "ctc-decoding":
logging.info("Use CTC decoding")
if params.method in ["ctc-decoding", "attention-decoder-rescoring-no-ngram"]:
max_token_id = params.vocab_size - 1
H = k2.ctc_topo(
max_token=max_token_id,
modified=False,
device=device,
)
lattice = get_lattice(
nnet_output=ctc_output,
decoding_graph=H,
@ -354,9 +364,23 @@ def main():
subsampling_factor=params.subsampling_factor,
)
if params.method == "ctc-decoding":
logging.info("Use CTC decoding")
best_path = one_best_decoding(
lattice=lattice, use_double_scores=params.use_double_scores
)
else:
logging.info("Use attention decoder rescoring without ngram")
best_path_dict = rescore_with_attention_decoder_no_ngram(
lattice=lattice,
num_paths=params.num_paths,
attention_decoder=model.attention_decoder,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
nbest_scale=params.nbest_scale,
)
best_path = next(iter(best_path_dict.values()))
token_ids = get_texts(best_path)
hyps = [[token_table[i] for i in ids] for ids in token_ids]
elif params.method in [
@ -430,7 +454,7 @@ def main():
raise ValueError(f"Unsupported decoding method: {params.method}")
s = "\n"
if params.method == "ctc-decoding":
if params.method in ["ctc-decoding", "attention-decoder-rescoring-no-ngram"]:
for filename, hyp in zip(params.sound_files, hyps):
words = "".join(hyp)
words = words.replace("", " ").strip()

View File

@ -1199,8 +1199,7 @@ def run(rank, world_size, args):
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.eos_id = sp.piece_to_id("<sos/eos>")
params.sos_id = sp.piece_to_id("<sos/eos>")
params.sos_id = params.eos_id = sp.piece_to_id("<sos/eos>")
params.vocab_size = sp.get_piece_size()
if not params.use_transducer: