Minor fixes to the decode.py

This commit is contained in:
pkufool 2023-06-02 16:25:44 +08:00
parent 8a94476fd9
commit 57c1d762b6
6 changed files with 109 additions and 66 deletions

View File

@ -58,6 +58,7 @@ Usage:
import argparse import argparse
import logging import logging
import os
from collections import defaultdict from collections import defaultdict
from pathlib import Path from pathlib import Path
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
@ -76,6 +77,8 @@ from beam_search import (
) )
from train import add_model_arguments, get_params, get_transducer_model from train import add_model_arguments, get_params, get_transducer_model
from icefall import ContextGraph, LmScorer, NgramLm
from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
from icefall.checkpoint import ( from icefall.checkpoint import (
average_checkpoints, average_checkpoints,
average_checkpoints_with_averaged_model, average_checkpoints_with_averaged_model,
@ -211,6 +214,26 @@ def get_parser():
Used only when --decoding_method is greedy_search""", Used only when --decoding_method is greedy_search""",
) )
parser.add_argument(
"--context-score",
type=float,
default=2,
help="""
The bonus score of each token for the context biasing words/phrases.
Used only when --decoding_method is modified_beam_search.
""",
)
parser.add_argument(
"--context-file",
type=str,
default="",
help="""
The path of the context biasing lists, one word/phrase each line
Used only when --decoding_method is modified_beam_search.
""",
)
add_model_arguments(parser) add_model_arguments(parser)
return parser return parser
@ -222,6 +245,7 @@ def decode_one_batch(
token_table: k2.SymbolTable, token_table: k2.SymbolTable,
batch: dict, batch: dict,
decoding_graph: Optional[k2.Fsa] = None, decoding_graph: Optional[k2.Fsa] = None,
context_graph: Optional[ContextGraph] = None,
) -> Dict[str, List[List[str]]]: ) -> Dict[str, List[List[str]]]:
"""Decode one batch and return the result in a dict. The dict has the """Decode one batch and return the result in a dict. The dict has the
following format: following format:
@ -285,6 +309,7 @@ def decode_one_batch(
encoder_out=encoder_out, encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens, encoder_out_lens=encoder_out_lens,
beam=params.beam_size, beam=params.beam_size,
context_graph=context_graph,
) )
else: else:
hyp_tokens = [] hyp_tokens = []
@ -324,7 +349,12 @@ def decode_one_batch(
): hyps ): hyps
} }
else: else:
return {f"beam_size_{params.beam_size}": hyps} key = f"beam_size_{params.beam_size}"
if params.has_contexts:
key += f"-context-score-{params.context_score}"
else:
key += "-no-context-words"
return {key: hyps}
def decode_dataset( def decode_dataset(
@ -333,6 +363,7 @@ def decode_dataset(
model: nn.Module, model: nn.Module,
token_table: k2.SymbolTable, token_table: k2.SymbolTable,
decoding_graph: Optional[k2.Fsa] = None, decoding_graph: Optional[k2.Fsa] = None,
context_graph: Optional[ContextGraph] = None,
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: ) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
"""Decode dataset. """Decode dataset.
@ -377,6 +408,7 @@ def decode_dataset(
model=model, model=model,
token_table=token_table, token_table=token_table,
decoding_graph=decoding_graph, decoding_graph=decoding_graph,
context_graph=context_graph,
batch=batch, batch=batch,
) )
@ -407,16 +439,17 @@ def save_results(
for key, results in results_dict.items(): for key, results in results_dict.items():
recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
results = sorted(results) results = sorted(results)
store_transcripts(filename=recog_path, texts=results) # we compute CER for aishell dataset.
results_char = []
for res in results:
results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
store_transcripts(filename=recog_path, texts=results_char)
logging.info(f"The transcripts are stored in {recog_path}") logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned # The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs. # ref/hyp pairs.
errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
# we compute CER for aishell dataset.
results_char = []
for res in results:
results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
with open(errs_filename, "w") as f: with open(errs_filename, "w") as f:
wer = write_error_stats( wer = write_error_stats(
f, f"{test_set_name}-{key}", results_char, enable_log=True f, f"{test_set_name}-{key}", results_char, enable_log=True
@ -457,6 +490,12 @@ def main():
"fast_beam_search", "fast_beam_search",
"modified_beam_search", "modified_beam_search",
) )
if os.path.exists(params.context_file):
params.has_contexts = True
else:
params.has_contexts = False
params.res_dir = params.exp_dir / params.decoding_method params.res_dir = params.exp_dir / params.decoding_method
if params.iter > 0: if params.iter > 0:
@ -470,6 +509,10 @@ def main():
params.suffix += f"-max-states-{params.max_states}" params.suffix += f"-max-states-{params.max_states}"
elif "beam_search" in params.decoding_method: elif "beam_search" in params.decoding_method:
params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
if params.has_contexts:
params.suffix += f"-context-score-{params.context_score}"
else:
params.suffix += "-no-contexts-words"
else: else:
params.suffix += f"-context-{params.context_size}" params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
@ -490,6 +533,11 @@ def main():
params.blank_id = 0 params.blank_id = 0
params.vocab_size = max(lexicon.tokens) + 1 params.vocab_size = max(lexicon.tokens) + 1
graph_compiler = CharCtcTrainingGraphCompiler(
lexicon=lexicon,
device=device,
)
logging.info(params) logging.info(params)
logging.info("About to create model") logging.info("About to create model")
@ -586,6 +634,19 @@ def main():
else: else:
decoding_graph = None decoding_graph = None
if params.decoding_method == "modified_beam_search":
if os.path.exists(params.context_file):
contexts_text = []
for line in open(params.context_file).readlines():
contexts_text.append(line.strip())
contexts = graph_compiler.texts_to_ids(contexts_text)
context_graph = ContextGraph(params.context_score)
context_graph.build(contexts)
else:
context_graph = None
else:
context_graph = None
num_param = sum([p.numel() for p in model.parameters()]) num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}") logging.info(f"Number of model parameters: {num_param}")
@ -608,6 +669,7 @@ def main():
model=model, model=model,
token_table=lexicon.token_table, token_table=lexicon.token_table,
decoding_graph=decoding_graph, decoding_graph=decoding_graph,
context_graph=context_graph,
) )
save_results( save_results(

View File

@ -15,7 +15,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging
import warnings import warnings
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Dict, List, Optional, Tuple, Union from typing import Dict, List, Optional, Tuple, Union

View File

@ -131,8 +131,6 @@ from pathlib import Path
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
import k2 import k2
import kaldifst
import graphviz
import sentencepiece as spm import sentencepiece as spm
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -576,7 +574,10 @@ def decode_one_batch(
return {key: (hyps, timestamps)} return {key: (hyps, timestamps)}
else: else:
key = f"beam_size_{params.beam_size}" key = f"beam_size_{params.beam_size}"
if params.has_contexts:
key += f"-context-score-{params.context_score}" key += f"-context-score-{params.context_score}"
else:
key += "-no-context-words"
return {key: (hyps, timestamps)} return {key: (hyps, timestamps)}
@ -626,7 +627,7 @@ def decode_dataset(
if params.decoding_method == "greedy_search": if params.decoding_method == "greedy_search":
log_interval = 50 log_interval = 50
else: else:
log_interval = 1 log_interval = 20
results = defaultdict(list) results = defaultdict(list)
for batch_idx, batch in enumerate(dl): for batch_idx, batch in enumerate(dl):
@ -759,6 +760,12 @@ def main():
"fast_beam_search_nbest_oracle", "fast_beam_search_nbest_oracle",
"modified_beam_search", "modified_beam_search",
) )
if os.path.exists(params.context_file):
params.has_contexts = True
else:
params.has_contexts = False
params.res_dir = params.exp_dir / params.decoding_method params.res_dir = params.exp_dir / params.decoding_method
if params.iter > 0: if params.iter > 0:
@ -781,7 +788,10 @@ def main():
params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
elif "beam_search" in params.decoding_method: elif "beam_search" in params.decoding_method:
params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
if params.has_contexts:
params.suffix += f"-context-score-{params.context_score}" params.suffix += f"-context-score-{params.context_score}"
else:
params.suffix += "-no-context-words"
else: else:
params.suffix += f"-context-{params.context_size}" params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
@ -938,14 +948,8 @@ def main():
test_clean_dl = librispeech.test_dataloaders(test_clean_cuts) test_clean_dl = librispeech.test_dataloaders(test_clean_cuts)
test_other_dl = librispeech.test_dataloaders(test_other_cuts) test_other_dl = librispeech.test_dataloaders(test_other_cuts)
test_book_cuts = librispeech.test_book_cuts() test_sets = ["test-clean", "test-other"]
test_book_dl = librispeech.test_dataloaders(test_book_cuts) test_dl = [test_clean_dl, test_other_dl]
test_book2_cuts = librispeech.test_book2_cuts()
test_book2_dl = librispeech.test_dataloaders(test_book2_cuts)
test_sets = ["test-book", "test-book2", "test-clean", "test-other"]
test_dl = [test_book_dl, test_book2_dl, test_clean_dl, test_other_dl]
for test_set, test_dl in zip(test_sets, test_dl): for test_set, test_dl in zip(test_sets, test_dl):
results_dict = decode_dataset( results_dict = decode_dataset(

View File

@ -389,7 +389,6 @@ class LibriSpeechAsrDataModule:
) )
sampler = DynamicBucketingSampler( sampler = DynamicBucketingSampler(
cuts, cuts,
num_buckets=2,
max_duration=self.args.max_duration, max_duration=self.args.max_duration,
shuffle=False, shuffle=False,
) )
@ -468,25 +467,6 @@ class LibriSpeechAsrDataModule:
self.args.manifest_dir / "librispeech_cuts_test-clean.jsonl.gz" self.args.manifest_dir / "librispeech_cuts_test-clean.jsonl.gz"
) )
@lru_cache()
def test_book_cuts(self) -> CutSet:
logging.info("About to get test-books cuts")
return load_manifest_lazy(self.args.manifest_dir / "libri_books_feats.jsonl.gz")
@lru_cache()
def test_book_test_cuts(self) -> CutSet:
logging.info("About to get test-books cuts")
return load_manifest_lazy(
self.args.manifest_dir / "libri_book_test_feats.jsonl.gz"
)
@lru_cache()
def test_book2_cuts(self) -> CutSet:
logging.info("About to get test-books2 cuts")
return load_manifest_lazy(
self.args.manifest_dir / "libri_books2_feats.jsonl.gz"
)
@lru_cache() @lru_cache()
def test_other_cuts(self) -> CutSet: def test_other_cuts(self) -> CutSet:
logging.info("About to get test-other cuts") logging.info("About to get test-other cuts")

