mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-03 14:14:19 +00:00
change asr_datamodule.py
This commit is contained in:
parent
7f9e426878
commit
c26a7e4dc4
@ -17,6 +17,7 @@
|
||||
|
||||
|
||||
import argparse
|
||||
import inspect
|
||||
import logging
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
@ -209,10 +210,22 @@ class TedLiumAsrDataModule:
|
||||
logging.info(
|
||||
f"Time warp factor: {self.args.spec_aug_time_warp_factor}"
|
||||
)
|
||||
# Judge num_frame_masks for SpecAugment according to Lhotse'version
|
||||
num_frame_masks = (
|
||||
2
|
||||
if (
|
||||
inspect.signature(SpecAugment.__init__)
|
||||
.parameters["num_frame_masks"]
|
||||
.default
|
||||
== 1
|
||||
)
|
||||
else 10
|
||||
)
|
||||
logging.info(f"Num frame mask: {num_frame_masks}")
|
||||
input_transforms.append(
|
||||
SpecAugment(
|
||||
time_warp_factor=self.args.spec_aug_time_warp_factor,
|
||||
num_frame_masks=10,
|
||||
num_frame_masks=num_frame_masks,
|
||||
features_mask_size=27,
|
||||
num_feature_masks=2,
|
||||
frames_mask_size=100,
|
||||
@ -229,7 +242,6 @@ class TedLiumAsrDataModule:
|
||||
input_transforms=input_transforms,
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
|
||||
if self.args.on_the_fly_feats:
|
||||
# NOTE: the PerturbSpeed transform should be added only if we
|
||||
# remove it from data prep stage.
|
||||
@ -268,7 +280,7 @@ class TedLiumAsrDataModule:
|
||||
shuffle=self.args.shuffle,
|
||||
)
|
||||
logging.info("About to create train dataloader")
|
||||
|
||||
# print(train)
|
||||
train_dl = DataLoader(
|
||||
train,
|
||||
sampler=train_sampler,
|
||||
@ -341,6 +353,7 @@ class TedLiumAsrDataModule:
|
||||
@lru_cache()
|
||||
def train_cuts(self) -> CutSet:
|
||||
logging.info("About to get train cuts")
|
||||
print(self.args.manifest_dir)
|
||||
return load_manifest(self.args.manifest_dir / "cuts_train.json.gz")
|
||||
|
||||
@lru_cache()
|
||||
|
Loading…
x
Reference in New Issue
Block a user