from local

This commit is contained in:
dohe0342 2023-03-23 23:00:56 +09:00
parent 774c318931
commit 4cb463247e
2 changed files with 180 additions and 0 deletions

View File

@ -1663,6 +1663,186 @@ def run_adapter(rank, world_size, args, wb=None):
cleanup_dist()
def run_adapter_uda(rank, world_size, args, wb=None):
"""
Args:
rank:
It is a value between 0 and `world_size-1`, which is
passed automatically by `mp.spawn()` in :func:`main`.
The node with rank 0 is responsible for saving checkpoint.
world_size:
Number of GPUs for DDP training.
args:
The return value of get_parser().parse_args()
"""
params = get_params()
params.update(vars(args))
fix_random_seed(params.seed)
if world_size > 1:
setup_dist(rank, world_size, params.master_port)
setup_logger(f"{params.exp_dir}/log/log-train")
logging.info("Training started")
if args.tensorboard and rank == 0:
tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
else:
tb_writer = None
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", rank)
logging.info(f"Device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size()
logging.info(params)
logging.info("About to create model")
model = get_transducer_model(params)
num_param = sum([p.numel() if p.requires_grad else 0 for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
assert params.save_every_n >= params.average_period
model_avg: Optional[nn.Module] = None
if rank == 0:
# model_avg is only used with rank 0
model_avg = copy.deepcopy(model).to(torch.float64)
assert params.start_epoch > 0, params.start_epoch
checkpoints = load_checkpoint_if_available(
params=params, model=model, model_avg=model_avg
)
model.to(device)
if world_size > 1:
logging.info("Using DDP")
model = DDP(model, device_ids=[rank], find_unused_parameters=True)
adapter_names = []
adapter_param = []
for n, p in model.named_parameters():
if 'adapters' in n:# or 'joiner' in n or 'simple' in n or 'ctc' in n:
adapter_names.append(n)
adapter_param.append(p)
elif 'joiner' in n or 'simple' in n or 'ctc' in n:
p.requires_grad = True
else:
p.requires_grad = False
optimizer_adapter = ScaledAdam(
adapter_param,
lr=params.adapter_lr,
clipping_scale=5.0,
parameters_names=[adapter_names],
)
scheduler_adapter = Eden(optimizer_adapter, 10000, 7) #params.lr_batche, params.lr_epochs)
optimizer, scheduler = optimizer_adapter, scheduler_adapter
librispeech = LibriSpeechAsrDataModule(args)
'''
if params.hpo:
train_cuts = librispeech.train_clean_10_cuts(option=params.gender)
else:
train_cuts = librispeech.train_clean_100_cuts(option=params.gender)
if params.full_libri:
train_cuts += librispeech.train_clean_360_cuts(option=params.gender)
train_cuts += librispeech.train_other_500_cuts(option=params.gender)
'''
#train_cuts = librispeech.train_clean_10_cuts(option='male')
#train_cuts = librispeech.test_clean_user(option='big')
train_cuts = librispeech.vox_cuts(option=params.spk_id)
def remove_short_and_long_utt(c: Cut):
return 1.0 <= c.duration <= 20.0
train_cuts = train_cuts.filter(remove_short_and_long_utt)
sampler_state_dict = None
train_dl = librispeech.train_dataloaders(
train_cuts, sampler_state_dict=sampler_state_dict
)
#train_dl = librispeech.test_dataloaders(
# train_cuts
#)
'''
print('\n'*5)
print('-'*30)
for batch in train_dl:
print(batch)
print('-'*30)
print('\n'*5)
exit()
'''
valid_cuts = librispeech.dev_clean_cuts(option=params.gender)
valid_cuts += librispeech.dev_other_cuts(option=params.gender)
valid_dl = librispeech.valid_dataloaders(valid_cuts)
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
for epoch in range(params.start_epoch, params.num_epochs + 1):
scheduler.step_epoch(epoch - 1)
fix_random_seed(params.seed + epoch - 1)
train_dl.sampler.set_epoch(epoch - 1)
if tb_writer is not None:
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
params.cur_epoch = epoch
train_one_epoch(
params=params,
model=model,
model_avg=model_avg,
optimizer=optimizer,
scheduler=scheduler,
sp=sp,
train_dl=train_dl,
valid_dl=valid_dl,
scaler=scaler,
tb_writer=tb_writer,
world_size=world_size,
rank=rank,
wb=wb,
)
if params.print_diagnostics:
diagnostic.print_diagnostics()
break
if epoch % 10 == 0:
save_checkpoint(
params=params,
model=model,
model_avg=model_avg,
optimizer=optimizer,
scheduler=scheduler,
sampler=train_dl.sampler,
scaler=scaler,
rank=rank,
)
logging.info("Done!")
if world_size > 1:
torch.distributed.barrier()
cleanup_dist()
def display_and_save_batch(
batch: dict,
params: AttributeDict,