View File

@ -396,21 +396,14 @@ class WenetSpeechAsrDataModule:
@lru_cache() @lru_cache()
def valid_cuts(self) -> CutSet: def valid_cuts(self) -> CutSet:
logging.info("About to get dev cuts") logging.info("About to get dev cuts")
return load_manifest_lazy(self.args.manifest_dir / "cuts_DEV2.jsonl.gz") return load_manifest_lazy(self.args.manifest_dir / "cuts_DEV.jsonl.gz")
@lru_cache() @lru_cache()
def test_net_cuts(self) -> List[CutSet]: def test_net_cuts(self) -> List[CutSet]:
logging.info("About to get TEST_NET cuts") logging.info("About to get TEST_NET cuts")
return load_manifest_lazy(self.args.manifest_dir / "cuts_TEST_NET2.jsonl.gz") return load_manifest_lazy(self.args.manifest_dir / "cuts_TEST_NET.jsonl.gz")
@lru_cache() @lru_cache()
def test_meeting_cuts(self) -> List[CutSet]: def test_meeting_cuts(self) -> List[CutSet]:
logging.info("About to get TEST_MEETING cuts") logging.info("About to get TEST_MEETING cuts")
return load_manifest_lazy( return load_manifest_lazy(self.args.manifest_dir / "cuts_TEST_MEETING.jsonl.gz")
self.args.manifest_dir / "cuts_TEST_MEETING2.jsonl.gz"
)
@lru_cache()
def test_car_cuts(self) -> List[CutSet]:
logging.info("About to get TEST_CAR cuts")
return load_manifest_lazy(self.args.manifest_dir / "car_test_feats.jsonl.gz")

