From 35ecd7e5629630242d28aa35004c8394ff7b1f91 Mon Sep 17 00:00:00 2001 From: Wei Kang Date: Sun, 6 Feb 2022 21:59:54 +0800 Subject: [PATCH 1/2] Fix torch.nn.Embedding error for torch below 1.8.0 (#198) --- egs/librispeech/ASR/transducer/beam_search.py | 4 +++- egs/librispeech/ASR/transducer/model.py | 1 + egs/librispeech/ASR/transducer_lstm/beam_search.py | 4 +++- egs/librispeech/ASR/transducer_lstm/model.py | 1 + egs/librispeech/ASR/transducer_stateless/beam_search.py | 2 +- egs/librispeech/ASR/transducer_stateless/model.py | 1 + 6 files changed, 10 insertions(+), 3 deletions(-) diff --git a/egs/librispeech/ASR/transducer/beam_search.py b/egs/librispeech/ASR/transducer/beam_search.py index f45d06ce9..11032f31a 100644 --- a/egs/librispeech/ASR/transducer/beam_search.py +++ b/egs/librispeech/ASR/transducer/beam_search.py @@ -38,7 +38,9 @@ def greedy_search(model: Transducer, encoder_out: torch.Tensor) -> List[int]: blank_id = model.decoder.blank_id device = model.device - sos = torch.tensor([blank_id], device=device).reshape(1, 1) + sos = torch.tensor([blank_id], device=device, dtype=torch.int64).reshape( + 1, 1 + ) decoder_out, (h, c) = model.decoder(sos) T = encoder_out.size(1) t = 0 diff --git a/egs/librispeech/ASR/transducer/model.py b/egs/librispeech/ASR/transducer/model.py index fa0b2dd68..8305248c9 100644 --- a/egs/librispeech/ASR/transducer/model.py +++ b/egs/librispeech/ASR/transducer/model.py @@ -99,6 +99,7 @@ class Transducer(nn.Module): sos_y = add_sos(y, sos_id=blank_id) sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id) + sos_y_padded = sos_y_padded.to(torch.int64) decoder_out, _ = self.decoder(sos_y_padded) diff --git a/egs/librispeech/ASR/transducer_lstm/beam_search.py b/egs/librispeech/ASR/transducer_lstm/beam_search.py index dfc22fcf8..3531a9633 100644 --- a/egs/librispeech/ASR/transducer_lstm/beam_search.py +++ b/egs/librispeech/ASR/transducer_lstm/beam_search.py @@ -38,7 +38,9 @@ def greedy_search(model: Transducer, encoder_out: torch.Tensor) -> List[int]: blank_id = model.decoder.blank_id device = model.device - sos = torch.tensor([blank_id], device=device).reshape(1, 1) + sos = torch.tensor([blank_id], device=device, dtype=torch.int64).reshape( + 1, 1 + ) decoder_out, (h, c) = model.decoder(sos) T = encoder_out.size(1) t = 0 diff --git a/egs/librispeech/ASR/transducer_lstm/model.py b/egs/librispeech/ASR/transducer_lstm/model.py index cb9afd8a2..31843b60e 100644 --- a/egs/librispeech/ASR/transducer_lstm/model.py +++ b/egs/librispeech/ASR/transducer_lstm/model.py @@ -101,6 +101,7 @@ class Transducer(nn.Module): sos_y = add_sos(y, sos_id=sos_id) sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id) + sos_y_padded = sos_y_padded.to(torch.int64) decoder_out, _ = self.decoder(sos_y_padded) diff --git a/egs/librispeech/ASR/transducer_stateless/beam_search.py b/egs/librispeech/ASR/transducer_stateless/beam_search.py index 341c74fab..1cce48235 100644 --- a/egs/librispeech/ASR/transducer_stateless/beam_search.py +++ b/egs/librispeech/ASR/transducer_stateless/beam_search.py @@ -48,7 +48,7 @@ def greedy_search( device = model.device decoder_input = torch.tensor( - [blank_id] * context_size, device=device + [blank_id] * context_size, device=device, dtype=torch.int64 ).reshape(1, context_size) decoder_out = model.decoder(decoder_input, need_pad=False) diff --git a/egs/librispeech/ASR/transducer_stateless/model.py b/egs/librispeech/ASR/transducer_stateless/model.py index 7aac290d9..17b5f63e5 100644 --- a/egs/librispeech/ASR/transducer_stateless/model.py +++ b/egs/librispeech/ASR/transducer_stateless/model.py @@ -93,6 +93,7 @@ class Transducer(nn.Module): sos_y = add_sos(y, sos_id=blank_id) sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id) + sos_y_padded = sos_y_padded.to(torch.int64) decoder_out = self.decoder(sos_y_padded) From 3323cabf467324b5d8bc3b1247a37724cd778ed0 Mon Sep 17 00:00:00 2001 From: Mingshuang Luo <37799481+luomingshuang@users.noreply.github.com> Date: Tue, 8 Feb 2022 14:25:31 +0800 Subject: [PATCH 2/2] Experiments based on SpecAugment change --- .../ASR/tdnn_lstm_ctc/asr_datamodule.py | 213 +++++++++++++++++- 1 file changed, 211 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py index e075a2d03..e5fcc5893 100644 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py @@ -28,7 +28,6 @@ from lhotse.dataset import ( K2SpeechRecognitionDataset, PrecomputedFeatures, SingleCutSampler, - SpecAugment, ) from lhotse.dataset.input_strategies import OnTheFlyFeatures from torch.utils.data import DataLoader @@ -219,10 +218,11 @@ class LibriSpeechAsrDataModule: input_transforms.append( SpecAugment( time_warp_factor=self.args.spec_aug_time_warp_factor, - num_frame_masks=2, + num_frame_masks=10, features_mask_size=27, num_feature_masks=2, frames_mask_size=100, + max_frames_mask_fraction=0.4, ) ) else: @@ -383,3 +383,212 @@ class LibriSpeechAsrDataModule: def test_other_cuts(self) -> CutSet: logging.info("About to get test-other cuts") return load_manifest(self.args.manifest_dir / "cuts_test-other.json.gz") + + +import math +import random +import numpy as np +from typing import Optional, Dict + +import torch + +from lhotse import CutSet + +class SpecAugment(torch.nn.Module): + """ + SpecAugment performs three augmentations: + - time warping of the feature matrix + - masking of ranges of features (frequency bands) + - masking of ranges of frames (time) + + The current implementation works with batches, but processes each example separately + in a loop rather than simultaneously to achieve different augmentation parameters for + each example. + """ + + def __init__( + self, + time_warp_factor: Optional[int] = 80, + num_feature_masks: int = 1, + features_mask_size: int = 13, + num_frame_masks: int = 1, + frames_mask_size: int = 70, + max_frames_mask_fraction: float = 0.2, + p=0.5, + ): + """ + SpecAugment's constructor. + + :param time_warp_factor: parameter for the time warping; larger values mean more warping. + Set to ``None``, or less than ``1``, to disable. + :param num_feature_masks: how many feature masks should be applied. Set to ``0`` to disable. + :param features_mask_size: the width of the feature mask (expressed in the number of masked feature bins). + This is the ``F`` parameter from the SpecAugment paper. + :param num_frame_masks: how many frame (temporal) masks should be applied. Set to ``0`` to disable. + :param frames_mask_size: the width of the frame (temporal) masks (expressed in the number of masked frames). + This is the ``T`` parameter from the SpecAugment paper. + :param max_frames_mask_fraction: limits the size of the frame (temporal) mask to this value times the length + of the utterance (or supervision segment). + This is the parameter denoted by ``p`` in the SpecAugment paper. + :param p: the probability of applying this transform. + It is different from ``p`` in the SpecAugment paper! + """ + super().__init__() + assert 0 <= p <= 1 + assert num_feature_masks >= 0 + assert num_frame_masks >= 0 + assert features_mask_size > 0 + assert frames_mask_size > 0 + self.time_warp_factor = time_warp_factor + self.num_feature_masks = num_feature_masks + self.features_mask_size = features_mask_size + self.num_frame_masks = num_frame_masks + self.frames_mask_size = frames_mask_size + self.max_frames_mask_fraction = max_frames_mask_fraction + self.p = p + + def forward( + self, + features: torch.Tensor, + supervision_segments: Optional[torch.IntTensor] = None, + *args, + **kwargs, + ) -> torch.Tensor: + """ + Computes SpecAugment for a batch of feature matrices. + + Since the batch will usually already be padded, the user can optionally + provide a ``supervision_segments`` tensor that will be used to apply SpecAugment + only to selected areas of the input. The format of this input is described below. + + :param features: a batch of feature matrices with shape ``(B, T, F)``. + :param supervision_segments: an int tensor of shape ``(S, 3)``. ``S`` is the number of + supervision segments that exist in ``features`` -- there may be either + less or more than the batch size. + The second dimension encoder three kinds of information: + the sequence index of the corresponding feature matrix in `features`, + the start frame index, and the number of frames for each segment. + :return: an augmented tensor of shape ``(B, T, F)``. + """ + assert len(features.shape) == 3, ( + "SpecAugment only supports batches of " "single-channel feature matrices." + ) + features = features.clone() + if supervision_segments is None: + # No supervisions - apply spec augment to full feature matrices. + for sequence_idx in range(features.size(0)): + features[sequence_idx] = self._forward_single(features[sequence_idx]) + else: + # Supervisions provided - we will apply time warping only on the supervised areas. + for sequence_idx, start_frame, num_frames in supervision_segments: + end_frame = start_frame + num_frames + features[sequence_idx, start_frame:end_frame] = self._forward_single( + features[sequence_idx, start_frame:end_frame], warp=True, mask=False + ) + # ... and then time-mask the full feature matrices. Note that in this mode, + # it might happen that masks are applied to different sequences/examples + # than the time warping. + for sequence_idx in range(features.size(0)): + features[sequence_idx] = self._forward_single( + features[sequence_idx], warp=False, mask=True + ) + return features + + def _forward_single( + self, features: torch.Tensor, warp: bool = True, mask: bool = True + ) -> torch.Tensor: + """ + Apply SpecAugment to a single feature matrix of shape (T, F). + """ + if random.random() > self.p: + # Randomly choose whether this transform is applied + return features + if warp: + if self.time_warp_factor is not None and self.time_warp_factor >= 1: + features = time_warp(features, factor=self.time_warp_factor) + if mask: + from torchaudio.functional import mask_along_axis + + mean = features.mean() + for _ in range(self.num_feature_masks): + features = mask_along_axis( + features.unsqueeze(0), + mask_param=self.features_mask_size, + mask_value=mean, + axis=2, + ).squeeze(0) + for _ in range(self.num_frame_masks): + _max_tot_mask_frames = self.max_frames_mask_fraction * features.size(0) + num_frame_masks = min(self.num_frame_masks, math.ceil(_max_tot_mask_frames / self.frames_mask_size)) + max_mask_frames = min(self.frames_mask_size, _max_tot_mask_frames // num_frame_masks) + + features = mask_along_axis( + features.unsqueeze(0), + mask_param=max_mask_frames, + mask_value=mean, + axis=1, + ).squeeze(0) + return features + + def state_dict(self) -> Dict: + return dict( + time_warp_factor=self.time_warp_factor, + num_feature_masks=self.num_feature_masks, + features_mask_size=self.features_mask_size, + num_frame_masks=self.num_frame_masks, + frames_mask_size=self.frames_mask_size, + max_frames_mask_fraction=self.max_frames_mask_fraction, + p=self.p, + ) + + def load_state_dict(self, state_dict: Dict): + self.time_warp_factor = state_dict.get( + "time_warp_factor", self.time_warp_factor + ) + self.num_feature_masks = state_dict.get( + "num_feature_masks", self.num_feature_masks + ) + self.features_mask_size = state_dict.get( + "features_mask_size", self.features_mask_size + ) + self.num_frame_masks = state_dict.get("num_frame_masks", self.num_frame_masks) + self.frames_mask_size = state_dict.get( + "frames_mask_size", self.frames_mask_size + ) + self.max_frames_mask_fraction = state_dict.get( + "max_frames_mask_fraction", self.max_frames_mask_fraction + ) + self.p = state_dict.get("p", self.p) + + +def time_warp(features: torch.Tensor, factor: int) -> torch.Tensor: + """ + Time warping as described in the SpecAugment paper. + Implementation based on Espresso: + https://github.com/freewym/espresso/blob/master/espresso/tools/specaug_interpolate.py#L51 + + :param features: input tensor of shape ``(T, F)`` + :param factor: time warping parameter. + :return: a warped tensor of shape ``(T, F)`` + """ + t = features.size(0) + if t - factor <= factor + 1: + return features + center = np.random.randint(factor + 1, t - factor) + warped = np.random.randint(center - factor, center + factor + 1) + if warped == center: + return features + features = features.unsqueeze(0).unsqueeze(0) + left = torch.nn.functional.interpolate( + features[:, :, :center, :], + size=(warped, features.size(3)), + mode="bicubic", + align_corners=False, + ) + right = torch.nn.functional.interpolate( + features[:, :, center:, :], + size=(t - warped, features.size(3)), + mode="bicubic", + align_corners=False, + ) + return torch.cat((left, right), dim=2).squeeze(0).squeeze(0) \ No newline at end of file