pre commit hook

This commit is contained in:
Desh Raj 2022-05-14 10:41:06 -04:00
parent 02b4b469a2
commit 4fc1638959
2 changed files with 27 additions and 13 deletions

View File

@ -19,14 +19,13 @@ import argparse
import logging import logging
from functools import lru_cache from functools import lru_cache
from pathlib import Path from pathlib import Path
from typing import Optional, Dict, Any from typing import Any, Dict, Optional
from tqdm import tqdm
import torch
from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy
from lhotse.dataset import ( from lhotse.dataset import (
CutMix,
CutConcatenate, CutConcatenate,
CutMix,
DynamicBucketingSampler, DynamicBucketingSampler,
K2SpeechRecognitionDataset, K2SpeechRecognitionDataset,
PrecomputedFeatures, PrecomputedFeatures,
@ -35,7 +34,7 @@ from lhotse.dataset import (
from lhotse.dataset.input_strategies import OnTheFlyFeatures from lhotse.dataset.input_strategies import OnTheFlyFeatures
from lhotse.utils import fix_random_seed from lhotse.utils import fix_random_seed
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
import torch from tqdm import tqdm
from icefall.utils import str2bool from icefall.utils import str2bool
@ -177,13 +176,17 @@ class SPGISpeechAsrDataModule:
The state dict for the training sampler. The state dict for the training sampler.
""" """
logging.info("About to get Musan cuts") logging.info("About to get Musan cuts")
cuts_musan = load_manifest(self.args.manifest_dir / "cuts_musan.jsonl.gz") cuts_musan = load_manifest(
self.args.manifest_dir / "cuts_musan.jsonl.gz"
)
transforms = [] transforms = []
if self.args.enable_musan: if self.args.enable_musan:
logging.info("Enable MUSAN") logging.info("Enable MUSAN")
transforms.append( transforms.append(
CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) CutMix(
cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True
)
) )
else: else:
logging.info("Disable MUSAN") logging.info("Disable MUSAN")
@ -205,7 +208,9 @@ class SPGISpeechAsrDataModule:
input_transforms = [] input_transforms = []
if self.args.enable_spec_aug: if self.args.enable_spec_aug:
logging.info("Enable SpecAugment") logging.info("Enable SpecAugment")
logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") logging.info(
f"Time warp factor: {self.args.spec_aug_time_warp_factor}"
)
input_transforms.append( input_transforms.append(
SpecAugment( SpecAugment(
time_warp_factor=self.args.spec_aug_time_warp_factor, time_warp_factor=self.args.spec_aug_time_warp_factor,
@ -222,7 +227,9 @@ class SPGISpeechAsrDataModule:
if self.args.on_the_fly_feats: if self.args.on_the_fly_feats:
train = K2SpeechRecognitionDataset( train = K2SpeechRecognitionDataset(
cut_transforms=transforms, cut_transforms=transforms,
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), input_strategy=OnTheFlyFeatures(
Fbank(FbankConfig(num_mel_bins=80))
),
input_transforms=input_transforms, input_transforms=input_transforms,
) )
else: else:
@ -275,7 +282,9 @@ class SPGISpeechAsrDataModule:
if self.args.on_the_fly_feats: if self.args.on_the_fly_feats:
validate = K2SpeechRecognitionDataset( validate = K2SpeechRecognitionDataset(
cut_transforms=transforms, cut_transforms=transforms,
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), input_strategy=OnTheFlyFeatures(
Fbank(FbankConfig(num_mel_bins=80))
),
) )
else: else:
validate = K2SpeechRecognitionDataset( validate = K2SpeechRecognitionDataset(
@ -319,7 +328,9 @@ class SPGISpeechAsrDataModule:
@lru_cache() @lru_cache()
def train_cuts(self) -> CutSet: def train_cuts(self) -> CutSet:
logging.info("About to get SPGISpeech train cuts") logging.info("About to get SPGISpeech train cuts")
return load_manifest_lazy(self.args.manifest_dir / "cuts_train_shuf.jsonl.gz") return load_manifest_lazy(
self.args.manifest_dir / "cuts_train_shuf.jsonl.gz"
)
@lru_cache() @lru_cache()
def dev_cuts(self) -> CutSet: def dev_cuts(self) -> CutSet:

View File

@ -119,7 +119,8 @@ def get_parser():
"--context-size", "--context-size",
type=int, type=int,
default=2, default=2,
help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
) )
return parser return parser
@ -192,7 +193,9 @@ def main():
if __name__ == "__main__": if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO) logging.basicConfig(format=formatter, level=logging.INFO)
main() main()