This commit is contained in:
Piotr Żelasko 2022-01-21 17:22:41 -05:00
parent f28951f2b6
commit f0f35e6671
2 changed files with 15 additions and 5 deletions

View File

@ -359,7 +359,9 @@ def compute_loss(
# Works with a phone lexicon # Works with a phone lexicon
decoding_graph = graph_compiler.compile(texts) decoding_graph = graph_compiler.compile(texts)
else: else:
raise ValueError(f"Unsupported type of graph compiler: {type(graph_compiler)}") raise ValueError(
f"Unsupported type of graph compiler: {type(graph_compiler)}"
)
dense_fsa_vec = k2.DenseFsaVec( dense_fsa_vec = k2.DenseFsaVec(
nnet_output, nnet_output,
@ -386,7 +388,9 @@ def compute_loss(
# #
# See https://github.com/k2-fsa/icefall/issues/97 # See https://github.com/k2-fsa/icefall/issues/97
# for more details # for more details
unsorted_token_ids = graph_compiler.texts_to_ids(supervisions["text"]) unsorted_token_ids = graph_compiler.texts_to_ids(
supervisions["text"]
)
att_loss = mmodel.decoder_forward( att_loss = mmodel.decoder_forward(
encoder_memory, encoder_memory,
memory_mask, memory_mask,
@ -519,7 +523,9 @@ def train_one_epoch(
loss_info.write_summary( loss_info.write_summary(
tb_writer, "train/current_", params.batch_idx_train tb_writer, "train/current_", params.batch_idx_train
) )
tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) tot_loss.write_summary(
tb_writer, "train/tot_", params.batch_idx_train
)
if batch_idx > 0 and batch_idx % params.valid_interval == 0: if batch_idx > 0 and batch_idx % params.valid_interval == 0:
logging.info("Computing validation loss") logging.info("Computing validation loss")
@ -654,7 +660,9 @@ def run(rank, world_size, args):
cur_lr = optimizer._rate cur_lr = optimizer._rate
if tb_writer is not None: if tb_writer is not None:
tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train) tb_writer.add_scalar(
"train/learning_rate", cur_lr, params.batch_idx_train
)
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
if rank == 0: if rank == 0:

View File

@ -75,7 +75,9 @@ class CtcTrainingGraphCompiler(object):
# NOTE: k2.compose runs on CUDA only when treat_epsilons_specially # NOTE: k2.compose runs on CUDA only when treat_epsilons_specially
# is False, so we add epsilon self-loops here # is False, so we add epsilon self-loops here
fsa_with_self_loops = k2.remove_epsilon_and_add_self_loops(transcript_fsa) fsa_with_self_loops = k2.remove_epsilon_and_add_self_loops(
transcript_fsa
)
fsa_with_self_loops = k2.arc_sort(fsa_with_self_loops) fsa_with_self_loops = k2.arc_sort(fsa_with_self_loops)