Update asr_datamodule.py

This commit is contained in:
Mingshuang Luo 2022-02-10 16:08:22 +08:00
parent 7472ef7d0e
commit ecfb28da20

View File

@ -218,11 +218,9 @@ class LibriSpeechAsrDataModule:
input_transforms.append(
SpecAugment(
time_warp_factor=self.args.spec_aug_time_warp_factor,
num_frame_masks=10,
features_mask_size=27,
num_feature_masks=2,
frames_mask_size=100,
max_frames_mask_fraction=0.4,
)
)
else:
@ -411,9 +409,9 @@ class SpecAugment(torch.nn.Module):
time_warp_factor: Optional[int] = 80,
num_feature_masks: int = 1,
features_mask_size: int = 13,
num_frame_masks: int = 1,
num_frame_masks: int = 10,
frames_mask_size: int = 70,
max_frames_mask_fraction: float = 0.2,
max_frames_mask_fraction: float = 0.4,
p=0.5,
):
"""
@ -425,10 +423,11 @@ class SpecAugment(torch.nn.Module):
:param features_mask_size: the width of the feature mask (expressed in the number of masked feature bins).
This is the ``F`` parameter from the SpecAugment paper.
:param num_frame_masks: how many frame (temporal) masks should be applied. Set to ``0`` to disable.
This is the maximum (it's also constrained by max_frames_mask_fraction).
:param frames_mask_size: the width of the frame (temporal) masks (expressed in the number of masked frames).
This is the ``T`` parameter from the SpecAugment paper.
:param max_frames_mask_fraction: limits the size of the frame (temporal) mask to this value times the length
of the utterance (or supervision segment).
of the utterance (or supervision segment). It is the total masked fraction including all masked regions.
This is the parameter denoted by ``p`` in the SpecAugment paper.
:param p: the probability of applying this transform.
It is different from ``p`` in the SpecAugment paper!
@ -517,11 +516,12 @@ class SpecAugment(torch.nn.Module):
mask_value=mean,
axis=2,
).squeeze(0)
for _ in range(self.num_frame_masks):
_max_tot_mask_frames = self.max_frames_mask_fraction * features.size(0)
num_frame_masks = min(self.num_frame_masks, math.ceil(_max_tot_mask_frames / self.frames_mask_size))
max_mask_frames = min(self.frames_mask_size, _max_tot_mask_frames // num_frame_masks)
_max_tot_mask_frames = self.max_frames_mask_fraction * features.size(0)
num_frame_masks = min(self.num_frame_masks, math.ceil(_max_tot_mask_frames / self.frames_mask_size))
max_mask_frames = min(self.frames_mask_size, _max_tot_mask_frames // num_frame_masks)
for _ in range(num_frame_masks):
features = mask_along_axis(
features.unsqueeze(0),
mask_param=max_mask_frames,