mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 09:32:20 +00:00
removed batch_name
to fix a KeyError with "uttid" (#1172)
This commit is contained in:
parent
5ed6fc0e6d
commit
4ab7d61008
@ -675,7 +675,6 @@ def train_one_epoch(
|
||||
for batch_idx, batch in enumerate(train_dl):
|
||||
params.batch_idx_train += 1
|
||||
batch_size = len(batch["supervisions"]["text"])
|
||||
batch_name = batch["supervisions"]["uttid"]
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
loss, loss_info = compute_loss(
|
||||
@ -698,10 +697,7 @@ def train_one_epoch(
|
||||
scaler.scale(loss).backward()
|
||||
except RuntimeError as e:
|
||||
if "CUDA out of memory" in str(e):
|
||||
logging.error(
|
||||
f"failing batch size:{batch_size} "
|
||||
f"failing batch names {batch_name}"
|
||||
)
|
||||
logging.error(f"failing batch size:{batch_size} ")
|
||||
raise
|
||||
|
||||
scheduler.step_batch(params.batch_idx_train)
|
||||
@ -756,10 +752,7 @@ def train_one_epoch(
|
||||
if loss_info["ctc_loss"] == float("inf") or loss_info["att_loss"] == float(
|
||||
"inf"
|
||||
):
|
||||
logging.error(
|
||||
"Your loss contains inf, something goes wrong"
|
||||
f"failing batch names {batch_name}"
|
||||
)
|
||||
logging.error("Your loss contains inf, something goes wrong")
|
||||
if tb_writer is not None:
|
||||
tb_writer.add_scalar(
|
||||
"train/learning_rate", cur_lr, params.batch_idx_train
|
||||
|
Loading…
x
Reference in New Issue
Block a user