Add need_repeat_flag in phone based ctc graph compiler (#727)

* Fix is_repeat_token in icefall

* Fix phone based recipe

* Update egs/librispeech/ASR/conformer_ctc3/train.py

Co-authored-by: Fangjun Kuang <csukuangfj@gmail.com>

* Fix black

Co-authored-by: Fangjun Kuang <csukuangfj@gmail.com>
This commit is contained in:
Wei Kang 2022-12-04 17:20:17 +08:00 committed by GitHub
parent e6a6727012
commit c25c8c6ad1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 15 additions and 4 deletions

View File

@ -890,6 +890,7 @@ def run(rank, world_size, args):
graph_compiler = CtcTrainingGraphCompiler(
lexicon,
device=device,
need_repeat_flag=params.delay_penalty > 0,
)
# Manually add the sos/eos ID with their default values
# from the BPE recipe which we're adapting here.

View File

@ -29,6 +29,7 @@ class CtcTrainingGraphCompiler(object):
lexicon: Lexicon,
device: torch.device,
oov: str = "<UNK>",
need_repeat_flag: bool = False,
):
"""
Args:
@ -39,6 +40,13 @@ class CtcTrainingGraphCompiler(object):
oov:
Out of vocabulary word. When a word in the transcript
does not exist in the lexicon, it is replaced with `oov`.
need_repeat_flag:
If True, will add an attribute named `_is_repeat_token_` to ctc_topo
indicating whether this token is a repeat token in ctc graph.
This attribute is needed to implement delay-penalty for phone-based
ctc loss. See https://github.com/k2-fsa/k2/pull/1086 for more
details. Note: The above change MUST be included in k2 to open this
flag.
"""
L_inv = lexicon.L_inv.to(device)
assert L_inv.requires_grad is False
@ -53,6 +61,12 @@ class CtcTrainingGraphCompiler(object):
ctc_topo = k2.ctc_topo(max_token_id, modified=False)
self.ctc_topo = ctc_topo.to(device)
if need_repeat_flag:
self.ctc_topo._is_repeat_token_ = (
self.ctc_topo.labels != self.ctc_topo.aux_labels
)
self.device = device
def compile(self, texts: List[str]) -> k2.Fsa:
@ -79,10 +93,6 @@ class CtcTrainingGraphCompiler(object):
fsa_with_self_loops = k2.arc_sort(fsa_with_self_loops)
self.ctc_topo._is_repeat_token_ = (
self.ctc_topo.labels != self.ctc_topo.aux_labels
).int()
decoding_graph = k2.compose(
self.ctc_topo, fsa_with_self_loops, treat_epsilons_specially=False
)