mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-04 14:44:18 +00:00
add model saving
This commit is contained in:
parent
2ce09809cd
commit
ac53222054
@ -29,9 +29,9 @@ import k2
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from asr_datamodule import AishellAsrDataModule
|
from asr_datamodule import AishellAsrDataModule
|
||||||
|
from model import load_model
|
||||||
|
|
||||||
#from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
|
from icefall.checkpoint import load_checkpoint, average_checkpoints_with_averaged_model
|
||||||
from icefall.checkpoint import average_checkpoints, load_checkpoint, average_checkpoints_with_averaged_model
|
|
||||||
from icefall.decode import (
|
from icefall.decode import (
|
||||||
get_lattice,
|
get_lattice,
|
||||||
nbest_decoding,
|
nbest_decoding,
|
||||||
@ -52,6 +52,56 @@ from zhconv import convert
|
|||||||
from tn.chinese.normalizer import Normalizer
|
from tn.chinese.normalizer import Normalizer
|
||||||
import re
|
import re
|
||||||
|
|
||||||
|
def average_checkpoints(
|
||||||
|
filenames: List[Path], device: torch.device = torch.device("cpu")
|
||||||
|
) -> dict:
|
||||||
|
"""Average a list of checkpoints.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filenames:
|
||||||
|
Filenames of the checkpoints to be averaged. We assume all
|
||||||
|
checkpoints are saved by :func:`save_checkpoint`.
|
||||||
|
device:
|
||||||
|
Move checkpoints to this device before averaging.
|
||||||
|
Returns:
|
||||||
|
Return a dict (i.e., state_dict) which is the average of all
|
||||||
|
model state dicts contained in the checkpoints.
|
||||||
|
"""
|
||||||
|
n = len(filenames)
|
||||||
|
|
||||||
|
if "model" in torch.load(filenames[0], map_location=device):
|
||||||
|
avg = torch.load(filenames[0], map_location=device)["model"]
|
||||||
|
else:
|
||||||
|
avg = torch.load(filenames[0], map_location=device)
|
||||||
|
|
||||||
|
# Identify shared parameters. Two parameters are said to be shared
|
||||||
|
# if they have the same data_ptr
|
||||||
|
uniqued: Dict[int, str] = dict()
|
||||||
|
|
||||||
|
for k, v in avg.items():
|
||||||
|
v_data_ptr = v.data_ptr()
|
||||||
|
if v_data_ptr in uniqued:
|
||||||
|
continue
|
||||||
|
uniqued[v_data_ptr] = k
|
||||||
|
|
||||||
|
uniqued_names = list(uniqued.values())
|
||||||
|
|
||||||
|
for i in range(1, n):
|
||||||
|
if "model" in torch.load(filenames[i], map_location=device):
|
||||||
|
state_dict = torch.load(filenames[i], map_location=device)["model"]
|
||||||
|
else:
|
||||||
|
state_dict = torch.load(filenames[i], map_location=device)
|
||||||
|
for k in uniqued_names:
|
||||||
|
avg[k] += state_dict[k]
|
||||||
|
|
||||||
|
for k in uniqued_names:
|
||||||
|
if avg[k].is_floating_point():
|
||||||
|
avg[k] /= n
|
||||||
|
else:
|
||||||
|
avg[k] //= n
|
||||||
|
|
||||||
|
return avg
|
||||||
|
|
||||||
def remove_punctuation(text: str or List[str]):
|
def remove_punctuation(text: str or List[str]):
|
||||||
# https://github.com/yeyupiaoling/Whisper-Finetune/blob/master/utils/data_utils.py
|
# https://github.com/yeyupiaoling/Whisper-Finetune/blob/master/utils/data_utils.py
|
||||||
punctuation = '!,.;:?、!,。;:?'
|
punctuation = '!,.;:?、!,。;:?'
|
||||||
@ -215,9 +265,9 @@ def decode_one_batch(
|
|||||||
assert feature.ndim == 3
|
assert feature.ndim == 3
|
||||||
feature = feature.to(device, dtype=dtype).transpose(1, 2)
|
feature = feature.to(device, dtype=dtype).transpose(1, 2)
|
||||||
# pad feature to T = 3000
|
# pad feature to T = 3000
|
||||||
T = 3000
|
#T = 3000
|
||||||
if feature.shape[2] < T:
|
#if feature.shape[2] < T:
|
||||||
feature = torch.cat([feature, torch.zeros(feature.shape[0], feature.shape[1], T - feature.shape[2]).to(device, dtype=dtype)], 2)
|
# feature = torch.cat([feature, torch.zeros(feature.shape[0], feature.shape[1], T - feature.shape[2]).to(device, dtype=dtype)], 2)
|
||||||
print(feature.shape,23333)
|
print(feature.shape,23333)
|
||||||
# at entry, feature is (N, T, C)
|
# at entry, feature is (N, T, C)
|
||||||
|
|
||||||
@ -379,29 +429,39 @@ def main():
|
|||||||
|
|
||||||
logging.info(f"device: {device}")
|
logging.info(f"device: {device}")
|
||||||
|
|
||||||
model = whisper.load_model(params.model_name)
|
model = load_model(params.model_name)
|
||||||
if params.epoch > 0:
|
if params.epoch > 0:
|
||||||
if params.avg > 1:
|
if params.avg > 1:
|
||||||
start = params.epoch - params.avg
|
start = params.epoch - params.avg
|
||||||
assert start >= 1, start
|
assert start >= 1, start
|
||||||
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
checkpoint = torch.load(f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location='cpu')
|
||||||
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
if 'model' not in checkpoint:
|
||||||
logging.info(
|
filenames = [f"{params.exp_dir}/epoch-{epoch}.pt" for epoch in range(start, params.epoch + 1)]
|
||||||
f"Calculating the averaged model over epoch range from "
|
model.load_state_dict(average_checkpoints(filenames))
|
||||||
f"{start} (excluded) to {params.epoch}"
|
else:
|
||||||
)
|
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||||
model.to(device)
|
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||||
model.load_state_dict(
|
logging.info(
|
||||||
average_checkpoints_with_averaged_model(
|
f"Calculating the averaged model over epoch range from "
|
||||||
filename_start=filename_start,
|
f"{start} (excluded) to {params.epoch}"
|
||||||
filename_end=filename_end,
|
|
||||||
device=device,
|
|
||||||
)
|
)
|
||||||
)
|
model.to(device)
|
||||||
|
model.load_state_dict(
|
||||||
|
average_checkpoints_with_averaged_model(
|
||||||
|
filename_start=filename_start,
|
||||||
|
filename_end=filename_end,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# save checkpoints
|
||||||
|
filename = f"{params.exp_dir}/epoch-{params.epoch}-avg-{params.avg}.pt"
|
||||||
|
torch.save(model.state_dict(), filename)
|
||||||
else:
|
else:
|
||||||
checkpoint = torch.load(f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location='cpu')
|
checkpoint = torch.load(f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location='cpu')
|
||||||
model.load_state_dict(checkpoint, strict=True)
|
if 'model' not in checkpoint:
|
||||||
#load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
model.load_state_dict(checkpoint, strict=True)
|
||||||
|
else:
|
||||||
|
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||||
model.to(device)
|
model.to(device)
|
||||||
model.eval()
|
model.eval()
|
||||||
num_param = sum([p.numel() for p in model.parameters()])
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
|
@ -16,12 +16,18 @@
|
|||||||
"reduce_bucket_size": 2e8,
|
"reduce_bucket_size": 2e8,
|
||||||
"contiguous_gradients": true
|
"contiguous_gradients": true
|
||||||
},
|
},
|
||||||
|
"optimizer": {
|
||||||
|
"type": "Adam",
|
||||||
|
"params": {
|
||||||
|
"lr": 1e-5
|
||||||
|
}
|
||||||
|
},
|
||||||
"scheduler": {
|
"scheduler": {
|
||||||
"type": "WarmupLR",
|
"type": "WarmupLR",
|
||||||
"params": {
|
"params": {
|
||||||
"warmup_min_lr": 1e-6,
|
"warmup_min_lr": 0,
|
||||||
"warmup_max_lr": 5e-6,
|
"warmup_max_lr": 1e-5,
|
||||||
"warmup_num_steps": 100
|
"warmup_num_steps": 1000
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"gradient_accumulation_steps": 1,
|
"gradient_accumulation_steps": 1,
|
||||||
|
@ -276,7 +276,6 @@ class Whisper(nn.Module):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def is_multilingual(self):
|
def is_multilingual(self):
|
||||||
return self.dims.n_vocab == 51865
|
|
||||||
return self.dims.n_vocab >= 51865
|
return self.dims.n_vocab >= 51865
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -126,7 +126,7 @@ def get_parser():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--num-epochs",
|
"--num-epochs",
|
||||||
type=int,
|
type=int,
|
||||||
default=10,
|
default=5,
|
||||||
help="Number of epochs to train.",
|
help="Number of epochs to train.",
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -649,7 +649,7 @@ def train_one_epoch(
|
|||||||
valid_info.write_summary(
|
valid_info.write_summary(
|
||||||
tb_writer, "train/valid_", params.batch_idx_train
|
tb_writer, "train/valid_", params.batch_idx_train
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||||
loss, loss_info = compute_loss(
|
loss, loss_info = compute_loss(
|
||||||
@ -732,7 +732,10 @@ def train_one_epoch(
|
|||||||
f"grad_scale is too small, exiting: {cur_grad_scale}"
|
f"grad_scale is too small, exiting: {cur_grad_scale}"
|
||||||
)
|
)
|
||||||
if batch_idx % params.log_interval == 0:
|
if batch_idx % params.log_interval == 0:
|
||||||
cur_lr = scheduler.get_last_lr()[0]
|
try:
|
||||||
|
cur_lr = scheduler.get_last_lr()[0]
|
||||||
|
except:
|
||||||
|
cur_lr = 0.0
|
||||||
cur_grad_scale = scaler._scale.item() if (params.use_fp16 and not params.deepspeed) else 1.0
|
cur_grad_scale = scaler._scale.item() if (params.use_fp16 and not params.deepspeed) else 1.0
|
||||||
|
|
||||||
logging.info(
|
logging.info(
|
||||||
@ -835,9 +838,8 @@ def run(rank, world_size, args):
|
|||||||
if world_size > 1:
|
if world_size > 1:
|
||||||
if params.deepspeed:
|
if params.deepspeed:
|
||||||
logging.info("Using DeepSpeed")
|
logging.info("Using DeepSpeed")
|
||||||
model, optimizer, _, _ = deepspeed.initialize(
|
model, optimizer, _, scheduler = deepspeed.initialize(
|
||||||
args=params, model=model, optimizer=optimizer,
|
args=params, model=model, model_parameters=model.parameters())
|
||||||
model_parameters=model.parameters())
|
|
||||||
else:
|
else:
|
||||||
logging.info("Using DDP")
|
logging.info("Using DDP")
|
||||||
setup_dist(use_ddp_launch=True)
|
setup_dist(use_ddp_launch=True)
|
||||||
@ -877,7 +879,8 @@ def run(rank, world_size, args):
|
|||||||
|
|
||||||
logging.info(f"start training from epoch {params.start_epoch}")
|
logging.info(f"start training from epoch {params.start_epoch}")
|
||||||
for epoch in range(params.start_epoch, params.num_epochs + 1):
|
for epoch in range(params.start_epoch, params.num_epochs + 1):
|
||||||
scheduler.step_epoch(epoch - 1)
|
if not params.deepspeed:
|
||||||
|
scheduler.step_epoch(epoch - 1)
|
||||||
fix_random_seed(params.seed + epoch - 1)
|
fix_random_seed(params.seed + epoch - 1)
|
||||||
train_dl.sampler.set_epoch(epoch - 1)
|
train_dl.sampler.set_epoch(epoch - 1)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user