diff --git a/egs/librispeech/ASR/conformer_ctc/train.py b/egs/librispeech/ASR/conformer_ctc/train.py index c1fa814c0..cb0bd5c2d 100755 --- a/egs/librispeech/ASR/conformer_ctc/train.py +++ b/egs/librispeech/ASR/conformer_ctc/train.py @@ -41,6 +41,7 @@ from icefall.checkpoint import load_checkpoint from icefall.checkpoint import save_checkpoint as save_checkpoint_impl from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info +from icefall.graph_compiler import CtcTrainingGraphCompiler from icefall.lexicon import Lexicon from icefall.utils import ( AttributeDict, @@ -123,6 +124,15 @@ def get_parser(): """, ) + parser.add_argument( + "--num-decoder-layers", + type=int, + default=6, + help="""Number of decoder layer of transformer decoder. + Setting this to 0 will not create the decoder at all (pure CTC model) + """, + ) + parser.add_argument( "--lr-factor", type=float, @@ -210,7 +220,6 @@ def get_params() -> AttributeDict: "use_feat_batchnorm": True, "attention_dim": 512, "nhead": 8, - "num_decoder_layers": 6, # parameters for loss "beam_size": 10, "reduction": "sum", @@ -357,9 +366,17 @@ def compute_loss( supervisions, subsampling_factor=params.subsampling_factor ) - token_ids = graph_compiler.texts_to_ids(texts) - - decoding_graph = graph_compiler.compile(token_ids) + if isinstance(graph_compiler, BpeCtcTrainingGraphCompiler): + # Works with a BPE model + token_ids = graph_compiler.texts_to_ids(texts) + decoding_graph = graph_compiler.compile(token_ids) + elif isinstance(graph_compiler, CtcTrainingGraphCompiler): + # Works with a phone lexicon + decoding_graph = graph_compiler.compile(texts) + else: + raise ValueError( + f"Unsupported type of graph compiler: {type(graph_compiler)}" + ) dense_fsa_vec = k2.DenseFsaVec( nnet_output, @@ -584,12 +601,38 @@ def run(rank, world_size, args): if torch.cuda.is_available(): device = torch.device("cuda", rank) - graph_compiler = BpeCtcTrainingGraphCompiler( - params.lang_dir, - device=device, - sos_token="", - eos_token="", - ) + if "lang_bpe" in params.lang_dir: + graph_compiler = BpeCtcTrainingGraphCompiler( + params.lang_dir, + device=device, + sos_token="", + eos_token="", + ) + elif "lang_phone" in params.lang_dir: + assert params.att_rate == 0, ( + "Attention decoder training does not support phone lang dirs " + "at this time due to a missing symbol. Set --att-rate=0 " + "for pure CTC training when using a phone-based lang dir." + ) + assert params.num_decoder_layers == 0, ( + "Attention decoder training does not support phone lang dirs " + "at this time due to a missing symbol. " + "Set --num-decoder-layers=0 for pure CTC training when using " + "a phone-based lang dir." + ) + graph_compiler = CtcTrainingGraphCompiler( + lexicon, + device=device, + ) + # Manually add the sos/eos ID with their default values + # from the BPE recipe which we're adapting here. + graph_compiler.sos_id = 1 + graph_compiler.eos_id = 1 + else: + raise ValueError( + f"Unsupported type of lang dir (we expected it to have " + f"'lang_bpe' or 'lang_phone' in its name): {params.lang_dir}" + ) logging.info("About to create model") model = Conformer( @@ -607,7 +650,9 @@ def run(rank, world_size, args): model.to(device) if world_size > 1: - model = DDP(model, device_ids=[rank]) + # Note: find_unused_parameters=True is needed in case we + # want to set params.att_rate = 0 (i.e. att decoder is not trained) + model = DDP(model, device_ids=[rank], find_unused_parameters=True) optimizer = Noam( model.parameters(), diff --git a/icefall/graph_compiler.py b/icefall/graph_compiler.py index b4c87d964..570ed7d7a 100644 --- a/icefall/graph_compiler.py +++ b/icefall/graph_compiler.py @@ -89,6 +89,29 @@ class CtcTrainingGraphCompiler(object): return decoding_graph + def texts_to_ids(self, texts: List[str]) -> List[List[int]]: + """Convert a list of texts to a list-of-list of word IDs. + + Args: + texts: + It is a list of strings. Each string consists of space(s) + separated words. An example containing two strings is given below: + + ['HELLO ICEFALL', 'HELLO k2'] + Returns: + Return a list-of-list of word IDs. + """ + word_ids_list = [] + for text in texts: + word_ids = [] + for word in text.split(): + if word in self.word_table: + word_ids.append(self.word_table[word]) + else: + word_ids.append(self.oov_id) + word_ids_list.append(word_ids) + return word_ids_list + def convert_transcript_to_fsa(self, texts: List[str]) -> k2.Fsa: """Convert a list of transcript texts to an FsaVec.