from local

This commit is contained in:
dohe0342 2023-06-09 15:24:03 +09:00
parent 6109080f22
commit 4929de22dc
4 changed files with 246 additions and 51 deletions

Binary file not shown.

View File

@ -1274,6 +1274,252 @@ def run(rank, world_size, args, wb=None):
"""
params = get_params()
params.update(vars(args))
#params.warm_step *= params.accum_grads
fix_random_seed(params.seed)
if world_size > 1:
setup_dist(rank, world_size, params.master_port)
setup_logger(f"{params.exp_dir}/log/log-train")
logging.info("Training started")
if args.tensorboard and rank == 0:
tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
else:
tb_writer = None
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", rank)
logging.info(f"Device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size()
logging.info(params)
logging.info("About to create model")
model = get_transducer_model(params)
logging.info(model)
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
assert params.save_every_n >= params.average_period
model_avg: Optional[nn.Module] = None
if rank == 0:
# model_avg is only used with rank 0
model_avg = copy.deepcopy(model).to(torch.float64)
assert params.start_epoch > 0, params.start_epoch
checkpoints = load_checkpoint_if_available(
params=params, model=model, model_avg=model_avg
)
model.to(device)
if world_size > 1:
logging.info("Using DDP")
model = DDP(model, device_ids=[rank], find_unused_parameters=True)
if params.multi_optim:
logging.info("Using seperate optimizers over encoder, decoder ...")
enc_param = []
enc_names = []
dec_names = []
dec_param = []
for n, p in model.named_parameters():
name = n.split('.')[1]
if name == 'encoder' and 'feature_extractor' not in n:
enc_names.append(n)
enc_param.append(p)
elif 'ctc_output' in n:
enc_names.append(n)
enc_param.append(p)
elif 'feature_extractor' not in n:
dec_names.append(n)
dec_param.append(p)
optimizer_enc = ScaledAdam(
enc_param,
lr=params.peak_enc_lr,
clipping_scale=None,
parameters_names=[enc_names],
)
optimizer_dec = ScaledAdam(
dec_param,
lr=params.peak_dec_lr,
clipping_scale=5.0,
parameters_names=[dec_names],
)
scheduler_enc = Eden(optimizer_enc, params.lr_batches, params.lr_epochs)
scheduler_dec = Eden(optimizer_dec, params.lr_batches, params.lr_epochs)
optimizer = [optimizer_enc, optimizer_dec]
scheduler = [scheduler_enc, scheduler_dec]
else:
parameters_names = []
parameters_names.append(
[name_param_pair[0] for name_param_pair in model.named_parameters()]
)
logging.info(f"len name = {len(parameters_names)}")
logging.info(f"len param = {len(list(model.parameters()))}")
optimizer = ScaledAdam(
model.parameters(),
lr=params.base_lr,
clipping_scale=2.0,
parameters_names=parameters_names,
)
scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)
if checkpoints and ("optimizer" in checkpoints or "optimizer_enc" in checkpoints):
if params.multi_optim:
logging.info("Loading optimizer state dict")
optimizer_enc.load_state_dict(checkpoints["optimizer_enc"])
optimizer_dec.load_state_dict(checkpoints["optimizer_dec"])
else:
logging.info("Loading optimizer state dict")
optimizer.load_state_dict(checkpoints["optimizer"])
if checkpoints:
if (
params.multi_optim
and "scheduler_enc" in checkpoints
and checkpoints["scheduler_enc"] is not None
):
logging.info("Loading enc/dec scheduler state dict")
scheduler_enc.load_state_dict(checkpoints["scheduler_enc"])
scheduler_dec.load_state_dict(checkpoints["scheduler_dec"])
else:
logging.info("Loading scheduler state dict")
scheduler.load_state_dict(checkpoints["scheduler"])
if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions(
2**22
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)
if params.inf_check:
register_inf_check_hooks(model)
librispeech = LibriSpeechAsrDataModule(args)
train_cuts = librispeech.train_clean_100_cuts()
if params.full_libri:
train_cuts += librispeech.train_clean_360_cuts()
train_cuts += librispeech.train_other_500_cuts()
def remove_short_and_long_utt(c: Cut):
# Keep only utterances with duration between 1 second and 20 seconds
#
# Caution: There is a reason to select 20.0 here. Please see
# ../local/display_manifest_statistics.py
#
# You should use ../local/display_manifest_statistics.py to get
# an utterance duration distribution for your dataset to select
# the threshold
return 1.0 <= c.duration <= 20.0
train_cuts = train_cuts.filter(remove_short_and_long_utt)
if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
# We only load the sampler's state dict when it loads a checkpoint
# saved in the middle of an epoch
sampler_state_dict = checkpoints["sampler"]
else:
sampler_state_dict = None
train_dl = librispeech.train_dataloaders(
train_cuts, sampler_state_dict=sampler_state_dict
)
valid_cuts = librispeech.dev_clean_cuts()
valid_cuts += librispeech.dev_other_cuts()
valid_dl = librispeech.valid_dataloaders(valid_cuts)
'''
if not params.print_diagnostics:
scan_pessimistic_batches_for_oom(
model=model,
train_dl=train_dl,
optimizer=optimizer,
sp=sp,
params=params,
)
'''
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"])
for epoch in range(params.start_epoch, params.num_epochs + 1):
if params.multi_optim:
scheduler_enc.step_epoch(epoch - 1)
scheduler_dec.step_epoch(epoch - 1)
else:
scheduler.step_epoch(epoch - 1)
fix_random_seed(params.seed + epoch - 1)
train_dl.sampler.set_epoch(epoch - 1)
if tb_writer is not None:
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
params.cur_epoch = epoch
train_one_epoch(
params=params,
model=model,
model_avg=model_avg,
optimizer=optimizer,
scheduler=scheduler,
sp=sp,
train_dl=train_dl,
valid_dl=valid_dl,
scaler=scaler,
tb_writer=tb_writer,
world_size=world_size,
rank=rank,
wb=wb,
)
if params.print_diagnostics:
diagnostic.print_diagnostics()
break
logging.info("Done!")
if world_size > 1:
torch.distributed.barrier()
cleanup_dist()
def run_adapter(rank, world_size, args, wb=None):
"""
Args:
rank:
It is a value between 0 and `world_size-1`, which is
passed automatically by `mp.spawn()` in :func:`main`.
The node with rank 0 is responsible for saving checkpoint.
world_size:
Number of GPUs for DDP training.
args:
The return value of get_parser().parse_args()
"""
params = get_params()
params.update(vars(args))
fix_random_seed(params.seed)
if world_size > 1:
@ -1349,34 +1595,10 @@ def run(rank, world_size, args, wb=None):
parameters_names=[adapter_names],
)
#for n, p in model.named_parameters():
# p.requires_grad = False
#prompt = torch.randn((100, 512), requires_grad=True)
#optimizer_adapter = ScaledAdam(
# [model.prompt],
# lr=params.adapter_lr,
# clipping_scale=5.0,
# parameters_names=['P'],
#)
scheduler_adapter = Eden(optimizer_adapter, 10000, 7) #params.lr_batche, params.lr_epochs)
optimizer, scheduler = optimizer_adapter, scheduler_adapter
librispeech = LibriSpeechAsrDataModule(args)
'''
if params.hpo:
train_cuts = librispeech.train_clean_10_cuts(option=params.gender)
else:
train_cuts = librispeech.train_clean_100_cuts(option=params.gender)
if params.full_libri:
train_cuts += librispeech.train_clean_360_cuts(option=params.gender)
train_cuts += librispeech.train_other_500_cuts(option=params.gender)
'''
#train_cuts = librispeech.train_clean_10_cuts(option='male')
#train_cuts = librispeech.test_clean_user(option='big')
train_cuts = librispeech.vox_cuts(option=params.spk_id)
def remove_short_and_long_utt(c: Cut):
@ -1389,19 +1611,6 @@ def run(rank, world_size, args, wb=None):
train_dl = librispeech.train_dataloaders(
train_cuts, sampler_state_dict=sampler_state_dict
)
#train_dl = librispeech.test_dataloaders(
# train_cuts
#)
'''
print('\n'*5)
print('-'*30)
for batch in train_dl:
print(batch)
print('-'*30)
print('\n'*5)
exit()
'''
valid_cuts = librispeech.dev_clean_cuts(option=params.gender)
valid_cuts += librispeech.dev_other_cuts(option=params.gender)
@ -1440,20 +1649,6 @@ def run(rank, world_size, args, wb=None):
diagnostic.print_diagnostics()
break
'''
if epoch % 10 == 0:
save_checkpoint(
params=params,
model=model,
model_avg=model_avg,
optimizer=optimizer,
scheduler=scheduler,
sampler=train_dl.sampler,
scaler=scaler,
rank=rank,
)
'''
logging.info("Done!")
if world_size > 1: