mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 09:32:20 +00:00
Add need_repeat_flag in phone based ctc graph compiler (#727)
* Fix is_repeat_token in icefall * Fix phone based recipe * Update egs/librispeech/ASR/conformer_ctc3/train.py Co-authored-by: Fangjun Kuang <csukuangfj@gmail.com> * Fix black Co-authored-by: Fangjun Kuang <csukuangfj@gmail.com>
This commit is contained in:
parent
e6a6727012
commit
c25c8c6ad1
@ -890,6 +890,7 @@ def run(rank, world_size, args):
|
||||
graph_compiler = CtcTrainingGraphCompiler(
|
||||
lexicon,
|
||||
device=device,
|
||||
need_repeat_flag=params.delay_penalty > 0,
|
||||
)
|
||||
# Manually add the sos/eos ID with their default values
|
||||
# from the BPE recipe which we're adapting here.
|
||||
|
@ -29,6 +29,7 @@ class CtcTrainingGraphCompiler(object):
|
||||
lexicon: Lexicon,
|
||||
device: torch.device,
|
||||
oov: str = "<UNK>",
|
||||
need_repeat_flag: bool = False,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
@ -39,6 +40,13 @@ class CtcTrainingGraphCompiler(object):
|
||||
oov:
|
||||
Out of vocabulary word. When a word in the transcript
|
||||
does not exist in the lexicon, it is replaced with `oov`.
|
||||
need_repeat_flag:
|
||||
If True, will add an attribute named `_is_repeat_token_` to ctc_topo
|
||||
indicating whether this token is a repeat token in ctc graph.
|
||||
This attribute is needed to implement delay-penalty for phone-based
|
||||
ctc loss. See https://github.com/k2-fsa/k2/pull/1086 for more
|
||||
details. Note: The above change MUST be included in k2 to open this
|
||||
flag.
|
||||
"""
|
||||
L_inv = lexicon.L_inv.to(device)
|
||||
assert L_inv.requires_grad is False
|
||||
@ -53,6 +61,12 @@ class CtcTrainingGraphCompiler(object):
|
||||
ctc_topo = k2.ctc_topo(max_token_id, modified=False)
|
||||
|
||||
self.ctc_topo = ctc_topo.to(device)
|
||||
|
||||
if need_repeat_flag:
|
||||
self.ctc_topo._is_repeat_token_ = (
|
||||
self.ctc_topo.labels != self.ctc_topo.aux_labels
|
||||
)
|
||||
|
||||
self.device = device
|
||||
|
||||
def compile(self, texts: List[str]) -> k2.Fsa:
|
||||
@ -79,10 +93,6 @@ class CtcTrainingGraphCompiler(object):
|
||||
|
||||
fsa_with_self_loops = k2.arc_sort(fsa_with_self_loops)
|
||||
|
||||
self.ctc_topo._is_repeat_token_ = (
|
||||
self.ctc_topo.labels != self.ctc_topo.aux_labels
|
||||
).int()
|
||||
|
||||
decoding_graph = k2.compose(
|
||||
self.ctc_topo, fsa_with_self_loops, treat_epsilons_specially=False
|
||||
)
|
||||
|
Loading…
x
Reference in New Issue
Block a user