From 2f0d3d7ae37381653829c98aa2a416c8bc569608 Mon Sep 17 00:00:00 2001 From: pkufool Date: Tue, 21 Nov 2023 18:42:17 +0800 Subject: [PATCH] black --- egs/libriheavy/ASR/local/norm_text.py | 3 +- egs/libriheavy/ASR/local/prepare_manifest.py | 6 +- egs/libriheavy/ASR/local/train_bpe_model.py | 12 +-- .../ASR/zipformer/asr_datamodule.py | 44 +++++----- egs/libriheavy/ASR/zipformer/train.py | 85 ++++++------------- 5 files changed, 61 insertions(+), 89 deletions(-) diff --git a/egs/libriheavy/ASR/local/norm_text.py b/egs/libriheavy/ASR/local/norm_text.py index 99f59a320..c2fc0d92d 100755 --- a/egs/libriheavy/ASR/local/norm_text.py +++ b/egs/libriheavy/ASR/local/norm_text.py @@ -19,6 +19,7 @@ import argparse import codecs import sys + def get_args(): parser = argparse.ArgumentParser() parser.add_argument( @@ -52,6 +53,6 @@ def main(): print(remove_punc_to_upper(line)) line = f.readline() - + if __name__ == "__main__": main() diff --git a/egs/libriheavy/ASR/local/prepare_manifest.py b/egs/libriheavy/ASR/local/prepare_manifest.py index 720455e20..42f392cae 100755 --- a/egs/libriheavy/ASR/local/prepare_manifest.py +++ b/egs/libriheavy/ASR/local/prepare_manifest.py @@ -20,11 +20,13 @@ import json import sys from pathlib import Path + def simple_cleanup(text: str) -> str: table = str.maketrans("’‘,。;?!():-《》、“”【】", "'',.;?!(): <>/\"\"[]") text = text.translate(table) return text.strip() + # Assign text of the supervisions and remove unnecessary entries. def main(): assert len(sys.argv) == 3, "Usage: ./local/prepare_manifest.py INPUT OUTPUT_DIR" @@ -33,7 +35,9 @@ def main(): with gzip.open(sys.argv[1], "r") as fin, gzip.open(oname, "w") as fout: for line in fin: cut = json.loads(line) - cut["supervisions"][0]["text"] = simple_cleanup(cut["supervisions"][0]["custom"]["texts"][0]) + cut["supervisions"][0]["text"] = simple_cleanup( + cut["supervisions"][0]["custom"]["texts"][0] + ) del cut["supervisions"][0]["custom"] del cut["custom"] fout.write((json.dumps(cut) + "\n").encode()) diff --git a/egs/libriheavy/ASR/local/train_bpe_model.py b/egs/libriheavy/ASR/local/train_bpe_model.py index 4da3e097e..714dc87bf 100755 --- a/egs/libriheavy/ASR/local/train_bpe_model.py +++ b/egs/libriheavy/ASR/local/train_bpe_model.py @@ -44,8 +44,8 @@ def get_args(): parser.add_argument( "--byte-fallback", - action='store_true', - help="""Whether to enable byte_fallback when training bpe.""" + action="store_true", + help="""Whether to enable byte_fallback when training bpe.""", ) parser.add_argument( @@ -56,15 +56,11 @@ def get_args(): ) parser.add_argument( - "--transcript", - type=str, - help="Training transcript.", + "--transcript", type=str, help="Training transcript.", ) parser.add_argument( - "--vocab-size", - type=int, - help="Vocabulary size for BPE training", + "--vocab-size", type=int, help="Vocabulary size for BPE training", ) return parser.parse_args() diff --git a/egs/libriheavy/ASR/zipformer/asr_datamodule.py b/egs/libriheavy/ASR/zipformer/asr_datamodule.py index 1a6d833a6..ca9cd29cf 100644 --- a/egs/libriheavy/ASR/zipformer/asr_datamodule.py +++ b/egs/libriheavy/ASR/zipformer/asr_datamodule.py @@ -215,9 +215,7 @@ class LibriHeavyAsrDataModule: ) def train_dataloaders( - self, - cuts_train: CutSet, - sampler_state_dict: Optional[Dict[str, Any]] = None, + self, cuts_train: CutSet, sampler_state_dict: Optional[Dict[str, Any]] = None, ) -> DataLoader: """ Args: @@ -359,13 +357,10 @@ class LibriHeavyAsrDataModule: ) else: validate = K2SpeechRecognitionDataset( - cut_transforms=transforms, - return_cuts=self.args.return_cuts, + cut_transforms=transforms, return_cuts=self.args.return_cuts, ) valid_sampler = DynamicBucketingSampler( - cuts_valid, - max_duration=self.args.max_duration, - shuffle=False, + cuts_valid, max_duration=self.args.max_duration, shuffle=False, ) logging.info("About to create dev dataloader") valid_dl = DataLoader( @@ -387,45 +382,52 @@ class LibriHeavyAsrDataModule: return_cuts=self.args.return_cuts, ) sampler = DynamicBucketingSampler( - cuts, - max_duration=self.args.max_duration, - shuffle=False, + cuts, max_duration=self.args.max_duration, shuffle=False, ) logging.debug("About to create test dataloader") test_dl = DataLoader( - test, - batch_size=None, - sampler=sampler, - num_workers=self.args.num_workers, + test, batch_size=None, sampler=sampler, num_workers=self.args.num_workers, ) return test_dl @lru_cache() def train_small_cuts(self) -> CutSet: logging.info("About to get small subset cuts") - return load_manifest_lazy(self.args.manifest_dir / "libriheavy_cuts_small.jsonl.gz") + return load_manifest_lazy( + self.args.manifest_dir / "libriheavy_cuts_small.jsonl.gz" + ) @lru_cache() def train_medium_cuts(self) -> CutSet: logging.info("About to get medium subset cuts") - return load_manifest_lazy(self.args.manifest_dir / "libriheavy_cuts_medium.jsonl.gz") + return load_manifest_lazy( + self.args.manifest_dir / "libriheavy_cuts_medium.jsonl.gz" + ) @lru_cache() def train_large_cuts(self) -> CutSet: logging.info("About to get large subset cuts") - return load_manifest_lazy(self.args.manifest_dir / "libriheavy_cuts_large.jsonl.gz") + return load_manifest_lazy( + self.args.manifest_dir / "libriheavy_cuts_large.jsonl.gz" + ) @lru_cache() def dev_cuts(self) -> CutSet: logging.info("About to get dev cuts") - return load_manifest_lazy(self.args.manifest_dir / "libriheavy_cuts_dev.jsonl.gz") + return load_manifest_lazy( + self.args.manifest_dir / "libriheavy_cuts_dev.jsonl.gz" + ) @lru_cache() def test_clean_cuts(self) -> CutSet: logging.info("About to get the test-clean cuts") - return load_manifest_lazy(self.args.manifest_dir / "libriheavy_cuts_test_clean.jsonl.gz") + return load_manifest_lazy( + self.args.manifest_dir / "libriheavy_cuts_test_clean.jsonl.gz" + ) @lru_cache() def test_other_cuts(self) -> CutSet: logging.info("About to get the test-other cuts") - return load_manifest_lazy(self.args.manifest_dir / "libriheavy_cuts_test_other.jsonl.gz") + return load_manifest_lazy( + self.args.manifest_dir / "libriheavy_cuts_test_other.jsonl.gz" + ) diff --git a/egs/libriheavy/ASR/zipformer/train.py b/egs/libriheavy/ASR/zipformer/train.py index b60d7d43b..b9e0dffd7 100644 --- a/egs/libriheavy/ASR/zipformer/train.py +++ b/egs/libriheavy/ASR/zipformer/train.py @@ -255,24 +255,17 @@ def add_model_arguments(parser: argparse.ArgumentParser): ) parser.add_argument( - "--use-ctc", - type=str2bool, - default=False, - help="If True, use CTC head.", + "--use-ctc", type=str2bool, default=False, help="If True, use CTC head.", ) - def get_parser(): parser = argparse.ArgumentParser( formatter_class=argparse.ArgumentDefaultsHelpFormatter ) parser.add_argument( - "--world-size", - type=int, - default=1, - help="Number of GPUs for DDP training.", + "--world-size", type=int, default=1, help="Number of GPUs for DDP training.", ) parser.add_argument( @@ -290,10 +283,7 @@ def get_parser(): ) parser.add_argument( - "--num-epochs", - type=int, - default=30, - help="Number of epochs to train.", + "--num-epochs", type=int, default=30, help="Number of epochs to train.", ) parser.add_argument( @@ -401,10 +391,7 @@ def get_parser(): ) parser.add_argument( - "--ctc-loss-scale", - type=float, - default=0.2, - help="Scale for CTC loss.", + "--ctc-loss-scale", type=float, default=0.2, help="Scale for CTC loss.", ) parser.add_argument( @@ -615,11 +602,11 @@ def get_joiner_model(params: AttributeDict) -> nn.Module: def get_model(params: AttributeDict) -> nn.Module: - assert ( - params.use_transducer or params.use_ctc - ), (f"At least one of them should be True, " + assert params.use_transducer or params.use_ctc, ( + f"At least one of them should be True, " f"but got params.use_transducer={params.use_transducer}, " - f"params.use_ctc={params.use_ctc}") + f"params.use_ctc={params.use_ctc}" + ) encoder_embed = get_encoder_embed(params) encoder = get_encoder_model(params) @@ -797,12 +784,12 @@ def compute_loss( batch_idx_train = params.batch_idx_train warm_step = params.warm_step - + texts = batch["supervisions"]["text"] - + y = sp.encode(texts, out_type=int) y = k2.RaggedTensor(y) - + with torch.set_grad_enabled(is_training): simple_loss, pruned_loss, ctc_loss = model( x=feature, @@ -820,17 +807,16 @@ def compute_loss( # take down the scale on the simple loss from 1.0 at the start # to params.simple_loss scale by warm_step. simple_loss_scale = ( - s if batch_idx_train >= warm_step + s + if batch_idx_train >= warm_step else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) ) pruned_loss_scale = ( - 1.0 if batch_idx_train >= warm_step + 1.0 + if batch_idx_train >= warm_step else 0.1 + 0.9 * (batch_idx_train / warm_step) ) - loss += ( - simple_loss_scale * simple_loss - + pruned_loss_scale * pruned_loss - ) + loss += simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss if params.use_ctc: loss += params.ctc_loss_scale * ctc_loss @@ -867,11 +853,7 @@ def compute_validation_loss( for batch_idx, batch in enumerate(valid_dl): loss, loss_info = compute_loss( - params=params, - model=model, - sp=sp, - batch=batch, - is_training=False, + params=params, model=model, sp=sp, batch=batch, is_training=False, ) assert loss.requires_grad is False tot_loss = tot_loss + loss_info @@ -961,11 +943,7 @@ def train_one_epoch( try: with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( - params=params, - model=model, - sp=sp, - batch=batch, - is_training=True, + params=params, model=model, sp=sp, batch=batch, is_training=True, ) # summary stats tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info @@ -975,7 +953,9 @@ def train_one_epoch( scaler.scale(loss).backward() scheduler.step_batch(params.batch_idx_train) # Use the number of hours of speech to adjust the learning rate - scheduler.step_epoch(params.batch_idx_train * params.max_duration * params.world_size / 3600) + scheduler.step_epoch( + params.batch_idx_train * params.max_duration * params.world_size / 3600 + ) scaler.step(optimizer) scaler.update() @@ -994,9 +974,7 @@ def train_one_epoch( and params.batch_idx_train % params.average_period == 0 ): update_averaged_model( - params=params, - model_cur=model, - model_avg=model_avg, + params=params, model_cur=model, model_avg=model_avg, ) if ( @@ -1016,9 +994,7 @@ def train_one_epoch( rank=rank, ) remove_checkpoints( - out_dir=params.exp_dir, - topk=params.keep_last_k, - rank=rank, + out_dir=params.exp_dir, topk=params.keep_last_k, rank=rank, ) if batch_idx % 100 == 0 and params.use_fp16: @@ -1180,14 +1156,13 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 2 ** 22 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) if params.inf_check: register_inf_check_hooks(model) - def normalize_text(c: Cut): text = remove_punc_to_upper(c.supervisions[0].text) c.supervisions[0].text = text @@ -1233,7 +1208,7 @@ def run(rank, world_size, args): libriheavy = LibriHeavyAsrDataModule(args) train_cuts = libriheavy.train_small_cuts() - if params.subset == 'M' or params.subset == 'L': + if params.subset == "M" or params.subset == "L": train_cuts += libriheavy.train_medium_cuts() if params.subset == "L": train_cuts += libriheavy.train_large_cuts() @@ -1322,9 +1297,7 @@ def run(rank, world_size, args): def display_and_save_batch( - batch: dict, - params: AttributeDict, - sp: spm.SentencePieceProcessor, + batch: dict, params: AttributeDict, sp: spm.SentencePieceProcessor, ) -> None: """Display the batch statistics and save the batch into disk. @@ -1371,11 +1344,7 @@ def scan_pessimistic_batches_for_oom( try: with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, _ = compute_loss( - params=params, - model=model, - sp=sp, - batch=batch, - is_training=True, + params=params, model=model, sp=sp, batch=batch, is_training=True, ) loss.backward() optimizer.zero_grad()