mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-09 14:05:33 +00:00
Update asr_datamodule.py
This commit is contained in:
parent
7472ef7d0e
commit
ecfb28da20
@ -218,11 +218,9 @@ class LibriSpeechAsrDataModule:
|
||||
input_transforms.append(
|
||||
SpecAugment(
|
||||
time_warp_factor=self.args.spec_aug_time_warp_factor,
|
||||
num_frame_masks=10,
|
||||
features_mask_size=27,
|
||||
num_feature_masks=2,
|
||||
frames_mask_size=100,
|
||||
max_frames_mask_fraction=0.4,
|
||||
)
|
||||
)
|
||||
else:
|
||||
@ -411,9 +409,9 @@ class SpecAugment(torch.nn.Module):
|
||||
time_warp_factor: Optional[int] = 80,
|
||||
num_feature_masks: int = 1,
|
||||
features_mask_size: int = 13,
|
||||
num_frame_masks: int = 1,
|
||||
num_frame_masks: int = 10,
|
||||
frames_mask_size: int = 70,
|
||||
max_frames_mask_fraction: float = 0.2,
|
||||
max_frames_mask_fraction: float = 0.4,
|
||||
p=0.5,
|
||||
):
|
||||
"""
|
||||
@ -425,10 +423,11 @@ class SpecAugment(torch.nn.Module):
|
||||
:param features_mask_size: the width of the feature mask (expressed in the number of masked feature bins).
|
||||
This is the ``F`` parameter from the SpecAugment paper.
|
||||
:param num_frame_masks: how many frame (temporal) masks should be applied. Set to ``0`` to disable.
|
||||
This is the maximum (it's also constrained by max_frames_mask_fraction).
|
||||
:param frames_mask_size: the width of the frame (temporal) masks (expressed in the number of masked frames).
|
||||
This is the ``T`` parameter from the SpecAugment paper.
|
||||
:param max_frames_mask_fraction: limits the size of the frame (temporal) mask to this value times the length
|
||||
of the utterance (or supervision segment).
|
||||
of the utterance (or supervision segment). It is the total masked fraction including all masked regions.
|
||||
This is the parameter denoted by ``p`` in the SpecAugment paper.
|
||||
:param p: the probability of applying this transform.
|
||||
It is different from ``p`` in the SpecAugment paper!
|
||||
@ -517,11 +516,12 @@ class SpecAugment(torch.nn.Module):
|
||||
mask_value=mean,
|
||||
axis=2,
|
||||
).squeeze(0)
|
||||
for _ in range(self.num_frame_masks):
|
||||
_max_tot_mask_frames = self.max_frames_mask_fraction * features.size(0)
|
||||
num_frame_masks = min(self.num_frame_masks, math.ceil(_max_tot_mask_frames / self.frames_mask_size))
|
||||
max_mask_frames = min(self.frames_mask_size, _max_tot_mask_frames // num_frame_masks)
|
||||
|
||||
|
||||
_max_tot_mask_frames = self.max_frames_mask_fraction * features.size(0)
|
||||
num_frame_masks = min(self.num_frame_masks, math.ceil(_max_tot_mask_frames / self.frames_mask_size))
|
||||
max_mask_frames = min(self.frames_mask_size, _max_tot_mask_frames // num_frame_masks)
|
||||
|
||||
for _ in range(num_frame_masks):
|
||||
features = mask_along_axis(
|
||||
features.unsqueeze(0),
|
||||
mask_param=max_mask_frames,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user