add model saving

This commit is contained in:
Yuekai Zhang 2024-01-15 14:56:18 +08:00
parent 2ce09809cd
commit ac53222054
4 changed files with 100 additions and 32 deletions

View File

@ -29,9 +29,9 @@ import k2
import torch
import torch.nn as nn
from asr_datamodule import AishellAsrDataModule
from model import load_model
#from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
from icefall.checkpoint import average_checkpoints, load_checkpoint, average_checkpoints_with_averaged_model
from icefall.checkpoint import load_checkpoint, average_checkpoints_with_averaged_model
from icefall.decode import (
get_lattice,
nbest_decoding,
@ -52,6 +52,56 @@ from zhconv import convert
from tn.chinese.normalizer import Normalizer
import re
def average_checkpoints(
filenames: List[Path], device: torch.device = torch.device("cpu")
) -> dict:
"""Average a list of checkpoints.
Args:
filenames:
Filenames of the checkpoints to be averaged. We assume all
checkpoints are saved by :func:`save_checkpoint`.
device:
Move checkpoints to this device before averaging.
Returns:
Return a dict (i.e., state_dict) which is the average of all
model state dicts contained in the checkpoints.
"""
n = len(filenames)
if "model" in torch.load(filenames[0], map_location=device):
avg = torch.load(filenames[0], map_location=device)["model"]
else:
avg = torch.load(filenames[0], map_location=device)
# Identify shared parameters. Two parameters are said to be shared
# if they have the same data_ptr
uniqued: Dict[int, str] = dict()
for k, v in avg.items():
v_data_ptr = v.data_ptr()
if v_data_ptr in uniqued:
continue
uniqued[v_data_ptr] = k
uniqued_names = list(uniqued.values())
for i in range(1, n):
if "model" in torch.load(filenames[i], map_location=device):
state_dict = torch.load(filenames[i], map_location=device)["model"]
else:
state_dict = torch.load(filenames[i], map_location=device)
for k in uniqued_names:
avg[k] += state_dict[k]
for k in uniqued_names:
if avg[k].is_floating_point():
avg[k] /= n
else:
avg[k] //= n
return avg
def remove_punctuation(text: str or List[str]):
# https://github.com/yeyupiaoling/Whisper-Finetune/blob/master/utils/data_utils.py
punctuation = '!,.;:?、!,。;:?'
@ -215,9 +265,9 @@ def decode_one_batch(
assert feature.ndim == 3
feature = feature.to(device, dtype=dtype).transpose(1, 2)
# pad feature to T = 3000
T = 3000
if feature.shape[2] < T:
feature = torch.cat([feature, torch.zeros(feature.shape[0], feature.shape[1], T - feature.shape[2]).to(device, dtype=dtype)], 2)
#T = 3000
#if feature.shape[2] < T:
# feature = torch.cat([feature, torch.zeros(feature.shape[0], feature.shape[1], T - feature.shape[2]).to(device, dtype=dtype)], 2)
print(feature.shape,23333)
# at entry, feature is (N, T, C)
@ -379,29 +429,39 @@ def main():
logging.info(f"device: {device}")
model = whisper.load_model(params.model_name)
model = load_model(params.model_name)
if params.epoch > 0:
if params.avg > 1:
start = params.epoch - params.avg
assert start >= 1, start
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(
f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
checkpoint = torch.load(f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location='cpu')
if 'model' not in checkpoint:
filenames = [f"{params.exp_dir}/epoch-{epoch}.pt" for epoch in range(start, params.epoch + 1)]
model.load_state_dict(average_checkpoints(filenames))
else:
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(
f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}"
)
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
# save checkpoints
filename = f"{params.exp_dir}/epoch-{params.epoch}-avg-{params.avg}.pt"
torch.save(model.state_dict(), filename)
else:
checkpoint = torch.load(f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location='cpu')
model.load_state_dict(checkpoint, strict=True)
#load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
if 'model' not in checkpoint:
model.load_state_dict(checkpoint, strict=True)
else:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
model.to(device)
model.eval()
num_param = sum([p.numel() for p in model.parameters()])

View File

@ -16,12 +16,18 @@
"reduce_bucket_size": 2e8,
"contiguous_gradients": true
},
"optimizer": {
"type": "Adam",
"params": {
"lr": 1e-5
}
},
"scheduler": {
"type": "WarmupLR",
"params": {
"warmup_min_lr": 1e-6,
"warmup_max_lr": 5e-6,
"warmup_num_steps": 100
"warmup_min_lr": 0,
"warmup_max_lr": 1e-5,
"warmup_num_steps": 1000
}
},
"gradient_accumulation_steps": 1,

View File

@ -276,7 +276,6 @@ class Whisper(nn.Module):
@property
def is_multilingual(self):
return self.dims.n_vocab == 51865
return self.dims.n_vocab >= 51865
@property

View File

@ -126,7 +126,7 @@ def get_parser():
parser.add_argument(
"--num-epochs",
type=int,
default=10,
default=5,
help="Number of epochs to train.",
)
@ -732,7 +732,10 @@ def train_one_epoch(
f"grad_scale is too small, exiting: {cur_grad_scale}"
)
if batch_idx % params.log_interval == 0:
cur_lr = scheduler.get_last_lr()[0]
try:
cur_lr = scheduler.get_last_lr()[0]
except:
cur_lr = 0.0
cur_grad_scale = scaler._scale.item() if (params.use_fp16 and not params.deepspeed) else 1.0
logging.info(
@ -835,9 +838,8 @@ def run(rank, world_size, args):
if world_size > 1:
if params.deepspeed:
logging.info("Using DeepSpeed")
model, optimizer, _, _ = deepspeed.initialize(
args=params, model=model, optimizer=optimizer,
model_parameters=model.parameters())
model, optimizer, _, scheduler = deepspeed.initialize(
args=params, model=model, model_parameters=model.parameters())
else:
logging.info("Using DDP")
setup_dist(use_ddp_launch=True)
@ -877,7 +879,8 @@ def run(rank, world_size, args):
logging.info(f"start training from epoch {params.start_epoch}")
for epoch in range(params.start_epoch, params.num_epochs + 1):
scheduler.step_epoch(epoch - 1)
if not params.deepspeed:
scheduler.step_epoch(epoch - 1)
fix_random_seed(params.seed + epoch - 1)
train_dl.sampler.set_epoch(epoch - 1)