#!/usr/bin/env python3 # Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, # Wei Kang # Mingshuang Luo) # # See ../../../../LICENSE for clarification regarding multiple authors # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ Usage: export CUDA_VISIBLE_DEVICES="0,1" ./pruned_transducer_stateless2/train.py \ --world-size 2 \ --num-epochs 30 \ --start-epoch 0 \ --exp-dir pruned_transducer_stateless2/exp \ --lang-dir data/lang_char \ --max-duration 250 \ --save-every-n 1000 # For mix precision training: ./pruned_transducer_stateless2/train.py \ --world-size 2 \ --num-epochs 30 \ --start-epoch 0 \ --exp-dir pruned_transducer_stateless2/exp \ --lang-dir data/lang_char \ --max-duration 250 \ --save-every-n 1000 --use-fp16 True """ import argparse import logging import os import warnings from pathlib import Path from shutil import copyfile from typing import Any, Dict, Optional, Tuple, Union import k2 import optim import torch import torch.multiprocessing as mp import torch.nn as nn from asr_datamodule import Aidatatang_200zhAsrDataModule from conformer import Conformer from decoder import Decoder from joiner import Joiner from lhotse.cut import Cut from lhotse.dataset.sampling.base import CutSampler from lhotse.utils import fix_random_seed from model import Transducer from optim import Eden, Eve from torch import Tensor from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from icefall import diagnostics from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler from icefall.checkpoint import load_checkpoint, remove_checkpoints from icefall.checkpoint import save_checkpoint as save_checkpoint_impl from icefall.checkpoint import save_checkpoint_with_global_batch_idx from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info from icefall.lexicon import Lexicon from icefall.utils import ( AttributeDict, MetricsTracker, setup_logger, str2bool, torch_autocast, ) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] os.environ["CUDA_LAUNCH_BLOCKING"] = "1" def get_parser(): parser = argparse.ArgumentParser( formatter_class=argparse.ArgumentDefaultsHelpFormatter ) parser.add_argument( "--world-size", type=int, default=1, help="Number of GPUs for DDP training.", ) parser.add_argument( "--master-port", type=int, default=12359, help="Master port to use for DDP training.", ) parser.add_argument( "--tensorboard", type=str2bool, default=True, help="Should various information be logged in tensorboard.", ) parser.add_argument( "--num-epochs", type=int, default=30, help="Number of epochs to train.", ) parser.add_argument( "--start-epoch", type=int, default=0, help="""Resume training from from this epoch. If it is positive, it will load checkpoint from transducer_stateless2/exp/epoch-{start_epoch-1}.pt """, ) parser.add_argument( "--start-batch", type=int, default=0, help="""If positive, --start-epoch is ignored and it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt """, ) parser.add_argument( "--exp-dir", type=str, default="pruned_transducer_stateless2/exp", help="""The experiment dir. It specifies the directory where all training related files, e.g., checkpoints, log, etc, are saved """, ) parser.add_argument( "--lang-dir", type=str, default="data/lang_char", help="""The lang dir It contains language related input files such as "lexicon.txt" """, ) parser.add_argument( "--initial-lr", type=float, default=0.003, help="The initial learning rate. This value should not need to be changed.", ) parser.add_argument( "--lr-batches", type=float, default=5000, help="""Number of steps that affects how rapidly the learning rate decreases. We suggest not to change this.""", ) parser.add_argument( "--lr-epochs", type=float, default=6, help="""Number of epochs that affects how rapidly the learning rate decreases. """, ) parser.add_argument( "--context-size", type=int, default=2, help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--prune-range", type=int, default=5, help="The prune range for rnnt loss, it means how many symbols(context)" "we are using to compute the loss", ) parser.add_argument( "--lm-scale", type=float, default=0.25, help="The scale to smooth the loss with lm " "(output of prediction network) part.", ) parser.add_argument( "--am-scale", type=float, default=0.0, help="The scale to smooth the loss with am (output of encoder network) part.", ) parser.add_argument( "--simple-loss-scale", type=float, default=0.5, help="To get pruning ranges, we will calculate a simple version" "loss(joiner is just addition), this simple loss also uses for" "training (as a regularization item). We will scale the simple loss" "with this parameter before adding to the final loss.", ) parser.add_argument( "--seed", type=int, default=42, help="The seed for random generators intended for reproducibility", ) parser.add_argument( "--print-diagnostics", type=str2bool, default=False, help="Accumulate stats on activations, print them and exit.", ) parser.add_argument( "--save-every-n", type=int, default=8000, help="""Save checkpoint after processing this number of batches" periodically. We save checkpoint to exp-dir/ whenever params.batch_idx_train % save_every_n == 0. The checkpoint filename has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the end of each epoch where `xxx` is the epoch number counting from 0. """, ) parser.add_argument( "--keep-last-k", type=int, default=20, help="""Only keep this number of checkpoints on disk. For instance, if it is 3, there are only 3 checkpoints in the exp-dir with filenames `checkpoint-xxx.pt`. It does not affect checkpoints with name `epoch-xxx.pt`. """, ) parser.add_argument( "--use-fp16", type=str2bool, default=False, help="Whether to use half precision training.", ) return parser def get_params() -> AttributeDict: """Return a dict containing training parameters. All training related parameters that are not passed from the commandline are saved in the variable `params`. Commandline options are merged into `params` after they are parsed, so you can also access them via `params`. Explanation of options saved in `params`: - best_train_loss: Best training loss so far. It is used to select the model that has the lowest training loss. It is updated during the training. - best_valid_loss: Best validation loss so far. It is used to select the model that has the lowest validation loss. It is updated during the training. - best_train_epoch: It is the epoch that has the best training loss. - best_valid_epoch: It is the epoch that has the best validation loss. - batch_idx_train: Used to writing statistics to tensorboard. It contains number of batches trained so far across epochs. - log_interval: Print training loss if batch_idx % log_interval` is 0 - reset_interval: Reset statistics if batch_idx % reset_interval is 0 - valid_interval: Run validation if batch_idx % valid_interval is 0 - feature_dim: The model input dim. It has to match the one used in computing features. - subsampling_factor: The subsampling factor for the model. - encoder_dim: Hidden dim for multi-head attention model. - num_decoder_layers: Number of decoder layer of transformer decoder. - warm_step: The warm_step for Noam optimizer. """ params = AttributeDict( { "best_train_loss": float("inf"), "best_valid_loss": float("inf"), "best_train_epoch": -1, "best_valid_epoch": -1, "batch_idx_train": 10, "log_interval": 1, "reset_interval": 200, "valid_interval": 400, # parameters for conformer "feature_dim": 80, "subsampling_factor": 4, "encoder_dim": 512, "nhead": 8, "dim_feedforward": 2048, "num_encoder_layers": 12, # parameters for decoder "decoder_dim": 512, # parameters for joiner "joiner_dim": 512, # parameters for Noam "model_warm_step": 200, "env_info": get_env_info(), } ) return params def get_encoder_model(params: AttributeDict) -> nn.Module: # TODO: We can add an option to switch between Conformer and Transformer encoder = Conformer( num_features=params.feature_dim, subsampling_factor=params.subsampling_factor, d_model=params.encoder_dim, nhead=params.nhead, dim_feedforward=params.dim_feedforward, num_encoder_layers=params.num_encoder_layers, ) return encoder def get_decoder_model(params: AttributeDict) -> nn.Module: decoder = Decoder( vocab_size=params.vocab_size, decoder_dim=params.decoder_dim, blank_id=params.blank_id, context_size=params.context_size, ) return decoder def get_joiner_model(params: AttributeDict) -> nn.Module: joiner = Joiner( encoder_dim=params.encoder_dim, decoder_dim=params.decoder_dim, joiner_dim=params.joiner_dim, vocab_size=params.vocab_size, ) return joiner def get_transducer_model(params: AttributeDict) -> nn.Module: encoder = get_encoder_model(params) decoder = get_decoder_model(params) joiner = get_joiner_model(params) model = Transducer( encoder=encoder, decoder=decoder, joiner=joiner, encoder_dim=params.encoder_dim, decoder_dim=params.decoder_dim, joiner_dim=params.joiner_dim, vocab_size=params.vocab_size, ) return model def load_checkpoint_if_available( params: AttributeDict, model: nn.Module, optimizer: Optional[torch.optim.Optimizer] = None, scheduler: Optional[LRSchedulerType] = None, ) -> Optional[Dict[str, Any]]: """Load checkpoint from file. If params.start_batch is positive, it will load the checkpoint from `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if params.start_epoch is positive, it will load the checkpoint from `params.start_epoch - 1`. Apart from loading state dict for `model` and `optimizer` it also updates `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, and `best_valid_loss` in `params`. Args: params: The return value of :func:`get_params`. model: The training model. optimizer: The optimizer that we are using. scheduler: The scheduler that we are using. Returns: Return a dict containing previously saved training info. """ if params.start_batch > 0: filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" elif params.start_epoch > 0: filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" else: return None assert filename.is_file(), f"{filename} does not exist!" saved_params = load_checkpoint( filename, model=model, optimizer=optimizer, scheduler=scheduler, ) keys = [ "best_train_epoch", "best_valid_epoch", "batch_idx_train", "best_train_loss", "best_valid_loss", ] for k in keys: params[k] = saved_params[k] if params.start_batch > 0: if "cur_epoch" in saved_params: params["start_epoch"] = saved_params["cur_epoch"] return saved_params def save_checkpoint( params: AttributeDict, model: nn.Module, optimizer: Optional[torch.optim.Optimizer] = None, scheduler: Optional[LRSchedulerType] = None, sampler: Optional[CutSampler] = None, scaler: Optional[GradScaler] = None, rank: int = 0, ) -> None: """Save model, optimizer, scheduler and training stats to file. Args: params: It is returned by :func:`get_params`. model: The training model. optimizer: The optimizer used in the training. sampler: The sampler for the training dataset. scaler: The scaler used for mix precision training. """ if rank != 0: return filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" save_checkpoint_impl( filename=filename, model=model, params=params, optimizer=optimizer, scheduler=scheduler, sampler=sampler, scaler=scaler, rank=rank, ) if params.best_train_epoch == params.cur_epoch: best_train_filename = params.exp_dir / "best-train-loss.pt" copyfile(src=filename, dst=best_train_filename) if params.best_valid_epoch == params.cur_epoch: best_valid_filename = params.exp_dir / "best-valid-loss.pt" copyfile(src=filename, dst=best_valid_filename) def compute_loss( params: AttributeDict, model: nn.Module, graph_compiler: CharCtcTrainingGraphCompiler, batch: dict, is_training: bool, warmup: float = 1.0, ) -> Tuple[Tensor, MetricsTracker]: """ Compute CTC loss given the model and its inputs. Args: params: Parameters for training. See :func:`get_params`. model: The model for training. It is an instance of Conformer in our case. batch: A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` for the content in it. is_training: True for training. False for validation. When it is True, this function enables autograd during computation; when it is False, it disables autograd. warmup: a floating point value which increases throughout training; values >= 1.0 are fully warmed up and have all modules present. """ device = model.device feature = batch["inputs"] # at entry, feature is (N, T, C) assert feature.ndim == 3 feature = feature.to(device) supervisions = batch["supervisions"] feature_lens = supervisions["num_frames"].to(device) texts = batch["supervisions"]["text"] y = graph_compiler.texts_to_ids(texts) if type(y) == list: y = k2.RaggedTensor(y).to(device) else: y = y.to(device) with torch.set_grad_enabled(is_training): simple_loss, pruned_loss = model( x=feature, x_lens=feature_lens, y=y, prune_range=params.prune_range, am_scale=params.am_scale, lm_scale=params.lm_scale, warmup=warmup, ) # after the main warmup step, we keep pruned_loss_scale small # for the same amount of time (model_warm_step), to avoid # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. pruned_loss_scale = ( 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) ) loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss assert loss.requires_grad == is_training info = MetricsTracker() with warnings.catch_warnings(): warnings.simplefilter("ignore") info["frames"] = (feature_lens // params.subsampling_factor).sum().item() # Note: We use reduction=sum while computing the loss. info["loss"] = loss.detach().cpu().item() info["simple_loss"] = simple_loss.detach().cpu().item() info["pruned_loss"] = pruned_loss.detach().cpu().item() return loss, info def compute_validation_loss( params: AttributeDict, model: nn.Module, graph_compiler: CharCtcTrainingGraphCompiler, valid_dl: torch.utils.data.DataLoader, world_size: int = 1, ) -> MetricsTracker: """Run the validation process.""" model.eval() tot_loss = MetricsTracker() for batch_idx, batch in enumerate(valid_dl): loss, loss_info = compute_loss( params=params, model=model, graph_compiler=graph_compiler, batch=batch, is_training=False, ) assert loss.requires_grad is False tot_loss = tot_loss + loss_info if world_size > 1: tot_loss.reduce(loss.device) loss_value = tot_loss["loss"] / tot_loss["frames"] if loss_value < params.best_valid_loss: params.best_valid_epoch = params.cur_epoch params.best_valid_loss = loss_value return tot_loss def train_one_epoch( params: AttributeDict, model: nn.Module, optimizer: torch.optim.Optimizer, scheduler: LRSchedulerType, graph_compiler: CharCtcTrainingGraphCompiler, train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, scaler: GradScaler, tb_writer: Optional[SummaryWriter] = None, world_size: int = 1, rank: int = 0, ) -> None: """Train the model for one epoch. The training loss from the mean of all frames is saved in `params.train_loss`. It runs the validation process every `params.valid_interval` batches. Args: params: It is returned by :func:`get_params`. model: The model for training. optimizer: The optimizer we are using. scheduler: The learning rate scheduler, we call step() every step. train_dl: Dataloader for the training dataset. valid_dl: Dataloader for the validation dataset. scaler: The scaler used for mix precision training. tb_writer: Writer to write log messages to tensorboard. world_size: Number of nodes in DDP training. If it is 1, DDP is disabled. rank: The rank of the node in DDP training. If no DDP is used, it should be set to 0. """ model.train() tot_loss = MetricsTracker() for batch_idx, batch in enumerate(train_dl): params.batch_idx_train += 1 batch_size = len(batch["supervisions"]["text"]) with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, graph_compiler=graph_compiler, batch=batch, is_training=True, warmup=(params.batch_idx_train / params.model_warm_step), ) # summary stats tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info # NOTE: We use reduction==sum and loss is computed over utterances # in the batch and there is no normalization to it so far. scaler.scale(loss).backward() scheduler.step_batch(params.batch_idx_train) scaler.step(optimizer) scaler.update() optimizer.zero_grad() if params.print_diagnostics and batch_idx == 5: return if ( params.batch_idx_train > 0 and params.batch_idx_train % params.save_every_n == 0 ): save_checkpoint_with_global_batch_idx( out_dir=params.exp_dir, global_batch_idx=params.batch_idx_train, model=model, params=params, optimizer=optimizer, scheduler=scheduler, sampler=train_dl.sampler, scaler=scaler, rank=rank, ) remove_checkpoints( out_dir=params.exp_dir, topk=params.keep_last_k, rank=rank, ) if batch_idx % params.log_interval == 0: cur_lr = scheduler.get_last_lr()[0] logging.info( f"Epoch {params.cur_epoch}, " f"batch {batch_idx}, loss[{loss_info}], " f"tot_loss[{tot_loss}], batch size: {batch_size}, " f"lr: {cur_lr:.2e}" ) if tb_writer is not None: tb_writer.add_scalar( "train/learning_rate", cur_lr, params.batch_idx_train ) loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") valid_info = compute_validation_loss( params=params, model=model, graph_compiler=graph_compiler, valid_dl=valid_dl, world_size=world_size, ) model.train() logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") if tb_writer is not None: valid_info.write_summary( tb_writer, "train/valid_", params.batch_idx_train ) loss_value = tot_loss["loss"] / tot_loss["frames"] params.train_loss = loss_value if params.train_loss < params.best_train_loss: params.best_train_epoch = params.cur_epoch params.best_train_loss = params.train_loss def run(rank, world_size, args): """ Args: rank: It is a value between 0 and `world_size-1`, which is passed automatically by `mp.spawn()` in :func:`main`. The node with rank 0 is responsible for saving checkpoint. world_size: Number of GPUs for DDP training. args: The return value of get_parser().parse_args() """ params = get_params() params.update(vars(args)) fix_random_seed(params.seed) if world_size > 1: setup_dist(rank, world_size, params.master_port) setup_logger(f"{params.exp_dir}/log/log-train") logging.info("Training started") if args.tensorboard and rank == 0: tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") else: tb_writer = None device = torch.device("cpu") if torch.cuda.is_available(): device = torch.device("cuda", rank) logging.info(f"Device: {device}") lexicon = Lexicon(params.lang_dir) graph_compiler = CharCtcTrainingGraphCompiler( lexicon=lexicon, device=device, ) params.blank_id = lexicon.token_table[""] params.vocab_size = max(lexicon.tokens) + 1 logging.info(params) logging.info("About to create model") model = get_transducer_model(params) num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") checkpoints = load_checkpoint_if_available(params=params, model=model) model.to(device) if world_size > 1: logging.info("Using DDP") model = DDP(model, device_ids=[rank]) model.device = device optimizer = Eve(model.parameters(), lr=params.initial_lr) scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) if checkpoints and "optimizer" in checkpoints: logging.info("Loading optimizer state dict") optimizer.load_state_dict(checkpoints["optimizer"]) if ( checkpoints and "scheduler" in checkpoints and checkpoints["scheduler"] is not None ): logging.info("Loading scheduler state dict") scheduler.load_state_dict(checkpoints["scheduler"]) if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) aidatatang_200zh = Aidatatang_200zhAsrDataModule(args) train_cuts = aidatatang_200zh.train_cuts() valid_cuts = aidatatang_200zh.valid_cuts() def remove_short_and_long_utt(c: Cut): # Keep only utterances with duration between 1 second and 10.0 seconds # # Caution: There is a reason to select 10.0 here. Please see # ../local/display_manifest_statistics.py # # You should use ../local/display_manifest_statistics.py to get # an utterance duration distribution for your dataset to select # the threshold return 1.0 <= c.duration <= 10.0 train_cuts = train_cuts.filter(remove_short_and_long_utt) valid_dl = aidatatang_200zh.valid_dataloaders(valid_cuts) if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: # We only load the sampler's state dict when it loads a checkpoint # saved in the middle of an epoch sampler_state_dict = checkpoints["sampler"] else: sampler_state_dict = None train_dl = aidatatang_200zh.train_dataloaders( train_cuts, sampler_state_dict=sampler_state_dict ) if not params.print_diagnostics and params.start_batch == 0: scan_pessimistic_batches_for_oom( model=model, train_dl=train_dl, optimizer=optimizer, graph_compiler=graph_compiler, params=params, ) scaler = GradScaler(enabled=params.use_fp16) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) for epoch in range(params.start_epoch, params.num_epochs): scheduler.step_epoch(epoch) fix_random_seed(params.seed + epoch) train_dl.sampler.set_epoch(epoch) if tb_writer is not None: tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) params.cur_epoch = epoch train_one_epoch( params=params, model=model, optimizer=optimizer, scheduler=scheduler, graph_compiler=graph_compiler, train_dl=train_dl, valid_dl=valid_dl, scaler=scaler, tb_writer=tb_writer, world_size=world_size, rank=rank, ) if params.print_diagnostics: diagnostic.print_diagnostics() break save_checkpoint( params=params, model=model, optimizer=optimizer, scheduler=scheduler, sampler=train_dl.sampler, scaler=scaler, rank=rank, ) logging.info("Done!") if world_size > 1: torch.distributed.barrier() cleanup_dist() def scan_pessimistic_batches_for_oom( model: nn.Module, train_dl: torch.utils.data.DataLoader, optimizer: torch.optim.Optimizer, graph_compiler: CharCtcTrainingGraphCompiler, params: AttributeDict, ): from lhotse.dataset import find_pessimistic_batches logging.info( "Sanity check -- see if any of the batches in epoch 0 would cause OOM." ) batches, crit_values = find_pessimistic_batches(train_dl.sampler) for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: # warmup = 0.0 is so that the derivs for the pruned loss stay zero # (i.e. are not remembered by the decaying-average in adam), because # we want to avoid these params being subject to shrinkage in adam. with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, graph_compiler=graph_compiler, batch=batch, is_training=True, warmup=0.0, ) loss.backward() optimizer.step() optimizer.zero_grad() except RuntimeError as e: if "CUDA out of memory" in str(e): logging.error( "Your GPU ran out of memory with the current " "max_duration setting. We recommend decreasing " "max_duration and trying again.\n" f"Failing criterion: {criterion} " f"(={crit_values[criterion]}) ..." ) raise def main(): parser = get_parser() Aidatatang_200zhAsrDataModule.add_arguments(parser) args = parser.parse_args() args.lang_dir = Path(args.lang_dir) args.exp_dir = Path(args.exp_dir) world_size = args.world_size assert world_size >= 1 if world_size > 1: mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) else: run(rank=0, world_size=1, args=args) torch.set_num_threads(1) torch.set_num_interop_threads(1) if __name__ == "__main__": main()