From 7b6a89749d5d1f80a70ea1ccaad2000bad139eaa Mon Sep 17 00:00:00 2001 From: Triplecq Date: Sun, 14 Jan 2024 17:29:22 -0500 Subject: [PATCH] customize decoding script --- egs/reazonspeech/ASR/zipformer/decode.py | 68 ++++++++++++------------ 1 file changed, 34 insertions(+), 34 deletions(-) diff --git a/egs/reazonspeech/ASR/zipformer/decode.py b/egs/reazonspeech/ASR/zipformer/decode.py index 339e253e6..e2752c95c 100755 --- a/egs/reazonspeech/ASR/zipformer/decode.py +++ b/egs/reazonspeech/ASR/zipformer/decode.py @@ -103,10 +103,10 @@ from pathlib import Path from typing import Dict, List, Optional, Tuple import k2 -import sentencepiece as spm +from tokenizer import Tokenizer import torch import torch.nn as nn -from asr_datamodule import LibriSpeechAsrDataModule +from asr_datamodule import ReazonSpeechAsrDataModule from beam_search import ( beam_search, fast_beam_search_nbest, @@ -204,7 +204,7 @@ def get_parser(): parser.add_argument( "--lang-dir", type=Path, - default="data/lang_bpe_500", + default="data/lang_char", help="The lang dir containing word table and LG graph", ) @@ -378,7 +378,7 @@ def get_parser(): def decode_one_batch( params: AttributeDict, model: nn.Module, - sp: spm.SentencePieceProcessor, + sp: Tokenizer, batch: dict, word_table: Optional[k2.SymbolTable] = None, decoding_graph: Optional[k2.Fsa] = None, @@ -459,7 +459,7 @@ def decode_one_batch( max_states=params.max_states, ) for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + hyps.append(sp.text2word(hyp)) elif params.decoding_method == "fast_beam_search_nbest_LG": hyp_tokens = fast_beam_search_nbest_LG( model=model, @@ -487,7 +487,7 @@ def decode_one_batch( nbest_scale=params.nbest_scale, ) for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + hyps.append(sp.text2word(hyp)) elif params.decoding_method == "fast_beam_search_nbest_oracle": hyp_tokens = fast_beam_search_nbest_oracle( model=model, @@ -502,7 +502,7 @@ def decode_one_batch( nbest_scale=params.nbest_scale, ) for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + hyps.append(sp.text2word(hyp)) elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, @@ -510,7 +510,7 @@ def decode_one_batch( encoder_out_lens=encoder_out_lens, ) for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + hyps.append(sp.text2word(hyp)) elif params.decoding_method == "modified_beam_search": hyp_tokens = modified_beam_search( model=model, @@ -520,7 +520,7 @@ def decode_one_batch( context_graph=context_graph, ) for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + hyps.append(sp.text2word(hyp)) elif params.decoding_method == "modified_beam_search_lm_shallow_fusion": hyp_tokens = modified_beam_search_lm_shallow_fusion( model=model, @@ -530,7 +530,7 @@ def decode_one_batch( LM=LM, ) for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + hyps.append(sp.text2word(hyp)) elif params.decoding_method == "modified_beam_search_LODR": hyp_tokens = modified_beam_search_LODR( model=model, @@ -543,7 +543,7 @@ def decode_one_batch( context_graph=context_graph, ) for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + hyps.append(sp.text2word(hyp)) elif params.decoding_method == "modified_beam_search_lm_rescore": lm_scale_list = [0.01 * i for i in range(10, 50)] ans_dict = modified_beam_search_lm_rescore( @@ -589,7 +589,7 @@ def decode_one_batch( raise ValueError( f"Unsupported decoding method: {params.decoding_method}" ) - hyps.append(sp.decode(hyp).split()) + hyps.append(sp.text2word(sp.decode(hyp))) if params.decoding_method == "greedy_search": return {"greedy_search": hyps} @@ -628,7 +628,7 @@ def decode_dataset( dl: torch.utils.data.DataLoader, params: AttributeDict, model: nn.Module, - sp: spm.SentencePieceProcessor, + sp: Tokenizer, word_table: Optional[k2.SymbolTable] = None, decoding_graph: Optional[k2.Fsa] = None, context_graph: Optional[ContextGraph] = None, @@ -694,7 +694,7 @@ def decode_dataset( this_batch = [] assert len(hyps) == len(texts) for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): - ref_words = ref_text.split() + ref_words = sp.text2word(ref_text) this_batch.append((cut_id, ref_words, hyp_words)) results[name].extend(this_batch) @@ -755,8 +755,8 @@ def save_results( @torch.no_grad() def main(): parser = get_parser() - LibriSpeechAsrDataModule.add_arguments(parser) - LmScorer.add_arguments(parser) + ReazonSpeechAsrDataModule.add_arguments(parser) + Tokenizer.add_arguments(parser) args = parser.parse_args() args.exp_dir = Path(args.exp_dir) @@ -839,10 +839,9 @@ def main(): logging.info(f"Device: {device}") - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + sp = Tokenizer.load(params.lang, params.lang_type) - # and are defined in local/train_bpe_model.py + # and are defined in local/prepare_lang_char.py params.blank_id = sp.piece_to_id("") params.unk_id = sp.piece_to_id("") params.vocab_size = sp.get_piece_size() @@ -1014,20 +1013,11 @@ def main(): # we need cut ids to display recognition results. args.return_cuts = True - librispeech = LibriSpeechAsrDataModule(args) + reazonspeech_corpus = ReazonSpeechAsrDataModule(args) - test_clean_cuts = librispeech.test_clean_cuts() - test_other_cuts = librispeech.test_other_cuts() - - test_clean_dl = librispeech.test_dataloaders(test_clean_cuts) - test_other_dl = librispeech.test_dataloaders(test_other_cuts) - - test_sets = ["test-clean", "test-other"] - test_dl = [test_clean_dl, test_other_dl] - - for test_set, test_dl in zip(test_sets, test_dl): + for subdir in ["valid"]: results_dict = decode_dataset( - dl=test_dl, + dl=reazonspeech_corpus.test_dataloaders(getattr(reazonspeech_corpus, f"{subdir}_cuts")()), params=params, model=model, sp=sp, @@ -1038,12 +1028,22 @@ def main(): ngram_lm=ngram_lm, ngram_lm_scale=ngram_lm_scale, ) - - save_results( + tot_err = save_results( params=params, - test_set_name=test_set, + test_set_name=subdir, results_dict=results_dict, ) + with ( + params.res_dir + / ( + f"{subdir}-{params.decode_chunk_len}_{params.beam_size}" + f"_{params.avg}_{params.epoch}.cer" + ) + ).open("w") as fout: + if len(tot_err) == 1: + fout.write(f"{tot_err[0][1]}") + else: + fout.write("\n".join(f"{k}\t{v}") for k, v in tot_err) logging.info("Done!")