Merge b52b5c683f4f7daef450eebf60f651100b916cdc into 5379c8e9fa13f6f2364b4a0db89fa3074266fb58

This commit is contained in:
LIyong.Guo 2022-06-16 18:22:21 -04:00 committed by GitHub
commit 86f45caafc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 347 additions and 5 deletions

View File

@ -0,0 +1,287 @@
import math
import random
from typing import Dict, Optional, Tuple
import numpy as np
import torch
class SpecAugment(torch.nn.Module):
"""
SpecAugment performs three augmentations:
- time warping of the feature matrix
- masking of ranges of features (frequency bands)
- masking of ranges of frames (time)
The current implementation works with batches, but processes each example separately
in a loop rather than simultaneously to achieve different augmentation parameters for
each example.
"""
def __init__(
self,
time_warp_factor: Optional[int] = 80,
num_feature_masks: int = 2,
features_mask_size: int = 27,
num_frame_masks: int = 10,
frames_mask_size: int = 100,
max_frames_mask_fraction: float = 0.15,
p=0.9,
):
"""
SpecAugment's constructor.
:param time_warp_factor: parameter for the time warping; larger values mean more warping.
Set to ``None``, or less than ``1``, to disable.
:param num_feature_masks: how many feature masks should be applied. Set to ``0`` to disable.
:param features_mask_size: the width of the feature mask (expressed in the number of masked feature bins).
This is the ``F`` parameter from the SpecAugment paper.
:param num_frame_masks: the number of masking regions for utterances. Set to ``0`` to disable.
:param frames_mask_size: the width of the frame (temporal) masks (expressed in the number of masked frames).
This is the ``T`` parameter from the SpecAugment paper.
:param max_frames_mask_fraction: limits the size of the frame (temporal) mask to this value times the length
of the utterance (or supervision segment).
This is the parameter denoted by ``p`` in the SpecAugment paper.
:param p: the probability of applying this transform.
It is different from ``p`` in the SpecAugment paper!
"""
super().__init__()
assert 0 <= p <= 1
assert num_feature_masks >= 0
assert num_frame_masks > 0
assert features_mask_size > 0
assert frames_mask_size > 0
self.time_warp_factor = time_warp_factor
self.num_feature_masks = num_feature_masks
self.features_mask_size = features_mask_size
self.num_frame_masks = num_frame_masks
self.frames_mask_size = frames_mask_size
self.max_frames_mask_fraction = max_frames_mask_fraction
self.p = p
def forward(
self,
features: torch.Tensor,
supervision_segments: Optional[torch.IntTensor] = None,
*args,
**kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Computes SpecAugment for a batch of feature matrices.
Since the batch will usually already be padded, the user can optionally
provide a ``supervision_segments`` tensor that will be used to apply SpecAugment
only to selected areas of the input. The format of this input is described below.
:param features: a batch of feature matrices with shape ``(B, T, F)``.
:param supervision_segments: an int tensor of shape ``(S, 3)``. ``S`` is the number of
supervision segments that exist in ``features`` -- there may be either
less or more than the batch size.
The second dimension encoder three kinds of information:
the sequence index of the corresponding feature matrix in `features`,
the start frame index, and the number of frames for each segment.
:return: an augmented tensor of shape ``(B, T, F)``.
"""
assert len(features.shape) == 3, (
"SpecAugment only supports batches of "
"single-channel feature matrices."
)
features = features.clone()
# 1 (True) represents masked area;
# 0 (False) represents original un-masked area.
time_masked_area = torch.zeros_like(features)
if supervision_segments is None:
# No supervisions - apply spec augment to full feature matrices.
for sequence_idx in range(features.size(0)):
(
features[sequence_idx],
time_masked_area[sequence_idx],
) = self._forward_single(features[sequence_idx])
else:
# Supervisions provided - we will apply time warping only on the supervised areas.
for sequence_idx, start_frame, num_frames in supervision_segments:
end_frame = start_frame + num_frames
(
features[sequence_idx, start_frame:end_frame],
time_masked_area[sequence_idx, start_frame:end_frame],
) = self._forward_single(
features[sequence_idx, start_frame:end_frame],
warp=True,
mask=False,
)
# ... and then time-mask the full feature matrices. Note that in this mode,
# it might happen that masks are applied to different sequences/examples
# than the time warping.
for sequence_idx in range(features.size(0)):
(
features[sequence_idx],
time_masked_area[sequence_idx],
) = self._forward_single(
features[sequence_idx], warp=False, mask=True
)
return features, time_masked_area
def _forward_single(
self, features: torch.Tensor, warp: bool = True, mask: bool = True
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Apply SpecAugment to a single feature matrix of shape (T, F).
"""
time_masked_area = torch.zeros_like(features)
if random.random() > self.p:
# Randomly choose whether this transform is applied
# No augmentation, no masked area.
return features, time_masked_area
if warp:
if self.time_warp_factor is not None and self.time_warp_factor >= 1:
features = time_warp(features, factor=self.time_warp_factor)
if mask:
mean = features.mean()
# Frequency masking
features, _ = mask_along_axis_optimized(
features,
mask_size=self.features_mask_size,
mask_times=self.num_feature_masks,
mask_value=mean,
axis=2,
)
# Time masking
max_tot_mask_frames = self.max_frames_mask_fraction * features.size(
0
)
num_frame_masks = min(
self.num_frame_masks,
math.ceil(max_tot_mask_frames / self.frames_mask_size),
)
max_mask_frames = min(
self.frames_mask_size, max_tot_mask_frames // num_frame_masks
)
features, time_masked_area = mask_along_axis_optimized(
features,
mask_size=max_mask_frames,
mask_times=num_frame_masks,
mask_value=mean,
axis=1,
)
return features, time_masked_area
def state_dict(self) -> Dict:
return dict(
time_warp_factor=self.time_warp_factor,
num_feature_masks=self.num_feature_masks,
features_mask_size=self.features_mask_size,
num_frame_masks=self.num_frame_masks,
frames_mask_size=self.frames_mask_size,
max_frames_mask_fraction=self.max_frames_mask_fraction,
p=self.p,
)
def load_state_dict(self, state_dict: Dict):
self.time_warp_factor = state_dict.get(
"time_warp_factor", self.time_warp_factor
)
self.num_feature_masks = state_dict.get(
"num_feature_masks", self.num_feature_masks
)
self.features_mask_size = state_dict.get(
"features_mask_size", self.features_mask_size
)
self.num_frame_masks = state_dict.get(
"num_frame_masks", self.num_frame_masks
)
self.frames_mask_size = state_dict.get(
"frames_mask_size", self.frames_mask_size
)
self.max_frames_mask_fraction = state_dict.get(
"max_frames_mask_fraction", self.max_frames_mask_fraction
)
self.p = state_dict.get("p", self.p)
def mask_along_axis_optimized(
features: torch.Tensor,
mask_size: int,
mask_times: int,
mask_value: float,
axis: int,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Apply Frequency and Time masking along axis.
Frequency and Time masking as described in the SpecAugment paper.
:param features: input tensor of shape ``(T, F)``
:mask_size: the width size for masking.
:mask_times: the number of masking regions.
:mask_value: Value to assign to the masked regions.
:axis: Axis to apply masking on (1 -> time, 2 -> frequency)
"""
if axis not in [1, 2]:
raise ValueError("Only Frequency and Time masking are supported!")
# 1 (True) represents masked area;
# 0 (False) represents original un-masked area.
masked_area = torch.zeros_like(features)
features = features.unsqueeze(0)
masked_area = masked_area.unsqueeze(0)
features = features.reshape([-1] + list(features.size()[-2:]))
values = torch.randint(int(0), int(mask_size), (1, mask_times))
min_values = torch.rand(1, mask_times) * (features.size(axis) - values)
mask_starts = (min_values.long()).squeeze()
mask_ends = (min_values.long() + values.long()).squeeze()
if axis == 1:
if mask_times == 1:
features[:, mask_starts:mask_ends] = mask_value
return features.squeeze(0), masked_area
for (mask_start, mask_end) in zip(mask_starts, mask_ends):
features[:, mask_start:mask_end] = mask_value
masked_area[:, mask_start:mask_end] = 1
else:
if mask_times == 1:
features[:, :, mask_starts:mask_ends] = mask_value
masked_area[:, :, mask_starts:mask_ends] = 1
return features.squeeze(0), masked_area
for (mask_start, mask_end) in zip(mask_starts, mask_ends):
features[:, :, mask_start:mask_end] = mask_value
masked_area[:, :, mask_start:mask_end] = 1
features = features.squeeze(0)
masked_area = masked_area.squeeze(0)
return features, masked_area
def time_warp(features: torch.Tensor, factor: int) -> torch.Tensor:
"""
Time warping as described in the SpecAugment paper.
Implementation based on Espresso:
https://github.com/freewym/espresso/blob/master/espresso/tools/specaug_interpolate.py#L51
:param features: input tensor of shape ``(T, F)``
:param factor: time warping parameter.
:return: a warped tensor of shape ``(T, F)``
"""
t = features.size(0)
if t - factor <= factor + 1:
return features
center = np.random.randint(factor + 1, t - factor)
warped = np.random.randint(center - factor, center + factor + 1)
if warped == center:
return features
features = features.unsqueeze(0).unsqueeze(0)
left = torch.nn.functional.interpolate(
features[:, :, :center, :],
size=(warped, features.size(3)),
mode="bicubic",
align_corners=False,
)
right = torch.nn.functional.interpolate(
features[:, :, center:, :],
size=(t - warped, features.size(3)),
mode="bicubic",
align_corners=False,
)
return torch.cat((left, right), dim=2).squeeze(0).squeeze(0)

View File

@ -23,7 +23,7 @@ from scaling import ScaledLinear
from icefall.utils import add_sos
from quantization.prediction import JointCodebookLoss
from multi_quantization.prediction import JointCodebookLoss
class Transducer(nn.Module):
@ -41,6 +41,8 @@ class Transducer(nn.Module):
joiner_dim: int,
vocab_size: int,
num_codebooks: int = 0,
masked_scale: float = 1.0,
unmasked_scale: float = 1.0,
):
"""
Args:
@ -60,6 +62,10 @@ class Transducer(nn.Module):
contains unnormalized probs, i.e., not processed by log-softmax.
num_codebooks:
Used by distillation loss.
masked_scale:
scale of codebook loss of masked area.
unmasked_scale:
scale of codebook loss of unmasked area.
"""
super().__init__()
assert isinstance(encoder, EncoderInterface), type(encoder)
@ -75,8 +81,12 @@ class Transducer(nn.Module):
self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size)
if num_codebooks > 0:
self.codebook_loss_net = JointCodebookLoss(
predictor_channels=encoder_dim, num_codebooks=num_codebooks
predictor_channels=encoder_dim,
num_codebooks=num_codebooks,
reduction="none",
)
self.masked_scale = masked_scale
self.unmasked_scale = unmasked_scale
def forward(
self,
@ -88,6 +98,7 @@ class Transducer(nn.Module):
lm_scale: float = 0.0,
warmup: float = 1.0,
codebook_indexes: torch.Tensor = None,
time_masked_area: torch.Tensor = None,
) -> torch.Tensor:
"""
Args:
@ -113,6 +124,8 @@ class Transducer(nn.Module):
warmup > 1 "are fully warmed up" and all modules will be active.
codebook_indexes:
codebook_indexes extracted from a teacher model.
time_masked_area:
masked area by SpecAugment, 1 represents masked.
Returns:
Return the transducer loss.
@ -140,6 +153,22 @@ class Transducer(nn.Module):
codebook_loss = self.codebook_loss_net(
middle_layer_output, codebook_indexes
)
codebook_loss = codebook_loss.reshape(codebook_indexes.shape)
target_t = codebook_loss.shape[1]
time_masked_area = time_masked_area.bool()
time_masked_area = time_masked_area[
:, : target_t * 4 : 4, 0 # noqa E203
]
assert time_masked_area.shape == codebook_loss.shape[:-1]
time_masked_area = time_masked_area.unsqueeze(2).to(
codebook_loss.device
)
masked_loss = (time_masked_area * codebook_loss).sum()
unmasked_loss = (~time_masked_area * codebook_loss).sum()
codebook_loss = (
self.masked_scale * masked_loss
+ self.unmasked_scale * unmasked_loss
)
else:
# when codebook index is not available.
codebook_loss = None

View File

@ -177,6 +177,18 @@ def get_parser():
changed.""",
)
parser.add_argument(
"--masked-scale",
type=float,
default=1.0,
)
parser.add_argument(
"--unmasked-scale",
type=float,
default=1.0,
)
parser.add_argument(
"--lr-batches",
type=float,
@ -378,6 +390,8 @@ def get_params() -> AttributeDict:
# two successive codebook_index are concatenated together.
# Detailed in function Transducer::concat_sucessive_codebook_indexes.
"num_codebooks": 16, # used to construct distillation loss
"masked_scale": 1.0,
"unmasked_scale": 1.0,
}
)
@ -436,6 +450,8 @@ def get_transducer_model(params: AttributeDict) -> nn.Module:
num_codebooks=params.num_codebooks
if params.enable_distiallation
else 0,
masked_scale=params.masked_scale,
unmasked_scale=params.unmasked_scale,
)
return model
@ -602,7 +618,7 @@ def compute_loss(
if isinstance(model, DDP)
else next(model.parameters()).device
)
feature = batch["inputs"]
feature, time_masked_area = batch["inputs"]
# at entry, feature is (N, T, C)
assert feature.ndim == 3
feature = feature.to(device)
@ -631,6 +647,7 @@ def compute_loss(
lm_scale=params.lm_scale,
warmup=warmup,
codebook_indexes=codebook_indexes,
time_masked_area=time_masked_area,
)
# after the main warmup step, we keep pruned_loss_scale small
# for the same amount of time (model_warm_step), to avoid
@ -1089,7 +1106,9 @@ def main():
parser = get_parser()
LibriSpeechAsrDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
args.exp_dir = Path(
f"{args.exp_dir}-masked_scale-{args.masked_scale}-un-{args.unmasked_scale}-{args.spec_aug_max_frames_mask_fraction}"
)
world_size = args.world_size
assert world_size >= 1

View File

@ -32,7 +32,6 @@ from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
K2SpeechRecognitionDataset,
PrecomputedFeatures,
SingleCutSampler,
SpecAugment,
)
from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
AudioSamples,
@ -41,6 +40,7 @@ from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
from lhotse.utils import fix_random_seed
from torch.utils.data import DataLoader
from aug import SpecAugment
from icefall.utils import str2bool
@ -183,6 +183,12 @@ class LibriSpeechAsrDataModule:
help="When enabled, use SpecAugment for training dataset.",
)
group.add_argument(
"--spec-aug-max-frames-mask-fraction",
type=float,
default=0.15,
)
group.add_argument(
"--spec-aug-time-warp-factor",
type=int,
@ -272,6 +278,7 @@ class LibriSpeechAsrDataModule:
features_mask_size=27,
num_feature_masks=2,
frames_mask_size=100,
max_frames_mask_fraction=self.args.spec_aug_max_frames_mask_fraction,
)
)
else: