diff --git a/egs/librispeech/ASR/transducer_stateless_multi_datasets/asr_datamodule.py b/egs/librispeech/ASR/transducer_stateless_multi_datasets/asr_datamodule.py index 16daf2f1b..fe0d0a872 100644 --- a/egs/librispeech/ASR/transducer_stateless_multi_datasets/asr_datamodule.py +++ b/egs/librispeech/ASR/transducer_stateless_multi_datasets/asr_datamodule.py @@ -133,6 +133,15 @@ class AsrDataModule: help="Path to directory with train/valid/test cuts.", ) + group.add_argument( + "--on-the-fly-feats", + type=str2bool, + default=False, + help="When enabled, use on-the-fly cut mixing and feature " + "extraction. Will drop existing precomputed feature manifests " + "if available. Used only in dev/test CutSet", + ) + def train_dataloaders( self, cuts_train: CutSet, @@ -240,3 +249,56 @@ class AsrDataModule: persistent_workers=False, ) return train_dl + + def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: + transforms = [] + + logging.info("About to create dev dataset") + if self.args.on_the_fly_feats: + validate = K2SpeechRecognitionDataset( + cut_transforms=transforms, + input_strategy=OnTheFlyFeatures( + Fbank(FbankConfig(num_mel_bins=80)) + ), + return_cuts=self.args.return_cuts, + ) + else: + validate = K2SpeechRecognitionDataset( + cut_transforms=transforms, + return_cuts=self.args.return_cuts, + ) + valid_sampler = BucketingSampler( + cuts_valid, + max_duration=self.args.max_duration, + shuffle=False, + ) + logging.info("About to create dev dataloader") + valid_dl = DataLoader( + validate, + sampler=valid_sampler, + batch_size=None, + num_workers=2, + persistent_workers=False, + ) + + return valid_dl + + def test_dataloaders(self, cuts: CutSet) -> DataLoader: + logging.debug("About to create test dataset") + test = K2SpeechRecognitionDataset( + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) + if self.args.on_the_fly_feats + else PrecomputedFeatures(), + return_cuts=self.args.return_cuts, + ) + sampler = BucketingSampler( + cuts, max_duration=self.args.max_duration, shuffle=False + ) + logging.debug("About to create test dataloader") + test_dl = DataLoader( + test, + batch_size=None, + sampler=sampler, + num_workers=self.args.num_workers, + ) + return test_dl diff --git a/egs/librispeech/ASR/transducer_stateless_multi_datasets/model.py b/egs/librispeech/ASR/transducer_stateless_multi_datasets/model.py index 8281e1fb5..919c19a86 100644 --- a/egs/librispeech/ASR/transducer_stateless_multi_datasets/model.py +++ b/egs/librispeech/ASR/transducer_stateless_multi_datasets/model.py @@ -34,6 +34,8 @@ class Transducer(nn.Module): encoder: EncoderInterface, decoder: nn.Module, joiner: nn.Module, + decoder_giga: nn.Module, + joiner_giga: nn.Module, ): """ Args: @@ -50,20 +52,30 @@ class Transducer(nn.Module): It has two inputs with shapes: (N, T, C) and (N, U, C). Its output shape is (N, T, U, C). Note that its output contains unnormalized probs, i.e., not processed by log-softmax. + decoder_giga: + The decoder for the GigaSpeech dataset. + joiner_giga: + The joiner for the GigaSpeech dataset. """ super().__init__() assert isinstance(encoder, EncoderInterface), type(encoder) assert hasattr(decoder, "blank_id") + assert hasattr(decoder_giga, "blank_id") self.encoder = encoder + self.decoder = decoder self.joiner = joiner + self.decoder_giga = decoder_giga + self.joiner_giga = joiner_giga + def forward( self, x: torch.Tensor, x_lens: torch.Tensor, y: k2.RaggedTensor, + libri: bool = True, modified_transducer_prob: float = 0.0, ) -> torch.Tensor: """ @@ -76,6 +88,9 @@ class Transducer(nn.Module): y: A ragged tensor with 2 axes [utt][label]. It contains labels of each utterance. + libri: + True to use the decoder and joiner for the LibriSpeech dataset. + False to use the decoder and joiner for the GigaSpeech dataset. modified_transducer_prob: The probability to use modified transducer loss. Returns: @@ -100,10 +115,17 @@ class Transducer(nn.Module): sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id) sos_y_padded = sos_y_padded.to(torch.int64) - decoder_out = self.decoder(sos_y_padded) + if libri: + decoder = self.decoder + joiner = self.joiner + else: + decoder = self.decoder_giga + joiner = self.joiner_giga + + decoder_out = decoder(sos_y_padded) # +1 here since a blank is prepended to each utterance. - logits = self.joiner( + logits = joiner( encoder_out=encoder_out, decoder_out=decoder_out, encoder_out_len=x_lens, diff --git a/egs/librispeech/ASR/transducer_stateless_multi_datasets/train.py b/egs/librispeech/ASR/transducer_stateless_multi_datasets/train.py index 544f6e9b1..8db8bc920 100755 --- a/egs/librispeech/ASR/transducer_stateless_multi_datasets/train.py +++ b/egs/librispeech/ASR/transducer_stateless_multi_datasets/train.py @@ -21,11 +21,11 @@ Usage: export CUDA_VISIBLE_DEVICES="0,1,2,3" -./transducer_stateless/train.py \ +./transducer_stateless_multi_datasets/train.py \ --world-size 4 \ --num-epochs 30 \ --start-epoch 0 \ - --exp-dir transducer_stateless/exp \ + --exp-dir transducer_stateless_multi_datasets/exp \ --full-libri 1 \ --max-duration 250 \ --lr-factor 2.5 @@ -34,6 +34,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3" import argparse import logging +import random from pathlib import Path from shutil import copyfile from typing import Optional, Tuple @@ -43,12 +44,15 @@ import sentencepiece as spm import torch import torch.multiprocessing as mp import torch.nn as nn -from asr_datamodule import LibriSpeechAsrDataModule +from asr_datamodule import AsrDataModule from conformer import Conformer from decoder import Decoder +from gigaspeech import GigaSpeech from joiner import Joiner +from lhotse import CutSet, load_manifest from lhotse.cut import Cut from lhotse.utils import fix_random_seed +from librispeech import LibriSpeech from model import Transducer from torch import Tensor from torch.nn.parallel import DistributedDataParallel as DDP @@ -82,6 +86,14 @@ def get_parser(): help="Master port to use for DDP training.", ) + parser.add_argument( + "--full-libri", + type=str2bool, + default=True, + help="When enabled, use 960h LibriSpeech. " + "Otherwise, use 100h subset.", + ) + parser.add_argument( "--tensorboard", type=str2bool, @@ -109,7 +121,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="transducer_stateless/exp", + default="transducer_stateless_multi_datasets/exp", help="""The experiment dir. It specifies the directory where all training related files, e.g., checkpoints, log, etc, are saved @@ -259,13 +271,19 @@ def get_joiner_model(params: AttributeDict) -> nn.Module: def get_transducer_model(params: AttributeDict) -> nn.Module: encoder = get_encoder_model(params) + decoder = get_decoder_model(params) joiner = get_joiner_model(params) + decoder_giga = get_decoder_model(params) + joiner_giga = get_joiner_model(params) + model = Transducer( encoder=encoder, decoder=decoder, joiner=joiner, + decoder_giga=decoder_giga, + joiner_giga=joiner_giga, ) return model @@ -357,6 +375,17 @@ def save_checkpoint( copyfile(src=filename, dst=best_valid_filename) +def is_libri(c: Cut) -> bool: + """Return True if this cut is from the LibriSpeech dataset. + + Note: + During data preparation, we set the custom field in + the supervision segment of GigaSpeech to dict(origin='giga') + See ../local/preprocess_gigaspeech.py. + """ + return c.supervisions[0].custom is None + + def compute_loss( params: AttributeDict, model: nn.Module, @@ -389,6 +418,8 @@ def compute_loss( supervisions = batch["supervisions"] feature_lens = supervisions["num_frames"].to(device) + libri = is_libri(supervisions["cut"][0]) + texts = batch["supervisions"]["text"] y = sp.encode(texts, out_type=int) y = k2.RaggedTensor(y).to(device) @@ -398,6 +429,7 @@ def compute_loss( x=feature, x_lens=feature_lens, y=y, + libri=libri, modified_transducer_prob=params.modified_transducer_prob, ) @@ -452,7 +484,9 @@ def train_one_epoch( optimizer: torch.optim.Optimizer, sp: spm.SentencePieceProcessor, train_dl: torch.utils.data.DataLoader, + giga_train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, + rng: random.Random, tb_writer: Optional[SummaryWriter] = None, world_size: int = 1, ) -> None: @@ -473,6 +507,8 @@ def train_one_epoch( Dataloader for the training dataset. valid_dl: Dataloader for the validation dataset. + rng: + For select which dataset to use. tb_writer: Writer to write log messages to tensorboard. world_size: @@ -482,7 +518,27 @@ def train_one_epoch( tot_loss = MetricsTracker() - for batch_idx, batch in enumerate(train_dl): + # index 0: for LibriSpeech + # index 1: for GigaSpeech + # This sets the probabilities for choosing which datasets + dl_weights = [0.8, 0.2] + + iter_libri = iter(train_dl) + iter_giga = iter(giga_train_dl) + + batch_idx = 0 + + while True: + idx = rng.choices((0, 1), weights=dl_weights, k=1)[0] + dl = iter_libri if idx == 0 else iter_giga + + try: + batch = next(dl) + except StopIteration: + break + + batch_idx += 1 + params.batch_idx_train += 1 batch_size = len(batch["supervisions"]["text"]) @@ -544,6 +600,25 @@ def train_one_epoch( params.best_train_loss = params.train_loss +def filter_short_and_long_utterances(cuts: CutSet) -> CutSet: + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 20 seconds + return 1.0 <= c.duration <= 20.0 + + num_in_total = len(cuts) + cuts = cuts.filter(remove_short_and_long_utt) + + num_left = len(cuts) + num_removed = num_in_total - num_left + removed_percent = num_removed / num_in_total * 100 + + logging.info(f"Before removing short and long utterances: {num_in_total}") + logging.info(f"After removing short and long utterances: {num_left}") + logging.info(f"Removed {num_removed} utterances ({removed_percent:.5f}%)") + + return cuts + + def run(rank, world_size, args): """ Args: @@ -562,7 +637,9 @@ def run(rank, world_size, args): params.valid_interval = 800 params.warm_step = 8000 - fix_random_seed(42) + seed = 42 + fix_random_seed(seed) + rng = random.Random(seed) if world_size > 1: setup_dist(rank, world_size, params.master_port) @@ -599,7 +676,7 @@ def run(rank, world_size, args): model.to(device) if world_size > 1: logging.info("Using DDP") - model = DDP(model, device_ids=[rank]) + model = DDP(model, device_ids=[rank], find_unused_parameters=True) model.device = device optimizer = Noam( @@ -613,45 +690,66 @@ def run(rank, world_size, args): logging.info("Loading optimizer state dict") optimizer.load_state_dict(checkpoints["optimizer"]) - librispeech = LibriSpeechAsrDataModule(args) + librispeech = LibriSpeech(manifest_dir=args.manifest_dir) train_cuts = librispeech.train_clean_100_cuts() if params.full_libri: train_cuts += librispeech.train_clean_360_cuts() train_cuts += librispeech.train_other_500_cuts() - def remove_short_and_long_utt(c: Cut): - # Keep only utterances with duration between 1 second and 20 seconds - return 1.0 <= c.duration <= 20.0 + train_cuts = filter_short_and_long_utterances(train_cuts) - num_in_total = len(train_cuts) + gigaspeech = GigaSpeech(manifest_dir=args.manifest_dir) + # XL 10k hours + # L 2.5k hours + # M 1k hours + # S 250 hours + # XS 10 hours + # DEV 12 hours + # Test 40 hours + # train_giga_cuts = gigaspeech.train_M_cuts() + train_giga_cuts = gigaspeech.train_S_cuts() + train_giga_cuts = filter_short_and_long_utterances(train_giga_cuts) - train_cuts = train_cuts.filter(remove_short_and_long_utt) + if args.enable_musan: + cuts_musan = load_manifest( + Path(args.manifest_dir) / "cuts_musan.json.gz" + ) + else: + cuts_musan = None - num_left = len(train_cuts) - num_removed = num_in_total - num_left - removed_percent = num_removed / num_in_total * 100 + asr_datamodule = AsrDataModule(args) - logging.info(f"Before removing short and long utterances: {num_in_total}") - logging.info(f"After removing short and long utterances: {num_left}") - logging.info(f"Removed {num_removed} utterances ({removed_percent:.5f}%)") + train_dl = asr_datamodule.train_dataloaders( + train_cuts, + dynamic_bucketing=False, + on_the_fly_feats=False, + cuts_musan=cuts_musan, + ) - train_dl = librispeech.train_dataloaders(train_cuts) + giga_train_dl = asr_datamodule.train_dataloaders( + train_giga_cuts, + dynamic_bucketing=True, + on_the_fly_feats=True, + cuts_musan=cuts_musan, + ) valid_cuts = librispeech.dev_clean_cuts() valid_cuts += librispeech.dev_other_cuts() - valid_dl = librispeech.valid_dataloaders(valid_cuts) + valid_dl = asr_datamodule.valid_dataloaders(valid_cuts) - scan_pessimistic_batches_for_oom( - model=model, - train_dl=train_dl, - optimizer=optimizer, - sp=sp, - params=params, - ) + for dl in [train_dl, giga_train_dl]: + scan_pessimistic_batches_for_oom( + model=model, + train_dl=dl, + optimizer=optimizer, + sp=sp, + params=params, + ) for epoch in range(params.start_epoch, params.num_epochs): train_dl.sampler.set_epoch(epoch) + giga_train_dl.sampler.set_epoch(epoch) cur_lr = optimizer._rate if tb_writer is not None: @@ -671,7 +769,9 @@ def run(rank, world_size, args): optimizer=optimizer, sp=sp, train_dl=train_dl, + giga_train_dl=giga_train_dl, valid_dl=valid_dl, + rng=rng, tb_writer=tb_writer, world_size=world_size, ) @@ -731,7 +831,7 @@ def scan_pessimistic_batches_for_oom( def main(): parser = get_parser() - LibriSpeechAsrDataModule.add_arguments(parser) + AsrDataModule.add_arguments(parser) args = parser.parse_args() args.exp_dir = Path(args.exp_dir)