pre commit hook

This commit is contained in:
Desh Raj 2022-05-14 10:41:06 -04:00
parent 02b4b469a2
commit 4fc1638959
2 changed files with 27 additions and 13 deletions

View File

@ -19,14 +19,13 @@ import argparse
import logging
from functools import lru_cache
from pathlib import Path
from typing import Optional, Dict, Any
from tqdm import tqdm
from typing import Any, Dict, Optional
import torch
from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy
from lhotse.dataset import (
CutMix,
CutConcatenate,
CutMix,
DynamicBucketingSampler,
K2SpeechRecognitionDataset,
PrecomputedFeatures,
@ -35,7 +34,7 @@ from lhotse.dataset import (
from lhotse.dataset.input_strategies import OnTheFlyFeatures
from lhotse.utils import fix_random_seed
from torch.utils.data import DataLoader
import torch
from tqdm import tqdm
from icefall.utils import str2bool
@ -177,13 +176,17 @@ class SPGISpeechAsrDataModule:
The state dict for the training sampler.
"""
logging.info("About to get Musan cuts")
cuts_musan = load_manifest(self.args.manifest_dir / "cuts_musan.jsonl.gz")
cuts_musan = load_manifest(
self.args.manifest_dir / "cuts_musan.jsonl.gz"
)
transforms = []
if self.args.enable_musan:
logging.info("Enable MUSAN")
transforms.append(
CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
CutMix(
cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True
)
)
else:
logging.info("Disable MUSAN")
@ -205,7 +208,9 @@ class SPGISpeechAsrDataModule:
input_transforms = []
if self.args.enable_spec_aug:
logging.info("Enable SpecAugment")
logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
logging.info(
f"Time warp factor: {self.args.spec_aug_time_warp_factor}"
)
input_transforms.append(
SpecAugment(
time_warp_factor=self.args.spec_aug_time_warp_factor,
@ -222,7 +227,9 @@ class SPGISpeechAsrDataModule:
if self.args.on_the_fly_feats:
train = K2SpeechRecognitionDataset(
cut_transforms=transforms,
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
input_strategy=OnTheFlyFeatures(
Fbank(FbankConfig(num_mel_bins=80))
),
input_transforms=input_transforms,
)
else:
@ -275,7 +282,9 @@ class SPGISpeechAsrDataModule:
if self.args.on_the_fly_feats:
validate = K2SpeechRecognitionDataset(
cut_transforms=transforms,
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
input_strategy=OnTheFlyFeatures(
Fbank(FbankConfig(num_mel_bins=80))
),
)
else:
validate = K2SpeechRecognitionDataset(
@ -319,7 +328,9 @@ class SPGISpeechAsrDataModule:
@lru_cache()
def train_cuts(self) -> CutSet:
logging.info("About to get SPGISpeech train cuts")
return load_manifest_lazy(self.args.manifest_dir / "cuts_train_shuf.jsonl.gz")
return load_manifest_lazy(
self.args.manifest_dir / "cuts_train_shuf.jsonl.gz"
)
@lru_cache()
def dev_cuts(self) -> CutSet:

View File

@ -119,7 +119,8 @@ def get_parser():
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
)
return parser
@ -192,7 +193,9 @@ def main():
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
main()