mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 10:16:14 +00:00
Add files via upload
This commit is contained in:
parent
279dc74b4e
commit
79fd09e3e5
@ -33,10 +33,7 @@ def get_parser():
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--world-size",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of GPUs for DDP training.",
|
||||
"--world-size", type=int, default=1, help="Number of GPUs for DDP training.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@ -54,10 +51,7 @@ def get_parser():
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-epochs",
|
||||
type=int,
|
||||
default=15,
|
||||
help="Number of epochs to train.",
|
||||
"--num-epochs", type=int, default=15, help="Number of epochs to train.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@ -187,10 +181,7 @@ def load_checkpoint_if_available(
|
||||
|
||||
filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
|
||||
saved_params = load_checkpoint(
|
||||
filename,
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
scheduler=scheduler,
|
||||
filename, model=model, optimizer=optimizer, scheduler=scheduler,
|
||||
)
|
||||
|
||||
keys = [
|
||||
@ -287,16 +278,12 @@ def compute_loss(
|
||||
|
||||
batch_size = nnet_output.shape[0]
|
||||
supervision_segments = torch.tensor(
|
||||
[[i, 0, nnet_output.shape[1]] for i in range(batch_size)],
|
||||
dtype=torch.int32,
|
||||
[[i, 0, nnet_output.shape[1]] for i in range(batch_size)], dtype=torch.int32,
|
||||
)
|
||||
|
||||
decoding_graph = graph_compiler.compile(texts)
|
||||
|
||||
dense_fsa_vec = k2.DenseFsaVec(
|
||||
nnet_output,
|
||||
supervision_segments,
|
||||
)
|
||||
dense_fsa_vec = k2.DenseFsaVec(nnet_output, supervision_segments,)
|
||||
|
||||
loss = k2.ctc_loss(
|
||||
decoding_graph=decoding_graph,
|
||||
@ -309,8 +296,8 @@ def compute_loss(
|
||||
assert loss.requires_grad == is_training
|
||||
|
||||
info = LossRecord()
|
||||
info['frames'] = supervision_segments[:, 2].sum().item()
|
||||
info['loss'] = loss.detach().cpu().item()
|
||||
info["frames"] = supervision_segments[:, 2].sum().item()
|
||||
info["loss"] = loss.detach().cpu().item()
|
||||
|
||||
return loss, info
|
||||
|
||||
@ -344,7 +331,7 @@ def compute_validation_loss(
|
||||
if world_size > 1:
|
||||
tot_loss.reduce(loss.device)
|
||||
|
||||
loss_value = tot_loss['loss'] / tot_loss['frames']
|
||||
loss_value = tot_loss["loss"] / tot_loss["frames"]
|
||||
|
||||
if loss_value < params.best_valid_loss:
|
||||
params.best_valid_epoch = params.cur_epoch
|
||||
@ -420,15 +407,9 @@ def train_one_epoch(
|
||||
|
||||
if tb_writer is not None:
|
||||
loss_info.write_summary(
|
||||
tb_writer,
|
||||
"train/current_",
|
||||
params.batch_idx_train
|
||||
)
|
||||
tot_loss.write_summary(
|
||||
tb_writer,
|
||||
"train/tot_",
|
||||
params.batch_idx_train
|
||||
tb_writer, "train/current_", params.batch_idx_train
|
||||
)
|
||||
tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
|
||||
|
||||
if batch_idx > 0 and batch_idx % params.valid_interval == 0:
|
||||
valid_info = compute_validation_loss(
|
||||
@ -439,17 +420,13 @@ def train_one_epoch(
|
||||
world_size=world_size,
|
||||
)
|
||||
model.train()
|
||||
logging.info(
|
||||
f"Epoch {params.cur_epoch}, validation {valid_info}"
|
||||
)
|
||||
logging.info(f"Epoch {params.cur_epoch}, validation {valid_info}")
|
||||
if tb_writer is not None:
|
||||
valid_info.write_summary(
|
||||
tb_writer,
|
||||
"train/valid_",
|
||||
params.batch_idx_train,
|
||||
tb_writer, "train/valid_", params.batch_idx_train,
|
||||
)
|
||||
|
||||
loss_value = tot_loss['loss'] / tot_loss['frames']
|
||||
loss_value = tot_loss["loss"] / tot_loss["frames"]
|
||||
params.train_loss = loss_value
|
||||
|
||||
if params.train_loss < params.best_train_loss:
|
||||
@ -506,9 +483,7 @@ def run(rank, world_size, args):
|
||||
model = DDP(model, device_ids=[rank])
|
||||
|
||||
optimizer = optim.SGD(
|
||||
model.parameters(),
|
||||
lr=params.lr,
|
||||
weight_decay=params.weight_decay,
|
||||
model.parameters(), lr=params.lr, weight_decay=params.weight_decay,
|
||||
)
|
||||
|
||||
if checkpoints:
|
||||
@ -542,11 +517,7 @@ def run(rank, world_size, args):
|
||||
)
|
||||
|
||||
save_checkpoint(
|
||||
params=params,
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
scheduler=None,
|
||||
rank=rank,
|
||||
params=params, model=model, optimizer=optimizer, scheduler=None, rank=rank,
|
||||
)
|
||||
|
||||
logging.info("Done!")
|
||||
|
Loading…
x
Reference in New Issue
Block a user