customize decoding script

This commit is contained in:
Triplecq 2024-01-14 17:29:22 -05:00
parent 04fa9e3e8c
commit 7b6a89749d

View File

@ -103,10 +103,10 @@ from pathlib import Path
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
import k2 import k2
import sentencepiece as spm from tokenizer import Tokenizer
import torch import torch
import torch.nn as nn import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule from asr_datamodule import ReazonSpeechAsrDataModule
from beam_search import ( from beam_search import (
beam_search, beam_search,
fast_beam_search_nbest, fast_beam_search_nbest,
@ -204,7 +204,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--lang-dir", "--lang-dir",
type=Path, type=Path,
default="data/lang_bpe_500", default="data/lang_char",
help="The lang dir containing word table and LG graph", help="The lang dir containing word table and LG graph",
) )
@ -378,7 +378,7 @@ def get_parser():
def decode_one_batch( def decode_one_batch(
params: AttributeDict, params: AttributeDict,
model: nn.Module, model: nn.Module,
sp: spm.SentencePieceProcessor, sp: Tokenizer,
batch: dict, batch: dict,
word_table: Optional[k2.SymbolTable] = None, word_table: Optional[k2.SymbolTable] = None,
decoding_graph: Optional[k2.Fsa] = None, decoding_graph: Optional[k2.Fsa] = None,
@ -459,7 +459,7 @@ def decode_one_batch(
max_states=params.max_states, max_states=params.max_states,
) )
for hyp in sp.decode(hyp_tokens): for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split()) hyps.append(sp.text2word(hyp))
elif params.decoding_method == "fast_beam_search_nbest_LG": elif params.decoding_method == "fast_beam_search_nbest_LG":
hyp_tokens = fast_beam_search_nbest_LG( hyp_tokens = fast_beam_search_nbest_LG(
model=model, model=model,
@ -487,7 +487,7 @@ def decode_one_batch(
nbest_scale=params.nbest_scale, nbest_scale=params.nbest_scale,
) )
for hyp in sp.decode(hyp_tokens): for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split()) hyps.append(sp.text2word(hyp))
elif params.decoding_method == "fast_beam_search_nbest_oracle": elif params.decoding_method == "fast_beam_search_nbest_oracle":
hyp_tokens = fast_beam_search_nbest_oracle( hyp_tokens = fast_beam_search_nbest_oracle(
model=model, model=model,
@ -502,7 +502,7 @@ def decode_one_batch(
nbest_scale=params.nbest_scale, nbest_scale=params.nbest_scale,
) )
for hyp in sp.decode(hyp_tokens): for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split()) hyps.append(sp.text2word(hyp))
elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
hyp_tokens = greedy_search_batch( hyp_tokens = greedy_search_batch(
model=model, model=model,
@ -510,7 +510,7 @@ def decode_one_batch(
encoder_out_lens=encoder_out_lens, encoder_out_lens=encoder_out_lens,
) )
for hyp in sp.decode(hyp_tokens): for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split()) hyps.append(sp.text2word(hyp))
elif params.decoding_method == "modified_beam_search": elif params.decoding_method == "modified_beam_search":
hyp_tokens = modified_beam_search( hyp_tokens = modified_beam_search(
model=model, model=model,
@ -520,7 +520,7 @@ def decode_one_batch(
context_graph=context_graph, context_graph=context_graph,
) )
for hyp in sp.decode(hyp_tokens): for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split()) hyps.append(sp.text2word(hyp))
elif params.decoding_method == "modified_beam_search_lm_shallow_fusion": elif params.decoding_method == "modified_beam_search_lm_shallow_fusion":
hyp_tokens = modified_beam_search_lm_shallow_fusion( hyp_tokens = modified_beam_search_lm_shallow_fusion(
model=model, model=model,
@ -530,7 +530,7 @@ def decode_one_batch(
LM=LM, LM=LM,
) )
for hyp in sp.decode(hyp_tokens): for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split()) hyps.append(sp.text2word(hyp))
elif params.decoding_method == "modified_beam_search_LODR": elif params.decoding_method == "modified_beam_search_LODR":
hyp_tokens = modified_beam_search_LODR( hyp_tokens = modified_beam_search_LODR(
model=model, model=model,
@ -543,7 +543,7 @@ def decode_one_batch(
context_graph=context_graph, context_graph=context_graph,
) )
for hyp in sp.decode(hyp_tokens): for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split()) hyps.append(sp.text2word(hyp))
elif params.decoding_method == "modified_beam_search_lm_rescore": elif params.decoding_method == "modified_beam_search_lm_rescore":
lm_scale_list = [0.01 * i for i in range(10, 50)] lm_scale_list = [0.01 * i for i in range(10, 50)]
ans_dict = modified_beam_search_lm_rescore( ans_dict = modified_beam_search_lm_rescore(
@ -589,7 +589,7 @@ def decode_one_batch(
raise ValueError( raise ValueError(
f"Unsupported decoding method: {params.decoding_method}" f"Unsupported decoding method: {params.decoding_method}"
) )
hyps.append(sp.decode(hyp).split()) hyps.append(sp.text2word(sp.decode(hyp)))
if params.decoding_method == "greedy_search": if params.decoding_method == "greedy_search":
return {"greedy_search": hyps} return {"greedy_search": hyps}
@ -628,7 +628,7 @@ def decode_dataset(
dl: torch.utils.data.DataLoader, dl: torch.utils.data.DataLoader,
params: AttributeDict, params: AttributeDict,
model: nn.Module, model: nn.Module,
sp: spm.SentencePieceProcessor, sp: Tokenizer,
word_table: Optional[k2.SymbolTable] = None, word_table: Optional[k2.SymbolTable] = None,
decoding_graph: Optional[k2.Fsa] = None, decoding_graph: Optional[k2.Fsa] = None,
context_graph: Optional[ContextGraph] = None, context_graph: Optional[ContextGraph] = None,
@ -694,7 +694,7 @@ def decode_dataset(
this_batch = [] this_batch = []
assert len(hyps) == len(texts) assert len(hyps) == len(texts)
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
ref_words = ref_text.split() ref_words = sp.text2word(ref_text)
this_batch.append((cut_id, ref_words, hyp_words)) this_batch.append((cut_id, ref_words, hyp_words))
results[name].extend(this_batch) results[name].extend(this_batch)
@ -755,8 +755,8 @@ def save_results(
@torch.no_grad() @torch.no_grad()
def main(): def main():
parser = get_parser() parser = get_parser()
LibriSpeechAsrDataModule.add_arguments(parser) ReazonSpeechAsrDataModule.add_arguments(parser)
LmScorer.add_arguments(parser) Tokenizer.add_arguments(parser)
args = parser.parse_args() args = parser.parse_args()
args.exp_dir = Path(args.exp_dir) args.exp_dir = Path(args.exp_dir)
@ -839,10 +839,9 @@ def main():
logging.info(f"Device: {device}") logging.info(f"Device: {device}")
sp = spm.SentencePieceProcessor() sp = Tokenizer.load(params.lang, params.lang_type)
sp.load(params.bpe_model)
# <blk> and <unk> are defined in local/train_bpe_model.py # <blk> and <unk> are defined in local/prepare_lang_char.py
params.blank_id = sp.piece_to_id("<blk>") params.blank_id = sp.piece_to_id("<blk>")
params.unk_id = sp.piece_to_id("<unk>") params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size() params.vocab_size = sp.get_piece_size()
@ -1014,20 +1013,11 @@ def main():
# we need cut ids to display recognition results. # we need cut ids to display recognition results.
args.return_cuts = True args.return_cuts = True
librispeech = LibriSpeechAsrDataModule(args) reazonspeech_corpus = ReazonSpeechAsrDataModule(args)
test_clean_cuts = librispeech.test_clean_cuts() for subdir in ["valid"]:
test_other_cuts = librispeech.test_other_cuts()
test_clean_dl = librispeech.test_dataloaders(test_clean_cuts)
test_other_dl = librispeech.test_dataloaders(test_other_cuts)
test_sets = ["test-clean", "test-other"]
test_dl = [test_clean_dl, test_other_dl]
for test_set, test_dl in zip(test_sets, test_dl):
results_dict = decode_dataset( results_dict = decode_dataset(
dl=test_dl, dl=reazonspeech_corpus.test_dataloaders(getattr(reazonspeech_corpus, f"{subdir}_cuts")()),
params=params, params=params,
model=model, model=model,
sp=sp, sp=sp,
@ -1038,12 +1028,22 @@ def main():
ngram_lm=ngram_lm, ngram_lm=ngram_lm,
ngram_lm_scale=ngram_lm_scale, ngram_lm_scale=ngram_lm_scale,
) )
tot_err = save_results(
save_results(
params=params, params=params,
test_set_name=test_set, test_set_name=subdir,
results_dict=results_dict, results_dict=results_dict,
) )
with (
params.res_dir
/ (
f"{subdir}-{params.decode_chunk_len}_{params.beam_size}"
f"_{params.avg}_{params.epoch}.cer"
)
).open("w") as fout:
if len(tot_err) == 1:
fout.write(f"{tot_err[0][1]}")
else:
fout.write("\n".join(f"{k}\t{v}") for k, v in tot_err)
logging.info("Done!") logging.info("Done!")