From 4fc1638959b4c053b3a4e81e24a06c4b818a7d34 Mon Sep 17 00:00:00 2001 From: Desh Raj Date: Sat, 14 May 2022 10:41:06 -0400 Subject: [PATCH] pre commit hook --- .../asr_datamodule.py | 33 ++++++++++++------- .../pruned_transducer_stateless2/export.py | 7 ++-- 2 files changed, 27 insertions(+), 13 deletions(-) diff --git a/egs/spgispeech/ASR/pruned_transducer_stateless2/asr_datamodule.py b/egs/spgispeech/ASR/pruned_transducer_stateless2/asr_datamodule.py index 0d76b7d4d..f165f6e60 100644 --- a/egs/spgispeech/ASR/pruned_transducer_stateless2/asr_datamodule.py +++ b/egs/spgispeech/ASR/pruned_transducer_stateless2/asr_datamodule.py @@ -19,14 +19,13 @@ import argparse import logging from functools import lru_cache from pathlib import Path -from typing import Optional, Dict, Any - -from tqdm import tqdm +from typing import Any, Dict, Optional +import torch from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy from lhotse.dataset import ( - CutMix, CutConcatenate, + CutMix, DynamicBucketingSampler, K2SpeechRecognitionDataset, PrecomputedFeatures, @@ -35,7 +34,7 @@ from lhotse.dataset import ( from lhotse.dataset.input_strategies import OnTheFlyFeatures from lhotse.utils import fix_random_seed from torch.utils.data import DataLoader -import torch +from tqdm import tqdm from icefall.utils import str2bool @@ -177,13 +176,17 @@ class SPGISpeechAsrDataModule: The state dict for the training sampler. """ logging.info("About to get Musan cuts") - cuts_musan = load_manifest(self.args.manifest_dir / "cuts_musan.jsonl.gz") + cuts_musan = load_manifest( + self.args.manifest_dir / "cuts_musan.jsonl.gz" + ) transforms = [] if self.args.enable_musan: logging.info("Enable MUSAN") transforms.append( - CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) + CutMix( + cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True + ) ) else: logging.info("Disable MUSAN") @@ -205,7 +208,9 @@ class SPGISpeechAsrDataModule: input_transforms = [] if self.args.enable_spec_aug: logging.info("Enable SpecAugment") - logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") + logging.info( + f"Time warp factor: {self.args.spec_aug_time_warp_factor}" + ) input_transforms.append( SpecAugment( time_warp_factor=self.args.spec_aug_time_warp_factor, @@ -222,7 +227,9 @@ class SPGISpeechAsrDataModule: if self.args.on_the_fly_feats: train = K2SpeechRecognitionDataset( cut_transforms=transforms, - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), + input_strategy=OnTheFlyFeatures( + Fbank(FbankConfig(num_mel_bins=80)) + ), input_transforms=input_transforms, ) else: @@ -275,7 +282,9 @@ class SPGISpeechAsrDataModule: if self.args.on_the_fly_feats: validate = K2SpeechRecognitionDataset( cut_transforms=transforms, - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), + input_strategy=OnTheFlyFeatures( + Fbank(FbankConfig(num_mel_bins=80)) + ), ) else: validate = K2SpeechRecognitionDataset( @@ -319,7 +328,9 @@ class SPGISpeechAsrDataModule: @lru_cache() def train_cuts(self) -> CutSet: logging.info("About to get SPGISpeech train cuts") - return load_manifest_lazy(self.args.manifest_dir / "cuts_train_shuf.jsonl.gz") + return load_manifest_lazy( + self.args.manifest_dir / "cuts_train_shuf.jsonl.gz" + ) @lru_cache() def dev_cuts(self) -> CutSet: diff --git a/egs/spgispeech/ASR/pruned_transducer_stateless2/export.py b/egs/spgispeech/ASR/pruned_transducer_stateless2/export.py index a5eca6e2d..6119ecf2c 100755 --- a/egs/spgispeech/ASR/pruned_transducer_stateless2/export.py +++ b/egs/spgispeech/ASR/pruned_transducer_stateless2/export.py @@ -119,7 +119,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) return parser @@ -192,7 +193,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main()