mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 10:16:14 +00:00
update export.py and pretrained_ctc.py
This commit is contained in:
parent
84dfb5765b
commit
acdc333971
@ -73,6 +73,29 @@ Usage:
|
||||
--nbest-scale 1.0 \
|
||||
--lm-dir data/lm \
|
||||
--decoding-method whole-lattice-rescoring
|
||||
|
||||
(6) attention-decoder-rescoring-no-ngram
|
||||
./zipformer/ctc_decode.py \
|
||||
--epoch 30 \
|
||||
--avg 15 \
|
||||
--exp-dir ./zipformer/exp \
|
||||
--use-ctc 1 \
|
||||
--use-attention-decoder 1 \
|
||||
--max-duration 100 \
|
||||
--decoding-method attention-decoder-rescoring-no-ngram
|
||||
|
||||
(7) attention-decoder-rescoring-with-ngram
|
||||
./zipformer/ctc_decode.py \
|
||||
--epoch 30 \
|
||||
--avg 15 \
|
||||
--exp-dir ./zipformer/exp \
|
||||
--use-ctc 1 \
|
||||
--use-attention-decoder 1 \
|
||||
--max-duration 100 \
|
||||
--hlg-scale 0.6 \
|
||||
--nbest-scale 1.0 \
|
||||
--lm-dir data/lm \
|
||||
--decoding-method attention-decoder-rescoring-with-ngram
|
||||
"""
|
||||
|
||||
|
||||
@ -101,10 +124,10 @@ from icefall.decode import (
|
||||
nbest_decoding,
|
||||
nbest_oracle,
|
||||
one_best_decoding,
|
||||
rescore_with_n_best_list,
|
||||
rescore_with_whole_lattice,
|
||||
rescore_with_attention_decoder_no_ngram,
|
||||
rescore_with_attention_decoder_with_ngram,
|
||||
rescore_with_n_best_list,
|
||||
rescore_with_whole_lattice,
|
||||
)
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import (
|
||||
@ -214,6 +237,10 @@ def get_parser():
|
||||
- (6) nbest-oracle. Its WER is the lower bound of any n-best
|
||||
rescoring method can achieve. Useful for debugging n-best
|
||||
rescoring method.
|
||||
- (7) attention-decoder-rescoring-no-ngram. Extract n paths from the decoding
|
||||
lattice, rescore them with the attention decoder.
|
||||
- (8) attention-decoder-rescoring-with-ngram. Extract n paths from the LM
|
||||
rescored lattice, rescore them with the attention decoder.
|
||||
""",
|
||||
)
|
||||
|
||||
|
@ -404,6 +404,7 @@ def main():
|
||||
|
||||
token_table = k2.SymbolTable.from_file(params.tokens)
|
||||
params.blank_id = token_table["<blk>"]
|
||||
params.sos_id = params.eos_id = token_table["<sos/eos>"]
|
||||
params.vocab_size = num_tokens(token_table) + 1
|
||||
|
||||
logging.info(params)
|
||||
@ -466,8 +467,6 @@ def main():
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
elif params.avg == 1:
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
else:
|
||||
assert params.avg > 0, params.avg
|
||||
start = params.epoch - params.avg
|
||||
|
@ -81,6 +81,15 @@ Usage of this script:
|
||||
--sample-rate 16000 \
|
||||
/path/to/foo.wav \
|
||||
/path/to/bar.wav
|
||||
|
||||
(5) attention-decoder-rescoring-no-ngram
|
||||
./zipformer/pretrained_ctc.py \
|
||||
--checkpoint ./zipformer/exp/pretrained.pt \
|
||||
--tokens data/lang_bpe_500/tokens.txt \
|
||||
--method attention-decoder-rescoring-no-ngram \
|
||||
--sample-rate 16000 \
|
||||
/path/to/foo.wav \
|
||||
/path/to/bar.wav
|
||||
"""
|
||||
|
||||
import argparse
|
||||
@ -100,6 +109,7 @@ from train import add_model_arguments, get_model, get_params
|
||||
from icefall.decode import (
|
||||
get_lattice,
|
||||
one_best_decoding,
|
||||
rescore_with_attention_decoder_no_ngram,
|
||||
rescore_with_n_best_list,
|
||||
rescore_with_whole_lattice,
|
||||
)
|
||||
@ -172,6 +182,8 @@ def get_parser():
|
||||
decoding lattice and then use 1best to decode the
|
||||
rescored lattice.
|
||||
We call it HLG decoding + whole-lattice n-gram LM rescoring.
|
||||
(4) attention-decoder-rescoring-no-ngram. Extract n paths from the decoding
|
||||
lattice, rescore them with the attention decoder.
|
||||
""",
|
||||
)
|
||||
|
||||
@ -276,6 +288,7 @@ def main():
|
||||
token_table = k2.SymbolTable.from_file(params.tokens)
|
||||
params.vocab_size = num_tokens(token_table) + 1 # +1 for blank
|
||||
params.blank_id = token_table["<blk>"]
|
||||
params.sos_id = params.eos_id = token_table["<sos/eos>"]
|
||||
assert params.blank_id == 0
|
||||
|
||||
logging.info(f"{params}")
|
||||
@ -333,16 +346,13 @@ def main():
|
||||
dtype=torch.int32,
|
||||
)
|
||||
|
||||
if params.method == "ctc-decoding":
|
||||
logging.info("Use CTC decoding")
|
||||
if params.method in ["ctc-decoding", "attention-decoder-rescoring-no-ngram"]:
|
||||
max_token_id = params.vocab_size - 1
|
||||
|
||||
H = k2.ctc_topo(
|
||||
max_token=max_token_id,
|
||||
modified=False,
|
||||
device=device,
|
||||
)
|
||||
|
||||
lattice = get_lattice(
|
||||
nnet_output=ctc_output,
|
||||
decoding_graph=H,
|
||||
@ -354,9 +364,23 @@ def main():
|
||||
subsampling_factor=params.subsampling_factor,
|
||||
)
|
||||
|
||||
best_path = one_best_decoding(
|
||||
lattice=lattice, use_double_scores=params.use_double_scores
|
||||
)
|
||||
if params.method == "ctc-decoding":
|
||||
logging.info("Use CTC decoding")
|
||||
best_path = one_best_decoding(
|
||||
lattice=lattice, use_double_scores=params.use_double_scores
|
||||
)
|
||||
else:
|
||||
logging.info("Use attention decoder rescoring without ngram")
|
||||
best_path_dict = rescore_with_attention_decoder_no_ngram(
|
||||
lattice=lattice,
|
||||
num_paths=params.num_paths,
|
||||
attention_decoder=model.attention_decoder,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
nbest_scale=params.nbest_scale,
|
||||
)
|
||||
best_path = next(iter(best_path_dict.values()))
|
||||
|
||||
token_ids = get_texts(best_path)
|
||||
hyps = [[token_table[i] for i in ids] for ids in token_ids]
|
||||
elif params.method in [
|
||||
@ -430,7 +454,7 @@ def main():
|
||||
raise ValueError(f"Unsupported decoding method: {params.method}")
|
||||
|
||||
s = "\n"
|
||||
if params.method == "ctc-decoding":
|
||||
if params.method in ["ctc-decoding", "attention-decoder-rescoring-no-ngram"]:
|
||||
for filename, hyp in zip(params.sound_files, hyps):
|
||||
words = "".join(hyp)
|
||||
words = words.replace("▁", " ").strip()
|
||||
|
@ -1199,8 +1199,7 @@ def run(rank, world_size, args):
|
||||
|
||||
# <blk> is defined in local/train_bpe_model.py
|
||||
params.blank_id = sp.piece_to_id("<blk>")
|
||||
params.eos_id = sp.piece_to_id("<sos/eos>")
|
||||
params.sos_id = sp.piece_to_id("<sos/eos>")
|
||||
params.sos_id = params.eos_id = sp.piece_to_id("<sos/eos>")
|
||||
params.vocab_size = sp.get_piece_size()
|
||||
|
||||
if not params.use_transducer:
|
||||
|
Loading…
x
Reference in New Issue
Block a user