mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
from local
This commit is contained in:
parent
75e857b5d5
commit
03d1d62967
Binary file not shown.
@ -153,6 +153,13 @@ def add_adapter_arguments(parser: argparse.ArgumentParser):
|
||||
help="adapter learning rate"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--gender",
|
||||
type=str,
|
||||
default='male',
|
||||
help="select gender"
|
||||
)
|
||||
|
||||
|
||||
def add_rep_arguments(parser: argparse.ArgumentParser):
|
||||
parser.add_argument(
|
||||
@ -161,6 +168,13 @@ def add_rep_arguments(parser: argparse.ArgumentParser):
|
||||
default=True,
|
||||
help="Use wandb for MLOps",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--hpo",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="Use small db for HPO",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--accum-grads",
|
||||
type=int,
|
||||
@ -286,14 +300,14 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
||||
parser.add_argument(
|
||||
"--decoder-dim",
|
||||
type=int,
|
||||
default=512,
|
||||
default=768,
|
||||
help="Embedding dimension in the decoder model.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--joiner-dim",
|
||||
type=int,
|
||||
default=512,
|
||||
default=768,
|
||||
help="""Dimension used in the joiner model.
|
||||
Outputs from the encoder and decoder model are projected
|
||||
to this dimension before adding.
|
||||
@ -333,6 +347,13 @@ def get_parser():
|
||||
default=30,
|
||||
help="Number of epochs to train.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-updates",
|
||||
type=int,
|
||||
default=5000,
|
||||
help="Number of epochs to train.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--start-epoch",
|
||||
@ -461,7 +482,7 @@ def get_parser():
|
||||
parser.add_argument(
|
||||
"--save-every-n",
|
||||
type=int,
|
||||
default=2000,
|
||||
default=200,
|
||||
help="""Save checkpoint after processing this number of batches"
|
||||
periodically. We save checkpoint to exp-dir/ whenever
|
||||
params.batch_idx_train % save_every_n == 0. The checkpoint filename
|
||||
@ -485,7 +506,7 @@ def get_parser():
|
||||
parser.add_argument(
|
||||
"--average-period",
|
||||
type=int,
|
||||
default=200,
|
||||
default=10,
|
||||
help="""Update the averaged model, namely `model_avg`, after processing
|
||||
this number of batches. `model_avg` is a separate version of model,
|
||||
in which each floating-point parameter is the average of all the
|
||||
@ -561,7 +582,7 @@ def get_params() -> AttributeDict:
|
||||
"best_train_epoch": -1,
|
||||
"best_valid_epoch": -1,
|
||||
"batch_idx_train": 0,
|
||||
"log_interval": 50,
|
||||
"log_interval": 20,
|
||||
"reset_interval": 200,
|
||||
"valid_interval": 3000, # For the 100h subset, use 800
|
||||
# parameters for zipformer
|
||||
@ -570,7 +591,8 @@ def get_params() -> AttributeDict:
|
||||
# parameters for ctc loss
|
||||
"beam_size": 10,
|
||||
"use_double_scores": True,
|
||||
"warm_step": 4000,
|
||||
"warm_step": 0,
|
||||
#"warm_step": 4000,
|
||||
#"warm_step": 3000,
|
||||
"env_info": get_env_info(),
|
||||
}
|
||||
@ -685,7 +707,7 @@ def load_checkpoint_if_available(
|
||||
elif params.start_epoch > 1:
|
||||
filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
|
||||
elif params.add_adapter:
|
||||
filename = params.exp_dir / f"d2v-base-T.pt"
|
||||
filename = params.exp_dir / f"../d2v-base-T.pt"
|
||||
else:
|
||||
return None
|
||||
|
||||
@ -717,6 +739,8 @@ def load_checkpoint_if_available(
|
||||
if "cur_batch_idx" in saved_params:
|
||||
params["cur_batch_idx"] = saved_params["cur_batch_idx"]
|
||||
|
||||
params.batch_idx_train = 0
|
||||
|
||||
return saved_params
|
||||
|
||||
|
||||
@ -818,6 +842,7 @@ def compute_loss(
|
||||
warm_step = params.warm_step
|
||||
|
||||
texts = batch["supervisions"]["text"]
|
||||
|
||||
token_ids = sp.encode(texts, out_type=int)
|
||||
y = k2.RaggedTensor(token_ids).to(device)
|
||||
|
||||
@ -888,12 +913,21 @@ def compute_loss(
|
||||
if decode:
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
hypos = model.module.decode(
|
||||
x=feature,
|
||||
x_lens=feature_lens,
|
||||
y=y,
|
||||
sp=sp
|
||||
)
|
||||
try:
|
||||
hypos = model.module.decode(
|
||||
x=feature,
|
||||
x_lens=feature_lens,
|
||||
y=y,
|
||||
sp=sp
|
||||
)
|
||||
except:
|
||||
hypos = model.decode(
|
||||
x=feature,
|
||||
x_lens=feature_lens,
|
||||
y=y,
|
||||
sp=sp
|
||||
)
|
||||
|
||||
logging.info(f'ref: {batch["supervisions"]["text"][0]}')
|
||||
logging.info(f'hyp: {" ".join(hypos[0])}')
|
||||
model.train()
|
||||
@ -1002,6 +1036,8 @@ def train_one_epoch(
|
||||
scheduler_enc, scheduler_dec = scheduler[0], scheduler[1]
|
||||
|
||||
for batch_idx, batch in enumerate(train_dl):
|
||||
if params.batch_idx_train > params.num_updates:
|
||||
break
|
||||
if batch_idx < cur_batch_idx:
|
||||
continue
|
||||
cur_batch_idx = batch_idx
|
||||
@ -1019,7 +1055,9 @@ def train_one_epoch(
|
||||
is_training=True,
|
||||
decode = True if batch_idx % params.decode_interval == 0 else False,
|
||||
)
|
||||
loss_info.reduce(loss.device)
|
||||
|
||||
try: loss_info.reduce(loss.device)
|
||||
except: pass
|
||||
|
||||
numel = params.world_size / (params.accum_grads * loss_info["utterances"])
|
||||
loss *= numel ## normalize loss over utts(batch size)
|
||||
@ -1053,7 +1091,8 @@ def train_one_epoch(
|
||||
|
||||
if params.print_diagnostics and batch_idx == 5:
|
||||
return
|
||||
|
||||
|
||||
'''
|
||||
if (
|
||||
rank == 0
|
||||
and params.batch_idx_train > 0
|
||||
@ -1064,6 +1103,7 @@ def train_one_epoch(
|
||||
model_cur=model,
|
||||
model_avg=model_avg,
|
||||
)
|
||||
'''
|
||||
|
||||
if (
|
||||
params.batch_idx_train > 0
|
||||
@ -1083,11 +1123,13 @@ def train_one_epoch(
|
||||
rank=rank,
|
||||
)
|
||||
del params.cur_batch_idx
|
||||
'''
|
||||
remove_checkpoints(
|
||||
out_dir=params.exp_dir,
|
||||
topk=params.keep_last_k,
|
||||
rank=rank,
|
||||
)
|
||||
'''
|
||||
|
||||
if batch_idx % 100 == 0 and params.use_fp16:
|
||||
# If the grad scale was less than 1, try increasing it. The _growth_interval
|
||||
@ -1106,13 +1148,16 @@ def train_one_epoch(
|
||||
f"grad_scale is too small, exiting: {cur_grad_scale}"
|
||||
)
|
||||
|
||||
if params.batch_idx_train > 4000 and loss > 300 and params.wandb:
|
||||
wb.log({"valid/loss": 10000})
|
||||
raise RunteimError(
|
||||
f"divergence... exiting: loss={loss}"
|
||||
)
|
||||
#if params.batch_idx_train > 4000 and loss > 300 and params.wandb:
|
||||
# wb.log({"valid/loss": 10000})
|
||||
# raise RuntimeError(
|
||||
# f"divergence... exiting: loss={loss}"
|
||||
# )
|
||||
|
||||
if batch_idx % (params.log_interval*params.accum_grads) == 0:
|
||||
#for n, p in model.named_parameters():
|
||||
# if 'adapter' in n:
|
||||
# print(p)
|
||||
if params.multi_optim:
|
||||
cur_enc_lr = scheduler_enc.get_last_lr()[0]
|
||||
cur_dec_lr = scheduler_dec.get_last_lr()[0]
|
||||
@ -1169,7 +1214,8 @@ def train_one_epoch(
|
||||
wb.log({"train/simple_loss": loss_info["simple_loss"]*numel})
|
||||
wb.log({"train/pruned_loss": loss_info["pruned_loss"]*numel})
|
||||
wb.log({"train/ctc_loss": loss_info["ctc_loss"]*numel})
|
||||
|
||||
|
||||
'''
|
||||
logging.info("Computing validation loss")
|
||||
valid_info = compute_validation_loss(
|
||||
params=params,
|
||||
@ -1190,11 +1236,15 @@ def train_one_epoch(
|
||||
|
||||
if wb is not None and rank == 0:
|
||||
numel = 1 / (params.accum_grads * valid_info["utterances"])
|
||||
wb.log({"valid/loss": valid_info["loss"]*numel})
|
||||
#wb.log({"valid/loss": valid_info["loss"]*numel})
|
||||
wb.log({"valid/loss": numel*(valid_info["simple_loss"]
|
||||
+valid_info["pruned_loss"]
|
||||
+valid_info["ctc_loss"]
|
||||
)})
|
||||
wb.log({"valid/simple_loss": valid_info["simple_loss"]*numel})
|
||||
wb.log({"valid/pruned_loss": valid_info["pruned_loss"]*numel})
|
||||
wb.log({"valid/ctc_loss": valid_info["ctc_loss"]*numel})
|
||||
|
||||
'''
|
||||
loss_value = tot_loss["loss"] / tot_loss["utterances"]
|
||||
params.train_loss = loss_value
|
||||
if params.train_loss < params.best_train_loss:
|
||||
@ -1247,7 +1297,6 @@ def run(rank, world_size, args, wb=None):
|
||||
logging.info("About to create model")
|
||||
model = get_transducer_model(params)
|
||||
logging.info(model)
|
||||
exit()
|
||||
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
@ -1441,17 +1490,20 @@ def run(rank, world_size, args, wb=None):
|
||||
if params.print_diagnostics:
|
||||
diagnostic.print_diagnostics()
|
||||
break
|
||||
|
||||
save_checkpoint(
|
||||
params=params,
|
||||
model=model,
|
||||
model_avg=model_avg,
|
||||
optimizer=optimizer,
|
||||
scheduler=scheduler,
|
||||
sampler=train_dl.sampler,
|
||||
scaler=scaler,
|
||||
rank=rank,
|
||||
)
|
||||
|
||||
'''
|
||||
if epoch % 50 == 0:
|
||||
save_checkpoint(
|
||||
params=params,
|
||||
model=model,
|
||||
model_avg=model_avg,
|
||||
optimizer=optimizer,
|
||||
scheduler=scheduler,
|
||||
sampler=train_dl.sampler,
|
||||
scaler=scaler,
|
||||
rank=rank,
|
||||
)
|
||||
'''
|
||||
|
||||
logging.info("Done!")
|
||||
|
||||
@ -1526,46 +1578,72 @@ def run_adapter(rank, world_size, args, wb=None):
|
||||
adapter_names = []
|
||||
adapter_param = []
|
||||
for n, p in model.named_parameters():
|
||||
if 'adapters' in n:
|
||||
if 'adapters' in n:# or 'joiner' in n or 'simple' in n or 'ctc' in n:
|
||||
adapter_names.append(n)
|
||||
adapter_param.append(p)
|
||||
#else:
|
||||
# p.requires_grad = False
|
||||
elif 'joiner' in n or 'simple' in n or 'ctc' in n:
|
||||
p.requires_grad = True
|
||||
else:
|
||||
p.requires_grad = False
|
||||
|
||||
optimizer_adapter = ScaledAdam(
|
||||
adapter_param,
|
||||
lr=params.adapter_lr,
|
||||
clipping_scale=5.0,
|
||||
parameters_names=[adapter_names],
|
||||
)
|
||||
scheduler_adapter = Eden(optimizer_adapter, 5000, 3.5) #params.lr_batche, params.lr_epochs)
|
||||
scheduler_adapter = Eden(optimizer_adapter, 10000, 7) #params.lr_batche, params.lr_epochs)
|
||||
|
||||
optimizer, scheduler = optimizer_adapter, scheduler_adapter
|
||||
|
||||
librispeech = LibriSpeechAsrDataModule(args)
|
||||
|
||||
'''
|
||||
if params.hpo:
|
||||
train_cuts = librispeech.train_clean_10_cuts(option=params.gender)
|
||||
else:
|
||||
train_cuts = librispeech.train_clean_100_cuts(option=params.gender)
|
||||
if params.full_libri:
|
||||
train_cuts += librispeech.train_clean_360_cuts(option=params.gender)
|
||||
train_cuts += librispeech.train_other_500_cuts(option=params.gender)
|
||||
'''
|
||||
|
||||
train_cuts = librispeech.train_clean_100_cuts()
|
||||
if params.full_libri:
|
||||
train_cuts += librispeech.train_clean_360_cuts()
|
||||
train_cuts += librispeech.train_other_500_cuts()
|
||||
|
||||
#train_cuts = librispeech.train_clean_10_cuts(option='male')
|
||||
#train_cuts = librispeech.test_clean_user(option='big')
|
||||
train_cuts = librispeech.vox_cuts(option=params.spk_id)
|
||||
|
||||
def remove_short_and_long_utt(c: Cut):
|
||||
return 1.0 <= c.duration <= 20.0
|
||||
|
||||
train_cuts = train_cuts.filter(remove_short_and_long_utt)
|
||||
|
||||
|
||||
sampler_state_dict = None
|
||||
|
||||
train_dl = librispeech.train_dataloaders(
|
||||
train_cuts, sampler_state_dict=sampler_state_dict
|
||||
)
|
||||
#train_dl = librispeech.test_dataloaders(
|
||||
# train_cuts
|
||||
#)
|
||||
|
||||
'''
|
||||
print('\n'*5)
|
||||
print('-'*30)
|
||||
for batch in train_dl:
|
||||
print(batch)
|
||||
print('-'*30)
|
||||
print('\n'*5)
|
||||
exit()
|
||||
'''
|
||||
|
||||
valid_cuts = librispeech.dev_clean_cuts()
|
||||
valid_cuts += librispeech.dev_other_cuts()
|
||||
valid_cuts = librispeech.dev_clean_cuts(option=params.gender)
|
||||
valid_cuts += librispeech.dev_other_cuts(option=params.gender)
|
||||
valid_dl = librispeech.valid_dataloaders(valid_cuts)
|
||||
|
||||
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
|
||||
|
||||
for epoch in range(params.start_epoch, params.num_epochs + 1):
|
||||
logging.info(f"update num : {params.batch_idx_train}")
|
||||
scheduler.step_epoch(epoch - 1)
|
||||
fix_random_seed(params.seed + epoch - 1)
|
||||
train_dl.sampler.set_epoch(epoch - 1)
|
||||
@ -1594,17 +1672,20 @@ def run_adapter(rank, world_size, args, wb=None):
|
||||
if params.print_diagnostics:
|
||||
diagnostic.print_diagnostics()
|
||||
break
|
||||
|
||||
save_checkpoint(
|
||||
params=params,
|
||||
model=model,
|
||||
model_avg=model_avg,
|
||||
optimizer=optimizer,
|
||||
scheduler=scheduler,
|
||||
sampler=train_dl.sampler,
|
||||
scaler=scaler,
|
||||
rank=rank,
|
||||
)
|
||||
|
||||
'''
|
||||
if epoch % 10 == 0:
|
||||
save_checkpoint(
|
||||
params=params,
|
||||
model=model,
|
||||
model_avg=model_avg,
|
||||
optimizer=optimizer,
|
||||
scheduler=scheduler,
|
||||
sampler=train_dl.sampler,
|
||||
scaler=scaler,
|
||||
rank=rank,
|
||||
)
|
||||
'''
|
||||
|
||||
logging.info("Done!")
|
||||
|
||||
@ -1691,13 +1772,13 @@ def main():
|
||||
parser = get_parser()
|
||||
LibriSpeechAsrDataModule.add_arguments(parser)
|
||||
args = parser.parse_args()
|
||||
#args.exp_dir = args.exp_dir + str(random.randint(0,400))
|
||||
if args.wandb: args.exp_dir = args.exp_dir + str(random.randint(0,400))
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
|
||||
logging.info("save arguments to config.yaml...")
|
||||
save_args(args)
|
||||
|
||||
if args.wandb: wb = wandb.init(project="d2v-T", entity="dohe0342", config=vars(args))
|
||||
|
||||
if args.wandb: wb = wandb.init(project="d2v-adapter", entity="dohe0342", config=vars(args))
|
||||
else: wb = None
|
||||
|
||||
world_size = args.world_size
|
||||
@ -1709,7 +1790,7 @@ def main():
|
||||
join=True
|
||||
)
|
||||
else:
|
||||
if not args.add_adapter: run(rank=0, world_size=1, args=args, wb=wb)
|
||||
if args.add_adapter: run_adapter(rank=0, world_size=1, args=args, wb=wb)
|
||||
else: run(rank=0, world_size=1, args=args, wb=wb)
|
||||
|
||||
torch.set_num_threads(1)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user