change asr_datamodule.py

This commit is contained in:
luomingshuang 2022-03-07 15:59:47 +08:00
parent 7f9e426878
commit c26a7e4dc4

View File

@ -17,6 +17,7 @@
import argparse
import inspect
import logging
from functools import lru_cache
from pathlib import Path
@ -209,10 +210,22 @@ class TedLiumAsrDataModule:
logging.info(
f"Time warp factor: {self.args.spec_aug_time_warp_factor}"
)
# Judge num_frame_masks for SpecAugment according to Lhotse'version
num_frame_masks = (
2
if (
inspect.signature(SpecAugment.__init__)
.parameters["num_frame_masks"]
.default
== 1
)
else 10
)
logging.info(f"Num frame mask: {num_frame_masks}")
input_transforms.append(
SpecAugment(
time_warp_factor=self.args.spec_aug_time_warp_factor,
num_frame_masks=10,
num_frame_masks=num_frame_masks,
features_mask_size=27,
num_feature_masks=2,
frames_mask_size=100,
@ -229,7 +242,6 @@ class TedLiumAsrDataModule:
input_transforms=input_transforms,
return_cuts=self.args.return_cuts,
)
if self.args.on_the_fly_feats:
# NOTE: the PerturbSpeed transform should be added only if we
# remove it from data prep stage.
@ -268,7 +280,7 @@ class TedLiumAsrDataModule:
shuffle=self.args.shuffle,
)
logging.info("About to create train dataloader")
# print(train)
train_dl = DataLoader(
train,
sampler=train_sampler,
@ -341,6 +353,7 @@ class TedLiumAsrDataModule:
@lru_cache()
def train_cuts(self) -> CutSet:
logging.info("About to get train cuts")
print(self.args.manifest_dir)
return load_manifest(self.args.manifest_dir / "cuts_train.json.gz")
@lru_cache()