This commit is contained in:
Piotr Żelasko 2022-01-21 17:27:02 -05:00
parent f0f35e6671
commit 1d5fe8afa4

View File

@ -38,11 +38,11 @@ from torch.utils.tensorboard import SummaryWriter
from transformer import Noam from transformer import Noam
from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
from icefall.graph_compiler import CtcTrainingGraphCompiler
from icefall.checkpoint import load_checkpoint from icefall.checkpoint import load_checkpoint
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
from icefall.dist import cleanup_dist, setup_dist from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info from icefall.env import get_env_info
from icefall.graph_compiler import CtcTrainingGraphCompiler
from icefall.lexicon import Lexicon from icefall.lexicon import Lexicon
from icefall.utils import ( from icefall.utils import (
AttributeDict, AttributeDict,
@ -595,8 +595,9 @@ def run(rank, world_size, args):
) )
elif "lang_phone" in params.lang_dir: elif "lang_phone" in params.lang_dir:
assert params.att_rate == 0, ( assert params.att_rate == 0, (
"Attention decoder training does not support phone lang dirs at this time due to a missing " "Attention decoder training does not support phone lang dirs "
"<sos/eos> symbol. Set --att-rate=0 for pure CTC training when using a phone-based lang dir." "at this time due to a missing <sos/eos> symbol. Set --att-rate=0 "
"for pure CTC training when using a phone-based lang dir."
) )
graph_compiler = CtcTrainingGraphCompiler( graph_compiler = CtcTrainingGraphCompiler(
lexicon, lexicon,
@ -608,8 +609,8 @@ def run(rank, world_size, args):
graph_compiler.eos_id = 1 graph_compiler.eos_id = 1
else: else:
raise ValueError( raise ValueError(
f"Unsupported type of lang dir (we expected it to have 'lang_bpe' or 'lang_phone' " f"Unsupported type of lang dir (we expected it to have "
f"in its name): {params.lang_dir}" f"'lang_bpe' or 'lang_phone' in its name): {params.lang_dir}"
) )
logging.info("About to create model") logging.info("About to create model")