Fix convert_texts_into_ids() in the tedlium3 recipe.

This commit is contained in:
Fangjun Kuang 2025-04-24 14:50:51 +08:00
parent 5ec95e5482
commit d1312e4d3e
2 changed files with 2 additions and 2 deletions

View File

@ -422,7 +422,7 @@ def compute_loss(
texts = batch["supervisions"]["text"] texts = batch["supervisions"]["text"]
unk_id = params.unk_id unk_id = params.unk_id
y = convert_texts_into_ids(texts, unk_id, sp=sp) y = convert_texts_into_ids(texts, sp=sp)
y = k2.RaggedTensor(y).to(device) y = k2.RaggedTensor(y).to(device)
with torch.set_grad_enabled(is_training): with torch.set_grad_enabled(is_training):

View File

@ -397,7 +397,7 @@ def compute_loss(
texts = batch["supervisions"]["text"] texts = batch["supervisions"]["text"]
unk_id = params.unk_id unk_id = params.unk_id
y = convert_texts_into_ids(texts, unk_id, sp=sp) y = convert_texts_into_ids(texts, sp=sp)
y = k2.RaggedTensor(y).to(device) y = k2.RaggedTensor(y).to(device)
with torch.set_grad_enabled(is_training): with torch.set_grad_enabled(is_training):