support finetune zipformer

This commit is contained in:
marcoyang 2024-01-16 10:19:51 +08:00
parent 057238c27e
commit fa96660ac9

View File

@ -3,7 +3,8 @@
# Wei Kang, # Wei Kang,
# Mingshuang Luo, # Mingshuang Luo,
# Zengwei Yao, # Zengwei Yao,
# Daniel Povey) # Daniel Povey,
# Xiaoyu Yang)
# #
# See ../../../../LICENSE for clarification regarding multiple authors # See ../../../../LICENSE for clarification regarding multiple authors
# #
@ -24,7 +25,7 @@ Usage:
export CUDA_VISIBLE_DEVICES="0,1,2,3" export CUDA_VISIBLE_DEVICES="0,1,2,3"
# For non-streaming model training: # For non-streaming model training:
./zipformer/train.py \ ./zipformer/finetune.py \
--world-size 4 \ --world-size 4 \
--num-epochs 30 \ --num-epochs 30 \
--start-epoch 1 \ --start-epoch 1 \
@ -34,7 +35,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
--max-duration 1000 --max-duration 1000
# For streaming model training: # For streaming model training:
./zipformer/train.py \ ./zipformer/finetune.py \
--world-size 4 \ --world-size 4 \
--num-epochs 30 \ --num-epochs 30 \
--start-epoch 1 \ --start-epoch 1 \
@ -57,7 +58,7 @@ import logging
import warnings import warnings
from pathlib import Path from pathlib import Path
from shutil import copyfile from shutil import copyfile
from typing import Any, Dict, Optional, Tuple, Union from typing import Any, Dict, List, Optional, Tuple, Union
import k2 import k2
import optim import optim
@ -68,7 +69,7 @@ import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule from asr_datamodule import LibriSpeechAsrDataModule
from decoder import Decoder from decoder import Decoder
from joiner import Joiner from joiner import Joiner
from lhotse.cut import Cut from lhotse.cut import Cut, CutSet
from lhotse.dataset.sampling.base import CutSampler from lhotse.dataset.sampling.base import CutSampler
from lhotse.utils import fix_random_seed from lhotse.utils import fix_random_seed
from model import AsrModel from model import AsrModel
@ -123,6 +124,46 @@ def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None:
module.name = name module.name = name
def add_finetune_arguments(parser: argparse.ArgumentParser):
parser.add_argument(
"--do-finetune",
type=str2bool,
default=True,
help="If true, finetune from a pre-trained checkpoint",
)
parser.add_argument(
"--use-mux",
type=str2bool,
default=False,
help="""
Whether to adapt. If true, we will mix 5% of the new data
with 95% of the original data to fine-tune. This is useful
if you want to maintain the performance on the original domain
""",
)
parser.add_argument(
"--init-modules",
type=str,
default=None,
help="""
Modules to be initialized. It matches all parameters starting with
a specific key. The keys are given with Comma seperated. If None,
all modules will be initialised. For example, if you only want to
initialise all parameters staring with "encoder", use "encoder";
if you want to initialise parameters starting with encoder or decoder,
use "encoder,joiner".
""",
)
parser.add_argument(
"--finetune-ckpt",
type=str,
default=None,
help="Fine-tuning from which checkpoint (path to a .pt file)",
)
def add_model_arguments(parser: argparse.ArgumentParser): def add_model_arguments(parser: argparse.ArgumentParser):
parser.add_argument( parser.add_argument(
"--num-encoder-layers", "--num-encoder-layers",
@ -469,6 +510,7 @@ def get_parser():
) )
add_model_arguments(parser) add_model_arguments(parser)
add_finetune_arguments(parser)
return parser return parser
@ -700,6 +742,54 @@ def load_checkpoint_if_available(
return saved_params return saved_params
def load_model_params(
ckpt: str, model: nn.Module, init_modules: List[str] = None, strict: bool = True
):
"""Load model params from checkpoint
Args:
ckpt (str): Path to the checkpoint
model (nn.Module): model to be loaded
init_modules (list[str]): List of modules to be initialized
"""
logging.info(f"Loading checkpoint from {ckpt}")
checkpoint = torch.load(ckpt, map_location="cpu")
# if module list is empty, load the whole model from ckpt
if not init_modules:
if next(iter(checkpoint["model"])).startswith("module."):
logging.info("Loading checkpoint saved by DDP")
dst_state_dict = model.state_dict()
src_state_dict = checkpoint["model"]
for key in dst_state_dict.keys():
src_key = "{}.{}".format("module", key)
dst_state_dict[key] = src_state_dict.pop(src_key)
assert len(src_state_dict) == 0
model.load_state_dict(dst_state_dict, strict=strict)
else:
model.load_state_dict(checkpoint["model"], strict=strict)
else:
src_state_dict = checkpoint["model"]
dst_state_dict = model.state_dict()
for module in init_modules:
logging.info(f"Loading parameters starting with prefix {module}")
src_keys = [
k for k in src_state_dict.keys() if k.startswith(module.strip() + ".")
]
dst_keys = [
k for k in dst_state_dict.keys() if k.startswith(module.strip() + ".")
]
assert set(src_keys) == set(dst_keys) # two sets should match exactly
for key in src_keys:
dst_state_dict[key] = src_state_dict.pop(key)
model.load_state_dict(dst_state_dict, strict=strict)
return None
def save_checkpoint( def save_checkpoint(
params: AttributeDict, params: AttributeDict,
model: Union[nn.Module, DDP], model: Union[nn.Module, DDP],
@ -881,7 +971,8 @@ def train_one_epoch(
scheduler: LRSchedulerType, scheduler: LRSchedulerType,
sp: spm.SentencePieceProcessor, sp: spm.SentencePieceProcessor,
train_dl: torch.utils.data.DataLoader, train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader, valid_dls: torch.utils.data.DataLoader,
valid_sets: List[str],
scaler: GradScaler, scaler: GradScaler,
model_avg: Optional[nn.Module] = None, model_avg: Optional[nn.Module] = None,
tb_writer: Optional[SummaryWriter] = None, tb_writer: Optional[SummaryWriter] = None,
@ -1052,23 +1143,26 @@ def train_one_epoch(
) )
if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: if batch_idx % params.valid_interval == 0 and not params.print_diagnostics:
logging.info("Computing validation loss") for valid_set, valid_dl in zip(valid_sets, valid_dls):
valid_info = compute_validation_loss( logging.info(f"Computing validation loss on {valid_set}")
params=params, valid_info = compute_validation_loss(
model=model, params=params,
sp=sp, model=model,
valid_dl=valid_dl, sp=sp,
world_size=world_size, valid_dl=valid_dl,
) world_size=world_size,
model.train()
logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
logging.info(
f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
)
if tb_writer is not None:
valid_info.write_summary(
tb_writer, "train/valid_", params.batch_idx_train
) )
model.train()
logging.info(
f"Validation on {valid_set}: Epoch {params.cur_epoch}, validation: {valid_info}"
)
logging.info(
f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
)
if tb_writer is not None:
valid_info.write_summary(
tb_writer, f"train/{valid_set}_valid_", params.batch_idx_train
)
loss_value = tot_loss["loss"] / tot_loss["frames"] loss_value = tot_loss["loss"] / tot_loss["frames"]
params.train_loss = loss_value params.train_loss = loss_value
@ -1133,10 +1227,23 @@ def run(rank, world_size, args):
# model_avg is only used with rank 0 # model_avg is only used with rank 0
model_avg = copy.deepcopy(model).to(torch.float64) model_avg = copy.deepcopy(model).to(torch.float64)
assert params.start_epoch > 0, params.start_epoch # load model parameters for model fine-tuning
checkpoints = load_checkpoint_if_available( if params.do_finetune:
params=params, model=model, model_avg=model_avg assert params.start_epoch == 1, "Fine-tune must start from epoch 1"
) modules = params.init_modules.split(",") if params.init_modules else None
checkpoints = load_model_params(
ckpt=params.finetune_ckpt, model=model, init_modules=modules
)
# Need to update the model_avg if use initialisation
if rank == 0:
# model_avg is only used with rank 0
model_avg = copy.deepcopy(model).to(torch.float64)
else:
# resuming training
assert params.start_epoch > 1, params.start_epoch
checkpoints = load_checkpoint_if_available(
params=params, model=model, model_avg=model_avg
)
model.to(device) model.to(device)
if world_size > 1: if world_size > 1:
@ -1174,10 +1281,18 @@ def run(rank, world_size, args):
librispeech = LibriSpeechAsrDataModule(args) librispeech = LibriSpeechAsrDataModule(args)
train_cuts = librispeech.train_clean_100_cuts() gigaspeech_cuts = librispeech.gigaspeech_subset_small_cuts()
if params.full_libri: if params.use_mux:
train_cuts += librispeech.train_clean_360_cuts() librispeech_cuts = librispeech.train_all_shuf_cuts()
train_cuts += librispeech.train_other_500_cuts() train_cuts = CutSet.mux(
gigaspeech_cuts, # num cuts = 688182
librispeech_cuts, # num cuts = 843723
weights=[688182, 843723],
stop_early=True,
)
else:
train_cuts = gigaspeech_cuts
logging.info(train_cuts)
def remove_short_and_long_utt(c: Cut): def remove_short_and_long_utt(c: Cut):
# Keep only utterances with duration between 1 second and 20 seconds # Keep only utterances with duration between 1 second and 20 seconds
@ -1231,7 +1346,13 @@ def run(rank, world_size, args):
valid_cuts = librispeech.dev_clean_cuts() valid_cuts = librispeech.dev_clean_cuts()
valid_cuts += librispeech.dev_other_cuts() valid_cuts += librispeech.dev_other_cuts()
valid_dl = librispeech.valid_dataloaders(valid_cuts) gigaspeech_dev_cuts = librispeech.gigaspeech_dev_cuts()
valid_sets = ["librispeech", "gigaspeech"]
valid_dls = [
librispeech.valid_dataloaders(valid_cuts),
librispeech.valid_dataloaders(gigaspeech_dev_cuts),
]
if not params.print_diagnostics: if not params.print_diagnostics:
scan_pessimistic_batches_for_oom( scan_pessimistic_batches_for_oom(
@ -1265,7 +1386,8 @@ def run(rank, world_size, args):
scheduler=scheduler, scheduler=scheduler,
sp=sp, sp=sp,
train_dl=train_dl, train_dl=train_dl,
valid_dl=valid_dl, valid_dls=valid_dls,
valid_sets=valid_sets,
scaler=scaler, scaler=scaler,
tb_writer=tb_writer, tb_writer=tb_writer,
world_size=world_size, world_size=world_size,