mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
Add files via upload
This commit is contained in:
parent
0dfe0e6680
commit
bdd890bab9
@ -59,10 +59,7 @@ def get_parser():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--world-size",
|
"--world-size", type=int, default=1, help="Number of GPUs for DDP training.",
|
||||||
type=int,
|
|
||||||
default=1,
|
|
||||||
help="Number of GPUs for DDP training.",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -80,10 +77,7 @@ def get_parser():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--num-epochs",
|
"--num-epochs", type=int, default=35, help="Number of epochs to train.",
|
||||||
type=int,
|
|
||||||
default=35,
|
|
||||||
help="Number of epochs to train.",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -230,10 +224,7 @@ def load_checkpoint_if_available(
|
|||||||
|
|
||||||
filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
|
filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
|
||||||
saved_params = load_checkpoint(
|
saved_params = load_checkpoint(
|
||||||
filename,
|
filename, model=model, optimizer=optimizer, scheduler=scheduler,
|
||||||
model=model,
|
|
||||||
optimizer=optimizer,
|
|
||||||
scheduler=scheduler,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
keys = [
|
keys = [
|
||||||
@ -335,9 +326,7 @@ def compute_loss(
|
|||||||
decoding_graph = graph_compiler.compile(token_ids)
|
decoding_graph = graph_compiler.compile(token_ids)
|
||||||
|
|
||||||
dense_fsa_vec = k2.DenseFsaVec(
|
dense_fsa_vec = k2.DenseFsaVec(
|
||||||
nnet_output,
|
nnet_output, supervision_segments, allow_truncate=params.subsampling_factor - 1,
|
||||||
supervision_segments,
|
|
||||||
allow_truncate=params.subsampling_factor - 1,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
ctc_loss = k2.ctc_loss(
|
ctc_loss = k2.ctc_loss(
|
||||||
@ -374,12 +363,12 @@ def compute_loss(
|
|||||||
assert loss.requires_grad == is_training
|
assert loss.requires_grad == is_training
|
||||||
|
|
||||||
info = LossRecord()
|
info = LossRecord()
|
||||||
info['frames'] = supervision_segments[:, 2].sum().item()
|
info["frames"] = supervision_segments[:, 2].sum().item()
|
||||||
info['ctc_loss'] = ctc_loss.detach().cpu().item()
|
info["ctc_loss"] = ctc_loss.detach().cpu().item()
|
||||||
if params.att_rate != 0.0:
|
if params.att_rate != 0.0:
|
||||||
info['att_loss'] = att_loss.detach().cpu().item()
|
info["att_loss"] = att_loss.detach().cpu().item()
|
||||||
|
|
||||||
info['loss'] = loss.detach().cpu().item()
|
info["loss"] = loss.detach().cpu().item()
|
||||||
|
|
||||||
return loss, info
|
return loss, info
|
||||||
|
|
||||||
@ -410,7 +399,7 @@ def compute_validation_loss(
|
|||||||
if world_size > 1:
|
if world_size > 1:
|
||||||
tot_loss.reduce(loss.device)
|
tot_loss.reduce(loss.device)
|
||||||
|
|
||||||
loss_value = tot_loss['loss'] / tot_loss['frames']
|
loss_value = tot_loss["loss"] / tot_loss["frames"]
|
||||||
if loss_value < params.best_valid_loss:
|
if loss_value < params.best_valid_loss:
|
||||||
params.best_valid_epoch = params.cur_epoch
|
params.best_valid_epoch = params.cur_epoch
|
||||||
params.best_valid_loss = loss_value
|
params.best_valid_loss = loss_value
|
||||||
@ -489,15 +478,9 @@ def train_one_epoch(
|
|||||||
|
|
||||||
if tb_writer is not None:
|
if tb_writer is not None:
|
||||||
loss_info.write_summary(
|
loss_info.write_summary(
|
||||||
tb_writer,
|
tb_writer, "train/current_", params.batch_idx_train
|
||||||
"train/current_",
|
|
||||||
params.batch_idx_train
|
|
||||||
)
|
|
||||||
tot_loss.write_summary(
|
|
||||||
tb_writer,
|
|
||||||
"train/tot_",
|
|
||||||
params.batch_idx_train
|
|
||||||
)
|
)
|
||||||
|
tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
|
||||||
|
|
||||||
if batch_idx > 0 and batch_idx % params.valid_interval == 0:
|
if batch_idx > 0 and batch_idx % params.valid_interval == 0:
|
||||||
logging.info("Computing validation loss")
|
logging.info("Computing validation loss")
|
||||||
@ -509,17 +492,13 @@ def train_one_epoch(
|
|||||||
world_size=world_size,
|
world_size=world_size,
|
||||||
)
|
)
|
||||||
model.train()
|
model.train()
|
||||||
logging.info(
|
logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
|
||||||
f"Epoch {params.cur_epoch}, validation: {valid_info}"
|
|
||||||
)
|
|
||||||
if tb_writer is not None:
|
if tb_writer is not None:
|
||||||
valid_info.write_summary(
|
valid_info.write_summary(
|
||||||
tb_writer,
|
tb_writer, "train/valid_", params.batch_idx_train
|
||||||
"train/valid_",
|
|
||||||
params.batch_idx_train
|
|
||||||
)
|
)
|
||||||
|
|
||||||
loss_value = tot_loss['loss'] / tot_loss['frames']
|
loss_value = tot_loss["loss"] / tot_loss["frames"]
|
||||||
params.train_loss = loss_value
|
params.train_loss = loss_value
|
||||||
if params.train_loss < params.best_train_loss:
|
if params.train_loss < params.best_train_loss:
|
||||||
params.best_train_epoch = params.cur_epoch
|
params.best_train_epoch = params.cur_epoch
|
||||||
@ -563,10 +542,7 @@ def run(rank, world_size, args):
|
|||||||
device = torch.device("cuda", rank)
|
device = torch.device("cuda", rank)
|
||||||
|
|
||||||
graph_compiler = BpeCtcTrainingGraphCompiler(
|
graph_compiler = BpeCtcTrainingGraphCompiler(
|
||||||
params.lang_dir,
|
params.lang_dir, device=device, sos_token="<sos/eos>", eos_token="<sos/eos>",
|
||||||
device=device,
|
|
||||||
sos_token="<sos/eos>",
|
|
||||||
eos_token="<sos/eos>",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
logging.info("About to create model")
|
logging.info("About to create model")
|
||||||
@ -607,9 +583,7 @@ def run(rank, world_size, args):
|
|||||||
|
|
||||||
cur_lr = optimizer._rate
|
cur_lr = optimizer._rate
|
||||||
if tb_writer is not None:
|
if tb_writer is not None:
|
||||||
tb_writer.add_scalar(
|
tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train)
|
||||||
"train/learning_rate", cur_lr, params.batch_idx_train
|
|
||||||
)
|
|
||||||
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
|
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
|
||||||
|
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
@ -629,10 +603,7 @@ def run(rank, world_size, args):
|
|||||||
)
|
)
|
||||||
|
|
||||||
save_checkpoint(
|
save_checkpoint(
|
||||||
params=params,
|
params=params, model=model, optimizer=optimizer, rank=rank,
|
||||||
model=model,
|
|
||||||
optimizer=optimizer,
|
|
||||||
rank=rank,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
logging.info("Done!")
|
logging.info("Done!")
|
||||||
|
Loading…
x
Reference in New Issue
Block a user