diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/asr_datamodule.py b/egs/librispeech/ASR/pruned_transducer_stateless3/asr_datamodule.py new file mode 100644 index 000000000..fe0d0a872 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/asr_datamodule.py @@ -0,0 +1,304 @@ +# Copyright 2021 Piotr Żelasko +# 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import logging +from pathlib import Path +from typing import Optional + +from lhotse import CutSet, Fbank, FbankConfig +from lhotse.dataset import ( + BucketingSampler, + CutMix, + DynamicBucketingSampler, + K2SpeechRecognitionDataset, + SpecAugment, +) +from lhotse.dataset.input_strategies import ( + OnTheFlyFeatures, + PrecomputedFeatures, +) +from torch.utils.data import DataLoader + +from icefall.utils import str2bool + + +class AsrDataModule: + def __init__(self, args: argparse.Namespace): + self.args = args + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + group = parser.add_argument_group( + title="ASR data related options", + description="These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc.", + ) + + group.add_argument( + "--max-duration", + type=int, + default=200.0, + help="Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM.", + ) + + group.add_argument( + "--bucketing-sampler", + type=str2bool, + default=True, + help="When enabled, the batches will come from buckets of " + "similar duration (saves padding frames).", + ) + + group.add_argument( + "--num-buckets", + type=int, + default=30, + help="The number of buckets for the BucketingSampler " + "and DynamicBucketingSampler." + "(you might want to increase it for larger datasets).", + ) + + group.add_argument( + "--shuffle", + type=str2bool, + default=True, + help="When enabled (=default), the examples will be " + "shuffled for each epoch.", + ) + + group.add_argument( + "--return-cuts", + type=str2bool, + default=True, + help="When enabled, each batch will have the " + "field: batch['supervisions']['cut'] with the cuts that " + "were used to construct it.", + ) + + group.add_argument( + "--num-workers", + type=int, + default=2, + help="The number of training dataloader workers that " + "collect the batches.", + ) + + group.add_argument( + "--enable-spec-aug", + type=str2bool, + default=True, + help="When enabled, use SpecAugment for training dataset.", + ) + + group.add_argument( + "--spec-aug-time-warp-factor", + type=int, + default=80, + help="Used only when --enable-spec-aug is True. " + "It specifies the factor for time warping in SpecAugment. " + "Larger values mean more warping. " + "A value less than 1 means to disable time warp.", + ) + + group.add_argument( + "--enable-musan", + type=str2bool, + default=True, + help="When enabled, select noise from MUSAN and mix it" + "with training dataset. ", + ) + + group.add_argument( + "--manifest-dir", + type=Path, + default=Path("data/fbank"), + help="Path to directory with train/valid/test cuts.", + ) + + group.add_argument( + "--on-the-fly-feats", + type=str2bool, + default=False, + help="When enabled, use on-the-fly cut mixing and feature " + "extraction. Will drop existing precomputed feature manifests " + "if available. Used only in dev/test CutSet", + ) + + def train_dataloaders( + self, + cuts_train: CutSet, + dynamic_bucketing: bool, + on_the_fly_feats: bool, + cuts_musan: Optional[CutSet] = None, + ) -> DataLoader: + """ + Args: + cuts_train: + Cuts for training. + cuts_musan: + If not None, it is the cuts for mixing. + dynamic_bucketing: + True to use DynamicBucketingSampler; + False to use BucketingSampler. + on_the_fly_feats: + True to use OnTheFlyFeatures; + False to use PrecomputedFeatures. + """ + transforms = [] + if cuts_musan is not None: + logging.info("Enable MUSAN") + transforms.append( + CutMix( + cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True + ) + ) + else: + logging.info("Disable MUSAN") + + input_transforms = [] + + if self.args.enable_spec_aug: + logging.info("Enable SpecAugment") + logging.info( + f"Time warp factor: {self.args.spec_aug_time_warp_factor}" + ) + input_transforms.append( + SpecAugment( + time_warp_factor=self.args.spec_aug_time_warp_factor, + num_frame_masks=2, + features_mask_size=27, + num_feature_masks=2, + frames_mask_size=100, + ) + ) + else: + logging.info("Disable SpecAugment") + + logging.info("About to create train dataset") + train = K2SpeechRecognitionDataset( + cut_transforms=transforms, + input_transforms=input_transforms, + return_cuts=self.args.return_cuts, + ) + + # NOTE: the PerturbSpeed transform should be added only if we + # remove it from data prep stage. + # Add on-the-fly speed perturbation; since originally it would + # have increased epoch size by 3, we will apply prob 2/3 and use + # 3x more epochs. + # Speed perturbation probably should come first before + # concatenation, but in principle the transforms order doesn't have + # to be strict (e.g. could be randomized) + # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa + # Drop feats to be on the safe side. + train = K2SpeechRecognitionDataset( + cut_transforms=transforms, + input_strategy=( + OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) + if on_the_fly_feats + else PrecomputedFeatures() + ), + input_transforms=input_transforms, + return_cuts=self.args.return_cuts, + ) + + if dynamic_bucketing: + logging.info("Using DynamicBucketingSampler.") + train_sampler = DynamicBucketingSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + num_buckets=self.args.num_buckets, + drop_last=True, + ) + else: + logging.info("Using BucketingSampler.") + train_sampler = BucketingSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + num_buckets=self.args.num_buckets, + bucket_method="equal_duration", + drop_last=True, + ) + + logging.info("About to create train dataloader") + train_dl = DataLoader( + train, + sampler=train_sampler, + batch_size=None, + num_workers=self.args.num_workers, + persistent_workers=False, + ) + return train_dl + + def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: + transforms = [] + + logging.info("About to create dev dataset") + if self.args.on_the_fly_feats: + validate = K2SpeechRecognitionDataset( + cut_transforms=transforms, + input_strategy=OnTheFlyFeatures( + Fbank(FbankConfig(num_mel_bins=80)) + ), + return_cuts=self.args.return_cuts, + ) + else: + validate = K2SpeechRecognitionDataset( + cut_transforms=transforms, + return_cuts=self.args.return_cuts, + ) + valid_sampler = BucketingSampler( + cuts_valid, + max_duration=self.args.max_duration, + shuffle=False, + ) + logging.info("About to create dev dataloader") + valid_dl = DataLoader( + validate, + sampler=valid_sampler, + batch_size=None, + num_workers=2, + persistent_workers=False, + ) + + return valid_dl + + def test_dataloaders(self, cuts: CutSet) -> DataLoader: + logging.debug("About to create test dataset") + test = K2SpeechRecognitionDataset( + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) + if self.args.on_the_fly_feats + else PrecomputedFeatures(), + return_cuts=self.args.return_cuts, + ) + sampler = BucketingSampler( + cuts, max_duration=self.args.max_duration, shuffle=False + ) + logging.debug("About to create test dataloader") + test_dl = DataLoader( + test, + batch_size=None, + sampler=sampler, + num_workers=self.args.num_workers, + ) + return test_dl diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/gigaspeech.py b/egs/librispeech/ASR/pruned_transducer_stateless3/gigaspeech.py new file mode 100644 index 000000000..286771d7d --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/gigaspeech.py @@ -0,0 +1,75 @@ +# Copyright 2021 Piotr Żelasko +# 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import logging +from pathlib import Path + +from lhotse import CutSet, load_manifest + + +class GigaSpeech: + def __init__(self, manifest_dir: str): + """ + Args: + manifest_dir: + It is expected to contain the following files:: + + - cuts_XL_raw.jsonl.gz + - cuts_L_raw.jsonl.gz + - cuts_M_raw.jsonl.gz + - cuts_S_raw.jsonl.gz + - cuts_XS_raw.jsonl.gz + - cuts_DEV_raw.jsonl.gz + - cuts_TEST_raw.jsonl.gz + """ + self.manifest_dir = Path(manifest_dir) + + def train_XL_cuts(self) -> CutSet: + f = self.manifest_dir / "cuts_XL_raw.jsonl.gz" + logging.info(f"About to get train-XL cuts from {f}") + return CutSet.from_jsonl_lazy(f) + + def train_L_cuts(self) -> CutSet: + f = self.manifest_dir / "cuts_L_raw.jsonl.gz" + logging.info(f"About to get train-L cuts from {f}") + return CutSet.from_jsonl_lazy(f) + + def train_M_cuts(self) -> CutSet: + f = self.manifest_dir / "cuts_M_raw.jsonl.gz" + logging.info(f"About to get train-M cuts from {f}") + return CutSet.from_jsonl_lazy(f) + + def train_S_cuts(self) -> CutSet: + f = self.manifest_dir / "cuts_S_raw.jsonl.gz" + logging.info(f"About to get train-S cuts from {f}") + return CutSet.from_jsonl_lazy(f) + + def train_XS_cuts(self) -> CutSet: + f = self.manifest_dir / "cuts_XS_raw.jsonl.gz" + logging.info(f"About to get train-XS cuts from {f}") + return CutSet.from_jsonl_lazy(f) + + def test_cuts(self) -> CutSet: + f = self.manifest_dir / "cuts_TEST.jsonl.gz" + logging.info(f"About to get TEST cuts from {f}") + return load_manifest(f) + + def dev_cuts(self) -> CutSet: + f = self.manifest_dir / "cuts_DEV.jsonl.gz" + logging.info(f"About to get DEV cuts from {f}") + return load_manifest(f) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/librispeech.py b/egs/librispeech/ASR/pruned_transducer_stateless3/librispeech.py new file mode 100644 index 000000000..00b7c8334 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/librispeech.py @@ -0,0 +1,74 @@ +# Copyright 2021 Piotr Żelasko +# 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from pathlib import Path + +from lhotse import CutSet, load_manifest + + +class LibriSpeech: + def __init__(self, manifest_dir: str): + """ + Args: + manifest_dir: + It is expected to contain the following files:: + + - cuts_dev-clean.json.gz + - cuts_dev-other.json.gz + - cuts_test-clean.json.gz + - cuts_test-other.json.gz + - cuts_train-clean-100.json.gz + - cuts_train-clean-360.json.gz + - cuts_train-other-500.json.gz + """ + self.manifest_dir = Path(manifest_dir) + + def train_clean_100_cuts(self) -> CutSet: + f = self.manifest_dir / "cuts_train-clean-100.json.gz" + logging.info(f"About to get train-clean-100 cuts from {f}") + return load_manifest(f) + + def train_clean_360_cuts(self) -> CutSet: + f = self.manifest_dir / "cuts_train-clean-360.json.gz" + logging.info(f"About to get train-clean-360 cuts from {f}") + return load_manifest(f) + + def train_other_500_cuts(self) -> CutSet: + f = self.manifest_dir / "cuts_train-other-500.json.gz" + logging.info(f"About to get train-other-500 cuts from {f}") + return load_manifest(f) + + def test_clean_cuts(self) -> CutSet: + f = self.manifest_dir / "cuts_test-clean.json.gz" + logging.info(f"About to get test-clean cuts from {f}") + return load_manifest(f) + + def test_other_cuts(self) -> CutSet: + f = self.manifest_dir / "cuts_test-other.json.gz" + logging.info(f"About to get test-other cuts from {f}") + return load_manifest(f) + + def dev_clean_cuts(self) -> CutSet: + f = self.manifest_dir / "cuts_dev-clean.json.gz" + logging.info(f"About to get dev-clean cuts from {f}") + return load_manifest(f) + + def dev_other_cuts(self) -> CutSet: + f = self.manifest_dir / "cuts_dev-other.json.gz" + logging.info(f"About to get dev-other cuts from {f}") + return load_manifest(f) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/model.py b/egs/librispeech/ASR/pruned_transducer_stateless3/model.py index 599bf2506..5894361fc 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/model.py @@ -15,6 +15,8 @@ # limitations under the License. +from typing import Optional + import k2 import torch import torch.nn as nn @@ -38,6 +40,8 @@ class Transducer(nn.Module): decoder_dim: int, joiner_dim: int, vocab_size: int, + decoder_giga: Optional[nn.Module] = None, + joiner_giga: Optional[nn.Module] = None, ): """ Args: @@ -51,11 +55,25 @@ class Transducer(nn.Module): is (N, U) and its output shape is (N, U, decoder_dim). It should contain one attribute: `blank_id`. joiner: - It has two inputs with shapes: (N, T, encoder_dim) and (N, U, decoder_dim). - Its output shape is (N, T, U, vocab_size). Note that its output contains + It has two inputs with shapes: (N, T, encoder_dim) and + (N, U, decoder_dim). Its output shape is (N, T, U, vocab_size). + Note that its output contains unnormalized probs, i.e., not processed by log-softmax. + encoder_dim: + Output dimension of the encoder network. + decoder_dim: + Output dimension of the decoder network. + joiner_dim: + Input dimension of the joiner network. + vocab_size: + Output dimension of the joiner network. + decoder_giga: + Optional. The decoder network for the GigaSpeech dataset. + joiner_giga: + Optional. The joiner network for the GigaSpeech dataset. """ super().__init__() + assert isinstance(encoder, EncoderInterface), type(encoder) assert hasattr(decoder, "blank_id") @@ -63,16 +81,26 @@ class Transducer(nn.Module): self.decoder = decoder self.joiner = joiner + self.decoder_giga = decoder_giga + self.joiner_giga = joiner_giga + self.simple_am_proj = ScaledLinear( encoder_dim, vocab_size, initial_speed=0.5 ) self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size) + if decoder_giga is not None: + self.simple_am_proj_giga = ScaledLinear( + encoder_dim, vocab_size, initial_speed=0.5 + ) + self.simple_lm_proj_giga = ScaledLinear(decoder_dim, vocab_size) + def forward( self, x: torch.Tensor, x_lens: torch.Tensor, y: k2.RaggedTensor, + libri: bool = True, prune_range: int = 5, am_scale: float = 0.0, lm_scale: float = 0.0, @@ -88,6 +116,9 @@ class Transducer(nn.Module): y: A ragged tensor with 2 axes [utt][label]. It contains labels of each utterance. + libri: + True to use the decoder and joiner for the LibriSpeech dataset. + False to use the decoder and joiner for the GigaSpeech dataset. prune_range: The prune range for rnnt loss, it means how many symbols(context) we are considering for each frame to compute the loss. @@ -115,21 +146,32 @@ class Transducer(nn.Module): assert x.size(0) == x_lens.size(0) == y.dim0 - encoder_out, x_lens = self.encoder(x, x_lens, warmup=warmup) - assert torch.all(x_lens > 0) + encoder_out, encoder_out_lens = self.encoder(x, x_lens, warmup=warmup) + assert torch.all(encoder_out_lens > 0) + + if libri: + decoder = self.decoder + simple_lm_proj = self.simple_lm_proj + simple_am_proj = self.simple_am_proj + joiner = self.joiner + else: + decoder = self.decoder_giga + simple_lm_proj = self.simple_lm_proj_giga + simple_am_proj = self.simple_am_proj_giga + joiner = self.joiner_giga # Now for the decoder, i.e., the prediction network row_splits = y.shape.row_splits(1) y_lens = row_splits[1:] - row_splits[:-1] - blank_id = self.decoder.blank_id + blank_id = decoder.blank_id sos_y = add_sos(y, sos_id=blank_id) # sos_y_padded: [B, S + 1], start with SOS. sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id) # decoder_out: [B, S + 1, decoder_dim] - decoder_out = self.decoder(sos_y_padded) + decoder_out = decoder(sos_y_padded) # Note: y does not start with SOS # y_padded : [B, S] @@ -140,10 +182,10 @@ class Transducer(nn.Module): (x.size(0), 4), dtype=torch.int64, device=x.device ) boundary[:, 2] = y_lens - boundary[:, 3] = x_lens + boundary[:, 3] = encoder_out_lens - lm = self.simple_lm_proj(decoder_out) - am = self.simple_am_proj(encoder_out) + lm = simple_lm_proj(decoder_out) + am = simple_am_proj(encoder_out) with torch.cuda.amp.autocast(enabled=False): simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( @@ -169,8 +211,8 @@ class Transducer(nn.Module): # am_pruned : [B, T, prune_range, encoder_dim] # lm_pruned : [B, T, prune_range, decoder_dim] am_pruned, lm_pruned = k2.do_rnnt_pruning( - am=self.joiner.encoder_proj(encoder_out), - lm=self.joiner.decoder_proj(decoder_out), + am=joiner.encoder_proj(encoder_out), + lm=joiner.decoder_proj(decoder_out), ranges=ranges, ) @@ -178,7 +220,7 @@ class Transducer(nn.Module): # project_input=False since we applied the decoder's input projections # prior to do_rnnt_pruning (this is an optimization for speed). - logits = self.joiner(am_pruned, lm_pruned, project_input=False) + logits = joiner(am_pruned, lm_pruned, project_input=False) with torch.cuda.amp.autocast(enabled=False): pruned_loss = k2.rnnt_loss_pruned( diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/train.py b/egs/librispeech/ASR/pruned_transducer_stateless3/train.py index 80617847a..7e3155018 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/train.py @@ -21,22 +21,26 @@ Usage: export CUDA_VISIBLE_DEVICES="0,1,2,3" -./pruned_transducer_stateless2/train.py \ +cd egs/librispeech/ASR/ +./prepare.sh +./prepare_giga_speech.sh + +./pruned_transducer_stateless3/train.py \ --world-size 4 \ --num-epochs 30 \ --start-epoch 0 \ - --exp-dir pruned_transducer_stateless2/exp \ + --exp-dir pruned_transducer_stateless3/exp \ --full-libri 1 \ --max-duration 300 # For mix precision training: -./pruned_transducer_stateless2/train.py \ +./pruned_transducer_stateless3/train.py \ --world-size 4 \ --num-epochs 30 \ --start-epoch 0 \ --use_fp16 1 \ - --exp-dir pruned_transducer_stateless2/exp \ + --exp-dir pruned_transducer_stateless3/exp \ --full-libri 1 \ --max-duration 550 @@ -45,6 +49,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3" import argparse import logging +import random import warnings from pathlib import Path from shutil import copyfile @@ -56,13 +61,16 @@ import sentencepiece as spm import torch import torch.multiprocessing as mp import torch.nn as nn -from asr_datamodule import LibriSpeechAsrDataModule +from asr_datamodule import AsrDataModule from conformer import Conformer from decoder import Decoder +from gigaspeech import GigaSpeech from joiner import Joiner +from lhotse import CutSet, load_manifest from lhotse.cut import Cut from lhotse.dataset.sampling.base import CutSampler from lhotse.utils import fix_random_seed +from librispeech import LibriSpeech from model import Transducer from optim import Eden, Eve from torch import Tensor @@ -109,6 +117,14 @@ def get_parser(): help="Should various information be logged in tensorboard.", ) + parser.add_argument( + "--full-libri", + type=str2bool, + default=True, + help="When enabled, use 960h LibriSpeech. " + "Otherwise, use 100h subset.", + ) + parser.add_argument( "--num-epochs", type=int, @@ -122,7 +138,7 @@ def get_parser(): default=0, help="""Resume training from from this epoch. If it is positive, it will load checkpoint from - transducer_stateless2/exp/epoch-{start_epoch-1}.pt + transducer_stateless3/exp/epoch-{start_epoch-1}.pt """, ) @@ -138,7 +154,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="pruned_transducer_stateless2/exp", + default="pruned_transducer_stateless3/exp", help="""The experiment dir. It specifies the directory where all training related files, e.g., checkpoints, log, etc, are saved @@ -156,7 +172,8 @@ def get_parser(): "--initial-lr", type=float, default=0.003, - help="The initial learning rate. This value should not need to be changed.", + help="The initial learning rate. This value should not need " + "to be changed.", ) parser.add_argument( @@ -170,7 +187,7 @@ def get_parser(): parser.add_argument( "--lr-epochs", type=float, - default=6, + default=4, help="""Number of epochs that affects how rapidly the learning rate decreases. """, ) @@ -262,6 +279,13 @@ def get_parser(): help="Whether to use half precision training.", ) + parser.add_argument( + "--giga-prob", + type=float, + default=0.5, + help="The probability to select a batch from the GigaSpeech dataset", + ) + return parser @@ -377,10 +401,15 @@ def get_transducer_model(params: AttributeDict) -> nn.Module: decoder = get_decoder_model(params) joiner = get_joiner_model(params) + decoder_giga = get_decoder_model(params) + joiner_giga = get_joiner_model(params) + model = Transducer( encoder=encoder, decoder=decoder, joiner=joiner, + decoder_giga=decoder_giga, + joiner_giga=joiner_giga, encoder_dim=params.encoder_dim, decoder_dim=params.decoder_dim, joiner_dim=params.joiner_dim, @@ -448,9 +477,6 @@ def load_checkpoint_if_available( if "cur_epoch" in saved_params: params["start_epoch"] = saved_params["cur_epoch"] - if "cur_batch_idx" in saved_params: - params["cur_batch_idx"] = saved_params["cur_batch_idx"] - return saved_params @@ -500,6 +526,17 @@ def save_checkpoint( copyfile(src=filename, dst=best_valid_filename) +def is_libri(c: Cut) -> bool: + """Return True if this cut is from the LibriSpeech dataset. + + Note: + During data preparation, we set the custom field in + the supervision segment of GigaSpeech to dict(origin='giga') + See ../local/preprocess_gigaspeech.py. + """ + return c.supervisions[0].custom is None + + def compute_loss( params: AttributeDict, model: nn.Module, @@ -535,6 +572,8 @@ def compute_loss( supervisions = batch["supervisions"] feature_lens = supervisions["num_frames"].to(device) + libri = is_libri(supervisions["cut"][0]) + texts = batch["supervisions"]["text"] y = sp.encode(texts, out_type=int) y = k2.RaggedTensor(y).to(device) @@ -544,6 +583,7 @@ def compute_loss( x=feature, x_lens=feature_lens, y=y, + libri=libri, prune_range=params.prune_range, am_scale=params.am_scale, lm_scale=params.lm_scale, @@ -621,7 +661,9 @@ def train_one_epoch( scheduler: LRSchedulerType, sp: spm.SentencePieceProcessor, train_dl: torch.utils.data.DataLoader, + giga_train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, + rng: random.Random, scaler: GradScaler, tb_writer: Optional[SummaryWriter] = None, world_size: int = 1, @@ -644,8 +686,12 @@ def train_one_epoch( The learning rate scheduler, we call step() every step. train_dl: Dataloader for the training dataset. + giga_train_dl: + Dataloader for the GigaSpeech training dataset. valid_dl: Dataloader for the validation dataset. + rng: + For selecting which dataset to use. scaler: The scaler used for mix precision training. tb_writer: @@ -658,18 +704,36 @@ def train_one_epoch( """ model.train() + libri_tot_loss = MetricsTracker() + giga_tot_loss = MetricsTracker() tot_loss = MetricsTracker() - cur_batch_idx = params.get("cur_batch_idx", 0) + # index 0: for LibriSpeech + # index 1: for GigaSpeech + # This sets the probabilities for choosing which datasets + dl_weights = [1 - params.giga_prob, params.giga_prob] - for batch_idx, batch in enumerate(train_dl): - if batch_idx < cur_batch_idx: - continue - cur_batch_idx = batch_idx + iter_libri = iter(train_dl) + iter_giga = iter(giga_train_dl) + + batch_idx = 0 + + while True: + idx = rng.choices((0, 1), weights=dl_weights, k=1)[0] + dl = iter_libri if idx == 0 else iter_giga + + try: + batch = next(dl) + except StopIteration: + break + + batch_idx += 1 params.batch_idx_train += 1 batch_size = len(batch["supervisions"]["text"]) + libri = is_libri(batch["supervisions"]["cut"][0]) + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, @@ -682,6 +746,17 @@ def train_one_epoch( # summary stats tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + if libri: + libri_tot_loss = ( + libri_tot_loss * (1 - 1 / params.reset_interval) + ) + loss_info + prefix = "libri" # for logging only + else: + giga_tot_loss = ( + giga_tot_loss * (1 - 1 / params.reset_interval) + ) + loss_info + prefix = "giga" + # NOTE: We use reduction==sum and loss is computed over utterances # in the batch and there is no normalization to it so far. scaler.scale(loss).backward() @@ -697,7 +772,6 @@ def train_one_epoch( params.batch_idx_train > 0 and params.batch_idx_train % params.save_every_n == 0 ): - params.cur_batch_idx = batch_idx save_checkpoint_with_global_batch_idx( out_dir=params.exp_dir, global_batch_idx=params.batch_idx_train, @@ -709,7 +783,6 @@ def train_one_epoch( scaler=scaler, rank=rank, ) - del params.cur_batch_idx remove_checkpoints( out_dir=params.exp_dir, topk=params.keep_last_k, @@ -720,8 +793,11 @@ def train_one_epoch( cur_lr = scheduler.get_last_lr()[0] logging.info( f"Epoch {params.cur_epoch}, " - f"batch {batch_idx}, loss[{loss_info}], " - f"tot_loss[{tot_loss}], batch size: {batch_size}, " + f"batch {batch_idx}, {prefix}_loss[{loss_info}], " + f"tot_loss[{tot_loss}], " + f"libri_tot_loss[{libri_tot_loss}], " + f"giga_tot_loss[{giga_tot_loss}], " + f"batch size: {batch_size}" f"lr: {cur_lr:.2e}" ) @@ -731,11 +807,19 @@ def train_one_epoch( ) loss_info.write_summary( - tb_writer, "train/current_", params.batch_idx_train + tb_writer, + f"train/current_{prefix}_", + params.batch_idx_train, ) tot_loss.write_summary( tb_writer, "train/tot_", params.batch_idx_train ) + libri_tot_loss.write_summary( + tb_writer, "train/libri_tot_", params.batch_idx_train + ) + giga_tot_loss.write_summary( + tb_writer, "train/giga_tot_", params.batch_idx_train + ) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -760,6 +844,23 @@ def train_one_epoch( params.best_train_loss = params.train_loss +def filter_short_and_long_utterances(cuts: CutSet) -> CutSet: + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 20 seconds + # + # Caution: There is a reason to select 20.0 here. Please see + # ../local/display_manifest_statistics.py + # + # You should use ../local/display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold + return 1.0 <= c.duration <= 20.0 + + cuts = cuts.filter(remove_short_and_long_utt) + + return cuts + + def run(rank, world_size, args): """ Args: @@ -778,6 +879,7 @@ def run(rank, world_size, args): params.valid_interval = 1600 fix_random_seed(params.seed) + rng = random.Random(params.seed) if world_size > 1: setup_dist(rank, world_size, params.master_port) @@ -814,7 +916,7 @@ def run(rank, world_size, args): model.to(device) if world_size > 1: logging.info("Using DDP") - model = DDP(model, device_ids=[rank]) + model = DDP(model, device_ids=[rank], find_unused_parameters=True) model.device = device optimizer = Eve(model.parameters(), lr=params.initial_lr) @@ -839,45 +941,65 @@ def run(rank, world_size, args): ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) - librispeech = LibriSpeechAsrDataModule(args) + librispeech = LibriSpeech(manifest_dir=args.manifest_dir) train_cuts = librispeech.train_clean_100_cuts() if params.full_libri: train_cuts += librispeech.train_clean_360_cuts() train_cuts += librispeech.train_other_500_cuts() - def remove_short_and_long_utt(c: Cut): - # Keep only utterances with duration between 1 second and 20 seconds - # - # Caution: There is a reason to select 20.0 here. Please see - # ../local/display_manifest_statistics.py - # - # You should use ../local/display_manifest_statistics.py to get - # an utterance duration distribution for your dataset to select - # the threshold - return 1.0 <= c.duration <= 20.0 + train_cuts = filter_short_and_long_utterances(train_cuts) - train_cuts = train_cuts.filter(remove_short_and_long_utt) - - if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: - # We only load the sampler's state dict when it loads a checkpoint - # saved in the middle of an epoch - sampler_state_dict = checkpoints["sampler"] + gigaspeech = GigaSpeech(manifest_dir=args.manifest_dir) + # XL 10k hours + # L 2.5k hours + # M 1k hours + # S 250 hours + # XS 10 hours + # DEV 12 hours + # Test 40 hours + if params.full_libri: + logging.info("Using the XL subset of GigaSpeech (10k hours)") + train_giga_cuts = gigaspeech.train_XL_cuts() else: - sampler_state_dict = None + logging.info("Using the S subset of GigaSpeech (250 hours)") + train_giga_cuts = gigaspeech.train_S_cuts() - train_dl = librispeech.train_dataloaders( - train_cuts, sampler_state_dict=sampler_state_dict + train_giga_cuts = filter_short_and_long_utterances(train_giga_cuts) + + if args.enable_musan: + cuts_musan = load_manifest( + Path(args.manifest_dir) / "cuts_musan.json.gz" + ) + else: + cuts_musan = None + + asr_datamodule = AsrDataModule(args) + + train_dl = asr_datamodule.train_dataloaders( + train_cuts, + dynamic_bucketing=False, + on_the_fly_feats=False, + cuts_musan=cuts_musan, + ) + + giga_train_dl = asr_datamodule.train_dataloaders( + train_giga_cuts, + dynamic_bucketing=True, + on_the_fly_feats=True, + cuts_musan=cuts_musan, ) valid_cuts = librispeech.dev_clean_cuts() valid_cuts += librispeech.dev_other_cuts() - valid_dl = librispeech.valid_dataloaders(valid_cuts) + valid_dl = asr_datamodule.valid_dataloaders(valid_cuts) - if not params.print_diagnostics: + # It's time consuming to include `giga_train_dl` here + # for dl in [train_dl, giga_train_dl]: + for dl in [train_dl]: scan_pessimistic_batches_for_oom( model=model, - train_dl=train_dl, + train_dl=dl, optimizer=optimizer, sp=sp, params=params, @@ -905,7 +1027,9 @@ def run(rank, world_size, args): scheduler=scheduler, sp=sp, train_dl=train_dl, + giga_train_dl=giga_train_dl, valid_dl=valid_dl, + rng=rng, scaler=scaler, tb_writer=tb_writer, world_size=world_size, @@ -978,10 +1102,12 @@ def scan_pessimistic_batches_for_oom( def main(): parser = get_parser() - LibriSpeechAsrDataModule.add_arguments(parser) + AsrDataModule.add_arguments(parser) args = parser.parse_args() args.exp_dir = Path(args.exp_dir) + assert 0 <= args.giga_prob < 1, args.giga_prob + world_size = args.world_size assert world_size >= 1 if world_size > 1: