mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Minor fixes to the decode.py
This commit is contained in:
parent
8a94476fd9
commit
57c1d762b6
@ -58,6 +58,7 @@ Usage:
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
@ -76,6 +77,8 @@ from beam_search import (
|
||||
)
|
||||
from train import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
from icefall import ContextGraph, LmScorer, NgramLm
|
||||
from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
average_checkpoints_with_averaged_model,
|
||||
@ -211,6 +214,26 @@ def get_parser():
|
||||
Used only when --decoding_method is greedy_search""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-score",
|
||||
type=float,
|
||||
default=2,
|
||||
help="""
|
||||
The bonus score of each token for the context biasing words/phrases.
|
||||
Used only when --decoding_method is modified_beam_search.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-file",
|
||||
type=str,
|
||||
default="",
|
||||
help="""
|
||||
The path of the context biasing lists, one word/phrase each line
|
||||
Used only when --decoding_method is modified_beam_search.
|
||||
""",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
@ -222,6 +245,7 @@ def decode_one_batch(
|
||||
token_table: k2.SymbolTable,
|
||||
batch: dict,
|
||||
decoding_graph: Optional[k2.Fsa] = None,
|
||||
context_graph: Optional[ContextGraph] = None,
|
||||
) -> Dict[str, List[List[str]]]:
|
||||
"""Decode one batch and return the result in a dict. The dict has the
|
||||
following format:
|
||||
@ -285,6 +309,7 @@ def decode_one_batch(
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam_size,
|
||||
context_graph=context_graph,
|
||||
)
|
||||
else:
|
||||
hyp_tokens = []
|
||||
@ -324,7 +349,12 @@ def decode_one_batch(
|
||||
): hyps
|
||||
}
|
||||
else:
|
||||
return {f"beam_size_{params.beam_size}": hyps}
|
||||
key = f"beam_size_{params.beam_size}"
|
||||
if params.has_contexts:
|
||||
key += f"-context-score-{params.context_score}"
|
||||
else:
|
||||
key += "-no-context-words"
|
||||
return {key: hyps}
|
||||
|
||||
|
||||
def decode_dataset(
|
||||
@ -333,6 +363,7 @@ def decode_dataset(
|
||||
model: nn.Module,
|
||||
token_table: k2.SymbolTable,
|
||||
decoding_graph: Optional[k2.Fsa] = None,
|
||||
context_graph: Optional[ContextGraph] = None,
|
||||
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
|
||||
"""Decode dataset.
|
||||
|
||||
@ -377,6 +408,7 @@ def decode_dataset(
|
||||
model=model,
|
||||
token_table=token_table,
|
||||
decoding_graph=decoding_graph,
|
||||
context_graph=context_graph,
|
||||
batch=batch,
|
||||
)
|
||||
|
||||
@ -407,16 +439,17 @@ def save_results(
|
||||
for key, results in results_dict.items():
|
||||
recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
|
||||
results = sorted(results)
|
||||
store_transcripts(filename=recog_path, texts=results)
|
||||
# we compute CER for aishell dataset.
|
||||
results_char = []
|
||||
for res in results:
|
||||
results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
|
||||
|
||||
store_transcripts(filename=recog_path, texts=results_char)
|
||||
logging.info(f"The transcripts are stored in {recog_path}")
|
||||
|
||||
# The following prints out WERs, per-word error statistics and aligned
|
||||
# ref/hyp pairs.
|
||||
errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
|
||||
# we compute CER for aishell dataset.
|
||||
results_char = []
|
||||
for res in results:
|
||||
results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
|
||||
with open(errs_filename, "w") as f:
|
||||
wer = write_error_stats(
|
||||
f, f"{test_set_name}-{key}", results_char, enable_log=True
|
||||
@ -457,6 +490,12 @@ def main():
|
||||
"fast_beam_search",
|
||||
"modified_beam_search",
|
||||
)
|
||||
|
||||
if os.path.exists(params.context_file):
|
||||
params.has_contexts = True
|
||||
else:
|
||||
params.has_contexts = False
|
||||
|
||||
params.res_dir = params.exp_dir / params.decoding_method
|
||||
|
||||
if params.iter > 0:
|
||||
@ -470,6 +509,10 @@ def main():
|
||||
params.suffix += f"-max-states-{params.max_states}"
|
||||
elif "beam_search" in params.decoding_method:
|
||||
params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
|
||||
if params.has_contexts:
|
||||
params.suffix += f"-context-score-{params.context_score}"
|
||||
else:
|
||||
params.suffix += "-no-contexts-words"
|
||||
else:
|
||||
params.suffix += f"-context-{params.context_size}"
|
||||
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
||||
@ -490,6 +533,11 @@ def main():
|
||||
params.blank_id = 0
|
||||
params.vocab_size = max(lexicon.tokens) + 1
|
||||
|
||||
graph_compiler = CharCtcTrainingGraphCompiler(
|
||||
lexicon=lexicon,
|
||||
device=device,
|
||||
)
|
||||
|
||||
logging.info(params)
|
||||
|
||||
logging.info("About to create model")
|
||||
@ -586,6 +634,19 @@ def main():
|
||||
else:
|
||||
decoding_graph = None
|
||||
|
||||
if params.decoding_method == "modified_beam_search":
|
||||
if os.path.exists(params.context_file):
|
||||
contexts_text = []
|
||||
for line in open(params.context_file).readlines():
|
||||
contexts_text.append(line.strip())
|
||||
contexts = graph_compiler.texts_to_ids(contexts_text)
|
||||
context_graph = ContextGraph(params.context_score)
|
||||
context_graph.build(contexts)
|
||||
else:
|
||||
context_graph = None
|
||||
else:
|
||||
context_graph = None
|
||||
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
@ -608,6 +669,7 @@ def main():
|
||||
model=model,
|
||||
token_table=lexicon.token_table,
|
||||
decoding_graph=decoding_graph,
|
||||
context_graph=context_graph,
|
||||
)
|
||||
|
||||
save_results(
|
||||
|
||||
@ -15,7 +15,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import warnings
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
@ -131,8 +131,6 @@ from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import k2
|
||||
import kaldifst
|
||||
import graphviz
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -576,7 +574,10 @@ def decode_one_batch(
|
||||
return {key: (hyps, timestamps)}
|
||||
else:
|
||||
key = f"beam_size_{params.beam_size}"
|
||||
key += f"-context-score-{params.context_score}"
|
||||
if params.has_contexts:
|
||||
key += f"-context-score-{params.context_score}"
|
||||
else:
|
||||
key += "-no-context-words"
|
||||
return {key: (hyps, timestamps)}
|
||||
|
||||
|
||||
@ -626,7 +627,7 @@ def decode_dataset(
|
||||
if params.decoding_method == "greedy_search":
|
||||
log_interval = 50
|
||||
else:
|
||||
log_interval = 1
|
||||
log_interval = 20
|
||||
|
||||
results = defaultdict(list)
|
||||
for batch_idx, batch in enumerate(dl):
|
||||
@ -759,6 +760,12 @@ def main():
|
||||
"fast_beam_search_nbest_oracle",
|
||||
"modified_beam_search",
|
||||
)
|
||||
|
||||
if os.path.exists(params.context_file):
|
||||
params.has_contexts = True
|
||||
else:
|
||||
params.has_contexts = False
|
||||
|
||||
params.res_dir = params.exp_dir / params.decoding_method
|
||||
|
||||
if params.iter > 0:
|
||||
@ -781,7 +788,10 @@ def main():
|
||||
params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
|
||||
elif "beam_search" in params.decoding_method:
|
||||
params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
|
||||
params.suffix += f"-context-score-{params.context_score}"
|
||||
if params.has_contexts:
|
||||
params.suffix += f"-context-score-{params.context_score}"
|
||||
else:
|
||||
params.suffix += "-no-context-words"
|
||||
else:
|
||||
params.suffix += f"-context-{params.context_size}"
|
||||
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
||||
@ -938,14 +948,8 @@ def main():
|
||||
test_clean_dl = librispeech.test_dataloaders(test_clean_cuts)
|
||||
test_other_dl = librispeech.test_dataloaders(test_other_cuts)
|
||||
|
||||
test_book_cuts = librispeech.test_book_cuts()
|
||||
test_book_dl = librispeech.test_dataloaders(test_book_cuts)
|
||||
|
||||
test_book2_cuts = librispeech.test_book2_cuts()
|
||||
test_book2_dl = librispeech.test_dataloaders(test_book2_cuts)
|
||||
|
||||
test_sets = ["test-book", "test-book2", "test-clean", "test-other"]
|
||||
test_dl = [test_book_dl, test_book2_dl, test_clean_dl, test_other_dl]
|
||||
test_sets = ["test-clean", "test-other"]
|
||||
test_dl = [test_clean_dl, test_other_dl]
|
||||
|
||||
for test_set, test_dl in zip(test_sets, test_dl):
|
||||
results_dict = decode_dataset(
|
||||
|
||||
@ -389,7 +389,6 @@ class LibriSpeechAsrDataModule:
|
||||
)
|
||||
sampler = DynamicBucketingSampler(
|
||||
cuts,
|
||||
num_buckets=2,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=False,
|
||||
)
|
||||
@ -468,25 +467,6 @@ class LibriSpeechAsrDataModule:
|
||||
self.args.manifest_dir / "librispeech_cuts_test-clean.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def test_book_cuts(self) -> CutSet:
|
||||
logging.info("About to get test-books cuts")
|
||||
return load_manifest_lazy(self.args.manifest_dir / "libri_books_feats.jsonl.gz")
|
||||
|
||||
@lru_cache()
|
||||
def test_book_test_cuts(self) -> CutSet:
|
||||
logging.info("About to get test-books cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "libri_book_test_feats.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def test_book2_cuts(self) -> CutSet:
|
||||
logging.info("About to get test-books2 cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "libri_books2_feats.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def test_other_cuts(self) -> CutSet:
|
||||
logging.info("About to get test-other cuts")
|
||||
|
||||
@ -396,21 +396,14 @@ class WenetSpeechAsrDataModule:
|
||||
@lru_cache()
|
||||
def valid_cuts(self) -> CutSet:
|
||||
logging.info("About to get dev cuts")
|
||||
return load_manifest_lazy(self.args.manifest_dir / "cuts_DEV2.jsonl.gz")
|
||||
return load_manifest_lazy(self.args.manifest_dir / "cuts_DEV.jsonl.gz")
|
||||
|
||||
@lru_cache()
|
||||
def test_net_cuts(self) -> List[CutSet]:
|
||||
logging.info("About to get TEST_NET cuts")
|
||||
return load_manifest_lazy(self.args.manifest_dir / "cuts_TEST_NET2.jsonl.gz")
|
||||
return load_manifest_lazy(self.args.manifest_dir / "cuts_TEST_NET.jsonl.gz")
|
||||
|
||||
@lru_cache()
|
||||
def test_meeting_cuts(self) -> List[CutSet]:
|
||||
logging.info("About to get TEST_MEETING cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "cuts_TEST_MEETING2.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def test_car_cuts(self) -> List[CutSet]:
|
||||
logging.info("About to get TEST_CAR cuts")
|
||||
return load_manifest_lazy(self.args.manifest_dir / "car_test_feats.jsonl.gz")
|
||||
return load_manifest_lazy(self.args.manifest_dir / "cuts_TEST_MEETING.jsonl.gz")
|
||||
|
||||
@ -533,9 +533,12 @@ def decode_one_batch(
|
||||
): hyps
|
||||
}
|
||||
else:
|
||||
return {
|
||||
f"beam_size_{params.beam_size}_context_score_{params.context_score}": hyps
|
||||
}
|
||||
key = f"beam_size_{params.beam_size}"
|
||||
if params.has_contexts:
|
||||
key += f"-context-score-{params.context_score}"
|
||||
else:
|
||||
key += "-no-context-words"
|
||||
return {key: hyps}
|
||||
|
||||
|
||||
def decode_dataset(
|
||||
@ -674,6 +677,12 @@ def main():
|
||||
"modified_beam_search_lm_shallow_fusion",
|
||||
"modified_beam_search_LODR",
|
||||
)
|
||||
|
||||
if os.path.exists(params.context_file):
|
||||
params.has_contexts = True
|
||||
else:
|
||||
params.has_contexts = False
|
||||
|
||||
params.res_dir = params.exp_dir / params.decoding_method
|
||||
|
||||
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||
@ -683,7 +692,10 @@ def main():
|
||||
params.suffix += f"-max-states-{params.max_states}"
|
||||
elif "beam_search" in params.decoding_method:
|
||||
params.suffix += f"-beam-{params.beam_size}"
|
||||
params.suffix += f"-context-score-{params.context_score}"
|
||||
if params.has_contexts:
|
||||
params.suffix += f"-context-score-{params.context_score}"
|
||||
else:
|
||||
params.suffix += "-no-contexts-words"
|
||||
else:
|
||||
params.suffix += f"-context-{params.context_size}"
|
||||
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
||||
@ -851,14 +863,10 @@ def main():
|
||||
|
||||
if params.decoding_method == "modified_beam_search":
|
||||
if os.path.exists(params.context_file):
|
||||
contexts = []
|
||||
contexts_text = []
|
||||
for line in open(params.context_file).readlines():
|
||||
context_list = graph_compiler.texts_to_ids(line.strip())
|
||||
tmp = []
|
||||
for context in context_list:
|
||||
for x in context:
|
||||
tmp.append(x)
|
||||
contexts.append(tmp)
|
||||
contexts_text.append(line.strip())
|
||||
contexts = graph_compiler.texts_to_ids(contexts_text)
|
||||
context_graph = ContextGraph(params.context_score)
|
||||
context_graph.build(contexts)
|
||||
else:
|
||||
@ -882,11 +890,8 @@ def main():
|
||||
test_meeting_cuts = wenetspeech.test_meeting_cuts()
|
||||
test_meeting_dl = wenetspeech.test_dataloaders(test_meeting_cuts)
|
||||
|
||||
test_car_cuts = wenetspeech.test_car_cuts()
|
||||
test_car_dl = wenetspeech.test_dataloaders(test_car_cuts)
|
||||
|
||||
test_sets = ["CAR", "TEST_NET", "DEV", "TEST_MEETING"]
|
||||
test_dls = [test_car_dl, test_net_dl, dev_dl, test_meeting_dl]
|
||||
test_sets = ["DEV", "TEST_NET", "TEST_MEETING"]
|
||||
test_dls = [dev_dl, test_net_dl, test_meeting_dl]
|
||||
|
||||
for test_set, test_dl in zip(test_sets, test_dls):
|
||||
results_dict = decode_dataset(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user