mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Display losses for gigaspeech and librispeech separately.
This commit is contained in:
parent
018d03cd08
commit
1930d72b17
@ -516,6 +516,8 @@ def train_one_epoch(
|
|||||||
"""
|
"""
|
||||||
model.train()
|
model.train()
|
||||||
|
|
||||||
|
libri_tot_loss = MetricsTracker()
|
||||||
|
giga_tot_loss = MetricsTracker()
|
||||||
tot_loss = MetricsTracker()
|
tot_loss = MetricsTracker()
|
||||||
|
|
||||||
# index 0: for LibriSpeech
|
# index 0: for LibriSpeech
|
||||||
@ -542,6 +544,8 @@ def train_one_epoch(
|
|||||||
params.batch_idx_train += 1
|
params.batch_idx_train += 1
|
||||||
batch_size = len(batch["supervisions"]["text"])
|
batch_size = len(batch["supervisions"]["text"])
|
||||||
|
|
||||||
|
libri = is_libri(batch["supervisions"]["cut"][0])
|
||||||
|
|
||||||
loss, loss_info = compute_loss(
|
loss, loss_info = compute_loss(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
@ -551,6 +555,16 @@ def train_one_epoch(
|
|||||||
)
|
)
|
||||||
# summary stats
|
# summary stats
|
||||||
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
|
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
|
||||||
|
if libri:
|
||||||
|
libri_tot_loss = (
|
||||||
|
libri_tot_loss * (1 - 1 / params.reset_interval)
|
||||||
|
) + loss_info
|
||||||
|
prefix = "libri" # for logging only
|
||||||
|
else:
|
||||||
|
giga_tot_loss = (
|
||||||
|
giga_tot_loss * (1 - 1 / params.reset_interval)
|
||||||
|
) + loss_info
|
||||||
|
prefix = "giga"
|
||||||
|
|
||||||
# NOTE: We use reduction==sum and loss is computed over utterances
|
# NOTE: We use reduction==sum and loss is computed over utterances
|
||||||
# in the batch and there is no normalization to it so far.
|
# in the batch and there is no normalization to it so far.
|
||||||
@ -563,19 +577,29 @@ def train_one_epoch(
|
|||||||
if batch_idx % params.log_interval == 0:
|
if batch_idx % params.log_interval == 0:
|
||||||
logging.info(
|
logging.info(
|
||||||
f"Epoch {params.cur_epoch}, "
|
f"Epoch {params.cur_epoch}, "
|
||||||
f"batch {batch_idx}, loss[{loss_info}], "
|
f"batch {batch_idx}, {prefix}_loss[{loss_info}], "
|
||||||
f"tot_loss[{tot_loss}], batch size: {batch_size}"
|
f"tot_loss[{tot_loss}], "
|
||||||
|
f"libri_tot_loss[{libri_tot_loss}], "
|
||||||
|
f"giga_tot_loss[{giga_tot_loss}], "
|
||||||
|
f"batch size: {batch_size}"
|
||||||
)
|
)
|
||||||
|
|
||||||
if batch_idx % params.log_interval == 0:
|
if batch_idx % params.log_interval == 0:
|
||||||
|
|
||||||
if tb_writer is not None:
|
if tb_writer is not None:
|
||||||
loss_info.write_summary(
|
loss_info.write_summary(
|
||||||
tb_writer, "train/current_", params.batch_idx_train
|
tb_writer,
|
||||||
|
f"train/current_{prefix}_",
|
||||||
|
params.batch_idx_train,
|
||||||
)
|
)
|
||||||
tot_loss.write_summary(
|
tot_loss.write_summary(
|
||||||
tb_writer, "train/tot_", params.batch_idx_train
|
tb_writer, "train/tot_", params.batch_idx_train
|
||||||
)
|
)
|
||||||
|
libri_tot_loss.write_summary(
|
||||||
|
tb_writer, "train/libri_tot_", params.batch_idx_train
|
||||||
|
)
|
||||||
|
giga_tot_loss.write_summary(
|
||||||
|
tb_writer, "train/giga_tot_", params.batch_idx_train
|
||||||
|
)
|
||||||
|
|
||||||
if batch_idx > 0 and batch_idx % params.valid_interval == 0:
|
if batch_idx > 0 and batch_idx % params.valid_interval == 0:
|
||||||
logging.info("Computing validation loss")
|
logging.info("Computing validation loss")
|
||||||
@ -738,7 +762,9 @@ def run(rank, world_size, args):
|
|||||||
valid_cuts += librispeech.dev_other_cuts()
|
valid_cuts += librispeech.dev_other_cuts()
|
||||||
valid_dl = asr_datamodule.valid_dataloaders(valid_cuts)
|
valid_dl = asr_datamodule.valid_dataloaders(valid_cuts)
|
||||||
|
|
||||||
for dl in [train_dl, giga_train_dl]:
|
# It's time consuming to include `giga_train_dl` here
|
||||||
|
# for dl in [train_dl, giga_train_dl]:
|
||||||
|
for dl in [train_dl]:
|
||||||
scan_pessimistic_batches_for_oom(
|
scan_pessimistic_batches_for_oom(
|
||||||
model=model,
|
model=model,
|
||||||
train_dl=dl,
|
train_dl=dl,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user