mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-03 22:24:19 +00:00
support finetune zipformer
This commit is contained in:
parent
057238c27e
commit
fa96660ac9
@ -3,7 +3,8 @@
|
||||
# Wei Kang,
|
||||
# Mingshuang Luo,
|
||||
# Zengwei Yao,
|
||||
# Daniel Povey)
|
||||
# Daniel Povey,
|
||||
# Xiaoyu Yang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
@ -24,7 +25,7 @@ Usage:
|
||||
export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
||||
|
||||
# For non-streaming model training:
|
||||
./zipformer/train.py \
|
||||
./zipformer/finetune.py \
|
||||
--world-size 4 \
|
||||
--num-epochs 30 \
|
||||
--start-epoch 1 \
|
||||
@ -34,7 +35,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
||||
--max-duration 1000
|
||||
|
||||
# For streaming model training:
|
||||
./zipformer/train.py \
|
||||
./zipformer/finetune.py \
|
||||
--world-size 4 \
|
||||
--num-epochs 30 \
|
||||
--start-epoch 1 \
|
||||
@ -57,7 +58,7 @@ import logging
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from shutil import copyfile
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import k2
|
||||
import optim
|
||||
@ -68,7 +69,7 @@ import torch.nn as nn
|
||||
from asr_datamodule import LibriSpeechAsrDataModule
|
||||
from decoder import Decoder
|
||||
from joiner import Joiner
|
||||
from lhotse.cut import Cut
|
||||
from lhotse.cut import Cut, CutSet
|
||||
from lhotse.dataset.sampling.base import CutSampler
|
||||
from lhotse.utils import fix_random_seed
|
||||
from model import AsrModel
|
||||
@ -123,6 +124,46 @@ def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None:
|
||||
module.name = name
|
||||
|
||||
|
||||
def add_finetune_arguments(parser: argparse.ArgumentParser):
|
||||
parser.add_argument(
|
||||
"--do-finetune",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="If true, finetune from a pre-trained checkpoint",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use-mux",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="""
|
||||
Whether to adapt. If true, we will mix 5% of the new data
|
||||
with 95% of the original data to fine-tune. This is useful
|
||||
if you want to maintain the performance on the original domain
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--init-modules",
|
||||
type=str,
|
||||
default=None,
|
||||
help="""
|
||||
Modules to be initialized. It matches all parameters starting with
|
||||
a specific key. The keys are given with Comma seperated. If None,
|
||||
all modules will be initialised. For example, if you only want to
|
||||
initialise all parameters staring with "encoder", use "encoder";
|
||||
if you want to initialise parameters starting with encoder or decoder,
|
||||
use "encoder,joiner".
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--finetune-ckpt",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Fine-tuning from which checkpoint (path to a .pt file)",
|
||||
)
|
||||
|
||||
|
||||
def add_model_arguments(parser: argparse.ArgumentParser):
|
||||
parser.add_argument(
|
||||
"--num-encoder-layers",
|
||||
@ -469,6 +510,7 @@ def get_parser():
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
add_finetune_arguments(parser)
|
||||
|
||||
return parser
|
||||
|
||||
@ -700,6 +742,54 @@ def load_checkpoint_if_available(
|
||||
return saved_params
|
||||
|
||||
|
||||
def load_model_params(
|
||||
ckpt: str, model: nn.Module, init_modules: List[str] = None, strict: bool = True
|
||||
):
|
||||
"""Load model params from checkpoint
|
||||
|
||||
Args:
|
||||
ckpt (str): Path to the checkpoint
|
||||
model (nn.Module): model to be loaded
|
||||
init_modules (list[str]): List of modules to be initialized
|
||||
|
||||
"""
|
||||
logging.info(f"Loading checkpoint from {ckpt}")
|
||||
checkpoint = torch.load(ckpt, map_location="cpu")
|
||||
|
||||
# if module list is empty, load the whole model from ckpt
|
||||
if not init_modules:
|
||||
if next(iter(checkpoint["model"])).startswith("module."):
|
||||
logging.info("Loading checkpoint saved by DDP")
|
||||
|
||||
dst_state_dict = model.state_dict()
|
||||
src_state_dict = checkpoint["model"]
|
||||
for key in dst_state_dict.keys():
|
||||
src_key = "{}.{}".format("module", key)
|
||||
dst_state_dict[key] = src_state_dict.pop(src_key)
|
||||
assert len(src_state_dict) == 0
|
||||
model.load_state_dict(dst_state_dict, strict=strict)
|
||||
else:
|
||||
model.load_state_dict(checkpoint["model"], strict=strict)
|
||||
else:
|
||||
src_state_dict = checkpoint["model"]
|
||||
dst_state_dict = model.state_dict()
|
||||
for module in init_modules:
|
||||
logging.info(f"Loading parameters starting with prefix {module}")
|
||||
src_keys = [
|
||||
k for k in src_state_dict.keys() if k.startswith(module.strip() + ".")
|
||||
]
|
||||
dst_keys = [
|
||||
k for k in dst_state_dict.keys() if k.startswith(module.strip() + ".")
|
||||
]
|
||||
assert set(src_keys) == set(dst_keys) # two sets should match exactly
|
||||
for key in src_keys:
|
||||
dst_state_dict[key] = src_state_dict.pop(key)
|
||||
|
||||
model.load_state_dict(dst_state_dict, strict=strict)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def save_checkpoint(
|
||||
params: AttributeDict,
|
||||
model: Union[nn.Module, DDP],
|
||||
@ -881,7 +971,8 @@ def train_one_epoch(
|
||||
scheduler: LRSchedulerType,
|
||||
sp: spm.SentencePieceProcessor,
|
||||
train_dl: torch.utils.data.DataLoader,
|
||||
valid_dl: torch.utils.data.DataLoader,
|
||||
valid_dls: torch.utils.data.DataLoader,
|
||||
valid_sets: List[str],
|
||||
scaler: GradScaler,
|
||||
model_avg: Optional[nn.Module] = None,
|
||||
tb_writer: Optional[SummaryWriter] = None,
|
||||
@ -1052,23 +1143,26 @@ def train_one_epoch(
|
||||
)
|
||||
|
||||
if batch_idx % params.valid_interval == 0 and not params.print_diagnostics:
|
||||
logging.info("Computing validation loss")
|
||||
valid_info = compute_validation_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
sp=sp,
|
||||
valid_dl=valid_dl,
|
||||
world_size=world_size,
|
||||
)
|
||||
model.train()
|
||||
logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
|
||||
logging.info(
|
||||
f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
|
||||
)
|
||||
if tb_writer is not None:
|
||||
valid_info.write_summary(
|
||||
tb_writer, "train/valid_", params.batch_idx_train
|
||||
for valid_set, valid_dl in zip(valid_sets, valid_dls):
|
||||
logging.info(f"Computing validation loss on {valid_set}")
|
||||
valid_info = compute_validation_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
sp=sp,
|
||||
valid_dl=valid_dl,
|
||||
world_size=world_size,
|
||||
)
|
||||
model.train()
|
||||
logging.info(
|
||||
f"Validation on {valid_set}: Epoch {params.cur_epoch}, validation: {valid_info}"
|
||||
)
|
||||
logging.info(
|
||||
f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
|
||||
)
|
||||
if tb_writer is not None:
|
||||
valid_info.write_summary(
|
||||
tb_writer, f"train/{valid_set}_valid_", params.batch_idx_train
|
||||
)
|
||||
|
||||
loss_value = tot_loss["loss"] / tot_loss["frames"]
|
||||
params.train_loss = loss_value
|
||||
@ -1133,10 +1227,23 @@ def run(rank, world_size, args):
|
||||
# model_avg is only used with rank 0
|
||||
model_avg = copy.deepcopy(model).to(torch.float64)
|
||||
|
||||
assert params.start_epoch > 0, params.start_epoch
|
||||
checkpoints = load_checkpoint_if_available(
|
||||
params=params, model=model, model_avg=model_avg
|
||||
)
|
||||
# load model parameters for model fine-tuning
|
||||
if params.do_finetune:
|
||||
assert params.start_epoch == 1, "Fine-tune must start from epoch 1"
|
||||
modules = params.init_modules.split(",") if params.init_modules else None
|
||||
checkpoints = load_model_params(
|
||||
ckpt=params.finetune_ckpt, model=model, init_modules=modules
|
||||
)
|
||||
# Need to update the model_avg if use initialisation
|
||||
if rank == 0:
|
||||
# model_avg is only used with rank 0
|
||||
model_avg = copy.deepcopy(model).to(torch.float64)
|
||||
else:
|
||||
# resuming training
|
||||
assert params.start_epoch > 1, params.start_epoch
|
||||
checkpoints = load_checkpoint_if_available(
|
||||
params=params, model=model, model_avg=model_avg
|
||||
)
|
||||
|
||||
model.to(device)
|
||||
if world_size > 1:
|
||||
@ -1174,10 +1281,18 @@ def run(rank, world_size, args):
|
||||
|
||||
librispeech = LibriSpeechAsrDataModule(args)
|
||||
|
||||
train_cuts = librispeech.train_clean_100_cuts()
|
||||
if params.full_libri:
|
||||
train_cuts += librispeech.train_clean_360_cuts()
|
||||
train_cuts += librispeech.train_other_500_cuts()
|
||||
gigaspeech_cuts = librispeech.gigaspeech_subset_small_cuts()
|
||||
if params.use_mux:
|
||||
librispeech_cuts = librispeech.train_all_shuf_cuts()
|
||||
train_cuts = CutSet.mux(
|
||||
gigaspeech_cuts, # num cuts = 688182
|
||||
librispeech_cuts, # num cuts = 843723
|
||||
weights=[688182, 843723],
|
||||
stop_early=True,
|
||||
)
|
||||
else:
|
||||
train_cuts = gigaspeech_cuts
|
||||
logging.info(train_cuts)
|
||||
|
||||
def remove_short_and_long_utt(c: Cut):
|
||||
# Keep only utterances with duration between 1 second and 20 seconds
|
||||
@ -1231,7 +1346,13 @@ def run(rank, world_size, args):
|
||||
|
||||
valid_cuts = librispeech.dev_clean_cuts()
|
||||
valid_cuts += librispeech.dev_other_cuts()
|
||||
valid_dl = librispeech.valid_dataloaders(valid_cuts)
|
||||
gigaspeech_dev_cuts = librispeech.gigaspeech_dev_cuts()
|
||||
|
||||
valid_sets = ["librispeech", "gigaspeech"]
|
||||
valid_dls = [
|
||||
librispeech.valid_dataloaders(valid_cuts),
|
||||
librispeech.valid_dataloaders(gigaspeech_dev_cuts),
|
||||
]
|
||||
|
||||
if not params.print_diagnostics:
|
||||
scan_pessimistic_batches_for_oom(
|
||||
@ -1265,7 +1386,8 @@ def run(rank, world_size, args):
|
||||
scheduler=scheduler,
|
||||
sp=sp,
|
||||
train_dl=train_dl,
|
||||
valid_dl=valid_dl,
|
||||
valid_dls=valid_dls,
|
||||
valid_sets=valid_sets,
|
||||
scaler=scaler,
|
||||
tb_writer=tb_writer,
|
||||
world_size=world_size,
|
||||
|
Loading…
x
Reference in New Issue
Block a user