mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Merge b52b5c683f4f7daef450eebf60f651100b916cdc into 5379c8e9fa13f6f2364b4a0db89fa3074266fb58
This commit is contained in:
commit
86f45caafc
287
egs/librispeech/ASR/pruned_transducer_stateless6/aug.py
Normal file
287
egs/librispeech/ASR/pruned_transducer_stateless6/aug.py
Normal file
@ -0,0 +1,287 @@
|
|||||||
|
import math
|
||||||
|
import random
|
||||||
|
from typing import Dict, Optional, Tuple
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
class SpecAugment(torch.nn.Module):
|
||||||
|
"""
|
||||||
|
SpecAugment performs three augmentations:
|
||||||
|
- time warping of the feature matrix
|
||||||
|
- masking of ranges of features (frequency bands)
|
||||||
|
- masking of ranges of frames (time)
|
||||||
|
|
||||||
|
The current implementation works with batches, but processes each example separately
|
||||||
|
in a loop rather than simultaneously to achieve different augmentation parameters for
|
||||||
|
each example.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
time_warp_factor: Optional[int] = 80,
|
||||||
|
num_feature_masks: int = 2,
|
||||||
|
features_mask_size: int = 27,
|
||||||
|
num_frame_masks: int = 10,
|
||||||
|
frames_mask_size: int = 100,
|
||||||
|
max_frames_mask_fraction: float = 0.15,
|
||||||
|
p=0.9,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
SpecAugment's constructor.
|
||||||
|
|
||||||
|
:param time_warp_factor: parameter for the time warping; larger values mean more warping.
|
||||||
|
Set to ``None``, or less than ``1``, to disable.
|
||||||
|
:param num_feature_masks: how many feature masks should be applied. Set to ``0`` to disable.
|
||||||
|
:param features_mask_size: the width of the feature mask (expressed in the number of masked feature bins).
|
||||||
|
This is the ``F`` parameter from the SpecAugment paper.
|
||||||
|
:param num_frame_masks: the number of masking regions for utterances. Set to ``0`` to disable.
|
||||||
|
:param frames_mask_size: the width of the frame (temporal) masks (expressed in the number of masked frames).
|
||||||
|
This is the ``T`` parameter from the SpecAugment paper.
|
||||||
|
:param max_frames_mask_fraction: limits the size of the frame (temporal) mask to this value times the length
|
||||||
|
of the utterance (or supervision segment).
|
||||||
|
This is the parameter denoted by ``p`` in the SpecAugment paper.
|
||||||
|
:param p: the probability of applying this transform.
|
||||||
|
It is different from ``p`` in the SpecAugment paper!
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
assert 0 <= p <= 1
|
||||||
|
assert num_feature_masks >= 0
|
||||||
|
assert num_frame_masks > 0
|
||||||
|
assert features_mask_size > 0
|
||||||
|
assert frames_mask_size > 0
|
||||||
|
self.time_warp_factor = time_warp_factor
|
||||||
|
self.num_feature_masks = num_feature_masks
|
||||||
|
self.features_mask_size = features_mask_size
|
||||||
|
self.num_frame_masks = num_frame_masks
|
||||||
|
self.frames_mask_size = frames_mask_size
|
||||||
|
self.max_frames_mask_fraction = max_frames_mask_fraction
|
||||||
|
self.p = p
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
features: torch.Tensor,
|
||||||
|
supervision_segments: Optional[torch.IntTensor] = None,
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Computes SpecAugment for a batch of feature matrices.
|
||||||
|
|
||||||
|
Since the batch will usually already be padded, the user can optionally
|
||||||
|
provide a ``supervision_segments`` tensor that will be used to apply SpecAugment
|
||||||
|
only to selected areas of the input. The format of this input is described below.
|
||||||
|
|
||||||
|
:param features: a batch of feature matrices with shape ``(B, T, F)``.
|
||||||
|
:param supervision_segments: an int tensor of shape ``(S, 3)``. ``S`` is the number of
|
||||||
|
supervision segments that exist in ``features`` -- there may be either
|
||||||
|
less or more than the batch size.
|
||||||
|
The second dimension encoder three kinds of information:
|
||||||
|
the sequence index of the corresponding feature matrix in `features`,
|
||||||
|
the start frame index, and the number of frames for each segment.
|
||||||
|
:return: an augmented tensor of shape ``(B, T, F)``.
|
||||||
|
"""
|
||||||
|
assert len(features.shape) == 3, (
|
||||||
|
"SpecAugment only supports batches of "
|
||||||
|
"single-channel feature matrices."
|
||||||
|
)
|
||||||
|
features = features.clone()
|
||||||
|
# 1 (True) represents masked area;
|
||||||
|
# 0 (False) represents original un-masked area.
|
||||||
|
time_masked_area = torch.zeros_like(features)
|
||||||
|
if supervision_segments is None:
|
||||||
|
# No supervisions - apply spec augment to full feature matrices.
|
||||||
|
for sequence_idx in range(features.size(0)):
|
||||||
|
(
|
||||||
|
features[sequence_idx],
|
||||||
|
time_masked_area[sequence_idx],
|
||||||
|
) = self._forward_single(features[sequence_idx])
|
||||||
|
|
||||||
|
else:
|
||||||
|
# Supervisions provided - we will apply time warping only on the supervised areas.
|
||||||
|
for sequence_idx, start_frame, num_frames in supervision_segments:
|
||||||
|
end_frame = start_frame + num_frames
|
||||||
|
(
|
||||||
|
features[sequence_idx, start_frame:end_frame],
|
||||||
|
time_masked_area[sequence_idx, start_frame:end_frame],
|
||||||
|
) = self._forward_single(
|
||||||
|
features[sequence_idx, start_frame:end_frame],
|
||||||
|
warp=True,
|
||||||
|
mask=False,
|
||||||
|
)
|
||||||
|
# ... and then time-mask the full feature matrices. Note that in this mode,
|
||||||
|
# it might happen that masks are applied to different sequences/examples
|
||||||
|
# than the time warping.
|
||||||
|
for sequence_idx in range(features.size(0)):
|
||||||
|
(
|
||||||
|
features[sequence_idx],
|
||||||
|
time_masked_area[sequence_idx],
|
||||||
|
) = self._forward_single(
|
||||||
|
features[sequence_idx], warp=False, mask=True
|
||||||
|
)
|
||||||
|
|
||||||
|
return features, time_masked_area
|
||||||
|
|
||||||
|
def _forward_single(
|
||||||
|
self, features: torch.Tensor, warp: bool = True, mask: bool = True
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Apply SpecAugment to a single feature matrix of shape (T, F).
|
||||||
|
"""
|
||||||
|
time_masked_area = torch.zeros_like(features)
|
||||||
|
if random.random() > self.p:
|
||||||
|
# Randomly choose whether this transform is applied
|
||||||
|
# No augmentation, no masked area.
|
||||||
|
return features, time_masked_area
|
||||||
|
if warp:
|
||||||
|
if self.time_warp_factor is not None and self.time_warp_factor >= 1:
|
||||||
|
features = time_warp(features, factor=self.time_warp_factor)
|
||||||
|
if mask:
|
||||||
|
mean = features.mean()
|
||||||
|
# Frequency masking
|
||||||
|
features, _ = mask_along_axis_optimized(
|
||||||
|
features,
|
||||||
|
mask_size=self.features_mask_size,
|
||||||
|
mask_times=self.num_feature_masks,
|
||||||
|
mask_value=mean,
|
||||||
|
axis=2,
|
||||||
|
)
|
||||||
|
# Time masking
|
||||||
|
max_tot_mask_frames = self.max_frames_mask_fraction * features.size(
|
||||||
|
0
|
||||||
|
)
|
||||||
|
num_frame_masks = min(
|
||||||
|
self.num_frame_masks,
|
||||||
|
math.ceil(max_tot_mask_frames / self.frames_mask_size),
|
||||||
|
)
|
||||||
|
max_mask_frames = min(
|
||||||
|
self.frames_mask_size, max_tot_mask_frames // num_frame_masks
|
||||||
|
)
|
||||||
|
features, time_masked_area = mask_along_axis_optimized(
|
||||||
|
features,
|
||||||
|
mask_size=max_mask_frames,
|
||||||
|
mask_times=num_frame_masks,
|
||||||
|
mask_value=mean,
|
||||||
|
axis=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
return features, time_masked_area
|
||||||
|
|
||||||
|
def state_dict(self) -> Dict:
|
||||||
|
return dict(
|
||||||
|
time_warp_factor=self.time_warp_factor,
|
||||||
|
num_feature_masks=self.num_feature_masks,
|
||||||
|
features_mask_size=self.features_mask_size,
|
||||||
|
num_frame_masks=self.num_frame_masks,
|
||||||
|
frames_mask_size=self.frames_mask_size,
|
||||||
|
max_frames_mask_fraction=self.max_frames_mask_fraction,
|
||||||
|
p=self.p,
|
||||||
|
)
|
||||||
|
|
||||||
|
def load_state_dict(self, state_dict: Dict):
|
||||||
|
self.time_warp_factor = state_dict.get(
|
||||||
|
"time_warp_factor", self.time_warp_factor
|
||||||
|
)
|
||||||
|
self.num_feature_masks = state_dict.get(
|
||||||
|
"num_feature_masks", self.num_feature_masks
|
||||||
|
)
|
||||||
|
self.features_mask_size = state_dict.get(
|
||||||
|
"features_mask_size", self.features_mask_size
|
||||||
|
)
|
||||||
|
self.num_frame_masks = state_dict.get(
|
||||||
|
"num_frame_masks", self.num_frame_masks
|
||||||
|
)
|
||||||
|
self.frames_mask_size = state_dict.get(
|
||||||
|
"frames_mask_size", self.frames_mask_size
|
||||||
|
)
|
||||||
|
self.max_frames_mask_fraction = state_dict.get(
|
||||||
|
"max_frames_mask_fraction", self.max_frames_mask_fraction
|
||||||
|
)
|
||||||
|
self.p = state_dict.get("p", self.p)
|
||||||
|
|
||||||
|
|
||||||
|
def mask_along_axis_optimized(
|
||||||
|
features: torch.Tensor,
|
||||||
|
mask_size: int,
|
||||||
|
mask_times: int,
|
||||||
|
mask_value: float,
|
||||||
|
axis: int,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Apply Frequency and Time masking along axis.
|
||||||
|
Frequency and Time masking as described in the SpecAugment paper.
|
||||||
|
|
||||||
|
:param features: input tensor of shape ``(T, F)``
|
||||||
|
:mask_size: the width size for masking.
|
||||||
|
:mask_times: the number of masking regions.
|
||||||
|
:mask_value: Value to assign to the masked regions.
|
||||||
|
:axis: Axis to apply masking on (1 -> time, 2 -> frequency)
|
||||||
|
"""
|
||||||
|
if axis not in [1, 2]:
|
||||||
|
raise ValueError("Only Frequency and Time masking are supported!")
|
||||||
|
|
||||||
|
# 1 (True) represents masked area;
|
||||||
|
# 0 (False) represents original un-masked area.
|
||||||
|
masked_area = torch.zeros_like(features)
|
||||||
|
features = features.unsqueeze(0)
|
||||||
|
masked_area = masked_area.unsqueeze(0)
|
||||||
|
features = features.reshape([-1] + list(features.size()[-2:]))
|
||||||
|
|
||||||
|
values = torch.randint(int(0), int(mask_size), (1, mask_times))
|
||||||
|
min_values = torch.rand(1, mask_times) * (features.size(axis) - values)
|
||||||
|
mask_starts = (min_values.long()).squeeze()
|
||||||
|
mask_ends = (min_values.long() + values.long()).squeeze()
|
||||||
|
|
||||||
|
if axis == 1:
|
||||||
|
if mask_times == 1:
|
||||||
|
features[:, mask_starts:mask_ends] = mask_value
|
||||||
|
return features.squeeze(0), masked_area
|
||||||
|
for (mask_start, mask_end) in zip(mask_starts, mask_ends):
|
||||||
|
features[:, mask_start:mask_end] = mask_value
|
||||||
|
masked_area[:, mask_start:mask_end] = 1
|
||||||
|
else:
|
||||||
|
if mask_times == 1:
|
||||||
|
features[:, :, mask_starts:mask_ends] = mask_value
|
||||||
|
masked_area[:, :, mask_starts:mask_ends] = 1
|
||||||
|
return features.squeeze(0), masked_area
|
||||||
|
for (mask_start, mask_end) in zip(mask_starts, mask_ends):
|
||||||
|
features[:, :, mask_start:mask_end] = mask_value
|
||||||
|
masked_area[:, :, mask_start:mask_end] = 1
|
||||||
|
|
||||||
|
features = features.squeeze(0)
|
||||||
|
masked_area = masked_area.squeeze(0)
|
||||||
|
return features, masked_area
|
||||||
|
|
||||||
|
|
||||||
|
def time_warp(features: torch.Tensor, factor: int) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Time warping as described in the SpecAugment paper.
|
||||||
|
Implementation based on Espresso:
|
||||||
|
https://github.com/freewym/espresso/blob/master/espresso/tools/specaug_interpolate.py#L51
|
||||||
|
|
||||||
|
:param features: input tensor of shape ``(T, F)``
|
||||||
|
:param factor: time warping parameter.
|
||||||
|
:return: a warped tensor of shape ``(T, F)``
|
||||||
|
"""
|
||||||
|
t = features.size(0)
|
||||||
|
if t - factor <= factor + 1:
|
||||||
|
return features
|
||||||
|
center = np.random.randint(factor + 1, t - factor)
|
||||||
|
warped = np.random.randint(center - factor, center + factor + 1)
|
||||||
|
if warped == center:
|
||||||
|
return features
|
||||||
|
features = features.unsqueeze(0).unsqueeze(0)
|
||||||
|
left = torch.nn.functional.interpolate(
|
||||||
|
features[:, :, :center, :],
|
||||||
|
size=(warped, features.size(3)),
|
||||||
|
mode="bicubic",
|
||||||
|
align_corners=False,
|
||||||
|
)
|
||||||
|
right = torch.nn.functional.interpolate(
|
||||||
|
features[:, :, center:, :],
|
||||||
|
size=(t - warped, features.size(3)),
|
||||||
|
mode="bicubic",
|
||||||
|
align_corners=False,
|
||||||
|
)
|
||||||
|
return torch.cat((left, right), dim=2).squeeze(0).squeeze(0)
|
||||||
@ -23,7 +23,7 @@ from scaling import ScaledLinear
|
|||||||
|
|
||||||
from icefall.utils import add_sos
|
from icefall.utils import add_sos
|
||||||
|
|
||||||
from quantization.prediction import JointCodebookLoss
|
from multi_quantization.prediction import JointCodebookLoss
|
||||||
|
|
||||||
|
|
||||||
class Transducer(nn.Module):
|
class Transducer(nn.Module):
|
||||||
@ -41,6 +41,8 @@ class Transducer(nn.Module):
|
|||||||
joiner_dim: int,
|
joiner_dim: int,
|
||||||
vocab_size: int,
|
vocab_size: int,
|
||||||
num_codebooks: int = 0,
|
num_codebooks: int = 0,
|
||||||
|
masked_scale: float = 1.0,
|
||||||
|
unmasked_scale: float = 1.0,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@ -60,6 +62,10 @@ class Transducer(nn.Module):
|
|||||||
contains unnormalized probs, i.e., not processed by log-softmax.
|
contains unnormalized probs, i.e., not processed by log-softmax.
|
||||||
num_codebooks:
|
num_codebooks:
|
||||||
Used by distillation loss.
|
Used by distillation loss.
|
||||||
|
masked_scale:
|
||||||
|
scale of codebook loss of masked area.
|
||||||
|
unmasked_scale:
|
||||||
|
scale of codebook loss of unmasked area.
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
assert isinstance(encoder, EncoderInterface), type(encoder)
|
assert isinstance(encoder, EncoderInterface), type(encoder)
|
||||||
@ -75,8 +81,12 @@ class Transducer(nn.Module):
|
|||||||
self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size)
|
self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size)
|
||||||
if num_codebooks > 0:
|
if num_codebooks > 0:
|
||||||
self.codebook_loss_net = JointCodebookLoss(
|
self.codebook_loss_net = JointCodebookLoss(
|
||||||
predictor_channels=encoder_dim, num_codebooks=num_codebooks
|
predictor_channels=encoder_dim,
|
||||||
|
num_codebooks=num_codebooks,
|
||||||
|
reduction="none",
|
||||||
)
|
)
|
||||||
|
self.masked_scale = masked_scale
|
||||||
|
self.unmasked_scale = unmasked_scale
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -88,6 +98,7 @@ class Transducer(nn.Module):
|
|||||||
lm_scale: float = 0.0,
|
lm_scale: float = 0.0,
|
||||||
warmup: float = 1.0,
|
warmup: float = 1.0,
|
||||||
codebook_indexes: torch.Tensor = None,
|
codebook_indexes: torch.Tensor = None,
|
||||||
|
time_masked_area: torch.Tensor = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@ -113,6 +124,8 @@ class Transducer(nn.Module):
|
|||||||
warmup > 1 "are fully warmed up" and all modules will be active.
|
warmup > 1 "are fully warmed up" and all modules will be active.
|
||||||
codebook_indexes:
|
codebook_indexes:
|
||||||
codebook_indexes extracted from a teacher model.
|
codebook_indexes extracted from a teacher model.
|
||||||
|
time_masked_area:
|
||||||
|
masked area by SpecAugment, 1 represents masked.
|
||||||
Returns:
|
Returns:
|
||||||
Return the transducer loss.
|
Return the transducer loss.
|
||||||
|
|
||||||
@ -140,6 +153,22 @@ class Transducer(nn.Module):
|
|||||||
codebook_loss = self.codebook_loss_net(
|
codebook_loss = self.codebook_loss_net(
|
||||||
middle_layer_output, codebook_indexes
|
middle_layer_output, codebook_indexes
|
||||||
)
|
)
|
||||||
|
codebook_loss = codebook_loss.reshape(codebook_indexes.shape)
|
||||||
|
target_t = codebook_loss.shape[1]
|
||||||
|
time_masked_area = time_masked_area.bool()
|
||||||
|
time_masked_area = time_masked_area[
|
||||||
|
:, : target_t * 4 : 4, 0 # noqa E203
|
||||||
|
]
|
||||||
|
assert time_masked_area.shape == codebook_loss.shape[:-1]
|
||||||
|
time_masked_area = time_masked_area.unsqueeze(2).to(
|
||||||
|
codebook_loss.device
|
||||||
|
)
|
||||||
|
masked_loss = (time_masked_area * codebook_loss).sum()
|
||||||
|
unmasked_loss = (~time_masked_area * codebook_loss).sum()
|
||||||
|
codebook_loss = (
|
||||||
|
self.masked_scale * masked_loss
|
||||||
|
+ self.unmasked_scale * unmasked_loss
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
# when codebook index is not available.
|
# when codebook index is not available.
|
||||||
codebook_loss = None
|
codebook_loss = None
|
||||||
|
|||||||
@ -177,6 +177,18 @@ def get_parser():
|
|||||||
changed.""",
|
changed.""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--masked-scale",
|
||||||
|
type=float,
|
||||||
|
default=1.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--unmasked-scale",
|
||||||
|
type=float,
|
||||||
|
default=1.0,
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--lr-batches",
|
"--lr-batches",
|
||||||
type=float,
|
type=float,
|
||||||
@ -378,6 +390,8 @@ def get_params() -> AttributeDict:
|
|||||||
# two successive codebook_index are concatenated together.
|
# two successive codebook_index are concatenated together.
|
||||||
# Detailed in function Transducer::concat_sucessive_codebook_indexes.
|
# Detailed in function Transducer::concat_sucessive_codebook_indexes.
|
||||||
"num_codebooks": 16, # used to construct distillation loss
|
"num_codebooks": 16, # used to construct distillation loss
|
||||||
|
"masked_scale": 1.0,
|
||||||
|
"unmasked_scale": 1.0,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -436,6 +450,8 @@ def get_transducer_model(params: AttributeDict) -> nn.Module:
|
|||||||
num_codebooks=params.num_codebooks
|
num_codebooks=params.num_codebooks
|
||||||
if params.enable_distiallation
|
if params.enable_distiallation
|
||||||
else 0,
|
else 0,
|
||||||
|
masked_scale=params.masked_scale,
|
||||||
|
unmasked_scale=params.unmasked_scale,
|
||||||
)
|
)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
@ -602,7 +618,7 @@ def compute_loss(
|
|||||||
if isinstance(model, DDP)
|
if isinstance(model, DDP)
|
||||||
else next(model.parameters()).device
|
else next(model.parameters()).device
|
||||||
)
|
)
|
||||||
feature = batch["inputs"]
|
feature, time_masked_area = batch["inputs"]
|
||||||
# at entry, feature is (N, T, C)
|
# at entry, feature is (N, T, C)
|
||||||
assert feature.ndim == 3
|
assert feature.ndim == 3
|
||||||
feature = feature.to(device)
|
feature = feature.to(device)
|
||||||
@ -631,6 +647,7 @@ def compute_loss(
|
|||||||
lm_scale=params.lm_scale,
|
lm_scale=params.lm_scale,
|
||||||
warmup=warmup,
|
warmup=warmup,
|
||||||
codebook_indexes=codebook_indexes,
|
codebook_indexes=codebook_indexes,
|
||||||
|
time_masked_area=time_masked_area,
|
||||||
)
|
)
|
||||||
# after the main warmup step, we keep pruned_loss_scale small
|
# after the main warmup step, we keep pruned_loss_scale small
|
||||||
# for the same amount of time (model_warm_step), to avoid
|
# for the same amount of time (model_warm_step), to avoid
|
||||||
@ -1089,7 +1106,9 @@ def main():
|
|||||||
parser = get_parser()
|
parser = get_parser()
|
||||||
LibriSpeechAsrDataModule.add_arguments(parser)
|
LibriSpeechAsrDataModule.add_arguments(parser)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
args.exp_dir = Path(args.exp_dir)
|
args.exp_dir = Path(
|
||||||
|
f"{args.exp_dir}-masked_scale-{args.masked_scale}-un-{args.unmasked_scale}-{args.spec_aug_max_frames_mask_fraction}"
|
||||||
|
)
|
||||||
|
|
||||||
world_size = args.world_size
|
world_size = args.world_size
|
||||||
assert world_size >= 1
|
assert world_size >= 1
|
||||||
|
|||||||
@ -32,7 +32,6 @@ from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
|
|||||||
K2SpeechRecognitionDataset,
|
K2SpeechRecognitionDataset,
|
||||||
PrecomputedFeatures,
|
PrecomputedFeatures,
|
||||||
SingleCutSampler,
|
SingleCutSampler,
|
||||||
SpecAugment,
|
|
||||||
)
|
)
|
||||||
from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
|
from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
|
||||||
AudioSamples,
|
AudioSamples,
|
||||||
@ -41,6 +40,7 @@ from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
|
|||||||
from lhotse.utils import fix_random_seed
|
from lhotse.utils import fix_random_seed
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
|
from aug import SpecAugment
|
||||||
from icefall.utils import str2bool
|
from icefall.utils import str2bool
|
||||||
|
|
||||||
|
|
||||||
@ -183,6 +183,12 @@ class LibriSpeechAsrDataModule:
|
|||||||
help="When enabled, use SpecAugment for training dataset.",
|
help="When enabled, use SpecAugment for training dataset.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
group.add_argument(
|
||||||
|
"--spec-aug-max-frames-mask-fraction",
|
||||||
|
type=float,
|
||||||
|
default=0.15,
|
||||||
|
)
|
||||||
|
|
||||||
group.add_argument(
|
group.add_argument(
|
||||||
"--spec-aug-time-warp-factor",
|
"--spec-aug-time-warp-factor",
|
||||||
type=int,
|
type=int,
|
||||||
@ -272,6 +278,7 @@ class LibriSpeechAsrDataModule:
|
|||||||
features_mask_size=27,
|
features_mask_size=27,
|
||||||
num_feature_masks=2,
|
num_feature_masks=2,
|
||||||
frames_mask_size=100,
|
frames_mask_size=100,
|
||||||
|
max_frames_mask_fraction=self.args.spec_aug_max_frames_mask_fraction,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user