mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 09:32:20 +00:00
Use LossRecord to record and print the loss for the training process (#62)
* Update index.rst (AS->ASR) * Update conformer_ctc.rst (pretraind->pretrained) * Fix some spelling errors. * Fix some spelling errors. * Use LossRecord to record and print loss in the training process * Change the name "LossRecord" to "MetricsTracker"
This commit is contained in:
parent
beb54ddb61
commit
597c5efdb1
@ -1,6 +1,7 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang,
|
||||
# Wei Kang)
|
||||
# Wei Kang
|
||||
# Mingshuang Luo)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
@ -21,13 +22,15 @@ import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from shutil import copyfile
|
||||
from typing import Optional
|
||||
from typing import Optional, Tuple
|
||||
|
||||
|
||||
import k2
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
|
||||
from asr_datamodule import LibriSpeechAsrDataModule
|
||||
from conformer import Conformer
|
||||
from lhotse.utils import fix_random_seed
|
||||
@ -43,6 +46,7 @@ from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
MetricsTracker,
|
||||
encode_supervisions,
|
||||
setup_logger,
|
||||
str2bool,
|
||||
@ -287,7 +291,7 @@ def compute_loss(
|
||||
batch: dict,
|
||||
graph_compiler: BpeCtcTrainingGraphCompiler,
|
||||
is_training: bool,
|
||||
):
|
||||
) -> Tuple[Tensor, MetricsTracker]:
|
||||
"""
|
||||
Compute CTC loss given the model and its inputs.
|
||||
|
||||
@ -367,15 +371,17 @@ def compute_loss(
|
||||
loss = ctc_loss
|
||||
att_loss = torch.tensor([0])
|
||||
|
||||
# train_frames and valid_frames are used for printing.
|
||||
if is_training:
|
||||
params.train_frames = supervision_segments[:, 2].sum().item()
|
||||
else:
|
||||
params.valid_frames = supervision_segments[:, 2].sum().item()
|
||||
|
||||
assert loss.requires_grad == is_training
|
||||
|
||||
return loss, ctc_loss.detach(), att_loss.detach()
|
||||
info = MetricsTracker()
|
||||
info["frames"] = supervision_segments[:, 2].sum().item()
|
||||
info["ctc_loss"] = ctc_loss.detach().cpu().item()
|
||||
if params.att_rate != 0.0:
|
||||
info["att_loss"] = att_loss.detach().cpu().item()
|
||||
|
||||
info["loss"] = loss.detach().cpu().item()
|
||||
|
||||
return loss, info
|
||||
|
||||
|
||||
def compute_validation_loss(
|
||||
@ -384,18 +390,14 @@ def compute_validation_loss(
|
||||
graph_compiler: BpeCtcTrainingGraphCompiler,
|
||||
valid_dl: torch.utils.data.DataLoader,
|
||||
world_size: int = 1,
|
||||
) -> None:
|
||||
"""Run the validation process. The validation loss
|
||||
is saved in `params.valid_loss`.
|
||||
"""
|
||||
) -> MetricsTracker:
|
||||
"""Run the validation process."""
|
||||
model.eval()
|
||||
|
||||
tot_loss = 0.0
|
||||
tot_ctc_loss = 0.0
|
||||
tot_att_loss = 0.0
|
||||
tot_frames = 0.0
|
||||
tot_loss = MetricsTracker()
|
||||
|
||||
for batch_idx, batch in enumerate(valid_dl):
|
||||
loss, ctc_loss, att_loss = compute_loss(
|
||||
loss, loss_info = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
batch=batch,
|
||||
@ -403,36 +405,17 @@ def compute_validation_loss(
|
||||
is_training=False,
|
||||
)
|
||||
assert loss.requires_grad is False
|
||||
assert ctc_loss.requires_grad is False
|
||||
assert att_loss.requires_grad is False
|
||||
|
||||
loss_cpu = loss.detach().cpu().item()
|
||||
tot_loss += loss_cpu
|
||||
|
||||
tot_ctc_loss += ctc_loss.detach().cpu().item()
|
||||
tot_att_loss += att_loss.detach().cpu().item()
|
||||
|
||||
tot_frames += params.valid_frames
|
||||
tot_loss = tot_loss + loss_info
|
||||
|
||||
if world_size > 1:
|
||||
s = torch.tensor(
|
||||
[tot_loss, tot_ctc_loss, tot_att_loss, tot_frames],
|
||||
device=loss.device,
|
||||
)
|
||||
dist.all_reduce(s, op=dist.ReduceOp.SUM)
|
||||
s = s.cpu().tolist()
|
||||
tot_loss = s[0]
|
||||
tot_ctc_loss = s[1]
|
||||
tot_att_loss = s[2]
|
||||
tot_frames = s[3]
|
||||
tot_loss.reduce(loss.device)
|
||||
|
||||
params.valid_loss = tot_loss / tot_frames
|
||||
params.valid_ctc_loss = tot_ctc_loss / tot_frames
|
||||
params.valid_att_loss = tot_att_loss / tot_frames
|
||||
|
||||
if params.valid_loss < params.best_valid_loss:
|
||||
loss_value = tot_loss["loss"] / tot_loss["frames"]
|
||||
if loss_value < params.best_valid_loss:
|
||||
params.best_valid_epoch = params.cur_epoch
|
||||
params.best_valid_loss = params.valid_loss
|
||||
params.best_valid_loss = loss_value
|
||||
|
||||
return tot_loss
|
||||
|
||||
|
||||
def train_one_epoch(
|
||||
@ -471,24 +454,21 @@ def train_one_epoch(
|
||||
"""
|
||||
model.train()
|
||||
|
||||
tot_loss = 0.0 # sum of losses over all batches
|
||||
tot_ctc_loss = 0.0
|
||||
tot_att_loss = 0.0
|
||||
tot_loss = MetricsTracker()
|
||||
|
||||
tot_frames = 0.0 # sum of frames over all batches
|
||||
params.tot_loss = 0.0
|
||||
params.tot_frames = 0.0
|
||||
for batch_idx, batch in enumerate(train_dl):
|
||||
params.batch_idx_train += 1
|
||||
batch_size = len(batch["supervisions"]["text"])
|
||||
|
||||
loss, ctc_loss, att_loss = compute_loss(
|
||||
loss, loss_info = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
batch=batch,
|
||||
graph_compiler=graph_compiler,
|
||||
is_training=True,
|
||||
)
|
||||
# summary stats
|
||||
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
|
||||
|
||||
# NOTE: We use reduction==sum and loss is computed over utterances
|
||||
# in the batch and there is no normalization to it so far.
|
||||
@ -498,75 +478,26 @@ def train_one_epoch(
|
||||
clip_grad_norm_(model.parameters(), 5.0, 2.0)
|
||||
optimizer.step()
|
||||
|
||||
loss_cpu = loss.detach().cpu().item()
|
||||
ctc_loss_cpu = ctc_loss.detach().cpu().item()
|
||||
att_loss_cpu = att_loss.detach().cpu().item()
|
||||
|
||||
tot_frames += params.train_frames
|
||||
tot_loss += loss_cpu
|
||||
tot_ctc_loss += ctc_loss_cpu
|
||||
tot_att_loss += att_loss_cpu
|
||||
|
||||
params.tot_frames += params.train_frames
|
||||
params.tot_loss += loss_cpu
|
||||
|
||||
tot_avg_loss = tot_loss / tot_frames
|
||||
tot_avg_ctc_loss = tot_ctc_loss / tot_frames
|
||||
tot_avg_att_loss = tot_att_loss / tot_frames
|
||||
|
||||
if batch_idx % params.log_interval == 0:
|
||||
logging.info(
|
||||
f"Epoch {params.cur_epoch}, batch {batch_idx}, "
|
||||
f"batch avg ctc loss {ctc_loss_cpu/params.train_frames:.4f}, "
|
||||
f"batch avg att loss {att_loss_cpu/params.train_frames:.4f}, "
|
||||
f"batch avg loss {loss_cpu/params.train_frames:.4f}, "
|
||||
f"total avg ctc loss: {tot_avg_ctc_loss:.4f}, "
|
||||
f"total avg att loss: {tot_avg_att_loss:.4f}, "
|
||||
f"total avg loss: {tot_avg_loss:.4f}, "
|
||||
f"batch size: {batch_size}"
|
||||
f"Epoch {params.cur_epoch}, "
|
||||
f"batch {batch_idx}, loss[{loss_info}], "
|
||||
f"tot_loss[{tot_loss}], batch size: {batch_size}"
|
||||
)
|
||||
|
||||
if batch_idx % 10 == 0:
|
||||
|
||||
if tb_writer is not None:
|
||||
tb_writer.add_scalar(
|
||||
"train/current_ctc_loss",
|
||||
ctc_loss_cpu / params.train_frames,
|
||||
params.batch_idx_train,
|
||||
loss_info.write_summary(
|
||||
tb_writer, "train/current_", params.batch_idx_train
|
||||
)
|
||||
tb_writer.add_scalar(
|
||||
"train/current_att_loss",
|
||||
att_loss_cpu / params.train_frames,
|
||||
params.batch_idx_train,
|
||||
tot_loss.write_summary(
|
||||
tb_writer, "train/tot_", params.batch_idx_train
|
||||
)
|
||||
tb_writer.add_scalar(
|
||||
"train/current_loss",
|
||||
loss_cpu / params.train_frames,
|
||||
params.batch_idx_train,
|
||||
)
|
||||
tb_writer.add_scalar(
|
||||
"train/tot_avg_ctc_loss",
|
||||
tot_avg_ctc_loss,
|
||||
params.batch_idx_train,
|
||||
)
|
||||
|
||||
tb_writer.add_scalar(
|
||||
"train/tot_avg_att_loss",
|
||||
tot_avg_att_loss,
|
||||
params.batch_idx_train,
|
||||
)
|
||||
tb_writer.add_scalar(
|
||||
"train/tot_avg_loss",
|
||||
tot_avg_loss,
|
||||
params.batch_idx_train,
|
||||
)
|
||||
if batch_idx > 0 and batch_idx % params.reset_interval == 0:
|
||||
tot_loss = 0.0 # sum of losses over all batches
|
||||
tot_ctc_loss = 0.0
|
||||
tot_att_loss = 0.0
|
||||
|
||||
tot_frames = 0.0 # sum of frames over all batches
|
||||
|
||||
if batch_idx > 0 and batch_idx % params.valid_interval == 0:
|
||||
compute_validation_loss(
|
||||
logging.info("Computing validation loss")
|
||||
valid_info = compute_validation_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
graph_compiler=graph_compiler,
|
||||
@ -574,33 +505,14 @@ def train_one_epoch(
|
||||
world_size=world_size,
|
||||
)
|
||||
model.train()
|
||||
logging.info(
|
||||
f"Epoch {params.cur_epoch}, "
|
||||
f"valid ctc loss {params.valid_ctc_loss:.4f},"
|
||||
f"valid att loss {params.valid_att_loss:.4f},"
|
||||
f"valid loss {params.valid_loss:.4f},"
|
||||
f" best valid loss: {params.best_valid_loss:.4f} "
|
||||
f"best valid epoch: {params.best_valid_epoch}"
|
||||
)
|
||||
logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
|
||||
if tb_writer is not None:
|
||||
tb_writer.add_scalar(
|
||||
"train/valid_ctc_loss",
|
||||
params.valid_ctc_loss,
|
||||
params.batch_idx_train,
|
||||
)
|
||||
tb_writer.add_scalar(
|
||||
"train/valid_att_loss",
|
||||
params.valid_att_loss,
|
||||
params.batch_idx_train,
|
||||
)
|
||||
tb_writer.add_scalar(
|
||||
"train/valid_loss",
|
||||
params.valid_loss,
|
||||
params.batch_idx_train,
|
||||
valid_info.write_summary(
|
||||
tb_writer, "train/valid_", params.batch_idx_train
|
||||
)
|
||||
|
||||
params.train_loss = params.tot_loss / params.tot_frames
|
||||
|
||||
loss_value = tot_loss["loss"] / tot_loss["frames"]
|
||||
params.train_loss = loss_value
|
||||
if params.train_loss < params.best_train_loss:
|
||||
params.best_train_epoch = params.cur_epoch
|
||||
params.best_train_loss = params.train_loss
|
||||
|
@ -57,13 +57,13 @@ log() {
|
||||
log "dl_dir: $dl_dir"
|
||||
|
||||
if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then
|
||||
log "stage -1: Download LM"
|
||||
log "Stage -1: Download LM"
|
||||
[ ! -e $dl_dir/lm ] && mkdir -p $dl_dir/lm
|
||||
./local/download_lm.py --out-dir=$dl_dir/lm
|
||||
fi
|
||||
|
||||
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
|
||||
log "stage 0: Download data"
|
||||
log "Stage 0: Download data"
|
||||
|
||||
# If you have pre-downloaded it to /path/to/LibriSpeech,
|
||||
# you can create a symlink
|
||||
@ -126,7 +126,7 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
||||
fi
|
||||
|
||||
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
|
||||
log "State 6: Prepare BPE based lang"
|
||||
log "Stage 6: Prepare BPE based lang"
|
||||
|
||||
for vocab_size in ${vocab_sizes[@]}; do
|
||||
lang_dir=data/lang_bpe_${vocab_size}
|
||||
|
@ -1,5 +1,6 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang
|
||||
# Mingshuang Luo)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
@ -20,14 +21,15 @@ import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from shutil import copyfile
|
||||
from typing import Optional
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import k2
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from torch import Tensor
|
||||
|
||||
from asr_datamodule import LibriSpeechAsrDataModule
|
||||
from lhotse.utils import fix_random_seed
|
||||
from model import TdnnLstm
|
||||
@ -43,6 +45,7 @@ from icefall.graph_compiler import CtcTrainingGraphCompiler
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
MetricsTracker,
|
||||
encode_supervisions,
|
||||
setup_logger,
|
||||
str2bool,
|
||||
@ -267,7 +270,7 @@ def compute_loss(
|
||||
batch: dict,
|
||||
graph_compiler: CtcTrainingGraphCompiler,
|
||||
is_training: bool,
|
||||
):
|
||||
) -> Tuple[Tensor, MetricsTracker]:
|
||||
"""
|
||||
Compute CTC loss given the model and its inputs.
|
||||
|
||||
@ -324,13 +327,11 @@ def compute_loss(
|
||||
|
||||
assert loss.requires_grad == is_training
|
||||
|
||||
# train_frames and valid_frames are used for printing.
|
||||
if is_training:
|
||||
params.train_frames = supervision_segments[:, 2].sum().item()
|
||||
else:
|
||||
params.valid_frames = supervision_segments[:, 2].sum().item()
|
||||
info = MetricsTracker()
|
||||
info["frames"] = supervision_segments[:, 2].sum().item()
|
||||
info["loss"] = loss.detach().cpu().item()
|
||||
|
||||
return loss
|
||||
return loss, info
|
||||
|
||||
|
||||
def compute_validation_loss(
|
||||
@ -339,16 +340,16 @@ def compute_validation_loss(
|
||||
graph_compiler: CtcTrainingGraphCompiler,
|
||||
valid_dl: torch.utils.data.DataLoader,
|
||||
world_size: int = 1,
|
||||
) -> None:
|
||||
) -> MetricsTracker:
|
||||
"""Run the validation process. The validation loss
|
||||
is saved in `params.valid_loss`.
|
||||
"""
|
||||
model.eval()
|
||||
|
||||
tot_loss = 0.0
|
||||
tot_frames = 0.0
|
||||
tot_loss = MetricsTracker()
|
||||
|
||||
for batch_idx, batch in enumerate(valid_dl):
|
||||
loss = compute_loss(
|
||||
loss, loss_info = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
batch=batch,
|
||||
@ -357,22 +358,18 @@ def compute_validation_loss(
|
||||
)
|
||||
assert loss.requires_grad is False
|
||||
|
||||
loss_cpu = loss.detach().cpu().item()
|
||||
tot_loss += loss_cpu
|
||||
tot_frames += params.valid_frames
|
||||
tot_loss = tot_loss + loss_info
|
||||
|
||||
if world_size > 1:
|
||||
s = torch.tensor([tot_loss, tot_frames], device=loss.device)
|
||||
dist.all_reduce(s, op=dist.ReduceOp.SUM)
|
||||
s = s.cpu().tolist()
|
||||
tot_loss = s[0]
|
||||
tot_frames = s[1]
|
||||
tot_loss.reduce(loss.device)
|
||||
|
||||
params.valid_loss = tot_loss / tot_frames
|
||||
loss_value = tot_loss["loss"] / tot_loss["frames"]
|
||||
|
||||
if params.valid_loss < params.best_valid_loss:
|
||||
if loss_value < params.best_valid_loss:
|
||||
params.best_valid_epoch = params.cur_epoch
|
||||
params.best_valid_loss = params.valid_loss
|
||||
params.best_valid_loss = loss_value
|
||||
|
||||
return tot_loss
|
||||
|
||||
|
||||
def train_one_epoch(
|
||||
@ -411,67 +408,45 @@ def train_one_epoch(
|
||||
"""
|
||||
model.train()
|
||||
|
||||
tot_loss = 0.0 # reset after params.reset_interval of batches
|
||||
tot_frames = 0.0 # reset after params.reset_interval of batches
|
||||
|
||||
params.tot_loss = 0.0
|
||||
params.tot_frames = 0.0
|
||||
tot_loss = MetricsTracker()
|
||||
|
||||
for batch_idx, batch in enumerate(train_dl):
|
||||
params.batch_idx_train += 1
|
||||
batch_size = len(batch["supervisions"]["text"])
|
||||
|
||||
loss = compute_loss(
|
||||
loss, loss_info = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
batch=batch,
|
||||
graph_compiler=graph_compiler,
|
||||
is_training=True,
|
||||
)
|
||||
|
||||
# NOTE: We use reduction==sum and loss is computed over utterances
|
||||
# in the batch and there is no normalization to it so far.
|
||||
# summary stats.
|
||||
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
|
||||
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
clip_grad_norm_(model.parameters(), 5.0, 2.0)
|
||||
optimizer.step()
|
||||
|
||||
loss_cpu = loss.detach().cpu().item()
|
||||
|
||||
tot_frames += params.train_frames
|
||||
tot_loss += loss_cpu
|
||||
tot_avg_loss = tot_loss / tot_frames
|
||||
|
||||
params.tot_frames += params.train_frames
|
||||
params.tot_loss += loss_cpu
|
||||
|
||||
if batch_idx % params.log_interval == 0:
|
||||
logging.info(
|
||||
f"Epoch {params.cur_epoch}, batch {batch_idx}, "
|
||||
f"batch avg loss {loss_cpu/params.train_frames:.4f}, "
|
||||
f"total avg loss: {tot_avg_loss:.4f}, "
|
||||
f"batch size: {batch_size}"
|
||||
f"Epoch {params.cur_epoch}, "
|
||||
f"batch {batch_idx}, loss[{loss_info}], "
|
||||
f"tot_loss[{tot_loss}], batch size: {batch_size}"
|
||||
)
|
||||
if batch_idx % 10 == 0:
|
||||
|
||||
if tb_writer is not None:
|
||||
tb_writer.add_scalar(
|
||||
"train/current_loss",
|
||||
loss_cpu / params.train_frames,
|
||||
params.batch_idx_train,
|
||||
loss_info.write_summary(
|
||||
tb_writer, "train/current_", params.batch_idx_train
|
||||
)
|
||||
|
||||
tb_writer.add_scalar(
|
||||
"train/tot_avg_loss",
|
||||
tot_avg_loss,
|
||||
params.batch_idx_train,
|
||||
tot_loss.write_summary(
|
||||
tb_writer, "train/tot_", params.batch_idx_train
|
||||
)
|
||||
|
||||
if batch_idx > 0 and batch_idx % params.reset_interval == 0:
|
||||
tot_loss = 0
|
||||
tot_frames = 0
|
||||
|
||||
if batch_idx > 0 and batch_idx % params.valid_interval == 0:
|
||||
compute_validation_loss(
|
||||
valid_info = compute_validation_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
graph_compiler=graph_compiler,
|
||||
@ -479,13 +454,16 @@ def train_one_epoch(
|
||||
world_size=world_size,
|
||||
)
|
||||
model.train()
|
||||
logging.info(
|
||||
f"Epoch {params.cur_epoch}, valid loss {params.valid_loss:.4f},"
|
||||
f" best valid loss: {params.best_valid_loss:.4f} "
|
||||
f"best valid epoch: {params.best_valid_epoch}"
|
||||
)
|
||||
logging.info(f"Epoch {params.cur_epoch}, validation {valid_info}")
|
||||
if tb_writer is not None:
|
||||
valid_info.write_summary(
|
||||
tb_writer,
|
||||
"train/valid_",
|
||||
params.batch_idx_train,
|
||||
)
|
||||
|
||||
params.train_loss = params.tot_loss / params.tot_frames
|
||||
loss_value = tot_loss["loss"] / tot_loss["frames"]
|
||||
params.train_loss = loss_value
|
||||
|
||||
if params.train_loss < params.best_train_loss:
|
||||
params.best_train_epoch = params.cur_epoch
|
||||
|
@ -24,7 +24,7 @@ log() {
|
||||
log "dl_dir: $dl_dir"
|
||||
|
||||
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
|
||||
log "stage 0: Download data"
|
||||
log "Stage 0: Download data"
|
||||
mkdir -p $dl_dir
|
||||
|
||||
if [ ! -f $dl_dir/waves_yesno/.completed ]; then
|
||||
|
@ -4,14 +4,14 @@ import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from shutil import copyfile
|
||||
from typing import Optional
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import k2
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from torch import Tensor
|
||||
from asr_datamodule import YesNoAsrDataModule
|
||||
from lhotse.utils import fix_random_seed
|
||||
from model import Tdnn
|
||||
@ -24,7 +24,7 @@ from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.graph_compiler import CtcTrainingGraphCompiler
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import AttributeDict, setup_logger, str2bool
|
||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
||||
|
||||
|
||||
def get_parser():
|
||||
@ -122,6 +122,8 @@ def get_params() -> AttributeDict:
|
||||
|
||||
- valid_interval: Run validation if batch_idx % valid_interval` is 0
|
||||
|
||||
- reset_interval: Reset statistics if batch_idx % reset_interval is 0
|
||||
|
||||
- beam_size: It is used in k2.ctc_loss
|
||||
|
||||
- reduction: It is used in k2.ctc_loss
|
||||
@ -142,6 +144,7 @@ def get_params() -> AttributeDict:
|
||||
"best_valid_epoch": -1,
|
||||
"batch_idx_train": 0,
|
||||
"log_interval": 10,
|
||||
"reset_interval": 20,
|
||||
"valid_interval": 10,
|
||||
"beam_size": 10,
|
||||
"reduction": "sum",
|
||||
@ -245,7 +248,7 @@ def compute_loss(
|
||||
batch: dict,
|
||||
graph_compiler: CtcTrainingGraphCompiler,
|
||||
is_training: bool,
|
||||
):
|
||||
) -> Tuple[Tensor, MetricsTracker]:
|
||||
"""
|
||||
Compute CTC loss given the model and its inputs.
|
||||
|
||||
@ -305,13 +308,11 @@ def compute_loss(
|
||||
|
||||
assert loss.requires_grad == is_training
|
||||
|
||||
# train_frames and valid_frames are used for printing.
|
||||
if is_training:
|
||||
params.train_frames = supervision_segments[:, 2].sum().item()
|
||||
else:
|
||||
params.valid_frames = supervision_segments[:, 2].sum().item()
|
||||
info = MetricsTracker()
|
||||
info["frames"] = supervision_segments[:, 2].sum().item()
|
||||
info["loss"] = loss.detach().cpu().item()
|
||||
|
||||
return loss
|
||||
return loss, info
|
||||
|
||||
|
||||
def compute_validation_loss(
|
||||
@ -320,16 +321,16 @@ def compute_validation_loss(
|
||||
graph_compiler: CtcTrainingGraphCompiler,
|
||||
valid_dl: torch.utils.data.DataLoader,
|
||||
world_size: int = 1,
|
||||
) -> None:
|
||||
) -> MetricsTracker:
|
||||
"""Run the validation process. The validation loss
|
||||
is saved in `params.valid_loss`.
|
||||
"""
|
||||
model.eval()
|
||||
|
||||
tot_loss = 0.0
|
||||
tot_frames = 0.0
|
||||
tot_loss = MetricsTracker()
|
||||
|
||||
for batch_idx, batch in enumerate(valid_dl):
|
||||
loss = compute_loss(
|
||||
loss, loss_info = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
batch=batch,
|
||||
@ -338,22 +339,18 @@ def compute_validation_loss(
|
||||
)
|
||||
assert loss.requires_grad is False
|
||||
|
||||
loss_cpu = loss.detach().cpu().item()
|
||||
tot_loss += loss_cpu
|
||||
tot_frames += params.valid_frames
|
||||
tot_loss = tot_loss + loss_info
|
||||
|
||||
if world_size > 1:
|
||||
s = torch.tensor([tot_loss, tot_frames], device=loss.device)
|
||||
dist.all_reduce(s, op=dist.ReduceOp.SUM)
|
||||
s = s.cpu().tolist()
|
||||
tot_loss = s[0]
|
||||
tot_frames = s[1]
|
||||
tot_loss.reduce(loss.device)
|
||||
|
||||
params.valid_loss = tot_loss / tot_frames
|
||||
loss_value = tot_loss["loss"] / tot_loss["frames"]
|
||||
|
||||
if params.valid_loss < params.best_valid_loss:
|
||||
if loss_value < params.best_valid_loss:
|
||||
params.best_valid_epoch = params.cur_epoch
|
||||
params.best_valid_loss = params.valid_loss
|
||||
params.best_valid_loss = loss_value
|
||||
|
||||
return tot_loss
|
||||
|
||||
|
||||
def train_one_epoch(
|
||||
@ -392,57 +389,45 @@ def train_one_epoch(
|
||||
"""
|
||||
model.train()
|
||||
|
||||
tot_loss = 0.0 # sum of losses over all batches
|
||||
tot_frames = 0.0 # sum of frames over all batches
|
||||
tot_loss = MetricsTracker()
|
||||
|
||||
for batch_idx, batch in enumerate(train_dl):
|
||||
params.batch_idx_train += 1
|
||||
batch_size = len(batch["supervisions"]["text"])
|
||||
|
||||
loss = compute_loss(
|
||||
loss, loss_info = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
batch=batch,
|
||||
graph_compiler=graph_compiler,
|
||||
is_training=True,
|
||||
)
|
||||
|
||||
# NOTE: We use reduction==sum and loss is computed over utterances
|
||||
# in the batch and there is no normalization to it so far.
|
||||
# summary stats.
|
||||
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
|
||||
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
clip_grad_norm_(model.parameters(), 5.0, 2.0)
|
||||
optimizer.step()
|
||||
|
||||
loss_cpu = loss.detach().cpu().item()
|
||||
|
||||
tot_frames += params.train_frames
|
||||
tot_loss += loss_cpu
|
||||
tot_avg_loss = tot_loss / tot_frames
|
||||
|
||||
if batch_idx % params.log_interval == 0:
|
||||
logging.info(
|
||||
f"Epoch {params.cur_epoch}, batch {batch_idx}, "
|
||||
f"batch avg loss {loss_cpu/params.train_frames:.4f}, "
|
||||
f"total avg loss: {tot_avg_loss:.4f}, "
|
||||
f"batch size: {batch_size}"
|
||||
f"Epoch {params.cur_epoch}, "
|
||||
f"batch {batch_idx}, loss[{loss_info}], "
|
||||
f"tot_loss[{tot_loss}], batch size: {batch_size}"
|
||||
)
|
||||
if batch_idx % 10 == 0:
|
||||
|
||||
if tb_writer is not None:
|
||||
tb_writer.add_scalar(
|
||||
"train/current_loss",
|
||||
loss_cpu / params.train_frames,
|
||||
params.batch_idx_train,
|
||||
loss_info.write_summary(
|
||||
tb_writer, "train/current_", params.batch_idx_train
|
||||
)
|
||||
|
||||
tb_writer.add_scalar(
|
||||
"train/tot_avg_loss",
|
||||
tot_avg_loss,
|
||||
params.batch_idx_train,
|
||||
tot_loss.write_summary(
|
||||
tb_writer, "train/tot_", params.batch_idx_train
|
||||
)
|
||||
|
||||
if batch_idx > 0 and batch_idx % params.valid_interval == 0:
|
||||
compute_validation_loss(
|
||||
valid_info = compute_validation_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
graph_compiler=graph_compiler,
|
||||
@ -450,19 +435,16 @@ def train_one_epoch(
|
||||
world_size=world_size,
|
||||
)
|
||||
model.train()
|
||||
logging.info(
|
||||
f"Epoch {params.cur_epoch}, valid loss {params.valid_loss:.4f},"
|
||||
f" best valid loss: {params.best_valid_loss:.4f} "
|
||||
f"best valid epoch: {params.best_valid_epoch}"
|
||||
)
|
||||
logging.info(f"Epoch {params.cur_epoch}, validation {valid_info}")
|
||||
if tb_writer is not None:
|
||||
tb_writer.add_scalar(
|
||||
"train/valid_loss",
|
||||
params.valid_loss,
|
||||
valid_info.write_summary(
|
||||
tb_writer,
|
||||
"train/valid_",
|
||||
params.batch_idx_train,
|
||||
)
|
||||
|
||||
params.train_loss = tot_loss / tot_frames
|
||||
loss_value = tot_loss["loss"] / tot_loss["frames"]
|
||||
params.train_loss = loss_value
|
||||
|
||||
if params.train_loss < params.best_train_loss:
|
||||
params.best_train_epoch = params.cur_epoch
|
||||
|
@ -1,4 +1,5 @@
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang
|
||||
# Mingshuang Luo)
|
||||
#
|
||||
# See ../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
@ -17,6 +18,7 @@
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import collections
|
||||
import os
|
||||
import subprocess
|
||||
from collections import defaultdict
|
||||
@ -29,6 +31,7 @@ import k2
|
||||
import kaldialign
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
Pathlike = Union[str, Path]
|
||||
|
||||
@ -166,8 +169,8 @@ def encode_supervisions(
|
||||
supervisions: dict, subsampling_factor: int
|
||||
) -> Tuple[torch.Tensor, List[str]]:
|
||||
"""
|
||||
Encodes Lhotse's ``batch["supervisions"]`` dict into a pair of torch Tensor,
|
||||
and a list of transcription strings.
|
||||
Encodes Lhotse's ``batch["supervisions"]`` dict into
|
||||
a pair of torch Tensor, and a list of transcription strings.
|
||||
|
||||
The supervision tensor has shape ``(batch_size, 3)``.
|
||||
Its second dimension contains information about sequence index [0],
|
||||
@ -272,13 +275,13 @@ def write_error_stats(
|
||||
Errors: 23 insertions, 57 deletions, 212 substitutions, over 2606
|
||||
reference words (2337 correct)
|
||||
|
||||
- The difference between the reference transcript and predicted results.
|
||||
- The difference between the reference transcript and predicted result.
|
||||
An instance is given below::
|
||||
|
||||
THE ASSOCIATION OF (EDISON->ADDISON) ILLUMINATING COMPANIES
|
||||
|
||||
The above example shows that the reference word is `EDISON`, but it is
|
||||
predicted to `ADDISON` (a substitution error).
|
||||
The above example shows that the reference word is `EDISON`,
|
||||
but it is predicted to `ADDISON` (a substitution error).
|
||||
|
||||
Another example is::
|
||||
|
||||
@ -419,3 +422,76 @@ def write_error_stats(
|
||||
|
||||
print(f"{word} {corr} {tot_errs} {ref_count} {hyp_count}", file=f)
|
||||
return float(tot_err_rate)
|
||||
|
||||
|
||||
class MetricsTracker(collections.defaultdict):
|
||||
def __init__(self):
|
||||
# Passing the type 'int' to the base-class constructor
|
||||
# makes undefined items default to int() which is zero.
|
||||
# This class will play a role as metrics tracker.
|
||||
# It can record many metrics, including but not limited to loss.
|
||||
super(MetricsTracker, self).__init__(int)
|
||||
|
||||
def __add__(self, other: "MetricsTracker") -> "MetricsTracker":
|
||||
ans = MetricsTracker()
|
||||
for k, v in self.items():
|
||||
ans[k] = v
|
||||
for k, v in other.items():
|
||||
ans[k] = ans[k] + v
|
||||
return ans
|
||||
|
||||
def __mul__(self, alpha: float) -> "MetricsTracker":
|
||||
ans = MetricsTracker()
|
||||
for k, v in self.items():
|
||||
ans[k] = v * alpha
|
||||
return ans
|
||||
|
||||
def __str__(self) -> str:
|
||||
ans = ""
|
||||
for k, v in self.norm_items():
|
||||
norm_value = "%.4g" % v
|
||||
ans += str(k) + "=" + str(norm_value) + ", "
|
||||
frames = str(self["frames"])
|
||||
ans += "over " + frames + " frames."
|
||||
return ans
|
||||
|
||||
def norm_items(self) -> List[Tuple[str, float]]:
|
||||
"""
|
||||
Returns a list of pairs, like:
|
||||
[('ctc_loss', 0.1), ('att_loss', 0.07)]
|
||||
"""
|
||||
num_frames = self["frames"] if "frames" in self else 1
|
||||
ans = []
|
||||
for k, v in self.items():
|
||||
if k != "frames":
|
||||
norm_value = float(v) / num_frames
|
||||
ans.append((k, norm_value))
|
||||
return ans
|
||||
|
||||
def reduce(self, device):
|
||||
"""
|
||||
Reduce using torch.distributed, which I believe ensures that
|
||||
all processes get the total.
|
||||
"""
|
||||
keys = sorted(self.keys())
|
||||
s = torch.tensor([float(self[k]) for k in keys], device=device)
|
||||
dist.all_reduce(s, op=dist.ReduceOp.SUM)
|
||||
for k, v in zip(keys, s.cpu().tolist()):
|
||||
self[k] = v
|
||||
|
||||
def write_summary(
|
||||
self,
|
||||
tb_writer: SummaryWriter,
|
||||
prefix: str,
|
||||
batch_idx: int,
|
||||
) -> None:
|
||||
"""Add logging information to a TensorBoard writer.
|
||||
|
||||
Args:
|
||||
tb_writer: a TensorBoard writer
|
||||
prefix: a prefix for the name of the loss, e.g. "train/valid_",
|
||||
or "train/current_"
|
||||
batch_idx: The current batch index, used as the x-axis of the plot.
|
||||
"""
|
||||
for k, v in self.norm_items():
|
||||
tb_writer.add_scalar(prefix + k, v, batch_idx)
|
||||
|
Loading…
x
Reference in New Issue
Block a user