From 279dc74b4e49e0c6d940afb4ea0329e9976f26ef Mon Sep 17 00:00:00 2001 From: Mingshuang Luo <37799481+luomingshuang@users.noreply.github.com> Date: Wed, 29 Sep 2021 19:23:54 +0800 Subject: [PATCH] Add files via upload --- egs/librispeech/ASR/tdnn_lstm_ctc/train.py | 53 ++++++---------------- 1 file changed, 14 insertions(+), 39 deletions(-) diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/train.py b/egs/librispeech/ASR/tdnn_lstm_ctc/train.py index 016d51e2c..2b22e4e0f 100644 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/train.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/train.py @@ -58,10 +58,7 @@ def get_parser(): ) parser.add_argument( - "--world-size", - type=int, - default=1, - help="Number of GPUs for DDP training.", + "--world-size", type=int, default=1, help="Number of GPUs for DDP training.", ) parser.add_argument( @@ -79,10 +76,7 @@ def get_parser(): ) parser.add_argument( - "--num-epochs", - type=int, - default=20, - help="Number of epochs to train.", + "--num-epochs", type=int, default=20, help="Number of epochs to train.", ) parser.add_argument( @@ -209,10 +203,7 @@ def load_checkpoint_if_available( filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" saved_params = load_checkpoint( - filename, - model=model, - optimizer=optimizer, - scheduler=scheduler, + filename, model=model, optimizer=optimizer, scheduler=scheduler, ) keys = [ @@ -312,9 +303,7 @@ def compute_loss( decoding_graph = graph_compiler.compile(texts) dense_fsa_vec = k2.DenseFsaVec( - nnet_output, - supervision_segments, - allow_truncate=params.subsampling_factor - 1, + nnet_output, supervision_segments, allow_truncate=params.subsampling_factor - 1, ) loss = k2.ctc_loss( @@ -328,8 +317,8 @@ def compute_loss( assert loss.requires_grad == is_training info = LossRecord() - info['frames'] = supervision_segments[:, 2].sum().item() - info['loss'] = loss.detach().cpu().item() + info["frames"] = supervision_segments[:, 2].sum().item() + info["loss"] = loss.detach().cpu().item() return loss, info @@ -363,7 +352,7 @@ def compute_validation_loss( if world_size > 1: tot_loss.reduce(loss.device) - loss_value = tot_loss['loss'] / tot_loss['frames'] + loss_value = tot_loss["loss"] / tot_loss["frames"] if loss_value < params.best_valid_loss: params.best_valid_epoch = params.cur_epoch @@ -439,15 +428,9 @@ def train_one_epoch( if tb_writer is not None: loss_info.write_summary( - tb_writer, - "train/current_", - params.batch_idx_train - ) - tot_loss.write_summary( - tb_writer, - "train/tot_", - params.batch_idx_train + tb_writer, "train/current_", params.batch_idx_train ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if batch_idx > 0 and batch_idx % params.valid_interval == 0: valid_info = compute_validation_loss( @@ -458,17 +441,13 @@ def train_one_epoch( world_size=world_size, ) model.train() - logging.info( - f"Epoch {params.cur_epoch}, validation {valid_info}" - ) + logging.info(f"Epoch {params.cur_epoch}, validation {valid_info}") if tb_writer is not None: valid_info.write_summary( - tb_writer, - "train/valid_", - params.batch_idx_train, + tb_writer, "train/valid_", params.batch_idx_train, ) - loss_value = tot_loss['loss'] / tot_loss['frames'] + loss_value = tot_loss["loss"] / tot_loss["frames"] params.train_loss = loss_value if params.train_loss < params.best_train_loss: @@ -526,9 +505,7 @@ def run(rank, world_size, args): model = DDP(model, device_ids=[rank]) optimizer = optim.AdamW( - model.parameters(), - lr=params.lr, - weight_decay=params.weight_decay, + model.parameters(), lr=params.lr, weight_decay=params.weight_decay, ) scheduler = StepLR(optimizer, step_size=8, gamma=0.1) @@ -548,9 +525,7 @@ def run(rank, world_size, args): if tb_writer is not None: tb_writer.add_scalar( - "train/lr", - scheduler.get_last_lr()[0], - params.batch_idx_train, + "train/lr", scheduler.get_last_lr()[0], params.batch_idx_train, ) tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)