mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
Merge pull request #185 from pzelasko/feature/libri-conformer-phone-ctc
Fix using `lang_phone` in conformer CTC training
This commit is contained in:
commit
8e6fd97c6b
@ -41,6 +41,7 @@ from icefall.checkpoint import load_checkpoint
|
|||||||
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
|
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
|
||||||
from icefall.dist import cleanup_dist, setup_dist
|
from icefall.dist import cleanup_dist, setup_dist
|
||||||
from icefall.env import get_env_info
|
from icefall.env import get_env_info
|
||||||
|
from icefall.graph_compiler import CtcTrainingGraphCompiler
|
||||||
from icefall.lexicon import Lexicon
|
from icefall.lexicon import Lexicon
|
||||||
from icefall.utils import (
|
from icefall.utils import (
|
||||||
AttributeDict,
|
AttributeDict,
|
||||||
@ -123,6 +124,15 @@ def get_parser():
|
|||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--num-decoder-layers",
|
||||||
|
type=int,
|
||||||
|
default=6,
|
||||||
|
help="""Number of decoder layer of transformer decoder.
|
||||||
|
Setting this to 0 will not create the decoder at all (pure CTC model)
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--lr-factor",
|
"--lr-factor",
|
||||||
type=float,
|
type=float,
|
||||||
@ -210,7 +220,6 @@ def get_params() -> AttributeDict:
|
|||||||
"use_feat_batchnorm": True,
|
"use_feat_batchnorm": True,
|
||||||
"attention_dim": 512,
|
"attention_dim": 512,
|
||||||
"nhead": 8,
|
"nhead": 8,
|
||||||
"num_decoder_layers": 6,
|
|
||||||
# parameters for loss
|
# parameters for loss
|
||||||
"beam_size": 10,
|
"beam_size": 10,
|
||||||
"reduction": "sum",
|
"reduction": "sum",
|
||||||
@ -357,9 +366,17 @@ def compute_loss(
|
|||||||
supervisions, subsampling_factor=params.subsampling_factor
|
supervisions, subsampling_factor=params.subsampling_factor
|
||||||
)
|
)
|
||||||
|
|
||||||
token_ids = graph_compiler.texts_to_ids(texts)
|
if isinstance(graph_compiler, BpeCtcTrainingGraphCompiler):
|
||||||
|
# Works with a BPE model
|
||||||
decoding_graph = graph_compiler.compile(token_ids)
|
token_ids = graph_compiler.texts_to_ids(texts)
|
||||||
|
decoding_graph = graph_compiler.compile(token_ids)
|
||||||
|
elif isinstance(graph_compiler, CtcTrainingGraphCompiler):
|
||||||
|
# Works with a phone lexicon
|
||||||
|
decoding_graph = graph_compiler.compile(texts)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unsupported type of graph compiler: {type(graph_compiler)}"
|
||||||
|
)
|
||||||
|
|
||||||
dense_fsa_vec = k2.DenseFsaVec(
|
dense_fsa_vec = k2.DenseFsaVec(
|
||||||
nnet_output,
|
nnet_output,
|
||||||
@ -584,12 +601,38 @@ def run(rank, world_size, args):
|
|||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device("cuda", rank)
|
device = torch.device("cuda", rank)
|
||||||
|
|
||||||
graph_compiler = BpeCtcTrainingGraphCompiler(
|
if "lang_bpe" in params.lang_dir:
|
||||||
params.lang_dir,
|
graph_compiler = BpeCtcTrainingGraphCompiler(
|
||||||
device=device,
|
params.lang_dir,
|
||||||
sos_token="<sos/eos>",
|
device=device,
|
||||||
eos_token="<sos/eos>",
|
sos_token="<sos/eos>",
|
||||||
)
|
eos_token="<sos/eos>",
|
||||||
|
)
|
||||||
|
elif "lang_phone" in params.lang_dir:
|
||||||
|
assert params.att_rate == 0, (
|
||||||
|
"Attention decoder training does not support phone lang dirs "
|
||||||
|
"at this time due to a missing <sos/eos> symbol. Set --att-rate=0 "
|
||||||
|
"for pure CTC training when using a phone-based lang dir."
|
||||||
|
)
|
||||||
|
assert params.num_decoder_layers == 0, (
|
||||||
|
"Attention decoder training does not support phone lang dirs "
|
||||||
|
"at this time due to a missing <sos/eos> symbol. "
|
||||||
|
"Set --num-decoder-layers=0 for pure CTC training when using "
|
||||||
|
"a phone-based lang dir."
|
||||||
|
)
|
||||||
|
graph_compiler = CtcTrainingGraphCompiler(
|
||||||
|
lexicon,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
# Manually add the sos/eos ID with their default values
|
||||||
|
# from the BPE recipe which we're adapting here.
|
||||||
|
graph_compiler.sos_id = 1
|
||||||
|
graph_compiler.eos_id = 1
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unsupported type of lang dir (we expected it to have "
|
||||||
|
f"'lang_bpe' or 'lang_phone' in its name): {params.lang_dir}"
|
||||||
|
)
|
||||||
|
|
||||||
logging.info("About to create model")
|
logging.info("About to create model")
|
||||||
model = Conformer(
|
model = Conformer(
|
||||||
@ -607,7 +650,9 @@ def run(rank, world_size, args):
|
|||||||
|
|
||||||
model.to(device)
|
model.to(device)
|
||||||
if world_size > 1:
|
if world_size > 1:
|
||||||
model = DDP(model, device_ids=[rank])
|
# Note: find_unused_parameters=True is needed in case we
|
||||||
|
# want to set params.att_rate = 0 (i.e. att decoder is not trained)
|
||||||
|
model = DDP(model, device_ids=[rank], find_unused_parameters=True)
|
||||||
|
|
||||||
optimizer = Noam(
|
optimizer = Noam(
|
||||||
model.parameters(),
|
model.parameters(),
|
||||||
|
@ -89,6 +89,29 @@ class CtcTrainingGraphCompiler(object):
|
|||||||
|
|
||||||
return decoding_graph
|
return decoding_graph
|
||||||
|
|
||||||
|
def texts_to_ids(self, texts: List[str]) -> List[List[int]]:
|
||||||
|
"""Convert a list of texts to a list-of-list of word IDs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
texts:
|
||||||
|
It is a list of strings. Each string consists of space(s)
|
||||||
|
separated words. An example containing two strings is given below:
|
||||||
|
|
||||||
|
['HELLO ICEFALL', 'HELLO k2']
|
||||||
|
Returns:
|
||||||
|
Return a list-of-list of word IDs.
|
||||||
|
"""
|
||||||
|
word_ids_list = []
|
||||||
|
for text in texts:
|
||||||
|
word_ids = []
|
||||||
|
for word in text.split():
|
||||||
|
if word in self.word_table:
|
||||||
|
word_ids.append(self.word_table[word])
|
||||||
|
else:
|
||||||
|
word_ids.append(self.oov_id)
|
||||||
|
word_ids_list.append(word_ids)
|
||||||
|
return word_ids_list
|
||||||
|
|
||||||
def convert_transcript_to_fsa(self, texts: List[str]) -> k2.Fsa:
|
def convert_transcript_to_fsa(self, texts: List[str]) -> k2.Fsa:
|
||||||
"""Convert a list of transcript texts to an FsaVec.
|
"""Convert a list of transcript texts to an FsaVec.
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user