predicted masked codebook indexes only

This commit is contained in:
Guo Liyong 2022-05-30 14:33:48 +08:00
parent 0c33543ce7
commit 90024c308f
4 changed files with 67 additions and 22 deletions

View File

@ -32,7 +32,6 @@ from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
K2SpeechRecognitionDataset, K2SpeechRecognitionDataset,
PrecomputedFeatures, PrecomputedFeatures,
SingleCutSampler, SingleCutSampler,
SpecAugment,
) )
from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
AudioSamples, AudioSamples,
@ -41,6 +40,7 @@ from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
from lhotse.utils import fix_random_seed from lhotse.utils import fix_random_seed
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from aug import SpecAugment
from icefall.utils import str2bool from icefall.utils import str2bool

View File

@ -1,6 +1,6 @@
import math import math
import random import random
from typing import Dict, Optional from typing import Dict, Optional, Tuple
import numpy as np import numpy as np
import torch import torch
@ -65,7 +65,7 @@ class SpecAugment(torch.nn.Module):
supervision_segments: Optional[torch.IntTensor] = None, supervision_segments: Optional[torch.IntTensor] = None,
*args, *args,
**kwargs, **kwargs,
) -> torch.Tensor: ) -> Tuple[torch.Tensor, torch.Tensor]:
""" """
Computes SpecAugment for a batch of feature matrices. Computes SpecAugment for a batch of feature matrices.
@ -87,19 +87,25 @@ class SpecAugment(torch.nn.Module):
"single-channel feature matrices." "single-channel feature matrices."
) )
features = features.clone() features = features.clone()
# 1 (True) represents masked area;
# 0 (False) represents original un-masked area.
time_masked_area = torch.zeros_like(features)
if supervision_segments is None: if supervision_segments is None:
# No supervisions - apply spec augment to full feature matrices. # No supervisions - apply spec augment to full feature matrices.
for sequence_idx in range(features.size(0)): for sequence_idx in range(features.size(0)):
features[sequence_idx] = self._forward_single( (
features[sequence_idx] features[sequence_idx],
) time_masked_area[sequence_idx],
) = self._forward_single(features[sequence_idx])
else: else:
# Supervisions provided - we will apply time warping only on the supervised areas. # Supervisions provided - we will apply time warping only on the supervised areas.
for sequence_idx, start_frame, num_frames in supervision_segments: for sequence_idx, start_frame, num_frames in supervision_segments:
end_frame = start_frame + num_frames end_frame = start_frame + num_frames
features[ (
sequence_idx, start_frame:end_frame features[sequence_idx, start_frame:end_frame],
] = self._forward_single( time_masked_area[sequence_idx, start_frame:end_frame],
) = self._forward_single(
features[sequence_idx, start_frame:end_frame], features[sequence_idx, start_frame:end_frame],
warp=True, warp=True,
mask=False, mask=False,
@ -108,27 +114,33 @@ class SpecAugment(torch.nn.Module):
# it might happen that masks are applied to different sequences/examples # it might happen that masks are applied to different sequences/examples
# than the time warping. # than the time warping.
for sequence_idx in range(features.size(0)): for sequence_idx in range(features.size(0)):
features[sequence_idx] = self._forward_single( (
features[sequence_idx],
time_masked_area[sequence_idx],
) = self._forward_single(
features[sequence_idx], warp=False, mask=True features[sequence_idx], warp=False, mask=True
) )
return features
return features, time_masked_area
def _forward_single( def _forward_single(
self, features: torch.Tensor, warp: bool = True, mask: bool = True self, features: torch.Tensor, warp: bool = True, mask: bool = True
) -> torch.Tensor: ) -> Tuple[torch.Tensor, torch.Tensor]:
""" """
Apply SpecAugment to a single feature matrix of shape (T, F). Apply SpecAugment to a single feature matrix of shape (T, F).
""" """
time_masked_area = torch.zeros_like(features)
if random.random() > self.p: if random.random() > self.p:
# Randomly choose whether this transform is applied # Randomly choose whether this transform is applied
return features # No augmentation, no masked area.
return features, time_masked_area
if warp: if warp:
if self.time_warp_factor is not None and self.time_warp_factor >= 1: if self.time_warp_factor is not None and self.time_warp_factor >= 1:
features = time_warp(features, factor=self.time_warp_factor) features = time_warp(features, factor=self.time_warp_factor)
if mask: if mask:
mean = features.mean() mean = features.mean()
# Frequency masking # Frequency masking
features = mask_along_axis_optimized( features, _ = mask_along_axis_optimized(
features, features,
mask_size=self.features_mask_size, mask_size=self.features_mask_size,
mask_times=self.num_feature_masks, mask_times=self.num_feature_masks,
@ -146,7 +158,7 @@ class SpecAugment(torch.nn.Module):
max_mask_frames = min( max_mask_frames = min(
self.frames_mask_size, max_tot_mask_frames // num_frame_masks self.frames_mask_size, max_tot_mask_frames // num_frame_masks
) )
features = mask_along_axis_optimized( features, time_masked_area = mask_along_axis_optimized(
features, features,
mask_size=max_mask_frames, mask_size=max_mask_frames,
mask_times=num_frame_masks, mask_times=num_frame_masks,
@ -154,7 +166,7 @@ class SpecAugment(torch.nn.Module):
axis=1, axis=1,
) )
return features return features, time_masked_area
def state_dict(self) -> Dict: def state_dict(self) -> Dict:
return dict( return dict(
@ -195,7 +207,7 @@ def mask_along_axis_optimized(
mask_times: int, mask_times: int,
mask_value: float, mask_value: float,
axis: int, axis: int,
) -> torch.Tensor: ) -> Tuple[torch.Tensor, torch.Tensor]:
""" """
Apply Frequency and Time masking along axis. Apply Frequency and Time masking along axis.
Frequency and Time masking as described in the SpecAugment paper. Frequency and Time masking as described in the SpecAugment paper.
@ -209,7 +221,11 @@ def mask_along_axis_optimized(
if axis not in [1, 2]: if axis not in [1, 2]:
raise ValueError("Only Frequency and Time masking are supported!") raise ValueError("Only Frequency and Time masking are supported!")
# 1 (True) represents masked area;
# 0 (False) represents original un-masked area.
masked_area = torch.zeros_like(features)
features = features.unsqueeze(0) features = features.unsqueeze(0)
masked_area = masked_area.unsqueeze(0)
features = features.reshape([-1] + list(features.size()[-2:])) features = features.reshape([-1] + list(features.size()[-2:]))
values = torch.randint(int(0), int(mask_size), (1, mask_times)) values = torch.randint(int(0), int(mask_size), (1, mask_times))
@ -220,18 +236,22 @@ def mask_along_axis_optimized(
if axis == 1: if axis == 1:
if mask_times == 1: if mask_times == 1:
features[:, mask_starts:mask_ends] = mask_value features[:, mask_starts:mask_ends] = mask_value
return features.squeeze(0) return features.squeeze(0), masked_area
for (mask_start, mask_end) in zip(mask_starts, mask_ends): for (mask_start, mask_end) in zip(mask_starts, mask_ends):
features[:, mask_start:mask_end] = mask_value features[:, mask_start:mask_end] = mask_value
masked_area[:, mask_start:mask_end] = 1
else: else:
if mask_times == 1: if mask_times == 1:
features[:, :, mask_starts:mask_ends] = mask_value features[:, :, mask_starts:mask_ends] = mask_value
return features.squeeze(0) masked_area[:, :, mask_starts:mask_ends] = 1
return features.squeeze(0), masked_area
for (mask_start, mask_end) in zip(mask_starts, mask_ends): for (mask_start, mask_end) in zip(mask_starts, mask_ends):
features[:, :, mask_start:mask_end] = mask_value features[:, :, mask_start:mask_end] = mask_value
masked_area[:, :, mask_start:mask_end] = 1
features = features.squeeze(0) features = features.squeeze(0)
return features masked_area = masked_area.squeeze(0)
return features, masked_area
def time_warp(features: torch.Tensor, factor: int) -> torch.Tensor: def time_warp(features: torch.Tensor, factor: int) -> torch.Tensor:

View File

@ -75,7 +75,9 @@ class Transducer(nn.Module):
self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size) self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size)
if num_codebooks > 0: if num_codebooks > 0:
self.codebook_loss_net = JointCodebookLoss( self.codebook_loss_net = JointCodebookLoss(
predictor_channels=encoder_dim, num_codebooks=num_codebooks predictor_channels=encoder_dim,
num_codebooks=num_codebooks,
reduction="none",
) )
def forward( def forward(
@ -88,6 +90,8 @@ class Transducer(nn.Module):
lm_scale: float = 0.0, lm_scale: float = 0.0,
warmup: float = 1.0, warmup: float = 1.0,
codebook_indexes: torch.Tensor = None, codebook_indexes: torch.Tensor = None,
time_masked_area: torch.Tensor = None,
masked_scale: float = 1.0,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Args: Args:
@ -113,6 +117,11 @@ class Transducer(nn.Module):
warmup > 1 "are fully warmed up" and all modules will be active. warmup > 1 "are fully warmed up" and all modules will be active.
codebook_indexes: codebook_indexes:
codebook_indexes extracted from a teacher model. codebook_indexes extracted from a teacher model.
time_masked_area:
masked area by SpecAugment, 1 represents masked.
masked_scale:
scale of codebook loss of masked area.
the unmasked_scale = 1 - masked_scale
Returns: Returns:
Return the transducer loss. Return the transducer loss.
@ -140,6 +149,21 @@ class Transducer(nn.Module):
codebook_loss = self.codebook_loss_net( codebook_loss = self.codebook_loss_net(
middle_layer_output, codebook_indexes middle_layer_output, codebook_indexes
) )
codebook_loss = codebook_loss.reshape(codebook_indexes.shape)
target_t = codebook_loss.shape[1]
time_masked_area = time_masked_area.bool()
time_masked_area = time_masked_area[
:, : target_t * 4 : 4, 0 # noqa E203
]
assert time_masked_area.shape == codebook_loss.shape[:-1]
time_masked_area = time_masked_area.unsqueeze(2).to(
codebook_loss.device
)
masked_loss = (time_masked_area * codebook_loss).sum()
unmasked_loss = (~time_masked_area * codebook_loss).sum()
codebook_loss = (
masked_scale * masked_loss + (1 - masked_scale) * unmasked_loss
)
else: else:
# when codebook index is not available. # when codebook index is not available.
codebook_loss = None codebook_loss = None

View File

@ -602,7 +602,7 @@ def compute_loss(
if isinstance(model, DDP) if isinstance(model, DDP)
else next(model.parameters()).device else next(model.parameters()).device
) )
feature = batch["inputs"] feature, time_masked_area = batch["inputs"]
# at entry, feature is (N, T, C) # at entry, feature is (N, T, C)
assert feature.ndim == 3 assert feature.ndim == 3
feature = feature.to(device) feature = feature.to(device)
@ -631,6 +631,7 @@ def compute_loss(
lm_scale=params.lm_scale, lm_scale=params.lm_scale,
warmup=warmup, warmup=warmup,
codebook_indexes=codebook_indexes, codebook_indexes=codebook_indexes,
time_masked_area=time_masked_area,
) )
# after the main warmup step, we keep pruned_loss_scale small # after the main warmup step, we keep pruned_loss_scale small
# for the same amount of time (model_warm_step), to avoid # for the same amount of time (model_warm_step), to avoid