From f6f29b932135dd58aa188eca7e67ed4d2186282d Mon Sep 17 00:00:00 2001 From: marcoyang Date: Tue, 14 Feb 2023 15:52:54 +0800 Subject: [PATCH] reformat --- .../ASR/lstm_transducer_stateless3/decode.py | 20 +++++-------------- .../export-for-ncnn.py | 14 ++++--------- .../asr_datamodule.py | 12 +++-------- 3 files changed, 12 insertions(+), 34 deletions(-) diff --git a/egs/tal_csasr/ASR/lstm_transducer_stateless3/decode.py b/egs/tal_csasr/ASR/lstm_transducer_stateless3/decode.py index f40d22cd8..8794b49f2 100755 --- a/egs/tal_csasr/ASR/lstm_transducer_stateless3/decode.py +++ b/egs/tal_csasr/ASR/lstm_transducer_stateless3/decode.py @@ -302,9 +302,7 @@ def decode_one_batch( en_hyps.append(en_text) elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, + model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, ) for i in range(encoder_out.size(0)): hyp = sp.decode([lexicon.token_table[idx] for idx in hyp_tokens[i]]) @@ -360,9 +358,7 @@ def decode_one_batch( ) elif params.decoding_method == "beam_search": hyp = beam_search( - model=model, - encoder_out=encoder_out_i, - beam=params.beam_size, + model=model, encoder_out=encoder_out_i, beam=params.beam_size, ) else: raise ValueError( @@ -726,19 +722,13 @@ def main(): sp=sp, ) save_results( - params=params, - test_set_name=test_set, - results_dict=results_dict, + params=params, test_set_name=test_set, results_dict=results_dict, ) save_results( - params=params, - test_set_name=test_set, - results_dict=zh_results_dict, + params=params, test_set_name=test_set, results_dict=zh_results_dict, ) save_results( - params=params, - test_set_name=test_set, - results_dict=en_results_dict, + params=params, test_set_name=test_set, results_dict=en_results_dict, ) logging.info("Done!") diff --git a/egs/tal_csasr/ASR/lstm_transducer_stateless3/export-for-ncnn.py b/egs/tal_csasr/ASR/lstm_transducer_stateless3/export-for-ncnn.py index 83e1b8936..9816b87b1 100755 --- a/egs/tal_csasr/ASR/lstm_transducer_stateless3/export-for-ncnn.py +++ b/egs/tal_csasr/ASR/lstm_transducer_stateless3/export-for-ncnn.py @@ -107,10 +107,7 @@ def get_parser(): ) parser.add_argument( - "--lang-dir", - type=str, - default="data/lang_char", - help="Path to the lang", + "--lang-dir", type=str, default="data/lang_char", help="Path to the lang", ) parser.add_argument( @@ -137,8 +134,7 @@ def get_parser(): def export_encoder_model_jit_trace( - encoder_model: torch.nn.Module, - encoder_filename: str, + encoder_model: torch.nn.Module, encoder_filename: str, ) -> None: """Export the given encoder model with torch.jit.trace() @@ -160,8 +156,7 @@ def export_encoder_model_jit_trace( def export_decoder_model_jit_trace( - decoder_model: torch.nn.Module, - decoder_filename: str, + decoder_model: torch.nn.Module, decoder_filename: str, ) -> None: """Export the given decoder model with torch.jit.trace() @@ -182,8 +177,7 @@ def export_decoder_model_jit_trace( def export_joiner_model_jit_trace( - joiner_model: torch.nn.Module, - joiner_filename: str, + joiner_model: torch.nn.Module, joiner_filename: str, ) -> None: """Export the given joiner model with torch.jit.trace() diff --git a/egs/tal_csasr/ASR/pruned_transducer_stateless5/asr_datamodule.py b/egs/tal_csasr/ASR/pruned_transducer_stateless5/asr_datamodule.py index 2240c1c1d..2ef4e9860 100644 --- a/egs/tal_csasr/ASR/pruned_transducer_stateless5/asr_datamodule.py +++ b/egs/tal_csasr/ASR/pruned_transducer_stateless5/asr_datamodule.py @@ -210,9 +210,7 @@ class TAL_CSASRAsrDataModule: ) def train_dataloaders( - self, - cuts_train: CutSet, - sampler_state_dict: Optional[Dict[str, Any]] = None, + self, cuts_train: CutSet, sampler_state_dict: Optional[Dict[str, Any]] = None, ) -> DataLoader: """ Args: @@ -357,8 +355,7 @@ class TAL_CSASRAsrDataModule: ) else: validate = K2SpeechRecognitionDataset( - cut_transforms=transforms, - return_cuts=self.args.return_cuts, + cut_transforms=transforms, return_cuts=self.args.return_cuts, ) valid_sampler = DynamicBucketingSampler( cuts_valid, @@ -395,10 +392,7 @@ class TAL_CSASRAsrDataModule: ) logging.info("About to create test dataloader") test_dl = DataLoader( - test, - batch_size=None, - sampler=sampler, - num_workers=self.args.num_workers, + test, batch_size=None, sampler=sampler, num_workers=self.args.num_workers, ) return test_dl