mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-07 16:14:17 +00:00
pre commit hook
This commit is contained in:
parent
02b4b469a2
commit
4fc1638959
@ -19,14 +19,13 @@ import argparse
|
||||
import logging
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from typing import Optional, Dict, Any
|
||||
|
||||
from tqdm import tqdm
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import torch
|
||||
from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy
|
||||
from lhotse.dataset import (
|
||||
CutMix,
|
||||
CutConcatenate,
|
||||
CutMix,
|
||||
DynamicBucketingSampler,
|
||||
K2SpeechRecognitionDataset,
|
||||
PrecomputedFeatures,
|
||||
@ -35,7 +34,7 @@ from lhotse.dataset import (
|
||||
from lhotse.dataset.input_strategies import OnTheFlyFeatures
|
||||
from lhotse.utils import fix_random_seed
|
||||
from torch.utils.data import DataLoader
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from icefall.utils import str2bool
|
||||
|
||||
@ -177,13 +176,17 @@ class SPGISpeechAsrDataModule:
|
||||
The state dict for the training sampler.
|
||||
"""
|
||||
logging.info("About to get Musan cuts")
|
||||
cuts_musan = load_manifest(self.args.manifest_dir / "cuts_musan.jsonl.gz")
|
||||
cuts_musan = load_manifest(
|
||||
self.args.manifest_dir / "cuts_musan.jsonl.gz"
|
||||
)
|
||||
|
||||
transforms = []
|
||||
if self.args.enable_musan:
|
||||
logging.info("Enable MUSAN")
|
||||
transforms.append(
|
||||
CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
|
||||
CutMix(
|
||||
cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True
|
||||
)
|
||||
)
|
||||
else:
|
||||
logging.info("Disable MUSAN")
|
||||
@ -205,7 +208,9 @@ class SPGISpeechAsrDataModule:
|
||||
input_transforms = []
|
||||
if self.args.enable_spec_aug:
|
||||
logging.info("Enable SpecAugment")
|
||||
logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
|
||||
logging.info(
|
||||
f"Time warp factor: {self.args.spec_aug_time_warp_factor}"
|
||||
)
|
||||
input_transforms.append(
|
||||
SpecAugment(
|
||||
time_warp_factor=self.args.spec_aug_time_warp_factor,
|
||||
@ -222,7 +227,9 @@ class SPGISpeechAsrDataModule:
|
||||
if self.args.on_the_fly_feats:
|
||||
train = K2SpeechRecognitionDataset(
|
||||
cut_transforms=transforms,
|
||||
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
|
||||
input_strategy=OnTheFlyFeatures(
|
||||
Fbank(FbankConfig(num_mel_bins=80))
|
||||
),
|
||||
input_transforms=input_transforms,
|
||||
)
|
||||
else:
|
||||
@ -275,7 +282,9 @@ class SPGISpeechAsrDataModule:
|
||||
if self.args.on_the_fly_feats:
|
||||
validate = K2SpeechRecognitionDataset(
|
||||
cut_transforms=transforms,
|
||||
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
|
||||
input_strategy=OnTheFlyFeatures(
|
||||
Fbank(FbankConfig(num_mel_bins=80))
|
||||
),
|
||||
)
|
||||
else:
|
||||
validate = K2SpeechRecognitionDataset(
|
||||
@ -319,7 +328,9 @@ class SPGISpeechAsrDataModule:
|
||||
@lru_cache()
|
||||
def train_cuts(self) -> CutSet:
|
||||
logging.info("About to get SPGISpeech train cuts")
|
||||
return load_manifest_lazy(self.args.manifest_dir / "cuts_train_shuf.jsonl.gz")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "cuts_train_shuf.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def dev_cuts(self) -> CutSet:
|
||||
|
@ -119,7 +119,8 @@ def get_parser():
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
|
||||
help="The context size in the decoder. 1 means bigram; "
|
||||
"2 means tri-gram",
|
||||
)
|
||||
|
||||
return parser
|
||||
@ -192,7 +193,9 @@ def main():
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
formatter = (
|
||||
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
)
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
||||
|
Loading…
x
Reference in New Issue
Block a user