mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-05 15:14:18 +00:00
change asr_datamodule.py
This commit is contained in:
parent
7f9e426878
commit
c26a7e4dc4
@ -17,6 +17,7 @@
|
|||||||
|
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
import inspect
|
||||||
import logging
|
import logging
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@ -209,10 +210,22 @@ class TedLiumAsrDataModule:
|
|||||||
logging.info(
|
logging.info(
|
||||||
f"Time warp factor: {self.args.spec_aug_time_warp_factor}"
|
f"Time warp factor: {self.args.spec_aug_time_warp_factor}"
|
||||||
)
|
)
|
||||||
|
# Judge num_frame_masks for SpecAugment according to Lhotse'version
|
||||||
|
num_frame_masks = (
|
||||||
|
2
|
||||||
|
if (
|
||||||
|
inspect.signature(SpecAugment.__init__)
|
||||||
|
.parameters["num_frame_masks"]
|
||||||
|
.default
|
||||||
|
== 1
|
||||||
|
)
|
||||||
|
else 10
|
||||||
|
)
|
||||||
|
logging.info(f"Num frame mask: {num_frame_masks}")
|
||||||
input_transforms.append(
|
input_transforms.append(
|
||||||
SpecAugment(
|
SpecAugment(
|
||||||
time_warp_factor=self.args.spec_aug_time_warp_factor,
|
time_warp_factor=self.args.spec_aug_time_warp_factor,
|
||||||
num_frame_masks=10,
|
num_frame_masks=num_frame_masks,
|
||||||
features_mask_size=27,
|
features_mask_size=27,
|
||||||
num_feature_masks=2,
|
num_feature_masks=2,
|
||||||
frames_mask_size=100,
|
frames_mask_size=100,
|
||||||
@ -229,7 +242,6 @@ class TedLiumAsrDataModule:
|
|||||||
input_transforms=input_transforms,
|
input_transforms=input_transforms,
|
||||||
return_cuts=self.args.return_cuts,
|
return_cuts=self.args.return_cuts,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.args.on_the_fly_feats:
|
if self.args.on_the_fly_feats:
|
||||||
# NOTE: the PerturbSpeed transform should be added only if we
|
# NOTE: the PerturbSpeed transform should be added only if we
|
||||||
# remove it from data prep stage.
|
# remove it from data prep stage.
|
||||||
@ -268,7 +280,7 @@ class TedLiumAsrDataModule:
|
|||||||
shuffle=self.args.shuffle,
|
shuffle=self.args.shuffle,
|
||||||
)
|
)
|
||||||
logging.info("About to create train dataloader")
|
logging.info("About to create train dataloader")
|
||||||
|
# print(train)
|
||||||
train_dl = DataLoader(
|
train_dl = DataLoader(
|
||||||
train,
|
train,
|
||||||
sampler=train_sampler,
|
sampler=train_sampler,
|
||||||
@ -341,6 +353,7 @@ class TedLiumAsrDataModule:
|
|||||||
@lru_cache()
|
@lru_cache()
|
||||||
def train_cuts(self) -> CutSet:
|
def train_cuts(self) -> CutSet:
|
||||||
logging.info("About to get train cuts")
|
logging.info("About to get train cuts")
|
||||||
|
print(self.args.manifest_dir)
|
||||||
return load_manifest(self.args.manifest_dir / "cuts_train.json.gz")
|
return load_manifest(self.args.manifest_dir / "cuts_train.json.gz")
|
||||||
|
|
||||||
@lru_cache()
|
@lru_cache()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user