support deepspeed to finetune large model

This commit is contained in:
Yuekai Zhang 2024-01-12 16:14:10 +08:00
parent 92895f774f
commit b6418acda2
6 changed files with 162 additions and 215 deletions

View File

@ -182,7 +182,7 @@ class AishellAsrDataModule:
) )
def train_dataloaders( def train_dataloaders(
self, cuts_train: CutSet, sampler_state_dict: Optional[Dict[str, Any]] = None self, cuts_train: CutSet, sampler_state_dict: Optional[Dict[str, Any]] = None, rank = None, world_size = None
) -> DataLoader: ) -> DataLoader:
""" """
Args: Args:
@ -276,6 +276,8 @@ class AishellAsrDataModule:
shuffle=self.args.shuffle, shuffle=self.args.shuffle,
num_buckets=self.args.num_buckets, num_buckets=self.args.num_buckets,
drop_last=self.args.drop_last, drop_last=self.args.drop_last,
world_size=world_size,
rank=rank,
) )
else: else:
logging.info("Using SimpleCutSampler.") logging.info("Using SimpleCutSampler.")
@ -300,7 +302,7 @@ class AishellAsrDataModule:
return train_dl return train_dl
def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: def valid_dataloaders(self, cuts_valid: CutSet, rank = None, world_size = None) -> DataLoader:
transforms = [] transforms = []
if self.args.concatenate_cuts: if self.args.concatenate_cuts:
transforms = [ transforms = [
@ -325,6 +327,8 @@ class AishellAsrDataModule:
cuts_valid, cuts_valid,
max_duration=self.args.max_duration, max_duration=self.args.max_duration,
shuffle=False, shuffle=False,
rank=rank,
world_size=world_size,
) )
logging.info("About to create dev dataloader") logging.info("About to create dev dataloader")
valid_dl = DataLoader( valid_dl = DataLoader(

View File

@ -109,20 +109,17 @@ def get_parser():
default="beam-search", default="beam-search",
help="""Decoding method. help="""Decoding method.
Supported values are: Supported values are:
- (0) ctc-decoding. Use CTC decoding. It maps the tokens ids to - beam-search
tokens using token symbol tabel directly.
- (1) 1best. Extract the best path from the decoding lattice as the
decoding result.
- (2) nbest. Extract n paths from the decoding lattice; the path
with the highest score is the decoding result.
- (3) attention-decoder. Extract n paths from the lattice,
the path with the highest score is the decoding result.
- (4) nbest-oracle. Its WER is the lower bound of any n-best
rescoring method can achieve. Useful for debugging n-best
rescoring method.
""", """,
) )
parser.add_argument(
"--beam-size",
type=int,
default=1,
help="beam size for beam search decoding",
)
parser.add_argument( parser.add_argument(
"--exp-dir", "--exp-dir",
type=str, type=str,
@ -357,10 +354,9 @@ def main():
params = get_params() params = get_params()
params.update(vars(args)) params.update(vars(args))
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
setup_logger(f"{params.exp_dir}/log-{params.method}/log-decode-{params.suffix}") setup_logger(f"{params.exp_dir}/log-{params.method}-beam{params.beam_size}/log-decode-{params.suffix}")
#options = whisper.DecodingOptions(task="transcribe", language="zh", without_timestamps=True, beam_size=10) options = whisper.DecodingOptions(task="transcribe", language="zh", without_timestamps=True, beam_size=params.beam_size)
options = whisper.DecodingOptions(task="transcribe", language="zh", without_timestamps=True, beam_size=None)
params.decoding_options = options params.decoding_options = options
params.cleaner = BasicTextNormalizer() params.cleaner = BasicTextNormalizer()
params.normalizer = Normalizer() params.normalizer = Normalizer()

View File

@ -0,0 +1,32 @@
{
"fp16": {
"enabled": true,
"loss_scale": 0,
"loss_scale_window": 1000,
"initial_scale_power": 16,
"hysteresis": 2,
"min_loss_scale": 1
},
"zero_optimization": {
"stage": 1,
"allgather_partitions": true,
"allgather_bucket_size": 2e8,
"overlap_comm": true,
"reduce_scatter": true,
"reduce_bucket_size": 2e8,
"contiguous_gradients": true
},
"scheduler": {
"type": "WarmupLR",
"params": {
"warmup_min_lr": 5e-6,
"warmup_max_lr": 1e-5,
"warmup_num_steps": 100
}
},
"gradient_accumulation_steps": 1,
"gradient_clipping": 5,
"steps_per_print": 50,
"train_micro_batch_size_per_gpu": 1,
"wall_clock_breakdown": false
}

View File

@ -372,7 +372,7 @@ def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]:
def load_model( def load_model(
name: str, name: str,
device: Optional[Union[str, torch.device]] = None, device: Optional[Union[str, torch.device]] = 'cpu',
download_root: str = None, download_root: str = None,
in_memory: bool = False, in_memory: bool = False,
) -> Whisper: ) -> Whisper:
@ -397,8 +397,8 @@ def load_model(
The Whisper ASR model instance The Whisper ASR model instance
""" """
if device is None: # if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu" # device = "cuda" if torch.cuda.is_available() else "cpu"
if download_root is None: if download_root is None:
default = os.path.join(os.path.expanduser("~"), ".cache") default = os.path.join(os.path.expanduser("~"), ".cache")
download_root = os.path.join(os.getenv("XDG_CACHE_HOME", default), "whisper") download_root = os.path.join(os.getenv("XDG_CACHE_HOME", default), "whisper")

View File

@ -8,3 +8,4 @@ librosa
openai-whisper openai-whisper
zhconv zhconv
WeTextProcessing WeTextProcessing
deepspeed

View File

@ -42,6 +42,8 @@ import warnings
from pathlib import Path from pathlib import Path
from shutil import copyfile from shutil import copyfile
from typing import Any, Dict, Optional, Tuple, Union from typing import Any, Dict, Optional, Tuple, Union
import deepspeed
from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict
import k2 import k2
import optim import optim
@ -102,15 +104,6 @@ def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None:
if hasattr(module, "batch_count"): if hasattr(module, "batch_count"):
module.batch_count = batch_count module.batch_count = batch_count
def add_deepspeed_arguments(parser: argparse.ArgumentParser):
parser.add_argument(
"--deepspeed-config",
type=str,
default=None,
help="Path to deepspeed json config file.",
)
def get_parser(): def get_parser():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter formatter_class=argparse.ArgumentDefaultsHelpFormatter
@ -251,7 +244,7 @@ def get_parser():
help="Whether to use half precision training.", help="Whether to use half precision training.",
) )
add_deepspeed_arguments(parser) parser = deepspeed.add_config_arguments(parser)
return parser return parser
@ -495,7 +488,6 @@ def compute_loss(
feature = feature.transpose(1, 2) # (N, C, T) feature = feature.transpose(1, 2) # (N, C, T)
# pad feature from B,80,T to B,80,3000 # pad feature from B,80,T to B,80,3000
#feature = torch.nn.functional.pad(feature, (0, 3000 - feature.shape[-1])) #feature = torch.nn.functional.pad(feature, (0, 3000 - feature.shape[-1]))
#print(feature.shape, 23333333)
supervisions = batch["supervisions"] supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device) feature_lens = supervisions["num_frames"].to(device)
@ -629,24 +621,25 @@ def train_one_epoch(
for batch_idx, batch in enumerate(train_dl): for batch_idx, batch in enumerate(train_dl):
params.batch_idx_train += 1 params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: # if batch_idx % params.valid_interval == 0 and not params.print_diagnostics:
logging.info("Computing validation loss") # logging.info("Computing validation loss")
valid_info = compute_validation_loss( # valid_info = compute_validation_loss(
params=params, # params=params,
tokenizer=tokenizer, # tokenizer=tokenizer,
model=model, # model=model,
valid_dl=valid_dl, # valid_dl=valid_dl,
world_size=world_size, # world_size=world_size,
) # )
model.train() # model.train()
logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") # logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
logging.info( # logging.info(
f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" # f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
) # )
if tb_writer is not None: # if tb_writer is not None:
valid_info.write_summary( # valid_info.write_summary(
tb_writer, "train/valid_", params.batch_idx_train # tb_writer, "train/valid_", params.batch_idx_train
) # )
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch.cuda.amp.autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
@ -661,13 +654,20 @@ def train_one_epoch(
# NOTE: We use reduction==sum and loss is computed over utterances # NOTE: We use reduction==sum and loss is computed over utterances
# in the batch and there is no normalization to it so far. # in the batch and there is no normalization to it so far.
scaler.scale(loss).backward() if params.deepspeed:
set_batch_count(model, params.batch_idx_train) # deepspeed's backward() is different from torch's backward()
scheduler.step_batch(params.batch_idx_train) # in that it does not accept a loss tensor as input.
# It computes the loss internally.
model.backward(loss)
model.step()
else:
scaler.scale(loss).backward()
set_batch_count(model, params.batch_idx_train)
scheduler.step_batch(params.batch_idx_train)
scaler.step(optimizer) scaler.step(optimizer)
scaler.update() scaler.update()
optimizer.zero_grad() optimizer.zero_grad()
except: # noqa except: # noqa
display_and_save_batch(batch, params=params) display_and_save_batch(batch, params=params)
raise raise
@ -679,6 +679,7 @@ def train_one_epoch(
rank == 0 rank == 0
and params.batch_idx_train > 0 and params.batch_idx_train > 0
and params.batch_idx_train % params.average_period == 0 and params.batch_idx_train % params.average_period == 0
and not params.deepspeed
): ):
update_averaged_model( update_averaged_model(
params=params, params=params,
@ -686,29 +687,28 @@ def train_one_epoch(
model_avg=model_avg, model_avg=model_avg,
) )
if ( # if (
params.batch_idx_train > 0 # params.batch_idx_train > 0
and params.batch_idx_train % params.save_every_n == 0 # and params.batch_idx_train % params.save_every_n == 0
): # ):
save_checkpoint_with_global_batch_idx( # save_checkpoint_with_global_batch_idx(
out_dir=params.exp_dir, # out_dir=params.exp_dir,
global_batch_idx=params.batch_idx_train, # global_batch_idx=params.batch_idx_train,
model=model, # model=model,
model_avg=model_avg, # model_avg=model_avg,
params=params, # params=params,
optimizer=optimizer, # optimizer=optimizer,
scheduler=scheduler, # scheduler=scheduler,
sampler=train_dl.sampler, # sampler=train_dl.sampler,
scaler=scaler, # scaler=scaler,
rank=rank, # rank=rank,
) # )
remove_checkpoints( # remove_checkpoints(
out_dir=params.exp_dir, # out_dir=params.exp_dir,
topk=params.keep_last_k, # topk=params.keep_last_k,
rank=rank, # rank=rank,
) # )
if batch_idx % 100 == 0 and params.use_fp16 and not params.deepspeed:
if batch_idx % 100 == 0 and params.use_fp16:
# If the grad scale was less than 1, try increasing it. The _growth_interval # If the grad scale was less than 1, try increasing it. The _growth_interval
# of the grad scaler is configurable, but we can't configure it to have different # of the grad scaler is configurable, but we can't configure it to have different
# behavior depending on the current grad scale. # behavior depending on the current grad scale.
@ -723,14 +723,14 @@ def train_one_epoch(
) )
if batch_idx % params.log_interval == 0: if batch_idx % params.log_interval == 0:
cur_lr = scheduler.get_last_lr()[0] cur_lr = scheduler.get_last_lr()[0]
cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 cur_grad_scale = scaler._scale.item() if (params.use_fp16 and not params.deepspeed) else 1.0
logging.info( logging.info(
f"Epoch {params.cur_epoch}, " f"Epoch {params.cur_epoch}, "
f"batch {batch_idx}, loss[{loss_info}], " f"batch {batch_idx}, loss[{loss_info}], "
f"tot_loss[{tot_loss}], batch size: {batch_size}, " f"tot_loss[{tot_loss}], batch size: {batch_size}, "
f"lr: {cur_lr:.2e}, " f"lr: {cur_lr:.2e}, "
+ (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + (f"grad_scale: {scaler._scale.item()}" if (params.use_fp16 and not params.deepspeed) else "")
) )
if tb_writer is not None: if tb_writer is not None:
@ -774,37 +774,21 @@ def run(rank, world_size, args):
fix_random_seed(params.seed) fix_random_seed(params.seed)
setup_dist(use_ddp_launch=True)
setup_logger(f"{params.exp_dir}/log/log-train") setup_logger(f"{params.exp_dir}/log/log-train")
logging.info("Training started") logging.info(params)
if args.tensorboard and rank == 0:
tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
else:
tb_writer = None
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", rank)
logging.info(f"Device: {device}")
logging.info("About to create model") logging.info("About to create model")
#model = whisper.load_model("medium")
# TODO download model only on rank 0 # TODO download model only on rank 0
# TODO may change compute validation loss using multiple cards # TODO may change compute validation loss using multiple cards
model = load_model("medium") # model = load_model("medium")
model = load_model("large-v2")
del model.alignment_heads del model.alignment_heads
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
tokenizer = whisper.tokenizer.get_tokenizer( tokenizer = whisper.tokenizer.get_tokenizer(
model.is_multilingual, language="zh", task="transcribe" model.is_multilingual, language="zh", task="transcribe"
) )
logging.info(params)
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
assert params.save_every_n >= params.average_period assert params.save_every_n >= params.average_period
model_avg: Optional[nn.Module] = None model_avg: Optional[nn.Module] = None
@ -817,10 +801,12 @@ def run(rank, world_size, args):
params=params, model=model, model_avg=model_avg params=params, model=model, model_avg=model_avg
) )
if torch.cuda.is_available():
device = torch.device("cuda", rank)
else:
device = torch.device("cpu")
logging.info(f"Device: {device}")
model.to(device) model.to(device)
if world_size > 1:
logging.info("Using DDP")
model = DDP(model, device_ids=[rank], find_unused_parameters=True)
optimizer = torch.optim.AdamW(model.parameters(), lr=params.base_lr) optimizer = torch.optim.AdamW(model.parameters(), lr=params.base_lr)
scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)
@ -837,6 +823,17 @@ def run(rank, world_size, args):
logging.info("Loading scheduler state dict") logging.info("Loading scheduler state dict")
scheduler.load_state_dict(checkpoints["scheduler"]) scheduler.load_state_dict(checkpoints["scheduler"])
if world_size > 1:
if params.deepspeed:
logging.info("Using DeepSpeed")
model, optimizer, _, _ = deepspeed.initialize(
args=params, model=model, optimizer=optimizer,
model_parameters=model.parameters())
else:
logging.info("Using DDP")
setup_dist(use_ddp_launch=True)
model = DDP(model, device_ids=[rank], find_unused_parameters=True)
if params.print_diagnostics: if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions( opts = diagnostics.TensorDiagnosticOptions(
2**22 2**22
@ -846,51 +843,8 @@ def run(rank, world_size, args):
if params.inf_check: if params.inf_check:
register_inf_check_hooks(model) register_inf_check_hooks(model)
def remove_short_and_long_utt(c: Cut):
# Keep only utterances with duration between 1 second and 20 seconds
#
# Caution: There is a reason to select 20.0 here. Please see
# ../local/display_manifest_statistics.py
#
# You should use ../local/display_manifest_statistics.py to get
# an utterance duration distribution for your dataset to select
# the threshold
if c.duration < 1.0 or c.duration > 12.0:
logging.warning(
f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
)
return False
# In pruned RNN-T, we require that T >= S
# where T is the number of feature frames after subsampling
# and S is the number of tokens in the utterance
# In ./zipformer.py, the conv module uses the following expression
# for subsampling
# T = ((c.num_frames - 7) // 2 + 1) // 2
# tokens = sp.encode(c.supervisions[0].text, out_type=str)
# if T < len(tokens):
# logging.warning(
# f"Exclude cut with ID {c.id} from training. "
# f"Number of frames (before subsampling): {c.num_frames}. "
# f"Number of frames (after subsampling): {T}. "
# f"Text: {c.supervisions[0].text}. "
# f"Tokens: {tokens}. "
# f"Number of tokens: {len(tokens)}"
# )
# return False
return True
aishell = AishellAsrDataModule(args) aishell = AishellAsrDataModule(args)
if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
# We only load the sampler's state dict when it loads a checkpoint # We only load the sampler's state dict when it loads a checkpoint
# saved in the middle of an epoch # saved in the middle of an epoch
@ -899,22 +853,19 @@ def run(rank, world_size, args):
sampler_state_dict = None sampler_state_dict = None
train_dl = aishell.train_dataloaders(aishell.train_cuts()) train_dl = aishell.train_dataloaders(aishell.train_cuts(), rank=rank, world_size=world_size)
valid_dl = aishell.valid_dataloaders(aishell.valid_cuts()) valid_dl = aishell.valid_dataloaders(aishell.valid_cuts(), rank=rank, world_size=world_size)
# if not params.print_diagnostics:
# scan_pessimistic_batches_for_oom(
# model=model,
# train_dl=train_dl,
# optimizer=optimizer,
# graph_compiler=graph_compiler,
# params=params,
# )
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
if checkpoints and "grad_scaler" in checkpoints: if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict") logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"]) scaler.load_state_dict(checkpoints["grad_scaler"])
if args.tensorboard and rank == 0:
tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
else:
tb_writer = None
logging.info(f"start training from epoch {params.start_epoch}") logging.info(f"start training from epoch {params.start_epoch}")
for epoch in range(params.start_epoch, params.num_epochs + 1): for epoch in range(params.start_epoch, params.num_epochs + 1):
scheduler.step_epoch(epoch - 1) scheduler.step_epoch(epoch - 1)
@ -945,20 +896,28 @@ def run(rank, world_size, args):
diagnostic.print_diagnostics() diagnostic.print_diagnostics()
break break
save_checkpoint( if params.deepspeed:
params=params, model.save_checkpoint(save_dir=params.exp_dir,
model=model, tag=f"epoch-{params.cur_epoch}",
model_avg=model_avg, client_state={})
optimizer=optimizer, convert_zero_checkpoint_to_fp32_state_dict(
scheduler=scheduler, params.exp_dir, f"epoch-{params.cur_epoch}.pt",
sampler=train_dl.sampler, tag=f"epoch-{params.cur_epoch}")
scaler=scaler, else:
rank=rank, save_checkpoint(
) params=params,
model=model,
model_avg=model_avg,
optimizer=optimizer,
scheduler=scheduler,
sampler=train_dl.sampler,
scaler=scaler,
rank=rank,
)
logging.info("Done!") logging.info("Done!")
if world_size > 1: if world_size > 1 and not params.deepspeed:
torch.distributed.barrier() torch.distributed.barrier()
cleanup_dist() cleanup_dist()
@ -988,48 +947,6 @@ def display_and_save_batch(
logging.info(f"features shape: {features.shape}") logging.info(f"features shape: {features.shape}")
# def scan_pessimistic_batches_for_oom(
# model: Union[nn.Module, DDP],
# tokenizer: whisper.tokenizer.Tokenizer,
# train_dl: torch.utils.data.DataLoader,
# optimizer: torch.optim.Optimizer,
# params: AttributeDict,
# ):
# from lhotse.dataset import find_pessimistic_batches
# logging.info(
# "Sanity check -- see if any of the batches in epoch 1 would cause OOM."
# )
# batches, crit_values = find_pessimistic_batches(train_dl.sampler)
# for criterion, cuts in batches.items():
# batch = train_dl.dataset[cuts]
# try:
# with torch.cuda.amp.autocast(enabled=params.use_fp16):
# loss, _ = compute_loss(
# params=params,
# tokenizer=tokenizer,
# model=model,
# batch=batch,
# is_training=True,
# )
# loss.backward()
# optimizer.zero_grad()
# except Exception as e:
# if "CUDA out of memory" in str(e):
# logging.error(
# "Your GPU ran out of memory with the current "
# "max_duration setting. We recommend decreasing "
# "max_duration and trying again.\n"
# f"Failing criterion: {criterion} "
# f"(={crit_values[criterion]}) ..."
# )
# display_and_save_batch(batch, params=params)
# raise
# logging.info(
# f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
# )
def main(): def main():
parser = get_parser() parser = get_parser()
AishellAsrDataModule.add_arguments(parser) AishellAsrDataModule.add_arguments(parser)
@ -1038,13 +955,10 @@ def main():
world_size = get_world_size() world_size = get_world_size()
rank = get_rank() rank = get_rank()
assert world_size >= 1
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
run(rank=rank, world_size=world_size, args=args) run(rank=rank, world_size=world_size, args=args)
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
if __name__ == "__main__": if __name__ == "__main__":
main() main()