From 5b3cd5debd1ebea15b45565cd22e5e0b5cd6b4cf Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Tue, 24 Aug 2021 15:23:44 +0800 Subject: [PATCH] WIP: Add doc for the LibriSpeech recipe. --- docs/source/index.rst | 2 +- docs/source/recipes/librispeech.rst | 8 +++ .../recipes/librispeech/conformer_ctc.rst | 2 + .../recipes/librispeech/tdnn_lstm_ctc.rst | 2 + egs/librispeech/ASR/tdnn_lstm_ctc/train.py | 56 +++++++++++++++---- 5 files changed, 59 insertions(+), 11 deletions(-) create mode 100644 docs/source/recipes/librispeech/conformer_ctc.rst create mode 100644 docs/source/recipes/librispeech/tdnn_lstm_ctc.rst diff --git a/docs/source/index.rst b/docs/source/index.rst index c5cd2e832..9313f1a67 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -3,7 +3,7 @@ You can adapt this file completely to your liking, but it should at least contain the root `toctree` directive. -icefall +Icefall ======= .. image:: _static/logo.png diff --git a/docs/source/recipes/librispeech.rst b/docs/source/recipes/librispeech.rst index 5b6ca04d4..946b23407 100644 --- a/docs/source/recipes/librispeech.rst +++ b/docs/source/recipes/librispeech.rst @@ -1,2 +1,10 @@ LibriSpeech =========== + +We provide the following models for the LibriSpeech dataset: + +.. toctree:: + :maxdepth: 2 + + librispeech/tdnn_lstm_ctc + librispeech/conformer_ctc diff --git a/docs/source/recipes/librispeech/conformer_ctc.rst b/docs/source/recipes/librispeech/conformer_ctc.rst new file mode 100644 index 000000000..4d531bf26 --- /dev/null +++ b/docs/source/recipes/librispeech/conformer_ctc.rst @@ -0,0 +1,2 @@ +Confromer CTC +============= diff --git a/docs/source/recipes/librispeech/tdnn_lstm_ctc.rst b/docs/source/recipes/librispeech/tdnn_lstm_ctc.rst new file mode 100644 index 000000000..373bb5905 --- /dev/null +++ b/docs/source/recipes/librispeech/tdnn_lstm_ctc.rst @@ -0,0 +1,2 @@ +TDNN LSTM CTC +============= diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/train.py b/egs/librispeech/ASR/tdnn_lstm_ctc/train.py index 23e224f76..4d45d197b 100755 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/train.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/train.py @@ -75,6 +75,23 @@ def get_parser(): help="Should various information be logged in tensorboard.", ) + parser.add_argument( + "--num-epochs", + type=int, + default=20, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=0, + help="""Resume training from from this epoch. + If it is positive, it will load checkpoint from + tdnn_lstm_ctc/exp/epoch-{start_epoch-1}.pt + """, + ) + return parser @@ -104,11 +121,6 @@ def get_params() -> AttributeDict: - subsampling_factor: The subsampling factor for the model. - - start_epoch: If it is not zero, load checkpoint `start_epoch-1` - and continue training from that checkpoint. - - - num_epochs: Number of epochs to train. - - best_train_loss: Best training loss so far. It is used to select the model that has the lowest training loss. It is updated during the training. @@ -127,6 +139,8 @@ def get_params() -> AttributeDict: - log_interval: Print training loss if batch_idx % log_interval` is 0 + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + - valid_interval: Run validation if batch_idx % valid_interval` is 0 - beam_size: It is used in k2.ctc_loss @@ -143,14 +157,13 @@ def get_params() -> AttributeDict: "feature_dim": 80, "weight_decay": 5e-4, "subsampling_factor": 3, - "start_epoch": 0, - "num_epochs": 10, "best_train_loss": float("inf"), "best_valid_loss": float("inf"), "best_train_epoch": -1, "best_valid_epoch": -1, "batch_idx_train": 0, "log_interval": 10, + "reset_interval": 200, "valid_interval": 1000, "beam_size": 10, "reduction": "sum", @@ -398,8 +411,12 @@ def train_one_epoch( """ model.train() - tot_loss = 0.0 # sum of losses over all batches - tot_frames = 0.0 # sum of frames over all batches + tot_loss = 0.0 # reset after params.reset_interval of batches + tot_frames = 0.0 # reset after params.reset_interval of batches + + params.tot_loss = 0.0 + params.tot_frames = 0.0 + for batch_idx, batch in enumerate(train_dl): params.batch_idx_train += 1 batch_size = len(batch["supervisions"]["text"]) @@ -426,6 +443,9 @@ def train_one_epoch( tot_loss += loss_cpu tot_avg_loss = tot_loss / tot_frames + params.tot_frames += params.train_frames + params.tot_loss += loss_cpu + if batch_idx % params.log_interval == 0: logging.info( f"Epoch {params.cur_epoch}, batch {batch_idx}, " @@ -433,6 +453,22 @@ def train_one_epoch( f"total avg loss: {tot_avg_loss:.4f}, " f"batch size: {batch_size}" ) + if tb_writer is not None: + tb_writer.add_scalar( + "train/current_loss", + loss_cpu / params.train_frames, + params.batch_idx_train, + ) + + tb_writer.add_scalar( + "train/tot_avg_loss", + tot_avg_loss, + params.batch_idx_train, + ) + + if batch_idx > 0 and batch_idx % params.reset_interval == 0: + tot_loss = 0 + tot_frames = 0 if batch_idx > 0 and batch_idx % params.valid_interval == 0: compute_validation_loss( @@ -449,7 +485,7 @@ def train_one_epoch( f"best valid epoch: {params.best_valid_epoch}" ) - params.train_loss = tot_loss / tot_frames + params.train_loss = params.tot_loss / params.tot_frames if params.train_loss < params.best_train_loss: params.best_train_epoch = params.cur_epoch