removed batch_name to fix a KeyError with "uttid" (#1172)

This commit is contained in:
zr_jin 2023-07-15 12:39:32 +08:00 committed by GitHub
parent 5ed6fc0e6d
commit 4ab7d61008
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -675,7 +675,6 @@ def train_one_epoch(
for batch_idx, batch in enumerate(train_dl):
params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"])
batch_name = batch["supervisions"]["uttid"]
with torch.cuda.amp.autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss(
@ -698,10 +697,7 @@ def train_one_epoch(
scaler.scale(loss).backward()
except RuntimeError as e:
if "CUDA out of memory" in str(e):
logging.error(
f"failing batch size:{batch_size} "
f"failing batch names {batch_name}"
)
logging.error(f"failing batch size:{batch_size} ")
raise
scheduler.step_batch(params.batch_idx_train)
@ -756,10 +752,7 @@ def train_one_epoch(
if loss_info["ctc_loss"] == float("inf") or loss_info["att_loss"] == float(
"inf"
):
logging.error(
"Your loss contains inf, something goes wrong"
f"failing batch names {batch_name}"
)
logging.error("Your loss contains inf, something goes wrong")
if tb_writer is not None:
tb_writer.add_scalar(
"train/learning_rate", cur_lr, params.batch_idx_train