Merge pull request #185 from pzelasko/feature/libri-conformer-phone-ctc

Fix using `lang_phone` in conformer CTC training
This commit is contained in:
Piotr Żelasko 2022-01-24 18:08:15 -05:00 committed by GitHub
commit 8e6fd97c6b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 79 additions and 11 deletions

View File

@ -41,6 +41,7 @@ from icefall.checkpoint import load_checkpoint
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
from icefall.dist import cleanup_dist, setup_dist from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info from icefall.env import get_env_info
from icefall.graph_compiler import CtcTrainingGraphCompiler
from icefall.lexicon import Lexicon from icefall.lexicon import Lexicon
from icefall.utils import ( from icefall.utils import (
AttributeDict, AttributeDict,
@ -123,6 +124,15 @@ def get_parser():
""", """,
) )
parser.add_argument(
"--num-decoder-layers",
type=int,
default=6,
help="""Number of decoder layer of transformer decoder.
Setting this to 0 will not create the decoder at all (pure CTC model)
""",
)
parser.add_argument( parser.add_argument(
"--lr-factor", "--lr-factor",
type=float, type=float,
@ -210,7 +220,6 @@ def get_params() -> AttributeDict:
"use_feat_batchnorm": True, "use_feat_batchnorm": True,
"attention_dim": 512, "attention_dim": 512,
"nhead": 8, "nhead": 8,
"num_decoder_layers": 6,
# parameters for loss # parameters for loss
"beam_size": 10, "beam_size": 10,
"reduction": "sum", "reduction": "sum",
@ -357,9 +366,17 @@ def compute_loss(
supervisions, subsampling_factor=params.subsampling_factor supervisions, subsampling_factor=params.subsampling_factor
) )
token_ids = graph_compiler.texts_to_ids(texts) if isinstance(graph_compiler, BpeCtcTrainingGraphCompiler):
# Works with a BPE model
decoding_graph = graph_compiler.compile(token_ids) token_ids = graph_compiler.texts_to_ids(texts)
decoding_graph = graph_compiler.compile(token_ids)
elif isinstance(graph_compiler, CtcTrainingGraphCompiler):
# Works with a phone lexicon
decoding_graph = graph_compiler.compile(texts)
else:
raise ValueError(
f"Unsupported type of graph compiler: {type(graph_compiler)}"
)
dense_fsa_vec = k2.DenseFsaVec( dense_fsa_vec = k2.DenseFsaVec(
nnet_output, nnet_output,
@ -584,12 +601,38 @@ def run(rank, world_size, args):
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device("cuda", rank) device = torch.device("cuda", rank)
graph_compiler = BpeCtcTrainingGraphCompiler( if "lang_bpe" in params.lang_dir:
params.lang_dir, graph_compiler = BpeCtcTrainingGraphCompiler(
device=device, params.lang_dir,
sos_token="<sos/eos>", device=device,
eos_token="<sos/eos>", sos_token="<sos/eos>",
) eos_token="<sos/eos>",
)
elif "lang_phone" in params.lang_dir:
assert params.att_rate == 0, (
"Attention decoder training does not support phone lang dirs "
"at this time due to a missing <sos/eos> symbol. Set --att-rate=0 "
"for pure CTC training when using a phone-based lang dir."
)
assert params.num_decoder_layers == 0, (
"Attention decoder training does not support phone lang dirs "
"at this time due to a missing <sos/eos> symbol. "
"Set --num-decoder-layers=0 for pure CTC training when using "
"a phone-based lang dir."
)
graph_compiler = CtcTrainingGraphCompiler(
lexicon,
device=device,
)
# Manually add the sos/eos ID with their default values
# from the BPE recipe which we're adapting here.
graph_compiler.sos_id = 1
graph_compiler.eos_id = 1
else:
raise ValueError(
f"Unsupported type of lang dir (we expected it to have "
f"'lang_bpe' or 'lang_phone' in its name): {params.lang_dir}"
)
logging.info("About to create model") logging.info("About to create model")
model = Conformer( model = Conformer(
@ -607,7 +650,9 @@ def run(rank, world_size, args):
model.to(device) model.to(device)
if world_size > 1: if world_size > 1:
model = DDP(model, device_ids=[rank]) # Note: find_unused_parameters=True is needed in case we
# want to set params.att_rate = 0 (i.e. att decoder is not trained)
model = DDP(model, device_ids=[rank], find_unused_parameters=True)
optimizer = Noam( optimizer = Noam(
model.parameters(), model.parameters(),

View File

@ -89,6 +89,29 @@ class CtcTrainingGraphCompiler(object):
return decoding_graph return decoding_graph
def texts_to_ids(self, texts: List[str]) -> List[List[int]]:
"""Convert a list of texts to a list-of-list of word IDs.
Args:
texts:
It is a list of strings. Each string consists of space(s)
separated words. An example containing two strings is given below:
['HELLO ICEFALL', 'HELLO k2']
Returns:
Return a list-of-list of word IDs.
"""
word_ids_list = []
for text in texts:
word_ids = []
for word in text.split():
if word in self.word_table:
word_ids.append(self.word_table[word])
else:
word_ids.append(self.oov_id)
word_ids_list.append(word_ids)
return word_ids_list
def convert_transcript_to_fsa(self, texts: List[str]) -> k2.Fsa: def convert_transcript_to_fsa(self, texts: List[str]) -> k2.Fsa:
"""Convert a list of transcript texts to an FsaVec. """Convert a list of transcript texts to an FsaVec.