Add files via upload

This commit is contained in:
Mingshuang Luo 2021-09-29 19:23:11 +08:00 committed by GitHub
parent 0dfe0e6680
commit bdd890bab9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -59,10 +59,7 @@ def get_parser():
) )
parser.add_argument( parser.add_argument(
"--world-size", "--world-size", type=int, default=1, help="Number of GPUs for DDP training.",
type=int,
default=1,
help="Number of GPUs for DDP training.",
) )
parser.add_argument( parser.add_argument(
@ -80,10 +77,7 @@ def get_parser():
) )
parser.add_argument( parser.add_argument(
"--num-epochs", "--num-epochs", type=int, default=35, help="Number of epochs to train.",
type=int,
default=35,
help="Number of epochs to train.",
) )
parser.add_argument( parser.add_argument(
@ -230,10 +224,7 @@ def load_checkpoint_if_available(
filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
saved_params = load_checkpoint( saved_params = load_checkpoint(
filename, filename, model=model, optimizer=optimizer, scheduler=scheduler,
model=model,
optimizer=optimizer,
scheduler=scheduler,
) )
keys = [ keys = [
@ -335,9 +326,7 @@ def compute_loss(
decoding_graph = graph_compiler.compile(token_ids) decoding_graph = graph_compiler.compile(token_ids)
dense_fsa_vec = k2.DenseFsaVec( dense_fsa_vec = k2.DenseFsaVec(
nnet_output, nnet_output, supervision_segments, allow_truncate=params.subsampling_factor - 1,
supervision_segments,
allow_truncate=params.subsampling_factor - 1,
) )
ctc_loss = k2.ctc_loss( ctc_loss = k2.ctc_loss(
@ -374,12 +363,12 @@ def compute_loss(
assert loss.requires_grad == is_training assert loss.requires_grad == is_training
info = LossRecord() info = LossRecord()
info['frames'] = supervision_segments[:, 2].sum().item() info["frames"] = supervision_segments[:, 2].sum().item()
info['ctc_loss'] = ctc_loss.detach().cpu().item() info["ctc_loss"] = ctc_loss.detach().cpu().item()
if params.att_rate != 0.0: if params.att_rate != 0.0:
info['att_loss'] = att_loss.detach().cpu().item() info["att_loss"] = att_loss.detach().cpu().item()
info['loss'] = loss.detach().cpu().item() info["loss"] = loss.detach().cpu().item()
return loss, info return loss, info
@ -410,7 +399,7 @@ def compute_validation_loss(
if world_size > 1: if world_size > 1:
tot_loss.reduce(loss.device) tot_loss.reduce(loss.device)
loss_value = tot_loss['loss'] / tot_loss['frames'] loss_value = tot_loss["loss"] / tot_loss["frames"]
if loss_value < params.best_valid_loss: if loss_value < params.best_valid_loss:
params.best_valid_epoch = params.cur_epoch params.best_valid_epoch = params.cur_epoch
params.best_valid_loss = loss_value params.best_valid_loss = loss_value
@ -489,15 +478,9 @@ def train_one_epoch(
if tb_writer is not None: if tb_writer is not None:
loss_info.write_summary( loss_info.write_summary(
tb_writer, tb_writer, "train/current_", params.batch_idx_train
"train/current_",
params.batch_idx_train
)
tot_loss.write_summary(
tb_writer,
"train/tot_",
params.batch_idx_train
) )
tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
if batch_idx > 0 and batch_idx % params.valid_interval == 0: if batch_idx > 0 and batch_idx % params.valid_interval == 0:
logging.info("Computing validation loss") logging.info("Computing validation loss")
@ -509,17 +492,13 @@ def train_one_epoch(
world_size=world_size, world_size=world_size,
) )
model.train() model.train()
logging.info( logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
f"Epoch {params.cur_epoch}, validation: {valid_info}"
)
if tb_writer is not None: if tb_writer is not None:
valid_info.write_summary( valid_info.write_summary(
tb_writer, tb_writer, "train/valid_", params.batch_idx_train
"train/valid_",
params.batch_idx_train
) )
loss_value = tot_loss['loss'] / tot_loss['frames'] loss_value = tot_loss["loss"] / tot_loss["frames"]
params.train_loss = loss_value params.train_loss = loss_value
if params.train_loss < params.best_train_loss: if params.train_loss < params.best_train_loss:
params.best_train_epoch = params.cur_epoch params.best_train_epoch = params.cur_epoch
@ -563,10 +542,7 @@ def run(rank, world_size, args):
device = torch.device("cuda", rank) device = torch.device("cuda", rank)
graph_compiler = BpeCtcTrainingGraphCompiler( graph_compiler = BpeCtcTrainingGraphCompiler(
params.lang_dir, params.lang_dir, device=device, sos_token="<sos/eos>", eos_token="<sos/eos>",
device=device,
sos_token="<sos/eos>",
eos_token="<sos/eos>",
) )
logging.info("About to create model") logging.info("About to create model")
@ -607,9 +583,7 @@ def run(rank, world_size, args):
cur_lr = optimizer._rate cur_lr = optimizer._rate
if tb_writer is not None: if tb_writer is not None:
tb_writer.add_scalar( tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train)
"train/learning_rate", cur_lr, params.batch_idx_train
)
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
if rank == 0: if rank == 0:
@ -629,10 +603,7 @@ def run(rank, world_size, args):
) )
save_checkpoint( save_checkpoint(
params=params, params=params, model=model, optimizer=optimizer, rank=rank,
model=model,
optimizer=optimizer,
rank=rank,
) )
logging.info("Done!") logging.info("Done!")