From b67c57de928a0bd6ef743ad20d4f9e4875c44087 Mon Sep 17 00:00:00 2001 From: luomingshuang <739314837@qq.com> Date: Mon, 7 Mar 2022 16:32:52 +0800 Subject: [PATCH] change for asr_datamodule.py --- .../transducer_stateless/asr_datamodule.py | 22 +++++++++---------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/egs/tedlium3/ASR/transducer_stateless/asr_datamodule.py b/egs/tedlium3/ASR/transducer_stateless/asr_datamodule.py index 02b383fec..537a9042b 100644 --- a/egs/tedlium3/ASR/transducer_stateless/asr_datamodule.py +++ b/egs/tedlium3/ASR/transducer_stateless/asr_datamodule.py @@ -210,17 +210,17 @@ class TedLiumAsrDataModule: logging.info( f"Time warp factor: {self.args.spec_aug_time_warp_factor}" ) - # Judge num_frame_masks according to Lhotse's version - num_frame_masks = ( - 2 - if ( - inspect.signature(SpecAugment.__init__) - .parameters["num_frame_masks"] - .default - == 1 - ) - else 10 - ) + # Design the value of num_frame_masks according to Lhotse's version. + # In different Lhotse's versions, the default of num_frame_masks is + # different. + num_frame_masks = 10 + num_frame_masks_parameter = inspect.signature( + SpecAugment.__init__ + ).parameters["num_frame_masks"] + if num_frame_masks_parameter.default == 1: + num_frame_masks = 2 + else: + num_frame_masks = 10 logging.info(f"Num frame mask: {num_frame_masks}") input_transforms.append( SpecAugment(