diff --git a/egs/librispeech/ASR/conformer_ctc/train.py b/egs/librispeech/ASR/conformer_ctc/train.py index d57d693a0..ff0c8f9f4 100755 --- a/egs/librispeech/ASR/conformer_ctc/train.py +++ b/egs/librispeech/ASR/conformer_ctc/train.py @@ -359,7 +359,9 @@ def compute_loss( # Works with a phone lexicon decoding_graph = graph_compiler.compile(texts) else: - raise ValueError(f"Unsupported type of graph compiler: {type(graph_compiler)}") + raise ValueError( + f"Unsupported type of graph compiler: {type(graph_compiler)}" + ) dense_fsa_vec = k2.DenseFsaVec( nnet_output, @@ -386,7 +388,9 @@ def compute_loss( # # See https://github.com/k2-fsa/icefall/issues/97 # for more details - unsorted_token_ids = graph_compiler.texts_to_ids(supervisions["text"]) + unsorted_token_ids = graph_compiler.texts_to_ids( + supervisions["text"] + ) att_loss = mmodel.decoder_forward( encoder_memory, memory_mask, @@ -519,7 +523,9 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + tot_loss.write_summary( + tb_writer, "train/tot_", params.batch_idx_train + ) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -654,7 +660,9 @@ def run(rank, world_size, args): cur_lr = optimizer._rate if tb_writer is not None: - tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train) + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) if rank == 0: diff --git a/icefall/graph_compiler.py b/icefall/graph_compiler.py index e2ff03f61..570ed7d7a 100644 --- a/icefall/graph_compiler.py +++ b/icefall/graph_compiler.py @@ -75,7 +75,9 @@ class CtcTrainingGraphCompiler(object): # NOTE: k2.compose runs on CUDA only when treat_epsilons_specially # is False, so we add epsilon self-loops here - fsa_with_self_loops = k2.remove_epsilon_and_add_self_loops(transcript_fsa) + fsa_with_self_loops = k2.remove_epsilon_and_add_self_loops( + transcript_fsa + ) fsa_with_self_loops = k2.arc_sort(fsa_with_self_loops)