From fa96660ac908da772579f5f745db41020aee4a76 Mon Sep 17 00:00:00 2001 From: marcoyang Date: Tue, 16 Jan 2024 10:19:51 +0800 Subject: [PATCH] support finetune zipformer --- egs/librispeech/ASR/zipformer/finetune.py | 186 ++++++++++++++++++---- 1 file changed, 154 insertions(+), 32 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/finetune.py b/egs/librispeech/ASR/zipformer/finetune.py index 7009f3346..cf62b33c6 100755 --- a/egs/librispeech/ASR/zipformer/finetune.py +++ b/egs/librispeech/ASR/zipformer/finetune.py @@ -3,7 +3,8 @@ # Wei Kang, # Mingshuang Luo, # Zengwei Yao, -# Daniel Povey) +# Daniel Povey, +# Xiaoyu Yang) # # See ../../../../LICENSE for clarification regarding multiple authors # @@ -24,7 +25,7 @@ Usage: export CUDA_VISIBLE_DEVICES="0,1,2,3" # For non-streaming model training: -./zipformer/train.py \ +./zipformer/finetune.py \ --world-size 4 \ --num-epochs 30 \ --start-epoch 1 \ @@ -34,7 +35,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3" --max-duration 1000 # For streaming model training: -./zipformer/train.py \ +./zipformer/finetune.py \ --world-size 4 \ --num-epochs 30 \ --start-epoch 1 \ @@ -57,7 +58,7 @@ import logging import warnings from pathlib import Path from shutil import copyfile -from typing import Any, Dict, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union import k2 import optim @@ -68,7 +69,7 @@ import torch.nn as nn from asr_datamodule import LibriSpeechAsrDataModule from decoder import Decoder from joiner import Joiner -from lhotse.cut import Cut +from lhotse.cut import Cut, CutSet from lhotse.dataset.sampling.base import CutSampler from lhotse.utils import fix_random_seed from model import AsrModel @@ -123,6 +124,46 @@ def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: module.name = name +def add_finetune_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--do-finetune", + type=str2bool, + default=True, + help="If true, finetune from a pre-trained checkpoint", + ) + parser.add_argument( + "--use-mux", + type=str2bool, + default=False, + help=""" + Whether to adapt. If true, we will mix 5% of the new data + with 95% of the original data to fine-tune. This is useful + if you want to maintain the performance on the original domain + """, + ) + + parser.add_argument( + "--init-modules", + type=str, + default=None, + help=""" + Modules to be initialized. It matches all parameters starting with + a specific key. The keys are given with Comma seperated. If None, + all modules will be initialised. For example, if you only want to + initialise all parameters staring with "encoder", use "encoder"; + if you want to initialise parameters starting with encoder or decoder, + use "encoder,joiner". + """, + ) + + parser.add_argument( + "--finetune-ckpt", + type=str, + default=None, + help="Fine-tuning from which checkpoint (path to a .pt file)", + ) + + def add_model_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--num-encoder-layers", @@ -469,6 +510,7 @@ def get_parser(): ) add_model_arguments(parser) + add_finetune_arguments(parser) return parser @@ -700,6 +742,54 @@ def load_checkpoint_if_available( return saved_params +def load_model_params( + ckpt: str, model: nn.Module, init_modules: List[str] = None, strict: bool = True +): + """Load model params from checkpoint + + Args: + ckpt (str): Path to the checkpoint + model (nn.Module): model to be loaded + init_modules (list[str]): List of modules to be initialized + + """ + logging.info(f"Loading checkpoint from {ckpt}") + checkpoint = torch.load(ckpt, map_location="cpu") + + # if module list is empty, load the whole model from ckpt + if not init_modules: + if next(iter(checkpoint["model"])).startswith("module."): + logging.info("Loading checkpoint saved by DDP") + + dst_state_dict = model.state_dict() + src_state_dict = checkpoint["model"] + for key in dst_state_dict.keys(): + src_key = "{}.{}".format("module", key) + dst_state_dict[key] = src_state_dict.pop(src_key) + assert len(src_state_dict) == 0 + model.load_state_dict(dst_state_dict, strict=strict) + else: + model.load_state_dict(checkpoint["model"], strict=strict) + else: + src_state_dict = checkpoint["model"] + dst_state_dict = model.state_dict() + for module in init_modules: + logging.info(f"Loading parameters starting with prefix {module}") + src_keys = [ + k for k in src_state_dict.keys() if k.startswith(module.strip() + ".") + ] + dst_keys = [ + k for k in dst_state_dict.keys() if k.startswith(module.strip() + ".") + ] + assert set(src_keys) == set(dst_keys) # two sets should match exactly + for key in src_keys: + dst_state_dict[key] = src_state_dict.pop(key) + + model.load_state_dict(dst_state_dict, strict=strict) + + return None + + def save_checkpoint( params: AttributeDict, model: Union[nn.Module, DDP], @@ -881,7 +971,8 @@ def train_one_epoch( scheduler: LRSchedulerType, sp: spm.SentencePieceProcessor, train_dl: torch.utils.data.DataLoader, - valid_dl: torch.utils.data.DataLoader, + valid_dls: torch.utils.data.DataLoader, + valid_sets: List[str], scaler: GradScaler, model_avg: Optional[nn.Module] = None, tb_writer: Optional[SummaryWriter] = None, @@ -1052,23 +1143,26 @@ def train_one_epoch( ) if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: - logging.info("Computing validation loss") - valid_info = compute_validation_loss( - params=params, - model=model, - sp=sp, - valid_dl=valid_dl, - world_size=world_size, - ) - model.train() - logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") - logging.info( - f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" - ) - if tb_writer is not None: - valid_info.write_summary( - tb_writer, "train/valid_", params.batch_idx_train + for valid_set, valid_dl in zip(valid_sets, valid_dls): + logging.info(f"Computing validation loss on {valid_set}") + valid_info = compute_validation_loss( + params=params, + model=model, + sp=sp, + valid_dl=valid_dl, + world_size=world_size, ) + model.train() + logging.info( + f"Validation on {valid_set}: Epoch {params.cur_epoch}, validation: {valid_info}" + ) + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + if tb_writer is not None: + valid_info.write_summary( + tb_writer, f"train/{valid_set}_valid_", params.batch_idx_train + ) loss_value = tot_loss["loss"] / tot_loss["frames"] params.train_loss = loss_value @@ -1133,10 +1227,23 @@ def run(rank, world_size, args): # model_avg is only used with rank 0 model_avg = copy.deepcopy(model).to(torch.float64) - assert params.start_epoch > 0, params.start_epoch - checkpoints = load_checkpoint_if_available( - params=params, model=model, model_avg=model_avg - ) + # load model parameters for model fine-tuning + if params.do_finetune: + assert params.start_epoch == 1, "Fine-tune must start from epoch 1" + modules = params.init_modules.split(",") if params.init_modules else None + checkpoints = load_model_params( + ckpt=params.finetune_ckpt, model=model, init_modules=modules + ) + # Need to update the model_avg if use initialisation + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model).to(torch.float64) + else: + # resuming training + assert params.start_epoch > 1, params.start_epoch + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) model.to(device) if world_size > 1: @@ -1174,10 +1281,18 @@ def run(rank, world_size, args): librispeech = LibriSpeechAsrDataModule(args) - train_cuts = librispeech.train_clean_100_cuts() - if params.full_libri: - train_cuts += librispeech.train_clean_360_cuts() - train_cuts += librispeech.train_other_500_cuts() + gigaspeech_cuts = librispeech.gigaspeech_subset_small_cuts() + if params.use_mux: + librispeech_cuts = librispeech.train_all_shuf_cuts() + train_cuts = CutSet.mux( + gigaspeech_cuts, # num cuts = 688182 + librispeech_cuts, # num cuts = 843723 + weights=[688182, 843723], + stop_early=True, + ) + else: + train_cuts = gigaspeech_cuts + logging.info(train_cuts) def remove_short_and_long_utt(c: Cut): # Keep only utterances with duration between 1 second and 20 seconds @@ -1231,7 +1346,13 @@ def run(rank, world_size, args): valid_cuts = librispeech.dev_clean_cuts() valid_cuts += librispeech.dev_other_cuts() - valid_dl = librispeech.valid_dataloaders(valid_cuts) + gigaspeech_dev_cuts = librispeech.gigaspeech_dev_cuts() + + valid_sets = ["librispeech", "gigaspeech"] + valid_dls = [ + librispeech.valid_dataloaders(valid_cuts), + librispeech.valid_dataloaders(gigaspeech_dev_cuts), + ] if not params.print_diagnostics: scan_pessimistic_batches_for_oom( @@ -1265,7 +1386,8 @@ def run(rank, world_size, args): scheduler=scheduler, sp=sp, train_dl=train_dl, - valid_dl=valid_dl, + valid_dls=valid_dls, + valid_sets=valid_sets, scaler=scaler, tb_writer=tb_writer, world_size=world_size,