mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 10:16:14 +00:00
Update train_char.py
This commit is contained in:
parent
4413713a05
commit
9bf88ac3b1
@ -97,6 +97,7 @@ from icefall.utils import (
|
|||||||
get_parameter_groups_with_lrs,
|
get_parameter_groups_with_lrs,
|
||||||
setup_logger,
|
setup_logger,
|
||||||
str2bool,
|
str2bool,
|
||||||
|
tokenize_by_CJK_char,
|
||||||
)
|
)
|
||||||
|
|
||||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||||
@ -356,7 +357,7 @@ def compute_loss(
|
|||||||
batch_idx_train = params.batch_idx_train
|
batch_idx_train = params.batch_idx_train
|
||||||
warm_step = params.warm_step
|
warm_step = params.warm_step
|
||||||
|
|
||||||
texts = batch["supervisions"]["text"]
|
texts = [tokenize_by_CJK_char(text) for text in batch["supervisions"]["text"]]
|
||||||
y = graph_compiler.texts_to_ids(texts)
|
y = graph_compiler.texts_to_ids(texts)
|
||||||
y = k2.RaggedTensor(y).to(device)
|
y = k2.RaggedTensor(y).to(device)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user