WIP: Add doc for the LibriSpeech recipe.

This commit is contained in:
Fangjun Kuang 2021-08-24 15:23:44 +08:00
parent 01da00dca0
commit 5b3cd5debd
5 changed files with 59 additions and 11 deletions

View File

@ -3,7 +3,7 @@
You can adapt this file completely to your liking, but it should at least
contain the root `toctree` directive.
icefall
Icefall
=======
.. image:: _static/logo.png

View File

@ -1,2 +1,10 @@
LibriSpeech
===========
We provide the following models for the LibriSpeech dataset:
.. toctree::
:maxdepth: 2
librispeech/tdnn_lstm_ctc
librispeech/conformer_ctc

View File

@ -0,0 +1,2 @@
Confromer CTC
=============

View File

@ -0,0 +1,2 @@
TDNN LSTM CTC
=============

View File

@ -75,6 +75,23 @@ def get_parser():
help="Should various information be logged in tensorboard.",
)
parser.add_argument(
"--num-epochs",
type=int,
default=20,
help="Number of epochs to train.",
)
parser.add_argument(
"--start-epoch",
type=int,
default=0,
help="""Resume training from from this epoch.
If it is positive, it will load checkpoint from
tdnn_lstm_ctc/exp/epoch-{start_epoch-1}.pt
""",
)
return parser
@ -104,11 +121,6 @@ def get_params() -> AttributeDict:
- subsampling_factor: The subsampling factor for the model.
- start_epoch: If it is not zero, load checkpoint `start_epoch-1`
and continue training from that checkpoint.
- num_epochs: Number of epochs to train.
- best_train_loss: Best training loss so far. It is used to select
the model that has the lowest training loss. It is
updated during the training.
@ -127,6 +139,8 @@ def get_params() -> AttributeDict:
- log_interval: Print training loss if batch_idx % log_interval` is 0
- reset_interval: Reset statistics if batch_idx % reset_interval is 0
- valid_interval: Run validation if batch_idx % valid_interval` is 0
- beam_size: It is used in k2.ctc_loss
@ -143,14 +157,13 @@ def get_params() -> AttributeDict:
"feature_dim": 80,
"weight_decay": 5e-4,
"subsampling_factor": 3,
"start_epoch": 0,
"num_epochs": 10,
"best_train_loss": float("inf"),
"best_valid_loss": float("inf"),
"best_train_epoch": -1,
"best_valid_epoch": -1,
"batch_idx_train": 0,
"log_interval": 10,
"reset_interval": 200,
"valid_interval": 1000,
"beam_size": 10,
"reduction": "sum",
@ -398,8 +411,12 @@ def train_one_epoch(
"""
model.train()
tot_loss = 0.0 # sum of losses over all batches
tot_frames = 0.0 # sum of frames over all batches
tot_loss = 0.0 # reset after params.reset_interval of batches
tot_frames = 0.0 # reset after params.reset_interval of batches
params.tot_loss = 0.0
params.tot_frames = 0.0
for batch_idx, batch in enumerate(train_dl):
params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"])
@ -426,6 +443,9 @@ def train_one_epoch(
tot_loss += loss_cpu
tot_avg_loss = tot_loss / tot_frames
params.tot_frames += params.train_frames
params.tot_loss += loss_cpu
if batch_idx % params.log_interval == 0:
logging.info(
f"Epoch {params.cur_epoch}, batch {batch_idx}, "
@ -433,6 +453,22 @@ def train_one_epoch(
f"total avg loss: {tot_avg_loss:.4f}, "
f"batch size: {batch_size}"
)
if tb_writer is not None:
tb_writer.add_scalar(
"train/current_loss",
loss_cpu / params.train_frames,
params.batch_idx_train,
)
tb_writer.add_scalar(
"train/tot_avg_loss",
tot_avg_loss,
params.batch_idx_train,
)
if batch_idx > 0 and batch_idx % params.reset_interval == 0:
tot_loss = 0
tot_frames = 0
if batch_idx > 0 and batch_idx % params.valid_interval == 0:
compute_validation_loss(
@ -449,7 +485,7 @@ def train_one_epoch(
f"best valid epoch: {params.best_valid_epoch}"
)
params.train_loss = tot_loss / tot_frames
params.train_loss = params.tot_loss / params.tot_frames
if params.train_loss < params.best_train_loss:
params.best_train_epoch = params.cur_epoch