mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
Merge pull request #1 from luomingshuang/lossrecord-v2
style check with flake8 and black
This commit is contained in:
commit
b35eed961a
@ -59,10 +59,7 @@ def get_parser():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--world-size",
|
"--world-size", type=int, default=1, help="Number of GPUs for DDP training.",
|
||||||
type=int,
|
|
||||||
default=1,
|
|
||||||
help="Number of GPUs for DDP training.",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -80,10 +77,7 @@ def get_parser():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--num-epochs",
|
"--num-epochs", type=int, default=35, help="Number of epochs to train.",
|
||||||
type=int,
|
|
||||||
default=35,
|
|
||||||
help="Number of epochs to train.",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -230,10 +224,7 @@ def load_checkpoint_if_available(
|
|||||||
|
|
||||||
filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
|
filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
|
||||||
saved_params = load_checkpoint(
|
saved_params = load_checkpoint(
|
||||||
filename,
|
filename, model=model, optimizer=optimizer, scheduler=scheduler,
|
||||||
model=model,
|
|
||||||
optimizer=optimizer,
|
|
||||||
scheduler=scheduler,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
keys = [
|
keys = [
|
||||||
@ -335,9 +326,7 @@ def compute_loss(
|
|||||||
decoding_graph = graph_compiler.compile(token_ids)
|
decoding_graph = graph_compiler.compile(token_ids)
|
||||||
|
|
||||||
dense_fsa_vec = k2.DenseFsaVec(
|
dense_fsa_vec = k2.DenseFsaVec(
|
||||||
nnet_output,
|
nnet_output, supervision_segments, allow_truncate=params.subsampling_factor - 1,
|
||||||
supervision_segments,
|
|
||||||
allow_truncate=params.subsampling_factor - 1,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
ctc_loss = k2.ctc_loss(
|
ctc_loss = k2.ctc_loss(
|
||||||
@ -374,12 +363,12 @@ def compute_loss(
|
|||||||
assert loss.requires_grad == is_training
|
assert loss.requires_grad == is_training
|
||||||
|
|
||||||
info = LossRecord()
|
info = LossRecord()
|
||||||
info['frames'] = supervision_segments[:, 2].sum().item()
|
info["frames"] = supervision_segments[:, 2].sum().item()
|
||||||
info['ctc_loss'] = ctc_loss.detach().cpu().item()
|
info["ctc_loss"] = ctc_loss.detach().cpu().item()
|
||||||
if params.att_rate != 0.0:
|
if params.att_rate != 0.0:
|
||||||
info['att_loss'] = att_loss.detach().cpu().item()
|
info["att_loss"] = att_loss.detach().cpu().item()
|
||||||
|
|
||||||
info['loss'] = loss.detach().cpu().item()
|
info["loss"] = loss.detach().cpu().item()
|
||||||
|
|
||||||
return loss, info
|
return loss, info
|
||||||
|
|
||||||
@ -410,7 +399,7 @@ def compute_validation_loss(
|
|||||||
if world_size > 1:
|
if world_size > 1:
|
||||||
tot_loss.reduce(loss.device)
|
tot_loss.reduce(loss.device)
|
||||||
|
|
||||||
loss_value = tot_loss['loss'] / tot_loss['frames']
|
loss_value = tot_loss["loss"] / tot_loss["frames"]
|
||||||
if loss_value < params.best_valid_loss:
|
if loss_value < params.best_valid_loss:
|
||||||
params.best_valid_epoch = params.cur_epoch
|
params.best_valid_epoch = params.cur_epoch
|
||||||
params.best_valid_loss = loss_value
|
params.best_valid_loss = loss_value
|
||||||
@ -489,15 +478,9 @@ def train_one_epoch(
|
|||||||
|
|
||||||
if tb_writer is not None:
|
if tb_writer is not None:
|
||||||
loss_info.write_summary(
|
loss_info.write_summary(
|
||||||
tb_writer,
|
tb_writer, "train/current_", params.batch_idx_train
|
||||||
"train/current_",
|
|
||||||
params.batch_idx_train
|
|
||||||
)
|
|
||||||
tot_loss.write_summary(
|
|
||||||
tb_writer,
|
|
||||||
"train/tot_",
|
|
||||||
params.batch_idx_train
|
|
||||||
)
|
)
|
||||||
|
tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
|
||||||
|
|
||||||
if batch_idx > 0 and batch_idx % params.valid_interval == 0:
|
if batch_idx > 0 and batch_idx % params.valid_interval == 0:
|
||||||
logging.info("Computing validation loss")
|
logging.info("Computing validation loss")
|
||||||
@ -509,17 +492,13 @@ def train_one_epoch(
|
|||||||
world_size=world_size,
|
world_size=world_size,
|
||||||
)
|
)
|
||||||
model.train()
|
model.train()
|
||||||
logging.info(
|
logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
|
||||||
f"Epoch {params.cur_epoch}, validation: {valid_info}"
|
|
||||||
)
|
|
||||||
if tb_writer is not None:
|
if tb_writer is not None:
|
||||||
valid_info.write_summary(
|
valid_info.write_summary(
|
||||||
tb_writer,
|
tb_writer, "train/valid_", params.batch_idx_train
|
||||||
"train/valid_",
|
|
||||||
params.batch_idx_train
|
|
||||||
)
|
)
|
||||||
|
|
||||||
loss_value = tot_loss['loss'] / tot_loss['frames']
|
loss_value = tot_loss["loss"] / tot_loss["frames"]
|
||||||
params.train_loss = loss_value
|
params.train_loss = loss_value
|
||||||
if params.train_loss < params.best_train_loss:
|
if params.train_loss < params.best_train_loss:
|
||||||
params.best_train_epoch = params.cur_epoch
|
params.best_train_epoch = params.cur_epoch
|
||||||
@ -563,10 +542,7 @@ def run(rank, world_size, args):
|
|||||||
device = torch.device("cuda", rank)
|
device = torch.device("cuda", rank)
|
||||||
|
|
||||||
graph_compiler = BpeCtcTrainingGraphCompiler(
|
graph_compiler = BpeCtcTrainingGraphCompiler(
|
||||||
params.lang_dir,
|
params.lang_dir, device=device, sos_token="<sos/eos>", eos_token="<sos/eos>",
|
||||||
device=device,
|
|
||||||
sos_token="<sos/eos>",
|
|
||||||
eos_token="<sos/eos>",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
logging.info("About to create model")
|
logging.info("About to create model")
|
||||||
@ -607,9 +583,7 @@ def run(rank, world_size, args):
|
|||||||
|
|
||||||
cur_lr = optimizer._rate
|
cur_lr = optimizer._rate
|
||||||
if tb_writer is not None:
|
if tb_writer is not None:
|
||||||
tb_writer.add_scalar(
|
tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train)
|
||||||
"train/learning_rate", cur_lr, params.batch_idx_train
|
|
||||||
)
|
|
||||||
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
|
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
|
||||||
|
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
@ -629,10 +603,7 @@ def run(rank, world_size, args):
|
|||||||
)
|
)
|
||||||
|
|
||||||
save_checkpoint(
|
save_checkpoint(
|
||||||
params=params,
|
params=params, model=model, optimizer=optimizer, rank=rank,
|
||||||
model=model,
|
|
||||||
optimizer=optimizer,
|
|
||||||
rank=rank,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
logging.info("Done!")
|
logging.info("Done!")
|
||||||
|
@ -58,10 +58,7 @@ def get_parser():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--world-size",
|
"--world-size", type=int, default=1, help="Number of GPUs for DDP training.",
|
||||||
type=int,
|
|
||||||
default=1,
|
|
||||||
help="Number of GPUs for DDP training.",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -79,10 +76,7 @@ def get_parser():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--num-epochs",
|
"--num-epochs", type=int, default=20, help="Number of epochs to train.",
|
||||||
type=int,
|
|
||||||
default=20,
|
|
||||||
help="Number of epochs to train.",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -209,10 +203,7 @@ def load_checkpoint_if_available(
|
|||||||
|
|
||||||
filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
|
filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
|
||||||
saved_params = load_checkpoint(
|
saved_params = load_checkpoint(
|
||||||
filename,
|
filename, model=model, optimizer=optimizer, scheduler=scheduler,
|
||||||
model=model,
|
|
||||||
optimizer=optimizer,
|
|
||||||
scheduler=scheduler,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
keys = [
|
keys = [
|
||||||
@ -312,9 +303,7 @@ def compute_loss(
|
|||||||
decoding_graph = graph_compiler.compile(texts)
|
decoding_graph = graph_compiler.compile(texts)
|
||||||
|
|
||||||
dense_fsa_vec = k2.DenseFsaVec(
|
dense_fsa_vec = k2.DenseFsaVec(
|
||||||
nnet_output,
|
nnet_output, supervision_segments, allow_truncate=params.subsampling_factor - 1,
|
||||||
supervision_segments,
|
|
||||||
allow_truncate=params.subsampling_factor - 1,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
loss = k2.ctc_loss(
|
loss = k2.ctc_loss(
|
||||||
@ -328,8 +317,8 @@ def compute_loss(
|
|||||||
assert loss.requires_grad == is_training
|
assert loss.requires_grad == is_training
|
||||||
|
|
||||||
info = LossRecord()
|
info = LossRecord()
|
||||||
info['frames'] = supervision_segments[:, 2].sum().item()
|
info["frames"] = supervision_segments[:, 2].sum().item()
|
||||||
info['loss'] = loss.detach().cpu().item()
|
info["loss"] = loss.detach().cpu().item()
|
||||||
|
|
||||||
return loss, info
|
return loss, info
|
||||||
|
|
||||||
@ -363,7 +352,7 @@ def compute_validation_loss(
|
|||||||
if world_size > 1:
|
if world_size > 1:
|
||||||
tot_loss.reduce(loss.device)
|
tot_loss.reduce(loss.device)
|
||||||
|
|
||||||
loss_value = tot_loss['loss'] / tot_loss['frames']
|
loss_value = tot_loss["loss"] / tot_loss["frames"]
|
||||||
|
|
||||||
if loss_value < params.best_valid_loss:
|
if loss_value < params.best_valid_loss:
|
||||||
params.best_valid_epoch = params.cur_epoch
|
params.best_valid_epoch = params.cur_epoch
|
||||||
@ -439,15 +428,9 @@ def train_one_epoch(
|
|||||||
|
|
||||||
if tb_writer is not None:
|
if tb_writer is not None:
|
||||||
loss_info.write_summary(
|
loss_info.write_summary(
|
||||||
tb_writer,
|
tb_writer, "train/current_", params.batch_idx_train
|
||||||
"train/current_",
|
|
||||||
params.batch_idx_train
|
|
||||||
)
|
|
||||||
tot_loss.write_summary(
|
|
||||||
tb_writer,
|
|
||||||
"train/tot_",
|
|
||||||
params.batch_idx_train
|
|
||||||
)
|
)
|
||||||
|
tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
|
||||||
|
|
||||||
if batch_idx > 0 and batch_idx % params.valid_interval == 0:
|
if batch_idx > 0 and batch_idx % params.valid_interval == 0:
|
||||||
valid_info = compute_validation_loss(
|
valid_info = compute_validation_loss(
|
||||||
@ -458,17 +441,13 @@ def train_one_epoch(
|
|||||||
world_size=world_size,
|
world_size=world_size,
|
||||||
)
|
)
|
||||||
model.train()
|
model.train()
|
||||||
logging.info(
|
logging.info(f"Epoch {params.cur_epoch}, validation {valid_info}")
|
||||||
f"Epoch {params.cur_epoch}, validation {valid_info}"
|
|
||||||
)
|
|
||||||
if tb_writer is not None:
|
if tb_writer is not None:
|
||||||
valid_info.write_summary(
|
valid_info.write_summary(
|
||||||
tb_writer,
|
tb_writer, "train/valid_", params.batch_idx_train,
|
||||||
"train/valid_",
|
|
||||||
params.batch_idx_train,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
loss_value = tot_loss['loss'] / tot_loss['frames']
|
loss_value = tot_loss["loss"] / tot_loss["frames"]
|
||||||
params.train_loss = loss_value
|
params.train_loss = loss_value
|
||||||
|
|
||||||
if params.train_loss < params.best_train_loss:
|
if params.train_loss < params.best_train_loss:
|
||||||
@ -526,9 +505,7 @@ def run(rank, world_size, args):
|
|||||||
model = DDP(model, device_ids=[rank])
|
model = DDP(model, device_ids=[rank])
|
||||||
|
|
||||||
optimizer = optim.AdamW(
|
optimizer = optim.AdamW(
|
||||||
model.parameters(),
|
model.parameters(), lr=params.lr, weight_decay=params.weight_decay,
|
||||||
lr=params.lr,
|
|
||||||
weight_decay=params.weight_decay,
|
|
||||||
)
|
)
|
||||||
scheduler = StepLR(optimizer, step_size=8, gamma=0.1)
|
scheduler = StepLR(optimizer, step_size=8, gamma=0.1)
|
||||||
|
|
||||||
@ -548,9 +525,7 @@ def run(rank, world_size, args):
|
|||||||
|
|
||||||
if tb_writer is not None:
|
if tb_writer is not None:
|
||||||
tb_writer.add_scalar(
|
tb_writer.add_scalar(
|
||||||
"train/lr",
|
"train/lr", scheduler.get_last_lr()[0], params.batch_idx_train,
|
||||||
scheduler.get_last_lr()[0],
|
|
||||||
params.batch_idx_train,
|
|
||||||
)
|
)
|
||||||
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
|
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
|
||||||
|
|
||||||
|
@ -33,10 +33,7 @@ def get_parser():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--world-size",
|
"--world-size", type=int, default=1, help="Number of GPUs for DDP training.",
|
||||||
type=int,
|
|
||||||
default=1,
|
|
||||||
help="Number of GPUs for DDP training.",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -54,10 +51,7 @@ def get_parser():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--num-epochs",
|
"--num-epochs", type=int, default=15, help="Number of epochs to train.",
|
||||||
type=int,
|
|
||||||
default=15,
|
|
||||||
help="Number of epochs to train.",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -187,10 +181,7 @@ def load_checkpoint_if_available(
|
|||||||
|
|
||||||
filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
|
filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
|
||||||
saved_params = load_checkpoint(
|
saved_params = load_checkpoint(
|
||||||
filename,
|
filename, model=model, optimizer=optimizer, scheduler=scheduler,
|
||||||
model=model,
|
|
||||||
optimizer=optimizer,
|
|
||||||
scheduler=scheduler,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
keys = [
|
keys = [
|
||||||
@ -287,16 +278,12 @@ def compute_loss(
|
|||||||
|
|
||||||
batch_size = nnet_output.shape[0]
|
batch_size = nnet_output.shape[0]
|
||||||
supervision_segments = torch.tensor(
|
supervision_segments = torch.tensor(
|
||||||
[[i, 0, nnet_output.shape[1]] for i in range(batch_size)],
|
[[i, 0, nnet_output.shape[1]] for i in range(batch_size)], dtype=torch.int32,
|
||||||
dtype=torch.int32,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
decoding_graph = graph_compiler.compile(texts)
|
decoding_graph = graph_compiler.compile(texts)
|
||||||
|
|
||||||
dense_fsa_vec = k2.DenseFsaVec(
|
dense_fsa_vec = k2.DenseFsaVec(nnet_output, supervision_segments,)
|
||||||
nnet_output,
|
|
||||||
supervision_segments,
|
|
||||||
)
|
|
||||||
|
|
||||||
loss = k2.ctc_loss(
|
loss = k2.ctc_loss(
|
||||||
decoding_graph=decoding_graph,
|
decoding_graph=decoding_graph,
|
||||||
@ -309,8 +296,8 @@ def compute_loss(
|
|||||||
assert loss.requires_grad == is_training
|
assert loss.requires_grad == is_training
|
||||||
|
|
||||||
info = LossRecord()
|
info = LossRecord()
|
||||||
info['frames'] = supervision_segments[:, 2].sum().item()
|
info["frames"] = supervision_segments[:, 2].sum().item()
|
||||||
info['loss'] = loss.detach().cpu().item()
|
info["loss"] = loss.detach().cpu().item()
|
||||||
|
|
||||||
return loss, info
|
return loss, info
|
||||||
|
|
||||||
@ -344,7 +331,7 @@ def compute_validation_loss(
|
|||||||
if world_size > 1:
|
if world_size > 1:
|
||||||
tot_loss.reduce(loss.device)
|
tot_loss.reduce(loss.device)
|
||||||
|
|
||||||
loss_value = tot_loss['loss'] / tot_loss['frames']
|
loss_value = tot_loss["loss"] / tot_loss["frames"]
|
||||||
|
|
||||||
if loss_value < params.best_valid_loss:
|
if loss_value < params.best_valid_loss:
|
||||||
params.best_valid_epoch = params.cur_epoch
|
params.best_valid_epoch = params.cur_epoch
|
||||||
@ -420,15 +407,9 @@ def train_one_epoch(
|
|||||||
|
|
||||||
if tb_writer is not None:
|
if tb_writer is not None:
|
||||||
loss_info.write_summary(
|
loss_info.write_summary(
|
||||||
tb_writer,
|
tb_writer, "train/current_", params.batch_idx_train
|
||||||
"train/current_",
|
|
||||||
params.batch_idx_train
|
|
||||||
)
|
|
||||||
tot_loss.write_summary(
|
|
||||||
tb_writer,
|
|
||||||
"train/tot_",
|
|
||||||
params.batch_idx_train
|
|
||||||
)
|
)
|
||||||
|
tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
|
||||||
|
|
||||||
if batch_idx > 0 and batch_idx % params.valid_interval == 0:
|
if batch_idx > 0 and batch_idx % params.valid_interval == 0:
|
||||||
valid_info = compute_validation_loss(
|
valid_info = compute_validation_loss(
|
||||||
@ -439,17 +420,13 @@ def train_one_epoch(
|
|||||||
world_size=world_size,
|
world_size=world_size,
|
||||||
)
|
)
|
||||||
model.train()
|
model.train()
|
||||||
logging.info(
|
logging.info(f"Epoch {params.cur_epoch}, validation {valid_info}")
|
||||||
f"Epoch {params.cur_epoch}, validation {valid_info}"
|
|
||||||
)
|
|
||||||
if tb_writer is not None:
|
if tb_writer is not None:
|
||||||
valid_info.write_summary(
|
valid_info.write_summary(
|
||||||
tb_writer,
|
tb_writer, "train/valid_", params.batch_idx_train,
|
||||||
"train/valid_",
|
|
||||||
params.batch_idx_train,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
loss_value = tot_loss['loss'] / tot_loss['frames']
|
loss_value = tot_loss["loss"] / tot_loss["frames"]
|
||||||
params.train_loss = loss_value
|
params.train_loss = loss_value
|
||||||
|
|
||||||
if params.train_loss < params.best_train_loss:
|
if params.train_loss < params.best_train_loss:
|
||||||
@ -506,9 +483,7 @@ def run(rank, world_size, args):
|
|||||||
model = DDP(model, device_ids=[rank])
|
model = DDP(model, device_ids=[rank])
|
||||||
|
|
||||||
optimizer = optim.SGD(
|
optimizer = optim.SGD(
|
||||||
model.parameters(),
|
model.parameters(), lr=params.lr, weight_decay=params.weight_decay,
|
||||||
lr=params.lr,
|
|
||||||
weight_decay=params.weight_decay,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if checkpoints:
|
if checkpoints:
|
||||||
@ -542,11 +517,7 @@ def run(rank, world_size, args):
|
|||||||
)
|
)
|
||||||
|
|
||||||
save_checkpoint(
|
save_checkpoint(
|
||||||
params=params,
|
params=params, model=model, optimizer=optimizer, scheduler=None, rank=rank,
|
||||||
model=model,
|
|
||||||
optimizer=optimizer,
|
|
||||||
scheduler=None,
|
|
||||||
rank=rank,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
logging.info("Done!")
|
logging.info("Done!")
|
||||||
|
@ -107,9 +107,7 @@ def setup_logger(
|
|||||||
formatter = f"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] ({rank}/{world_size}) %(message)s" # noqa
|
formatter = f"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] ({rank}/{world_size}) %(message)s" # noqa
|
||||||
log_filename = f"{log_filename}-{date_time}-{rank}"
|
log_filename = f"{log_filename}-{date_time}-{rank}"
|
||||||
else:
|
else:
|
||||||
formatter = (
|
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
|
||||||
)
|
|
||||||
log_filename = f"{log_filename}-{date_time}"
|
log_filename = f"{log_filename}-{date_time}"
|
||||||
|
|
||||||
os.makedirs(os.path.dirname(log_filename), exist_ok=True)
|
os.makedirs(os.path.dirname(log_filename), exist_ok=True)
|
||||||
@ -236,9 +234,7 @@ def get_texts(
|
|||||||
return aux_labels.tolist()
|
return aux_labels.tolist()
|
||||||
|
|
||||||
|
|
||||||
def store_transcripts(
|
def store_transcripts(filename: Pathlike, texts: Iterable[Tuple[str, str]]) -> None:
|
||||||
filename: Pathlike, texts: Iterable[Tuple[str, str]]
|
|
||||||
) -> None:
|
|
||||||
"""Save predicted results and reference transcripts to a file.
|
"""Save predicted results and reference transcripts to a file.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -369,19 +365,14 @@ def write_error_stats(
|
|||||||
]
|
]
|
||||||
ali = list(filter(lambda x: x != [[], []], ali))
|
ali = list(filter(lambda x: x != [[], []], ali))
|
||||||
ali = [
|
ali = [
|
||||||
[
|
[ERR if x == [] else " ".join(x), ERR if y == [] else " ".join(y),]
|
||||||
ERR if x == [] else " ".join(x),
|
|
||||||
ERR if y == [] else " ".join(y),
|
|
||||||
]
|
|
||||||
for x, y in ali
|
for x, y in ali
|
||||||
]
|
]
|
||||||
|
|
||||||
print(
|
print(
|
||||||
" ".join(
|
" ".join(
|
||||||
(
|
(
|
||||||
ref_word
|
ref_word if ref_word == hyp_word else f"({ref_word}->{hyp_word})"
|
||||||
if ref_word == hyp_word
|
|
||||||
else f"({ref_word}->{hyp_word})"
|
|
||||||
for ref_word, hyp_word in ali
|
for ref_word, hyp_word in ali
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
@ -391,9 +382,7 @@ def write_error_stats(
|
|||||||
print("", file=f)
|
print("", file=f)
|
||||||
print("SUBSTITUTIONS: count ref -> hyp", file=f)
|
print("SUBSTITUTIONS: count ref -> hyp", file=f)
|
||||||
|
|
||||||
for count, (ref, hyp) in sorted(
|
for count, (ref, hyp) in sorted([(v, k) for k, v in subs.items()], reverse=True):
|
||||||
[(v, k) for k, v in subs.items()], reverse=True
|
|
||||||
):
|
|
||||||
print(f"{count} {ref} -> {hyp}", file=f)
|
print(f"{count} {ref} -> {hyp}", file=f)
|
||||||
|
|
||||||
print("", file=f)
|
print("", file=f)
|
||||||
@ -407,9 +396,7 @@ def write_error_stats(
|
|||||||
print(f"{count} {hyp}", file=f)
|
print(f"{count} {hyp}", file=f)
|
||||||
|
|
||||||
print("", file=f)
|
print("", file=f)
|
||||||
print(
|
print("PER-WORD STATS: word corr tot_errs count_in_ref count_in_hyp", file=f)
|
||||||
"PER-WORD STATS: word corr tot_errs count_in_ref count_in_hyp", file=f
|
|
||||||
)
|
|
||||||
for _, word, counts in sorted(
|
for _, word, counts in sorted(
|
||||||
[(sum(v[1:]), k, v) for k, v in words.items()], reverse=True
|
[(sum(v[1:]), k, v) for k, v in words.items()], reverse=True
|
||||||
):
|
):
|
||||||
@ -428,7 +415,7 @@ class LossRecord(collections.defaultdict):
|
|||||||
# makes undefined items default to int() which is zero.
|
# makes undefined items default to int() which is zero.
|
||||||
super(LossRecord, self).__init__(int)
|
super(LossRecord, self).__init__(int)
|
||||||
|
|
||||||
def __add__(self, other: 'LossRecord') -> 'LossRecord':
|
def __add__(self, other: "LossRecord") -> "LossRecord":
|
||||||
ans = LossRecord()
|
ans = LossRecord()
|
||||||
for k, v in self.items():
|
for k, v in self.items():
|
||||||
ans[k] = v
|
ans[k] = v
|
||||||
@ -436,19 +423,19 @@ class LossRecord(collections.defaultdict):
|
|||||||
ans[k] = ans[k] + v
|
ans[k] = ans[k] + v
|
||||||
return ans
|
return ans
|
||||||
|
|
||||||
def __mul__(self, alpha: float) -> 'LossRecord':
|
def __mul__(self, alpha: float) -> "LossRecord":
|
||||||
ans = LossRecord()
|
ans = LossRecord()
|
||||||
for k, v in self.items():
|
for k, v in self.items():
|
||||||
ans[k] = v * alpha
|
ans[k] = v * alpha
|
||||||
return ans
|
return ans
|
||||||
|
|
||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
ans = ''
|
ans = ""
|
||||||
for k, v in self.norm_items():
|
for k, v in self.norm_items():
|
||||||
norm_value = '%.4g' % v
|
norm_value = "%.4g" % v
|
||||||
ans += (str(k) + '=' + str(norm_value) + ', ')
|
ans += str(k) + "=" + str(norm_value) + ", "
|
||||||
frames = str(self['frames'])
|
frames = str(self["frames"])
|
||||||
ans += 'over ' + frames + ' frames.'
|
ans += "over " + frames + " frames."
|
||||||
return ans
|
return ans
|
||||||
|
|
||||||
def norm_items(self) -> List[Tuple[str, float]]:
|
def norm_items(self) -> List[Tuple[str, float]]:
|
||||||
@ -456,10 +443,10 @@ class LossRecord(collections.defaultdict):
|
|||||||
Returns a list of pairs, like:
|
Returns a list of pairs, like:
|
||||||
[('ctc_loss', 0.1), ('att_loss', 0.07)]
|
[('ctc_loss', 0.1), ('att_loss', 0.07)]
|
||||||
"""
|
"""
|
||||||
num_frames = self['frames'] if 'frames' in self else 1
|
num_frames = self["frames"] if "frames" in self else 1
|
||||||
ans = []
|
ans = []
|
||||||
for k, v in self.items():
|
for k, v in self.items():
|
||||||
if k != 'frames':
|
if k != "frames":
|
||||||
norm_value = float(v) / num_frames
|
norm_value = float(v) / num_frames
|
||||||
ans.append((k, norm_value))
|
ans.append((k, norm_value))
|
||||||
return ans
|
return ans
|
||||||
@ -470,17 +457,13 @@ class LossRecord(collections.defaultdict):
|
|||||||
all processes get the total.
|
all processes get the total.
|
||||||
"""
|
"""
|
||||||
keys = sorted(self.keys())
|
keys = sorted(self.keys())
|
||||||
s = torch.tensor([float(self[k]) for k in keys],
|
s = torch.tensor([float(self[k]) for k in keys], device=device)
|
||||||
device=device)
|
|
||||||
dist.all_reduce(s, op=dist.ReduceOp.SUM)
|
dist.all_reduce(s, op=dist.ReduceOp.SUM)
|
||||||
for k, v in zip(keys, s.cpu().tolist()):
|
for k, v in zip(keys, s.cpu().tolist()):
|
||||||
self[k] = v
|
self[k] = v
|
||||||
|
|
||||||
def write_summary(
|
def write_summary(
|
||||||
self,
|
self, tb_writer: SummaryWriter, prefix: str, batch_idx: int,
|
||||||
tb_writer: SummaryWriter,
|
|
||||||
prefix: str,
|
|
||||||
batch_idx: int,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Add logging information to a TensorBoard writer.
|
"""Add logging information to a TensorBoard writer.
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user