diff --git a/egs/tedlium3/ASR/transducer_stateless/asr_datamodule.py b/egs/tedlium3/ASR/transducer_stateless/asr_datamodule.py index 7b35166a2..416d2d373 100644 --- a/egs/tedlium3/ASR/transducer_stateless/asr_datamodule.py +++ b/egs/tedlium3/ASR/transducer_stateless/asr_datamodule.py @@ -17,6 +17,7 @@ import argparse +import inspect import logging from functools import lru_cache from pathlib import Path @@ -209,10 +210,22 @@ class TedLiumAsrDataModule: logging.info( f"Time warp factor: {self.args.spec_aug_time_warp_factor}" ) + # Judge num_frame_masks for SpecAugment according to Lhotse'version + num_frame_masks = ( + 2 + if ( + inspect.signature(SpecAugment.__init__) + .parameters["num_frame_masks"] + .default + == 1 + ) + else 10 + ) + logging.info(f"Num frame mask: {num_frame_masks}") input_transforms.append( SpecAugment( time_warp_factor=self.args.spec_aug_time_warp_factor, - num_frame_masks=10, + num_frame_masks=num_frame_masks, features_mask_size=27, num_feature_masks=2, frames_mask_size=100, @@ -229,7 +242,6 @@ class TedLiumAsrDataModule: input_transforms=input_transforms, return_cuts=self.args.return_cuts, ) - if self.args.on_the_fly_feats: # NOTE: the PerturbSpeed transform should be added only if we # remove it from data prep stage. @@ -268,7 +280,7 @@ class TedLiumAsrDataModule: shuffle=self.args.shuffle, ) logging.info("About to create train dataloader") - + # print(train) train_dl = DataLoader( train, sampler=train_sampler, @@ -341,6 +353,7 @@ class TedLiumAsrDataModule: @lru_cache() def train_cuts(self) -> CutSet: logging.info("About to get train cuts") + print(self.args.manifest_dir) return load_manifest(self.args.manifest_dir / "cuts_train.json.gz") @lru_cache()