mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-27 02:34:21 +00:00
update export.py and pretrained_ctc.py
This commit is contained in:
parent
84dfb5765b
commit
acdc333971
@ -73,6 +73,29 @@ Usage:
|
|||||||
--nbest-scale 1.0 \
|
--nbest-scale 1.0 \
|
||||||
--lm-dir data/lm \
|
--lm-dir data/lm \
|
||||||
--decoding-method whole-lattice-rescoring
|
--decoding-method whole-lattice-rescoring
|
||||||
|
|
||||||
|
(6) attention-decoder-rescoring-no-ngram
|
||||||
|
./zipformer/ctc_decode.py \
|
||||||
|
--epoch 30 \
|
||||||
|
--avg 15 \
|
||||||
|
--exp-dir ./zipformer/exp \
|
||||||
|
--use-ctc 1 \
|
||||||
|
--use-attention-decoder 1 \
|
||||||
|
--max-duration 100 \
|
||||||
|
--decoding-method attention-decoder-rescoring-no-ngram
|
||||||
|
|
||||||
|
(7) attention-decoder-rescoring-with-ngram
|
||||||
|
./zipformer/ctc_decode.py \
|
||||||
|
--epoch 30 \
|
||||||
|
--avg 15 \
|
||||||
|
--exp-dir ./zipformer/exp \
|
||||||
|
--use-ctc 1 \
|
||||||
|
--use-attention-decoder 1 \
|
||||||
|
--max-duration 100 \
|
||||||
|
--hlg-scale 0.6 \
|
||||||
|
--nbest-scale 1.0 \
|
||||||
|
--lm-dir data/lm \
|
||||||
|
--decoding-method attention-decoder-rescoring-with-ngram
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
@ -101,10 +124,10 @@ from icefall.decode import (
|
|||||||
nbest_decoding,
|
nbest_decoding,
|
||||||
nbest_oracle,
|
nbest_oracle,
|
||||||
one_best_decoding,
|
one_best_decoding,
|
||||||
rescore_with_n_best_list,
|
|
||||||
rescore_with_whole_lattice,
|
|
||||||
rescore_with_attention_decoder_no_ngram,
|
rescore_with_attention_decoder_no_ngram,
|
||||||
rescore_with_attention_decoder_with_ngram,
|
rescore_with_attention_decoder_with_ngram,
|
||||||
|
rescore_with_n_best_list,
|
||||||
|
rescore_with_whole_lattice,
|
||||||
)
|
)
|
||||||
from icefall.lexicon import Lexicon
|
from icefall.lexicon import Lexicon
|
||||||
from icefall.utils import (
|
from icefall.utils import (
|
||||||
@ -214,6 +237,10 @@ def get_parser():
|
|||||||
- (6) nbest-oracle. Its WER is the lower bound of any n-best
|
- (6) nbest-oracle. Its WER is the lower bound of any n-best
|
||||||
rescoring method can achieve. Useful for debugging n-best
|
rescoring method can achieve. Useful for debugging n-best
|
||||||
rescoring method.
|
rescoring method.
|
||||||
|
- (7) attention-decoder-rescoring-no-ngram. Extract n paths from the decoding
|
||||||
|
lattice, rescore them with the attention decoder.
|
||||||
|
- (8) attention-decoder-rescoring-with-ngram. Extract n paths from the LM
|
||||||
|
rescored lattice, rescore them with the attention decoder.
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -404,6 +404,7 @@ def main():
|
|||||||
|
|
||||||
token_table = k2.SymbolTable.from_file(params.tokens)
|
token_table = k2.SymbolTable.from_file(params.tokens)
|
||||||
params.blank_id = token_table["<blk>"]
|
params.blank_id = token_table["<blk>"]
|
||||||
|
params.sos_id = params.eos_id = token_table["<sos/eos>"]
|
||||||
params.vocab_size = num_tokens(token_table) + 1
|
params.vocab_size = num_tokens(token_table) + 1
|
||||||
|
|
||||||
logging.info(params)
|
logging.info(params)
|
||||||
@ -466,8 +467,6 @@ def main():
|
|||||||
device=device,
|
device=device,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
elif params.avg == 1:
|
|
||||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
|
||||||
else:
|
else:
|
||||||
assert params.avg > 0, params.avg
|
assert params.avg > 0, params.avg
|
||||||
start = params.epoch - params.avg
|
start = params.epoch - params.avg
|
||||||
|
@ -81,6 +81,15 @@ Usage of this script:
|
|||||||
--sample-rate 16000 \
|
--sample-rate 16000 \
|
||||||
/path/to/foo.wav \
|
/path/to/foo.wav \
|
||||||
/path/to/bar.wav
|
/path/to/bar.wav
|
||||||
|
|
||||||
|
(5) attention-decoder-rescoring-no-ngram
|
||||||
|
./zipformer/pretrained_ctc.py \
|
||||||
|
--checkpoint ./zipformer/exp/pretrained.pt \
|
||||||
|
--tokens data/lang_bpe_500/tokens.txt \
|
||||||
|
--method attention-decoder-rescoring-no-ngram \
|
||||||
|
--sample-rate 16000 \
|
||||||
|
/path/to/foo.wav \
|
||||||
|
/path/to/bar.wav
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
@ -100,6 +109,7 @@ from train import add_model_arguments, get_model, get_params
|
|||||||
from icefall.decode import (
|
from icefall.decode import (
|
||||||
get_lattice,
|
get_lattice,
|
||||||
one_best_decoding,
|
one_best_decoding,
|
||||||
|
rescore_with_attention_decoder_no_ngram,
|
||||||
rescore_with_n_best_list,
|
rescore_with_n_best_list,
|
||||||
rescore_with_whole_lattice,
|
rescore_with_whole_lattice,
|
||||||
)
|
)
|
||||||
@ -172,6 +182,8 @@ def get_parser():
|
|||||||
decoding lattice and then use 1best to decode the
|
decoding lattice and then use 1best to decode the
|
||||||
rescored lattice.
|
rescored lattice.
|
||||||
We call it HLG decoding + whole-lattice n-gram LM rescoring.
|
We call it HLG decoding + whole-lattice n-gram LM rescoring.
|
||||||
|
(4) attention-decoder-rescoring-no-ngram. Extract n paths from the decoding
|
||||||
|
lattice, rescore them with the attention decoder.
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -276,6 +288,7 @@ def main():
|
|||||||
token_table = k2.SymbolTable.from_file(params.tokens)
|
token_table = k2.SymbolTable.from_file(params.tokens)
|
||||||
params.vocab_size = num_tokens(token_table) + 1 # +1 for blank
|
params.vocab_size = num_tokens(token_table) + 1 # +1 for blank
|
||||||
params.blank_id = token_table["<blk>"]
|
params.blank_id = token_table["<blk>"]
|
||||||
|
params.sos_id = params.eos_id = token_table["<sos/eos>"]
|
||||||
assert params.blank_id == 0
|
assert params.blank_id == 0
|
||||||
|
|
||||||
logging.info(f"{params}")
|
logging.info(f"{params}")
|
||||||
@ -333,16 +346,13 @@ def main():
|
|||||||
dtype=torch.int32,
|
dtype=torch.int32,
|
||||||
)
|
)
|
||||||
|
|
||||||
if params.method == "ctc-decoding":
|
if params.method in ["ctc-decoding", "attention-decoder-rescoring-no-ngram"]:
|
||||||
logging.info("Use CTC decoding")
|
|
||||||
max_token_id = params.vocab_size - 1
|
max_token_id = params.vocab_size - 1
|
||||||
|
|
||||||
H = k2.ctc_topo(
|
H = k2.ctc_topo(
|
||||||
max_token=max_token_id,
|
max_token=max_token_id,
|
||||||
modified=False,
|
modified=False,
|
||||||
device=device,
|
device=device,
|
||||||
)
|
)
|
||||||
|
|
||||||
lattice = get_lattice(
|
lattice = get_lattice(
|
||||||
nnet_output=ctc_output,
|
nnet_output=ctc_output,
|
||||||
decoding_graph=H,
|
decoding_graph=H,
|
||||||
@ -354,9 +364,23 @@ def main():
|
|||||||
subsampling_factor=params.subsampling_factor,
|
subsampling_factor=params.subsampling_factor,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if params.method == "ctc-decoding":
|
||||||
|
logging.info("Use CTC decoding")
|
||||||
best_path = one_best_decoding(
|
best_path = one_best_decoding(
|
||||||
lattice=lattice, use_double_scores=params.use_double_scores
|
lattice=lattice, use_double_scores=params.use_double_scores
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
logging.info("Use attention decoder rescoring without ngram")
|
||||||
|
best_path_dict = rescore_with_attention_decoder_no_ngram(
|
||||||
|
lattice=lattice,
|
||||||
|
num_paths=params.num_paths,
|
||||||
|
attention_decoder=model.attention_decoder,
|
||||||
|
encoder_out=encoder_out,
|
||||||
|
encoder_out_lens=encoder_out_lens,
|
||||||
|
nbest_scale=params.nbest_scale,
|
||||||
|
)
|
||||||
|
best_path = next(iter(best_path_dict.values()))
|
||||||
|
|
||||||
token_ids = get_texts(best_path)
|
token_ids = get_texts(best_path)
|
||||||
hyps = [[token_table[i] for i in ids] for ids in token_ids]
|
hyps = [[token_table[i] for i in ids] for ids in token_ids]
|
||||||
elif params.method in [
|
elif params.method in [
|
||||||
@ -430,7 +454,7 @@ def main():
|
|||||||
raise ValueError(f"Unsupported decoding method: {params.method}")
|
raise ValueError(f"Unsupported decoding method: {params.method}")
|
||||||
|
|
||||||
s = "\n"
|
s = "\n"
|
||||||
if params.method == "ctc-decoding":
|
if params.method in ["ctc-decoding", "attention-decoder-rescoring-no-ngram"]:
|
||||||
for filename, hyp in zip(params.sound_files, hyps):
|
for filename, hyp in zip(params.sound_files, hyps):
|
||||||
words = "".join(hyp)
|
words = "".join(hyp)
|
||||||
words = words.replace("▁", " ").strip()
|
words = words.replace("▁", " ").strip()
|
||||||
|
@ -1199,8 +1199,7 @@ def run(rank, world_size, args):
|
|||||||
|
|
||||||
# <blk> is defined in local/train_bpe_model.py
|
# <blk> is defined in local/train_bpe_model.py
|
||||||
params.blank_id = sp.piece_to_id("<blk>")
|
params.blank_id = sp.piece_to_id("<blk>")
|
||||||
params.eos_id = sp.piece_to_id("<sos/eos>")
|
params.sos_id = params.eos_id = sp.piece_to_id("<sos/eos>")
|
||||||
params.sos_id = sp.piece_to_id("<sos/eos>")
|
|
||||||
params.vocab_size = sp.get_piece_size()
|
params.vocab_size = sp.get_piece_size()
|
||||||
|
|
||||||
if not params.use_transducer:
|
if not params.use_transducer:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user