From 03fa9957a6826cb68a5a515ae05f74c059853545 Mon Sep 17 00:00:00 2001 From: Guo Liyong Date: Mon, 30 May 2022 10:47:19 +0800 Subject: [PATCH 1/6] direct copy from lhotse --- .../ASR/pruned_transducer_stateless6/aug.py | 261 ++++++++++++++++++ 1 file changed, 261 insertions(+) create mode 100644 egs/librispeech/ASR/pruned_transducer_stateless6/aug.py diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/aug.py b/egs/librispeech/ASR/pruned_transducer_stateless6/aug.py new file mode 100644 index 000000000..202005d98 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/aug.py @@ -0,0 +1,261 @@ +import bisect +import math +import random +from typing import Dict, Optional, Sequence, Tuple, TypeVar, Union + +import numpy as np +import torch + +from lhotse import CutSet +from lhotse.augmentation import dereverb_wpe_torch +from lhotse.utils import Pathlike + + +class SpecAugment(torch.nn.Module): + """ + SpecAugment performs three augmentations: + - time warping of the feature matrix + - masking of ranges of features (frequency bands) + - masking of ranges of frames (time) + + The current implementation works with batches, but processes each example separately + in a loop rather than simultaneously to achieve different augmentation parameters for + each example. + """ + + def __init__( + self, + time_warp_factor: Optional[int] = 80, + num_feature_masks: int = 2, + features_mask_size: int = 27, + num_frame_masks: int = 10, + frames_mask_size: int = 100, + max_frames_mask_fraction: float = 0.15, + p=0.9, + ): + """ + SpecAugment's constructor. + + :param time_warp_factor: parameter for the time warping; larger values mean more warping. + Set to ``None``, or less than ``1``, to disable. + :param num_feature_masks: how many feature masks should be applied. Set to ``0`` to disable. + :param features_mask_size: the width of the feature mask (expressed in the number of masked feature bins). + This is the ``F`` parameter from the SpecAugment paper. + :param num_frame_masks: the number of masking regions for utterances. Set to ``0`` to disable. + :param frames_mask_size: the width of the frame (temporal) masks (expressed in the number of masked frames). + This is the ``T`` parameter from the SpecAugment paper. + :param max_frames_mask_fraction: limits the size of the frame (temporal) mask to this value times the length + of the utterance (or supervision segment). + This is the parameter denoted by ``p`` in the SpecAugment paper. + :param p: the probability of applying this transform. + It is different from ``p`` in the SpecAugment paper! + """ + super().__init__() + assert 0 <= p <= 1 + assert num_feature_masks >= 0 + assert num_frame_masks > 0 + assert features_mask_size > 0 + assert frames_mask_size > 0 + self.time_warp_factor = time_warp_factor + self.num_feature_masks = num_feature_masks + self.features_mask_size = features_mask_size + self.num_frame_masks = num_frame_masks + self.frames_mask_size = frames_mask_size + self.max_frames_mask_fraction = max_frames_mask_fraction + self.p = p + + def forward( + self, + features: torch.Tensor, + supervision_segments: Optional[torch.IntTensor] = None, + *args, + **kwargs, + ) -> torch.Tensor: + """ + Computes SpecAugment for a batch of feature matrices. + + Since the batch will usually already be padded, the user can optionally + provide a ``supervision_segments`` tensor that will be used to apply SpecAugment + only to selected areas of the input. The format of this input is described below. + + :param features: a batch of feature matrices with shape ``(B, T, F)``. + :param supervision_segments: an int tensor of shape ``(S, 3)``. ``S`` is the number of + supervision segments that exist in ``features`` -- there may be either + less or more than the batch size. + The second dimension encoder three kinds of information: + the sequence index of the corresponding feature matrix in `features`, + the start frame index, and the number of frames for each segment. + :return: an augmented tensor of shape ``(B, T, F)``. + """ + assert len(features.shape) == 3, ( + "SpecAugment only supports batches of " "single-channel feature matrices." + ) + features = features.clone() + if supervision_segments is None: + # No supervisions - apply spec augment to full feature matrices. + for sequence_idx in range(features.size(0)): + features[sequence_idx] = self._forward_single(features[sequence_idx]) + else: + # Supervisions provided - we will apply time warping only on the supervised areas. + for sequence_idx, start_frame, num_frames in supervision_segments: + end_frame = start_frame + num_frames + features[sequence_idx, start_frame:end_frame] = self._forward_single( + features[sequence_idx, start_frame:end_frame], warp=True, mask=False + ) + # ... and then time-mask the full feature matrices. Note that in this mode, + # it might happen that masks are applied to different sequences/examples + # than the time warping. + for sequence_idx in range(features.size(0)): + features[sequence_idx] = self._forward_single( + features[sequence_idx], warp=False, mask=True + ) + return features + + def _forward_single( + self, features: torch.Tensor, warp: bool = True, mask: bool = True + ) -> torch.Tensor: + """ + Apply SpecAugment to a single feature matrix of shape (T, F). + """ + if random.random() > self.p: + # Randomly choose whether this transform is applied + return features + if warp: + if self.time_warp_factor is not None and self.time_warp_factor >= 1: + features = time_warp(features, factor=self.time_warp_factor) + if mask: + mean = features.mean() + # Frequency masking + features = mask_along_axis_optimized( + features, + mask_size=self.features_mask_size, + mask_times=self.num_feature_masks, + mask_value=mean, + axis=2, + ) + # Time masking + max_tot_mask_frames = self.max_frames_mask_fraction * features.size(0) + num_frame_masks = min( + self.num_frame_masks, + math.ceil(max_tot_mask_frames / self.frames_mask_size), + ) + max_mask_frames = min( + self.frames_mask_size, max_tot_mask_frames // num_frame_masks + ) + features = mask_along_axis_optimized( + features, + mask_size=max_mask_frames, + mask_times=num_frame_masks, + mask_value=mean, + axis=1, + ) + + return features + + def state_dict(self) -> Dict: + return dict( + time_warp_factor=self.time_warp_factor, + num_feature_masks=self.num_feature_masks, + features_mask_size=self.features_mask_size, + num_frame_masks=self.num_frame_masks, + frames_mask_size=self.frames_mask_size, + max_frames_mask_fraction=self.max_frames_mask_fraction, + p=self.p, + ) + + def load_state_dict(self, state_dict: Dict): + self.time_warp_factor = state_dict.get( + "time_warp_factor", self.time_warp_factor + ) + self.num_feature_masks = state_dict.get( + "num_feature_masks", self.num_feature_masks + ) + self.features_mask_size = state_dict.get( + "features_mask_size", self.features_mask_size + ) + self.num_frame_masks = state_dict.get("num_frame_masks", self.num_frame_masks) + self.frames_mask_size = state_dict.get( + "frames_mask_size", self.frames_mask_size + ) + self.max_frames_mask_fraction = state_dict.get( + "max_frames_mask_fraction", self.max_frames_mask_fraction + ) + self.p = state_dict.get("p", self.p) + + +def mask_along_axis_optimized( + features: torch.Tensor, + mask_size: int, + mask_times: int, + mask_value: float, + axis: int, +) -> torch.Tensor: + """ + Apply Frequency and Time masking along axis. + Frequency and Time masking as described in the SpecAugment paper. + + :param features: input tensor of shape ``(T, F)`` + :mask_size: the width size for masking. + :mask_times: the number of masking regions. + :mask_value: Value to assign to the masked regions. + :axis: Axis to apply masking on (1 -> time, 2 -> frequency) + """ + if axis not in [1, 2]: + raise ValueError("Only Frequency and Time masking are supported!") + + features = features.unsqueeze(0) + features = features.reshape([-1] + list(features.size()[-2:])) + + values = torch.randint(int(0), int(mask_size), (1, mask_times)) + min_values = torch.rand(1, mask_times) * (features.size(axis) - values) + mask_starts = (min_values.long()).squeeze() + mask_ends = (min_values.long() + values.long()).squeeze() + + if axis == 1: + if mask_times == 1: + features[:, mask_starts:mask_ends] = mask_value + return features.squeeze(0) + for (mask_start, mask_end) in zip(mask_starts, mask_ends): + features[:, mask_start:mask_end] = mask_value + else: + if mask_times == 1: + features[:, :, mask_starts:mask_ends] = mask_value + return features.squeeze(0) + for (mask_start, mask_end) in zip(mask_starts, mask_ends): + features[:, :, mask_start:mask_end] = mask_value + + features = features.squeeze(0) + return features + + +def time_warp(features: torch.Tensor, factor: int) -> torch.Tensor: + """ + Time warping as described in the SpecAugment paper. + Implementation based on Espresso: + https://github.com/freewym/espresso/blob/master/espresso/tools/specaug_interpolate.py#L51 + + :param features: input tensor of shape ``(T, F)`` + :param factor: time warping parameter. + :return: a warped tensor of shape ``(T, F)`` + """ + t = features.size(0) + if t - factor <= factor + 1: + return features + center = np.random.randint(factor + 1, t - factor) + warped = np.random.randint(center - factor, center + factor + 1) + if warped == center: + return features + features = features.unsqueeze(0).unsqueeze(0) + left = torch.nn.functional.interpolate( + features[:, :, :center, :], + size=(warped, features.size(3)), + mode="bicubic", + align_corners=False, + ) + right = torch.nn.functional.interpolate( + features[:, :, center:, :], + size=(t - warped, features.size(3)), + mode="bicubic", + align_corners=False, + ) + return torch.cat((left, right), dim=2).squeeze(0).squeeze(0) From 496fb963c8387fb9466ee898c4f1c7aa27115e1d Mon Sep 17 00:00:00 2001 From: Guo Liyong Date: Mon, 30 May 2022 10:49:08 +0800 Subject: [PATCH 2/6] icefall format --- .../ASR/pruned_transducer_stateless6/aug.py | 30 +++++++++++-------- 1 file changed, 18 insertions(+), 12 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/aug.py b/egs/librispeech/ASR/pruned_transducer_stateless6/aug.py index 202005d98..c60d328c5 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/aug.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/aug.py @@ -1,15 +1,10 @@ -import bisect import math import random -from typing import Dict, Optional, Sequence, Tuple, TypeVar, Union +from typing import Dict, Optional import numpy as np import torch -from lhotse import CutSet -from lhotse.augmentation import dereverb_wpe_torch -from lhotse.utils import Pathlike - class SpecAugment(torch.nn.Module): """ @@ -88,19 +83,26 @@ class SpecAugment(torch.nn.Module): :return: an augmented tensor of shape ``(B, T, F)``. """ assert len(features.shape) == 3, ( - "SpecAugment only supports batches of " "single-channel feature matrices." + "SpecAugment only supports batches of " + "single-channel feature matrices." ) features = features.clone() if supervision_segments is None: # No supervisions - apply spec augment to full feature matrices. for sequence_idx in range(features.size(0)): - features[sequence_idx] = self._forward_single(features[sequence_idx]) + features[sequence_idx] = self._forward_single( + features[sequence_idx] + ) else: # Supervisions provided - we will apply time warping only on the supervised areas. for sequence_idx, start_frame, num_frames in supervision_segments: end_frame = start_frame + num_frames - features[sequence_idx, start_frame:end_frame] = self._forward_single( - features[sequence_idx, start_frame:end_frame], warp=True, mask=False + features[ + sequence_idx, start_frame:end_frame + ] = self._forward_single( + features[sequence_idx, start_frame:end_frame], + warp=True, + mask=False, ) # ... and then time-mask the full feature matrices. Note that in this mode, # it might happen that masks are applied to different sequences/examples @@ -134,7 +136,9 @@ class SpecAugment(torch.nn.Module): axis=2, ) # Time masking - max_tot_mask_frames = self.max_frames_mask_fraction * features.size(0) + max_tot_mask_frames = self.max_frames_mask_fraction * features.size( + 0 + ) num_frame_masks = min( self.num_frame_masks, math.ceil(max_tot_mask_frames / self.frames_mask_size), @@ -173,7 +177,9 @@ class SpecAugment(torch.nn.Module): self.features_mask_size = state_dict.get( "features_mask_size", self.features_mask_size ) - self.num_frame_masks = state_dict.get("num_frame_masks", self.num_frame_masks) + self.num_frame_masks = state_dict.get( + "num_frame_masks", self.num_frame_masks + ) self.frames_mask_size = state_dict.get( "frames_mask_size", self.frames_mask_size ) From 0c33543ce7fa644193f490c79edd34fad5870bfa Mon Sep 17 00:00:00 2001 From: Guo Liyong Date: Mon, 30 May 2022 14:31:20 +0800 Subject: [PATCH 3/6] copy asr_datamodule.py to stateless6 --- .../asr_datamodule.py | 446 +++++++++++++++++- 1 file changed, 445 insertions(+), 1 deletion(-) mode change 120000 => 100644 egs/librispeech/ASR/pruned_transducer_stateless6/asr_datamodule.py diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/asr_datamodule.py b/egs/librispeech/ASR/pruned_transducer_stateless6/asr_datamodule.py deleted file mode 120000 index a074d6085..000000000 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/asr_datamodule.py +++ /dev/null @@ -1 +0,0 @@ -../pruned_transducer_stateless2/asr_datamodule.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/asr_datamodule.py b/egs/librispeech/ASR/pruned_transducer_stateless6/asr_datamodule.py new file mode 100644 index 000000000..e83009d4a --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/asr_datamodule.py @@ -0,0 +1,445 @@ +# Copyright 2021 Piotr Żelasko +# Copyright 2022 Xiaomi Corporation (Author: Mingshuang Luo) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import inspect +import logging +from functools import lru_cache +from pathlib import Path +from typing import Any, Dict, Optional + +import torch +from lhotse import CutSet, Fbank, FbankConfig, load_manifest +from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures + BucketingSampler, + CutConcatenate, + CutMix, + K2SpeechRecognitionDataset, + PrecomputedFeatures, + SingleCutSampler, + SpecAugment, +) +from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples + AudioSamples, + OnTheFlyFeatures, +) +from lhotse.utils import fix_random_seed +from torch.utils.data import DataLoader + +from icefall.utils import str2bool + + +class _SeedWorkers: + def __init__(self, seed: int): + self.seed = seed + + def __call__(self, worker_id: int): + fix_random_seed(self.seed + worker_id) + + +class LibriSpeechAsrDataModule: + """ + DataModule for k2 ASR experiments. + It assumes there is always one train and valid dataloader, + but there can be multiple test dataloaders (e.g. LibriSpeech test-clean + and test-other). + + It contains all the common data pipeline modules used in ASR + experiments, e.g.: + - dynamic batch size, + - bucketing samplers, + - cut concatenation, + - augmentation, + - on-the-fly feature extraction + + This class should be derived for specific corpora used in ASR tasks. + """ + + def __init__(self, args: argparse.Namespace): + self.args = args + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + group = parser.add_argument_group( + title="ASR data related options", + description="These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc.", + ) + group.add_argument( + "--full-libri", + type=str2bool, + default=True, + help="When enabled, use 960h LibriSpeech. " + "Otherwise, use 100h subset.", + ) + group.add_argument( + "--manifest-dir", + type=Path, + default=Path("data/fbank"), + help="Path to directory with train/valid/test cuts.", + ) + group.add_argument( + "--max-duration", + type=int, + default=200.0, + help="Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM.", + ) + group.add_argument( + "--bucketing-sampler", + type=str2bool, + default=True, + help="When enabled, the batches will come from buckets of " + "similar duration (saves padding frames).", + ) + group.add_argument( + "--num-buckets", + type=int, + default=30, + help="The number of buckets for the BucketingSampler" + "(you might want to increase it for larger datasets).", + ) + group.add_argument( + "--concatenate-cuts", + type=str2bool, + default=False, + help="When enabled, utterances (cuts) will be concatenated " + "to minimize the amount of padding.", + ) + group.add_argument( + "--duration-factor", + type=float, + default=1.0, + help="Determines the maximum duration of a concatenated cut " + "relative to the duration of the longest cut in a batch.", + ) + group.add_argument( + "--gap", + type=float, + default=1.0, + help="The amount of padding (in seconds) inserted between " + "concatenated cuts. This padding is filled with noise when " + "noise augmentation is used.", + ) + group.add_argument( + "--on-the-fly-feats", + type=str2bool, + default=False, + help="When enabled, use on-the-fly cut mixing and feature " + "extraction. Will drop existing precomputed feature manifests " + "if available.", + ) + group.add_argument( + "--shuffle", + type=str2bool, + default=True, + help="When enabled (=default), the examples will be " + "shuffled for each epoch.", + ) + group.add_argument( + "--drop-last", + type=str2bool, + default=True, + help="Whether to drop last batch. Used by sampler.", + ) + group.add_argument( + "--return-cuts", + type=str2bool, + default=True, + help="When enabled, each batch will have the " + "field: batch['supervisions']['cut'] with the cuts that " + "were used to construct it.", + ) + + group.add_argument( + "--num-workers", + type=int, + default=2, + help="The number of training dataloader workers that " + "collect the batches.", + ) + + group.add_argument( + "--enable-spec-aug", + type=str2bool, + default=True, + help="When enabled, use SpecAugment for training dataset.", + ) + + group.add_argument( + "--spec-aug-time-warp-factor", + type=int, + default=80, + help="Used only when --enable-spec-aug is True. " + "It specifies the factor for time warping in SpecAugment. " + "Larger values mean more warping. " + "A value less than 1 means to disable time warp.", + ) + + group.add_argument( + "--enable-musan", + type=str2bool, + default=True, + help="When enabled, select noise from MUSAN and mix it" + "with training dataset. ", + ) + + group.add_argument( + "--input-strategy", + type=str, + default="PrecomputedFeatures", + help="AudioSamples or PrecomputedFeatures", + ) + + def train_dataloaders( + self, + cuts_train: CutSet, + sampler_state_dict: Optional[Dict[str, Any]] = None, + ) -> DataLoader: + """ + Args: + cuts_train: + CutSet for training. + sampler_state_dict: + The state dict for the training sampler. + """ + transforms = [] + if self.args.enable_musan: + logging.info("Enable MUSAN") + logging.info("About to get Musan cuts") + cuts_musan = load_manifest( + self.args.manifest_dir / "cuts_musan.json.gz" + ) + transforms.append( + CutMix( + cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True + ) + ) + else: + logging.info("Disable MUSAN") + + if self.args.concatenate_cuts: + logging.info( + f"Using cut concatenation with duration factor " + f"{self.args.duration_factor} and gap {self.args.gap}." + ) + # Cut concatenation should be the first transform in the list, + # so that if we e.g. mix noise in, it will fill the gaps between + # different utterances. + transforms = [ + CutConcatenate( + duration_factor=self.args.duration_factor, gap=self.args.gap + ) + ] + transforms + + input_transforms = [] + if self.args.enable_spec_aug: + logging.info("Enable SpecAugment") + logging.info( + f"Time warp factor: {self.args.spec_aug_time_warp_factor}" + ) + # Set the value of num_frame_masks according to Lhotse's version. + # In different Lhotse's versions, the default of num_frame_masks is + # different. + num_frame_masks = 10 + num_frame_masks_parameter = inspect.signature( + SpecAugment.__init__ + ).parameters["num_frame_masks"] + if num_frame_masks_parameter.default == 1: + num_frame_masks = 2 + logging.info(f"Num frame mask: {num_frame_masks}") + input_transforms.append( + SpecAugment( + time_warp_factor=self.args.spec_aug_time_warp_factor, + num_frame_masks=num_frame_masks, + features_mask_size=27, + num_feature_masks=2, + frames_mask_size=100, + ) + ) + else: + logging.info("Disable SpecAugment") + + logging.info("About to create train dataset") + train = K2SpeechRecognitionDataset( + input_strategy=eval(self.args.input_strategy)(), + cut_transforms=transforms, + input_transforms=input_transforms, + return_cuts=self.args.return_cuts, + ) + + if self.args.on_the_fly_feats: + # NOTE: the PerturbSpeed transform should be added only if we + # remove it from data prep stage. + # Add on-the-fly speed perturbation; since originally it would + # have increased epoch size by 3, we will apply prob 2/3 and use + # 3x more epochs. + # Speed perturbation probably should come first before + # concatenation, but in principle the transforms order doesn't have + # to be strict (e.g. could be randomized) + # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa + # Drop feats to be on the safe side. + train = K2SpeechRecognitionDataset( + cut_transforms=transforms, + input_strategy=OnTheFlyFeatures( + Fbank(FbankConfig(num_mel_bins=80)) + ), + input_transforms=input_transforms, + return_cuts=self.args.return_cuts, + ) + + if self.args.bucketing_sampler: + logging.info("Using BucketingSampler.") + train_sampler = BucketingSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + num_buckets=self.args.num_buckets, + bucket_method="equal_duration", + drop_last=self.args.drop_last, + ) + else: + logging.info("Using SingleCutSampler.") + train_sampler = SingleCutSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + ) + logging.info("About to create train dataloader") + + if sampler_state_dict is not None: + logging.info("Loading sampler state dict") + train_sampler.load_state_dict(sampler_state_dict) + + # 'seed' is derived from the current random state, which will have + # previously been set in the main process. + seed = torch.randint(0, 100000, ()).item() + worker_init_fn = _SeedWorkers(seed) + + train_dl = DataLoader( + train, + sampler=train_sampler, + batch_size=None, + num_workers=self.args.num_workers, + persistent_workers=False, + worker_init_fn=worker_init_fn, + ) + + return train_dl + + def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: + transforms = [] + if self.args.concatenate_cuts: + transforms = [ + CutConcatenate( + duration_factor=self.args.duration_factor, gap=self.args.gap + ) + ] + transforms + + logging.info("About to create dev dataset") + if self.args.on_the_fly_feats: + validate = K2SpeechRecognitionDataset( + cut_transforms=transforms, + input_strategy=OnTheFlyFeatures( + Fbank(FbankConfig(num_mel_bins=80)) + ), + return_cuts=self.args.return_cuts, + ) + else: + validate = K2SpeechRecognitionDataset( + cut_transforms=transforms, + return_cuts=self.args.return_cuts, + ) + valid_sampler = BucketingSampler( + cuts_valid, + max_duration=self.args.max_duration, + shuffle=False, + ) + logging.info("About to create dev dataloader") + valid_dl = DataLoader( + validate, + sampler=valid_sampler, + batch_size=None, + num_workers=2, + persistent_workers=False, + ) + + return valid_dl + + def test_dataloaders(self, cuts: CutSet) -> DataLoader: + logging.debug("About to create test dataset") + test = K2SpeechRecognitionDataset( + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) + if self.args.on_the_fly_feats + else eval(self.args.input_strategy)(), + return_cuts=self.args.return_cuts, + ) + sampler = BucketingSampler( + cuts, max_duration=self.args.max_duration, shuffle=False + ) + logging.debug("About to create test dataloader") + test_dl = DataLoader( + test, + batch_size=None, + sampler=sampler, + num_workers=self.args.num_workers, + ) + return test_dl + + @lru_cache() + def train_clean_100_cuts(self) -> CutSet: + logging.info("About to get train-clean-100 cuts") + return load_manifest( + self.args.manifest_dir / "cuts_train-clean-100.json.gz" + ) + + @lru_cache() + def train_clean_360_cuts(self) -> CutSet: + logging.info("About to get train-clean-360 cuts") + return load_manifest( + self.args.manifest_dir / "cuts_train-clean-360.json.gz" + ) + + @lru_cache() + def train_other_500_cuts(self) -> CutSet: + logging.info("About to get train-other-500 cuts") + return load_manifest( + self.args.manifest_dir / "cuts_train-other-500.json.gz" + ) + + @lru_cache() + def dev_clean_cuts(self) -> CutSet: + logging.info("About to get dev-clean cuts") + return load_manifest(self.args.manifest_dir / "cuts_dev-clean.json.gz") + + @lru_cache() + def dev_other_cuts(self) -> CutSet: + logging.info("About to get dev-other cuts") + return load_manifest(self.args.manifest_dir / "cuts_dev-other.json.gz") + + @lru_cache() + def test_clean_cuts(self) -> CutSet: + logging.info("About to get test-clean cuts") + return load_manifest(self.args.manifest_dir / "cuts_test-clean.json.gz") + + @lru_cache() + def test_other_cuts(self) -> CutSet: + logging.info("About to get test-other cuts") + return load_manifest(self.args.manifest_dir / "cuts_test-other.json.gz") From 90024c308f1ac7a6e2177eb82db898669e2f7cb6 Mon Sep 17 00:00:00 2001 From: Guo Liyong Date: Mon, 30 May 2022 14:33:48 +0800 Subject: [PATCH 4/6] predicted masked codebook indexes only --- .../asr_datamodule.py | 2 +- .../ASR/pruned_transducer_stateless6/aug.py | 58 +++++++++++++------ .../ASR/pruned_transducer_stateless6/model.py | 26 ++++++++- .../ASR/pruned_transducer_stateless6/train.py | 3 +- 4 files changed, 67 insertions(+), 22 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/asr_datamodule.py b/egs/librispeech/ASR/pruned_transducer_stateless6/asr_datamodule.py index e83009d4a..03d0d1a88 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/asr_datamodule.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/asr_datamodule.py @@ -32,7 +32,6 @@ from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures K2SpeechRecognitionDataset, PrecomputedFeatures, SingleCutSampler, - SpecAugment, ) from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples AudioSamples, @@ -41,6 +40,7 @@ from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples from lhotse.utils import fix_random_seed from torch.utils.data import DataLoader +from aug import SpecAugment from icefall.utils import str2bool diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/aug.py b/egs/librispeech/ASR/pruned_transducer_stateless6/aug.py index c60d328c5..0746d0036 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/aug.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/aug.py @@ -1,6 +1,6 @@ import math import random -from typing import Dict, Optional +from typing import Dict, Optional, Tuple import numpy as np import torch @@ -65,7 +65,7 @@ class SpecAugment(torch.nn.Module): supervision_segments: Optional[torch.IntTensor] = None, *args, **kwargs, - ) -> torch.Tensor: + ) -> Tuple[torch.Tensor, torch.Tensor]: """ Computes SpecAugment for a batch of feature matrices. @@ -87,19 +87,25 @@ class SpecAugment(torch.nn.Module): "single-channel feature matrices." ) features = features.clone() + # 1 (True) represents masked area; + # 0 (False) represents original un-masked area. + time_masked_area = torch.zeros_like(features) if supervision_segments is None: # No supervisions - apply spec augment to full feature matrices. for sequence_idx in range(features.size(0)): - features[sequence_idx] = self._forward_single( - features[sequence_idx] - ) + ( + features[sequence_idx], + time_masked_area[sequence_idx], + ) = self._forward_single(features[sequence_idx]) + else: # Supervisions provided - we will apply time warping only on the supervised areas. for sequence_idx, start_frame, num_frames in supervision_segments: end_frame = start_frame + num_frames - features[ - sequence_idx, start_frame:end_frame - ] = self._forward_single( + ( + features[sequence_idx, start_frame:end_frame], + time_masked_area[sequence_idx, start_frame:end_frame], + ) = self._forward_single( features[sequence_idx, start_frame:end_frame], warp=True, mask=False, @@ -108,27 +114,33 @@ class SpecAugment(torch.nn.Module): # it might happen that masks are applied to different sequences/examples # than the time warping. for sequence_idx in range(features.size(0)): - features[sequence_idx] = self._forward_single( + ( + features[sequence_idx], + time_masked_area[sequence_idx], + ) = self._forward_single( features[sequence_idx], warp=False, mask=True ) - return features + + return features, time_masked_area def _forward_single( self, features: torch.Tensor, warp: bool = True, mask: bool = True - ) -> torch.Tensor: + ) -> Tuple[torch.Tensor, torch.Tensor]: """ Apply SpecAugment to a single feature matrix of shape (T, F). """ + time_masked_area = torch.zeros_like(features) if random.random() > self.p: # Randomly choose whether this transform is applied - return features + # No augmentation, no masked area. + return features, time_masked_area if warp: if self.time_warp_factor is not None and self.time_warp_factor >= 1: features = time_warp(features, factor=self.time_warp_factor) if mask: mean = features.mean() # Frequency masking - features = mask_along_axis_optimized( + features, _ = mask_along_axis_optimized( features, mask_size=self.features_mask_size, mask_times=self.num_feature_masks, @@ -146,7 +158,7 @@ class SpecAugment(torch.nn.Module): max_mask_frames = min( self.frames_mask_size, max_tot_mask_frames // num_frame_masks ) - features = mask_along_axis_optimized( + features, time_masked_area = mask_along_axis_optimized( features, mask_size=max_mask_frames, mask_times=num_frame_masks, @@ -154,7 +166,7 @@ class SpecAugment(torch.nn.Module): axis=1, ) - return features + return features, time_masked_area def state_dict(self) -> Dict: return dict( @@ -195,7 +207,7 @@ def mask_along_axis_optimized( mask_times: int, mask_value: float, axis: int, -) -> torch.Tensor: +) -> Tuple[torch.Tensor, torch.Tensor]: """ Apply Frequency and Time masking along axis. Frequency and Time masking as described in the SpecAugment paper. @@ -209,7 +221,11 @@ def mask_along_axis_optimized( if axis not in [1, 2]: raise ValueError("Only Frequency and Time masking are supported!") + # 1 (True) represents masked area; + # 0 (False) represents original un-masked area. + masked_area = torch.zeros_like(features) features = features.unsqueeze(0) + masked_area = masked_area.unsqueeze(0) features = features.reshape([-1] + list(features.size()[-2:])) values = torch.randint(int(0), int(mask_size), (1, mask_times)) @@ -220,18 +236,22 @@ def mask_along_axis_optimized( if axis == 1: if mask_times == 1: features[:, mask_starts:mask_ends] = mask_value - return features.squeeze(0) + return features.squeeze(0), masked_area for (mask_start, mask_end) in zip(mask_starts, mask_ends): features[:, mask_start:mask_end] = mask_value + masked_area[:, mask_start:mask_end] = 1 else: if mask_times == 1: features[:, :, mask_starts:mask_ends] = mask_value - return features.squeeze(0) + masked_area[:, :, mask_starts:mask_ends] = 1 + return features.squeeze(0), masked_area for (mask_start, mask_end) in zip(mask_starts, mask_ends): features[:, :, mask_start:mask_end] = mask_value + masked_area[:, :, mask_start:mask_end] = 1 features = features.squeeze(0) - return features + masked_area = masked_area.squeeze(0) + return features, masked_area def time_warp(features: torch.Tensor, factor: int) -> torch.Tensor: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/model.py b/egs/librispeech/ASR/pruned_transducer_stateless6/model.py index 66bb33e8d..5102f357e 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/model.py @@ -75,7 +75,9 @@ class Transducer(nn.Module): self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size) if num_codebooks > 0: self.codebook_loss_net = JointCodebookLoss( - predictor_channels=encoder_dim, num_codebooks=num_codebooks + predictor_channels=encoder_dim, + num_codebooks=num_codebooks, + reduction="none", ) def forward( @@ -88,6 +90,8 @@ class Transducer(nn.Module): lm_scale: float = 0.0, warmup: float = 1.0, codebook_indexes: torch.Tensor = None, + time_masked_area: torch.Tensor = None, + masked_scale: float = 1.0, ) -> torch.Tensor: """ Args: @@ -113,6 +117,11 @@ class Transducer(nn.Module): warmup > 1 "are fully warmed up" and all modules will be active. codebook_indexes: codebook_indexes extracted from a teacher model. + time_masked_area: + masked area by SpecAugment, 1 represents masked. + masked_scale: + scale of codebook loss of masked area. + the unmasked_scale = 1 - masked_scale Returns: Return the transducer loss. @@ -140,6 +149,21 @@ class Transducer(nn.Module): codebook_loss = self.codebook_loss_net( middle_layer_output, codebook_indexes ) + codebook_loss = codebook_loss.reshape(codebook_indexes.shape) + target_t = codebook_loss.shape[1] + time_masked_area = time_masked_area.bool() + time_masked_area = time_masked_area[ + :, : target_t * 4 : 4, 0 # noqa E203 + ] + assert time_masked_area.shape == codebook_loss.shape[:-1] + time_masked_area = time_masked_area.unsqueeze(2).to( + codebook_loss.device + ) + masked_loss = (time_masked_area * codebook_loss).sum() + unmasked_loss = (~time_masked_area * codebook_loss).sum() + codebook_loss = ( + masked_scale * masked_loss + (1 - masked_scale) * unmasked_loss + ) else: # when codebook index is not available. codebook_loss = None diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/train.py b/egs/librispeech/ASR/pruned_transducer_stateless6/train.py index feb58f457..dbf87ff48 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/train.py @@ -602,7 +602,7 @@ def compute_loss( if isinstance(model, DDP) else next(model.parameters()).device ) - feature = batch["inputs"] + feature, time_masked_area = batch["inputs"] # at entry, feature is (N, T, C) assert feature.ndim == 3 feature = feature.to(device) @@ -631,6 +631,7 @@ def compute_loss( lm_scale=params.lm_scale, warmup=warmup, codebook_indexes=codebook_indexes, + time_masked_area=time_masked_area, ) # after the main warmup step, we keep pruned_loss_scale small # for the same amount of time (model_warm_step), to avoid From c381b491f1cfa1e3fb6d505f66c17f9eccc1c1f2 Mon Sep 17 00:00:00 2001 From: Guo Liyong Date: Sat, 4 Jun 2022 21:01:07 +0800 Subject: [PATCH 5/6] different weight for masked/unmasked region --- .../ASR/pruned_transducer_stateless6/model.py | 17 ++++++++++------ .../ASR/pruned_transducer_stateless6/train.py | 20 ++++++++++++++++++- 2 files changed, 30 insertions(+), 7 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/model.py b/egs/librispeech/ASR/pruned_transducer_stateless6/model.py index 5102f357e..305049d69 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/model.py @@ -23,7 +23,7 @@ from scaling import ScaledLinear from icefall.utils import add_sos -from quantization.prediction import JointCodebookLoss +from multi_quantization.prediction import JointCodebookLoss class Transducer(nn.Module): @@ -41,6 +41,8 @@ class Transducer(nn.Module): joiner_dim: int, vocab_size: int, num_codebooks: int = 0, + masked_scale: float = 1.0, + unmasked_scale: float = 1.0, ): """ Args: @@ -60,6 +62,10 @@ class Transducer(nn.Module): contains unnormalized probs, i.e., not processed by log-softmax. num_codebooks: Used by distillation loss. + masked_scale: + scale of codebook loss of masked area. + unmasked_scale: + scale of codebook loss of unmasked area. """ super().__init__() assert isinstance(encoder, EncoderInterface), type(encoder) @@ -79,6 +85,8 @@ class Transducer(nn.Module): num_codebooks=num_codebooks, reduction="none", ) + self.masked_scale = masked_scale + self.unmasked_scale = unmasked_scale def forward( self, @@ -91,7 +99,6 @@ class Transducer(nn.Module): warmup: float = 1.0, codebook_indexes: torch.Tensor = None, time_masked_area: torch.Tensor = None, - masked_scale: float = 1.0, ) -> torch.Tensor: """ Args: @@ -119,9 +126,6 @@ class Transducer(nn.Module): codebook_indexes extracted from a teacher model. time_masked_area: masked area by SpecAugment, 1 represents masked. - masked_scale: - scale of codebook loss of masked area. - the unmasked_scale = 1 - masked_scale Returns: Return the transducer loss. @@ -162,7 +166,8 @@ class Transducer(nn.Module): masked_loss = (time_masked_area * codebook_loss).sum() unmasked_loss = (~time_masked_area * codebook_loss).sum() codebook_loss = ( - masked_scale * masked_loss + (1 - masked_scale) * unmasked_loss + self.masked_scale * masked_loss + + self.unmasked_scale * unmasked_loss ) else: # when codebook index is not available. diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/train.py b/egs/librispeech/ASR/pruned_transducer_stateless6/train.py index dbf87ff48..3cec4326e 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/train.py @@ -177,6 +177,18 @@ def get_parser(): changed.""", ) + parser.add_argument( + "--masked-scale", + type=float, + default=1.0, + ) + + parser.add_argument( + "--unmasked-scale", + type=float, + default=1.0, + ) + parser.add_argument( "--lr-batches", type=float, @@ -378,6 +390,8 @@ def get_params() -> AttributeDict: # two successive codebook_index are concatenated together. # Detailed in function Transducer::concat_sucessive_codebook_indexes. "num_codebooks": 16, # used to construct distillation loss + "masked_scale": 1.0, + "unmasked_scale": 1.0, } ) @@ -436,6 +450,8 @@ def get_transducer_model(params: AttributeDict) -> nn.Module: num_codebooks=params.num_codebooks if params.enable_distiallation else 0, + masked_scale=params.masked_scale, + unmasked_scale=params.unmasked_scale, ) return model @@ -1090,7 +1106,9 @@ def main(): parser = get_parser() LibriSpeechAsrDataModule.add_arguments(parser) args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) + args.exp_dir = Path( + f"{args.exp_dir}-masked_scale-{args.masked_scale}-un-{args.unmasked_scale}-{args.spec_aug_max_frames_mask_fraction}" + ) world_size = args.world_size assert world_size >= 1 From b52b5c683f4f7daef450eebf60f651100b916cdc Mon Sep 17 00:00:00 2001 From: Guo Liyong Date: Sat, 4 Jun 2022 21:05:07 +0800 Subject: [PATCH 6/6] config spec-aug-max-frames-mask-fraction --- .../asr_datamodule.py | 446 +----------------- .../ASR/tdnn_lstm_ctc/asr_datamodule.py | 9 +- 2 files changed, 9 insertions(+), 446 deletions(-) mode change 100644 => 120000 egs/librispeech/ASR/pruned_transducer_stateless6/asr_datamodule.py diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/asr_datamodule.py b/egs/librispeech/ASR/pruned_transducer_stateless6/asr_datamodule.py deleted file mode 100644 index 03d0d1a88..000000000 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/asr_datamodule.py +++ /dev/null @@ -1,445 +0,0 @@ -# Copyright 2021 Piotr Żelasko -# Copyright 2022 Xiaomi Corporation (Author: Mingshuang Luo) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import argparse -import inspect -import logging -from functools import lru_cache -from pathlib import Path -from typing import Any, Dict, Optional - -import torch -from lhotse import CutSet, Fbank, FbankConfig, load_manifest -from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures - BucketingSampler, - CutConcatenate, - CutMix, - K2SpeechRecognitionDataset, - PrecomputedFeatures, - SingleCutSampler, -) -from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples - AudioSamples, - OnTheFlyFeatures, -) -from lhotse.utils import fix_random_seed -from torch.utils.data import DataLoader - -from aug import SpecAugment -from icefall.utils import str2bool - - -class _SeedWorkers: - def __init__(self, seed: int): - self.seed = seed - - def __call__(self, worker_id: int): - fix_random_seed(self.seed + worker_id) - - -class LibriSpeechAsrDataModule: - """ - DataModule for k2 ASR experiments. - It assumes there is always one train and valid dataloader, - but there can be multiple test dataloaders (e.g. LibriSpeech test-clean - and test-other). - - It contains all the common data pipeline modules used in ASR - experiments, e.g.: - - dynamic batch size, - - bucketing samplers, - - cut concatenation, - - augmentation, - - on-the-fly feature extraction - - This class should be derived for specific corpora used in ASR tasks. - """ - - def __init__(self, args: argparse.Namespace): - self.args = args - - @classmethod - def add_arguments(cls, parser: argparse.ArgumentParser): - group = parser.add_argument_group( - title="ASR data related options", - description="These options are used for the preparation of " - "PyTorch DataLoaders from Lhotse CutSet's -- they control the " - "effective batch sizes, sampling strategies, applied data " - "augmentations, etc.", - ) - group.add_argument( - "--full-libri", - type=str2bool, - default=True, - help="When enabled, use 960h LibriSpeech. " - "Otherwise, use 100h subset.", - ) - group.add_argument( - "--manifest-dir", - type=Path, - default=Path("data/fbank"), - help="Path to directory with train/valid/test cuts.", - ) - group.add_argument( - "--max-duration", - type=int, - default=200.0, - help="Maximum pooled recordings duration (seconds) in a " - "single batch. You can reduce it if it causes CUDA OOM.", - ) - group.add_argument( - "--bucketing-sampler", - type=str2bool, - default=True, - help="When enabled, the batches will come from buckets of " - "similar duration (saves padding frames).", - ) - group.add_argument( - "--num-buckets", - type=int, - default=30, - help="The number of buckets for the BucketingSampler" - "(you might want to increase it for larger datasets).", - ) - group.add_argument( - "--concatenate-cuts", - type=str2bool, - default=False, - help="When enabled, utterances (cuts) will be concatenated " - "to minimize the amount of padding.", - ) - group.add_argument( - "--duration-factor", - type=float, - default=1.0, - help="Determines the maximum duration of a concatenated cut " - "relative to the duration of the longest cut in a batch.", - ) - group.add_argument( - "--gap", - type=float, - default=1.0, - help="The amount of padding (in seconds) inserted between " - "concatenated cuts. This padding is filled with noise when " - "noise augmentation is used.", - ) - group.add_argument( - "--on-the-fly-feats", - type=str2bool, - default=False, - help="When enabled, use on-the-fly cut mixing and feature " - "extraction. Will drop existing precomputed feature manifests " - "if available.", - ) - group.add_argument( - "--shuffle", - type=str2bool, - default=True, - help="When enabled (=default), the examples will be " - "shuffled for each epoch.", - ) - group.add_argument( - "--drop-last", - type=str2bool, - default=True, - help="Whether to drop last batch. Used by sampler.", - ) - group.add_argument( - "--return-cuts", - type=str2bool, - default=True, - help="When enabled, each batch will have the " - "field: batch['supervisions']['cut'] with the cuts that " - "were used to construct it.", - ) - - group.add_argument( - "--num-workers", - type=int, - default=2, - help="The number of training dataloader workers that " - "collect the batches.", - ) - - group.add_argument( - "--enable-spec-aug", - type=str2bool, - default=True, - help="When enabled, use SpecAugment for training dataset.", - ) - - group.add_argument( - "--spec-aug-time-warp-factor", - type=int, - default=80, - help="Used only when --enable-spec-aug is True. " - "It specifies the factor for time warping in SpecAugment. " - "Larger values mean more warping. " - "A value less than 1 means to disable time warp.", - ) - - group.add_argument( - "--enable-musan", - type=str2bool, - default=True, - help="When enabled, select noise from MUSAN and mix it" - "with training dataset. ", - ) - - group.add_argument( - "--input-strategy", - type=str, - default="PrecomputedFeatures", - help="AudioSamples or PrecomputedFeatures", - ) - - def train_dataloaders( - self, - cuts_train: CutSet, - sampler_state_dict: Optional[Dict[str, Any]] = None, - ) -> DataLoader: - """ - Args: - cuts_train: - CutSet for training. - sampler_state_dict: - The state dict for the training sampler. - """ - transforms = [] - if self.args.enable_musan: - logging.info("Enable MUSAN") - logging.info("About to get Musan cuts") - cuts_musan = load_manifest( - self.args.manifest_dir / "cuts_musan.json.gz" - ) - transforms.append( - CutMix( - cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True - ) - ) - else: - logging.info("Disable MUSAN") - - if self.args.concatenate_cuts: - logging.info( - f"Using cut concatenation with duration factor " - f"{self.args.duration_factor} and gap {self.args.gap}." - ) - # Cut concatenation should be the first transform in the list, - # so that if we e.g. mix noise in, it will fill the gaps between - # different utterances. - transforms = [ - CutConcatenate( - duration_factor=self.args.duration_factor, gap=self.args.gap - ) - ] + transforms - - input_transforms = [] - if self.args.enable_spec_aug: - logging.info("Enable SpecAugment") - logging.info( - f"Time warp factor: {self.args.spec_aug_time_warp_factor}" - ) - # Set the value of num_frame_masks according to Lhotse's version. - # In different Lhotse's versions, the default of num_frame_masks is - # different. - num_frame_masks = 10 - num_frame_masks_parameter = inspect.signature( - SpecAugment.__init__ - ).parameters["num_frame_masks"] - if num_frame_masks_parameter.default == 1: - num_frame_masks = 2 - logging.info(f"Num frame mask: {num_frame_masks}") - input_transforms.append( - SpecAugment( - time_warp_factor=self.args.spec_aug_time_warp_factor, - num_frame_masks=num_frame_masks, - features_mask_size=27, - num_feature_masks=2, - frames_mask_size=100, - ) - ) - else: - logging.info("Disable SpecAugment") - - logging.info("About to create train dataset") - train = K2SpeechRecognitionDataset( - input_strategy=eval(self.args.input_strategy)(), - cut_transforms=transforms, - input_transforms=input_transforms, - return_cuts=self.args.return_cuts, - ) - - if self.args.on_the_fly_feats: - # NOTE: the PerturbSpeed transform should be added only if we - # remove it from data prep stage. - # Add on-the-fly speed perturbation; since originally it would - # have increased epoch size by 3, we will apply prob 2/3 and use - # 3x more epochs. - # Speed perturbation probably should come first before - # concatenation, but in principle the transforms order doesn't have - # to be strict (e.g. could be randomized) - # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa - # Drop feats to be on the safe side. - train = K2SpeechRecognitionDataset( - cut_transforms=transforms, - input_strategy=OnTheFlyFeatures( - Fbank(FbankConfig(num_mel_bins=80)) - ), - input_transforms=input_transforms, - return_cuts=self.args.return_cuts, - ) - - if self.args.bucketing_sampler: - logging.info("Using BucketingSampler.") - train_sampler = BucketingSampler( - cuts_train, - max_duration=self.args.max_duration, - shuffle=self.args.shuffle, - num_buckets=self.args.num_buckets, - bucket_method="equal_duration", - drop_last=self.args.drop_last, - ) - else: - logging.info("Using SingleCutSampler.") - train_sampler = SingleCutSampler( - cuts_train, - max_duration=self.args.max_duration, - shuffle=self.args.shuffle, - ) - logging.info("About to create train dataloader") - - if sampler_state_dict is not None: - logging.info("Loading sampler state dict") - train_sampler.load_state_dict(sampler_state_dict) - - # 'seed' is derived from the current random state, which will have - # previously been set in the main process. - seed = torch.randint(0, 100000, ()).item() - worker_init_fn = _SeedWorkers(seed) - - train_dl = DataLoader( - train, - sampler=train_sampler, - batch_size=None, - num_workers=self.args.num_workers, - persistent_workers=False, - worker_init_fn=worker_init_fn, - ) - - return train_dl - - def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: - transforms = [] - if self.args.concatenate_cuts: - transforms = [ - CutConcatenate( - duration_factor=self.args.duration_factor, gap=self.args.gap - ) - ] + transforms - - logging.info("About to create dev dataset") - if self.args.on_the_fly_feats: - validate = K2SpeechRecognitionDataset( - cut_transforms=transforms, - input_strategy=OnTheFlyFeatures( - Fbank(FbankConfig(num_mel_bins=80)) - ), - return_cuts=self.args.return_cuts, - ) - else: - validate = K2SpeechRecognitionDataset( - cut_transforms=transforms, - return_cuts=self.args.return_cuts, - ) - valid_sampler = BucketingSampler( - cuts_valid, - max_duration=self.args.max_duration, - shuffle=False, - ) - logging.info("About to create dev dataloader") - valid_dl = DataLoader( - validate, - sampler=valid_sampler, - batch_size=None, - num_workers=2, - persistent_workers=False, - ) - - return valid_dl - - def test_dataloaders(self, cuts: CutSet) -> DataLoader: - logging.debug("About to create test dataset") - test = K2SpeechRecognitionDataset( - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) - if self.args.on_the_fly_feats - else eval(self.args.input_strategy)(), - return_cuts=self.args.return_cuts, - ) - sampler = BucketingSampler( - cuts, max_duration=self.args.max_duration, shuffle=False - ) - logging.debug("About to create test dataloader") - test_dl = DataLoader( - test, - batch_size=None, - sampler=sampler, - num_workers=self.args.num_workers, - ) - return test_dl - - @lru_cache() - def train_clean_100_cuts(self) -> CutSet: - logging.info("About to get train-clean-100 cuts") - return load_manifest( - self.args.manifest_dir / "cuts_train-clean-100.json.gz" - ) - - @lru_cache() - def train_clean_360_cuts(self) -> CutSet: - logging.info("About to get train-clean-360 cuts") - return load_manifest( - self.args.manifest_dir / "cuts_train-clean-360.json.gz" - ) - - @lru_cache() - def train_other_500_cuts(self) -> CutSet: - logging.info("About to get train-other-500 cuts") - return load_manifest( - self.args.manifest_dir / "cuts_train-other-500.json.gz" - ) - - @lru_cache() - def dev_clean_cuts(self) -> CutSet: - logging.info("About to get dev-clean cuts") - return load_manifest(self.args.manifest_dir / "cuts_dev-clean.json.gz") - - @lru_cache() - def dev_other_cuts(self) -> CutSet: - logging.info("About to get dev-other cuts") - return load_manifest(self.args.manifest_dir / "cuts_dev-other.json.gz") - - @lru_cache() - def test_clean_cuts(self) -> CutSet: - logging.info("About to get test-clean cuts") - return load_manifest(self.args.manifest_dir / "cuts_test-clean.json.gz") - - @lru_cache() - def test_other_cuts(self) -> CutSet: - logging.info("About to get test-other cuts") - return load_manifest(self.args.manifest_dir / "cuts_test-other.json.gz") diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/asr_datamodule.py b/egs/librispeech/ASR/pruned_transducer_stateless6/asr_datamodule.py new file mode 120000 index 000000000..a074d6085 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/asr_datamodule.py @@ -0,0 +1 @@ +../pruned_transducer_stateless2/asr_datamodule.py \ No newline at end of file diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py index e83009d4a..b391b2815 100644 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py @@ -32,7 +32,6 @@ from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures K2SpeechRecognitionDataset, PrecomputedFeatures, SingleCutSampler, - SpecAugment, ) from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples AudioSamples, @@ -41,6 +40,7 @@ from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples from lhotse.utils import fix_random_seed from torch.utils.data import DataLoader +from aug import SpecAugment from icefall.utils import str2bool @@ -183,6 +183,12 @@ class LibriSpeechAsrDataModule: help="When enabled, use SpecAugment for training dataset.", ) + group.add_argument( + "--spec-aug-max-frames-mask-fraction", + type=float, + default=0.15, + ) + group.add_argument( "--spec-aug-time-warp-factor", type=int, @@ -272,6 +278,7 @@ class LibriSpeechAsrDataModule: features_mask_size=27, num_feature_masks=2, frames_mask_size=100, + max_frames_mask_fraction=self.args.spec_aug_max_frames_mask_fraction, ) ) else: