Fix num_time_masks code; revert 0.8 to 0.9

This commit is contained in:
Daniel Povey 2022-02-10 15:53:11 +08:00
parent c170c53006
commit 4cd2c02fff
2 changed files with 6 additions and 7 deletions

View File

@ -223,7 +223,7 @@ class LibriSpeechAsrDataModule:
num_feature_masks=2, num_feature_masks=2,
frames_mask_size=100, frames_mask_size=100,
max_frames_mask_fraction=0.2, max_frames_mask_fraction=0.2,
p=0.8 p=0.9
) )
) )
else: else:
@ -518,11 +518,10 @@ class SpecAugment(torch.nn.Module):
mask_value=mean, mask_value=mean,
axis=2, axis=2,
).squeeze(0) ).squeeze(0)
for _ in range(self.num_frame_masks): _max_tot_mask_frames = self.max_frames_mask_fraction * features.size(0)
_max_tot_mask_frames = self.max_frames_mask_fraction * features.size(0) num_frame_masks = min(self.num_frame_masks, math.ceil(_max_tot_mask_frames / self.frames_mask_size))
num_frame_masks = min(self.num_frame_masks, math.ceil(_max_tot_mask_frames / self.frames_mask_size)) max_mask_frames = min(self.frames_mask_size, _max_tot_mask_frames // num_frame_masks)
max_mask_frames = min(self.frames_mask_size, _max_tot_mask_frames // num_frame_masks) for _ in range(num_frame_masks):
features = mask_along_axis( features = mask_along_axis(
features.unsqueeze(0), features.unsqueeze(0),
mask_param=max_mask_frames, mask_param=max_mask_frames,

View File

@ -109,7 +109,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--exp-dir", "--exp-dir",
type=str, type=str,
default="transducer_stateless/exp-100h-relu-specaugmod_p0.8_0.2", default="transducer_stateless/exp-100h-relu-specaugmod_p0.9_0.2_fix",
help="""The experiment dir. help="""The experiment dir.
It specifies the directory where all training related It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved files, e.g., checkpoints, log, etc, are saved