Fix num_time_masks code; revert 0.8 to 0.9

This commit is contained in:
Daniel Povey 2022-02-10 15:53:11 +08:00
parent c170c53006
commit 4cd2c02fff
2 changed files with 6 additions and 7 deletions

View File

@ -223,7 +223,7 @@ class LibriSpeechAsrDataModule:
num_feature_masks=2,
frames_mask_size=100,
max_frames_mask_fraction=0.2,
p=0.8
p=0.9
)
)
else:
@ -518,11 +518,10 @@ class SpecAugment(torch.nn.Module):
mask_value=mean,
axis=2,
).squeeze(0)
for _ in range(self.num_frame_masks):
_max_tot_mask_frames = self.max_frames_mask_fraction * features.size(0)
num_frame_masks = min(self.num_frame_masks, math.ceil(_max_tot_mask_frames / self.frames_mask_size))
max_mask_frames = min(self.frames_mask_size, _max_tot_mask_frames // num_frame_masks)
for _ in range(num_frame_masks):
features = mask_along_axis(
features.unsqueeze(0),
mask_param=max_mask_frames,

View File

@ -109,7 +109,7 @@ def get_parser():
parser.add_argument(
"--exp-dir",
type=str,
default="transducer_stateless/exp-100h-relu-specaugmod_p0.8_0.2",
default="transducer_stateless/exp-100h-relu-specaugmod_p0.9_0.2_fix",
help="""The experiment dir.
It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved