diff --git a/egs/librispeech/ASR/pruned2_knowledge/asr_datamodule.py b/egs/librispeech/ASR/pruned2_knowledge/asr_datamodule.py deleted file mode 120000 index 07f39b451..000000000 --- a/egs/librispeech/ASR/pruned2_knowledge/asr_datamodule.py +++ /dev/null @@ -1 +0,0 @@ -../transducer/asr_datamodule.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned2_knowledge/asr_datamodule.py b/egs/librispeech/ASR/pruned2_knowledge/asr_datamodule.py new file mode 100644 index 000000000..8dd1459ca --- /dev/null +++ b/egs/librispeech/ASR/pruned2_knowledge/asr_datamodule.py @@ -0,0 +1,428 @@ +# Copyright 2021 Piotr Żelasko +# Copyright 2022 Xiaomi Corporation (Author: Mingshuang Luo) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import inspect +import logging +from functools import lru_cache +from pathlib import Path +from typing import Any, Dict, Optional + +import torch +from lhotse import CutSet, Fbank, FbankConfig, load_manifest +from lhotse.dataset import ( + BucketingSampler, + CutConcatenate, + CutMix, + K2SpeechRecognitionDataset, + PrecomputedFeatures, + SingleCutSampler, + SpecAugment, +) +from lhotse.dataset.input_strategies import OnTheFlyFeatures +from lhotse.utils import fix_random_seed +from torch.utils.data import DataLoader + +from icefall.utils import str2bool + + +class _SeedWorkers: + def __init__(self, seed: int): + self.seed = seed + + def __call__(self, worker_id: int): + fix_random_seed(self.seed + worker_id) + + +class LibriSpeechAsrDataModule: + """ + DataModule for k2 ASR experiments. + It assumes there is always one train and valid dataloader, + but there can be multiple test dataloaders (e.g. LibriSpeech test-clean + and test-other). + + It contains all the common data pipeline modules used in ASR + experiments, e.g.: + - dynamic batch size, + - bucketing samplers, + - cut concatenation, + - augmentation, + - on-the-fly feature extraction + + This class should be derived for specific corpora used in ASR tasks. + """ + + def __init__(self, args: argparse.Namespace): + self.args = args + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + group = parser.add_argument_group( + title="ASR data related options", + description="These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc.", + ) + group.add_argument( + "--full-libri", + type=str2bool, + default=True, + help="When enabled, use 960h LibriSpeech. " + "Otherwise, use 100h subset.", + ) + group.add_argument( + "--manifest-dir", + type=Path, + default=Path("data/fbank"), + help="Path to directory with train/valid/test cuts.", + ) + group.add_argument( + "--max-duration", + type=int, + default=200.0, + help="Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM.", + ) + group.add_argument( + "--bucketing-sampler", + type=str2bool, + default=True, + help="When enabled, the batches will come from buckets of " + "similar duration (saves padding frames).", + ) + group.add_argument( + "--num-buckets", + type=int, + default=30, + help="The number of buckets for the BucketingSampler" + "(you might want to increase it for larger datasets).", + ) + group.add_argument( + "--concatenate-cuts", + type=str2bool, + default=False, + help="When enabled, utterances (cuts) will be concatenated " + "to minimize the amount of padding.", + ) + group.add_argument( + "--duration-factor", + type=float, + default=1.0, + help="Determines the maximum duration of a concatenated cut " + "relative to the duration of the longest cut in a batch.", + ) + group.add_argument( + "--gap", + type=float, + default=1.0, + help="The amount of padding (in seconds) inserted between " + "concatenated cuts. This padding is filled with noise when " + "noise augmentation is used.", + ) + group.add_argument( + "--on-the-fly-feats", + type=str2bool, + default=False, + help="When enabled, use on-the-fly cut mixing and feature " + "extraction. Will drop existing precomputed feature manifests " + "if available.", + ) + group.add_argument( + "--shuffle", + type=str2bool, + default=True, + help="When enabled (=default), the examples will be " + "shuffled for each epoch.", + ) + group.add_argument( + "--return-cuts", + type=str2bool, + default=True, + help="When enabled, each batch will have the " + "field: batch['supervisions']['cut'] with the cuts that " + "were used to construct it.", + ) + + group.add_argument( + "--num-workers", + type=int, + default=2, + help="The number of training dataloader workers that " + "collect the batches.", + ) + + group.add_argument( + "--enable-spec-aug", + type=str2bool, + default=True, + help="When enabled, use SpecAugment for training dataset.", + ) + + group.add_argument( + "--spec-aug-time-warp-factor", + type=int, + default=80, + help="Used only when --enable-spec-aug is True. " + "It specifies the factor for time warping in SpecAugment. " + "Larger values mean more warping. " + "A value less than 1 means to disable time warp.", + ) + + group.add_argument( + "--enable-musan", + type=str2bool, + default=True, + help="When enabled, select noise from MUSAN and mix it" + "with training dataset. ", + ) + + def train_dataloaders( + self, + cuts_train: CutSet, + sampler_state_dict: Optional[Dict[str, Any]] = None, + ) -> DataLoader: + """ + Args: + cuts_train: + CutSet for training. + sampler_state_dict: + The state dict for the training sampler. + """ + transforms = [] + if self.args.enable_musan: + logging.info("Enable MUSAN") + logging.info("About to get Musan cuts") + cuts_musan = load_manifest( + self.args.manifest_dir / "cuts_musan.json.gz" + ) + transforms.append( + CutMix( + cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True + ) + ) + else: + logging.info("Disable MUSAN") + + if self.args.concatenate_cuts: + logging.info( + f"Using cut concatenation with duration factor " + f"{self.args.duration_factor} and gap {self.args.gap}." + ) + # Cut concatenation should be the first transform in the list, + # so that if we e.g. mix noise in, it will fill the gaps between + # different utterances. + transforms = [ + CutConcatenate( + duration_factor=self.args.duration_factor, gap=self.args.gap + ) + ] + transforms + + input_transforms = [] + if self.args.enable_spec_aug: + logging.info("Enable SpecAugment") + logging.info( + f"Time warp factor: {self.args.spec_aug_time_warp_factor}" + ) + # Set the value of num_frame_masks according to Lhotse's version. + # In different Lhotse's versions, the default of num_frame_masks is + # different. + num_frame_masks = 10 + num_frame_masks_parameter = inspect.signature( + SpecAugment.__init__ + ).parameters["num_frame_masks"] + if num_frame_masks_parameter.default == 1: + num_frame_masks = 2 + logging.info(f"Num frame mask: {num_frame_masks}") + input_transforms.append( + SpecAugment( + time_warp_factor=self.args.spec_aug_time_warp_factor, + num_frame_masks=num_frame_masks, + features_mask_size=27, + num_feature_masks=2, + frames_mask_size=100, + ) + ) + else: + logging.info("Disable SpecAugment") + + logging.info("About to create train dataset") + train = K2SpeechRecognitionDataset( + cut_transforms=transforms, + input_transforms=input_transforms, + return_cuts=self.args.return_cuts, + ) + + if self.args.on_the_fly_feats: + # NOTE: the PerturbSpeed transform should be added only if we + # remove it from data prep stage. + # Add on-the-fly speed perturbation; since originally it would + # have increased epoch size by 3, we will apply prob 2/3 and use + # 3x more epochs. + # Speed perturbation probably should come first before + # concatenation, but in principle the transforms order doesn't have + # to be strict (e.g. could be randomized) + # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa + # Drop feats to be on the safe side. + train = K2SpeechRecognitionDataset( + cut_transforms=transforms, + input_strategy=OnTheFlyFeatures( + Fbank(FbankConfig(num_mel_bins=80)) + ), + input_transforms=input_transforms, + return_cuts=self.args.return_cuts, + ) + + if self.args.bucketing_sampler: + logging.info("Using BucketingSampler.") + train_sampler = BucketingSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + num_buckets=self.args.num_buckets, + bucket_method="equal_duration", + drop_last=True, + ) + else: + logging.info("Using SingleCutSampler.") + train_sampler = SingleCutSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + ) + logging.info("About to create train dataloader") + + if sampler_state_dict is not None: + logging.info("Loading sampler state dict") + train_sampler.load_state_dict(sampler_state_dict) + + # 'seed' is derived from the current random state, which will have + # previously been set in the main process. + seed = torch.randint(0, 100000, ()).item() + worker_init_fn = _SeedWorkers(seed) + + train_dl = DataLoader( + train, + sampler=train_sampler, + batch_size=None, + num_workers=self.args.num_workers, + persistent_workers=False, + worker_init_fn=worker_init_fn, + ) + + return train_dl + + def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: + transforms = [] + if self.args.concatenate_cuts: + transforms = [ + CutConcatenate( + duration_factor=self.args.duration_factor, gap=self.args.gap + ) + ] + transforms + + logging.info("About to create dev dataset") + if self.args.on_the_fly_feats: + validate = K2SpeechRecognitionDataset( + cut_transforms=transforms, + input_strategy=OnTheFlyFeatures( + Fbank(FbankConfig(num_mel_bins=80)) + ), + return_cuts=self.args.return_cuts, + ) + else: + validate = K2SpeechRecognitionDataset( + cut_transforms=transforms, + return_cuts=self.args.return_cuts, + ) + valid_sampler = BucketingSampler( + cuts_valid, + max_duration=self.args.max_duration, + shuffle=False, + ) + logging.info("About to create dev dataloader") + valid_dl = DataLoader( + validate, + sampler=valid_sampler, + batch_size=None, + num_workers=2, + persistent_workers=False, + ) + + return valid_dl + + def test_dataloaders(self, cuts: CutSet) -> DataLoader: + logging.debug("About to create test dataset") + test = K2SpeechRecognitionDataset( + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) + if self.args.on_the_fly_feats + else PrecomputedFeatures(), + return_cuts=self.args.return_cuts, + ) + sampler = BucketingSampler( + cuts, max_duration=self.args.max_duration, shuffle=False + ) + logging.debug("About to create test dataloader") + test_dl = DataLoader( + test, + batch_size=None, + sampler=sampler, + num_workers=self.args.num_workers, + ) + return test_dl + + @lru_cache() + def train_clean_100_cuts(self) -> CutSet: + logging.info("About to get train-clean-100 cuts") + return load_manifest( + self.args.manifest_dir / "cuts_train-clean-100.json.gz" + ) + + @lru_cache() + def train_clean_360_cuts(self) -> CutSet: + logging.info("About to get train-clean-360 cuts") + return load_manifest( + self.args.manifest_dir / "cuts_train-clean-360.json.gz" + ) + + @lru_cache() + def train_other_500_cuts(self) -> CutSet: + logging.info("About to get train-other-500 cuts") + return load_manifest( + self.args.manifest_dir / "cuts_train-other-500.json.gz" + ) + + @lru_cache() + def dev_clean_cuts(self) -> CutSet: + logging.info("About to get dev-clean cuts") + return load_manifest(self.args.manifest_dir / "cuts_dev-clean.json.gz") + + @lru_cache() + def dev_other_cuts(self) -> CutSet: + logging.info("About to get dev-other cuts") + return load_manifest(self.args.manifest_dir / "cuts_dev-other.json.gz") + + @lru_cache() + def test_clean_cuts(self) -> CutSet: + logging.info("About to get test-clean cuts") + return load_manifest(self.args.manifest_dir / "cuts_test-clean.json.gz") + + @lru_cache() + def test_other_cuts(self) -> CutSet: + logging.info("About to get test-other cuts") + return load_manifest(self.args.manifest_dir / "cuts_test-other.json.gz") diff --git a/egs/librispeech/ASR/pruned2_knowledge/decode.py b/egs/librispeech/ASR/pruned2_knowledge/decode.py index 38aff8834..b4a9af55a 100755 --- a/egs/librispeech/ASR/pruned2_knowledge/decode.py +++ b/egs/librispeech/ASR/pruned2_knowledge/decode.py @@ -18,36 +18,36 @@ """ Usage: (1) greedy search -./pruned_transducer_stateless2/decode.py \ +./pruned2_knowledge/decode.py \ --epoch 28 \ --avg 15 \ - --exp-dir ./pruned_transducer_stateless2/exp \ + --exp-dir ./pruned2_knowledge/exp \ --max-duration 100 \ --decoding-method greedy_search (2) beam search -./pruned_transducer_stateless2/decode.py \ +./pruned2_knowledge/decode.py \ --epoch 28 \ --avg 15 \ - --exp-dir ./pruned_transducer_stateless2/exp \ + --exp-dir ./pruned2_knowledge/exp \ --max-duration 100 \ --decoding-method beam_search \ --beam-size 4 (3) modified beam search -./pruned_transducer_stateless2/decode.py \ +./pruned2_knowledge/decode.py \ --epoch 28 \ --avg 15 \ - --exp-dir ./pruned_transducer_stateless2/exp \ + --exp-dir ./pruned2_knowledge/exp \ --max-duration 100 \ --decoding-method modified_beam_search \ --beam-size 4 (4) fast beam search -./pruned_transducer_stateless2/decode.py \ +./pruned2_knowledge/decode.py \ --epoch 28 \ --avg 15 \ - --exp-dir ./pruned_transducer_stateless2/exp \ + --exp-dir ./pruned2_knowledge/exp \ --max-duration 1500 \ --decoding-method fast_beam_search \ --beam 4 \ @@ -124,7 +124,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="pruned_transducer_stateless2/exp", + default="pruned2_knowledge/exp", help="The experiment dir", ) diff --git a/egs/librispeech/ASR/pruned2_knowledge/encoder_interface.py b/egs/librispeech/ASR/pruned2_knowledge/encoder_interface.py deleted file mode 120000 index aa5d0217a..000000000 --- a/egs/librispeech/ASR/pruned2_knowledge/encoder_interface.py +++ /dev/null @@ -1 +0,0 @@ -../transducer_stateless/encoder_interface.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned2_knowledge/encoder_interface.py b/egs/librispeech/ASR/pruned2_knowledge/encoder_interface.py new file mode 100644 index 000000000..257facce4 --- /dev/null +++ b/egs/librispeech/ASR/pruned2_knowledge/encoder_interface.py @@ -0,0 +1,43 @@ +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Tuple + +import torch +import torch.nn as nn + + +class EncoderInterface(nn.Module): + def forward( + self, x: torch.Tensor, x_lens: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + x: + A tensor of shape (batch_size, input_seq_len, num_features) + containing the input features. + x_lens: + A tensor of shape (batch_size,) containing the number of frames + in `x` before padding. + Returns: + Return a tuple containing two tensors: + - encoder_out, a tensor of (batch_size, out_seq_len, output_dim) + containing unnormalized probabilities, i.e., the output of a + linear layer. + - encoder_out_lens, a tensor of shape (batch_size,) containing + the number of frames in `encoder_out` before padding. + """ + raise NotImplementedError("Please implement it in a subclass") diff --git a/egs/librispeech/ASR/pruned2_knowledge/export.py b/egs/librispeech/ASR/pruned2_knowledge/export.py index b5757ee8c..96d1a30fb 100755 --- a/egs/librispeech/ASR/pruned2_knowledge/export.py +++ b/egs/librispeech/ASR/pruned2_knowledge/export.py @@ -20,23 +20,23 @@ # to a single one using model averaging. """ Usage: -./pruned_transducer_stateless2/export.py \ - --exp-dir ./pruned_transducer_stateless2/exp \ +./pruned2_knowledge/export.py \ + --exp-dir ./pruned2_knowledge/exp \ --bpe-model data/lang_bpe_500/bpe.model \ --epoch 20 \ --avg 10 It will generate a file exp_dir/pretrained.pt -To use the generated file with `pruned_transducer_stateless2/decode.py`, +To use the generated file with `pruned2_knowledge/decode.py`, you can do: cd /path/to/exp_dir ln -s pretrained.pt epoch-9999.pt cd /path/to/egs/librispeech/ASR - ./pruned_transducer_stateless2/decode.py \ - --exp-dir ./pruned_transducer_stateless2/exp \ + ./pruned2_knowledge/decode.py \ + --exp-dir ./pruned2_knowledge/exp \ --epoch 9999 \ --avg 1 \ --max-duration 100 \ @@ -80,7 +80,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="pruned_transducer_stateless2/exp", + default="pruned2_knowledge/exp", help="""It specifies the directory where all training related files, e.g., checkpoints, log, etc, are saved """, diff --git a/egs/librispeech/ASR/pruned2_knowledge/train.py b/egs/librispeech/ASR/pruned2_knowledge/train.py index 80617847a..311ba3b2b 100755 --- a/egs/librispeech/ASR/pruned2_knowledge/train.py +++ b/egs/librispeech/ASR/pruned2_knowledge/train.py @@ -21,22 +21,22 @@ Usage: export CUDA_VISIBLE_DEVICES="0,1,2,3" -./pruned_transducer_stateless2/train.py \ +./pruned2_knowledge/train.py \ --world-size 4 \ --num-epochs 30 \ --start-epoch 0 \ - --exp-dir pruned_transducer_stateless2/exp \ + --exp-dir pruned2_knowledge/exp \ --full-libri 1 \ --max-duration 300 # For mix precision training: -./pruned_transducer_stateless2/train.py \ +./pruned2_knowledge/train.py \ --world-size 4 \ --num-epochs 30 \ --start-epoch 0 \ --use_fp16 1 \ - --exp-dir pruned_transducer_stateless2/exp \ + --exp-dir pruned2_knowledge/exp \ --full-libri 1 \ --max-duration 550 @@ -138,7 +138,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="pruned_transducer_stateless2/exp", + default="pruned2_knowledge/exp", help="""The experiment dir. It specifies the directory where all training related files, e.g., checkpoints, log, etc, are saved