From bdd890bab9dea20b448a4d01743303f886d8262c Mon Sep 17 00:00:00 2001 From: Mingshuang Luo <37799481+luomingshuang@users.noreply.github.com> Date: Wed, 29 Sep 2021 19:23:11 +0800 Subject: [PATCH] Add files via upload --- egs/librispeech/ASR/conformer_ctc/train.py | 63 ++++++---------------- 1 file changed, 17 insertions(+), 46 deletions(-) diff --git a/egs/librispeech/ASR/conformer_ctc/train.py b/egs/librispeech/ASR/conformer_ctc/train.py index 34b99cd2d..98bd47bc1 100644 --- a/egs/librispeech/ASR/conformer_ctc/train.py +++ b/egs/librispeech/ASR/conformer_ctc/train.py @@ -59,10 +59,7 @@ def get_parser(): ) parser.add_argument( - "--world-size", - type=int, - default=1, - help="Number of GPUs for DDP training.", + "--world-size", type=int, default=1, help="Number of GPUs for DDP training.", ) parser.add_argument( @@ -80,10 +77,7 @@ def get_parser(): ) parser.add_argument( - "--num-epochs", - type=int, - default=35, - help="Number of epochs to train.", + "--num-epochs", type=int, default=35, help="Number of epochs to train.", ) parser.add_argument( @@ -230,10 +224,7 @@ def load_checkpoint_if_available( filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" saved_params = load_checkpoint( - filename, - model=model, - optimizer=optimizer, - scheduler=scheduler, + filename, model=model, optimizer=optimizer, scheduler=scheduler, ) keys = [ @@ -335,9 +326,7 @@ def compute_loss( decoding_graph = graph_compiler.compile(token_ids) dense_fsa_vec = k2.DenseFsaVec( - nnet_output, - supervision_segments, - allow_truncate=params.subsampling_factor - 1, + nnet_output, supervision_segments, allow_truncate=params.subsampling_factor - 1, ) ctc_loss = k2.ctc_loss( @@ -374,12 +363,12 @@ def compute_loss( assert loss.requires_grad == is_training info = LossRecord() - info['frames'] = supervision_segments[:, 2].sum().item() - info['ctc_loss'] = ctc_loss.detach().cpu().item() + info["frames"] = supervision_segments[:, 2].sum().item() + info["ctc_loss"] = ctc_loss.detach().cpu().item() if params.att_rate != 0.0: - info['att_loss'] = att_loss.detach().cpu().item() + info["att_loss"] = att_loss.detach().cpu().item() - info['loss'] = loss.detach().cpu().item() + info["loss"] = loss.detach().cpu().item() return loss, info @@ -410,7 +399,7 @@ def compute_validation_loss( if world_size > 1: tot_loss.reduce(loss.device) - loss_value = tot_loss['loss'] / tot_loss['frames'] + loss_value = tot_loss["loss"] / tot_loss["frames"] if loss_value < params.best_valid_loss: params.best_valid_epoch = params.cur_epoch params.best_valid_loss = loss_value @@ -489,15 +478,9 @@ def train_one_epoch( if tb_writer is not None: loss_info.write_summary( - tb_writer, - "train/current_", - params.batch_idx_train - ) - tot_loss.write_summary( - tb_writer, - "train/tot_", - params.batch_idx_train + tb_writer, "train/current_", params.batch_idx_train ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -509,17 +492,13 @@ def train_one_epoch( world_size=world_size, ) model.train() - logging.info( - f"Epoch {params.cur_epoch}, validation: {valid_info}" - ) + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") if tb_writer is not None: valid_info.write_summary( - tb_writer, - "train/valid_", - params.batch_idx_train + tb_writer, "train/valid_", params.batch_idx_train ) - loss_value = tot_loss['loss'] / tot_loss['frames'] + loss_value = tot_loss["loss"] / tot_loss["frames"] params.train_loss = loss_value if params.train_loss < params.best_train_loss: params.best_train_epoch = params.cur_epoch @@ -563,10 +542,7 @@ def run(rank, world_size, args): device = torch.device("cuda", rank) graph_compiler = BpeCtcTrainingGraphCompiler( - params.lang_dir, - device=device, - sos_token="", - eos_token="", + params.lang_dir, device=device, sos_token="", eos_token="", ) logging.info("About to create model") @@ -607,9 +583,7 @@ def run(rank, world_size, args): cur_lr = optimizer._rate if tb_writer is not None: - tb_writer.add_scalar( - "train/learning_rate", cur_lr, params.batch_idx_train - ) + tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train) tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) if rank == 0: @@ -629,10 +603,7 @@ def run(rank, world_size, args): ) save_checkpoint( - params=params, - model=model, - optimizer=optimizer, - rank=rank, + params=params, model=model, optimizer=optimizer, rank=rank, ) logging.info("Done!")