WIP: Add doc for the LibriSpeech recipe.

This commit is contained in:
Fangjun Kuang 2021-08-24 15:23:44 +08:00
parent 01da00dca0
commit 5b3cd5debd
5 changed files with 59 additions and 11 deletions

View File

@ -3,7 +3,7 @@
You can adapt this file completely to your liking, but it should at least You can adapt this file completely to your liking, but it should at least
contain the root `toctree` directive. contain the root `toctree` directive.
icefall Icefall
======= =======
.. image:: _static/logo.png .. image:: _static/logo.png

View File

@ -1,2 +1,10 @@
LibriSpeech LibriSpeech
=========== ===========
We provide the following models for the LibriSpeech dataset:
.. toctree::
:maxdepth: 2
librispeech/tdnn_lstm_ctc
librispeech/conformer_ctc

View File

@ -0,0 +1,2 @@
Confromer CTC
=============

View File

@ -0,0 +1,2 @@
TDNN LSTM CTC
=============

View File

@ -75,6 +75,23 @@ def get_parser():
help="Should various information be logged in tensorboard.", help="Should various information be logged in tensorboard.",
) )
parser.add_argument(
"--num-epochs",
type=int,
default=20,
help="Number of epochs to train.",
)
parser.add_argument(
"--start-epoch",
type=int,
default=0,
help="""Resume training from from this epoch.
If it is positive, it will load checkpoint from
tdnn_lstm_ctc/exp/epoch-{start_epoch-1}.pt
""",
)
return parser return parser
@ -104,11 +121,6 @@ def get_params() -> AttributeDict:
- subsampling_factor: The subsampling factor for the model. - subsampling_factor: The subsampling factor for the model.
- start_epoch: If it is not zero, load checkpoint `start_epoch-1`
and continue training from that checkpoint.
- num_epochs: Number of epochs to train.
- best_train_loss: Best training loss so far. It is used to select - best_train_loss: Best training loss so far. It is used to select
the model that has the lowest training loss. It is the model that has the lowest training loss. It is
updated during the training. updated during the training.
@ -127,6 +139,8 @@ def get_params() -> AttributeDict:
- log_interval: Print training loss if batch_idx % log_interval` is 0 - log_interval: Print training loss if batch_idx % log_interval` is 0
- reset_interval: Reset statistics if batch_idx % reset_interval is 0
- valid_interval: Run validation if batch_idx % valid_interval` is 0 - valid_interval: Run validation if batch_idx % valid_interval` is 0
- beam_size: It is used in k2.ctc_loss - beam_size: It is used in k2.ctc_loss
@ -143,14 +157,13 @@ def get_params() -> AttributeDict:
"feature_dim": 80, "feature_dim": 80,
"weight_decay": 5e-4, "weight_decay": 5e-4,
"subsampling_factor": 3, "subsampling_factor": 3,
"start_epoch": 0,
"num_epochs": 10,
"best_train_loss": float("inf"), "best_train_loss": float("inf"),
"best_valid_loss": float("inf"), "best_valid_loss": float("inf"),
"best_train_epoch": -1, "best_train_epoch": -1,
"best_valid_epoch": -1, "best_valid_epoch": -1,
"batch_idx_train": 0, "batch_idx_train": 0,
"log_interval": 10, "log_interval": 10,
"reset_interval": 200,
"valid_interval": 1000, "valid_interval": 1000,
"beam_size": 10, "beam_size": 10,
"reduction": "sum", "reduction": "sum",
@ -398,8 +411,12 @@ def train_one_epoch(
""" """
model.train() model.train()
tot_loss = 0.0 # sum of losses over all batches tot_loss = 0.0 # reset after params.reset_interval of batches
tot_frames = 0.0 # sum of frames over all batches tot_frames = 0.0 # reset after params.reset_interval of batches
params.tot_loss = 0.0
params.tot_frames = 0.0
for batch_idx, batch in enumerate(train_dl): for batch_idx, batch in enumerate(train_dl):
params.batch_idx_train += 1 params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
@ -426,6 +443,9 @@ def train_one_epoch(
tot_loss += loss_cpu tot_loss += loss_cpu
tot_avg_loss = tot_loss / tot_frames tot_avg_loss = tot_loss / tot_frames
params.tot_frames += params.train_frames
params.tot_loss += loss_cpu
if batch_idx % params.log_interval == 0: if batch_idx % params.log_interval == 0:
logging.info( logging.info(
f"Epoch {params.cur_epoch}, batch {batch_idx}, " f"Epoch {params.cur_epoch}, batch {batch_idx}, "
@ -433,6 +453,22 @@ def train_one_epoch(
f"total avg loss: {tot_avg_loss:.4f}, " f"total avg loss: {tot_avg_loss:.4f}, "
f"batch size: {batch_size}" f"batch size: {batch_size}"
) )
if tb_writer is not None:
tb_writer.add_scalar(
"train/current_loss",
loss_cpu / params.train_frames,
params.batch_idx_train,
)
tb_writer.add_scalar(
"train/tot_avg_loss",
tot_avg_loss,
params.batch_idx_train,
)
if batch_idx > 0 and batch_idx % params.reset_interval == 0:
tot_loss = 0
tot_frames = 0
if batch_idx > 0 and batch_idx % params.valid_interval == 0: if batch_idx > 0 and batch_idx % params.valid_interval == 0:
compute_validation_loss( compute_validation_loss(
@ -449,7 +485,7 @@ def train_one_epoch(
f"best valid epoch: {params.best_valid_epoch}" f"best valid epoch: {params.best_valid_epoch}"
) )
params.train_loss = tot_loss / tot_frames params.train_loss = params.tot_loss / params.tot_frames
if params.train_loss < params.best_train_loss: if params.train_loss < params.best_train_loss:
params.best_train_epoch = params.cur_epoch params.best_train_epoch = params.cur_epoch