support finetune zipformer

This commit is contained in:
marcoyang 2024-01-16 10:19:51 +08:00
parent 057238c27e
commit fa96660ac9

View File

@ -3,7 +3,8 @@
# Wei Kang,
# Mingshuang Luo,
# Zengwei Yao,
# Daniel Povey)
# Daniel Povey,
# Xiaoyu Yang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
@ -24,7 +25,7 @@ Usage:
export CUDA_VISIBLE_DEVICES="0,1,2,3"
# For non-streaming model training:
./zipformer/train.py \
./zipformer/finetune.py \
--world-size 4 \
--num-epochs 30 \
--start-epoch 1 \
@ -34,7 +35,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
--max-duration 1000
# For streaming model training:
./zipformer/train.py \
./zipformer/finetune.py \
--world-size 4 \
--num-epochs 30 \
--start-epoch 1 \
@ -57,7 +58,7 @@ import logging
import warnings
from pathlib import Path
from shutil import copyfile
from typing import Any, Dict, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, Union
import k2
import optim
@ -68,7 +69,7 @@ import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule
from decoder import Decoder
from joiner import Joiner
from lhotse.cut import Cut
from lhotse.cut import Cut, CutSet
from lhotse.dataset.sampling.base import CutSampler
from lhotse.utils import fix_random_seed
from model import AsrModel
@ -123,6 +124,46 @@ def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None:
module.name = name
def add_finetune_arguments(parser: argparse.ArgumentParser):
parser.add_argument(
"--do-finetune",
type=str2bool,
default=True,
help="If true, finetune from a pre-trained checkpoint",
)
parser.add_argument(
"--use-mux",
type=str2bool,
default=False,
help="""
Whether to adapt. If true, we will mix 5% of the new data
with 95% of the original data to fine-tune. This is useful
if you want to maintain the performance on the original domain
""",
)
parser.add_argument(
"--init-modules",
type=str,
default=None,
help="""
Modules to be initialized. It matches all parameters starting with
a specific key. The keys are given with Comma seperated. If None,
all modules will be initialised. For example, if you only want to
initialise all parameters staring with "encoder", use "encoder";
if you want to initialise parameters starting with encoder or decoder,
use "encoder,joiner".
""",
)
parser.add_argument(
"--finetune-ckpt",
type=str,
default=None,
help="Fine-tuning from which checkpoint (path to a .pt file)",
)
def add_model_arguments(parser: argparse.ArgumentParser):
parser.add_argument(
"--num-encoder-layers",
@ -469,6 +510,7 @@ def get_parser():
)
add_model_arguments(parser)
add_finetune_arguments(parser)
return parser
@ -700,6 +742,54 @@ def load_checkpoint_if_available(
return saved_params
def load_model_params(
ckpt: str, model: nn.Module, init_modules: List[str] = None, strict: bool = True
):
"""Load model params from checkpoint
Args:
ckpt (str): Path to the checkpoint
model (nn.Module): model to be loaded
init_modules (list[str]): List of modules to be initialized
"""
logging.info(f"Loading checkpoint from {ckpt}")
checkpoint = torch.load(ckpt, map_location="cpu")
# if module list is empty, load the whole model from ckpt
if not init_modules:
if next(iter(checkpoint["model"])).startswith("module."):
logging.info("Loading checkpoint saved by DDP")
dst_state_dict = model.state_dict()
src_state_dict = checkpoint["model"]
for key in dst_state_dict.keys():
src_key = "{}.{}".format("module", key)
dst_state_dict[key] = src_state_dict.pop(src_key)
assert len(src_state_dict) == 0
model.load_state_dict(dst_state_dict, strict=strict)
else:
model.load_state_dict(checkpoint["model"], strict=strict)
else:
src_state_dict = checkpoint["model"]
dst_state_dict = model.state_dict()
for module in init_modules:
logging.info(f"Loading parameters starting with prefix {module}")
src_keys = [
k for k in src_state_dict.keys() if k.startswith(module.strip() + ".")
]
dst_keys = [
k for k in dst_state_dict.keys() if k.startswith(module.strip() + ".")
]
assert set(src_keys) == set(dst_keys) # two sets should match exactly
for key in src_keys:
dst_state_dict[key] = src_state_dict.pop(key)
model.load_state_dict(dst_state_dict, strict=strict)
return None
def save_checkpoint(
params: AttributeDict,
model: Union[nn.Module, DDP],
@ -881,7 +971,8 @@ def train_one_epoch(
scheduler: LRSchedulerType,
sp: spm.SentencePieceProcessor,
train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader,
valid_dls: torch.utils.data.DataLoader,
valid_sets: List[str],
scaler: GradScaler,
model_avg: Optional[nn.Module] = None,
tb_writer: Optional[SummaryWriter] = None,
@ -1052,23 +1143,26 @@ def train_one_epoch(
)
if batch_idx % params.valid_interval == 0 and not params.print_diagnostics:
logging.info("Computing validation loss")
valid_info = compute_validation_loss(
params=params,
model=model,
sp=sp,
valid_dl=valid_dl,
world_size=world_size,
)
model.train()
logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
logging.info(
f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
)
if tb_writer is not None:
valid_info.write_summary(
tb_writer, "train/valid_", params.batch_idx_train
for valid_set, valid_dl in zip(valid_sets, valid_dls):
logging.info(f"Computing validation loss on {valid_set}")
valid_info = compute_validation_loss(
params=params,
model=model,
sp=sp,
valid_dl=valid_dl,
world_size=world_size,
)
model.train()
logging.info(
f"Validation on {valid_set}: Epoch {params.cur_epoch}, validation: {valid_info}"
)
logging.info(
f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
)
if tb_writer is not None:
valid_info.write_summary(
tb_writer, f"train/{valid_set}_valid_", params.batch_idx_train
)
loss_value = tot_loss["loss"] / tot_loss["frames"]
params.train_loss = loss_value
@ -1133,10 +1227,23 @@ def run(rank, world_size, args):
# model_avg is only used with rank 0
model_avg = copy.deepcopy(model).to(torch.float64)
assert params.start_epoch > 0, params.start_epoch
checkpoints = load_checkpoint_if_available(
params=params, model=model, model_avg=model_avg
)
# load model parameters for model fine-tuning
if params.do_finetune:
assert params.start_epoch == 1, "Fine-tune must start from epoch 1"
modules = params.init_modules.split(",") if params.init_modules else None
checkpoints = load_model_params(
ckpt=params.finetune_ckpt, model=model, init_modules=modules
)
# Need to update the model_avg if use initialisation
if rank == 0:
# model_avg is only used with rank 0
model_avg = copy.deepcopy(model).to(torch.float64)
else:
# resuming training
assert params.start_epoch > 1, params.start_epoch
checkpoints = load_checkpoint_if_available(
params=params, model=model, model_avg=model_avg
)
model.to(device)
if world_size > 1:
@ -1174,10 +1281,18 @@ def run(rank, world_size, args):
librispeech = LibriSpeechAsrDataModule(args)
train_cuts = librispeech.train_clean_100_cuts()
if params.full_libri:
train_cuts += librispeech.train_clean_360_cuts()
train_cuts += librispeech.train_other_500_cuts()
gigaspeech_cuts = librispeech.gigaspeech_subset_small_cuts()
if params.use_mux:
librispeech_cuts = librispeech.train_all_shuf_cuts()
train_cuts = CutSet.mux(
gigaspeech_cuts, # num cuts = 688182
librispeech_cuts, # num cuts = 843723
weights=[688182, 843723],
stop_early=True,
)
else:
train_cuts = gigaspeech_cuts
logging.info(train_cuts)
def remove_short_and_long_utt(c: Cut):
# Keep only utterances with duration between 1 second and 20 seconds
@ -1231,7 +1346,13 @@ def run(rank, world_size, args):
valid_cuts = librispeech.dev_clean_cuts()
valid_cuts += librispeech.dev_other_cuts()
valid_dl = librispeech.valid_dataloaders(valid_cuts)
gigaspeech_dev_cuts = librispeech.gigaspeech_dev_cuts()
valid_sets = ["librispeech", "gigaspeech"]
valid_dls = [
librispeech.valid_dataloaders(valid_cuts),
librispeech.valid_dataloaders(gigaspeech_dev_cuts),
]
if not params.print_diagnostics:
scan_pessimistic_batches_for_oom(
@ -1265,7 +1386,8 @@ def run(rank, world_size, args):
scheduler=scheduler,
sp=sp,
train_dl=train_dl,
valid_dl=valid_dl,
valid_dls=valid_dls,
valid_sets=valid_sets,
scaler=scaler,
tb_writer=tb_writer,
world_size=world_size,