from local

This commit is contained in:
dohe0342 2023-04-11 15:21:08 +09:00
parent 75e857b5d5
commit 03d1d62967
2 changed files with 143 additions and 62 deletions

View File

@ -153,6 +153,13 @@ def add_adapter_arguments(parser: argparse.ArgumentParser):
help="adapter learning rate"
)
parser.add_argument(
"--gender",
type=str,
default='male',
help="select gender"
)
def add_rep_arguments(parser: argparse.ArgumentParser):
parser.add_argument(
@ -161,6 +168,13 @@ def add_rep_arguments(parser: argparse.ArgumentParser):
default=True,
help="Use wandb for MLOps",
)
parser.add_argument(
"--hpo",
type=str2bool,
default=False,
help="Use small db for HPO",
)
parser.add_argument(
"--accum-grads",
type=int,
@ -286,14 +300,14 @@ def add_model_arguments(parser: argparse.ArgumentParser):
parser.add_argument(
"--decoder-dim",
type=int,
default=512,
default=768,
help="Embedding dimension in the decoder model.",
)
parser.add_argument(
"--joiner-dim",
type=int,
default=512,
default=768,
help="""Dimension used in the joiner model.
Outputs from the encoder and decoder model are projected
to this dimension before adding.
@ -333,6 +347,13 @@ def get_parser():
default=30,
help="Number of epochs to train.",
)
parser.add_argument(
"--num-updates",
type=int,
default=5000,
help="Number of epochs to train.",
)
parser.add_argument(
"--start-epoch",
@ -461,7 +482,7 @@ def get_parser():
parser.add_argument(
"--save-every-n",
type=int,
default=2000,
default=200,
help="""Save checkpoint after processing this number of batches"
periodically. We save checkpoint to exp-dir/ whenever
params.batch_idx_train % save_every_n == 0. The checkpoint filename
@ -485,7 +506,7 @@ def get_parser():
parser.add_argument(
"--average-period",
type=int,
default=200,
default=10,
help="""Update the averaged model, namely `model_avg`, after processing
this number of batches. `model_avg` is a separate version of model,
in which each floating-point parameter is the average of all the
@ -561,7 +582,7 @@ def get_params() -> AttributeDict:
"best_train_epoch": -1,
"best_valid_epoch": -1,
"batch_idx_train": 0,
"log_interval": 50,
"log_interval": 20,
"reset_interval": 200,
"valid_interval": 3000, # For the 100h subset, use 800
# parameters for zipformer
@ -570,7 +591,8 @@ def get_params() -> AttributeDict:
# parameters for ctc loss
"beam_size": 10,
"use_double_scores": True,
"warm_step": 4000,
"warm_step": 0,
#"warm_step": 4000,
#"warm_step": 3000,
"env_info": get_env_info(),
}
@ -685,7 +707,7 @@ def load_checkpoint_if_available(
elif params.start_epoch > 1:
filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
elif params.add_adapter:
filename = params.exp_dir / f"d2v-base-T.pt"
filename = params.exp_dir / f"../d2v-base-T.pt"
else:
return None
@ -717,6 +739,8 @@ def load_checkpoint_if_available(
if "cur_batch_idx" in saved_params:
params["cur_batch_idx"] = saved_params["cur_batch_idx"]
params.batch_idx_train = 0
return saved_params
@ -818,6 +842,7 @@ def compute_loss(
warm_step = params.warm_step
texts = batch["supervisions"]["text"]
token_ids = sp.encode(texts, out_type=int)
y = k2.RaggedTensor(token_ids).to(device)
@ -888,12 +913,21 @@ def compute_loss(
if decode:
model.eval()
with torch.no_grad():
hypos = model.module.decode(
x=feature,
x_lens=feature_lens,
y=y,
sp=sp
)
try:
hypos = model.module.decode(
x=feature,
x_lens=feature_lens,
y=y,
sp=sp
)
except:
hypos = model.decode(
x=feature,
x_lens=feature_lens,
y=y,
sp=sp
)
logging.info(f'ref: {batch["supervisions"]["text"][0]}')
logging.info(f'hyp: {" ".join(hypos[0])}')
model.train()
@ -1002,6 +1036,8 @@ def train_one_epoch(
scheduler_enc, scheduler_dec = scheduler[0], scheduler[1]
for batch_idx, batch in enumerate(train_dl):
if params.batch_idx_train > params.num_updates:
break
if batch_idx < cur_batch_idx:
continue
cur_batch_idx = batch_idx
@ -1019,7 +1055,9 @@ def train_one_epoch(
is_training=True,
decode = True if batch_idx % params.decode_interval == 0 else False,
)
loss_info.reduce(loss.device)
try: loss_info.reduce(loss.device)
except: pass
numel = params.world_size / (params.accum_grads * loss_info["utterances"])
loss *= numel ## normalize loss over utts(batch size)
@ -1053,7 +1091,8 @@ def train_one_epoch(
if params.print_diagnostics and batch_idx == 5:
return
'''
if (
rank == 0
and params.batch_idx_train > 0
@ -1064,6 +1103,7 @@ def train_one_epoch(
model_cur=model,
model_avg=model_avg,
)
'''
if (
params.batch_idx_train > 0
@ -1083,11 +1123,13 @@ def train_one_epoch(
rank=rank,
)
del params.cur_batch_idx
'''
remove_checkpoints(
out_dir=params.exp_dir,
topk=params.keep_last_k,
rank=rank,
)
'''
if batch_idx % 100 == 0 and params.use_fp16:
# If the grad scale was less than 1, try increasing it. The _growth_interval
@ -1106,13 +1148,16 @@ def train_one_epoch(
f"grad_scale is too small, exiting: {cur_grad_scale}"
)
if params.batch_idx_train > 4000 and loss > 300 and params.wandb:
wb.log({"valid/loss": 10000})
raise RunteimError(
f"divergence... exiting: loss={loss}"
)
#if params.batch_idx_train > 4000 and loss > 300 and params.wandb:
# wb.log({"valid/loss": 10000})
# raise RuntimeError(
# f"divergence... exiting: loss={loss}"
# )
if batch_idx % (params.log_interval*params.accum_grads) == 0:
#for n, p in model.named_parameters():
# if 'adapter' in n:
# print(p)
if params.multi_optim:
cur_enc_lr = scheduler_enc.get_last_lr()[0]
cur_dec_lr = scheduler_dec.get_last_lr()[0]
@ -1169,7 +1214,8 @@ def train_one_epoch(
wb.log({"train/simple_loss": loss_info["simple_loss"]*numel})
wb.log({"train/pruned_loss": loss_info["pruned_loss"]*numel})
wb.log({"train/ctc_loss": loss_info["ctc_loss"]*numel})
'''
logging.info("Computing validation loss")
valid_info = compute_validation_loss(
params=params,
@ -1190,11 +1236,15 @@ def train_one_epoch(
if wb is not None and rank == 0:
numel = 1 / (params.accum_grads * valid_info["utterances"])
wb.log({"valid/loss": valid_info["loss"]*numel})
#wb.log({"valid/loss": valid_info["loss"]*numel})
wb.log({"valid/loss": numel*(valid_info["simple_loss"]
+valid_info["pruned_loss"]
+valid_info["ctc_loss"]
)})
wb.log({"valid/simple_loss": valid_info["simple_loss"]*numel})
wb.log({"valid/pruned_loss": valid_info["pruned_loss"]*numel})
wb.log({"valid/ctc_loss": valid_info["ctc_loss"]*numel})
'''
loss_value = tot_loss["loss"] / tot_loss["utterances"]
params.train_loss = loss_value
if params.train_loss < params.best_train_loss:
@ -1247,7 +1297,6 @@ def run(rank, world_size, args, wb=None):
logging.info("About to create model")
model = get_transducer_model(params)
logging.info(model)
exit()
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
@ -1441,17 +1490,20 @@ def run(rank, world_size, args, wb=None):
if params.print_diagnostics:
diagnostic.print_diagnostics()
break
save_checkpoint(
params=params,
model=model,
model_avg=model_avg,
optimizer=optimizer,
scheduler=scheduler,
sampler=train_dl.sampler,
scaler=scaler,
rank=rank,
)
'''
if epoch % 50 == 0:
save_checkpoint(
params=params,
model=model,
model_avg=model_avg,
optimizer=optimizer,
scheduler=scheduler,
sampler=train_dl.sampler,
scaler=scaler,
rank=rank,
)
'''
logging.info("Done!")
@ -1526,46 +1578,72 @@ def run_adapter(rank, world_size, args, wb=None):
adapter_names = []
adapter_param = []
for n, p in model.named_parameters():
if 'adapters' in n:
if 'adapters' in n:# or 'joiner' in n or 'simple' in n or 'ctc' in n:
adapter_names.append(n)
adapter_param.append(p)
#else:
# p.requires_grad = False
elif 'joiner' in n or 'simple' in n or 'ctc' in n:
p.requires_grad = True
else:
p.requires_grad = False
optimizer_adapter = ScaledAdam(
adapter_param,
lr=params.adapter_lr,
clipping_scale=5.0,
parameters_names=[adapter_names],
)
scheduler_adapter = Eden(optimizer_adapter, 5000, 3.5) #params.lr_batche, params.lr_epochs)
scheduler_adapter = Eden(optimizer_adapter, 10000, 7) #params.lr_batche, params.lr_epochs)
optimizer, scheduler = optimizer_adapter, scheduler_adapter
librispeech = LibriSpeechAsrDataModule(args)
'''
if params.hpo:
train_cuts = librispeech.train_clean_10_cuts(option=params.gender)
else:
train_cuts = librispeech.train_clean_100_cuts(option=params.gender)
if params.full_libri:
train_cuts += librispeech.train_clean_360_cuts(option=params.gender)
train_cuts += librispeech.train_other_500_cuts(option=params.gender)
'''
train_cuts = librispeech.train_clean_100_cuts()
if params.full_libri:
train_cuts += librispeech.train_clean_360_cuts()
train_cuts += librispeech.train_other_500_cuts()
#train_cuts = librispeech.train_clean_10_cuts(option='male')
#train_cuts = librispeech.test_clean_user(option='big')
train_cuts = librispeech.vox_cuts(option=params.spk_id)
def remove_short_and_long_utt(c: Cut):
return 1.0 <= c.duration <= 20.0
train_cuts = train_cuts.filter(remove_short_and_long_utt)
sampler_state_dict = None
train_dl = librispeech.train_dataloaders(
train_cuts, sampler_state_dict=sampler_state_dict
)
#train_dl = librispeech.test_dataloaders(
# train_cuts
#)
'''
print('\n'*5)
print('-'*30)
for batch in train_dl:
print(batch)
print('-'*30)
print('\n'*5)
exit()
'''
valid_cuts = librispeech.dev_clean_cuts()
valid_cuts += librispeech.dev_other_cuts()
valid_cuts = librispeech.dev_clean_cuts(option=params.gender)
valid_cuts += librispeech.dev_other_cuts(option=params.gender)
valid_dl = librispeech.valid_dataloaders(valid_cuts)
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
for epoch in range(params.start_epoch, params.num_epochs + 1):
logging.info(f"update num : {params.batch_idx_train}")
scheduler.step_epoch(epoch - 1)
fix_random_seed(params.seed + epoch - 1)
train_dl.sampler.set_epoch(epoch - 1)
@ -1594,17 +1672,20 @@ def run_adapter(rank, world_size, args, wb=None):
if params.print_diagnostics:
diagnostic.print_diagnostics()
break
save_checkpoint(
params=params,
model=model,
model_avg=model_avg,
optimizer=optimizer,
scheduler=scheduler,
sampler=train_dl.sampler,
scaler=scaler,
rank=rank,
)
'''
if epoch % 10 == 0:
save_checkpoint(
params=params,
model=model,
model_avg=model_avg,
optimizer=optimizer,
scheduler=scheduler,
sampler=train_dl.sampler,
scaler=scaler,
rank=rank,
)
'''
logging.info("Done!")
@ -1691,13 +1772,13 @@ def main():
parser = get_parser()
LibriSpeechAsrDataModule.add_arguments(parser)
args = parser.parse_args()
#args.exp_dir = args.exp_dir + str(random.randint(0,400))
if args.wandb: args.exp_dir = args.exp_dir + str(random.randint(0,400))
args.exp_dir = Path(args.exp_dir)
logging.info("save arguments to config.yaml...")
save_args(args)
if args.wandb: wb = wandb.init(project="d2v-T", entity="dohe0342", config=vars(args))
if args.wandb: wb = wandb.init(project="d2v-adapter", entity="dohe0342", config=vars(args))
else: wb = None
world_size = args.world_size
@ -1709,7 +1790,7 @@ def main():
join=True
)
else:
if not args.add_adapter: run(rank=0, world_size=1, args=args, wb=wb)
if args.add_adapter: run_adapter(rank=0, world_size=1, args=args, wb=wb)
else: run(rank=0, world_size=1, args=args, wb=wb)
torch.set_num_threads(1)