From ac532220548f2a8746ecb9ebde9c660afcfc300f Mon Sep 17 00:00:00 2001 From: Yuekai Zhang Date: Mon, 15 Jan 2024 14:56:18 +0800 Subject: [PATCH] add model saving --- egs/aishell/ASR/whisper/decode.py | 102 +++++++++++++++---- egs/aishell/ASR/whisper/ds_config_zero1.json | 12 ++- egs/aishell/ASR/whisper/model.py | 1 - egs/aishell/ASR/whisper/train.py | 17 ++-- 4 files changed, 100 insertions(+), 32 deletions(-) diff --git a/egs/aishell/ASR/whisper/decode.py b/egs/aishell/ASR/whisper/decode.py index 2d8dbbfc3..34dae7a85 100644 --- a/egs/aishell/ASR/whisper/decode.py +++ b/egs/aishell/ASR/whisper/decode.py @@ -29,9 +29,9 @@ import k2 import torch import torch.nn as nn from asr_datamodule import AishellAsrDataModule +from model import load_model -#from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler -from icefall.checkpoint import average_checkpoints, load_checkpoint, average_checkpoints_with_averaged_model +from icefall.checkpoint import load_checkpoint, average_checkpoints_with_averaged_model from icefall.decode import ( get_lattice, nbest_decoding, @@ -52,6 +52,56 @@ from zhconv import convert from tn.chinese.normalizer import Normalizer import re +def average_checkpoints( + filenames: List[Path], device: torch.device = torch.device("cpu") +) -> dict: + """Average a list of checkpoints. + + Args: + filenames: + Filenames of the checkpoints to be averaged. We assume all + checkpoints are saved by :func:`save_checkpoint`. + device: + Move checkpoints to this device before averaging. + Returns: + Return a dict (i.e., state_dict) which is the average of all + model state dicts contained in the checkpoints. + """ + n = len(filenames) + + if "model" in torch.load(filenames[0], map_location=device): + avg = torch.load(filenames[0], map_location=device)["model"] + else: + avg = torch.load(filenames[0], map_location=device) + + # Identify shared parameters. Two parameters are said to be shared + # if they have the same data_ptr + uniqued: Dict[int, str] = dict() + + for k, v in avg.items(): + v_data_ptr = v.data_ptr() + if v_data_ptr in uniqued: + continue + uniqued[v_data_ptr] = k + + uniqued_names = list(uniqued.values()) + + for i in range(1, n): + if "model" in torch.load(filenames[i], map_location=device): + state_dict = torch.load(filenames[i], map_location=device)["model"] + else: + state_dict = torch.load(filenames[i], map_location=device) + for k in uniqued_names: + avg[k] += state_dict[k] + + for k in uniqued_names: + if avg[k].is_floating_point(): + avg[k] /= n + else: + avg[k] //= n + + return avg + def remove_punctuation(text: str or List[str]): # https://github.com/yeyupiaoling/Whisper-Finetune/blob/master/utils/data_utils.py punctuation = '!,.;:?、!,。;:?' @@ -215,9 +265,9 @@ def decode_one_batch( assert feature.ndim == 3 feature = feature.to(device, dtype=dtype).transpose(1, 2) # pad feature to T = 3000 - T = 3000 - if feature.shape[2] < T: - feature = torch.cat([feature, torch.zeros(feature.shape[0], feature.shape[1], T - feature.shape[2]).to(device, dtype=dtype)], 2) + #T = 3000 + #if feature.shape[2] < T: + # feature = torch.cat([feature, torch.zeros(feature.shape[0], feature.shape[1], T - feature.shape[2]).to(device, dtype=dtype)], 2) print(feature.shape,23333) # at entry, feature is (N, T, C) @@ -379,29 +429,39 @@ def main(): logging.info(f"device: {device}") - model = whisper.load_model(params.model_name) + model = load_model(params.model_name) if params.epoch > 0: if params.avg > 1: start = params.epoch - params.avg assert start >= 1, start - filename_start = f"{params.exp_dir}/epoch-{start}.pt" - filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" - logging.info( - f"Calculating the averaged model over epoch range from " - f"{start} (excluded) to {params.epoch}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, + checkpoint = torch.load(f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location='cpu') + if 'model' not in checkpoint: + filenames = [f"{params.exp_dir}/epoch-{epoch}.pt" for epoch in range(start, params.epoch + 1)] + model.load_state_dict(average_checkpoints(filenames)) + else: + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" ) - ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + # save checkpoints + filename = f"{params.exp_dir}/epoch-{params.epoch}-avg-{params.avg}.pt" + torch.save(model.state_dict(), filename) else: checkpoint = torch.load(f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location='cpu') - model.load_state_dict(checkpoint, strict=True) - #load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + if 'model' not in checkpoint: + model.load_state_dict(checkpoint, strict=True) + else: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) model.to(device) model.eval() num_param = sum([p.numel() for p in model.parameters()]) diff --git a/egs/aishell/ASR/whisper/ds_config_zero1.json b/egs/aishell/ASR/whisper/ds_config_zero1.json index 0d69f83b9..cd8cbac8e 100644 --- a/egs/aishell/ASR/whisper/ds_config_zero1.json +++ b/egs/aishell/ASR/whisper/ds_config_zero1.json @@ -16,12 +16,18 @@ "reduce_bucket_size": 2e8, "contiguous_gradients": true }, + "optimizer": { + "type": "Adam", + "params": { + "lr": 1e-5 + } + }, "scheduler": { "type": "WarmupLR", "params": { - "warmup_min_lr": 1e-6, - "warmup_max_lr": 5e-6, - "warmup_num_steps": 100 + "warmup_min_lr": 0, + "warmup_max_lr": 1e-5, + "warmup_num_steps": 1000 } }, "gradient_accumulation_steps": 1, diff --git a/egs/aishell/ASR/whisper/model.py b/egs/aishell/ASR/whisper/model.py index 4e0ef28fa..2f8fea38c 100644 --- a/egs/aishell/ASR/whisper/model.py +++ b/egs/aishell/ASR/whisper/model.py @@ -276,7 +276,6 @@ class Whisper(nn.Module): @property def is_multilingual(self): - return self.dims.n_vocab == 51865 return self.dims.n_vocab >= 51865 @property diff --git a/egs/aishell/ASR/whisper/train.py b/egs/aishell/ASR/whisper/train.py index d2937c9ee..932242ddb 100644 --- a/egs/aishell/ASR/whisper/train.py +++ b/egs/aishell/ASR/whisper/train.py @@ -126,7 +126,7 @@ def get_parser(): parser.add_argument( "--num-epochs", type=int, - default=10, + default=5, help="Number of epochs to train.", ) @@ -649,7 +649,7 @@ def train_one_epoch( valid_info.write_summary( tb_writer, "train/valid_", params.batch_idx_train ) - + try: with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( @@ -732,7 +732,10 @@ def train_one_epoch( f"grad_scale is too small, exiting: {cur_grad_scale}" ) if batch_idx % params.log_interval == 0: - cur_lr = scheduler.get_last_lr()[0] + try: + cur_lr = scheduler.get_last_lr()[0] + except: + cur_lr = 0.0 cur_grad_scale = scaler._scale.item() if (params.use_fp16 and not params.deepspeed) else 1.0 logging.info( @@ -835,9 +838,8 @@ def run(rank, world_size, args): if world_size > 1: if params.deepspeed: logging.info("Using DeepSpeed") - model, optimizer, _, _ = deepspeed.initialize( - args=params, model=model, optimizer=optimizer, - model_parameters=model.parameters()) + model, optimizer, _, scheduler = deepspeed.initialize( + args=params, model=model, model_parameters=model.parameters()) else: logging.info("Using DDP") setup_dist(use_ddp_launch=True) @@ -877,7 +879,8 @@ def run(rank, world_size, args): logging.info(f"start training from epoch {params.start_epoch}") for epoch in range(params.start_epoch, params.num_epochs + 1): - scheduler.step_epoch(epoch - 1) + if not params.deepspeed: + scheduler.step_epoch(epoch - 1) fix_random_seed(params.seed + epoch - 1) train_dl.sampler.set_epoch(epoch - 1)