diff --git a/egs/librispeech/ASR/conformer_ctc/asr_datamodule.py b/egs/librispeech/ASR/conformer_ctc/asr_datamodule.py index e5fcc5893..7570faccc 120000 --- a/egs/librispeech/ASR/conformer_ctc/asr_datamodule.py +++ b/egs/librispeech/ASR/conformer_ctc/asr_datamodule.py @@ -218,11 +218,9 @@ class LibriSpeechAsrDataModule: input_transforms.append( SpecAugment( time_warp_factor=self.args.spec_aug_time_warp_factor, - num_frame_masks=10, features_mask_size=27, num_feature_masks=2, frames_mask_size=100, - max_frames_mask_fraction=0.4, ) ) else: @@ -411,9 +409,9 @@ class SpecAugment(torch.nn.Module): time_warp_factor: Optional[int] = 80, num_feature_masks: int = 1, features_mask_size: int = 13, - num_frame_masks: int = 1, + num_frame_masks: int = 10, frames_mask_size: int = 70, - max_frames_mask_fraction: float = 0.2, + max_frames_mask_fraction: float = 0.4, p=0.5, ): """ @@ -425,10 +423,11 @@ class SpecAugment(torch.nn.Module): :param features_mask_size: the width of the feature mask (expressed in the number of masked feature bins). This is the ``F`` parameter from the SpecAugment paper. :param num_frame_masks: how many frame (temporal) masks should be applied. Set to ``0`` to disable. + This is the maximum (it's also constrained by max_frames_mask_fraction). :param frames_mask_size: the width of the frame (temporal) masks (expressed in the number of masked frames). This is the ``T`` parameter from the SpecAugment paper. :param max_frames_mask_fraction: limits the size of the frame (temporal) mask to this value times the length - of the utterance (or supervision segment). + of the utterance (or supervision segment). It is the total masked fraction including all masked regions. This is the parameter denoted by ``p`` in the SpecAugment paper. :param p: the probability of applying this transform. It is different from ``p`` in the SpecAugment paper! @@ -517,11 +516,12 @@ class SpecAugment(torch.nn.Module): mask_value=mean, axis=2, ).squeeze(0) - for _ in range(self.num_frame_masks): - _max_tot_mask_frames = self.max_frames_mask_fraction * features.size(0) - num_frame_masks = min(self.num_frame_masks, math.ceil(_max_tot_mask_frames / self.frames_mask_size)) - max_mask_frames = min(self.frames_mask_size, _max_tot_mask_frames // num_frame_masks) - + + _max_tot_mask_frames = self.max_frames_mask_fraction * features.size(0) + num_frame_masks = min(self.num_frame_masks, math.ceil(_max_tot_mask_frames / self.frames_mask_size)) + max_mask_frames = min(self.frames_mask_size, _max_tot_mask_frames // num_frame_masks) + + for _ in range(num_frame_masks): features = mask_along_axis( features.unsqueeze(0), mask_param=max_mask_frames,