diff --git a/egs/librispeech/ASR/zipformer/decode.py b/egs/librispeech/ASR/zipformer/decode.py index 3f0ee3103..df2d555a0 100755 --- a/egs/librispeech/ASR/zipformer/decode.py +++ b/egs/librispeech/ASR/zipformer/decode.py @@ -103,9 +103,10 @@ from pathlib import Path from typing import Dict, List, Optional, Tuple import k2 +import sentencepiece as spm import torch import torch.nn as nn -from asr_datamodule import ReazonSpeechAsrDataModule +from asr_datamodule import LibriSpeechAsrDataModule from beam_search import ( beam_search, fast_beam_search_nbest, @@ -120,7 +121,6 @@ from beam_search import ( modified_beam_search_lm_shallow_fusion, modified_beam_search_LODR, ) -from tokenizer import Tokenizer from train import add_model_arguments, get_model, get_params from icefall import ContextGraph, LmScorer, NgramLm @@ -133,7 +133,6 @@ from icefall.checkpoint import ( from icefall.lexicon import Lexicon from icefall.utils import ( AttributeDict, - make_pad_mask, setup_logger, store_transcripts, str2bool, @@ -204,7 +203,7 @@ def get_parser(): parser.add_argument( "--lang-dir", type=Path, - default="data/lang_char", + default="data/lang_bpe_500", help="The lang dir containing word table and LG graph", ) @@ -370,19 +369,6 @@ def get_parser(): modified_beam_search_LODR. """, ) - - parser.add_argument( - "--blank-penalty", - type=float, - default=0.0, - help=""" - The penalty applied on blank symbol during decoding. - Note: It is a positive value that would be applied to logits like - this `logits[:, 0] -= blank_penalty` (suppose logits.shape is - [batch_size, vocab] and blank id is 0). - """, - ) - add_model_arguments(parser) return parser @@ -391,7 +377,7 @@ def get_parser(): def decode_one_batch( params: AttributeDict, model: nn.Module, - sp: Tokenizer, + sp: spm.SentencePieceProcessor, batch: dict, word_table: Optional[k2.SymbolTable] = None, decoding_graph: Optional[k2.Fsa] = None, @@ -470,10 +456,9 @@ def decode_one_batch( beam=params.beam, max_contexts=params.max_contexts, max_states=params.max_states, - blank_penalty=params.blank_penalty, ) for hyp in sp.decode(hyp_tokens): - hyps.append(sp.text2word(hyp)) + hyps.append(hyp.split()) elif params.decoding_method == "fast_beam_search_nbest_LG": hyp_tokens = fast_beam_search_nbest_LG( model=model, @@ -485,7 +470,6 @@ def decode_one_batch( max_states=params.max_states, num_paths=params.num_paths, nbest_scale=params.nbest_scale, - blank_penalty=params.blank_penalty, ) for hyp in hyp_tokens: hyps.append([word_table[i] for i in hyp]) @@ -500,10 +484,9 @@ def decode_one_batch( max_states=params.max_states, num_paths=params.num_paths, nbest_scale=params.nbest_scale, - blank_penalty=params.blank_penalty, ) for hyp in sp.decode(hyp_tokens): - hyps.append(sp.text2word(hyp)) + hyps.append(hyp.split()) elif params.decoding_method == "fast_beam_search_nbest_oracle": hyp_tokens = fast_beam_search_nbest_oracle( model=model, @@ -516,19 +499,17 @@ def decode_one_batch( num_paths=params.num_paths, ref_texts=sp.encode(supervisions["text"]), nbest_scale=params.nbest_scale, - blank_penalty=params.blank_penalty, ) for hyp in sp.decode(hyp_tokens): - hyps.append(sp.text2word(hyp)) + hyps.append(hyp.split()) elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, - blank_penalty=params.blank_penalty, ) for hyp in sp.decode(hyp_tokens): - hyps.append(sp.text2word(hyp)) + hyps.append(hyp.split()) elif params.decoding_method == "modified_beam_search": hyp_tokens = modified_beam_search( model=model, @@ -536,10 +517,9 @@ def decode_one_batch( encoder_out_lens=encoder_out_lens, beam=params.beam_size, context_graph=context_graph, - blank_penalty=params.blank_penalty, ) for hyp in sp.decode(hyp_tokens): - hyps.append(sp.text2word(hyp)) + hyps.append(hyp.split()) elif params.decoding_method == "modified_beam_search_lm_shallow_fusion": hyp_tokens = modified_beam_search_lm_shallow_fusion( model=model, @@ -549,7 +529,7 @@ def decode_one_batch( LM=LM, ) for hyp in sp.decode(hyp_tokens): - hyps.append(sp.text2word(hyp)) + hyps.append(hyp.split()) elif params.decoding_method == "modified_beam_search_LODR": hyp_tokens = modified_beam_search_LODR( model=model, @@ -562,7 +542,7 @@ def decode_one_batch( context_graph=context_graph, ) for hyp in sp.decode(hyp_tokens): - hyps.append(sp.text2word(hyp)) + hyps.append(hyp.split()) elif params.decoding_method == "modified_beam_search_lm_rescore": lm_scale_list = [0.01 * i for i in range(10, 50)] ans_dict = modified_beam_search_lm_rescore( @@ -608,9 +588,8 @@ def decode_one_batch( raise ValueError( f"Unsupported decoding method: {params.decoding_method}" ) - hyps.append(sp.text2word(sp.decode(hyp))) + hyps.append(sp.decode(hyp).split()) - key = f"blank_penalty_{params.blank_penalty}" if params.decoding_method == "greedy_search": return {"greedy_search": hyps} elif "fast_beam_search" in params.decoding_method: @@ -648,7 +627,7 @@ def decode_dataset( dl: torch.utils.data.DataLoader, params: AttributeDict, model: nn.Module, - sp: Tokenizer, + sp: spm.SentencePieceProcessor, word_table: Optional[k2.SymbolTable] = None, decoding_graph: Optional[k2.Fsa] = None, context_graph: Optional[ContextGraph] = None, @@ -714,7 +693,7 @@ def decode_dataset( this_batch = [] assert len(hyps) == len(texts) for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): - ref_words = sp.text2word(ref_text) + ref_words = ref_text.split() this_batch.append((cut_id, ref_words, hyp_words)) results[name].extend(this_batch) @@ -775,8 +754,8 @@ def save_results( @torch.no_grad() def main(): parser = get_parser() - ReazonSpeechAsrDataModule.add_arguments(parser) - Tokenizer.add_arguments(parser) + LibriSpeechAsrDataModule.add_arguments(parser) + LmScorer.add_arguments(parser) args = parser.parse_args() args.exp_dir = Path(args.exp_dir) @@ -847,8 +826,6 @@ def main(): f"-LODR-{params.tokens_ngram}gram-scale-{params.ngram_lm_scale}" ) - params.suffix += f"-blank-penalty-{params.blank_penalty}" - if params.use_averaged_model: params.suffix += "-use-averaged-model" @@ -861,9 +838,10 @@ def main(): logging.info(f"Device: {device}") - sp = Tokenizer.load(params.lang, params.lang_type) + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) - # and are defined in local/prepare_lang_char.py + # and are defined in local/train_bpe_model.py params.blank_id = sp.piece_to_id("") params.unk_id = sp.piece_to_id("") params.vocab_size = sp.get_piece_size() @@ -1035,13 +1013,20 @@ def main(): # we need cut ids to display recognition results. args.return_cuts = True - reazonspeech_corpus = ReazonSpeechAsrDataModule(args) + librispeech = LibriSpeechAsrDataModule(args) - for subdir in ["valid"]: + test_clean_cuts = librispeech.test_clean_cuts() + test_other_cuts = librispeech.test_other_cuts() + + test_clean_dl = librispeech.test_dataloaders(test_clean_cuts) + test_other_dl = librispeech.test_dataloaders(test_other_cuts) + + test_sets = ["test-clean", "test-other"] + test_dl = [test_clean_dl, test_other_dl] + + for test_set, test_dl in zip(test_sets, test_dl): results_dict = decode_dataset( - dl=reazonspeech_corpus.test_dataloaders( - getattr(reazonspeech_corpus, f"{subdir}_cuts")() - ), + dl=test_dl, params=params, model=model, sp=sp, @@ -1052,25 +1037,15 @@ def main(): ngram_lm=ngram_lm, ngram_lm_scale=ngram_lm_scale, ) - tot_err = save_results( + + save_results( params=params, - test_set_name=subdir, + test_set_name=test_set, results_dict=results_dict, ) - # with ( - # params.res_dir - # / ( - # f"{subdir}-{params.decode_chunk_len}_{params.beam_size}" - # f"_{params.avg}_{params.epoch}.cer" - # ) - # ).open("w") as fout: - # if len(tot_err) == 1: - # fout.write(f"{tot_err[0][1]}") - # else: - # fout.write("\n".join(f"{k}\t{v}") for k, v in tot_err) logging.info("Done!") if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/egs/reazonspeech/ASR/zipformer/streaming/greedy_search/log-decode-epoch-28-avg-15-chunk-32-left-context-256-use-averaged-model-2024-07-29-17-52-11 b/egs/reazonspeech/ASR/zipformer/streaming/greedy_search/log-decode-epoch-28-avg-15-chunk-32-left-context-256-use-averaged-model-2024-07-29-17-52-11 deleted file mode 100644 index e20ff996d..000000000 --- a/egs/reazonspeech/ASR/zipformer/streaming/greedy_search/log-decode-epoch-28-avg-15-chunk-32-left-context-256-use-averaged-model-2024-07-29-17-52-11 +++ /dev/null @@ -1,2 +0,0 @@ -2024-07-29 17:52:11,668 INFO [streaming_decode.py:736] Decoding started -2024-07-29 17:52:11,669 INFO [streaming_decode.py:742] Device: cuda:0 diff --git a/egs/reazonspeech/ASR/zipformer/streaming/greedy_search/log-decode-epoch-28-avg-15-chunk-32-left-context-256-use-averaged-model-2024-07-29-17-54-22 b/egs/reazonspeech/ASR/zipformer/streaming/greedy_search/log-decode-epoch-28-avg-15-chunk-32-left-context-256-use-averaged-model-2024-07-29-17-54-22 deleted file mode 100644 index 6ad2cb344..000000000 --- a/egs/reazonspeech/ASR/zipformer/streaming/greedy_search/log-decode-epoch-28-avg-15-chunk-32-left-context-256-use-averaged-model-2024-07-29-17-54-22 +++ /dev/null @@ -1,2 +0,0 @@ -2024-07-29 17:54:22,556 INFO [streaming_decode.py:736] Decoding started -2024-07-29 17:54:22,556 INFO [streaming_decode.py:742] Device: cuda:0 diff --git a/egs/reazonspeech/ASR/zipformer/streaming/greedy_search/log-decode-epoch-28-avg-15-chunk-32-left-context-256-use-averaged-model-2024-07-29-17-55-15 b/egs/reazonspeech/ASR/zipformer/streaming/greedy_search/log-decode-epoch-28-avg-15-chunk-32-left-context-256-use-averaged-model-2024-07-29-17-55-15 deleted file mode 100644 index 40a190b75..000000000 --- a/egs/reazonspeech/ASR/zipformer/streaming/greedy_search/log-decode-epoch-28-avg-15-chunk-32-left-context-256-use-averaged-model-2024-07-29-17-55-15 +++ /dev/null @@ -1,2 +0,0 @@ -2024-07-29 17:55:15,276 INFO [streaming_decode.py:736] Decoding started -2024-07-29 17:55:15,277 INFO [streaming_decode.py:742] Device: cuda:0 diff --git a/egs/reazonspeech/ASR/zipformer/streaming/greedy_search/log-decode-epoch-28-avg-15-chunk-32-left-context-256-use-averaged-model-2024-07-29-17-59-02 b/egs/reazonspeech/ASR/zipformer/streaming/greedy_search/log-decode-epoch-28-avg-15-chunk-32-left-context-256-use-averaged-model-2024-07-29-17-59-02 deleted file mode 100644 index 4c721e7d2..000000000 --- a/egs/reazonspeech/ASR/zipformer/streaming/greedy_search/log-decode-epoch-28-avg-15-chunk-32-left-context-256-use-averaged-model-2024-07-29-17-59-02 +++ /dev/null @@ -1,2 +0,0 @@ -2024-07-29 17:59:02,028 INFO [streaming_decode.py:736] Decoding started -2024-07-29 17:59:02,029 INFO [streaming_decode.py:742] Device: cuda:0 diff --git a/egs/reazonspeech/ASR/zipformer/streaming/greedy_search/log-decode-epoch-28-avg-15-chunk-32-left-context-256-use-averaged-model-2024-07-29-18-01-06 b/egs/reazonspeech/ASR/zipformer/streaming/greedy_search/log-decode-epoch-28-avg-15-chunk-32-left-context-256-use-averaged-model-2024-07-29-18-01-06 deleted file mode 100644 index 6fe40297f..000000000 --- a/egs/reazonspeech/ASR/zipformer/streaming/greedy_search/log-decode-epoch-28-avg-15-chunk-32-left-context-256-use-averaged-model-2024-07-29-18-01-06 +++ /dev/null @@ -1,5 +0,0 @@ -2024-07-29 18:01:06,736 INFO [streaming_decode.py:736] Decoding started -2024-07-29 18:01:06,736 INFO [streaming_decode.py:742] Device: cuda:0 -2024-07-29 18:01:06,740 INFO [streaming_decode.py:753] {'best_train_loss': inf, 'best_valid_loss': inf, 'best_train_epoch': -1, 'best_valid_epoch': -1, 'batch_idx_train': 0, 'log_interval': 50, 'reset_interval': 200, 'valid_interval': 3000, 'feature_dim': 80, 'subsampling_factor': 4, 'warm_step': 2000, 'env_info': {'k2-version': '1.24.4', 'k2-build-type': 'Release', 'k2-with-cuda': True, 'k2-git-sha1': '8f976a1e1407e330e2a233d68f81b1eb5269fdaa', 'k2-git-date': 'Thu Jun 6 02:13:08 2024', 'lhotse-version': '1.26.0.dev+git.bd12d5d.clean', 'torch-version': '2.3.1+cu121', 'torch-cuda-available': True, 'torch-cuda-version': '12.1', 'python-version': '3.10', 'icefall-git-branch': 'jp-streaming', 'icefall-git-sha1': '4af81af-dirty', 'icefall-git-date': 'Thu Jul 18 22:05:59 2024', 'icefall-path': '/root/tmp/icefall', 'k2-path': '/root/miniconda3/envs/myenv/lib/python3.10/site-packages/k2/__init__.py', 'lhotse-path': '/root/miniconda3/envs/myenv/lib/python3.10/site-packages/lhotse/__init__.py', 'hostname': 'KDA00', 'IP address': '192.168.0.1'}, 'epoch': 28, 'iter': 0, 'avg': 15, 'use_averaged_model': True, 'exp_dir': PosixPath('zipformer'), 'bpe_model': 'data/lang_bpe_500/bpe.model', 'lang_dir': PosixPath('data/lang_char'), 'decoding_method': 'greedy_search', 'num_active_paths': 4, 'beam': 4, 'max_contexts': 4, 'max_states': 32, 'context_size': 2, 'num_decode_streams': 2000, 'num_encoder_layers': '2,2,3,4,3,2', 'downsampling_factor': '1,2,4,8,4,2', 'feedforward_dim': '512,768,1024,1536,1024,768', 'num_heads': '4,4,4,8,4,4', 'encoder_dim': '192,256,384,512,384,256', 'query_head_dim': '32', 'value_head_dim': '12', 'pos_head_dim': '4', 'pos_dim': 48, 'encoder_unmasked_dim': '192,192,256,256,256,192', 'cnn_module_kernel': '31,31,15,15,15,31', 'decoder_dim': 512, 'joiner_dim': 512, 'causal': True, 'chunk_size': '32', 'left_context_frames': '256', 'use_transducer': True, 'use_ctc': False, 'manifest_dir': PosixPath('data/manifests'), 'max_duration': 200.0, 'bucketing_sampler': True, 'num_buckets': 30, 'concatenate_cuts': False, 'duration_factor': 1.0, 'gap': 1.0, 'on_the_fly_feats': False, 'shuffle': True, 'drop_last': True, 'return_cuts': False, 'num_workers': 2, 'enable_spec_aug': True, 'spec_aug_time_warp_factor': 80, 'enable_musan': False, 'lang': PosixPath('data/lang_char'), 'lang_type': None, 'res_dir': PosixPath('zipformer/streaming/greedy_search'), 'suffix': 'epoch-28-avg-15-chunk-32-left-context-256-use-averaged-model', 'blank_id': 0, 'unk_id': 2990, 'vocab_size': 2992} -2024-07-29 18:01:06,740 INFO [streaming_decode.py:755] About to create model -2024-07-29 18:01:07,118 INFO [streaming_decode.py:822] Calculating the averaged model over epoch range from 13 (excluded) to 28