support deepspeed to finetune large model

This commit is contained in:
Yuekai Zhang 2024-01-12 16:14:10 +08:00
parent 92895f774f
commit b6418acda2
6 changed files with 162 additions and 215 deletions

View File

@ -182,7 +182,7 @@ class AishellAsrDataModule:
)
def train_dataloaders(
self, cuts_train: CutSet, sampler_state_dict: Optional[Dict[str, Any]] = None
self, cuts_train: CutSet, sampler_state_dict: Optional[Dict[str, Any]] = None, rank = None, world_size = None
) -> DataLoader:
"""
Args:
@ -276,6 +276,8 @@ class AishellAsrDataModule:
shuffle=self.args.shuffle,
num_buckets=self.args.num_buckets,
drop_last=self.args.drop_last,
world_size=world_size,
rank=rank,
)
else:
logging.info("Using SimpleCutSampler.")
@ -300,7 +302,7 @@ class AishellAsrDataModule:
return train_dl
def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
def valid_dataloaders(self, cuts_valid: CutSet, rank = None, world_size = None) -> DataLoader:
transforms = []
if self.args.concatenate_cuts:
transforms = [
@ -325,6 +327,8 @@ class AishellAsrDataModule:
cuts_valid,
max_duration=self.args.max_duration,
shuffle=False,
rank=rank,
world_size=world_size,
)
logging.info("About to create dev dataloader")
valid_dl = DataLoader(

View File

@ -109,20 +109,17 @@ def get_parser():
default="beam-search",
help="""Decoding method.
Supported values are:
- (0) ctc-decoding. Use CTC decoding. It maps the tokens ids to
tokens using token symbol tabel directly.
- (1) 1best. Extract the best path from the decoding lattice as the
decoding result.
- (2) nbest. Extract n paths from the decoding lattice; the path
with the highest score is the decoding result.
- (3) attention-decoder. Extract n paths from the lattice,
the path with the highest score is the decoding result.
- (4) nbest-oracle. Its WER is the lower bound of any n-best
rescoring method can achieve. Useful for debugging n-best
rescoring method.
- beam-search
""",
)
parser.add_argument(
"--beam-size",
type=int,
default=1,
help="beam size for beam search decoding",
)
parser.add_argument(
"--exp-dir",
type=str,
@ -357,10 +354,9 @@ def main():
params = get_params()
params.update(vars(args))
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
setup_logger(f"{params.exp_dir}/log-{params.method}/log-decode-{params.suffix}")
setup_logger(f"{params.exp_dir}/log-{params.method}-beam{params.beam_size}/log-decode-{params.suffix}")
#options = whisper.DecodingOptions(task="transcribe", language="zh", without_timestamps=True, beam_size=10)
options = whisper.DecodingOptions(task="transcribe", language="zh", without_timestamps=True, beam_size=None)
options = whisper.DecodingOptions(task="transcribe", language="zh", without_timestamps=True, beam_size=params.beam_size)
params.decoding_options = options
params.cleaner = BasicTextNormalizer()
params.normalizer = Normalizer()

View File

@ -0,0 +1,32 @@
{
"fp16": {
"enabled": true,
"loss_scale": 0,
"loss_scale_window": 1000,
"initial_scale_power": 16,
"hysteresis": 2,
"min_loss_scale": 1
},
"zero_optimization": {
"stage": 1,
"allgather_partitions": true,
"allgather_bucket_size": 2e8,
"overlap_comm": true,
"reduce_scatter": true,
"reduce_bucket_size": 2e8,
"contiguous_gradients": true
},
"scheduler": {
"type": "WarmupLR",
"params": {
"warmup_min_lr": 5e-6,
"warmup_max_lr": 1e-5,
"warmup_num_steps": 100
}
},
"gradient_accumulation_steps": 1,
"gradient_clipping": 5,
"steps_per_print": 50,
"train_micro_batch_size_per_gpu": 1,
"wall_clock_breakdown": false
}

View File

@ -372,7 +372,7 @@ def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]:
def load_model(
name: str,
device: Optional[Union[str, torch.device]] = None,
device: Optional[Union[str, torch.device]] = 'cpu',
download_root: str = None,
in_memory: bool = False,
) -> Whisper:
@ -397,8 +397,8 @@ def load_model(
The Whisper ASR model instance
"""
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
# if device is None:
# device = "cuda" if torch.cuda.is_available() else "cpu"
if download_root is None:
default = os.path.join(os.path.expanduser("~"), ".cache")
download_root = os.path.join(os.getenv("XDG_CACHE_HOME", default), "whisper")

