Add context biasing for zipformer recipe

This commit is contained in:
pkufool 2023-08-08 11:28:21 +08:00
parent 1ee251c8b3
commit 88067f7566

View File

@ -97,6 +97,7 @@ Usage:
import argparse import argparse
import logging import logging
import math import math
import os
from collections import defaultdict from collections import defaultdict
from pathlib import Path from pathlib import Path
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
@ -122,7 +123,7 @@ from beam_search import (
) )
from train import add_model_arguments, get_model, get_params from train import add_model_arguments, get_model, get_params
from icefall import LmScorer, NgramLm from icefall import ContextGraph, LmScorer, NgramLm
from icefall.checkpoint import ( from icefall.checkpoint import (
average_checkpoints, average_checkpoints,
average_checkpoints_with_averaged_model, average_checkpoints_with_averaged_model,
@ -347,6 +348,25 @@ def get_parser():
help="ID of the backoff symbol in the ngram LM", help="ID of the backoff symbol in the ngram LM",
) )
parser.add_argument(
"--context-score",
type=float,
default=2,
help="""
The bonus score of each token for the context biasing words/phrases.
Used only when --decoding_method is modified_beam_search.
""",
)
parser.add_argument(
"--context-file",
type=str,
default="",
help="""
The path of the context biasing lists, one word/phrase each line
Used only when --decoding_method is modified_beam_search.
""",
)
add_model_arguments(parser) add_model_arguments(parser)
return parser return parser
@ -359,6 +379,7 @@ def decode_one_batch(
batch: dict, batch: dict,
word_table: Optional[k2.SymbolTable] = None, word_table: Optional[k2.SymbolTable] = None,
decoding_graph: Optional[k2.Fsa] = None, decoding_graph: Optional[k2.Fsa] = None,
context_graph: Optional[ContextGraph] = None,
LM: Optional[LmScorer] = None, LM: Optional[LmScorer] = None,
ngram_lm=None, ngram_lm=None,
ngram_lm_scale: float = 0.0, ngram_lm_scale: float = 0.0,
@ -493,6 +514,7 @@ def decode_one_batch(
encoder_out=encoder_out, encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens, encoder_out_lens=encoder_out_lens,
beam=params.beam_size, beam=params.beam_size,
context_graph=context_graph,
) )
for hyp in sp.decode(hyp_tokens): for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split()) hyps.append(hyp.split())
@ -503,6 +525,7 @@ def decode_one_batch(
encoder_out_lens=encoder_out_lens, encoder_out_lens=encoder_out_lens,
beam=params.beam_size, beam=params.beam_size,
LM=LM, LM=LM,
context_graph=context_graph,
) )
for hyp in sp.decode(hyp_tokens): for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split()) hyps.append(hyp.split())
@ -515,6 +538,7 @@ def decode_one_batch(
LODR_lm=ngram_lm, LODR_lm=ngram_lm,
LODR_lm_scale=ngram_lm_scale, LODR_lm_scale=ngram_lm_scale,
LM=LM, LM=LM,
context_graph=context_graph,
) )
for hyp in sp.decode(hyp_tokens): for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split()) hyps.append(hyp.split())
@ -527,6 +551,7 @@ def decode_one_batch(
beam=params.beam_size, beam=params.beam_size,
LM=LM, LM=LM,
lm_scale_list=lm_scale_list, lm_scale_list=lm_scale_list,
context_graph=context_graph,
) )
elif params.decoding_method == "modified_beam_search_lm_rescore_LODR": elif params.decoding_method == "modified_beam_search_lm_rescore_LODR":
lm_scale_list = [0.02 * i for i in range(2, 30)] lm_scale_list = [0.02 * i for i in range(2, 30)]
@ -539,6 +564,7 @@ def decode_one_batch(
LODR_lm=ngram_lm, LODR_lm=ngram_lm,
sp=sp, sp=sp,
lm_scale_list=lm_scale_list, lm_scale_list=lm_scale_list,
context_graph=context_graph,
) )
else: else:
batch_size = encoder_out.size(0) batch_size = encoder_out.size(0)
@ -578,16 +604,24 @@ def decode_one_batch(
key += f"_ngram_lm_scale_{params.ngram_lm_scale}" key += f"_ngram_lm_scale_{params.ngram_lm_scale}"
return {key: hyps} return {key: hyps}
elif params.decoding_method in ( elif "modified_beam_search" in params.decoding_method:
"modified_beam_search_lm_rescore", prefix = f"beam_size_{params.beam_size}"
"modified_beam_search_lm_rescore_LODR", if params.has_contexts:
): prefix += f"-context-score-{params.context_score}"
ans = dict() else:
assert ans_dict is not None prefix += "-no-context-words"
for key, hyps in ans_dict.items(): if params.decoding_method in (
hyps = [sp.decode(hyp).split() for hyp in hyps] "modified_beam_search_lm_rescore",
ans[f"beam_size_{params.beam_size}_{key}"] = hyps "modified_beam_search_lm_rescore_LODR",
return ans ):
ans = dict()
assert ans_dict is not None
for key, hyps in ans_dict.items():
hyps = [sp.decode(hyp).split() for hyp in hyps]
ans[f"{prefix}_{key}"] = hyps
return ans
else:
return {prefix: hyps}
else: else:
return {f"beam_size_{params.beam_size}": hyps} return {f"beam_size_{params.beam_size}": hyps}
@ -599,6 +633,7 @@ def decode_dataset(
sp: spm.SentencePieceProcessor, sp: spm.SentencePieceProcessor,
word_table: Optional[k2.SymbolTable] = None, word_table: Optional[k2.SymbolTable] = None,
decoding_graph: Optional[k2.Fsa] = None, decoding_graph: Optional[k2.Fsa] = None,
context_graph: Optional[ContextGraph] = None,
LM: Optional[LmScorer] = None, LM: Optional[LmScorer] = None,
ngram_lm=None, ngram_lm=None,
ngram_lm_scale: float = 0.0, ngram_lm_scale: float = 0.0,
@ -649,6 +684,7 @@ def decode_dataset(
model=model, model=model,
sp=sp, sp=sp,
decoding_graph=decoding_graph, decoding_graph=decoding_graph,
context_graph=context_graph,
word_table=word_table, word_table=word_table,
batch=batch, batch=batch,
LM=LM, LM=LM,
@ -744,6 +780,11 @@ def main():
) )
params.res_dir = params.exp_dir / params.decoding_method params.res_dir = params.exp_dir / params.decoding_method
if os.path.exists(params.context_file):
params.has_contexts = True
else:
params.has_contexts = False
if params.iter > 0: if params.iter > 0:
params.suffix = f"iter-{params.iter}-avg-{params.avg}" params.suffix = f"iter-{params.iter}-avg-{params.avg}"
else: else:
@ -770,6 +811,10 @@ def main():
params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
elif "beam_search" in params.decoding_method: elif "beam_search" in params.decoding_method:
params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
if params.has_contexts:
params.suffix += f"-context-score-{params.context_score}"
else:
params.suffix += "-no-context-words"
else: else:
params.suffix += f"-context-{params.context_size}" params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
@ -952,6 +997,18 @@ def main():
decoding_graph = None decoding_graph = None
word_table = None word_table = None
if "modified_beam_search" in params.decoding_method:
if os.path.exists(params.context_file):
contexts = []
for line in open(params.context_file).readlines():
contexts.append(line.strip())
context_graph = ContextGraph(params.context_score)
context_graph.build(sp.encode(contexts))
else:
context_graph = None
else:
context_graph = None
num_param = sum([p.numel() for p in model.parameters()]) num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}") logging.info(f"Number of model parameters: {num_param}")
@ -976,6 +1033,7 @@ def main():
sp=sp, sp=sp,
word_table=word_table, word_table=word_table,
decoding_graph=decoding_graph, decoding_graph=decoding_graph,
context_graph=context_graph,
LM=LM, LM=LM,
ngram_lm=ngram_lm, ngram_lm=ngram_lm,
ngram_lm_scale=ngram_lm_scale, ngram_lm_scale=ngram_lm_scale,