mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Update asr_datamodule.py
This commit is contained in:
parent
7472ef7d0e
commit
ecfb28da20
@ -218,11 +218,9 @@ class LibriSpeechAsrDataModule:
|
|||||||
input_transforms.append(
|
input_transforms.append(
|
||||||
SpecAugment(
|
SpecAugment(
|
||||||
time_warp_factor=self.args.spec_aug_time_warp_factor,
|
time_warp_factor=self.args.spec_aug_time_warp_factor,
|
||||||
num_frame_masks=10,
|
|
||||||
features_mask_size=27,
|
features_mask_size=27,
|
||||||
num_feature_masks=2,
|
num_feature_masks=2,
|
||||||
frames_mask_size=100,
|
frames_mask_size=100,
|
||||||
max_frames_mask_fraction=0.4,
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@ -411,9 +409,9 @@ class SpecAugment(torch.nn.Module):
|
|||||||
time_warp_factor: Optional[int] = 80,
|
time_warp_factor: Optional[int] = 80,
|
||||||
num_feature_masks: int = 1,
|
num_feature_masks: int = 1,
|
||||||
features_mask_size: int = 13,
|
features_mask_size: int = 13,
|
||||||
num_frame_masks: int = 1,
|
num_frame_masks: int = 10,
|
||||||
frames_mask_size: int = 70,
|
frames_mask_size: int = 70,
|
||||||
max_frames_mask_fraction: float = 0.2,
|
max_frames_mask_fraction: float = 0.4,
|
||||||
p=0.5,
|
p=0.5,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@ -425,10 +423,11 @@ class SpecAugment(torch.nn.Module):
|
|||||||
:param features_mask_size: the width of the feature mask (expressed in the number of masked feature bins).
|
:param features_mask_size: the width of the feature mask (expressed in the number of masked feature bins).
|
||||||
This is the ``F`` parameter from the SpecAugment paper.
|
This is the ``F`` parameter from the SpecAugment paper.
|
||||||
:param num_frame_masks: how many frame (temporal) masks should be applied. Set to ``0`` to disable.
|
:param num_frame_masks: how many frame (temporal) masks should be applied. Set to ``0`` to disable.
|
||||||
|
This is the maximum (it's also constrained by max_frames_mask_fraction).
|
||||||
:param frames_mask_size: the width of the frame (temporal) masks (expressed in the number of masked frames).
|
:param frames_mask_size: the width of the frame (temporal) masks (expressed in the number of masked frames).
|
||||||
This is the ``T`` parameter from the SpecAugment paper.
|
This is the ``T`` parameter from the SpecAugment paper.
|
||||||
:param max_frames_mask_fraction: limits the size of the frame (temporal) mask to this value times the length
|
:param max_frames_mask_fraction: limits the size of the frame (temporal) mask to this value times the length
|
||||||
of the utterance (or supervision segment).
|
of the utterance (or supervision segment). It is the total masked fraction including all masked regions.
|
||||||
This is the parameter denoted by ``p`` in the SpecAugment paper.
|
This is the parameter denoted by ``p`` in the SpecAugment paper.
|
||||||
:param p: the probability of applying this transform.
|
:param p: the probability of applying this transform.
|
||||||
It is different from ``p`` in the SpecAugment paper!
|
It is different from ``p`` in the SpecAugment paper!
|
||||||
@ -517,11 +516,12 @@ class SpecAugment(torch.nn.Module):
|
|||||||
mask_value=mean,
|
mask_value=mean,
|
||||||
axis=2,
|
axis=2,
|
||||||
).squeeze(0)
|
).squeeze(0)
|
||||||
for _ in range(self.num_frame_masks):
|
|
||||||
_max_tot_mask_frames = self.max_frames_mask_fraction * features.size(0)
|
_max_tot_mask_frames = self.max_frames_mask_fraction * features.size(0)
|
||||||
num_frame_masks = min(self.num_frame_masks, math.ceil(_max_tot_mask_frames / self.frames_mask_size))
|
num_frame_masks = min(self.num_frame_masks, math.ceil(_max_tot_mask_frames / self.frames_mask_size))
|
||||||
max_mask_frames = min(self.frames_mask_size, _max_tot_mask_frames // num_frame_masks)
|
max_mask_frames = min(self.frames_mask_size, _max_tot_mask_frames // num_frame_masks)
|
||||||
|
|
||||||
|
for _ in range(num_frame_masks):
|
||||||
features = mask_along_axis(
|
features = mask_along_axis(
|
||||||
features.unsqueeze(0),
|
features.unsqueeze(0),
|
||||||
mask_param=max_mask_frames,
|
mask_param=max_mask_frames,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user