From a7cb40a0a0885cef2a33cbeafaa3a7392bbbded8 Mon Sep 17 00:00:00 2001 From: pkufool Date: Thu, 16 Nov 2023 11:41:29 +0800 Subject: [PATCH] Fix deocde --- egs/librispeech/ASR/pruned_transducer_stateless4/decode.py | 4 ++-- egs/librispeech/ASR/zipformer/decode.py | 4 ++-- egs/wenetspeech/ASR/pruned_transducer_stateless5/decode.py | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py index 524366068..5195a4ef6 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py @@ -927,9 +927,9 @@ def main(): if os.path.exists(params.context_file): contexts = [] for line in open(params.context_file).readlines(): - contexts.append(line.strip()) + contexts.append((sp.encode(line.strip()), 0.0)) context_graph = ContextGraph(params.context_score) - context_graph.build(sp.encode(contexts)) + context_graph.build(contexts) else: context_graph = None else: diff --git a/egs/librispeech/ASR/zipformer/decode.py b/egs/librispeech/ASR/zipformer/decode.py index 3531d657f..339e253e6 100755 --- a/egs/librispeech/ASR/zipformer/decode.py +++ b/egs/librispeech/ASR/zipformer/decode.py @@ -1001,9 +1001,9 @@ def main(): if os.path.exists(params.context_file): contexts = [] for line in open(params.context_file).readlines(): - contexts.append(line.strip()) + contexts.append((sp.encode(line.strip()), 0.0)) context_graph = ContextGraph(params.context_score) - context_graph.build(sp.encode(contexts)) + context_graph.build(contexts) else: context_graph = None else: diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode.py index 36b8a4b67..d665f3364 100755 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode.py @@ -868,7 +868,7 @@ def main(): contexts_text.append(line.strip()) contexts = graph_compiler.texts_to_ids(contexts_text) context_graph = ContextGraph(params.context_score) - context_graph.build(contexts) + context_graph.build([(c, 0.0) for c in contexts]) else: context_graph = None else: