diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/asr_datamodule.py b/egs/librispeech/ASR/pruned_transducer_stateless6/asr_datamodule.py index e83009d4a..03d0d1a88 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/asr_datamodule.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/asr_datamodule.py @@ -32,7 +32,6 @@ from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures K2SpeechRecognitionDataset, PrecomputedFeatures, SingleCutSampler, - SpecAugment, ) from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples AudioSamples, @@ -41,6 +40,7 @@ from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples from lhotse.utils import fix_random_seed from torch.utils.data import DataLoader +from aug import SpecAugment from icefall.utils import str2bool diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/aug.py b/egs/librispeech/ASR/pruned_transducer_stateless6/aug.py index c60d328c5..0746d0036 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/aug.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/aug.py @@ -1,6 +1,6 @@ import math import random -from typing import Dict, Optional +from typing import Dict, Optional, Tuple import numpy as np import torch @@ -65,7 +65,7 @@ class SpecAugment(torch.nn.Module): supervision_segments: Optional[torch.IntTensor] = None, *args, **kwargs, - ) -> torch.Tensor: + ) -> Tuple[torch.Tensor, torch.Tensor]: """ Computes SpecAugment for a batch of feature matrices. @@ -87,19 +87,25 @@ class SpecAugment(torch.nn.Module): "single-channel feature matrices." ) features = features.clone() + # 1 (True) represents masked area; + # 0 (False) represents original un-masked area. + time_masked_area = torch.zeros_like(features) if supervision_segments is None: # No supervisions - apply spec augment to full feature matrices. for sequence_idx in range(features.size(0)): - features[sequence_idx] = self._forward_single( - features[sequence_idx] - ) + ( + features[sequence_idx], + time_masked_area[sequence_idx], + ) = self._forward_single(features[sequence_idx]) + else: # Supervisions provided - we will apply time warping only on the supervised areas. for sequence_idx, start_frame, num_frames in supervision_segments: end_frame = start_frame + num_frames - features[ - sequence_idx, start_frame:end_frame - ] = self._forward_single( + ( + features[sequence_idx, start_frame:end_frame], + time_masked_area[sequence_idx, start_frame:end_frame], + ) = self._forward_single( features[sequence_idx, start_frame:end_frame], warp=True, mask=False, @@ -108,27 +114,33 @@ class SpecAugment(torch.nn.Module): # it might happen that masks are applied to different sequences/examples # than the time warping. for sequence_idx in range(features.size(0)): - features[sequence_idx] = self._forward_single( + ( + features[sequence_idx], + time_masked_area[sequence_idx], + ) = self._forward_single( features[sequence_idx], warp=False, mask=True ) - return features + + return features, time_masked_area def _forward_single( self, features: torch.Tensor, warp: bool = True, mask: bool = True - ) -> torch.Tensor: + ) -> Tuple[torch.Tensor, torch.Tensor]: """ Apply SpecAugment to a single feature matrix of shape (T, F). """ + time_masked_area = torch.zeros_like(features) if random.random() > self.p: # Randomly choose whether this transform is applied - return features + # No augmentation, no masked area. + return features, time_masked_area if warp: if self.time_warp_factor is not None and self.time_warp_factor >= 1: features = time_warp(features, factor=self.time_warp_factor) if mask: mean = features.mean() # Frequency masking - features = mask_along_axis_optimized( + features, _ = mask_along_axis_optimized( features, mask_size=self.features_mask_size, mask_times=self.num_feature_masks, @@ -146,7 +158,7 @@ class SpecAugment(torch.nn.Module): max_mask_frames = min( self.frames_mask_size, max_tot_mask_frames // num_frame_masks ) - features = mask_along_axis_optimized( + features, time_masked_area = mask_along_axis_optimized( features, mask_size=max_mask_frames, mask_times=num_frame_masks, @@ -154,7 +166,7 @@ class SpecAugment(torch.nn.Module): axis=1, ) - return features + return features, time_masked_area def state_dict(self) -> Dict: return dict( @@ -195,7 +207,7 @@ def mask_along_axis_optimized( mask_times: int, mask_value: float, axis: int, -) -> torch.Tensor: +) -> Tuple[torch.Tensor, torch.Tensor]: """ Apply Frequency and Time masking along axis. Frequency and Time masking as described in the SpecAugment paper. @@ -209,7 +221,11 @@ def mask_along_axis_optimized( if axis not in [1, 2]: raise ValueError("Only Frequency and Time masking are supported!") + # 1 (True) represents masked area; + # 0 (False) represents original un-masked area. + masked_area = torch.zeros_like(features) features = features.unsqueeze(0) + masked_area = masked_area.unsqueeze(0) features = features.reshape([-1] + list(features.size()[-2:])) values = torch.randint(int(0), int(mask_size), (1, mask_times)) @@ -220,18 +236,22 @@ def mask_along_axis_optimized( if axis == 1: if mask_times == 1: features[:, mask_starts:mask_ends] = mask_value - return features.squeeze(0) + return features.squeeze(0), masked_area for (mask_start, mask_end) in zip(mask_starts, mask_ends): features[:, mask_start:mask_end] = mask_value + masked_area[:, mask_start:mask_end] = 1 else: if mask_times == 1: features[:, :, mask_starts:mask_ends] = mask_value - return features.squeeze(0) + masked_area[:, :, mask_starts:mask_ends] = 1 + return features.squeeze(0), masked_area for (mask_start, mask_end) in zip(mask_starts, mask_ends): features[:, :, mask_start:mask_end] = mask_value + masked_area[:, :, mask_start:mask_end] = 1 features = features.squeeze(0) - return features + masked_area = masked_area.squeeze(0) + return features, masked_area def time_warp(features: torch.Tensor, factor: int) -> torch.Tensor: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/model.py b/egs/librispeech/ASR/pruned_transducer_stateless6/model.py index 66bb33e8d..5102f357e 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/model.py @@ -75,7 +75,9 @@ class Transducer(nn.Module): self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size) if num_codebooks > 0: self.codebook_loss_net = JointCodebookLoss( - predictor_channels=encoder_dim, num_codebooks=num_codebooks + predictor_channels=encoder_dim, + num_codebooks=num_codebooks, + reduction="none", ) def forward( @@ -88,6 +90,8 @@ class Transducer(nn.Module): lm_scale: float = 0.0, warmup: float = 1.0, codebook_indexes: torch.Tensor = None, + time_masked_area: torch.Tensor = None, + masked_scale: float = 1.0, ) -> torch.Tensor: """ Args: @@ -113,6 +117,11 @@ class Transducer(nn.Module): warmup > 1 "are fully warmed up" and all modules will be active. codebook_indexes: codebook_indexes extracted from a teacher model. + time_masked_area: + masked area by SpecAugment, 1 represents masked. + masked_scale: + scale of codebook loss of masked area. + the unmasked_scale = 1 - masked_scale Returns: Return the transducer loss. @@ -140,6 +149,21 @@ class Transducer(nn.Module): codebook_loss = self.codebook_loss_net( middle_layer_output, codebook_indexes ) + codebook_loss = codebook_loss.reshape(codebook_indexes.shape) + target_t = codebook_loss.shape[1] + time_masked_area = time_masked_area.bool() + time_masked_area = time_masked_area[ + :, : target_t * 4 : 4, 0 # noqa E203 + ] + assert time_masked_area.shape == codebook_loss.shape[:-1] + time_masked_area = time_masked_area.unsqueeze(2).to( + codebook_loss.device + ) + masked_loss = (time_masked_area * codebook_loss).sum() + unmasked_loss = (~time_masked_area * codebook_loss).sum() + codebook_loss = ( + masked_scale * masked_loss + (1 - masked_scale) * unmasked_loss + ) else: # when codebook index is not available. codebook_loss = None diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/train.py b/egs/librispeech/ASR/pruned_transducer_stateless6/train.py index feb58f457..dbf87ff48 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/train.py @@ -602,7 +602,7 @@ def compute_loss( if isinstance(model, DDP) else next(model.parameters()).device ) - feature = batch["inputs"] + feature, time_masked_area = batch["inputs"] # at entry, feature is (N, T, C) assert feature.ndim == 3 feature = feature.to(device) @@ -631,6 +631,7 @@ def compute_loss( lm_scale=params.lm_scale, warmup=warmup, codebook_indexes=codebook_indexes, + time_masked_area=time_masked_area, ) # after the main warmup step, we keep pruned_loss_scale small # for the same amount of time (model_warm_step), to avoid