mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
black
This commit is contained in:
parent
f28951f2b6
commit
f0f35e6671
@ -359,7 +359,9 @@ def compute_loss(
|
|||||||
# Works with a phone lexicon
|
# Works with a phone lexicon
|
||||||
decoding_graph = graph_compiler.compile(texts)
|
decoding_graph = graph_compiler.compile(texts)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported type of graph compiler: {type(graph_compiler)}")
|
raise ValueError(
|
||||||
|
f"Unsupported type of graph compiler: {type(graph_compiler)}"
|
||||||
|
)
|
||||||
|
|
||||||
dense_fsa_vec = k2.DenseFsaVec(
|
dense_fsa_vec = k2.DenseFsaVec(
|
||||||
nnet_output,
|
nnet_output,
|
||||||
@ -386,7 +388,9 @@ def compute_loss(
|
|||||||
#
|
#
|
||||||
# See https://github.com/k2-fsa/icefall/issues/97
|
# See https://github.com/k2-fsa/icefall/issues/97
|
||||||
# for more details
|
# for more details
|
||||||
unsorted_token_ids = graph_compiler.texts_to_ids(supervisions["text"])
|
unsorted_token_ids = graph_compiler.texts_to_ids(
|
||||||
|
supervisions["text"]
|
||||||
|
)
|
||||||
att_loss = mmodel.decoder_forward(
|
att_loss = mmodel.decoder_forward(
|
||||||
encoder_memory,
|
encoder_memory,
|
||||||
memory_mask,
|
memory_mask,
|
||||||
@ -519,7 +523,9 @@ def train_one_epoch(
|
|||||||
loss_info.write_summary(
|
loss_info.write_summary(
|
||||||
tb_writer, "train/current_", params.batch_idx_train
|
tb_writer, "train/current_", params.batch_idx_train
|
||||||
)
|
)
|
||||||
tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
|
tot_loss.write_summary(
|
||||||
|
tb_writer, "train/tot_", params.batch_idx_train
|
||||||
|
)
|
||||||
|
|
||||||
if batch_idx > 0 and batch_idx % params.valid_interval == 0:
|
if batch_idx > 0 and batch_idx % params.valid_interval == 0:
|
||||||
logging.info("Computing validation loss")
|
logging.info("Computing validation loss")
|
||||||
@ -654,7 +660,9 @@ def run(rank, world_size, args):
|
|||||||
|
|
||||||
cur_lr = optimizer._rate
|
cur_lr = optimizer._rate
|
||||||
if tb_writer is not None:
|
if tb_writer is not None:
|
||||||
tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train)
|
tb_writer.add_scalar(
|
||||||
|
"train/learning_rate", cur_lr, params.batch_idx_train
|
||||||
|
)
|
||||||
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
|
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
|
||||||
|
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
|
@ -75,7 +75,9 @@ class CtcTrainingGraphCompiler(object):
|
|||||||
|
|
||||||
# NOTE: k2.compose runs on CUDA only when treat_epsilons_specially
|
# NOTE: k2.compose runs on CUDA only when treat_epsilons_specially
|
||||||
# is False, so we add epsilon self-loops here
|
# is False, so we add epsilon self-loops here
|
||||||
fsa_with_self_loops = k2.remove_epsilon_and_add_self_loops(transcript_fsa)
|
fsa_with_self_loops = k2.remove_epsilon_and_add_self_loops(
|
||||||
|
transcript_fsa
|
||||||
|
)
|
||||||
|
|
||||||
fsa_with_self_loops = k2.arc_sort(fsa_with_self_loops)
|
fsa_with_self_loops = k2.arc_sort(fsa_with_self_loops)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user