customize decoding script

This commit is contained in:
Triplecq 2024-01-14 17:29:22 -05:00
parent 04fa9e3e8c
commit 7b6a89749d

View File

@ -103,10 +103,10 @@ from pathlib import Path
from typing import Dict, List, Optional, Tuple
import k2
import sentencepiece as spm
from tokenizer import Tokenizer
import torch
import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule
from asr_datamodule import ReazonSpeechAsrDataModule
from beam_search import (
beam_search,
fast_beam_search_nbest,
@ -204,7 +204,7 @@ def get_parser():
parser.add_argument(
"--lang-dir",
type=Path,
default="data/lang_bpe_500",
default="data/lang_char",
help="The lang dir containing word table and LG graph",
)
@ -378,7 +378,7 @@ def get_parser():
def decode_one_batch(
params: AttributeDict,
model: nn.Module,
sp: spm.SentencePieceProcessor,
sp: Tokenizer,
batch: dict,
word_table: Optional[k2.SymbolTable] = None,
decoding_graph: Optional[k2.Fsa] = None,
@ -459,7 +459,7 @@ def decode_one_batch(
max_states=params.max_states,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
hyps.append(sp.text2word(hyp))
elif params.decoding_method == "fast_beam_search_nbest_LG":
hyp_tokens = fast_beam_search_nbest_LG(
model=model,
@ -487,7 +487,7 @@ def decode_one_batch(
nbest_scale=params.nbest_scale,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
hyps.append(sp.text2word(hyp))
elif params.decoding_method == "fast_beam_search_nbest_oracle":
hyp_tokens = fast_beam_search_nbest_oracle(
model=model,
@ -502,7 +502,7 @@ def decode_one_batch(
nbest_scale=params.nbest_scale,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
hyps.append(sp.text2word(hyp))
elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
hyp_tokens = greedy_search_batch(
model=model,
@ -510,7 +510,7 @@ def decode_one_batch(
encoder_out_lens=encoder_out_lens,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
hyps.append(sp.text2word(hyp))
elif params.decoding_method == "modified_beam_search":
hyp_tokens = modified_beam_search(
model=model,
@ -520,7 +520,7 @@ def decode_one_batch(
context_graph=context_graph,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
hyps.append(sp.text2word(hyp))
elif params.decoding_method == "modified_beam_search_lm_shallow_fusion":
hyp_tokens = modified_beam_search_lm_shallow_fusion(
model=model,
@ -530,7 +530,7 @@ def decode_one_batch(
LM=LM,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
hyps.append(sp.text2word(hyp))
elif params.decoding_method == "modified_beam_search_LODR":
hyp_tokens = modified_beam_search_LODR(
model=model,
@ -543,7 +543,7 @@ def decode_one_batch(
context_graph=context_graph,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
hyps.append(sp.text2word(hyp))
elif params.decoding_method == "modified_beam_search_lm_rescore":
lm_scale_list = [0.01 * i for i in range(10, 50)]
ans_dict = modified_beam_search_lm_rescore(
@ -589,7 +589,7 @@ def decode_one_batch(
raise ValueError(
f"Unsupported decoding method: {params.decoding_method}"
)
hyps.append(sp.decode(hyp).split())
hyps.append(sp.text2word(sp.decode(hyp)))
if params.decoding_method == "greedy_search":
return {"greedy_search": hyps}
@ -628,7 +628,7 @@ def decode_dataset(
dl: torch.utils.data.DataLoader,
params: AttributeDict,
model: nn.Module,
sp: spm.SentencePieceProcessor,
sp: Tokenizer,
word_table: Optional[k2.SymbolTable] = None,
decoding_graph: Optional[k2.Fsa] = None,
context_graph: Optional[ContextGraph] = None,
@ -694,7 +694,7 @@ def decode_dataset(
this_batch = []
assert len(hyps) == len(texts)
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
ref_words = ref_text.split()
ref_words = sp.text2word(ref_text)
this_batch.append((cut_id, ref_words, hyp_words))
results[name].extend(this_batch)
@ -755,8 +755,8 @@ def save_results(
@torch.no_grad()
def main():
parser = get_parser()
LibriSpeechAsrDataModule.add_arguments(parser)
LmScorer.add_arguments(parser)
ReazonSpeechAsrDataModule.add_arguments(parser)
Tokenizer.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
@ -839,10 +839,9 @@ def main():
logging.info(f"Device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
sp = Tokenizer.load(params.lang, params.lang_type)
# <blk> and <unk> are defined in local/train_bpe_model.py
# <blk> and <unk> are defined in local/prepare_lang_char.py
params.blank_id = sp.piece_to_id("<blk>")
params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size()
@ -1014,20 +1013,11 @@ def main():
# we need cut ids to display recognition results.
args.return_cuts = True
librispeech = LibriSpeechAsrDataModule(args)
reazonspeech_corpus = ReazonSpeechAsrDataModule(args)
test_clean_cuts = librispeech.test_clean_cuts()
test_other_cuts = librispeech.test_other_cuts()
test_clean_dl = librispeech.test_dataloaders(test_clean_cuts)
test_other_dl = librispeech.test_dataloaders(test_other_cuts)
test_sets = ["test-clean", "test-other"]
test_dl = [test_clean_dl, test_other_dl]
for test_set, test_dl in zip(test_sets, test_dl):
for subdir in ["valid"]:
results_dict = decode_dataset(
dl=test_dl,
dl=reazonspeech_corpus.test_dataloaders(getattr(reazonspeech_corpus, f"{subdir}_cuts")()),
params=params,
model=model,
sp=sp,
@ -1038,12 +1028,22 @@ def main():
ngram_lm=ngram_lm,
ngram_lm_scale=ngram_lm_scale,
)
save_results(
tot_err = save_results(
params=params,
test_set_name=test_set,
test_set_name=subdir,
results_dict=results_dict,
)
with (
params.res_dir
/ (
f"{subdir}-{params.decode_chunk_len}_{params.beam_size}"
f"_{params.avg}_{params.epoch}.cer"
)
).open("w") as fout:
if len(tot_err) == 1:
fout.write(f"{tot_err[0][1]}")
else:
fout.write("\n".join(f"{k}\t{v}") for k, v in tot_err)
logging.info("Done!")