View File

@ -533,9 +533,12 @@ def decode_one_batch(
): hyps ): hyps
} }
else: else:
return { key = f"beam_size_{params.beam_size}"
f"beam_size_{params.beam_size}_context_score_{params.context_score}": hyps if params.has_contexts:
} key += f"-context-score-{params.context_score}"
else:
key += "-no-context-words"
return {key: hyps}
def decode_dataset( def decode_dataset(
@ -674,6 +677,12 @@ def main():
"modified_beam_search_lm_shallow_fusion", "modified_beam_search_lm_shallow_fusion",
"modified_beam_search_LODR", "modified_beam_search_LODR",
) )
if os.path.exists(params.context_file):
params.has_contexts = True
else:
params.has_contexts = False
params.res_dir = params.exp_dir / params.decoding_method params.res_dir = params.exp_dir / params.decoding_method
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
@ -683,7 +692,10 @@ def main():
params.suffix += f"-max-states-{params.max_states}" params.suffix += f"-max-states-{params.max_states}"
elif "beam_search" in params.decoding_method: elif "beam_search" in params.decoding_method:
params.suffix += f"-beam-{params.beam_size}" params.suffix += f"-beam-{params.beam_size}"
if params.has_contexts:
params.suffix += f"-context-score-{params.context_score}" params.suffix += f"-context-score-{params.context_score}"
else:
params.suffix += "-no-contexts-words"
else: else:
params.suffix += f"-context-{params.context_size}" params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
@ -851,14 +863,10 @@ def main():
if params.decoding_method == "modified_beam_search": if params.decoding_method == "modified_beam_search":
if os.path.exists(params.context_file): if os.path.exists(params.context_file):
contexts = [] contexts_text = []
for line in open(params.context_file).readlines(): for line in open(params.context_file).readlines():
context_list = graph_compiler.texts_to_ids(line.strip()) contexts_text.append(line.strip())
tmp = [] contexts = graph_compiler.texts_to_ids(contexts_text)
for context in context_list:
for x in context:
tmp.append(x)
contexts.append(tmp)
context_graph = ContextGraph(params.context_score) context_graph = ContextGraph(params.context_score)
context_graph.build(contexts) context_graph.build(contexts)
else: else:
@ -882,11 +890,8 @@ def main():
test_meeting_cuts = wenetspeech.test_meeting_cuts() test_meeting_cuts = wenetspeech.test_meeting_cuts()
test_meeting_dl = wenetspeech.test_dataloaders(test_meeting_cuts) test_meeting_dl = wenetspeech.test_dataloaders(test_meeting_cuts)
test_car_cuts = wenetspeech.test_car_cuts() test_sets = ["DEV", "TEST_NET", "TEST_MEETING"]
test_car_dl = wenetspeech.test_dataloaders(test_car_cuts) test_dls = [dev_dl, test_net_dl, test_meeting_dl]
test_sets = ["CAR", "TEST_NET", "DEV", "TEST_MEETING"]
test_dls = [test_car_dl, test_net_dl, dev_dl, test_meeting_dl]
for test_set, test_dl in zip(test_sets, test_dls): for test_set, test_dl in zip(test_sets, test_dls):
results_dict = decode_dataset( results_dict = decode_dataset(