Add context biasing for zipformer recipe

This commit is contained in:
pkufool 2023-08-08 11:28:21 +08:00
parent 1ee251c8b3
commit 88067f7566

View File

@ -97,6 +97,7 @@ Usage:
import argparse
import logging
import math
import os
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Optional, Tuple
@ -122,7 +123,7 @@ from beam_search import (
)
from train import add_model_arguments, get_model, get_params
from icefall import LmScorer, NgramLm
from icefall import ContextGraph, LmScorer, NgramLm
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
@ -347,6 +348,25 @@ def get_parser():
help="ID of the backoff symbol in the ngram LM",
)
parser.add_argument(
"--context-score",
type=float,
default=2,
help="""
The bonus score of each token for the context biasing words/phrases.
Used only when --decoding_method is modified_beam_search.
""",
)
parser.add_argument(
"--context-file",
type=str,
default="",
help="""
The path of the context biasing lists, one word/phrase each line
Used only when --decoding_method is modified_beam_search.
""",
)
add_model_arguments(parser)
return parser
@ -359,6 +379,7 @@ def decode_one_batch(
batch: dict,
word_table: Optional[k2.SymbolTable] = None,
decoding_graph: Optional[k2.Fsa] = None,
context_graph: Optional[ContextGraph] = None,
LM: Optional[LmScorer] = None,
ngram_lm=None,
ngram_lm_scale: float = 0.0,
@ -493,6 +514,7 @@ def decode_one_batch(
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size,
context_graph=context_graph,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
@ -503,6 +525,7 @@ def decode_one_batch(
encoder_out_lens=encoder_out_lens,
beam=params.beam_size,
LM=LM,
context_graph=context_graph,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
@ -515,6 +538,7 @@ def decode_one_batch(
LODR_lm=ngram_lm,
LODR_lm_scale=ngram_lm_scale,
LM=LM,
context_graph=context_graph,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
@ -527,6 +551,7 @@ def decode_one_batch(
beam=params.beam_size,
LM=LM,
lm_scale_list=lm_scale_list,
context_graph=context_graph,
)
elif params.decoding_method == "modified_beam_search_lm_rescore_LODR":
lm_scale_list = [0.02 * i for i in range(2, 30)]
@ -539,6 +564,7 @@ def decode_one_batch(
LODR_lm=ngram_lm,
sp=sp,
lm_scale_list=lm_scale_list,
context_graph=context_graph,
)
else:
batch_size = encoder_out.size(0)
@ -578,7 +604,13 @@ def decode_one_batch(
key += f"_ngram_lm_scale_{params.ngram_lm_scale}"
return {key: hyps}
elif params.decoding_method in (
elif "modified_beam_search" in params.decoding_method:
prefix = f"beam_size_{params.beam_size}"
if params.has_contexts:
prefix += f"-context-score-{params.context_score}"
else:
prefix += "-no-context-words"
if params.decoding_method in (
"modified_beam_search_lm_rescore",
"modified_beam_search_lm_rescore_LODR",
):
@ -586,8 +618,10 @@ def decode_one_batch(
assert ans_dict is not None
for key, hyps in ans_dict.items():
hyps = [sp.decode(hyp).split() for hyp in hyps]
ans[f"beam_size_{params.beam_size}_{key}"] = hyps
ans[f"{prefix}_{key}"] = hyps
return ans
else:
return {prefix: hyps}
else:
return {f"beam_size_{params.beam_size}": hyps}
@ -599,6 +633,7 @@ def decode_dataset(
sp: spm.SentencePieceProcessor,
word_table: Optional[k2.SymbolTable] = None,
decoding_graph: Optional[k2.Fsa] = None,
context_graph: Optional[ContextGraph] = None,
LM: Optional[LmScorer] = None,
ngram_lm=None,
ngram_lm_scale: float = 0.0,
@ -649,6 +684,7 @@ def decode_dataset(
model=model,
sp=sp,
decoding_graph=decoding_graph,
context_graph=context_graph,
word_table=word_table,
batch=batch,
LM=LM,
@ -744,6 +780,11 @@ def main():
)
params.res_dir = params.exp_dir / params.decoding_method
if os.path.exists(params.context_file):
params.has_contexts = True
else:
params.has_contexts = False
if params.iter > 0:
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
else:
@ -770,6 +811,10 @@ def main():
params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
elif "beam_search" in params.decoding_method:
params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
if params.has_contexts:
params.suffix += f"-context-score-{params.context_score}"
else:
params.suffix += "-no-context-words"
else:
params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
@ -952,6 +997,18 @@ def main():
decoding_graph = None
word_table = None
if "modified_beam_search" in params.decoding_method:
if os.path.exists(params.context_file):
contexts = []
for line in open(params.context_file).readlines():
contexts.append(line.strip())
context_graph = ContextGraph(params.context_score)
context_graph.build(sp.encode(contexts))
else:
context_graph = None
else:
context_graph = None
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
@ -976,6 +1033,7 @@ def main():
sp=sp,
word_table=word_table,
decoding_graph=decoding_graph,
context_graph=context_graph,
LM=LM,
ngram_lm=ngram_lm,
ngram_lm_scale=ngram_lm_scale,