diff --git a/egs/tedlium3/ASR/transducer_stateless/asr_datamodule.py b/egs/tedlium3/ASR/transducer_stateless/asr_datamodule.py index 02b383fec..537a9042b 100644 --- a/egs/tedlium3/ASR/transducer_stateless/asr_datamodule.py +++ b/egs/tedlium3/ASR/transducer_stateless/asr_datamodule.py @@ -210,17 +210,17 @@ class TedLiumAsrDataModule: logging.info( f"Time warp factor: {self.args.spec_aug_time_warp_factor}" ) - # Judge num_frame_masks according to Lhotse's version - num_frame_masks = ( - 2 - if ( - inspect.signature(SpecAugment.__init__) - .parameters["num_frame_masks"] - .default - == 1 - ) - else 10 - ) + # Design the value of num_frame_masks according to Lhotse's version. + # In different Lhotse's versions, the default of num_frame_masks is + # different. + num_frame_masks = 10 + num_frame_masks_parameter = inspect.signature( + SpecAugment.__init__ + ).parameters["num_frame_masks"] + if num_frame_masks_parameter.default == 1: + num_frame_masks = 2 + else: + num_frame_masks = 10 logging.info(f"Num frame mask: {num_frame_masks}") input_transforms.append( SpecAugment(