View File

@ -8,3 +8,4 @@ librosa
openai-whisper
zhconv
WeTextProcessing
deepspeed

View File

@ -42,6 +42,8 @@ import warnings
from pathlib import Path
from shutil import copyfile
from typing import Any, Dict, Optional, Tuple, Union
import deepspeed
from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict
import k2
import optim
@ -102,15 +104,6 @@ def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None:
if hasattr(module, "batch_count"):
module.batch_count = batch_count
def add_deepspeed_arguments(parser: argparse.ArgumentParser):
parser.add_argument(
"--deepspeed-config",
type=str,
default=None,
help="Path to deepspeed json config file.",
)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
@ -251,7 +244,7 @@ def get_parser():
help="Whether to use half precision training.",
)
add_deepspeed_arguments(parser)
parser = deepspeed.add_config_arguments(parser)
return parser
@ -495,7 +488,6 @@ def compute_loss(
feature = feature.transpose(1, 2) # (N, C, T)
# pad feature from B,80,T to B,80,3000
#feature = torch.nn.functional.pad(feature, (0, 3000 - feature.shape[-1]))
#print(feature.shape, 23333333)
supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device)
@ -629,24 +621,25 @@ def train_one_epoch(
for batch_idx, batch in enumerate(train_dl):
params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"])
if batch_idx % params.valid_interval == 0 and not params.print_diagnostics:
logging.info("Computing validation loss")
valid_info = compute_validation_loss(
params=params,
tokenizer=tokenizer,
model=model,
valid_dl=valid_dl,
world_size=world_size,
)
model.train()
logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
logging.info(
f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
)
if tb_writer is not None:
valid_info.write_summary(
tb_writer, "train/valid_", params.batch_idx_train
)
# if batch_idx % params.valid_interval == 0 and not params.print_diagnostics:
# logging.info("Computing validation loss")
# valid_info = compute_validation_loss(
# params=params,
# tokenizer=tokenizer,
# model=model,
# valid_dl=valid_dl,
# world_size=world_size,
# )
# model.train()
# logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
# logging.info(
# f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
# )
# if tb_writer is not None:
# valid_info.write_summary(
# tb_writer, "train/valid_", params.batch_idx_train
# )
try:
with torch.cuda.amp.autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss(
@ -661,13 +654,20 @@ def train_one_epoch(
# NOTE: We use reduction==sum and loss is computed over utterances
# in the batch and there is no normalization to it so far.
scaler.scale(loss).backward()
set_batch_count(model, params.batch_idx_train)
scheduler.step_batch(params.batch_idx_train)
if params.deepspeed:
# deepspeed's backward() is different from torch's backward()
# in that it does not accept a loss tensor as input.
# It computes the loss internally.
model.backward(loss)
model.step()
else:
scaler.scale(loss).backward()
set_batch_count(model, params.batch_idx_train)
scheduler.step_batch(params.batch_idx_train)
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
except: # noqa
display_and_save_batch(batch, params=params)
raise
@ -679,6 +679,7 @@ def train_one_epoch(
rank == 0
and params.batch_idx_train > 0
and params.batch_idx_train % params.average_period == 0
and not params.deepspeed
):
update_averaged_model(
params=params,
@ -686,29 +687,28 @@ def train_one_epoch(
model_avg=model_avg,
)
if (
params.batch_idx_train > 0
and params.batch_idx_train % params.save_every_n == 0
):
save_checkpoint_with_global_batch_idx(
out_dir=params.exp_dir,
global_batch_idx=params.batch_idx_train,
model=model,
model_avg=model_avg,
params=params,
optimizer=optimizer,
scheduler=scheduler,
sampler=train_dl.sampler,
scaler=scaler,
rank=rank,
)
remove_checkpoints(
out_dir=params.exp_dir,
topk=params.keep_last_k,
rank=rank,
)
if batch_idx % 100 == 0 and params.use_fp16:
# if (
# params.batch_idx_train > 0
# and params.batch_idx_train % params.save_every_n == 0
# ):
# save_checkpoint_with_global_batch_idx(
# out_dir=params.exp_dir,
# global_batch_idx=params.batch_idx_train,
# model=model,
# model_avg=model_avg,
# params=params,
# optimizer=optimizer,
# scheduler=scheduler,
# sampler=train_dl.sampler,
# scaler=scaler,
# rank=rank,
# )
# remove_checkpoints(
# out_dir=params.exp_dir,
# topk=params.keep_last_k,
# rank=rank,
# )
if batch_idx % 100 == 0 and params.use_fp16 and not params.deepspeed:
# If the grad scale was less than 1, try increasing it. The _growth_interval
# of the grad scaler is configurable, but we can't configure it to have different
# behavior depending on the current grad scale.
@ -723,14 +723,14 @@ def train_one_epoch(
)
if batch_idx % params.log_interval == 0:
cur_lr = scheduler.get_last_lr()[0]
cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0
cur_grad_scale = scaler._scale.item() if (params.use_fp16 and not params.deepspeed) else 1.0
logging.info(
f"Epoch {params.cur_epoch}, "
f"batch {batch_idx}, loss[{loss_info}], "
f"tot_loss[{tot_loss}], batch size: {batch_size}, "
f"lr: {cur_lr:.2e}, "
+ (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "")
+ (f"grad_scale: {scaler._scale.item()}" if (params.use_fp16 and not params.deepspeed) else "")
)
if tb_writer is not None:
@ -774,37 +774,21 @@ def run(rank, world_size, args):
fix_random_seed(params.seed)
setup_dist(use_ddp_launch=True)
setup_logger(f"{params.exp_dir}/log/log-train")
logging.info("Training started")
if args.tensorboard and rank == 0:
tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
else:
tb_writer = None
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", rank)
logging.info(f"Device: {device}")
logging.info(params)
logging.info("About to create model")
#model = whisper.load_model("medium")
# TODO download model only on rank 0
# TODO may change compute validation loss using multiple cards
model = load_model("medium")
# model = load_model("medium")
model = load_model("large-v2")
del model.alignment_heads
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
tokenizer = whisper.tokenizer.get_tokenizer(
model.is_multilingual, language="zh", task="transcribe"
)
logging.info(params)
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
assert params.save_every_n >= params.average_period
model_avg: Optional[nn.Module] = None
@ -817,10 +801,12 @@ def run(rank, world_size, args):
params=params, model=model, model_avg=model_avg
)
if torch.cuda.is_available():
device = torch.device("cuda", rank)
else:
device = torch.device("cpu")
logging.info(f"Device: {device}")
model.to(device)
if world_size > 1:
logging.info("Using DDP")
model = DDP(model, device_ids=[rank], find_unused_parameters=True)
optimizer = torch.optim.AdamW(model.parameters(), lr=params.base_lr)
scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)
@ -837,6 +823,17 @@ def run(rank, world_size, args):
logging.info("Loading scheduler state dict")
scheduler.load_state_dict(checkpoints["scheduler"])
if world_size > 1:
if params.deepspeed:
logging.info("Using DeepSpeed")
model, optimizer, _, _ = deepspeed.initialize(
args=params, model=model, optimizer=optimizer,
model_parameters=model.parameters())
else:
logging.info("Using DDP")
setup_dist(use_ddp_launch=True)
model = DDP(model, device_ids=[rank], find_unused_parameters=True)
if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions(
2**22
@ -846,51 +843,8 @@ def run(rank, world_size, args):
if params.inf_check:
register_inf_check_hooks(model)
def remove_short_and_long_utt(c: Cut):
# Keep only utterances with duration between 1 second and 20 seconds
#
# Caution: There is a reason to select 20.0 here. Please see
# ../local/display_manifest_statistics.py
#
# You should use ../local/display_manifest_statistics.py to get
# an utterance duration distribution for your dataset to select
# the threshold
if c.duration < 1.0 or c.duration > 12.0:
logging.warning(
f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
)
return False
# In pruned RNN-T, we require that T >= S
# where T is the number of feature frames after subsampling
# and S is the number of tokens in the utterance
# In ./zipformer.py, the conv module uses the following expression
# for subsampling
# T = ((c.num_frames - 7) // 2 + 1) // 2
# tokens = sp.encode(c.supervisions[0].text, out_type=str)
# if T < len(tokens):
# logging.warning(
# f"Exclude cut with ID {c.id} from training. "
# f"Number of frames (before subsampling): {c.num_frames}. "
# f"Number of frames (after subsampling): {T}. "
# f"Text: {c.supervisions[0].text}. "
# f"Tokens: {tokens}. "
# f"Number of tokens: {len(tokens)}"
# )
# return False
return True
aishell = AishellAsrDataModule(args)
if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
# We only load the sampler's state dict when it loads a checkpoint
# saved in the middle of an epoch
@ -899,22 +853,19 @@ def run(rank, world_size, args):
sampler_state_dict = None
train_dl = aishell.train_dataloaders(aishell.train_cuts())
valid_dl = aishell.valid_dataloaders(aishell.valid_cuts())
# if not params.print_diagnostics:
# scan_pessimistic_batches_for_oom(
# model=model,
# train_dl=train_dl,
# optimizer=optimizer,
# graph_compiler=graph_compiler,
# params=params,
# )
train_dl = aishell.train_dataloaders(aishell.train_cuts(), rank=rank, world_size=world_size)
valid_dl = aishell.valid_dataloaders(aishell.valid_cuts(), rank=rank, world_size=world_size)
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"])
if args.tensorboard and rank == 0:
tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
else:
tb_writer = None
logging.info(f"start training from epoch {params.start_epoch}")
for epoch in range(params.start_epoch, params.num_epochs + 1):
scheduler.step_epoch(epoch - 1)
@ -945,20 +896,28 @@ def run(rank, world_size, args):
diagnostic.print_diagnostics()
break
save_checkpoint(
params=params,
model=model,
model_avg=model_avg,
optimizer=optimizer,
scheduler=scheduler,
sampler=train_dl.sampler,
scaler=scaler,
rank=rank,
)
if params.deepspeed:
model.save_checkpoint(save_dir=params.exp_dir,
tag=f"epoch-{params.cur_epoch}",
client_state={})
convert_zero_checkpoint_to_fp32_state_dict(
params.exp_dir, f"epoch-{params.cur_epoch}.pt",
tag=f"epoch-{params.cur_epoch}")
else:
save_checkpoint(
params=params,
model=model,
model_avg=model_avg,
optimizer=optimizer,
scheduler=scheduler,
sampler=train_dl.sampler,
scaler=scaler,
rank=rank,
)
logging.info("Done!")
if world_size > 1:
if world_size > 1 and not params.deepspeed:
torch.distributed.barrier()
cleanup_dist()
@ -988,48 +947,6 @@ def display_and_save_batch(
logging.info(f"features shape: {features.shape}")
# def scan_pessimistic_batches_for_oom(
# model: Union[nn.Module, DDP],
# tokenizer: whisper.tokenizer.Tokenizer,
# train_dl: torch.utils.data.DataLoader,
# optimizer: torch.optim.Optimizer,
# params: AttributeDict,
# ):
# from lhotse.dataset import find_pessimistic_batches
# logging.info(
# "Sanity check -- see if any of the batches in epoch 1 would cause OOM."
# )
# batches, crit_values = find_pessimistic_batches(train_dl.sampler)
# for criterion, cuts in batches.items():
# batch = train_dl.dataset[cuts]
# try:
# with torch.cuda.amp.autocast(enabled=params.use_fp16):
# loss, _ = compute_loss(
# params=params,
# tokenizer=tokenizer,
# model=model,
# batch=batch,
# is_training=True,
# )
# loss.backward()
# optimizer.zero_grad()
# except Exception as e:
# if "CUDA out of memory" in str(e):
# logging.error(
# "Your GPU ran out of memory with the current "
# "max_duration setting. We recommend decreasing "
# "max_duration and trying again.\n"
# f"Failing criterion: {criterion} "
# f"(={crit_values[criterion]}) ..."
# )
# display_and_save_batch(batch, params=params)
# raise
# logging.info(
# f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
# )
def main():
parser = get_parser()
AishellAsrDataModule.add_arguments(parser)
@ -1038,13 +955,10 @@ def main():
world_size = get_world_size()
rank = get_rank()
assert world_size >= 1
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
run(rank=rank, world_size=world_size, args=args)
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
if __name__ == "__main__":
main()