diff --git a/egs/librispeech/ASR/conformer_ctc3/train.py b/egs/librispeech/ASR/conformer_ctc3/train.py index fb3b740c1..ac489af9e 100755 --- a/egs/librispeech/ASR/conformer_ctc3/train.py +++ b/egs/librispeech/ASR/conformer_ctc3/train.py @@ -890,6 +890,7 @@ def run(rank, world_size, args): graph_compiler = CtcTrainingGraphCompiler( lexicon, device=device, + need_repeat_flag=params.delay_penalty > 0, ) # Manually add the sos/eos ID with their default values # from the BPE recipe which we're adapting here. diff --git a/icefall/graph_compiler.py b/icefall/graph_compiler.py index 0dcd777ad..d26ddbbd1 100644 --- a/icefall/graph_compiler.py +++ b/icefall/graph_compiler.py @@ -29,6 +29,7 @@ class CtcTrainingGraphCompiler(object): lexicon: Lexicon, device: torch.device, oov: str = "", + need_repeat_flag: bool = False, ): """ Args: @@ -39,6 +40,13 @@ class CtcTrainingGraphCompiler(object): oov: Out of vocabulary word. When a word in the transcript does not exist in the lexicon, it is replaced with `oov`. + need_repeat_flag: + If True, will add an attribute named `_is_repeat_token_` to ctc_topo + indicating whether this token is a repeat token in ctc graph. + This attribute is needed to implement delay-penalty for phone-based + ctc loss. See https://github.com/k2-fsa/k2/pull/1086 for more + details. Note: The above change MUST be included in k2 to open this + flag. """ L_inv = lexicon.L_inv.to(device) assert L_inv.requires_grad is False @@ -53,6 +61,12 @@ class CtcTrainingGraphCompiler(object): ctc_topo = k2.ctc_topo(max_token_id, modified=False) self.ctc_topo = ctc_topo.to(device) + + if need_repeat_flag: + self.ctc_topo._is_repeat_token_ = ( + self.ctc_topo.labels != self.ctc_topo.aux_labels + ) + self.device = device def compile(self, texts: List[str]) -> k2.Fsa: @@ -79,10 +93,6 @@ class CtcTrainingGraphCompiler(object): fsa_with_self_loops = k2.arc_sort(fsa_with_self_loops) - self.ctc_topo._is_repeat_token_ = ( - self.ctc_topo.labels != self.ctc_topo.aux_labels - ).int() - decoding_graph = k2.compose( self.ctc_topo, fsa_with_self_loops, treat_epsilons_specially=False )