mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-07 08:04:18 +00:00
pre commit hook
This commit is contained in:
parent
02b4b469a2
commit
4fc1638959
@ -19,14 +19,13 @@ import argparse
|
|||||||
import logging
|
import logging
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional, Dict, Any
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
|
import torch
|
||||||
from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy
|
from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy
|
||||||
from lhotse.dataset import (
|
from lhotse.dataset import (
|
||||||
CutMix,
|
|
||||||
CutConcatenate,
|
CutConcatenate,
|
||||||
|
CutMix,
|
||||||
DynamicBucketingSampler,
|
DynamicBucketingSampler,
|
||||||
K2SpeechRecognitionDataset,
|
K2SpeechRecognitionDataset,
|
||||||
PrecomputedFeatures,
|
PrecomputedFeatures,
|
||||||
@ -35,7 +34,7 @@ from lhotse.dataset import (
|
|||||||
from lhotse.dataset.input_strategies import OnTheFlyFeatures
|
from lhotse.dataset.input_strategies import OnTheFlyFeatures
|
||||||
from lhotse.utils import fix_random_seed
|
from lhotse.utils import fix_random_seed
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
import torch
|
from tqdm import tqdm
|
||||||
|
|
||||||
from icefall.utils import str2bool
|
from icefall.utils import str2bool
|
||||||
|
|
||||||
@ -177,13 +176,17 @@ class SPGISpeechAsrDataModule:
|
|||||||
The state dict for the training sampler.
|
The state dict for the training sampler.
|
||||||
"""
|
"""
|
||||||
logging.info("About to get Musan cuts")
|
logging.info("About to get Musan cuts")
|
||||||
cuts_musan = load_manifest(self.args.manifest_dir / "cuts_musan.jsonl.gz")
|
cuts_musan = load_manifest(
|
||||||
|
self.args.manifest_dir / "cuts_musan.jsonl.gz"
|
||||||
|
)
|
||||||
|
|
||||||
transforms = []
|
transforms = []
|
||||||
if self.args.enable_musan:
|
if self.args.enable_musan:
|
||||||
logging.info("Enable MUSAN")
|
logging.info("Enable MUSAN")
|
||||||
transforms.append(
|
transforms.append(
|
||||||
CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
|
CutMix(
|
||||||
|
cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True
|
||||||
|
)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logging.info("Disable MUSAN")
|
logging.info("Disable MUSAN")
|
||||||
@ -205,7 +208,9 @@ class SPGISpeechAsrDataModule:
|
|||||||
input_transforms = []
|
input_transforms = []
|
||||||
if self.args.enable_spec_aug:
|
if self.args.enable_spec_aug:
|
||||||
logging.info("Enable SpecAugment")
|
logging.info("Enable SpecAugment")
|
||||||
logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
|
logging.info(
|
||||||
|
f"Time warp factor: {self.args.spec_aug_time_warp_factor}"
|
||||||
|
)
|
||||||
input_transforms.append(
|
input_transforms.append(
|
||||||
SpecAugment(
|
SpecAugment(
|
||||||
time_warp_factor=self.args.spec_aug_time_warp_factor,
|
time_warp_factor=self.args.spec_aug_time_warp_factor,
|
||||||
@ -222,7 +227,9 @@ class SPGISpeechAsrDataModule:
|
|||||||
if self.args.on_the_fly_feats:
|
if self.args.on_the_fly_feats:
|
||||||
train = K2SpeechRecognitionDataset(
|
train = K2SpeechRecognitionDataset(
|
||||||
cut_transforms=transforms,
|
cut_transforms=transforms,
|
||||||
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
|
input_strategy=OnTheFlyFeatures(
|
||||||
|
Fbank(FbankConfig(num_mel_bins=80))
|
||||||
|
),
|
||||||
input_transforms=input_transforms,
|
input_transforms=input_transforms,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@ -275,7 +282,9 @@ class SPGISpeechAsrDataModule:
|
|||||||
if self.args.on_the_fly_feats:
|
if self.args.on_the_fly_feats:
|
||||||
validate = K2SpeechRecognitionDataset(
|
validate = K2SpeechRecognitionDataset(
|
||||||
cut_transforms=transforms,
|
cut_transforms=transforms,
|
||||||
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
|
input_strategy=OnTheFlyFeatures(
|
||||||
|
Fbank(FbankConfig(num_mel_bins=80))
|
||||||
|
),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
validate = K2SpeechRecognitionDataset(
|
validate = K2SpeechRecognitionDataset(
|
||||||
@ -319,7 +328,9 @@ class SPGISpeechAsrDataModule:
|
|||||||
@lru_cache()
|
@lru_cache()
|
||||||
def train_cuts(self) -> CutSet:
|
def train_cuts(self) -> CutSet:
|
||||||
logging.info("About to get SPGISpeech train cuts")
|
logging.info("About to get SPGISpeech train cuts")
|
||||||
return load_manifest_lazy(self.args.manifest_dir / "cuts_train_shuf.jsonl.gz")
|
return load_manifest_lazy(
|
||||||
|
self.args.manifest_dir / "cuts_train_shuf.jsonl.gz"
|
||||||
|
)
|
||||||
|
|
||||||
@lru_cache()
|
@lru_cache()
|
||||||
def dev_cuts(self) -> CutSet:
|
def dev_cuts(self) -> CutSet:
|
||||||
|
@ -119,7 +119,8 @@ def get_parser():
|
|||||||
"--context-size",
|
"--context-size",
|
||||||
type=int,
|
type=int,
|
||||||
default=2,
|
default=2,
|
||||||
help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
|
help="The context size in the decoder. 1 means bigram; "
|
||||||
|
"2 means tri-gram",
|
||||||
)
|
)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
@ -192,7 +193,9 @@ def main():
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
formatter = (
|
||||||
|
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
|
)
|
||||||
|
|
||||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
main()
|
main()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user