mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 18:12:19 +00:00
Fix num_time_masks code; revert 0.8 to 0.9
This commit is contained in:
parent
c170c53006
commit
4cd2c02fff
@ -223,7 +223,7 @@ class LibriSpeechAsrDataModule:
|
||||
num_feature_masks=2,
|
||||
frames_mask_size=100,
|
||||
max_frames_mask_fraction=0.2,
|
||||
p=0.8
|
||||
p=0.9
|
||||
)
|
||||
)
|
||||
else:
|
||||
@ -518,11 +518,10 @@ class SpecAugment(torch.nn.Module):
|
||||
mask_value=mean,
|
||||
axis=2,
|
||||
).squeeze(0)
|
||||
for _ in range(self.num_frame_masks):
|
||||
_max_tot_mask_frames = self.max_frames_mask_fraction * features.size(0)
|
||||
num_frame_masks = min(self.num_frame_masks, math.ceil(_max_tot_mask_frames / self.frames_mask_size))
|
||||
max_mask_frames = min(self.frames_mask_size, _max_tot_mask_frames // num_frame_masks)
|
||||
|
||||
for _ in range(num_frame_masks):
|
||||
features = mask_along_axis(
|
||||
features.unsqueeze(0),
|
||||
mask_param=max_mask_frames,
|
||||
|
@ -109,7 +109,7 @@ def get_parser():
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="transducer_stateless/exp-100h-relu-specaugmod_p0.8_0.2",
|
||||
default="transducer_stateless/exp-100h-relu-specaugmod_p0.9_0.2_fix",
|
||||
help="""The experiment dir.
|
||||
It specifies the directory where all training related
|
||||
files, e.g., checkpoints, log, etc, are saved
|
||||
|
Loading…
x
Reference in New Issue
Block a user