mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-03 22:24:19 +00:00
add model saving
This commit is contained in:
parent
2ce09809cd
commit
ac53222054
@ -29,9 +29,9 @@ import k2
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import AishellAsrDataModule
|
||||
from model import load_model
|
||||
|
||||
#from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
|
||||
from icefall.checkpoint import average_checkpoints, load_checkpoint, average_checkpoints_with_averaged_model
|
||||
from icefall.checkpoint import load_checkpoint, average_checkpoints_with_averaged_model
|
||||
from icefall.decode import (
|
||||
get_lattice,
|
||||
nbest_decoding,
|
||||
@ -52,6 +52,56 @@ from zhconv import convert
|
||||
from tn.chinese.normalizer import Normalizer
|
||||
import re
|
||||
|
||||
def average_checkpoints(
|
||||
filenames: List[Path], device: torch.device = torch.device("cpu")
|
||||
) -> dict:
|
||||
"""Average a list of checkpoints.
|
||||
|
||||
Args:
|
||||
filenames:
|
||||
Filenames of the checkpoints to be averaged. We assume all
|
||||
checkpoints are saved by :func:`save_checkpoint`.
|
||||
device:
|
||||
Move checkpoints to this device before averaging.
|
||||
Returns:
|
||||
Return a dict (i.e., state_dict) which is the average of all
|
||||
model state dicts contained in the checkpoints.
|
||||
"""
|
||||
n = len(filenames)
|
||||
|
||||
if "model" in torch.load(filenames[0], map_location=device):
|
||||
avg = torch.load(filenames[0], map_location=device)["model"]
|
||||
else:
|
||||
avg = torch.load(filenames[0], map_location=device)
|
||||
|
||||
# Identify shared parameters. Two parameters are said to be shared
|
||||
# if they have the same data_ptr
|
||||
uniqued: Dict[int, str] = dict()
|
||||
|
||||
for k, v in avg.items():
|
||||
v_data_ptr = v.data_ptr()
|
||||
if v_data_ptr in uniqued:
|
||||
continue
|
||||
uniqued[v_data_ptr] = k
|
||||
|
||||
uniqued_names = list(uniqued.values())
|
||||
|
||||
for i in range(1, n):
|
||||
if "model" in torch.load(filenames[i], map_location=device):
|
||||
state_dict = torch.load(filenames[i], map_location=device)["model"]
|
||||
else:
|
||||
state_dict = torch.load(filenames[i], map_location=device)
|
||||
for k in uniqued_names:
|
||||
avg[k] += state_dict[k]
|
||||
|
||||
for k in uniqued_names:
|
||||
if avg[k].is_floating_point():
|
||||
avg[k] /= n
|
||||
else:
|
||||
avg[k] //= n
|
||||
|
||||
return avg
|
||||
|
||||
def remove_punctuation(text: str or List[str]):
|
||||
# https://github.com/yeyupiaoling/Whisper-Finetune/blob/master/utils/data_utils.py
|
||||
punctuation = '!,.;:?、!,。;:?'
|
||||
@ -215,9 +265,9 @@ def decode_one_batch(
|
||||
assert feature.ndim == 3
|
||||
feature = feature.to(device, dtype=dtype).transpose(1, 2)
|
||||
# pad feature to T = 3000
|
||||
T = 3000
|
||||
if feature.shape[2] < T:
|
||||
feature = torch.cat([feature, torch.zeros(feature.shape[0], feature.shape[1], T - feature.shape[2]).to(device, dtype=dtype)], 2)
|
||||
#T = 3000
|
||||
#if feature.shape[2] < T:
|
||||
# feature = torch.cat([feature, torch.zeros(feature.shape[0], feature.shape[1], T - feature.shape[2]).to(device, dtype=dtype)], 2)
|
||||
print(feature.shape,23333)
|
||||
# at entry, feature is (N, T, C)
|
||||
|
||||
@ -379,29 +429,39 @@ def main():
|
||||
|
||||
logging.info(f"device: {device}")
|
||||
|
||||
model = whisper.load_model(params.model_name)
|
||||
model = load_model(params.model_name)
|
||||
if params.epoch > 0:
|
||||
if params.avg > 1:
|
||||
start = params.epoch - params.avg
|
||||
assert start >= 1, start
|
||||
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||
logging.info(
|
||||
f"Calculating the averaged model over epoch range from "
|
||||
f"{start} (excluded) to {params.epoch}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
checkpoint = torch.load(f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location='cpu')
|
||||
if 'model' not in checkpoint:
|
||||
filenames = [f"{params.exp_dir}/epoch-{epoch}.pt" for epoch in range(start, params.epoch + 1)]
|
||||
model.load_state_dict(average_checkpoints(filenames))
|
||||
else:
|
||||
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||
logging.info(
|
||||
f"Calculating the averaged model over epoch range from "
|
||||
f"{start} (excluded) to {params.epoch}"
|
||||
)
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
# save checkpoints
|
||||
filename = f"{params.exp_dir}/epoch-{params.epoch}-avg-{params.avg}.pt"
|
||||
torch.save(model.state_dict(), filename)
|
||||
else:
|
||||
checkpoint = torch.load(f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location='cpu')
|
||||
model.load_state_dict(checkpoint, strict=True)
|
||||
#load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
if 'model' not in checkpoint:
|
||||
model.load_state_dict(checkpoint, strict=True)
|
||||
else:
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
model.to(device)
|
||||
model.eval()
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
|
@ -16,12 +16,18 @@
|
||||
"reduce_bucket_size": 2e8,
|
||||
"contiguous_gradients": true
|
||||
},
|
||||
"optimizer": {
|
||||
"type": "Adam",
|
||||
"params": {
|
||||
"lr": 1e-5
|
||||
}
|
||||
},
|
||||
"scheduler": {
|
||||
"type": "WarmupLR",
|
||||
"params": {
|
||||
"warmup_min_lr": 1e-6,
|
||||
"warmup_max_lr": 5e-6,
|
||||
"warmup_num_steps": 100
|
||||
"warmup_min_lr": 0,
|
||||
"warmup_max_lr": 1e-5,
|
||||
"warmup_num_steps": 1000
|
||||
}
|
||||
},
|
||||
"gradient_accumulation_steps": 1,
|
||||
|
@ -276,7 +276,6 @@ class Whisper(nn.Module):
|
||||
|
||||
@property
|
||||
def is_multilingual(self):
|
||||
return self.dims.n_vocab == 51865
|
||||
return self.dims.n_vocab >= 51865
|
||||
|
||||
@property
|
||||
|
@ -126,7 +126,7 @@ def get_parser():
|
||||
parser.add_argument(
|
||||
"--num-epochs",
|
||||
type=int,
|
||||
default=10,
|
||||
default=5,
|
||||
help="Number of epochs to train.",
|
||||
)
|
||||
|
||||
@ -649,7 +649,7 @@ def train_one_epoch(
|
||||
valid_info.write_summary(
|
||||
tb_writer, "train/valid_", params.batch_idx_train
|
||||
)
|
||||
|
||||
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
loss, loss_info = compute_loss(
|
||||
@ -732,7 +732,10 @@ def train_one_epoch(
|
||||
f"grad_scale is too small, exiting: {cur_grad_scale}"
|
||||
)
|
||||
if batch_idx % params.log_interval == 0:
|
||||
cur_lr = scheduler.get_last_lr()[0]
|
||||
try:
|
||||
cur_lr = scheduler.get_last_lr()[0]
|
||||
except:
|
||||
cur_lr = 0.0
|
||||
cur_grad_scale = scaler._scale.item() if (params.use_fp16 and not params.deepspeed) else 1.0
|
||||
|
||||
logging.info(
|
||||
@ -835,9 +838,8 @@ def run(rank, world_size, args):
|
||||
if world_size > 1:
|
||||
if params.deepspeed:
|
||||
logging.info("Using DeepSpeed")
|
||||
model, optimizer, _, _ = deepspeed.initialize(
|
||||
args=params, model=model, optimizer=optimizer,
|
||||
model_parameters=model.parameters())
|
||||
model, optimizer, _, scheduler = deepspeed.initialize(
|
||||
args=params, model=model, model_parameters=model.parameters())
|
||||
else:
|
||||
logging.info("Using DDP")
|
||||
setup_dist(use_ddp_launch=True)
|
||||
@ -877,7 +879,8 @@ def run(rank, world_size, args):
|
||||
|
||||
logging.info(f"start training from epoch {params.start_epoch}")
|
||||
for epoch in range(params.start_epoch, params.num_epochs + 1):
|
||||
scheduler.step_epoch(epoch - 1)
|
||||
if not params.deepspeed:
|
||||
scheduler.step_epoch(epoch - 1)
|
||||
fix_random_seed(params.seed + epoch - 1)
|
||||
train_dl.sampler.set_epoch(epoch - 1)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user