#!/usr/bin/env python3 # Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang, # Wei Kang, # Mingshuang Luo, # Zengwei Yao, # Yifan Yang, # Daniel Povey) # # See ../../../../LICENSE for clarification regarding multiple authors # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ Usage: export CUDA_VISIBLE_DEVICES="0,1,2,3" # For non-streaming model finetuning: ./zipformer/finetune.py \ --world-size 4 \ --num-epochs 10 \ --start-epoch 1 \ --use-fp16 1 \ --exp-dir zipformer/exp \ --max-duration 1000 # For non-streaming model finetuning with mux (original dataset): ./zipformer/finetune.py \ --world-size 4 \ --num-epochs 10 \ --start-epoch 1 \ --use-mux 1 \ --use-fp16 1 \ --exp-dir zipformer/exp \ --max-duration 1000 # For streaming model finetuning: ./zipformer/fintune.py \ --world-size 4 \ --num-epochs 10 \ --start-epoch 1 \ --use-fp16 1 \ --exp-dir zipformer/exp \ --causal 1 \ --max-duration 1000 # For streaming model finetuning with mux (original dataset): ./zipformer/fintune.py \ --world-size 4 \ --num-epochs 10 \ --start-epoch 1 \ --use-fp16 1 \ --exp-dir zipformer/exp \ --causal 1 \ --max-duration 1000 """ import argparse import copy import logging import warnings from pathlib import Path from typing import List, Optional, Tuple, Union import k2 import optim import torch import torch.multiprocessing as mp import torch.nn as nn from asr_datamodule import WenetSpeechAsrDataModule from lhotse.cut import Cut, CutSet from lhotse.utils import fix_random_seed from optim import Eden, ScaledAdam from torch import Tensor from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from train import ( add_model_arguments, add_training_arguments, compute_validation_loss, display_and_save_batch, encode_text, get_adjusted_batch_count, get_model, get_params, load_checkpoint_if_available, save_checkpoint, scan_pessimistic_batches_for_oom, set_batch_count, ) from icefall import diagnostics from icefall.checkpoint import remove_checkpoints from icefall.checkpoint import save_checkpoint as save_checkpoint_impl from icefall.checkpoint import ( save_checkpoint_with_global_batch_idx, update_averaged_model, ) from icefall.dist import cleanup_dist, setup_dist from icefall.err import raise_grad_scale_is_too_small_error from icefall.hooks import register_inf_check_hooks from icefall.utils import ( AttributeDict, MetricsTracker, get_parameter_groups_with_lrs, num_tokens, setup_logger, str2bool, text_to_pinyin, ) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] def add_finetune_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--use-mux", type=str2bool, default=False, help=""" Whether to adapt. If true, we will mix 5% of the new data with 95% of the original data to fine-tune. """, ) parser.add_argument( "--init-modules", type=str, default=None, help=""" Modules to be initialized. It matches all parameters starting with a specific key. The keys are given with Comma seperated. If None, all modules will be initialised. For example, if you only want to initialise all parameters staring with "encoder", use "encoder"; if you want to initialise parameters starting with encoder or decoder, use "encoder,joiner". """, ) parser.add_argument( "--finetune-ckpt", type=str, default=None, help="Fine-tuning from which checkpoint (a path to a .pt file)", ) parser.add_argument( "--continue-finetune", type=str2bool, default=False, help="Continue finetuning or finetune from pre-trained model", ) def get_parser(): parser = argparse.ArgumentParser( formatter_class=argparse.ArgumentDefaultsHelpFormatter ) parser.add_argument( "--lang-dir", type=str, default="data/lang_partial_tone", help="Path to the pinyin lang directory", ) parser.add_argument( "--pinyin-type", type=str, default="partial_with_tone", help=""" The style of the output pinyin, should be: full_with_tone : zhōng guó full_no_tone : zhong guo partial_with_tone : zh ōng g uó partial_no_tone : zh ong g uo """, ) parser.add_argument( "--pinyin-errors", default="split", type=str, help="""How to handle characters that has no pinyin, see `text_to_pinyin` in icefall/utils.py for details """, ) add_training_arguments(parser) add_model_arguments(parser) add_finetune_arguments(parser) return parser def load_model_params( ckpt: str, model: nn.Module, init_modules: List[str] = None, strict: bool = True ): """Load model params from checkpoint Args: ckpt (str): Path to the checkpoint model (nn.Module): model to be loaded """ logging.info(f"Loading checkpoint from {ckpt}") checkpoint = torch.load(ckpt, map_location="cpu") # if module list is empty, load the whole model from ckpt if not init_modules: if next(iter(checkpoint["model"])).startswith("module."): logging.info("Loading checkpoint saved by DDP") dst_state_dict = model.state_dict() src_state_dict = checkpoint["model"] for key in dst_state_dict.keys(): src_key = "{}.{}".format("module", key) dst_state_dict[key] = src_state_dict.pop(src_key) assert len(src_state_dict) == 0 model.load_state_dict(dst_state_dict, strict=strict) else: model.load_state_dict(checkpoint["model"], strict=strict) else: src_state_dict = checkpoint["model"] dst_state_dict = model.state_dict() for module in init_modules: logging.info(f"Loading parameters starting with prefix {module}") src_keys = [ k for k in src_state_dict.keys() if k.startswith(module.strip() + ".") ] dst_keys = [ k for k in dst_state_dict.keys() if k.startswith(module.strip() + ".") ] assert set(src_keys) == set(dst_keys) # two sets should match exactly for key in src_keys: dst_state_dict[key] = src_state_dict.pop(key) model.load_state_dict(dst_state_dict, strict=strict) return None def compute_loss( params: AttributeDict, model: Union[nn.Module, DDP], batch: dict, is_training: bool, ) -> Tuple[Tensor, MetricsTracker]: """ Compute loss given the model and its inputs. Args: params: Parameters for training. See :func:`get_params`. model: The model for training. It is an instance of Zipformer in our case. batch: A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` for the content in it. is_training: True for training. False for validation. When it is True, this function enables autograd during computation; when it is False, it disables autograd. warmup: a floating point value which increases throughout training; values >= 1.0 are fully warmed up and have all modules present. """ device = model.device if isinstance(model, DDP) else next(model.parameters()).device feature = batch["inputs"] # at entry, feature is (N, T, C) assert feature.ndim == 3 feature = feature.to(device) supervisions = batch["supervisions"] feature_lens = supervisions["num_frames"].to(device) batch_idx_train = params.batch_idx_train warm_step = params.warm_step texts = batch["supervisions"]["text"] y = [c.supervisions[0].tokens for c in supervisions["cut"]] y = k2.RaggedTensor(y) with torch.set_grad_enabled(is_training): losses = model( x=feature, x_lens=feature_lens, y=y, prune_range=params.prune_range, am_scale=params.am_scale, lm_scale=params.lm_scale, ) simple_loss, pruned_loss, ctc_loss = losses[:3] loss = 0.0 if params.use_transducer: s = params.simple_loss_scale # take down the scale on the simple loss from 1.0 at the start # to params.simple_loss scale by warm_step. simple_loss_scale = ( s if batch_idx_train >= warm_step else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) ) pruned_loss_scale = ( 1.0 if batch_idx_train >= warm_step else 0.1 + 0.9 * (batch_idx_train / warm_step) ) loss += simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss if params.use_ctc: loss += params.ctc_loss_scale * ctc_loss assert loss.requires_grad == is_training info = MetricsTracker() with warnings.catch_warnings(): warnings.simplefilter("ignore") info["frames"] = (feature_lens // params.subsampling_factor).sum().item() # Note: We use reduction=sum while computing the loss. info["loss"] = loss.detach().cpu().item() if params.use_transducer: info["simple_loss"] = simple_loss.detach().cpu().item() info["pruned_loss"] = pruned_loss.detach().cpu().item() if params.use_ctc: info["ctc_loss"] = ctc_loss.detach().cpu().item() return loss, info def train_one_epoch( params: AttributeDict, model: Union[nn.Module, DDP], optimizer: torch.optim.Optimizer, scheduler: LRSchedulerType, train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, scaler: GradScaler, model_avg: Optional[nn.Module] = None, tb_writer: Optional[SummaryWriter] = None, world_size: int = 1, rank: int = 0, ) -> None: """Train the model for one epoch. The training loss from the mean of all frames is saved in `params.train_loss`. It runs the validation process every `params.valid_interval` batches. Args: params: It is returned by :func:`get_params`. model: The model for training. optimizer: The optimizer we are using. scheduler: The learning rate scheduler, we call step() every step. train_dl: Dataloader for the training dataset. valid_dl: Dataloader for the validation dataset. scaler: The scaler used for mix precision training. model_avg: The stored model averaged from the start of training. tb_writer: Writer to write log messages to tensorboard. world_size: Number of nodes in DDP training. If it is 1, DDP is disabled. rank: The rank of the node in DDP training. If no DDP is used, it should be set to 0. """ model.train() tot_loss = MetricsTracker() saved_bad_model = False def save_bad_model(suffix: str = ""): save_checkpoint_impl( filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt", model=model, model_avg=model_avg, params=params, optimizer=optimizer, scheduler=scheduler, sampler=train_dl.sampler, scaler=scaler, rank=0, ) for batch_idx, batch in enumerate(train_dl): if batch_idx % 10 == 0: set_batch_count(model, get_adjusted_batch_count(params) + 100000) params.batch_idx_train += 1 batch_size = len(batch["supervisions"]["text"]) try: with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, batch=batch, is_training=True, ) # summary stats tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info # NOTE: We use reduction==sum and loss is computed over utterances # in the batch and there is no normalization to it so far. scaler.scale(loss).backward() scheduler.step_batch(params.batch_idx_train) scaler.step(optimizer) scaler.update() optimizer.zero_grad() except: # noqa save_bad_model() display_and_save_batch(batch, params=params) raise if params.print_diagnostics and batch_idx == 5: return if ( rank == 0 and params.batch_idx_train > 0 and params.batch_idx_train % params.average_period == 0 ): update_averaged_model( params=params, model_cur=model, model_avg=model_avg, ) if ( params.batch_idx_train > 0 and params.batch_idx_train % params.save_every_n == 0 ): save_checkpoint_with_global_batch_idx( out_dir=params.exp_dir, global_batch_idx=params.batch_idx_train, model=model, model_avg=model_avg, params=params, optimizer=optimizer, scheduler=scheduler, sampler=train_dl.sampler, scaler=scaler, rank=rank, ) remove_checkpoints( out_dir=params.exp_dir, topk=params.keep_last_k, rank=rank, ) if batch_idx % 100 == 0 and params.use_fp16: # If the grad scale was less than 1, try increasing it. The _growth_interval # of the grad scaler is configurable, but we can't configure it to have different # behavior depending on the current grad scale. cur_grad_scale = scaler._scale.item() if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0): scaler.update(cur_grad_scale * 2.0) if cur_grad_scale < 0.01: if not saved_bad_model: save_bad_model(suffix="-first-warning") saved_bad_model = True logging.warning(f"Grad scale is small: {cur_grad_scale}") if cur_grad_scale < 1.0e-05: save_bad_model() raise_grad_scale_is_too_small_error(cur_grad_scale) if batch_idx % params.log_interval == 0: cur_lr = max(scheduler.get_last_lr()) cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 logging.info( f"Epoch {params.cur_epoch}, " f"batch {batch_idx}, loss[{loss_info}], " f"tot_loss[{tot_loss}], batch size: {batch_size}, " f"lr: {cur_lr:.2e}, " + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") ) if tb_writer is not None: tb_writer.add_scalar( "train/learning_rate", cur_lr, params.batch_idx_train ) loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if params.use_fp16: tb_writer.add_scalar( "train/grad_scale", cur_grad_scale, params.batch_idx_train ) if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: logging.info("Computing validation loss") valid_info = compute_validation_loss( params=params, model=model, valid_dl=valid_dl, world_size=world_size, ) model.train() logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") logging.info( f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" ) if tb_writer is not None: valid_info.write_summary( tb_writer, "train/valid_", params.batch_idx_train ) loss_value = tot_loss["loss"] / tot_loss["frames"] params.train_loss = loss_value if params.train_loss < params.best_train_loss: params.best_train_epoch = params.cur_epoch params.best_train_loss = params.train_loss def run(rank, world_size, args): """ Args: rank: It is a value between 0 and `world_size-1`, which is passed automatically by `mp.spawn()` in :func:`main`. The node with rank 0 is responsible for saving checkpoint. world_size: Number of GPUs for DDP training. args: The return value of get_parser().parse_args() """ params = get_params() params.update(vars(args)) fix_random_seed(params.seed) if world_size > 1: setup_dist(rank, world_size, params.master_port) setup_logger(f"{params.exp_dir}/log/log-train") logging.info("Training started") if args.tensorboard and rank == 0: tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") else: tb_writer = None device = torch.device("cpu") if torch.cuda.is_available(): device = torch.device("cuda", rank) logging.info(f"Device: {device}") token_table = k2.SymbolTable.from_file(params.lang_dir / "tokens.txt") params.blank_id = token_table[""] params.vocab_size = num_tokens(token_table) + 1 if not params.use_transducer: params.ctc_loss_scale = 1.0 logging.info(params) logging.info("About to create model") model = get_model(params) num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") assert params.save_every_n >= params.average_period model_avg: Optional[nn.Module] = None if params.continue_finetune: assert params.start_epoch > 0, params.start_epoch checkpoints = load_checkpoint_if_available( params=params, model=model, model_avg=model_avg ) else: modules = params.init_modules.split(",") if params.init_modules else None checkpoints = load_model_params( ckpt=params.finetune_ckpt, model=model, init_modules=modules ) if rank == 0: # model_avg is only used with rank 0 model_avg = copy.deepcopy(model).to(torch.float64) model.to(device) if world_size > 1: logging.info("Using DDP") model = DDP(model, device_ids=[rank], find_unused_parameters=True) optimizer = ScaledAdam( get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True), lr=params.base_lr, # should have no effect clipping_scale=2.0, ) scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs, warmup_start=1.0) if checkpoints and "optimizer" in checkpoints: logging.info("Loading optimizer state dict") optimizer.load_state_dict(checkpoints["optimizer"]) if ( checkpoints and "scheduler" in checkpoints and checkpoints["scheduler"] is not None ): logging.info("Loading scheduler state dict") scheduler.load_state_dict(checkpoints["scheduler"]) if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) if params.inf_check: register_inf_check_hooks(model) def remove_short_utt(c: Cut): if c.duration > 15: return False # In ./zipformer.py, the conv module uses the following expression # for subsampling T = ((c.num_frames - 7) // 2 + 1) // 2 return T > 0 wenetspeech = WenetSpeechAsrDataModule(args) if params.use_mux: train_cuts = CutSet.mux( wenetspeech.train_cuts(), wenetspeech.nihaowenwen_train_cuts(), weights=[0.9, 0.1], ) else: train_cuts = wenetspeech.nihaowenwen_train_cuts() _encode_text = partial(encode_text, token_table=token_table, params=params) train_cuts = train_cuts.filter(remove_short_utt) train_cuts = train_cuts.map(_encode_text) if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: # We only load the sampler's state dict when it loads a checkpoint # saved in the middle of an epoch sampler_state_dict = checkpoints["sampler"] else: sampler_state_dict = None train_dl = wenetspeech.train_dataloaders( train_cuts, sampler_state_dict=sampler_state_dict ) valid_cuts = wenetspeech.nihaowenwen_dev_cuts() valid_cuts = valid_cuts.filter(remove_short_utt) valid_cuts = valid_cuts.map(_encode_text) valid_dl = wenetspeech.valid_dataloaders(valid_cuts) if not params.print_diagnostics and params.scan_for_oom_batches: scan_pessimistic_batches_for_oom( model=model, train_dl=train_dl, optimizer=optimizer, params=params, ) scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) for epoch in range(params.start_epoch, params.num_epochs + 1): scheduler.step_epoch(epoch - 1) fix_random_seed(params.seed + epoch - 1) train_dl.sampler.set_epoch(epoch - 1) if tb_writer is not None: tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) params.cur_epoch = epoch train_one_epoch( params=params, model=model, model_avg=model_avg, optimizer=optimizer, scheduler=scheduler, train_dl=train_dl, valid_dl=valid_dl, scaler=scaler, tb_writer=tb_writer, world_size=world_size, rank=rank, ) if params.print_diagnostics: diagnostic.print_diagnostics() break save_checkpoint( params=params, model=model, model_avg=model_avg, optimizer=optimizer, scheduler=scheduler, sampler=train_dl.sampler, scaler=scaler, rank=rank, ) logging.info("Done!") if world_size > 1: torch.distributed.barrier() cleanup_dist() def main(): parser = get_parser() WenetSpeechAsrDataModule.add_arguments(parser) args = parser.parse_args() args.exp_dir = Path(args.exp_dir) args.lang_dir = Path(args.lang_dir) args.return_cuts = True world_size = args.world_size assert world_size >= 1 if world_size > 1: mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) else: run(rank=0, world_size=1, args=args) if __name__ == "__main__": torch.set_num_threads(1) torch.set_num_interop_threads(1) main()