mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 17:42:21 +00:00
removed batch_name
to fix a KeyError with "uttid" (#1172)
This commit is contained in:
parent
5ed6fc0e6d
commit
4ab7d61008
@ -675,7 +675,6 @@ def train_one_epoch(
|
|||||||
for batch_idx, batch in enumerate(train_dl):
|
for batch_idx, batch in enumerate(train_dl):
|
||||||
params.batch_idx_train += 1
|
params.batch_idx_train += 1
|
||||||
batch_size = len(batch["supervisions"]["text"])
|
batch_size = len(batch["supervisions"]["text"])
|
||||||
batch_name = batch["supervisions"]["uttid"]
|
|
||||||
|
|
||||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||||
loss, loss_info = compute_loss(
|
loss, loss_info = compute_loss(
|
||||||
@ -698,10 +697,7 @@ def train_one_epoch(
|
|||||||
scaler.scale(loss).backward()
|
scaler.scale(loss).backward()
|
||||||
except RuntimeError as e:
|
except RuntimeError as e:
|
||||||
if "CUDA out of memory" in str(e):
|
if "CUDA out of memory" in str(e):
|
||||||
logging.error(
|
logging.error(f"failing batch size:{batch_size} ")
|
||||||
f"failing batch size:{batch_size} "
|
|
||||||
f"failing batch names {batch_name}"
|
|
||||||
)
|
|
||||||
raise
|
raise
|
||||||
|
|
||||||
scheduler.step_batch(params.batch_idx_train)
|
scheduler.step_batch(params.batch_idx_train)
|
||||||
@ -756,10 +752,7 @@ def train_one_epoch(
|
|||||||
if loss_info["ctc_loss"] == float("inf") or loss_info["att_loss"] == float(
|
if loss_info["ctc_loss"] == float("inf") or loss_info["att_loss"] == float(
|
||||||
"inf"
|
"inf"
|
||||||
):
|
):
|
||||||
logging.error(
|
logging.error("Your loss contains inf, something goes wrong")
|
||||||
"Your loss contains inf, something goes wrong"
|
|
||||||
f"failing batch names {batch_name}"
|
|
||||||
)
|
|
||||||
if tb_writer is not None:
|
if tb_writer is not None:
|
||||||
tb_writer.add_scalar(
|
tb_writer.add_scalar(
|
||||||
"train/learning_rate", cur_lr, params.batch_idx_train
|
"train/learning_rate", cur_lr, params.batch_idx_train
|
||||||
|
Loading…
x
Reference in New Issue
Block a user