mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 17:42:21 +00:00
Merge pull request #185 from pzelasko/feature/libri-conformer-phone-ctc
Fix using `lang_phone` in conformer CTC training
This commit is contained in:
commit
8e6fd97c6b
@ -41,6 +41,7 @@ from icefall.checkpoint import load_checkpoint
|
||||
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.graph_compiler import CtcTrainingGraphCompiler
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
@ -123,6 +124,15 @@ def get_parser():
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-decoder-layers",
|
||||
type=int,
|
||||
default=6,
|
||||
help="""Number of decoder layer of transformer decoder.
|
||||
Setting this to 0 will not create the decoder at all (pure CTC model)
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lr-factor",
|
||||
type=float,
|
||||
@ -210,7 +220,6 @@ def get_params() -> AttributeDict:
|
||||
"use_feat_batchnorm": True,
|
||||
"attention_dim": 512,
|
||||
"nhead": 8,
|
||||
"num_decoder_layers": 6,
|
||||
# parameters for loss
|
||||
"beam_size": 10,
|
||||
"reduction": "sum",
|
||||
@ -357,9 +366,17 @@ def compute_loss(
|
||||
supervisions, subsampling_factor=params.subsampling_factor
|
||||
)
|
||||
|
||||
if isinstance(graph_compiler, BpeCtcTrainingGraphCompiler):
|
||||
# Works with a BPE model
|
||||
token_ids = graph_compiler.texts_to_ids(texts)
|
||||
|
||||
decoding_graph = graph_compiler.compile(token_ids)
|
||||
elif isinstance(graph_compiler, CtcTrainingGraphCompiler):
|
||||
# Works with a phone lexicon
|
||||
decoding_graph = graph_compiler.compile(texts)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported type of graph compiler: {type(graph_compiler)}"
|
||||
)
|
||||
|
||||
dense_fsa_vec = k2.DenseFsaVec(
|
||||
nnet_output,
|
||||
@ -584,12 +601,38 @@ def run(rank, world_size, args):
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", rank)
|
||||
|
||||
if "lang_bpe" in params.lang_dir:
|
||||
graph_compiler = BpeCtcTrainingGraphCompiler(
|
||||
params.lang_dir,
|
||||
device=device,
|
||||
sos_token="<sos/eos>",
|
||||
eos_token="<sos/eos>",
|
||||
)
|
||||
elif "lang_phone" in params.lang_dir:
|
||||
assert params.att_rate == 0, (
|
||||
"Attention decoder training does not support phone lang dirs "
|
||||
"at this time due to a missing <sos/eos> symbol. Set --att-rate=0 "
|
||||
"for pure CTC training when using a phone-based lang dir."
|
||||
)
|
||||
assert params.num_decoder_layers == 0, (
|
||||
"Attention decoder training does not support phone lang dirs "
|
||||
"at this time due to a missing <sos/eos> symbol. "
|
||||
"Set --num-decoder-layers=0 for pure CTC training when using "
|
||||
"a phone-based lang dir."
|
||||
)
|
||||
graph_compiler = CtcTrainingGraphCompiler(
|
||||
lexicon,
|
||||
device=device,
|
||||
)
|
||||
# Manually add the sos/eos ID with their default values
|
||||
# from the BPE recipe which we're adapting here.
|
||||
graph_compiler.sos_id = 1
|
||||
graph_compiler.eos_id = 1
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported type of lang dir (we expected it to have "
|
||||
f"'lang_bpe' or 'lang_phone' in its name): {params.lang_dir}"
|
||||
)
|
||||
|
||||
logging.info("About to create model")
|
||||
model = Conformer(
|
||||
@ -607,7 +650,9 @@ def run(rank, world_size, args):
|
||||
|
||||
model.to(device)
|
||||
if world_size > 1:
|
||||
model = DDP(model, device_ids=[rank])
|
||||
# Note: find_unused_parameters=True is needed in case we
|
||||
# want to set params.att_rate = 0 (i.e. att decoder is not trained)
|
||||
model = DDP(model, device_ids=[rank], find_unused_parameters=True)
|
||||
|
||||
optimizer = Noam(
|
||||
model.parameters(),
|
||||
|
@ -89,6 +89,29 @@ class CtcTrainingGraphCompiler(object):
|
||||
|
||||
return decoding_graph
|
||||
|
||||
def texts_to_ids(self, texts: List[str]) -> List[List[int]]:
|
||||
"""Convert a list of texts to a list-of-list of word IDs.
|
||||
|
||||
Args:
|
||||
texts:
|
||||
It is a list of strings. Each string consists of space(s)
|
||||
separated words. An example containing two strings is given below:
|
||||
|
||||
['HELLO ICEFALL', 'HELLO k2']
|
||||
Returns:
|
||||
Return a list-of-list of word IDs.
|
||||
"""
|
||||
word_ids_list = []
|
||||
for text in texts:
|
||||
word_ids = []
|
||||
for word in text.split():
|
||||
if word in self.word_table:
|
||||
word_ids.append(self.word_table[word])
|
||||
else:
|
||||
word_ids.append(self.oov_id)
|
||||
word_ids_list.append(word_ids)
|
||||
return word_ids_list
|
||||
|
||||
def convert_transcript_to_fsa(self, texts: List[str]) -> k2.Fsa:
|
||||
"""Convert a list of transcript texts to an FsaVec.
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user