From c195a12a36ce8e153041d4e2903c9d11c5285ee4 Mon Sep 17 00:00:00 2001 From: pkufool Date: Tue, 21 Nov 2023 17:51:38 +0800 Subject: [PATCH] minor fixes --- egs/libriheavy/ASR/zipformer/decode.py | 29 +++++-------------- .../ASR/zipformer/text_normalization.py | 2 -- 2 files changed, 7 insertions(+), 24 deletions(-) diff --git a/egs/libriheavy/ASR/zipformer/decode.py b/egs/libriheavy/ASR/zipformer/decode.py index 2cb917920..8227fccdb 100644 --- a/egs/libriheavy/ASR/zipformer/decode.py +++ b/egs/libriheavy/ASR/zipformer/decode.py @@ -107,7 +107,7 @@ from beam_search import ( modified_beam_search, ) from lhotse.cut import Cut -from text_normalization import remove_punc_to_upper, +from text_normalization import remove_punc_to_upper from train import add_model_arguments, get_model, get_params from icefall.checkpoint import ( @@ -174,10 +174,7 @@ def get_parser(): ) parser.add_argument( - "--exp-dir", - type=str, - default="zipformer/exp", - help="The experiment dir", + "--exp-dir", type=str, default="zipformer/exp", help="The experiment dir", ) parser.add_argument( @@ -285,7 +282,7 @@ def get_parser(): type=str2bool, default=False, help="""Set to True, if the model was trained on texts with casing - and punctuation.""" + and punctuation.""", ) parser.add_argument( @@ -352,9 +349,7 @@ def decode_one_batch( pad_len = 30 feature_lens += pad_len feature = torch.nn.functional.pad( - feature, - pad=(0, 0, 0, pad_len), - value=LOG_EPS, + feature, pad=(0, 0, 0, pad_len), value=LOG_EPS, ) encoder_out, encoder_out_lens = model.forward_encoder(feature, feature_lens) @@ -404,9 +399,7 @@ def decode_one_batch( hyps.append(hyp.split()) elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, + model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) @@ -434,9 +427,7 @@ def decode_one_batch( ) elif params.decoding_method == "beam_search": hyp = beam_search( - model=model, - encoder_out=encoder_out_i, - beam=params.beam_size, + model=model, encoder_out=encoder_out_i, beam=params.beam_size, ) else: raise ValueError( @@ -505,9 +496,6 @@ def decode_dataset( warnings.simplefilter("ignore") for batch_idx, batch in enumerate(dl): texts = batch["supervisions"]["text"] - texts = [ - simple_normalization(t) for t in texts - ] # Do a simple normalization, as this is done during training cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] @@ -537,7 +525,6 @@ def decode_dataset( results[f"{name}_norm"].extend(this_batch) - num_cuts += len(texts) if batch_idx % log_interval == 0: @@ -786,9 +773,7 @@ def main(): ) save_results( - params=params, - test_set_name=test_set, - results_dict=results_dict, + params=params, test_set_name=test_set, results_dict=results_dict, ) logging.info("Done!") diff --git a/egs/libriheavy/ASR/zipformer/text_normalization.py b/egs/libriheavy/ASR/zipformer/text_normalization.py index 9c9ee6725..92590769c 100644 --- a/egs/libriheavy/ASR/zipformer/text_normalization.py +++ b/egs/libriheavy/ASR/zipformer/text_normalization.py @@ -43,8 +43,6 @@ def text_normalization(text: str) -> str: if __name__ == "__main__": - - assert simple_cleanup("I like this 《book>") == "I like this " assert remove_punc_to_upper("I like this 《book>") == "I LIKE THIS BOOK" assert ( text_normalization("Hello Mrs st 21st world 3rd she 99th MR")