From 35f5a15a542b3a90ccb9f85b981bf13e0c7658b5 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Thu, 10 Mar 2022 10:13:49 +0800 Subject: [PATCH] Use giga speech dataset as extra training data. --- .../asr_datamodule.py | 2 +- .../gigaspeech.py | 1 + .../librispeech.py | 1 + .../model.py | 30 ++- .../train.py | 233 +++++++++++++++--- .../train.py | 4 +- icefall/env.py | 1 + 7 files changed, 233 insertions(+), 39 deletions(-) create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless_multi_datasets/gigaspeech.py create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless_multi_datasets/librispeech.py diff --git a/egs/librispeech/ASR/pruned_transducer_stateless_multi_datasets/asr_datamodule.py b/egs/librispeech/ASR/pruned_transducer_stateless_multi_datasets/asr_datamodule.py index 07f39b451..2339c5bc2 120000 --- a/egs/librispeech/ASR/pruned_transducer_stateless_multi_datasets/asr_datamodule.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless_multi_datasets/asr_datamodule.py @@ -1 +1 @@ -../transducer/asr_datamodule.py \ No newline at end of file +../transducer_stateless_multi_datasets/asr_datamodule.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless_multi_datasets/gigaspeech.py b/egs/librispeech/ASR/pruned_transducer_stateless_multi_datasets/gigaspeech.py new file mode 120000 index 000000000..c17aa814c --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless_multi_datasets/gigaspeech.py @@ -0,0 +1 @@ +../transducer_stateless_multi_datasets/gigaspeech.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless_multi_datasets/librispeech.py b/egs/librispeech/ASR/pruned_transducer_stateless_multi_datasets/librispeech.py new file mode 120000 index 000000000..3cbe3fd15 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless_multi_datasets/librispeech.py @@ -0,0 +1 @@ +../transducer_stateless_multi_datasets/librispeech.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless_multi_datasets/model.py b/egs/librispeech/ASR/pruned_transducer_stateless_multi_datasets/model.py index ef0a9648c..1e716e2ab 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless_multi_datasets/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless_multi_datasets/model.py @@ -15,6 +15,8 @@ # limitations under the License. +from typing import Optional + import k2 import torch import torch.nn as nn @@ -33,6 +35,8 @@ class Transducer(nn.Module): encoder: EncoderInterface, decoder: nn.Module, joiner: nn.Module, + decoder_giga: Optional[nn.Module] = None, + joiner_giga: Optional[nn.Module] = None, ): """ Args: @@ -49,20 +53,32 @@ class Transducer(nn.Module): It has two inputs with shapes: (N, T, U, C) and (N, T, U, C). Its output shape is also (N, T, U, C). Note that its output contains unnormalized probs, i.e., not processed by log-softmax. + decoder_giga: + The decoder for the GigaSpeech dataset. + joiner_giga: + The joiner for the GigaSpeech dataset. """ super().__init__() assert isinstance(encoder, EncoderInterface), type(encoder) assert hasattr(decoder, "blank_id") + if decoder_giga is not None: + assert hasattr(decoder_giga, "blank_id") + self.encoder = encoder + self.decoder = decoder self.joiner = joiner + self.decoder_giga = decoder_giga + self.joiner_giga = joiner_giga + def forward( self, x: torch.Tensor, x_lens: torch.Tensor, y: k2.RaggedTensor, + libri: bool = True, prune_range: int = 5, am_scale: float = 0.0, lm_scale: float = 0.0, @@ -77,6 +93,9 @@ class Transducer(nn.Module): y: A ragged tensor with 2 axes [utt][label]. It contains labels of each utterance. + libri: + True to use the decoder and joiner for the LibriSpeech dataset. + False to use the decoder and joiner for the GigaSpeech dataset. prune_range: The prune range for rnnt loss, it means how many symbols(context) we are considering for each frame to compute the loss. @@ -114,8 +133,15 @@ class Transducer(nn.Module): # sos_y_padded: [B, S + 1], start with SOS. sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id) + if libri: + decoder = self.decoder + joiner = self.joiner + else: + decoder = self.decoder_giga + joiner = self.joiner_giga + # decoder_out: [B, S + 1, C] - decoder_out = self.decoder(sos_y_padded) + decoder_out = decoder(sos_y_padded) # Note: y does not start with SOS # y_padded : [B, S] @@ -155,7 +181,7 @@ class Transducer(nn.Module): ) # logits : [B, T, prune_range, C] - logits = self.joiner(am_pruned, lm_pruned) + logits = joiner(am_pruned, lm_pruned) pruned_loss = k2.rnnt_loss_pruned( logits=logits, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless_multi_datasets/train.py b/egs/librispeech/ASR/pruned_transducer_stateless_multi_datasets/train.py index 9bfbabf48..8755ba8df 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless_multi_datasets/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless_multi_datasets/train.py @@ -19,20 +19,44 @@ """ Usage: +cd egs/librispeech/ASR/ +./prepare.sh +./prepare_giga_speech.sh + +# 100-hours +export CUDA_VISIBLE_DEVICES="0,1" + +./pruned_transducer_stateless_multi_datasets/train.py \ + --world-size 2 \ + --num-epochs 60 \ + --start-epoch 0 \ + --exp-dir pruned_transducer_stateless_multi_datasets/exp-1 \ + --full-libri 0 \ + --max-duration 300 \ + --prune-range 5 \ + --lr-factor 1.0 \ + --lm-scale 0.25 + + +# 960 hours export CUDA_VISIBLE_DEVICES="0,1,2,3" -./pruned_transducer_stateless/train.py \ +./pruned_transducer_stateless_multi_datasets/train.py \ --world-size 4 \ - --num-epochs 30 \ + --num-epochs 60 \ --start-epoch 0 \ - --exp-dir pruned_transducer_stateless/exp \ + --exp-dir pruned_transducer_stateless_multi_datasets/exp-full \ --full-libri 1 \ - --max-duration 300 + --max-duration 300 \ + --prune-range 5 \ + --lr-factor 5.0 \ + --lm-scale 0.25 """ import argparse import logging +import random from pathlib import Path from shutil import copyfile from typing import Optional, Tuple @@ -42,12 +66,15 @@ import sentencepiece as spm import torch import torch.multiprocessing as mp import torch.nn as nn -from asr_datamodule import LibriSpeechAsrDataModule +from asr_datamodule import AsrDataModule from conformer import Conformer from decoder import Decoder +from gigaspeech import GigaSpeech from joiner import Joiner +from lhotse import CutSet, load_manifest from lhotse.cut import Cut from lhotse.utils import fix_random_seed +from librispeech import LibriSpeech from model import Transducer from torch import Tensor from torch.nn.parallel import DistributedDataParallel as DDP @@ -89,6 +116,14 @@ def get_parser(): help="Master port to use for DDP training.", ) + parser.add_argument( + "--full-libri", + type=str2bool, + default=True, + help="When enabled, use 960h LibriSpeech. " + "Otherwise, use 100h subset.", + ) + parser.add_argument( "--tensorboard", type=str2bool, @@ -116,7 +151,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="pruned_transducer_stateless/exp", + default="pruned_transducer_stateless_multi_datasets/exp", help="""The experiment dir. It specifies the directory where all training related files, e.g., checkpoints, log, etc, are saved @@ -179,6 +214,13 @@ def get_parser(): "with this parameter before adding to the final loss.", ) + parser.add_argument( + "--giga-prob", + type=float, + default=0.2, + help="The probability to select a batch from the GigaSpeech dataset", + ) + parser.add_argument( "--seed", type=int, @@ -253,8 +295,6 @@ def get_params() -> AttributeDict: "dim_feedforward": 2048, "num_encoder_layers": 12, "vgg_frontend": False, - # parameters for decoder - "embedding_dim": 512, # parameters for Noam "warm_step": 80000, # For the 100h subset, use 30000 "env_info": get_env_info(), @@ -302,13 +342,19 @@ def get_joiner_model(params: AttributeDict) -> nn.Module: def get_transducer_model(params: AttributeDict) -> nn.Module: encoder = get_encoder_model(params) + decoder = get_decoder_model(params) joiner = get_joiner_model(params) + decoder_giga = get_decoder_model(params) + joiner_giga = get_joiner_model(params) + model = Transducer( encoder=encoder, decoder=decoder, joiner=joiner, + decoder_giga=decoder_giga, + joiner_giga=joiner_giga, ) return model @@ -400,6 +446,17 @@ def save_checkpoint( copyfile(src=filename, dst=best_valid_filename) +def is_libri(c: Cut) -> bool: + """Return True if this cut is from the LibriSpeech dataset. + + Note: + During data preparation, we set the custom field in + the supervision segment of GigaSpeech to dict(origin='giga') + See ../local/preprocess_gigaspeech.py. + """ + return c.supervisions[0].custom is None + + def compute_loss( params: AttributeDict, model: nn.Module, @@ -432,6 +489,8 @@ def compute_loss( supervisions = batch["supervisions"] feature_lens = supervisions["num_frames"].to(device) + libri = is_libri(supervisions["cut"][0]) + texts = batch["supervisions"]["text"] y = sp.encode(texts, out_type=int) y = k2.RaggedTensor(y).to(device) @@ -441,6 +500,7 @@ def compute_loss( x=feature, x_lens=feature_lens, y=y, + libri=libri, prune_range=params.prune_range, am_scale=params.am_scale, lm_scale=params.lm_scale, @@ -500,7 +560,9 @@ def train_one_epoch( optimizer: torch.optim.Optimizer, sp: spm.SentencePieceProcessor, train_dl: torch.utils.data.DataLoader, + giga_train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, + rng: random.Random, tb_writer: Optional[SummaryWriter] = None, world_size: int = 1, ) -> None: @@ -519,8 +581,12 @@ def train_one_epoch( The optimizer we are using. train_dl: Dataloader for the training dataset. + giga_train_dl: + Dataloader for the GigaSpeech training dataset. valid_dl: Dataloader for the validation dataset. + rng: + For selecting which dataset to use. tb_writer: Writer to write log messages to tensorboard. world_size: @@ -528,6 +594,8 @@ def train_one_epoch( """ model.train() + libri_tot_loss = MetricsTracker() + giga_tot_loss = MetricsTracker() tot_loss = MetricsTracker() def maybe_log_gradients(tag: str): @@ -569,10 +637,32 @@ def train_one_epoch( else: optimizer.step() - for batch_idx, batch in enumerate(train_dl): + # index 0: for LibriSpeech + # index 1: for GigaSpeech + # This sets the probabilities for choosing which datasets + dl_weights = [1 - params.giga_prob, params.giga_prob] + + iter_libri = iter(train_dl) + iter_giga = iter(giga_train_dl) + + batch_idx = 0 + + while True: + idx = rng.choices((0, 1), weights=dl_weights, k=1)[0] + dl = iter_libri if idx == 0 else iter_giga + + try: + batch = next(dl) + except StopIteration: + break + + batch_idx += 1 + params.batch_idx_train += 1 batch_size = len(batch["supervisions"]["text"]) + libri = is_libri(batch["supervisions"]["cut"][0]) + loss, loss_info = compute_loss( params=params, model=model, @@ -582,6 +672,16 @@ def train_one_epoch( ) # summary stats tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + if libri: + libri_tot_loss = ( + libri_tot_loss * (1 - 1 / params.reset_interval) + ) + loss_info + prefix = "libri" # for logging only + else: + giga_tot_loss = ( + giga_tot_loss * (1 - 1 / params.reset_interval) + ) + loss_info + prefix = "giga" # NOTE: We use reduction==sum and loss is computed over utterances # in the batch and there is no normalization to it so far. @@ -597,18 +697,29 @@ def train_one_epoch( if batch_idx % params.log_interval == 0: logging.info( f"Epoch {params.cur_epoch}, " - f"batch {batch_idx}, loss[{loss_info}], " - f"tot_loss[{tot_loss}], batch size: {batch_size}" + f"batch {batch_idx}, {prefix}_loss[{loss_info}], " + f"tot_loss[{tot_loss}], " + f"libri_tot_loss[{libri_tot_loss}], " + f"giga_tot_loss[{giga_tot_loss}], " + f"batch size: {batch_size}" ) if batch_idx % params.log_interval == 0: if tb_writer is not None: loss_info.write_summary( - tb_writer, "train/current_", params.batch_idx_train + tb_writer, + f"train/current_{prefix}_", + params.batch_idx_train, ) tot_loss.write_summary( tb_writer, "train/tot_", params.batch_idx_train ) + libri_tot_loss.write_summary( + tb_writer, "train/libri_tot_", params.batch_idx_train + ) + giga_tot_loss.write_summary( + tb_writer, "train/giga_tot_", params.batch_idx_train + ) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -633,6 +744,25 @@ def train_one_epoch( params.best_train_loss = params.train_loss +def filter_short_and_long_utterances(cuts: CutSet) -> CutSet: + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 20 seconds + return 1.0 <= c.duration <= 20.0 + + num_in_total = len(cuts) + cuts = cuts.filter(remove_short_and_long_utt) + + num_left = len(cuts) + num_removed = num_in_total - num_left + removed_percent = num_removed / num_in_total * 100 + + logging.info(f"Before removing short and long utterances: {num_in_total}") + logging.info(f"After removing short and long utterances: {num_left}") + logging.info(f"Removed {num_removed} utterances ({removed_percent:.5f}%)") + + return cuts + + def run(rank, world_size, args): """ Args: @@ -652,6 +782,7 @@ def run(rank, world_size, args): params.warm_step = 30000 fix_random_seed(params.seed) + rng = random.Random(params.seed) if world_size > 1: setup_dist(rank, world_size, params.master_port) @@ -688,7 +819,7 @@ def run(rank, world_size, args): model.to(device) if world_size > 1: logging.info("Using DDP") - model = DDP(model, device_ids=[rank]) + model = DDP(model, device_ids=[rank], find_unused_parameters=True) model.device = device optimizer = Noam( @@ -702,46 +833,74 @@ def run(rank, world_size, args): logging.info("Loading optimizer state dict") optimizer.load_state_dict(checkpoints["optimizer"]) - librispeech = LibriSpeechAsrDataModule(args) + librispeech = LibriSpeech(manifest_dir=args.manifest_dir) train_cuts = librispeech.train_clean_100_cuts() if params.full_libri: train_cuts += librispeech.train_clean_360_cuts() train_cuts += librispeech.train_other_500_cuts() - def remove_short_and_long_utt(c: Cut): - # Keep only utterances with duration between 1 second and 20 seconds - return 1.0 <= c.duration <= 20.0 + train_cuts = filter_short_and_long_utterances(train_cuts) - num_in_total = len(train_cuts) + gigaspeech = GigaSpeech(manifest_dir=args.manifest_dir) + # XL 10k hours + # L 2.5k hours + # M 1k hours + # S 250 hours + # XS 10 hours + # DEV 12 hours + # Test 40 hours + if params.full_libri: + logging.info("Using the L subset of GigaSpeech (2.5k hours)") + train_giga_cuts = gigaspeech.train_L_cuts() + else: + logging.info("Using the S subset of GigaSpeech (250 hours)") + train_giga_cuts = gigaspeech.train_S_cuts() - train_cuts = train_cuts.filter(remove_short_and_long_utt) + train_giga_cuts = filter_short_and_long_utterances(train_giga_cuts) - num_left = len(train_cuts) - num_removed = num_in_total - num_left - removed_percent = num_removed / num_in_total * 100 + if args.enable_musan: + cuts_musan = load_manifest( + Path(args.manifest_dir) / "cuts_musan.json.gz" + ) + else: + cuts_musan = None - logging.info(f"Before removing short and long utterances: {num_in_total}") - logging.info(f"After removing short and long utterances: {num_left}") - logging.info(f"Removed {num_removed} utterances ({removed_percent:.5f}%)") + asr_datamodule = AsrDataModule(args) - train_dl = librispeech.train_dataloaders(train_cuts) + train_dl = asr_datamodule.train_dataloaders( + train_cuts, + dynamic_bucketing=False, + on_the_fly_feats=False, + cuts_musan=cuts_musan, + ) + + giga_train_dl = asr_datamodule.train_dataloaders( + train_giga_cuts, + dynamic_bucketing=True, + on_the_fly_feats=True, + cuts_musan=cuts_musan, + ) valid_cuts = librispeech.dev_clean_cuts() valid_cuts += librispeech.dev_other_cuts() - valid_dl = librispeech.valid_dataloaders(valid_cuts) + valid_dl = asr_datamodule.valid_dataloaders(valid_cuts) - scan_pessimistic_batches_for_oom( - model=model, - train_dl=train_dl, - optimizer=optimizer, - sp=sp, - params=params, - ) + # It's time consuming to include `giga_train_dl` here + # for dl in [train_dl, giga_train_dl]: + for dl in [train_dl]: + scan_pessimistic_batches_for_oom( + model=model, + train_dl=dl, + optimizer=optimizer, + sp=sp, + params=params, + ) for epoch in range(params.start_epoch, params.num_epochs): fix_random_seed(params.seed + epoch) train_dl.sampler.set_epoch(epoch) + giga_train_dl.sampler.set_epoch(epoch) cur_lr = optimizer._rate if tb_writer is not None: @@ -761,7 +920,9 @@ def run(rank, world_size, args): optimizer=optimizer, sp=sp, train_dl=train_dl, + giga_train_dl=giga_train_dl, valid_dl=valid_dl, + rng=rng, tb_writer=tb_writer, world_size=world_size, ) @@ -821,10 +982,12 @@ def scan_pessimistic_batches_for_oom( def main(): parser = get_parser() - LibriSpeechAsrDataModule.add_arguments(parser) + AsrDataModule.add_arguments(parser) args = parser.parse_args() args.exp_dir = Path(args.exp_dir) + assert 0 <= args.giga_prob < 1, args.giga_prob + world_size = args.world_size assert world_size >= 1 if world_size > 1: diff --git a/egs/librispeech/ASR/transducer_stateless_multi_datasets/train.py b/egs/librispeech/ASR/transducer_stateless_multi_datasets/train.py index 6a57a9cce..334dc9f42 100755 --- a/egs/librispeech/ASR/transducer_stateless_multi_datasets/train.py +++ b/egs/librispeech/ASR/transducer_stateless_multi_datasets/train.py @@ -535,10 +535,12 @@ def train_one_epoch( The optimizer we are using. train_dl: Dataloader for the training dataset. + giga_train_dl: + Dataloader for the GigaSpeech training dataset. valid_dl: Dataloader for the validation dataset. rng: - For select which dataset to use. + For selecting which dataset to use. tb_writer: Writer to write log messages to tensorboard. world_size: diff --git a/icefall/env.py b/icefall/env.py index 0684c4bf1..97c63ccde 100644 --- a/icefall/env.py +++ b/icefall/env.py @@ -97,6 +97,7 @@ def get_env_info() -> Dict[str, Any]: "lhotse-version": lhotse.__version__, "torch-cuda-available": torch.cuda.is_available(), "torch-cuda-version": torch.version.cuda, + "torch-version": torch.__version__, "python-version": sys.version[:3], "icefall-git-branch": get_git_branch_name(), "icefall-git-sha1": get_git_sha1(),