#!/usr/bin/env python3 # Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang, # Wei Kang, # Mingshuang Luo, # Zengwei Yao, # Yifan Yang, # Daniel Povey) # # See ../../../../LICENSE for clarification regarding multiple authors # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ Usage: export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7" # For non-streaming model training: ./zipformer/finetune.py \ --world-size 8 \ --num-epochs 30 \ --start-epoch 1 \ --use-fp16 1 \ --exp-dir zipformer/exp \ --max-duration 1000 # For streaming model training: ./zipformer/fintune.py \ --world-size 8 \ --num-epochs 30 \ --start-epoch 1 \ --use-fp16 1 \ --exp-dir zipformer/exp \ --causal 1 \ --max-duration 1000 It supports training with: - transducer loss (default), with `--use-transducer True --use-ctc False` - ctc loss (not recommended), with `--use-transducer False --use-ctc True` - transducer loss & ctc loss, with `--use-transducer True --use-ctc True` """ import argparse import copy import logging import warnings from pathlib import Path from shutil import copyfile from typing import Any, Dict, List, Optional, Tuple, Union import k2 import optim import sentencepiece as spm import torch import torch.multiprocessing as mp import torch.nn as nn from asr_datamodule import GigaSpeechAsrDataModule from decoder import Decoder from joiner import Joiner from lhotse.cut import Cut, CutSet from lhotse.dataset.sampling.base import CutSampler from lhotse.utils import fix_random_seed from model import AsrModel from optim import Eden, ScaledAdam from torch import Tensor from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from train import ( add_model_arguments, add_training_arguments, compute_loss, compute_validation_loss, display_and_save_batch, get_adjusted_batch_count, get_model, get_params, load_checkpoint_if_available, save_checkpoint, scan_pessimistic_batches_for_oom, set_batch_count, ) from icefall import diagnostics from icefall.checkpoint import remove_checkpoints from icefall.checkpoint import save_checkpoint as save_checkpoint_impl from icefall.checkpoint import ( save_checkpoint_with_global_batch_idx, update_averaged_model, ) from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info from icefall.err import raise_grad_scale_is_too_small_error from icefall.hooks import register_inf_check_hooks from icefall.utils import ( AttributeDict, MetricsTracker, get_parameter_groups_with_lrs, setup_logger, str2bool, ) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] def add_finetune_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--use-mux", type=str2bool, default=False, help=""" Whether to adapt. If true, we will mix 5% of the new data with 95% of the original data to fine-tune. """, ) parser.add_argument( "--init-modules", type=str, default=None, help=""" Modules to be initialized. It matches all parameters starting with a specific key. The keys are given with Comma seperated. If None, all modules will be initialised. For example, if you only want to initialise all parameters staring with "encoder", use "encoder"; if you want to initialise parameters starting with encoder or decoder, use "encoder,joiner". """, ) parser.add_argument( "--finetune-ckpt", type=str, default=None, help="Fine-tuning from which checkpoint (a path to a .pt file)", ) parser.add_argument( "--continue-finetune", type=str2bool, default=False, help="Continue finetuning or finetune from pre-trained model", ) def get_parser(): parser = argparse.ArgumentParser( formatter_class=argparse.ArgumentDefaultsHelpFormatter ) parser.add_argument( "--bpe-model", type=str, default="data/lang_bpe_500/bpe.model", help="Path to the BPE model", ) add_training_arguments(parser) add_model_arguments(parser) add_finetune_arguments(parser) return parser def load_model_params( ckpt: str, model: nn.Module, init_modules: List[str] = None, strict: bool = True ): """Load model params from checkpoint Args: ckpt (str): Path to the checkpoint model (nn.Module): model to be loaded """ logging.info(f"Loading checkpoint from {ckpt}") checkpoint = torch.load(ckpt, map_location="cpu") # if module list is empty, load the whole model from ckpt if not init_modules: if next(iter(checkpoint["model"])).startswith("module."): logging.info("Loading checkpoint saved by DDP") dst_state_dict = model.state_dict() src_state_dict = checkpoint["model"] for key in dst_state_dict.keys(): src_key = "{}.{}".format("module", key) dst_state_dict[key] = src_state_dict.pop(src_key) assert len(src_state_dict) == 0 model.load_state_dict(dst_state_dict, strict=strict) else: model.load_state_dict(checkpoint["model"], strict=strict) else: src_state_dict = checkpoint["model"] dst_state_dict = model.state_dict() for module in init_modules: logging.info(f"Loading parameters starting with prefix {module}") src_keys = [ k for k in src_state_dict.keys() if k.startswith(module.strip() + ".") ] dst_keys = [ k for k in dst_state_dict.keys() if k.startswith(module.strip() + ".") ] assert set(src_keys) == set(dst_keys) # two sets should match exactly for key in src_keys: dst_state_dict[key] = src_state_dict.pop(key) model.load_state_dict(dst_state_dict, strict=strict) return None def train_one_epoch( params: AttributeDict, model: Union[nn.Module, DDP], optimizer: torch.optim.Optimizer, scheduler: LRSchedulerType, sp: spm.SentencePieceProcessor, train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, scaler: GradScaler, model_avg: Optional[nn.Module] = None, tb_writer: Optional[SummaryWriter] = None, world_size: int = 1, rank: int = 0, ) -> None: """Train the model for one epoch. The training loss from the mean of all frames is saved in `params.train_loss`. It runs the validation process every `params.valid_interval` batches. Args: params: It is returned by :func:`get_params`. model: The model for training. optimizer: The optimizer we are using. scheduler: The learning rate scheduler, we call step() every step. train_dl: Dataloader for the training dataset. valid_dl: Dataloader for the validation dataset. scaler: The scaler used for mix precision training. model_avg: The stored model averaged from the start of training. tb_writer: Writer to write log messages to tensorboard. world_size: Number of nodes in DDP training. If it is 1, DDP is disabled. rank: The rank of the node in DDP training. If no DDP is used, it should be set to 0. """ model.train() tot_loss = MetricsTracker() saved_bad_model = False def save_bad_model(suffix: str = ""): save_checkpoint_impl( filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt", model=model, model_avg=model_avg, params=params, optimizer=optimizer, scheduler=scheduler, sampler=train_dl.sampler, scaler=scaler, rank=0, ) for batch_idx, batch in enumerate(train_dl): if batch_idx % 10 == 0: set_batch_count(model, get_adjusted_batch_count(params) + 100000) params.batch_idx_train += 1 batch_size = len(batch["supervisions"]["text"]) try: with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, sp=sp, batch=batch, is_training=True, ) # summary stats tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info # NOTE: We use reduction==sum and loss is computed over utterances # in the batch and there is no normalization to it so far. scaler.scale(loss).backward() # if params.continue_finetune: # set_batch_count(model, params.batch_idx_train) # else: # set_batch_count(model, params.batch_idx_train + 100000) scheduler.step_batch(params.batch_idx_train) scaler.step(optimizer) scaler.update() optimizer.zero_grad() except: # noqa save_bad_model() display_and_save_batch(batch, params=params, sp=sp) raise if params.print_diagnostics and batch_idx == 5: return if ( rank == 0 and params.batch_idx_train > 0 and params.batch_idx_train % params.average_period == 0 ): update_averaged_model( params=params, model_cur=model, model_avg=model_avg, ) if ( params.batch_idx_train > 0 and params.batch_idx_train % params.save_every_n == 0 ): save_checkpoint_with_global_batch_idx( out_dir=params.exp_dir, global_batch_idx=params.batch_idx_train, model=model, model_avg=model_avg, params=params, optimizer=optimizer, scheduler=scheduler, sampler=train_dl.sampler, scaler=scaler, rank=rank, ) remove_checkpoints( out_dir=params.exp_dir, topk=params.keep_last_k, rank=rank, ) if batch_idx % 100 == 0 and params.use_fp16: # If the grad scale was less than 1, try increasing it. The _growth_interval # of the grad scaler is configurable, but we can't configure it to have different # behavior depending on the current grad scale. cur_grad_scale = scaler._scale.item() if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0): scaler.update(cur_grad_scale * 2.0) if cur_grad_scale < 0.01: if not saved_bad_model: save_bad_model(suffix="-first-warning") saved_bad_model = True logging.warning(f"Grad scale is small: {cur_grad_scale}") if cur_grad_scale < 1.0e-05: save_bad_model() raise_grad_scale_is_too_small_error(cur_grad_scale) if batch_idx % params.log_interval == 0: cur_lr = max(scheduler.get_last_lr()) cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 logging.info( f"Epoch {params.cur_epoch}, " f"batch {batch_idx}, loss[{loss_info}], " f"tot_loss[{tot_loss}], batch size: {batch_size}, " f"lr: {cur_lr:.2e}, " + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") ) if tb_writer is not None: tb_writer.add_scalar( "train/learning_rate", cur_lr, params.batch_idx_train ) loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if params.use_fp16: tb_writer.add_scalar( "train/grad_scale", cur_grad_scale, params.batch_idx_train ) if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: logging.info("Computing validation loss") valid_info = compute_validation_loss( params=params, model=model, sp=sp, valid_dl=valid_dl, world_size=world_size, ) model.train() logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") logging.info( f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" ) if tb_writer is not None: valid_info.write_summary( tb_writer, "train/valid_", params.batch_idx_train ) loss_value = tot_loss["loss"] / tot_loss["frames"] params.train_loss = loss_value if params.train_loss < params.best_train_loss: params.best_train_epoch = params.cur_epoch params.best_train_loss = params.train_loss def run(rank, world_size, args): """ Args: rank: It is a value between 0 and `world_size-1`, which is passed automatically by `mp.spawn()` in :func:`main`. The node with rank 0 is responsible for saving checkpoint. world_size: Number of GPUs for DDP training. args: The return value of get_parser().parse_args() """ params = get_params() params.update(vars(args)) fix_random_seed(params.seed) if world_size > 1: setup_dist(rank, world_size, params.master_port) setup_logger(f"{params.exp_dir}/log/log-train") logging.info("Training started") if args.tensorboard and rank == 0: tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") else: tb_writer = None device = torch.device("cpu") if torch.cuda.is_available(): device = torch.device("cuda", rank) logging.info(f"Device: {device}") sp = spm.SentencePieceProcessor() sp.load(params.bpe_model) # is defined in local/train_bpe_model.py params.blank_id = sp.piece_to_id("") params.vocab_size = sp.get_piece_size() if not params.use_transducer: params.ctc_loss_scale = 1.0 logging.info(params) logging.info("About to create model") model = get_model(params) num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") assert params.save_every_n >= params.average_period model_avg: Optional[nn.Module] = None if params.continue_finetune: assert params.start_epoch > 0, params.start_epoch checkpoints = load_checkpoint_if_available( params=params, model=model, model_avg=model_avg ) else: modules = params.init_modules.split(",") if params.init_modules else None checkpoints = load_model_params( ckpt=params.finetune_ckpt, model=model, init_modules=modules ) if rank == 0: # model_avg is only used with rank 0 model_avg = copy.deepcopy(model).to(torch.float64) model.to(device) if world_size > 1: logging.info("Using DDP") model = DDP(model, device_ids=[rank], find_unused_parameters=True) optimizer = ScaledAdam( get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True), lr=params.base_lr, # should have no effect clipping_scale=2.0, ) scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs, warmup_start=1.0) if checkpoints and "optimizer" in checkpoints: logging.info("Loading optimizer state dict") optimizer.load_state_dict(checkpoints["optimizer"]) if ( checkpoints and "scheduler" in checkpoints and checkpoints["scheduler"] is not None ): logging.info("Loading scheduler state dict") scheduler.load_state_dict(checkpoints["scheduler"]) if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) if params.inf_check: register_inf_check_hooks(model) def remove_short_utt(c: Cut): # In ./zipformer.py, the conv module uses the following expression # for subsampling T = ((c.num_frames - 7) // 2 + 1) // 2 return T > 0 gigaspeech = GigaSpeechAsrDataModule(args) if params.use_mux: train_cuts = CutSet.mux( gigaspeech.train_cuts(), gigaspeech.fsc_train_cuts(), weights=[0.9, 0.1], ) else: train_cuts = gigaspeech.fsc_train_cuts() train_cuts = train_cuts.filter(remove_short_utt) if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: # We only load the sampler's state dict when it loads a checkpoint # saved in the middle of an epoch sampler_state_dict = checkpoints["sampler"] else: sampler_state_dict = None train_dl = gigaspeech.train_dataloaders( train_cuts, sampler_state_dict=sampler_state_dict ) valid_cuts = gigaspeech.fsc_valid_cuts() valid_cuts = valid_cuts.filter(remove_short_utt) valid_dl = gigaspeech.valid_dataloaders(valid_cuts) if not params.print_diagnostics and params.scan_for_oom_batches: scan_pessimistic_batches_for_oom( model=model, train_dl=train_dl, optimizer=optimizer, sp=sp, params=params, ) scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) for epoch in range(params.start_epoch, params.num_epochs + 1): scheduler.step_epoch(epoch - 1) fix_random_seed(params.seed + epoch - 1) train_dl.sampler.set_epoch(epoch - 1) if tb_writer is not None: tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) params.cur_epoch = epoch train_one_epoch( params=params, model=model, model_avg=model_avg, optimizer=optimizer, scheduler=scheduler, sp=sp, train_dl=train_dl, valid_dl=valid_dl, scaler=scaler, tb_writer=tb_writer, world_size=world_size, rank=rank, ) if params.print_diagnostics: diagnostic.print_diagnostics() break save_checkpoint( params=params, model=model, model_avg=model_avg, optimizer=optimizer, scheduler=scheduler, sampler=train_dl.sampler, scaler=scaler, rank=rank, ) logging.info("Done!") if world_size > 1: torch.distributed.barrier() cleanup_dist() def main(): parser = get_parser() GigaSpeechAsrDataModule.add_arguments(parser) args = parser.parse_args() args.exp_dir = Path(args.exp_dir) world_size = args.world_size assert world_size >= 1 if world_size > 1: mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) else: run(rank=0, world_size=1, args=args) if __name__ == "__main__": torch.set_num_threads(1) torch.set_num_interop_threads(1) main()