Update train_char.py

This commit is contained in:
jinzr 2024-03-13 12:01:34 +08:00
parent 4413713a05
commit 9bf88ac3b1

View File

@ -97,6 +97,7 @@ from icefall.utils import (
get_parameter_groups_with_lrs, get_parameter_groups_with_lrs,
setup_logger, setup_logger,
str2bool, str2bool,
tokenize_by_CJK_char,
) )
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@ -356,7 +357,7 @@ def compute_loss(
batch_idx_train = params.batch_idx_train batch_idx_train = params.batch_idx_train
warm_step = params.warm_step warm_step = params.warm_step
texts = batch["supervisions"]["text"] texts = [tokenize_by_CJK_char(text) for text in batch["supervisions"]["text"]]
y = graph_compiler.texts_to_ids(texts) y = graph_compiler.texts_to_ids(texts)
y = k2.RaggedTensor(y).to(device) y = k2.RaggedTensor(y).to(device)