mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-13 12:02:21 +00:00
WIP: Add doc for the LibriSpeech recipe.
This commit is contained in:
parent
01da00dca0
commit
5b3cd5debd
@ -3,7 +3,7 @@
|
||||
You can adapt this file completely to your liking, but it should at least
|
||||
contain the root `toctree` directive.
|
||||
|
||||
icefall
|
||||
Icefall
|
||||
=======
|
||||
|
||||
.. image:: _static/logo.png
|
||||
|
@ -1,2 +1,10 @@
|
||||
LibriSpeech
|
||||
===========
|
||||
|
||||
We provide the following models for the LibriSpeech dataset:
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 2
|
||||
|
||||
librispeech/tdnn_lstm_ctc
|
||||
librispeech/conformer_ctc
|
||||
|
2
docs/source/recipes/librispeech/conformer_ctc.rst
Normal file
2
docs/source/recipes/librispeech/conformer_ctc.rst
Normal file
@ -0,0 +1,2 @@
|
||||
Confromer CTC
|
||||
=============
|
2
docs/source/recipes/librispeech/tdnn_lstm_ctc.rst
Normal file
2
docs/source/recipes/librispeech/tdnn_lstm_ctc.rst
Normal file
@ -0,0 +1,2 @@
|
||||
TDNN LSTM CTC
|
||||
=============
|
@ -75,6 +75,23 @@ def get_parser():
|
||||
help="Should various information be logged in tensorboard.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-epochs",
|
||||
type=int,
|
||||
default=20,
|
||||
help="Number of epochs to train.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--start-epoch",
|
||||
type=int,
|
||||
default=0,
|
||||
help="""Resume training from from this epoch.
|
||||
If it is positive, it will load checkpoint from
|
||||
tdnn_lstm_ctc/exp/epoch-{start_epoch-1}.pt
|
||||
""",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
@ -104,11 +121,6 @@ def get_params() -> AttributeDict:
|
||||
|
||||
- subsampling_factor: The subsampling factor for the model.
|
||||
|
||||
- start_epoch: If it is not zero, load checkpoint `start_epoch-1`
|
||||
and continue training from that checkpoint.
|
||||
|
||||
- num_epochs: Number of epochs to train.
|
||||
|
||||
- best_train_loss: Best training loss so far. It is used to select
|
||||
the model that has the lowest training loss. It is
|
||||
updated during the training.
|
||||
@ -127,6 +139,8 @@ def get_params() -> AttributeDict:
|
||||
|
||||
- log_interval: Print training loss if batch_idx % log_interval` is 0
|
||||
|
||||
- reset_interval: Reset statistics if batch_idx % reset_interval is 0
|
||||
|
||||
- valid_interval: Run validation if batch_idx % valid_interval` is 0
|
||||
|
||||
- beam_size: It is used in k2.ctc_loss
|
||||
@ -143,14 +157,13 @@ def get_params() -> AttributeDict:
|
||||
"feature_dim": 80,
|
||||
"weight_decay": 5e-4,
|
||||
"subsampling_factor": 3,
|
||||
"start_epoch": 0,
|
||||
"num_epochs": 10,
|
||||
"best_train_loss": float("inf"),
|
||||
"best_valid_loss": float("inf"),
|
||||
"best_train_epoch": -1,
|
||||
"best_valid_epoch": -1,
|
||||
"batch_idx_train": 0,
|
||||
"log_interval": 10,
|
||||
"reset_interval": 200,
|
||||
"valid_interval": 1000,
|
||||
"beam_size": 10,
|
||||
"reduction": "sum",
|
||||
@ -398,8 +411,12 @@ def train_one_epoch(
|
||||
"""
|
||||
model.train()
|
||||
|
||||
tot_loss = 0.0 # sum of losses over all batches
|
||||
tot_frames = 0.0 # sum of frames over all batches
|
||||
tot_loss = 0.0 # reset after params.reset_interval of batches
|
||||
tot_frames = 0.0 # reset after params.reset_interval of batches
|
||||
|
||||
params.tot_loss = 0.0
|
||||
params.tot_frames = 0.0
|
||||
|
||||
for batch_idx, batch in enumerate(train_dl):
|
||||
params.batch_idx_train += 1
|
||||
batch_size = len(batch["supervisions"]["text"])
|
||||
@ -426,6 +443,9 @@ def train_one_epoch(
|
||||
tot_loss += loss_cpu
|
||||
tot_avg_loss = tot_loss / tot_frames
|
||||
|
||||
params.tot_frames += params.train_frames
|
||||
params.tot_loss += loss_cpu
|
||||
|
||||
if batch_idx % params.log_interval == 0:
|
||||
logging.info(
|
||||
f"Epoch {params.cur_epoch}, batch {batch_idx}, "
|
||||
@ -433,6 +453,22 @@ def train_one_epoch(
|
||||
f"total avg loss: {tot_avg_loss:.4f}, "
|
||||
f"batch size: {batch_size}"
|
||||
)
|
||||
if tb_writer is not None:
|
||||
tb_writer.add_scalar(
|
||||
"train/current_loss",
|
||||
loss_cpu / params.train_frames,
|
||||
params.batch_idx_train,
|
||||
)
|
||||
|
||||
tb_writer.add_scalar(
|
||||
"train/tot_avg_loss",
|
||||
tot_avg_loss,
|
||||
params.batch_idx_train,
|
||||
)
|
||||
|
||||
if batch_idx > 0 and batch_idx % params.reset_interval == 0:
|
||||
tot_loss = 0
|
||||
tot_frames = 0
|
||||
|
||||
if batch_idx > 0 and batch_idx % params.valid_interval == 0:
|
||||
compute_validation_loss(
|
||||
@ -449,7 +485,7 @@ def train_one_epoch(
|
||||
f"best valid epoch: {params.best_valid_epoch}"
|
||||
)
|
||||
|
||||
params.train_loss = tot_loss / tot_frames
|
||||
params.train_loss = params.tot_loss / params.tot_frames
|
||||
|
||||
if params.train_loss < params.best_train_loss:
|
||||
params.best_train_epoch = params.cur_epoch
|
||||
|
Loading…
x
Reference in New Issue
Block a user