#!/usr/bin/env python3 # Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang, # Wei Kang, # Mingshuang Luo, # Zengwei Yao, # Daniel Povey) # # See ../../../../LICENSE for clarification regarding multiple authors # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ Usage: export CUDA_VISIBLE_DEVICES="0,1,2,3" # For non-streaming model training: ./zipformer/train.py \ --world-size 4 \ --num-epochs 30 \ --start-epoch 1 \ --use-fp16 1 \ --exp-dir zipformer/exp \ --max-duration 1000 # For streaming model training: ./zipformer/train.py \ --world-size 4 \ --num-epochs 30 \ --start-epoch 1 \ --use-fp16 1 \ --exp-dir zipformer/exp \ --causal 1 \ --max-duration 1000 It supports training with: - transducer loss (default), with `--use-transducer True --use-ctc False` - ctc loss (not recommended), with `--use-transducer False --use-ctc True` - transducer loss & ctc loss, with `--use-transducer True --use-ctc True` """ import argparse import copy import logging import warnings from pathlib import Path from shutil import copyfile from typing import Any, Dict, Optional, Tuple, Union import k2 import optim import torch import torch.multiprocessing as mp import torch.nn as nn from asr_datamodule import WenetSpeechAsrDataModule from lhotse.cut import Cut from lhotse.dataset.sampling.base import CutSampler from lhotse.utils import fix_random_seed from model import AsrModel from optim import Eden, ScaledAdam from scaling import ScheduledFloat from torch import Tensor from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from train import ( add_model_arguments, add_training_arguments, display_and_save_batch, get_adjusted_batch_count, get_model, get_params, load_checkpoint_if_available, save_checkpoint, scan_pessimistic_batches_for_oom, set_batch_count, ) from icefall import diagnostics from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler from icefall.checkpoint import remove_checkpoints from icefall.checkpoint import save_checkpoint as save_checkpoint_impl from icefall.checkpoint import ( save_checkpoint_with_global_batch_idx, update_averaged_model, ) from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info from icefall.hooks import register_inf_check_hooks from icefall.lexicon import Lexicon from icefall.utils import ( AttributeDict, MetricsTracker, get_parameter_groups_with_lrs, setup_logger, str2bool, text_to_pinyin, ) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] def get_parser(): parser = argparse.ArgumentParser( formatter_class=argparse.ArgumentDefaultsHelpFormatter ) parser.add_argument( "--lang-dir", type=str, default="data/lang_partial_tone", help="Path to the pinyin lang directory", ) parser.add_argument( "--pinyin-type", type=str, default="partial_with_tone", help=""" The style of the output pinyin, should be: full_with_tone : zhong1 guo2 full_no_tone : zhong guo partial_with_tone : zh ong1 g uo2 partial_no_tone : zh ong g uo """, ) parser.add_argument( "--pinyin-errors", default="split", type=str, help="""How to handle characters that has no pinyin, see `text_to_pinyin` in icefall/utils.py for details """, ) add_training_arguments(parser) add_model_arguments(parser) return parser def compute_loss( params: AttributeDict, model: Union[nn.Module, DDP], graph_compiler: CharCtcTrainingGraphCompiler, batch: dict, is_training: bool, ) -> Tuple[Tensor, MetricsTracker]: """ Compute CTC loss given the model and its inputs. Args: params: Parameters for training. See :func:`get_params`. model: The model for training. It is an instance of Zipformer in our case. batch: A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` for the content in it. is_training: True for training. False for validation. When it is True, this function enables autograd during computation; when it is False, it disables autograd. warmup: a floating point value which increases throughout training; values >= 1.0 are fully warmed up and have all modules present. """ device = model.device if isinstance(model, DDP) else next(model.parameters()).device feature = batch["inputs"] # at entry, feature is (N, T, C) assert feature.ndim == 3 feature = feature.to(device) supervisions = batch["supervisions"] feature_lens = supervisions["num_frames"].to(device) batch_idx_train = params.batch_idx_train warm_step = params.warm_step texts = batch["supervisions"]["text"] y = graph_compiler.texts_to_ids(texts, sep="/") y = k2.RaggedTensor(y).to(device) with torch.set_grad_enabled(is_training): simple_loss, pruned_loss, _ = model( x=feature, x_lens=feature_lens, y=y, prune_range=params.prune_range, am_scale=params.am_scale, lm_scale=params.lm_scale, ) s = params.simple_loss_scale # take down the scale on the simple loss from 1.0 at the start # to params.simple_loss scale by warm_step. simple_loss_scale = ( s if batch_idx_train >= warm_step else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) ) pruned_loss_scale = ( 1.0 if batch_idx_train >= warm_step else 0.1 + 0.9 * (batch_idx_train / warm_step) ) loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss assert loss.requires_grad == is_training info = MetricsTracker() with warnings.catch_warnings(): warnings.simplefilter("ignore") info["frames"] = (feature_lens // params.subsampling_factor).sum().item() # Note: We use reduction=sum while computing the loss. info["loss"] = loss.detach().cpu().item() info["simple_loss"] = simple_loss.detach().cpu().item() info["pruned_loss"] = pruned_loss.detach().cpu().item() return loss, info def compute_validation_loss( params: AttributeDict, model: Union[nn.Module, DDP], graph_compiler: CharCtcTrainingGraphCompiler, valid_dl: torch.utils.data.DataLoader, world_size: int = 1, ) -> MetricsTracker: """Run the validation process.""" model.eval() tot_loss = MetricsTracker() for batch_idx, batch in enumerate(valid_dl): loss, loss_info = compute_loss( params=params, model=model, graph_compiler=graph_compiler, batch=batch, is_training=False, ) assert loss.requires_grad is False tot_loss = tot_loss + loss_info if world_size > 1: tot_loss.reduce(loss.device) loss_value = tot_loss["loss"] / tot_loss["frames"] if loss_value < params.best_valid_loss: params.best_valid_epoch = params.cur_epoch params.best_valid_loss = loss_value return tot_loss def train_one_epoch( params: AttributeDict, model: Union[nn.Module, DDP], optimizer: torch.optim.Optimizer, scheduler: LRSchedulerType, graph_compiler: CharCtcTrainingGraphCompiler, train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, scaler: GradScaler, model_avg: Optional[nn.Module] = None, tb_writer: Optional[SummaryWriter] = None, world_size: int = 1, rank: int = 0, ) -> None: """Train the model for one epoch. The training loss from the mean of all frames is saved in `params.train_loss`. It runs the validation process every `params.valid_interval` batches. Args: params: It is returned by :func:`get_params`. model: The model for training. optimizer: The optimizer we are using. scheduler: The learning rate scheduler, we call step() every step. train_dl: Dataloader for the training dataset. valid_dl: Dataloader for the validation dataset. scaler: The scaler used for mix precision training. model_avg: The stored model averaged from the start of training. tb_writer: Writer to write log messages to tensorboard. world_size: Number of nodes in DDP training. If it is 1, DDP is disabled. rank: The rank of the node in DDP training. If no DDP is used, it should be set to 0. """ model.train() tot_loss = MetricsTracker() cur_batch_idx = params.get("cur_batch_idx", 0) saved_bad_model = False def save_bad_model(suffix: str = ""): save_checkpoint_impl( filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt", model=model, model_avg=model_avg, params=params, optimizer=optimizer, scheduler=scheduler, sampler=train_dl.sampler, scaler=scaler, rank=0, ) for batch_idx, batch in enumerate(train_dl): if batch_idx % 10 == 0: set_batch_count(model, get_adjusted_batch_count(params)) if batch_idx < cur_batch_idx: continue cur_batch_idx = batch_idx params.batch_idx_train += 1 batch_size = len(batch["supervisions"]["text"]) try: with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, graph_compiler=graph_compiler, batch=batch, is_training=True, ) # summary stats tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info # NOTE: We use reduction==sum and loss is computed over utterances # in the batch and there is no normalization to it so far. scaler.scale(loss).backward() scheduler.step_batch(params.batch_idx_train) scaler.step(optimizer) scaler.update() optimizer.zero_grad() except: # noqa save_bad_model() display_and_save_batch(batch, params=params, graph_compiler=graph_compiler) raise if params.print_diagnostics and batch_idx == 5: return if ( rank == 0 and params.batch_idx_train > 0 and params.batch_idx_train % params.average_period == 0 ): update_averaged_model( params=params, model_cur=model, model_avg=model_avg, ) if ( params.batch_idx_train > 0 and params.batch_idx_train % params.save_every_n == 0 ): params.cur_batch_idx = batch_idx save_checkpoint_with_global_batch_idx( out_dir=params.exp_dir, global_batch_idx=params.batch_idx_train, model=model, model_avg=model_avg, params=params, optimizer=optimizer, scheduler=scheduler, sampler=train_dl.sampler, scaler=scaler, rank=rank, ) del params.cur_batch_idx remove_checkpoints( out_dir=params.exp_dir, topk=params.keep_last_k, rank=rank, ) if batch_idx % 100 == 0 and params.use_fp16: # If the grad scale was less than 1, try increasing it. The _growth_interval # of the grad scaler is configurable, but we can't configure it to have different # behavior depending on the current grad scale. cur_grad_scale = scaler._scale.item() if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0): scaler.update(cur_grad_scale * 2.0) if cur_grad_scale < 0.01: if not saved_bad_model: save_bad_model(suffix="-first-warning") saved_bad_model = True logging.warning(f"Grad scale is small: {cur_grad_scale}") if cur_grad_scale < 1.0e-05: save_bad_model() raise RuntimeError( f"grad_scale is too small, exiting: {cur_grad_scale}" ) if batch_idx % params.log_interval == 0: cur_lr = max(scheduler.get_last_lr()) cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 logging.info( f"Epoch {params.cur_epoch}, " f"batch {batch_idx}, loss[{loss_info}], " f"tot_loss[{tot_loss}], batch size: {batch_size}, " f"lr: {cur_lr:.2e}, " + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") ) if tb_writer is not None: tb_writer.add_scalar( "train/learning_rate", cur_lr, params.batch_idx_train ) loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if params.use_fp16: tb_writer.add_scalar( "train/grad_scale", cur_grad_scale, params.batch_idx_train ) if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: logging.info("Computing validation loss") valid_info = compute_validation_loss( params=params, model=model, graph_compiler=graph_compiler, valid_dl=valid_dl, world_size=world_size, ) model.train() logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") logging.info( f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" ) if tb_writer is not None: valid_info.write_summary( tb_writer, "train/valid_", params.batch_idx_train ) loss_value = tot_loss["loss"] / tot_loss["frames"] params.train_loss = loss_value if params.train_loss < params.best_train_loss: params.best_train_epoch = params.cur_epoch params.best_train_loss = params.train_loss def run(rank, world_size, args): """ Args: rank: It is a value between 0 and `world_size-1`, which is passed automatically by `mp.spawn()` in :func:`main`. The node with rank 0 is responsible for saving checkpoint. world_size: Number of GPUs for DDP training. args: The return value of get_parser().parse_args() """ params = get_params() params.update(vars(args)) fix_random_seed(params.seed) if world_size > 1: setup_dist(rank, world_size, params.master_port) setup_logger(f"{params.exp_dir}/log/log-train") logging.info("Training started") if args.tensorboard and rank == 0: tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") else: tb_writer = None device = torch.device("cpu") if torch.cuda.is_available(): device = torch.device("cuda", rank) logging.info(f"Device: {device}") lexicon = Lexicon(params.lang_dir) graph_compiler = CharCtcTrainingGraphCompiler( lexicon=lexicon, device=device, ) params.blank_id = lexicon.token_table[""] params.vocab_size = max(lexicon.tokens) + 1 if not params.use_transducer: params.ctc_loss_scale = 1.0 logging.info(params) logging.info("About to create model") model = get_model(params) num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") assert params.save_every_n >= params.average_period model_avg: Optional[nn.Module] = None if rank == 0: # model_avg is only used with rank 0 model_avg = copy.deepcopy(model).to(torch.float64) assert params.start_epoch > 0, params.start_epoch checkpoints = load_checkpoint_if_available( params=params, model=model, model_avg=model_avg ) model.to(device) if world_size > 1: logging.info("Using DDP") model = DDP(model, device_ids=[rank], find_unused_parameters=True) optimizer = ScaledAdam( get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True), lr=params.base_lr, # should have no effect clipping_scale=2.0, ) scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) if checkpoints and "optimizer" in checkpoints: logging.info("Loading optimizer state dict") optimizer.load_state_dict(checkpoints["optimizer"]) if ( checkpoints and "scheduler" in checkpoints and checkpoints["scheduler"] is not None ): logging.info("Loading scheduler state dict") scheduler.load_state_dict(checkpoints["scheduler"]) if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( 2**22 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) if params.inf_check: register_inf_check_hooks(model) wenetspeech = WenetSpeechAsrDataModule(args) train_cuts = wenetspeech.train_cuts() def remove_short_and_long_utt(c: Cut): # Keep only utterances with duration between 1 second and 20 seconds # # Caution: There is a reason to select 20.0 here. Please see # ../local/display_manifest_statistics.py # # You should use ../local/display_manifest_statistics.py to get # an utterance duration distribution for your dataset to select # the threshold if c.duration < 1.0 or c.duration > 15.0: # logging.warning( # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" # ) return False return True def encode_text(c: Cut): # Text normalize for each sample text = c.supervisions[0].text text = "/".join( text_to_pinyin(text, mode=params.pinyin_type, errors=params.pinyin_errors) ) c.supervisions[0].text = text return c train_cuts = train_cuts.filter(remove_short_and_long_utt) train_cuts = train_cuts.map(encode_text) if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: # We only load the sampler's state dict when it loads a checkpoint # saved in the middle of an epoch sampler_state_dict = checkpoints["sampler"] else: sampler_state_dict = None train_dl = wenetspeech.train_dataloaders( train_cuts, sampler_state_dict=sampler_state_dict ) valid_cuts = wenetspeech.valid_cuts() valid_cuts = valid_cuts.map(encode_text) valid_dl = wenetspeech.valid_dataloaders(valid_cuts) if not params.print_diagnostics: # scan_pessimistic_batches_for_oom( # model=model, # train_dl=train_dl, # optimizer=optimizer, # graph_compiler=graph_compiler, # params=params, # ) pass scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) for epoch in range(params.start_epoch, params.num_epochs + 1): scheduler.step_epoch(epoch - 1) fix_random_seed(params.seed + epoch - 1) train_dl.sampler.set_epoch(epoch - 1) if tb_writer is not None: tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) params.cur_epoch = epoch train_one_epoch( params=params, model=model, model_avg=model_avg, optimizer=optimizer, scheduler=scheduler, graph_compiler=graph_compiler, train_dl=train_dl, valid_dl=valid_dl, scaler=scaler, tb_writer=tb_writer, world_size=world_size, rank=rank, ) if params.print_diagnostics: diagnostic.print_diagnostics() break save_checkpoint( params=params, model=model, model_avg=model_avg, optimizer=optimizer, scheduler=scheduler, sampler=train_dl.sampler, scaler=scaler, rank=rank, ) logging.info("Done!") if world_size > 1: torch.distributed.barrier() cleanup_dist() def main(): parser = get_parser() WenetSpeechAsrDataModule.add_arguments(parser) args = parser.parse_args() args.lang_dir = Path(args.lang_dir) args.exp_dir = Path(args.exp_dir) world_size = args.world_size assert world_size >= 1 if world_size > 1: mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) else: run(rank=0, world_size=1, args=args) if __name__ == "__main__": main()