mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-19 05:54:20 +00:00
from local
This commit is contained in:
parent
6109080f22
commit
4929de22dc
Binary file not shown.
Binary file not shown.
Binary file not shown.
@ -1274,6 +1274,252 @@ def run(rank, world_size, args, wb=None):
|
||||
"""
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
#params.warm_step *= params.accum_grads
|
||||
|
||||
fix_random_seed(params.seed)
|
||||
if world_size > 1:
|
||||
setup_dist(rank, world_size, params.master_port)
|
||||
|
||||
setup_logger(f"{params.exp_dir}/log/log-train")
|
||||
logging.info("Training started")
|
||||
|
||||
if args.tensorboard and rank == 0:
|
||||
tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
|
||||
else:
|
||||
tb_writer = None
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", rank)
|
||||
logging.info(f"Device: {device}")
|
||||
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.load(params.bpe_model)
|
||||
|
||||
# <blk> is defined in local/train_bpe_model.py
|
||||
params.blank_id = sp.piece_to_id("<blk>")
|
||||
params.vocab_size = sp.get_piece_size()
|
||||
|
||||
logging.info(params)
|
||||
|
||||
logging.info("About to create model")
|
||||
model = get_transducer_model(params)
|
||||
logging.info(model)
|
||||
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
assert params.save_every_n >= params.average_period
|
||||
model_avg: Optional[nn.Module] = None
|
||||
if rank == 0:
|
||||
# model_avg is only used with rank 0
|
||||
model_avg = copy.deepcopy(model).to(torch.float64)
|
||||
|
||||
assert params.start_epoch > 0, params.start_epoch
|
||||
checkpoints = load_checkpoint_if_available(
|
||||
params=params, model=model, model_avg=model_avg
|
||||
)
|
||||
|
||||
model.to(device)
|
||||
if world_size > 1:
|
||||
logging.info("Using DDP")
|
||||
model = DDP(model, device_ids=[rank], find_unused_parameters=True)
|
||||
|
||||
if params.multi_optim:
|
||||
logging.info("Using seperate optimizers over encoder, decoder ...")
|
||||
|
||||
enc_param = []
|
||||
enc_names = []
|
||||
|
||||
dec_names = []
|
||||
dec_param = []
|
||||
|
||||
for n, p in model.named_parameters():
|
||||
name = n.split('.')[1]
|
||||
if name == 'encoder' and 'feature_extractor' not in n:
|
||||
enc_names.append(n)
|
||||
enc_param.append(p)
|
||||
elif 'ctc_output' in n:
|
||||
enc_names.append(n)
|
||||
enc_param.append(p)
|
||||
elif 'feature_extractor' not in n:
|
||||
dec_names.append(n)
|
||||
dec_param.append(p)
|
||||
|
||||
optimizer_enc = ScaledAdam(
|
||||
enc_param,
|
||||
lr=params.peak_enc_lr,
|
||||
clipping_scale=None,
|
||||
parameters_names=[enc_names],
|
||||
)
|
||||
optimizer_dec = ScaledAdam(
|
||||
dec_param,
|
||||
lr=params.peak_dec_lr,
|
||||
clipping_scale=5.0,
|
||||
parameters_names=[dec_names],
|
||||
)
|
||||
|
||||
scheduler_enc = Eden(optimizer_enc, params.lr_batches, params.lr_epochs)
|
||||
scheduler_dec = Eden(optimizer_dec, params.lr_batches, params.lr_epochs)
|
||||
optimizer = [optimizer_enc, optimizer_dec]
|
||||
scheduler = [scheduler_enc, scheduler_dec]
|
||||
|
||||
else:
|
||||
parameters_names = []
|
||||
parameters_names.append(
|
||||
[name_param_pair[0] for name_param_pair in model.named_parameters()]
|
||||
)
|
||||
|
||||
logging.info(f"len name = {len(parameters_names)}")
|
||||
logging.info(f"len param = {len(list(model.parameters()))}")
|
||||
|
||||
optimizer = ScaledAdam(
|
||||
model.parameters(),
|
||||
lr=params.base_lr,
|
||||
clipping_scale=2.0,
|
||||
parameters_names=parameters_names,
|
||||
)
|
||||
|
||||
scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)
|
||||
|
||||
if checkpoints and ("optimizer" in checkpoints or "optimizer_enc" in checkpoints):
|
||||
if params.multi_optim:
|
||||
logging.info("Loading optimizer state dict")
|
||||
optimizer_enc.load_state_dict(checkpoints["optimizer_enc"])
|
||||
optimizer_dec.load_state_dict(checkpoints["optimizer_dec"])
|
||||
|
||||
else:
|
||||
logging.info("Loading optimizer state dict")
|
||||
optimizer.load_state_dict(checkpoints["optimizer"])
|
||||
|
||||
if checkpoints:
|
||||
if (
|
||||
params.multi_optim
|
||||
and "scheduler_enc" in checkpoints
|
||||
and checkpoints["scheduler_enc"] is not None
|
||||
):
|
||||
logging.info("Loading enc/dec scheduler state dict")
|
||||
scheduler_enc.load_state_dict(checkpoints["scheduler_enc"])
|
||||
scheduler_dec.load_state_dict(checkpoints["scheduler_dec"])
|
||||
else:
|
||||
logging.info("Loading scheduler state dict")
|
||||
scheduler.load_state_dict(checkpoints["scheduler"])
|
||||
|
||||
if params.print_diagnostics:
|
||||
opts = diagnostics.TensorDiagnosticOptions(
|
||||
2**22
|
||||
) # allow 4 megabytes per sub-module
|
||||
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
||||
|
||||
if params.inf_check:
|
||||
register_inf_check_hooks(model)
|
||||
|
||||
librispeech = LibriSpeechAsrDataModule(args)
|
||||
|
||||
train_cuts = librispeech.train_clean_100_cuts()
|
||||
if params.full_libri:
|
||||
train_cuts += librispeech.train_clean_360_cuts()
|
||||
train_cuts += librispeech.train_other_500_cuts()
|
||||
|
||||
def remove_short_and_long_utt(c: Cut):
|
||||
# Keep only utterances with duration between 1 second and 20 seconds
|
||||
#
|
||||
# Caution: There is a reason to select 20.0 here. Please see
|
||||
# ../local/display_manifest_statistics.py
|
||||
#
|
||||
# You should use ../local/display_manifest_statistics.py to get
|
||||
# an utterance duration distribution for your dataset to select
|
||||
# the threshold
|
||||
return 1.0 <= c.duration <= 20.0
|
||||
|
||||
train_cuts = train_cuts.filter(remove_short_and_long_utt)
|
||||
|
||||
if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
|
||||
# We only load the sampler's state dict when it loads a checkpoint
|
||||
# saved in the middle of an epoch
|
||||
sampler_state_dict = checkpoints["sampler"]
|
||||
else:
|
||||
sampler_state_dict = None
|
||||
|
||||
train_dl = librispeech.train_dataloaders(
|
||||
train_cuts, sampler_state_dict=sampler_state_dict
|
||||
)
|
||||
|
||||
valid_cuts = librispeech.dev_clean_cuts()
|
||||
valid_cuts += librispeech.dev_other_cuts()
|
||||
valid_dl = librispeech.valid_dataloaders(valid_cuts)
|
||||
|
||||
'''
|
||||
if not params.print_diagnostics:
|
||||
scan_pessimistic_batches_for_oom(
|
||||
model=model,
|
||||
train_dl=train_dl,
|
||||
optimizer=optimizer,
|
||||
sp=sp,
|
||||
params=params,
|
||||
)
|
||||
'''
|
||||
|
||||
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
|
||||
if checkpoints and "grad_scaler" in checkpoints:
|
||||
logging.info("Loading grad scaler state dict")
|
||||
scaler.load_state_dict(checkpoints["grad_scaler"])
|
||||
|
||||
for epoch in range(params.start_epoch, params.num_epochs + 1):
|
||||
if params.multi_optim:
|
||||
scheduler_enc.step_epoch(epoch - 1)
|
||||
scheduler_dec.step_epoch(epoch - 1)
|
||||
else:
|
||||
scheduler.step_epoch(epoch - 1)
|
||||
fix_random_seed(params.seed + epoch - 1)
|
||||
train_dl.sampler.set_epoch(epoch - 1)
|
||||
|
||||
if tb_writer is not None:
|
||||
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
|
||||
|
||||
params.cur_epoch = epoch
|
||||
|
||||
train_one_epoch(
|
||||
params=params,
|
||||
model=model,
|
||||
model_avg=model_avg,
|
||||
optimizer=optimizer,
|
||||
scheduler=scheduler,
|
||||
sp=sp,
|
||||
train_dl=train_dl,
|
||||
valid_dl=valid_dl,
|
||||
scaler=scaler,
|
||||
tb_writer=tb_writer,
|
||||
world_size=world_size,
|
||||
rank=rank,
|
||||
wb=wb,
|
||||
)
|
||||
|
||||
if params.print_diagnostics:
|
||||
diagnostic.print_diagnostics()
|
||||
break
|
||||
|
||||
logging.info("Done!")
|
||||
|
||||
if world_size > 1:
|
||||
torch.distributed.barrier()
|
||||
cleanup_dist()
|
||||
|
||||
|
||||
def run_adapter(rank, world_size, args, wb=None):
|
||||
"""
|
||||
Args:
|
||||
rank:
|
||||
It is a value between 0 and `world_size-1`, which is
|
||||
passed automatically by `mp.spawn()` in :func:`main`.
|
||||
The node with rank 0 is responsible for saving checkpoint.
|
||||
world_size:
|
||||
Number of GPUs for DDP training.
|
||||
args:
|
||||
The return value of get_parser().parse_args()
|
||||
"""
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
|
||||
fix_random_seed(params.seed)
|
||||
if world_size > 1:
|
||||
@ -1349,34 +1595,10 @@ def run(rank, world_size, args, wb=None):
|
||||
parameters_names=[adapter_names],
|
||||
)
|
||||
|
||||
#for n, p in model.named_parameters():
|
||||
# p.requires_grad = False
|
||||
|
||||
#prompt = torch.randn((100, 512), requires_grad=True)
|
||||
#optimizer_adapter = ScaledAdam(
|
||||
# [model.prompt],
|
||||
# lr=params.adapter_lr,
|
||||
# clipping_scale=5.0,
|
||||
# parameters_names=['P'],
|
||||
#)
|
||||
|
||||
scheduler_adapter = Eden(optimizer_adapter, 10000, 7) #params.lr_batche, params.lr_epochs)
|
||||
optimizer, scheduler = optimizer_adapter, scheduler_adapter
|
||||
|
||||
librispeech = LibriSpeechAsrDataModule(args)
|
||||
|
||||
'''
|
||||
if params.hpo:
|
||||
train_cuts = librispeech.train_clean_10_cuts(option=params.gender)
|
||||
else:
|
||||
train_cuts = librispeech.train_clean_100_cuts(option=params.gender)
|
||||
if params.full_libri:
|
||||
train_cuts += librispeech.train_clean_360_cuts(option=params.gender)
|
||||
train_cuts += librispeech.train_other_500_cuts(option=params.gender)
|
||||
'''
|
||||
|
||||
#train_cuts = librispeech.train_clean_10_cuts(option='male')
|
||||
#train_cuts = librispeech.test_clean_user(option='big')
|
||||
train_cuts = librispeech.vox_cuts(option=params.spk_id)
|
||||
|
||||
def remove_short_and_long_utt(c: Cut):
|
||||
@ -1389,19 +1611,6 @@ def run(rank, world_size, args, wb=None):
|
||||
train_dl = librispeech.train_dataloaders(
|
||||
train_cuts, sampler_state_dict=sampler_state_dict
|
||||
)
|
||||
#train_dl = librispeech.test_dataloaders(
|
||||
# train_cuts
|
||||
#)
|
||||
|
||||
'''
|
||||
print('\n'*5)
|
||||
print('-'*30)
|
||||
for batch in train_dl:
|
||||
print(batch)
|
||||
print('-'*30)
|
||||
print('\n'*5)
|
||||
exit()
|
||||
'''
|
||||
|
||||
valid_cuts = librispeech.dev_clean_cuts(option=params.gender)
|
||||
valid_cuts += librispeech.dev_other_cuts(option=params.gender)
|
||||
@ -1440,20 +1649,6 @@ def run(rank, world_size, args, wb=None):
|
||||
diagnostic.print_diagnostics()
|
||||
break
|
||||
|
||||
'''
|
||||
if epoch % 10 == 0:
|
||||
save_checkpoint(
|
||||
params=params,
|
||||
model=model,
|
||||
model_avg=model_avg,
|
||||
optimizer=optimizer,
|
||||
scheduler=scheduler,
|
||||
sampler=train_dl.sampler,
|
||||
scaler=scaler,
|
||||
rank=rank,
|
||||
)
|
||||
'''
|
||||
|
||||
logging.info("Done!")
|
||||
|
||||
if world_size > 1:
|
||||
|
Loading…
x
Reference in New Issue
Block a user