mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-14 04:22:21 +00:00
WIP: Add doc for the LibriSpeech recipe.
This commit is contained in:
parent
01da00dca0
commit
5b3cd5debd
@ -3,7 +3,7 @@
|
|||||||
You can adapt this file completely to your liking, but it should at least
|
You can adapt this file completely to your liking, but it should at least
|
||||||
contain the root `toctree` directive.
|
contain the root `toctree` directive.
|
||||||
|
|
||||||
icefall
|
Icefall
|
||||||
=======
|
=======
|
||||||
|
|
||||||
.. image:: _static/logo.png
|
.. image:: _static/logo.png
|
||||||
|
@ -1,2 +1,10 @@
|
|||||||
LibriSpeech
|
LibriSpeech
|
||||||
===========
|
===========
|
||||||
|
|
||||||
|
We provide the following models for the LibriSpeech dataset:
|
||||||
|
|
||||||
|
.. toctree::
|
||||||
|
:maxdepth: 2
|
||||||
|
|
||||||
|
librispeech/tdnn_lstm_ctc
|
||||||
|
librispeech/conformer_ctc
|
||||||
|
2
docs/source/recipes/librispeech/conformer_ctc.rst
Normal file
2
docs/source/recipes/librispeech/conformer_ctc.rst
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
Confromer CTC
|
||||||
|
=============
|
2
docs/source/recipes/librispeech/tdnn_lstm_ctc.rst
Normal file
2
docs/source/recipes/librispeech/tdnn_lstm_ctc.rst
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
TDNN LSTM CTC
|
||||||
|
=============
|
@ -75,6 +75,23 @@ def get_parser():
|
|||||||
help="Should various information be logged in tensorboard.",
|
help="Should various information be logged in tensorboard.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--num-epochs",
|
||||||
|
type=int,
|
||||||
|
default=20,
|
||||||
|
help="Number of epochs to train.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--start-epoch",
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
help="""Resume training from from this epoch.
|
||||||
|
If it is positive, it will load checkpoint from
|
||||||
|
tdnn_lstm_ctc/exp/epoch-{start_epoch-1}.pt
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
@ -104,11 +121,6 @@ def get_params() -> AttributeDict:
|
|||||||
|
|
||||||
- subsampling_factor: The subsampling factor for the model.
|
- subsampling_factor: The subsampling factor for the model.
|
||||||
|
|
||||||
- start_epoch: If it is not zero, load checkpoint `start_epoch-1`
|
|
||||||
and continue training from that checkpoint.
|
|
||||||
|
|
||||||
- num_epochs: Number of epochs to train.
|
|
||||||
|
|
||||||
- best_train_loss: Best training loss so far. It is used to select
|
- best_train_loss: Best training loss so far. It is used to select
|
||||||
the model that has the lowest training loss. It is
|
the model that has the lowest training loss. It is
|
||||||
updated during the training.
|
updated during the training.
|
||||||
@ -127,6 +139,8 @@ def get_params() -> AttributeDict:
|
|||||||
|
|
||||||
- log_interval: Print training loss if batch_idx % log_interval` is 0
|
- log_interval: Print training loss if batch_idx % log_interval` is 0
|
||||||
|
|
||||||
|
- reset_interval: Reset statistics if batch_idx % reset_interval is 0
|
||||||
|
|
||||||
- valid_interval: Run validation if batch_idx % valid_interval` is 0
|
- valid_interval: Run validation if batch_idx % valid_interval` is 0
|
||||||
|
|
||||||
- beam_size: It is used in k2.ctc_loss
|
- beam_size: It is used in k2.ctc_loss
|
||||||
@ -143,14 +157,13 @@ def get_params() -> AttributeDict:
|
|||||||
"feature_dim": 80,
|
"feature_dim": 80,
|
||||||
"weight_decay": 5e-4,
|
"weight_decay": 5e-4,
|
||||||
"subsampling_factor": 3,
|
"subsampling_factor": 3,
|
||||||
"start_epoch": 0,
|
|
||||||
"num_epochs": 10,
|
|
||||||
"best_train_loss": float("inf"),
|
"best_train_loss": float("inf"),
|
||||||
"best_valid_loss": float("inf"),
|
"best_valid_loss": float("inf"),
|
||||||
"best_train_epoch": -1,
|
"best_train_epoch": -1,
|
||||||
"best_valid_epoch": -1,
|
"best_valid_epoch": -1,
|
||||||
"batch_idx_train": 0,
|
"batch_idx_train": 0,
|
||||||
"log_interval": 10,
|
"log_interval": 10,
|
||||||
|
"reset_interval": 200,
|
||||||
"valid_interval": 1000,
|
"valid_interval": 1000,
|
||||||
"beam_size": 10,
|
"beam_size": 10,
|
||||||
"reduction": "sum",
|
"reduction": "sum",
|
||||||
@ -398,8 +411,12 @@ def train_one_epoch(
|
|||||||
"""
|
"""
|
||||||
model.train()
|
model.train()
|
||||||
|
|
||||||
tot_loss = 0.0 # sum of losses over all batches
|
tot_loss = 0.0 # reset after params.reset_interval of batches
|
||||||
tot_frames = 0.0 # sum of frames over all batches
|
tot_frames = 0.0 # reset after params.reset_interval of batches
|
||||||
|
|
||||||
|
params.tot_loss = 0.0
|
||||||
|
params.tot_frames = 0.0
|
||||||
|
|
||||||
for batch_idx, batch in enumerate(train_dl):
|
for batch_idx, batch in enumerate(train_dl):
|
||||||
params.batch_idx_train += 1
|
params.batch_idx_train += 1
|
||||||
batch_size = len(batch["supervisions"]["text"])
|
batch_size = len(batch["supervisions"]["text"])
|
||||||
@ -426,6 +443,9 @@ def train_one_epoch(
|
|||||||
tot_loss += loss_cpu
|
tot_loss += loss_cpu
|
||||||
tot_avg_loss = tot_loss / tot_frames
|
tot_avg_loss = tot_loss / tot_frames
|
||||||
|
|
||||||
|
params.tot_frames += params.train_frames
|
||||||
|
params.tot_loss += loss_cpu
|
||||||
|
|
||||||
if batch_idx % params.log_interval == 0:
|
if batch_idx % params.log_interval == 0:
|
||||||
logging.info(
|
logging.info(
|
||||||
f"Epoch {params.cur_epoch}, batch {batch_idx}, "
|
f"Epoch {params.cur_epoch}, batch {batch_idx}, "
|
||||||
@ -433,6 +453,22 @@ def train_one_epoch(
|
|||||||
f"total avg loss: {tot_avg_loss:.4f}, "
|
f"total avg loss: {tot_avg_loss:.4f}, "
|
||||||
f"batch size: {batch_size}"
|
f"batch size: {batch_size}"
|
||||||
)
|
)
|
||||||
|
if tb_writer is not None:
|
||||||
|
tb_writer.add_scalar(
|
||||||
|
"train/current_loss",
|
||||||
|
loss_cpu / params.train_frames,
|
||||||
|
params.batch_idx_train,
|
||||||
|
)
|
||||||
|
|
||||||
|
tb_writer.add_scalar(
|
||||||
|
"train/tot_avg_loss",
|
||||||
|
tot_avg_loss,
|
||||||
|
params.batch_idx_train,
|
||||||
|
)
|
||||||
|
|
||||||
|
if batch_idx > 0 and batch_idx % params.reset_interval == 0:
|
||||||
|
tot_loss = 0
|
||||||
|
tot_frames = 0
|
||||||
|
|
||||||
if batch_idx > 0 and batch_idx % params.valid_interval == 0:
|
if batch_idx > 0 and batch_idx % params.valid_interval == 0:
|
||||||
compute_validation_loss(
|
compute_validation_loss(
|
||||||
@ -449,7 +485,7 @@ def train_one_epoch(
|
|||||||
f"best valid epoch: {params.best_valid_epoch}"
|
f"best valid epoch: {params.best_valid_epoch}"
|
||||||
)
|
)
|
||||||
|
|
||||||
params.train_loss = tot_loss / tot_frames
|
params.train_loss = params.tot_loss / params.tot_frames
|
||||||
|
|
||||||
if params.train_loss < params.best_train_loss:
|
if params.train_loss < params.best_train_loss:
|
||||||
params.best_train_epoch = params.cur_epoch
|
params.best_train_epoch = params.cur_epoch
|
||||||
|
Loading…
x
Reference in New Issue
Block a user