change asr_datamodule.py

This commit is contained in:
luomingshuang 2022-03-07 15:59:47 +08:00
parent 7f9e426878
commit c26a7e4dc4

View File

@ -17,6 +17,7 @@
import argparse import argparse
import inspect
import logging import logging
from functools import lru_cache from functools import lru_cache
from pathlib import Path from pathlib import Path
@ -209,10 +210,22 @@ class TedLiumAsrDataModule:
logging.info( logging.info(
f"Time warp factor: {self.args.spec_aug_time_warp_factor}" f"Time warp factor: {self.args.spec_aug_time_warp_factor}"
) )
# Judge num_frame_masks for SpecAugment according to Lhotse'version
num_frame_masks = (
2
if (
inspect.signature(SpecAugment.__init__)
.parameters["num_frame_masks"]
.default
== 1
)
else 10
)
logging.info(f"Num frame mask: {num_frame_masks}")
input_transforms.append( input_transforms.append(
SpecAugment( SpecAugment(
time_warp_factor=self.args.spec_aug_time_warp_factor, time_warp_factor=self.args.spec_aug_time_warp_factor,
num_frame_masks=10, num_frame_masks=num_frame_masks,
features_mask_size=27, features_mask_size=27,
num_feature_masks=2, num_feature_masks=2,
frames_mask_size=100, frames_mask_size=100,
@ -229,7 +242,6 @@ class TedLiumAsrDataModule:
input_transforms=input_transforms, input_transforms=input_transforms,
return_cuts=self.args.return_cuts, return_cuts=self.args.return_cuts,
) )
if self.args.on_the_fly_feats: if self.args.on_the_fly_feats:
# NOTE: the PerturbSpeed transform should be added only if we # NOTE: the PerturbSpeed transform should be added only if we
# remove it from data prep stage. # remove it from data prep stage.
@ -268,7 +280,7 @@ class TedLiumAsrDataModule:
shuffle=self.args.shuffle, shuffle=self.args.shuffle,
) )
logging.info("About to create train dataloader") logging.info("About to create train dataloader")
# print(train)
train_dl = DataLoader( train_dl = DataLoader(
train, train,
sampler=train_sampler, sampler=train_sampler,
@ -341,6 +353,7 @@ class TedLiumAsrDataModule:
@lru_cache() @lru_cache()
def train_cuts(self) -> CutSet: def train_cuts(self) -> CutSet:
logging.info("About to get train cuts") logging.info("About to get train cuts")
print(self.args.manifest_dir)
return load_manifest(self.args.manifest_dir / "cuts_train.json.gz") return load_manifest(self.args.manifest_dir / "cuts_train.json.gz")
@lru_cache() @lru_cache()