mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
from local
This commit is contained in:
parent
774c318931
commit
4cb463247e
Binary file not shown.
@ -1663,6 +1663,186 @@ def run_adapter(rank, world_size, args, wb=None):
|
||||
cleanup_dist()
|
||||
|
||||
|
||||
def run_adapter_uda(rank, world_size, args, wb=None):
|
||||
"""
|
||||
Args:
|
||||
rank:
|
||||
It is a value between 0 and `world_size-1`, which is
|
||||
passed automatically by `mp.spawn()` in :func:`main`.
|
||||
The node with rank 0 is responsible for saving checkpoint.
|
||||
world_size:
|
||||
Number of GPUs for DDP training.
|
||||
args:
|
||||
The return value of get_parser().parse_args()
|
||||
"""
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
|
||||
fix_random_seed(params.seed)
|
||||
if world_size > 1:
|
||||
setup_dist(rank, world_size, params.master_port)
|
||||
|
||||
setup_logger(f"{params.exp_dir}/log/log-train")
|
||||
logging.info("Training started")
|
||||
|
||||
if args.tensorboard and rank == 0:
|
||||
tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
|
||||
else:
|
||||
tb_writer = None
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", rank)
|
||||
logging.info(f"Device: {device}")
|
||||
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.load(params.bpe_model)
|
||||
|
||||
# <blk> is defined in local/train_bpe_model.py
|
||||
params.blank_id = sp.piece_to_id("<blk>")
|
||||
params.vocab_size = sp.get_piece_size()
|
||||
|
||||
logging.info(params)
|
||||
|
||||
logging.info("About to create model")
|
||||
model = get_transducer_model(params)
|
||||
|
||||
num_param = sum([p.numel() if p.requires_grad else 0 for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
assert params.save_every_n >= params.average_period
|
||||
model_avg: Optional[nn.Module] = None
|
||||
if rank == 0:
|
||||
# model_avg is only used with rank 0
|
||||
model_avg = copy.deepcopy(model).to(torch.float64)
|
||||
|
||||
assert params.start_epoch > 0, params.start_epoch
|
||||
checkpoints = load_checkpoint_if_available(
|
||||
params=params, model=model, model_avg=model_avg
|
||||
)
|
||||
|
||||
model.to(device)
|
||||
if world_size > 1:
|
||||
logging.info("Using DDP")
|
||||
model = DDP(model, device_ids=[rank], find_unused_parameters=True)
|
||||
|
||||
adapter_names = []
|
||||
adapter_param = []
|
||||
for n, p in model.named_parameters():
|
||||
if 'adapters' in n:# or 'joiner' in n or 'simple' in n or 'ctc' in n:
|
||||
adapter_names.append(n)
|
||||
adapter_param.append(p)
|
||||
elif 'joiner' in n or 'simple' in n or 'ctc' in n:
|
||||
p.requires_grad = True
|
||||
else:
|
||||
p.requires_grad = False
|
||||
|
||||
optimizer_adapter = ScaledAdam(
|
||||
adapter_param,
|
||||
lr=params.adapter_lr,
|
||||
clipping_scale=5.0,
|
||||
parameters_names=[adapter_names],
|
||||
)
|
||||
scheduler_adapter = Eden(optimizer_adapter, 10000, 7) #params.lr_batche, params.lr_epochs)
|
||||
|
||||
optimizer, scheduler = optimizer_adapter, scheduler_adapter
|
||||
|
||||
librispeech = LibriSpeechAsrDataModule(args)
|
||||
|
||||
'''
|
||||
if params.hpo:
|
||||
train_cuts = librispeech.train_clean_10_cuts(option=params.gender)
|
||||
else:
|
||||
train_cuts = librispeech.train_clean_100_cuts(option=params.gender)
|
||||
if params.full_libri:
|
||||
train_cuts += librispeech.train_clean_360_cuts(option=params.gender)
|
||||
train_cuts += librispeech.train_other_500_cuts(option=params.gender)
|
||||
'''
|
||||
|
||||
#train_cuts = librispeech.train_clean_10_cuts(option='male')
|
||||
#train_cuts = librispeech.test_clean_user(option='big')
|
||||
train_cuts = librispeech.vox_cuts(option=params.spk_id)
|
||||
|
||||
def remove_short_and_long_utt(c: Cut):
|
||||
return 1.0 <= c.duration <= 20.0
|
||||
|
||||
train_cuts = train_cuts.filter(remove_short_and_long_utt)
|
||||
|
||||
sampler_state_dict = None
|
||||
|
||||
train_dl = librispeech.train_dataloaders(
|
||||
train_cuts, sampler_state_dict=sampler_state_dict
|
||||
)
|
||||
#train_dl = librispeech.test_dataloaders(
|
||||
# train_cuts
|
||||
#)
|
||||
|
||||
'''
|
||||
print('\n'*5)
|
||||
print('-'*30)
|
||||
for batch in train_dl:
|
||||
print(batch)
|
||||
print('-'*30)
|
||||
print('\n'*5)
|
||||
exit()
|
||||
'''
|
||||
|
||||
valid_cuts = librispeech.dev_clean_cuts(option=params.gender)
|
||||
valid_cuts += librispeech.dev_other_cuts(option=params.gender)
|
||||
valid_dl = librispeech.valid_dataloaders(valid_cuts)
|
||||
|
||||
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
|
||||
|
||||
for epoch in range(params.start_epoch, params.num_epochs + 1):
|
||||
scheduler.step_epoch(epoch - 1)
|
||||
fix_random_seed(params.seed + epoch - 1)
|
||||
train_dl.sampler.set_epoch(epoch - 1)
|
||||
|
||||
if tb_writer is not None:
|
||||
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
|
||||
|
||||
params.cur_epoch = epoch
|
||||
|
||||
train_one_epoch(
|
||||
params=params,
|
||||
model=model,
|
||||
model_avg=model_avg,
|
||||
optimizer=optimizer,
|
||||
scheduler=scheduler,
|
||||
sp=sp,
|
||||
train_dl=train_dl,
|
||||
valid_dl=valid_dl,
|
||||
scaler=scaler,
|
||||
tb_writer=tb_writer,
|
||||
world_size=world_size,
|
||||
rank=rank,
|
||||
wb=wb,
|
||||
)
|
||||
|
||||
if params.print_diagnostics:
|
||||
diagnostic.print_diagnostics()
|
||||
break
|
||||
|
||||
if epoch % 10 == 0:
|
||||
save_checkpoint(
|
||||
params=params,
|
||||
model=model,
|
||||
model_avg=model_avg,
|
||||
optimizer=optimizer,
|
||||
scheduler=scheduler,
|
||||
sampler=train_dl.sampler,
|
||||
scaler=scaler,
|
||||
rank=rank,
|
||||
)
|
||||
|
||||
logging.info("Done!")
|
||||
|
||||
if world_size > 1:
|
||||
torch.distributed.barrier()
|
||||
cleanup_dist()
|
||||
|
||||
|
||||
|
||||
def display_and_save_batch(
|
||||
batch: dict,
|
||||
params: AttributeDict,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user