From d5252a41574da40be9fc20f3fc048b7b69d60702 Mon Sep 17 00:00:00 2001 From: pkufool Date: Tue, 21 Nov 2023 18:49:51 +0800 Subject: [PATCH] black --- egs/libriheavy/ASR/local/train_bpe_model.py | 8 ++- .../ASR/zipformer/asr_datamodule.py | 20 +++++-- egs/libriheavy/ASR/zipformer/decode.py | 21 ++++++-- egs/libriheavy/ASR/zipformer/train.py | 52 +++++++++++++++---- 4 files changed, 78 insertions(+), 23 deletions(-) diff --git a/egs/libriheavy/ASR/local/train_bpe_model.py b/egs/libriheavy/ASR/local/train_bpe_model.py index 714dc87bf..19caf43ab 100755 --- a/egs/libriheavy/ASR/local/train_bpe_model.py +++ b/egs/libriheavy/ASR/local/train_bpe_model.py @@ -56,11 +56,15 @@ def get_args(): ) parser.add_argument( - "--transcript", type=str, help="Training transcript.", + "--transcript", + type=str, + help="Training transcript.", ) parser.add_argument( - "--vocab-size", type=int, help="Vocabulary size for BPE training", + "--vocab-size", + type=int, + help="Vocabulary size for BPE training", ) return parser.parse_args() diff --git a/egs/libriheavy/ASR/zipformer/asr_datamodule.py b/egs/libriheavy/ASR/zipformer/asr_datamodule.py index ca9cd29cf..df761c1b8 100644 --- a/egs/libriheavy/ASR/zipformer/asr_datamodule.py +++ b/egs/libriheavy/ASR/zipformer/asr_datamodule.py @@ -215,7 +215,9 @@ class LibriHeavyAsrDataModule: ) def train_dataloaders( - self, cuts_train: CutSet, sampler_state_dict: Optional[Dict[str, Any]] = None, + self, + cuts_train: CutSet, + sampler_state_dict: Optional[Dict[str, Any]] = None, ) -> DataLoader: """ Args: @@ -357,10 +359,13 @@ class LibriHeavyAsrDataModule: ) else: validate = K2SpeechRecognitionDataset( - cut_transforms=transforms, return_cuts=self.args.return_cuts, + cut_transforms=transforms, + return_cuts=self.args.return_cuts, ) valid_sampler = DynamicBucketingSampler( - cuts_valid, max_duration=self.args.max_duration, shuffle=False, + cuts_valid, + max_duration=self.args.max_duration, + shuffle=False, ) logging.info("About to create dev dataloader") valid_dl = DataLoader( @@ -382,11 +387,16 @@ class LibriHeavyAsrDataModule: return_cuts=self.args.return_cuts, ) sampler = DynamicBucketingSampler( - cuts, max_duration=self.args.max_duration, shuffle=False, + cuts, + max_duration=self.args.max_duration, + shuffle=False, ) logging.debug("About to create test dataloader") test_dl = DataLoader( - test, batch_size=None, sampler=sampler, num_workers=self.args.num_workers, + test, + batch_size=None, + sampler=sampler, + num_workers=self.args.num_workers, ) return test_dl diff --git a/egs/libriheavy/ASR/zipformer/decode.py b/egs/libriheavy/ASR/zipformer/decode.py index 8227fccdb..1928e2635 100644 --- a/egs/libriheavy/ASR/zipformer/decode.py +++ b/egs/libriheavy/ASR/zipformer/decode.py @@ -174,7 +174,10 @@ def get_parser(): ) parser.add_argument( - "--exp-dir", type=str, default="zipformer/exp", help="The experiment dir", + "--exp-dir", + type=str, + default="zipformer/exp", + help="The experiment dir", ) parser.add_argument( @@ -349,7 +352,9 @@ def decode_one_batch( pad_len = 30 feature_lens += pad_len feature = torch.nn.functional.pad( - feature, pad=(0, 0, 0, pad_len), value=LOG_EPS, + feature, + pad=(0, 0, 0, pad_len), + value=LOG_EPS, ) encoder_out, encoder_out_lens = model.forward_encoder(feature, feature_lens) @@ -399,7 +404,9 @@ def decode_one_batch( hyps.append(hyp.split()) elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( - model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) @@ -427,7 +434,9 @@ def decode_one_batch( ) elif params.decoding_method == "beam_search": hyp = beam_search( - model=model, encoder_out=encoder_out_i, beam=params.beam_size, + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, ) else: raise ValueError( @@ -773,7 +782,9 @@ def main(): ) save_results( - params=params, test_set_name=test_set, results_dict=results_dict, + params=params, + test_set_name=test_set, + results_dict=results_dict, ) logging.info("Done!") diff --git a/egs/libriheavy/ASR/zipformer/train.py b/egs/libriheavy/ASR/zipformer/train.py index b9e0dffd7..577640735 100644 --- a/egs/libriheavy/ASR/zipformer/train.py +++ b/egs/libriheavy/ASR/zipformer/train.py @@ -255,7 +255,10 @@ def add_model_arguments(parser: argparse.ArgumentParser): ) parser.add_argument( - "--use-ctc", type=str2bool, default=False, help="If True, use CTC head.", + "--use-ctc", + type=str2bool, + default=False, + help="If True, use CTC head.", ) @@ -265,7 +268,10 @@ def get_parser(): ) parser.add_argument( - "--world-size", type=int, default=1, help="Number of GPUs for DDP training.", + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", ) parser.add_argument( @@ -283,7 +289,10 @@ def get_parser(): ) parser.add_argument( - "--num-epochs", type=int, default=30, help="Number of epochs to train.", + "--num-epochs", + type=int, + default=30, + help="Number of epochs to train.", ) parser.add_argument( @@ -391,7 +400,10 @@ def get_parser(): ) parser.add_argument( - "--ctc-loss-scale", type=float, default=0.2, help="Scale for CTC loss.", + "--ctc-loss-scale", + type=float, + default=0.2, + help="Scale for CTC loss.", ) parser.add_argument( @@ -853,7 +865,11 @@ def compute_validation_loss( for batch_idx, batch in enumerate(valid_dl): loss, loss_info = compute_loss( - params=params, model=model, sp=sp, batch=batch, is_training=False, + params=params, + model=model, + sp=sp, + batch=batch, + is_training=False, ) assert loss.requires_grad is False tot_loss = tot_loss + loss_info @@ -943,7 +959,11 @@ def train_one_epoch( try: with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( - params=params, model=model, sp=sp, batch=batch, is_training=True, + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, ) # summary stats tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info @@ -974,7 +994,9 @@ def train_one_epoch( and params.batch_idx_train % params.average_period == 0 ): update_averaged_model( - params=params, model_cur=model, model_avg=model_avg, + params=params, + model_cur=model, + model_avg=model_avg, ) if ( @@ -994,7 +1016,9 @@ def train_one_epoch( rank=rank, ) remove_checkpoints( - out_dir=params.exp_dir, topk=params.keep_last_k, rank=rank, + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, ) if batch_idx % 100 == 0 and params.use_fp16: @@ -1156,7 +1180,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2 ** 22 + 2**22 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) @@ -1297,7 +1321,9 @@ def run(rank, world_size, args): def display_and_save_batch( - batch: dict, params: AttributeDict, sp: spm.SentencePieceProcessor, + batch: dict, + params: AttributeDict, + sp: spm.SentencePieceProcessor, ) -> None: """Display the batch statistics and save the batch into disk. @@ -1344,7 +1370,11 @@ def scan_pessimistic_batches_for_oom( try: with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, _ = compute_loss( - params=params, model=model, sp=sp, batch=batch, is_training=True, + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, ) loss.backward() optimizer.zero_grad()