diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py index e5fcc5893..a5ab012e3 100644 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py @@ -220,7 +220,7 @@ class LibriSpeechAsrDataModule: time_warp_factor=self.args.spec_aug_time_warp_factor, num_frame_masks=10, features_mask_size=27, - num_feature_masks=2, + num_feature_masks=10, frames_mask_size=100, max_frames_mask_fraction=0.4, ) @@ -521,7 +521,7 @@ class SpecAugment(torch.nn.Module): _max_tot_mask_frames = self.max_frames_mask_fraction * features.size(0) num_frame_masks = min(self.num_frame_masks, math.ceil(_max_tot_mask_frames / self.frames_mask_size)) max_mask_frames = min(self.frames_mask_size, _max_tot_mask_frames // num_frame_masks) - + features = mask_along_axis( features.unsqueeze(0), mask_param=max_mask_frames, @@ -591,4 +591,4 @@ def time_warp(features: torch.Tensor, factor: int) -> torch.Tensor: mode="bicubic", align_corners=False, ) - return torch.cat((left, right), dim=2).squeeze(0).squeeze(0) \ No newline at end of file + return torch.cat((left, right), dim=2).squeeze(0).squeeze(0)