From a57c54124ad252a9c77528cdd814e27afdc285e8 Mon Sep 17 00:00:00 2001 From: marcoyang Date: Mon, 13 Feb 2023 17:55:34 +0800 Subject: [PATCH] fix code style --- .../streaming-ncnn-decode.py | 54 +++++-------------- .../streaming_decode.py | 50 ++++------------- .../ASR/lstm_transducer_stateless3/train.py | 54 +++++-------------- 3 files changed, 37 insertions(+), 121 deletions(-) diff --git a/egs/tal_csasr/ASR/lstm_transducer_stateless3/streaming-ncnn-decode.py b/egs/tal_csasr/ASR/lstm_transducer_stateless3/streaming-ncnn-decode.py index 910a799bd..05e6f2b0e 100755 --- a/egs/tal_csasr/ASR/lstm_transducer_stateless3/streaming-ncnn-decode.py +++ b/egs/tal_csasr/ASR/lstm_transducer_stateless3/streaming-ncnn-decode.py @@ -37,45 +37,31 @@ def get_args(): parser = argparse.ArgumentParser() parser.add_argument( - "--tokens", - type=str, - help="Path to tokens.txt", + "--tokens", type=str, help="Path to tokens.txt", ) parser.add_argument( - "--encoder-param-filename", - type=str, - help="Path to encoder.ncnn.param", + "--encoder-param-filename", type=str, help="Path to encoder.ncnn.param", ) parser.add_argument( - "--encoder-bin-filename", - type=str, - help="Path to encoder.ncnn.bin", + "--encoder-bin-filename", type=str, help="Path to encoder.ncnn.bin", ) parser.add_argument( - "--decoder-param-filename", - type=str, - help="Path to decoder.ncnn.param", + "--decoder-param-filename", type=str, help="Path to decoder.ncnn.param", ) parser.add_argument( - "--decoder-bin-filename", - type=str, - help="Path to decoder.ncnn.bin", + "--decoder-bin-filename", type=str, help="Path to decoder.ncnn.bin", ) parser.add_argument( - "--joiner-param-filename", - type=str, - help="Path to joiner.ncnn.param", + "--joiner-param-filename", type=str, help="Path to joiner.ncnn.param", ) parser.add_argument( - "--joiner-bin-filename", - type=str, - help="Path to joiner.ncnn.bin", + "--joiner-bin-filename", type=str, help="Path to joiner.ncnn.bin", ) parser.add_argument( @@ -86,23 +72,15 @@ def get_args(): ) parser.add_argument( - "--encoder-dim", - type=int, - default=512, - help="Encoder output dimesion.", + "--encoder-dim", type=int, default=512, help="Encoder output dimesion.", ) parser.add_argument( - "--rnn-hidden-size", - type=int, - default=2048, - help="Dimension of feed forward.", + "--rnn-hidden-size", type=int, default=2048, help="Dimension of feed forward.", ) parser.add_argument( - "sound_filename", - type=str, - help="Path to foo.wav", + "sound_filename", type=str, help="Path to foo.wav", ) return parser.parse_args() @@ -286,8 +264,7 @@ def main(): logging.info(f"Reading sound files: {sound_file}") wave_samples = read_sound_files( - filenames=[sound_file], - expected_sample_rate=sample_rate, + filenames=[sound_file], expected_sample_rate=sample_rate, )[0] logging.info(wave_samples.shape) @@ -298,11 +275,7 @@ def main(): states = ( torch.zeros(num_encoder_layers, batch_size, d_model), - torch.zeros( - num_encoder_layers, - batch_size, - rnn_hidden_size, - ), + torch.zeros(num_encoder_layers, batch_size, rnn_hidden_size,), ) hyp = None @@ -321,8 +294,7 @@ def main(): start += chunk online_fbank.accept_waveform( - sampling_rate=sample_rate, - waveform=samples, + sampling_rate=sample_rate, waveform=samples, ) while online_fbank.num_frames_ready - num_processed_frames >= segment: frames = [] diff --git a/egs/tal_csasr/ASR/lstm_transducer_stateless3/streaming_decode.py b/egs/tal_csasr/ASR/lstm_transducer_stateless3/streaming_decode.py index f49de2983..51c5617d2 100644 --- a/egs/tal_csasr/ASR/lstm_transducer_stateless3/streaming_decode.py +++ b/egs/tal_csasr/ASR/lstm_transducer_stateless3/streaming_decode.py @@ -215,10 +215,7 @@ def get_parser(): ) parser.add_argument( - "--sampling-rate", - type=float, - default=16000, - help="Sample rate of the audio", + "--sampling-rate", type=float, default=16000, help="Sample rate of the audio", ) parser.add_argument( @@ -234,9 +231,7 @@ def get_parser(): def greedy_search( - model: nn.Module, - encoder_out: torch.Tensor, - streams: List[Stream], + model: nn.Module, encoder_out: torch.Tensor, streams: List[Stream], ) -> None: """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. @@ -293,18 +288,12 @@ def greedy_search( device=device, dtype=torch.int64, ) - decoder_out = model.decoder( - decoder_input, - need_pad=False, - ) + decoder_out = model.decoder(decoder_input, need_pad=False,) decoder_out = model.joiner.decoder_proj(decoder_out) def modified_beam_search( - model: nn.Module, - encoder_out: torch.Tensor, - streams: List[Stream], - beam: int = 4, + model: nn.Module, encoder_out: torch.Tensor, streams: List[Stream], beam: int = 4, ): """Beam search in batch mode with --max-sym-per-frame=1 being hardcoded. @@ -358,9 +347,7 @@ def modified_beam_search( # Note: For torch 1.7.1 and below, it requires a torch.int64 tensor # as index, so we use `to(torch.int64)` below. current_encoder_out = torch.index_select( - current_encoder_out, - dim=0, - index=hyps_shape.row_ids(1).to(torch.int64), + current_encoder_out, dim=0, index=hyps_shape.row_ids(1).to(torch.int64), ) # (num_hyps, encoder_out_dim) logits = model.joiner(current_encoder_out, decoder_out, project_input=False) @@ -547,26 +534,19 @@ def decode_one_chunk( pad_length = tail_length - features.size(1) feature_lens += pad_length features = torch.nn.functional.pad( - features, - (0, 0, 0, pad_length), - mode="constant", - value=LOG_EPSILON, + features, (0, 0, 0, pad_length), mode="constant", value=LOG_EPSILON, ) # Stack states of all streams states = stack_states(state_list) encoder_out, encoder_out_lens, states = model.encoder( - x=features, - x_lens=feature_lens, - states=states, + x=features, x_lens=feature_lens, states=states, ) if params.decoding_method == "greedy_search": greedy_search( - model=model, - streams=streams, - encoder_out=encoder_out, + model=model, streams=streams, encoder_out=encoder_out, ) elif params.decoding_method == "modified_beam_search": modified_beam_search( @@ -725,10 +705,7 @@ def decode_dataset( while len(streams) > 0: finished_streams = decode_one_chunk( - model=model, - streams=streams, - params=params, - decoding_graph=decoding_graph, + model=model, streams=streams, params=params, decoding_graph=decoding_graph, ) for i in sorted(finished_streams, reverse=True): @@ -848,10 +825,7 @@ def main(): sp.load(bpe_model) lexicon = Lexicon(params.lang_dir) - graph_compiler = CharCtcTrainingGraphCompiler( - lexicon=lexicon, - device=device, - ) + graph_compiler = CharCtcTrainingGraphCompiler(lexicon=lexicon, device=device,) params.blank_id = lexicon.token_table[""] params.vocab_size = max(lexicon.tokens) + 1 @@ -979,9 +953,7 @@ def main(): ) save_results( - params=params, - test_set_name=test_set, - results_dict=results_dict, + params=params, test_set_name=test_set, results_dict=results_dict, ) logging.info("Done!") diff --git a/egs/tal_csasr/ASR/lstm_transducer_stateless3/train.py b/egs/tal_csasr/ASR/lstm_transducer_stateless3/train.py index c67aad202..7eec51c24 100755 --- a/egs/tal_csasr/ASR/lstm_transducer_stateless3/train.py +++ b/egs/tal_csasr/ASR/lstm_transducer_stateless3/train.py @@ -103,38 +103,23 @@ def add_model_arguments(parser: argparse.ArgumentParser): ) parser.add_argument( - "--encoder-dim", - type=int, - default=512, - help="Encoder output dimesion.", + "--encoder-dim", type=int, default=512, help="Encoder output dimesion.", ) parser.add_argument( - "--decoder-dim", - type=int, - default=512, - help="Decoder output dimension.", + "--decoder-dim", type=int, default=512, help="Decoder output dimension.", ) parser.add_argument( - "--joiner-dim", - type=int, - default=512, - help="Joiner output dimension.", + "--joiner-dim", type=int, default=512, help="Joiner output dimension.", ) parser.add_argument( - "--dim-feedforward", - type=int, - default=2048, - help="Dimension of feed forward.", + "--dim-feedforward", type=int, default=2048, help="Dimension of feed forward.", ) parser.add_argument( - "--rnn-hidden-size", - type=int, - default=1024, - help="Hidden dim for LSTM layers.", + "--rnn-hidden-size", type=int, default=1024, help="Hidden dim for LSTM layers.", ) parser.add_argument( @@ -171,10 +156,7 @@ def get_parser(): ) parser.add_argument( - "--world-size", - type=int, - default=1, - help="Number of GPUs for DDP training.", + "--world-size", type=int, default=1, help="Number of GPUs for DDP training.", ) parser.add_argument( @@ -192,10 +174,7 @@ def get_parser(): ) parser.add_argument( - "--num-epochs", - type=int, - default=40, - help="Number of epochs to train.", + "--num-epochs", type=int, default=40, help="Number of epochs to train.", ) parser.add_argument( @@ -670,7 +649,7 @@ def compute_loss( f"simple_loss: {simple_loss}\n" f"pruned_loss: {pruned_loss}" ) - display_and_save_batch(batch, params=params, sp=sp) + display_and_save_batch(batch, params=params) simple_loss = simple_loss[simple_loss_is_finite] pruned_loss = pruned_loss[pruned_loss_is_finite] @@ -834,7 +813,7 @@ def train_one_epoch( scaler.update() optimizer.zero_grad() except: # noqa - display_and_save_batch(batch, params=params, sp=sp) + display_and_save_batch(batch, params=params) raise if params.print_diagnostics and batch_idx == 30: @@ -846,9 +825,7 @@ def train_one_epoch( and params.batch_idx_train % params.average_period == 0 ): update_averaged_model( - params=params, - model_cur=model, - model_avg=model_avg, + params=params, model_cur=model, model_avg=model_avg, ) if ( @@ -870,9 +847,7 @@ def train_one_epoch( ) del params.cur_batch_idx remove_checkpoints( - out_dir=params.exp_dir, - topk=params.keep_last_k, - rank=rank, + out_dir=params.exp_dir, topk=params.keep_last_k, rank=rank, ) if batch_idx % params.log_interval == 0 and not params.print_diagnostics: @@ -960,10 +935,7 @@ def run(rank, world_size, args): sp.load(bpe_model) lexicon = Lexicon(params.lang_dir) - graph_compiler = CharCtcTrainingGraphCompiler( - lexicon=lexicon, - device=device, - ) + graph_compiler = CharCtcTrainingGraphCompiler(lexicon=lexicon, device=device,) params.blank_id = lexicon.token_table[""] params.vocab_size = max(lexicon.tokens) + 1 @@ -1014,7 +986,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 2 ** 22 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts)