diff --git a/egs/tal_csasr/ASR/lstm_transducer_stateless3/decode.py b/egs/tal_csasr/ASR/lstm_transducer_stateless3/decode.py index 8794b49f2..f40d22cd8 100755 --- a/egs/tal_csasr/ASR/lstm_transducer_stateless3/decode.py +++ b/egs/tal_csasr/ASR/lstm_transducer_stateless3/decode.py @@ -302,7 +302,9 @@ def decode_one_batch( en_hyps.append(en_text) elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( - model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, ) for i in range(encoder_out.size(0)): hyp = sp.decode([lexicon.token_table[idx] for idx in hyp_tokens[i]]) @@ -358,7 +360,9 @@ def decode_one_batch( ) elif params.decoding_method == "beam_search": hyp = beam_search( - model=model, encoder_out=encoder_out_i, beam=params.beam_size, + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, ) else: raise ValueError( @@ -722,13 +726,19 @@ def main(): sp=sp, ) save_results( - params=params, test_set_name=test_set, results_dict=results_dict, + params=params, + test_set_name=test_set, + results_dict=results_dict, ) save_results( - params=params, test_set_name=test_set, results_dict=zh_results_dict, + params=params, + test_set_name=test_set, + results_dict=zh_results_dict, ) save_results( - params=params, test_set_name=test_set, results_dict=en_results_dict, + params=params, + test_set_name=test_set, + results_dict=en_results_dict, ) logging.info("Done!") diff --git a/egs/tal_csasr/ASR/lstm_transducer_stateless3/export-for-ncnn.py b/egs/tal_csasr/ASR/lstm_transducer_stateless3/export-for-ncnn.py index 9816b87b1..83e1b8936 100755 --- a/egs/tal_csasr/ASR/lstm_transducer_stateless3/export-for-ncnn.py +++ b/egs/tal_csasr/ASR/lstm_transducer_stateless3/export-for-ncnn.py @@ -107,7 +107,10 @@ def get_parser(): ) parser.add_argument( - "--lang-dir", type=str, default="data/lang_char", help="Path to the lang", + "--lang-dir", + type=str, + default="data/lang_char", + help="Path to the lang", ) parser.add_argument( @@ -134,7 +137,8 @@ def get_parser(): def export_encoder_model_jit_trace( - encoder_model: torch.nn.Module, encoder_filename: str, + encoder_model: torch.nn.Module, + encoder_filename: str, ) -> None: """Export the given encoder model with torch.jit.trace() @@ -156,7 +160,8 @@ def export_encoder_model_jit_trace( def export_decoder_model_jit_trace( - decoder_model: torch.nn.Module, decoder_filename: str, + decoder_model: torch.nn.Module, + decoder_filename: str, ) -> None: """Export the given decoder model with torch.jit.trace() @@ -177,7 +182,8 @@ def export_decoder_model_jit_trace( def export_joiner_model_jit_trace( - joiner_model: torch.nn.Module, joiner_filename: str, + joiner_model: torch.nn.Module, + joiner_filename: str, ) -> None: """Export the given joiner model with torch.jit.trace() diff --git a/egs/tal_csasr/ASR/lstm_transducer_stateless3/streaming-ncnn-decode.py b/egs/tal_csasr/ASR/lstm_transducer_stateless3/streaming-ncnn-decode.py index 05e6f2b0e..910a799bd 100755 --- a/egs/tal_csasr/ASR/lstm_transducer_stateless3/streaming-ncnn-decode.py +++ b/egs/tal_csasr/ASR/lstm_transducer_stateless3/streaming-ncnn-decode.py @@ -37,31 +37,45 @@ def get_args(): parser = argparse.ArgumentParser() parser.add_argument( - "--tokens", type=str, help="Path to tokens.txt", + "--tokens", + type=str, + help="Path to tokens.txt", ) parser.add_argument( - "--encoder-param-filename", type=str, help="Path to encoder.ncnn.param", + "--encoder-param-filename", + type=str, + help="Path to encoder.ncnn.param", ) parser.add_argument( - "--encoder-bin-filename", type=str, help="Path to encoder.ncnn.bin", + "--encoder-bin-filename", + type=str, + help="Path to encoder.ncnn.bin", ) parser.add_argument( - "--decoder-param-filename", type=str, help="Path to decoder.ncnn.param", + "--decoder-param-filename", + type=str, + help="Path to decoder.ncnn.param", ) parser.add_argument( - "--decoder-bin-filename", type=str, help="Path to decoder.ncnn.bin", + "--decoder-bin-filename", + type=str, + help="Path to decoder.ncnn.bin", ) parser.add_argument( - "--joiner-param-filename", type=str, help="Path to joiner.ncnn.param", + "--joiner-param-filename", + type=str, + help="Path to joiner.ncnn.param", ) parser.add_argument( - "--joiner-bin-filename", type=str, help="Path to joiner.ncnn.bin", + "--joiner-bin-filename", + type=str, + help="Path to joiner.ncnn.bin", ) parser.add_argument( @@ -72,15 +86,23 @@ def get_args(): ) parser.add_argument( - "--encoder-dim", type=int, default=512, help="Encoder output dimesion.", + "--encoder-dim", + type=int, + default=512, + help="Encoder output dimesion.", ) parser.add_argument( - "--rnn-hidden-size", type=int, default=2048, help="Dimension of feed forward.", + "--rnn-hidden-size", + type=int, + default=2048, + help="Dimension of feed forward.", ) parser.add_argument( - "sound_filename", type=str, help="Path to foo.wav", + "sound_filename", + type=str, + help="Path to foo.wav", ) return parser.parse_args() @@ -264,7 +286,8 @@ def main(): logging.info(f"Reading sound files: {sound_file}") wave_samples = read_sound_files( - filenames=[sound_file], expected_sample_rate=sample_rate, + filenames=[sound_file], + expected_sample_rate=sample_rate, )[0] logging.info(wave_samples.shape) @@ -275,7 +298,11 @@ def main(): states = ( torch.zeros(num_encoder_layers, batch_size, d_model), - torch.zeros(num_encoder_layers, batch_size, rnn_hidden_size,), + torch.zeros( + num_encoder_layers, + batch_size, + rnn_hidden_size, + ), ) hyp = None @@ -294,7 +321,8 @@ def main(): start += chunk online_fbank.accept_waveform( - sampling_rate=sample_rate, waveform=samples, + sampling_rate=sample_rate, + waveform=samples, ) while online_fbank.num_frames_ready - num_processed_frames >= segment: frames = [] diff --git a/egs/tal_csasr/ASR/lstm_transducer_stateless3/streaming_decode.py b/egs/tal_csasr/ASR/lstm_transducer_stateless3/streaming_decode.py index 51c5617d2..f49de2983 100644 --- a/egs/tal_csasr/ASR/lstm_transducer_stateless3/streaming_decode.py +++ b/egs/tal_csasr/ASR/lstm_transducer_stateless3/streaming_decode.py @@ -215,7 +215,10 @@ def get_parser(): ) parser.add_argument( - "--sampling-rate", type=float, default=16000, help="Sample rate of the audio", + "--sampling-rate", + type=float, + default=16000, + help="Sample rate of the audio", ) parser.add_argument( @@ -231,7 +234,9 @@ def get_parser(): def greedy_search( - model: nn.Module, encoder_out: torch.Tensor, streams: List[Stream], + model: nn.Module, + encoder_out: torch.Tensor, + streams: List[Stream], ) -> None: """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. @@ -288,12 +293,18 @@ def greedy_search( device=device, dtype=torch.int64, ) - decoder_out = model.decoder(decoder_input, need_pad=False,) + decoder_out = model.decoder( + decoder_input, + need_pad=False, + ) decoder_out = model.joiner.decoder_proj(decoder_out) def modified_beam_search( - model: nn.Module, encoder_out: torch.Tensor, streams: List[Stream], beam: int = 4, + model: nn.Module, + encoder_out: torch.Tensor, + streams: List[Stream], + beam: int = 4, ): """Beam search in batch mode with --max-sym-per-frame=1 being hardcoded. @@ -347,7 +358,9 @@ def modified_beam_search( # Note: For torch 1.7.1 and below, it requires a torch.int64 tensor # as index, so we use `to(torch.int64)` below. current_encoder_out = torch.index_select( - current_encoder_out, dim=0, index=hyps_shape.row_ids(1).to(torch.int64), + current_encoder_out, + dim=0, + index=hyps_shape.row_ids(1).to(torch.int64), ) # (num_hyps, encoder_out_dim) logits = model.joiner(current_encoder_out, decoder_out, project_input=False) @@ -534,19 +547,26 @@ def decode_one_chunk( pad_length = tail_length - features.size(1) feature_lens += pad_length features = torch.nn.functional.pad( - features, (0, 0, 0, pad_length), mode="constant", value=LOG_EPSILON, + features, + (0, 0, 0, pad_length), + mode="constant", + value=LOG_EPSILON, ) # Stack states of all streams states = stack_states(state_list) encoder_out, encoder_out_lens, states = model.encoder( - x=features, x_lens=feature_lens, states=states, + x=features, + x_lens=feature_lens, + states=states, ) if params.decoding_method == "greedy_search": greedy_search( - model=model, streams=streams, encoder_out=encoder_out, + model=model, + streams=streams, + encoder_out=encoder_out, ) elif params.decoding_method == "modified_beam_search": modified_beam_search( @@ -705,7 +725,10 @@ def decode_dataset( while len(streams) > 0: finished_streams = decode_one_chunk( - model=model, streams=streams, params=params, decoding_graph=decoding_graph, + model=model, + streams=streams, + params=params, + decoding_graph=decoding_graph, ) for i in sorted(finished_streams, reverse=True): @@ -825,7 +848,10 @@ def main(): sp.load(bpe_model) lexicon = Lexicon(params.lang_dir) - graph_compiler = CharCtcTrainingGraphCompiler(lexicon=lexicon, device=device,) + graph_compiler = CharCtcTrainingGraphCompiler( + lexicon=lexicon, + device=device, + ) params.blank_id = lexicon.token_table[""] params.vocab_size = max(lexicon.tokens) + 1 @@ -953,7 +979,9 @@ def main(): ) save_results( - params=params, test_set_name=test_set, results_dict=results_dict, + params=params, + test_set_name=test_set, + results_dict=results_dict, ) logging.info("Done!") diff --git a/egs/tal_csasr/ASR/lstm_transducer_stateless3/train.py b/egs/tal_csasr/ASR/lstm_transducer_stateless3/train.py index 7eec51c24..bc1b9290e 100755 --- a/egs/tal_csasr/ASR/lstm_transducer_stateless3/train.py +++ b/egs/tal_csasr/ASR/lstm_transducer_stateless3/train.py @@ -103,23 +103,38 @@ def add_model_arguments(parser: argparse.ArgumentParser): ) parser.add_argument( - "--encoder-dim", type=int, default=512, help="Encoder output dimesion.", + "--encoder-dim", + type=int, + default=512, + help="Encoder output dimesion.", ) parser.add_argument( - "--decoder-dim", type=int, default=512, help="Decoder output dimension.", + "--decoder-dim", + type=int, + default=512, + help="Decoder output dimension.", ) parser.add_argument( - "--joiner-dim", type=int, default=512, help="Joiner output dimension.", + "--joiner-dim", + type=int, + default=512, + help="Joiner output dimension.", ) parser.add_argument( - "--dim-feedforward", type=int, default=2048, help="Dimension of feed forward.", + "--dim-feedforward", + type=int, + default=2048, + help="Dimension of feed forward.", ) parser.add_argument( - "--rnn-hidden-size", type=int, default=1024, help="Hidden dim for LSTM layers.", + "--rnn-hidden-size", + type=int, + default=1024, + help="Hidden dim for LSTM layers.", ) parser.add_argument( @@ -156,7 +171,10 @@ def get_parser(): ) parser.add_argument( - "--world-size", type=int, default=1, help="Number of GPUs for DDP training.", + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", ) parser.add_argument( @@ -174,7 +192,10 @@ def get_parser(): ) parser.add_argument( - "--num-epochs", type=int, default=40, help="Number of epochs to train.", + "--num-epochs", + type=int, + default=40, + help="Number of epochs to train.", ) parser.add_argument( @@ -825,7 +846,9 @@ def train_one_epoch( and params.batch_idx_train % params.average_period == 0 ): update_averaged_model( - params=params, model_cur=model, model_avg=model_avg, + params=params, + model_cur=model, + model_avg=model_avg, ) if ( @@ -847,7 +870,9 @@ def train_one_epoch( ) del params.cur_batch_idx remove_checkpoints( - out_dir=params.exp_dir, topk=params.keep_last_k, rank=rank, + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, ) if batch_idx % params.log_interval == 0 and not params.print_diagnostics: @@ -935,7 +960,10 @@ def run(rank, world_size, args): sp.load(bpe_model) lexicon = Lexicon(params.lang_dir) - graph_compiler = CharCtcTrainingGraphCompiler(lexicon=lexicon, device=device,) + graph_compiler = CharCtcTrainingGraphCompiler( + lexicon=lexicon, + device=device, + ) params.blank_id = lexicon.token_table[""] params.vocab_size = max(lexicon.tokens) + 1 @@ -986,7 +1014,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2 ** 22 + 2**22 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/tal_csasr/ASR/pruned_transducer_stateless5/asr_datamodule.py b/egs/tal_csasr/ASR/pruned_transducer_stateless5/asr_datamodule.py index 2ef4e9860..2240c1c1d 100644 --- a/egs/tal_csasr/ASR/pruned_transducer_stateless5/asr_datamodule.py +++ b/egs/tal_csasr/ASR/pruned_transducer_stateless5/asr_datamodule.py @@ -210,7 +210,9 @@ class TAL_CSASRAsrDataModule: ) def train_dataloaders( - self, cuts_train: CutSet, sampler_state_dict: Optional[Dict[str, Any]] = None, + self, + cuts_train: CutSet, + sampler_state_dict: Optional[Dict[str, Any]] = None, ) -> DataLoader: """ Args: @@ -355,7 +357,8 @@ class TAL_CSASRAsrDataModule: ) else: validate = K2SpeechRecognitionDataset( - cut_transforms=transforms, return_cuts=self.args.return_cuts, + cut_transforms=transforms, + return_cuts=self.args.return_cuts, ) valid_sampler = DynamicBucketingSampler( cuts_valid, @@ -392,7 +395,10 @@ class TAL_CSASRAsrDataModule: ) logging.info("About to create test dataloader") test_dl = DataLoader( - test, batch_size=None, sampler=sampler, num_workers=self.args.num_workers, + test, + batch_size=None, + sampler=sampler, + num_workers=self.args.num_workers, ) return test_dl