from local

This commit is contained in:
dohe0342 2023-04-11 15:21:08 +09:00
parent 75e857b5d5
commit 03d1d62967
2 changed files with 143 additions and 62 deletions

View File

@ -153,6 +153,13 @@ def add_adapter_arguments(parser: argparse.ArgumentParser):
help="adapter learning rate" help="adapter learning rate"
) )
parser.add_argument(
"--gender",
type=str,
default='male',
help="select gender"
)
def add_rep_arguments(parser: argparse.ArgumentParser): def add_rep_arguments(parser: argparse.ArgumentParser):
parser.add_argument( parser.add_argument(
@ -161,6 +168,13 @@ def add_rep_arguments(parser: argparse.ArgumentParser):
default=True, default=True,
help="Use wandb for MLOps", help="Use wandb for MLOps",
) )
parser.add_argument(
"--hpo",
type=str2bool,
default=False,
help="Use small db for HPO",
)
parser.add_argument( parser.add_argument(
"--accum-grads", "--accum-grads",
type=int, type=int,
@ -286,14 +300,14 @@ def add_model_arguments(parser: argparse.ArgumentParser):
parser.add_argument( parser.add_argument(
"--decoder-dim", "--decoder-dim",
type=int, type=int,
default=512, default=768,
help="Embedding dimension in the decoder model.", help="Embedding dimension in the decoder model.",
) )
parser.add_argument( parser.add_argument(
"--joiner-dim", "--joiner-dim",
type=int, type=int,
default=512, default=768,
help="""Dimension used in the joiner model. help="""Dimension used in the joiner model.
Outputs from the encoder and decoder model are projected Outputs from the encoder and decoder model are projected
to this dimension before adding. to this dimension before adding.
@ -333,6 +347,13 @@ def get_parser():
default=30, default=30,
help="Number of epochs to train.", help="Number of epochs to train.",
) )
parser.add_argument(
"--num-updates",
type=int,
default=5000,
help="Number of epochs to train.",
)
parser.add_argument( parser.add_argument(
"--start-epoch", "--start-epoch",
@ -461,7 +482,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--save-every-n", "--save-every-n",
type=int, type=int,
default=2000, default=200,
help="""Save checkpoint after processing this number of batches" help="""Save checkpoint after processing this number of batches"
periodically. We save checkpoint to exp-dir/ whenever periodically. We save checkpoint to exp-dir/ whenever
params.batch_idx_train % save_every_n == 0. The checkpoint filename params.batch_idx_train % save_every_n == 0. The checkpoint filename
@ -485,7 +506,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--average-period", "--average-period",
type=int, type=int,
default=200, default=10,
help="""Update the averaged model, namely `model_avg`, after processing help="""Update the averaged model, namely `model_avg`, after processing
this number of batches. `model_avg` is a separate version of model, this number of batches. `model_avg` is a separate version of model,
in which each floating-point parameter is the average of all the in which each floating-point parameter is the average of all the
@ -561,7 +582,7 @@ def get_params() -> AttributeDict:
"best_train_epoch": -1, "best_train_epoch": -1,
"best_valid_epoch": -1, "best_valid_epoch": -1,
"batch_idx_train": 0, "batch_idx_train": 0,
"log_interval": 50, "log_interval": 20,
"reset_interval": 200, "reset_interval": 200,
"valid_interval": 3000, # For the 100h subset, use 800 "valid_interval": 3000, # For the 100h subset, use 800
# parameters for zipformer # parameters for zipformer
@ -570,7 +591,8 @@ def get_params() -> AttributeDict:
# parameters for ctc loss # parameters for ctc loss
"beam_size": 10, "beam_size": 10,
"use_double_scores": True, "use_double_scores": True,
"warm_step": 4000, "warm_step": 0,
#"warm_step": 4000,
#"warm_step": 3000, #"warm_step": 3000,
"env_info": get_env_info(), "env_info": get_env_info(),
} }
@ -685,7 +707,7 @@ def load_checkpoint_if_available(
elif params.start_epoch > 1: elif params.start_epoch > 1:
filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
elif params.add_adapter: elif params.add_adapter:
filename = params.exp_dir / f"d2v-base-T.pt" filename = params.exp_dir / f"../d2v-base-T.pt"
else: else:
return None return None
@ -717,6 +739,8 @@ def load_checkpoint_if_available(
if "cur_batch_idx" in saved_params: if "cur_batch_idx" in saved_params:
params["cur_batch_idx"] = saved_params["cur_batch_idx"] params["cur_batch_idx"] = saved_params["cur_batch_idx"]
params.batch_idx_train = 0
return saved_params return saved_params
@ -818,6 +842,7 @@ def compute_loss(
warm_step = params.warm_step warm_step = params.warm_step
texts = batch["supervisions"]["text"] texts = batch["supervisions"]["text"]
token_ids = sp.encode(texts, out_type=int) token_ids = sp.encode(texts, out_type=int)
y = k2.RaggedTensor(token_ids).to(device) y = k2.RaggedTensor(token_ids).to(device)
@ -888,12 +913,21 @@ def compute_loss(
if decode: if decode:
model.eval() model.eval()
with torch.no_grad(): with torch.no_grad():
hypos = model.module.decode( try:
x=feature, hypos = model.module.decode(
x_lens=feature_lens, x=feature,
y=y, x_lens=feature_lens,
sp=sp y=y,
) sp=sp
)
except:
hypos = model.decode(
x=feature,
x_lens=feature_lens,
y=y,
sp=sp
)
logging.info(f'ref: {batch["supervisions"]["text"][0]}') logging.info(f'ref: {batch["supervisions"]["text"][0]}')
logging.info(f'hyp: {" ".join(hypos[0])}') logging.info(f'hyp: {" ".join(hypos[0])}')
model.train() model.train()
@ -1002,6 +1036,8 @@ def train_one_epoch(
scheduler_enc, scheduler_dec = scheduler[0], scheduler[1] scheduler_enc, scheduler_dec = scheduler[0], scheduler[1]
for batch_idx, batch in enumerate(train_dl): for batch_idx, batch in enumerate(train_dl):
if params.batch_idx_train > params.num_updates:
break
if batch_idx < cur_batch_idx: if batch_idx < cur_batch_idx:
continue continue
cur_batch_idx = batch_idx cur_batch_idx = batch_idx
@ -1019,7 +1055,9 @@ def train_one_epoch(
is_training=True, is_training=True,
decode = True if batch_idx % params.decode_interval == 0 else False, decode = True if batch_idx % params.decode_interval == 0 else False,
) )
loss_info.reduce(loss.device)
try: loss_info.reduce(loss.device)
except: pass
numel = params.world_size / (params.accum_grads * loss_info["utterances"]) numel = params.world_size / (params.accum_grads * loss_info["utterances"])
loss *= numel ## normalize loss over utts(batch size) loss *= numel ## normalize loss over utts(batch size)
@ -1053,7 +1091,8 @@ def train_one_epoch(
if params.print_diagnostics and batch_idx == 5: if params.print_diagnostics and batch_idx == 5:
return return
'''
if ( if (
rank == 0 rank == 0
and params.batch_idx_train > 0 and params.batch_idx_train > 0
@ -1064,6 +1103,7 @@ def train_one_epoch(
model_cur=model, model_cur=model,
model_avg=model_avg, model_avg=model_avg,
) )
'''
if ( if (
params.batch_idx_train > 0 params.batch_idx_train > 0
@ -1083,11 +1123,13 @@ def train_one_epoch(
rank=rank, rank=rank,
) )
del params.cur_batch_idx del params.cur_batch_idx
'''
remove_checkpoints( remove_checkpoints(
out_dir=params.exp_dir, out_dir=params.exp_dir,
topk=params.keep_last_k, topk=params.keep_last_k,
rank=rank, rank=rank,
) )
'''
if batch_idx % 100 == 0 and params.use_fp16: if batch_idx % 100 == 0 and params.use_fp16:
# If the grad scale was less than 1, try increasing it. The _growth_interval # If the grad scale was less than 1, try increasing it. The _growth_interval
@ -1106,13 +1148,16 @@ def train_one_epoch(
f"grad_scale is too small, exiting: {cur_grad_scale}" f"grad_scale is too small, exiting: {cur_grad_scale}"
) )
if params.batch_idx_train > 4000 and loss > 300 and params.wandb: #if params.batch_idx_train > 4000 and loss > 300 and params.wandb:
wb.log({"valid/loss": 10000}) # wb.log({"valid/loss": 10000})
raise RunteimError( # raise RuntimeError(
f"divergence... exiting: loss={loss}" # f"divergence... exiting: loss={loss}"
) # )
if batch_idx % (params.log_interval*params.accum_grads) == 0: if batch_idx % (params.log_interval*params.accum_grads) == 0:
#for n, p in model.named_parameters():
# if 'adapter' in n:
# print(p)
if params.multi_optim: if params.multi_optim:
cur_enc_lr = scheduler_enc.get_last_lr()[0] cur_enc_lr = scheduler_enc.get_last_lr()[0]
cur_dec_lr = scheduler_dec.get_last_lr()[0] cur_dec_lr = scheduler_dec.get_last_lr()[0]
@ -1169,7 +1214,8 @@ def train_one_epoch(
wb.log({"train/simple_loss": loss_info["simple_loss"]*numel}) wb.log({"train/simple_loss": loss_info["simple_loss"]*numel})
wb.log({"train/pruned_loss": loss_info["pruned_loss"]*numel}) wb.log({"train/pruned_loss": loss_info["pruned_loss"]*numel})
wb.log({"train/ctc_loss": loss_info["ctc_loss"]*numel}) wb.log({"train/ctc_loss": loss_info["ctc_loss"]*numel})
'''
logging.info("Computing validation loss") logging.info("Computing validation loss")
valid_info = compute_validation_loss( valid_info = compute_validation_loss(
params=params, params=params,
@ -1190,11 +1236,15 @@ def train_one_epoch(
if wb is not None and rank == 0: if wb is not None and rank == 0:
numel = 1 / (params.accum_grads * valid_info["utterances"]) numel = 1 / (params.accum_grads * valid_info["utterances"])
wb.log({"valid/loss": valid_info["loss"]*numel}) #wb.log({"valid/loss": valid_info["loss"]*numel})
wb.log({"valid/loss": numel*(valid_info["simple_loss"]
+valid_info["pruned_loss"]
+valid_info["ctc_loss"]
)})
wb.log({"valid/simple_loss": valid_info["simple_loss"]*numel}) wb.log({"valid/simple_loss": valid_info["simple_loss"]*numel})
wb.log({"valid/pruned_loss": valid_info["pruned_loss"]*numel}) wb.log({"valid/pruned_loss": valid_info["pruned_loss"]*numel})
wb.log({"valid/ctc_loss": valid_info["ctc_loss"]*numel}) wb.log({"valid/ctc_loss": valid_info["ctc_loss"]*numel})
'''
loss_value = tot_loss["loss"] / tot_loss["utterances"] loss_value = tot_loss["loss"] / tot_loss["utterances"]
params.train_loss = loss_value params.train_loss = loss_value
if params.train_loss < params.best_train_loss: if params.train_loss < params.best_train_loss:
@ -1247,7 +1297,6 @@ def run(rank, world_size, args, wb=None):
logging.info("About to create model") logging.info("About to create model")
model = get_transducer_model(params) model = get_transducer_model(params)
logging.info(model) logging.info(model)
exit()
num_param = sum([p.numel() for p in model.parameters()]) num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}") logging.info(f"Number of model parameters: {num_param}")
@ -1441,17 +1490,20 @@ def run(rank, world_size, args, wb=None):
if params.print_diagnostics: if params.print_diagnostics:
diagnostic.print_diagnostics() diagnostic.print_diagnostics()
break break
save_checkpoint( '''
params=params, if epoch % 50 == 0:
model=model, save_checkpoint(
model_avg=model_avg, params=params,
optimizer=optimizer, model=model,
scheduler=scheduler, model_avg=model_avg,
sampler=train_dl.sampler, optimizer=optimizer,
scaler=scaler, scheduler=scheduler,
rank=rank, sampler=train_dl.sampler,
) scaler=scaler,
rank=rank,
)
'''
logging.info("Done!") logging.info("Done!")
@ -1526,46 +1578,72 @@ def run_adapter(rank, world_size, args, wb=None):
adapter_names = [] adapter_names = []
adapter_param = [] adapter_param = []
for n, p in model.named_parameters(): for n, p in model.named_parameters():
if 'adapters' in n: if 'adapters' in n:# or 'joiner' in n or 'simple' in n or 'ctc' in n:
adapter_names.append(n) adapter_names.append(n)
adapter_param.append(p) adapter_param.append(p)
#else: elif 'joiner' in n or 'simple' in n or 'ctc' in n:
# p.requires_grad = False p.requires_grad = True
else:
p.requires_grad = False
optimizer_adapter = ScaledAdam( optimizer_adapter = ScaledAdam(
adapter_param, adapter_param,
lr=params.adapter_lr, lr=params.adapter_lr,
clipping_scale=5.0, clipping_scale=5.0,
parameters_names=[adapter_names], parameters_names=[adapter_names],
) )
scheduler_adapter = Eden(optimizer_adapter, 5000, 3.5) #params.lr_batche, params.lr_epochs) scheduler_adapter = Eden(optimizer_adapter, 10000, 7) #params.lr_batche, params.lr_epochs)
optimizer, scheduler = optimizer_adapter, scheduler_adapter optimizer, scheduler = optimizer_adapter, scheduler_adapter
librispeech = LibriSpeechAsrDataModule(args) librispeech = LibriSpeechAsrDataModule(args)
'''
if params.hpo:
train_cuts = librispeech.train_clean_10_cuts(option=params.gender)
else:
train_cuts = librispeech.train_clean_100_cuts(option=params.gender)
if params.full_libri:
train_cuts += librispeech.train_clean_360_cuts(option=params.gender)
train_cuts += librispeech.train_other_500_cuts(option=params.gender)
'''
train_cuts = librispeech.train_clean_100_cuts() #train_cuts = librispeech.train_clean_10_cuts(option='male')
if params.full_libri: #train_cuts = librispeech.test_clean_user(option='big')
train_cuts += librispeech.train_clean_360_cuts() train_cuts = librispeech.vox_cuts(option=params.spk_id)
train_cuts += librispeech.train_other_500_cuts()
def remove_short_and_long_utt(c: Cut): def remove_short_and_long_utt(c: Cut):
return 1.0 <= c.duration <= 20.0 return 1.0 <= c.duration <= 20.0
train_cuts = train_cuts.filter(remove_short_and_long_utt) train_cuts = train_cuts.filter(remove_short_and_long_utt)
sampler_state_dict = None sampler_state_dict = None
train_dl = librispeech.train_dataloaders( train_dl = librispeech.train_dataloaders(
train_cuts, sampler_state_dict=sampler_state_dict train_cuts, sampler_state_dict=sampler_state_dict
) )
#train_dl = librispeech.test_dataloaders(
# train_cuts
#)
'''
print('\n'*5)
print('-'*30)
for batch in train_dl:
print(batch)
print('-'*30)
print('\n'*5)
exit()
'''
valid_cuts = librispeech.dev_clean_cuts() valid_cuts = librispeech.dev_clean_cuts(option=params.gender)
valid_cuts += librispeech.dev_other_cuts() valid_cuts += librispeech.dev_other_cuts(option=params.gender)
valid_dl = librispeech.valid_dataloaders(valid_cuts) valid_dl = librispeech.valid_dataloaders(valid_cuts)
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
for epoch in range(params.start_epoch, params.num_epochs + 1): for epoch in range(params.start_epoch, params.num_epochs + 1):
logging.info(f"update num : {params.batch_idx_train}")
scheduler.step_epoch(epoch - 1) scheduler.step_epoch(epoch - 1)
fix_random_seed(params.seed + epoch - 1) fix_random_seed(params.seed + epoch - 1)
train_dl.sampler.set_epoch(epoch - 1) train_dl.sampler.set_epoch(epoch - 1)
@ -1594,17 +1672,20 @@ def run_adapter(rank, world_size, args, wb=None):
if params.print_diagnostics: if params.print_diagnostics:
diagnostic.print_diagnostics() diagnostic.print_diagnostics()
break break
save_checkpoint( '''
params=params, if epoch % 10 == 0:
model=model, save_checkpoint(
model_avg=model_avg, params=params,
optimizer=optimizer, model=model,
scheduler=scheduler, model_avg=model_avg,
sampler=train_dl.sampler, optimizer=optimizer,
scaler=scaler, scheduler=scheduler,
rank=rank, sampler=train_dl.sampler,
) scaler=scaler,
rank=rank,
)
'''
logging.info("Done!") logging.info("Done!")
@ -1691,13 +1772,13 @@ def main():
parser = get_parser() parser = get_parser()
LibriSpeechAsrDataModule.add_arguments(parser) LibriSpeechAsrDataModule.add_arguments(parser)
args = parser.parse_args() args = parser.parse_args()
#args.exp_dir = args.exp_dir + str(random.randint(0,400)) if args.wandb: args.exp_dir = args.exp_dir + str(random.randint(0,400))
args.exp_dir = Path(args.exp_dir) args.exp_dir = Path(args.exp_dir)
logging.info("save arguments to config.yaml...") logging.info("save arguments to config.yaml...")
save_args(args) save_args(args)
if args.wandb: wb = wandb.init(project="d2v-T", entity="dohe0342", config=vars(args)) if args.wandb: wb = wandb.init(project="d2v-adapter", entity="dohe0342", config=vars(args))
else: wb = None else: wb = None
world_size = args.world_size world_size = args.world_size
@ -1709,7 +1790,7 @@ def main():
join=True join=True
) )
else: else:
if not args.add_adapter: run(rank=0, world_size=1, args=args, wb=wb) if args.add_adapter: run_adapter(rank=0, world_size=1, args=args, wb=wb)
else: run(rank=0, world_size=1, args=args, wb=wb) else: run(rank=0, world_size=1, args=args, wb=wb)
torch.set_num_threads(1) torch.set_num_threads(1)