add model saving

This commit is contained in:
Yuekai Zhang 2024-01-15 14:56:18 +08:00
parent 2ce09809cd
commit ac53222054
4 changed files with 100 additions and 32 deletions

View File

@ -29,9 +29,9 @@ import k2
import torch import torch
import torch.nn as nn import torch.nn as nn
from asr_datamodule import AishellAsrDataModule from asr_datamodule import AishellAsrDataModule
from model import load_model
#from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler from icefall.checkpoint import load_checkpoint, average_checkpoints_with_averaged_model
from icefall.checkpoint import average_checkpoints, load_checkpoint, average_checkpoints_with_averaged_model
from icefall.decode import ( from icefall.decode import (
get_lattice, get_lattice,
nbest_decoding, nbest_decoding,
@ -52,6 +52,56 @@ from zhconv import convert
from tn.chinese.normalizer import Normalizer from tn.chinese.normalizer import Normalizer
import re import re
def average_checkpoints(
filenames: List[Path], device: torch.device = torch.device("cpu")
) -> dict:
"""Average a list of checkpoints.
Args:
filenames:
Filenames of the checkpoints to be averaged. We assume all
checkpoints are saved by :func:`save_checkpoint`.
device:
Move checkpoints to this device before averaging.
Returns:
Return a dict (i.e., state_dict) which is the average of all
model state dicts contained in the checkpoints.
"""
n = len(filenames)
if "model" in torch.load(filenames[0], map_location=device):
avg = torch.load(filenames[0], map_location=device)["model"]
else:
avg = torch.load(filenames[0], map_location=device)
# Identify shared parameters. Two parameters are said to be shared
# if they have the same data_ptr
uniqued: Dict[int, str] = dict()
for k, v in avg.items():
v_data_ptr = v.data_ptr()
if v_data_ptr in uniqued:
continue
uniqued[v_data_ptr] = k
uniqued_names = list(uniqued.values())
for i in range(1, n):
if "model" in torch.load(filenames[i], map_location=device):
state_dict = torch.load(filenames[i], map_location=device)["model"]
else:
state_dict = torch.load(filenames[i], map_location=device)
for k in uniqued_names:
avg[k] += state_dict[k]
for k in uniqued_names:
if avg[k].is_floating_point():
avg[k] /= n
else:
avg[k] //= n
return avg
def remove_punctuation(text: str or List[str]): def remove_punctuation(text: str or List[str]):
# https://github.com/yeyupiaoling/Whisper-Finetune/blob/master/utils/data_utils.py # https://github.com/yeyupiaoling/Whisper-Finetune/blob/master/utils/data_utils.py
punctuation = '!,.;:?、!,。;:?' punctuation = '!,.;:?、!,。;:?'
@ -215,9 +265,9 @@ def decode_one_batch(
assert feature.ndim == 3 assert feature.ndim == 3
feature = feature.to(device, dtype=dtype).transpose(1, 2) feature = feature.to(device, dtype=dtype).transpose(1, 2)
# pad feature to T = 3000 # pad feature to T = 3000
T = 3000 #T = 3000
if feature.shape[2] < T: #if feature.shape[2] < T:
feature = torch.cat([feature, torch.zeros(feature.shape[0], feature.shape[1], T - feature.shape[2]).to(device, dtype=dtype)], 2) # feature = torch.cat([feature, torch.zeros(feature.shape[0], feature.shape[1], T - feature.shape[2]).to(device, dtype=dtype)], 2)
print(feature.shape,23333) print(feature.shape,23333)
# at entry, feature is (N, T, C) # at entry, feature is (N, T, C)
@ -379,29 +429,39 @@ def main():
logging.info(f"device: {device}") logging.info(f"device: {device}")
model = whisper.load_model(params.model_name) model = load_model(params.model_name)
if params.epoch > 0: if params.epoch > 0:
if params.avg > 1: if params.avg > 1:
start = params.epoch - params.avg start = params.epoch - params.avg
assert start >= 1, start assert start >= 1, start
filename_start = f"{params.exp_dir}/epoch-{start}.pt" checkpoint = torch.load(f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location='cpu')
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" if 'model' not in checkpoint:
logging.info( filenames = [f"{params.exp_dir}/epoch-{epoch}.pt" for epoch in range(start, params.epoch + 1)]
f"Calculating the averaged model over epoch range from " model.load_state_dict(average_checkpoints(filenames))
f"{start} (excluded) to {params.epoch}" else:
) filename_start = f"{params.exp_dir}/epoch-{start}.pt"
model.to(device) filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
model.load_state_dict( logging.info(
average_checkpoints_with_averaged_model( f"Calculating the averaged model over epoch range from "
filename_start=filename_start, f"{start} (excluded) to {params.epoch}"
filename_end=filename_end,
device=device,
) )
) model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
# save checkpoints
filename = f"{params.exp_dir}/epoch-{params.epoch}-avg-{params.avg}.pt"
torch.save(model.state_dict(), filename)
else: else:
checkpoint = torch.load(f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location='cpu') checkpoint = torch.load(f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location='cpu')
model.load_state_dict(checkpoint, strict=True) if 'model' not in checkpoint:
#load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) model.load_state_dict(checkpoint, strict=True)
else:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
model.to(device) model.to(device)
model.eval() model.eval()
num_param = sum([p.numel() for p in model.parameters()]) num_param = sum([p.numel() for p in model.parameters()])

View File

@ -16,12 +16,18 @@
"reduce_bucket_size": 2e8, "reduce_bucket_size": 2e8,
"contiguous_gradients": true "contiguous_gradients": true
}, },
"optimizer": {
"type": "Adam",
"params": {
"lr": 1e-5
}
},
"scheduler": { "scheduler": {
"type": "WarmupLR", "type": "WarmupLR",
"params": { "params": {
"warmup_min_lr": 1e-6, "warmup_min_lr": 0,
"warmup_max_lr": 5e-6, "warmup_max_lr": 1e-5,
"warmup_num_steps": 100 "warmup_num_steps": 1000
} }
}, },
"gradient_accumulation_steps": 1, "gradient_accumulation_steps": 1,

View File

@ -276,7 +276,6 @@ class Whisper(nn.Module):
@property @property
def is_multilingual(self): def is_multilingual(self):
return self.dims.n_vocab == 51865
return self.dims.n_vocab >= 51865 return self.dims.n_vocab >= 51865
@property @property

View File

@ -126,7 +126,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--num-epochs", "--num-epochs",
type=int, type=int,
default=10, default=5,
help="Number of epochs to train.", help="Number of epochs to train.",
) )
@ -649,7 +649,7 @@ def train_one_epoch(
valid_info.write_summary( valid_info.write_summary(
tb_writer, "train/valid_", params.batch_idx_train tb_writer, "train/valid_", params.batch_idx_train
) )
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch.cuda.amp.autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
@ -732,7 +732,10 @@ def train_one_epoch(
f"grad_scale is too small, exiting: {cur_grad_scale}" f"grad_scale is too small, exiting: {cur_grad_scale}"
) )
if batch_idx % params.log_interval == 0: if batch_idx % params.log_interval == 0:
cur_lr = scheduler.get_last_lr()[0] try:
cur_lr = scheduler.get_last_lr()[0]
except:
cur_lr = 0.0
cur_grad_scale = scaler._scale.item() if (params.use_fp16 and not params.deepspeed) else 1.0 cur_grad_scale = scaler._scale.item() if (params.use_fp16 and not params.deepspeed) else 1.0
logging.info( logging.info(
@ -835,9 +838,8 @@ def run(rank, world_size, args):
if world_size > 1: if world_size > 1:
if params.deepspeed: if params.deepspeed:
logging.info("Using DeepSpeed") logging.info("Using DeepSpeed")
model, optimizer, _, _ = deepspeed.initialize( model, optimizer, _, scheduler = deepspeed.initialize(
args=params, model=model, optimizer=optimizer, args=params, model=model, model_parameters=model.parameters())
model_parameters=model.parameters())
else: else:
logging.info("Using DDP") logging.info("Using DDP")
setup_dist(use_ddp_launch=True) setup_dist(use_ddp_launch=True)
@ -877,7 +879,8 @@ def run(rank, world_size, args):
logging.info(f"start training from epoch {params.start_epoch}") logging.info(f"start training from epoch {params.start_epoch}")
for epoch in range(params.start_epoch, params.num_epochs + 1): for epoch in range(params.start_epoch, params.num_epochs + 1):
scheduler.step_epoch(epoch - 1) if not params.deepspeed:
scheduler.step_epoch(epoch - 1)
fix_random_seed(params.seed + epoch - 1) fix_random_seed(params.seed + epoch - 1)
train_dl.sampler.set_epoch(epoch - 1) train_dl.sampler.set_epoch(epoch - 1)