diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/decode.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/decode.py index 38aff8834..e68417f0a 100755 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/decode.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/decode.py @@ -63,10 +63,9 @@ from pathlib import Path from typing import Dict, List, Optional, Tuple import k2 -import sentencepiece as spm import torch import torch.nn as nn -from asr_datamodule import LibriSpeechAsrDataModule +from asr_datamodule import WenetSpeechAsrDataModule from beam_search import ( beam_search, fast_beam_search, @@ -81,6 +80,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) +from icefall.lexicon import Lexicon from icefall.utils import ( AttributeDict, setup_logger, @@ -203,7 +203,7 @@ def get_parser(): def decode_one_batch( params: AttributeDict, model: nn.Module, - sp: spm.SentencePieceProcessor, + lexicon: Lexicon, batch: dict, decoding_graph: Optional[k2.Fsa] = None, ) -> Dict[str, List[List[str]]]: @@ -222,8 +222,6 @@ def decode_one_batch( It's the return value of :func:`get_params`. model: The neural model. - sp: - The BPE model. batch: It is the return value from iterating `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation @@ -260,8 +258,8 @@ def decode_one_batch( max_contexts=params.max_contexts, max_states=params.max_states, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for i in range(encoder_out.size(0)): + hyps.append([lexicon.token_table[idx] for idx in hyp_tokens]) elif ( params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1 @@ -270,16 +268,16 @@ def decode_one_batch( model=model, encoder_out=encoder_out, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for i in range(encoder_out.size(0)): + hyps.append([lexicon.token_table[idx] for idx in hyp_tokens]) elif params.decoding_method == "modified_beam_search": hyp_tokens = modified_beam_search( model=model, encoder_out=encoder_out, beam=params.beam_size, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for i in range(encoder_out.size(0)): + hyps.append([lexicon.token_table[idx] for idx in hyp_tokens]) else: batch_size = encoder_out.size(0) @@ -303,7 +301,7 @@ def decode_one_batch( raise ValueError( f"Unsupported decoding method: {params.decoding_method}" ) - hyps.append(sp.decode(hyp).split()) + hyps.append([lexicon.token_table[idx] for idx in hyp]) if params.decoding_method == "greedy_search": return {"greedy_search": hyps} @@ -323,7 +321,7 @@ def decode_dataset( dl: torch.utils.data.DataLoader, params: AttributeDict, model: nn.Module, - sp: spm.SentencePieceProcessor, + lexicon: Lexicon, decoding_graph: Optional[k2.Fsa] = None, ) -> Dict[str, List[Tuple[List[str], List[str]]]]: """Decode dataset. @@ -335,8 +333,6 @@ def decode_dataset( It is returned by :func:`get_params`. model: The neural model. - sp: - The BPE model. decoding_graph: The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used only when --decoding_method is fast_beam_search. @@ -366,7 +362,7 @@ def decode_dataset( hyps_dict = decode_one_batch( params=params, model=model, - sp=sp, + lexicon=lexicon, decoding_graph=decoding_graph, batch=batch, ) @@ -438,7 +434,7 @@ def save_results( @torch.no_grad() def main(): parser = get_parser() - LibriSpeechAsrDataModule.add_arguments(parser) + WenetSpeechAsrDataModule.add_arguments(parser) args = parser.parse_args() args.exp_dir = Path(args.exp_dir) @@ -473,12 +469,9 @@ def main(): logging.info(f"Device: {device}") - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) - - # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + lexicon = Lexicon(params.lang_dir) + params.blank_id = lexicon.token_table[""] + params.vocab_size = max(lexicon.tokens) + 1 logging.info(params) @@ -514,26 +507,24 @@ def main(): num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") - librispeech = LibriSpeechAsrDataModule(args) + wenetspeech = WenetSpeechAsrDataModule(args) + test_net_cuts = wenetspeech.test_net_cuts() + test_meeting_cuts = wenetspeech.test_meeting_cuts() - test_clean_cuts = librispeech.test_clean_cuts() - test_other_cuts = librispeech.test_other_cuts() + test_net_dl = wenetspeech.valid_dataloaders(test_net_cuts) + test_meeting_dl = wenetspeech.test_dataloaders(test_meeting_cuts) - test_clean_dl = librispeech.test_dataloaders(test_clean_cuts) - test_other_dl = librispeech.test_dataloaders(test_other_cuts) - - test_sets = ["test-clean", "test-other"] - test_dl = [test_clean_dl, test_other_dl] + test_sets = ["TEST_NET", "TEST_MEETING"] + test_dl = [test_net_dl, test_meeting_dl] for test_set, test_dl in zip(test_sets, test_dl): results_dict = decode_dataset( dl=test_dl, params=params, model=model, - sp=sp, + lexicon=lexicon, decoding_graph=decoding_graph, ) - save_results( params=params, test_set_name=test_set,