mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-11 10:04:21 +00:00
Add context biasing for zipformer recipe
This commit is contained in:
parent
1ee251c8b3
commit
88067f7566
@ -97,6 +97,7 @@ Usage:
|
|||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
|
import os
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import Dict, List, Optional, Tuple
|
||||||
@ -122,7 +123,7 @@ from beam_search import (
|
|||||||
)
|
)
|
||||||
from train import add_model_arguments, get_model, get_params
|
from train import add_model_arguments, get_model, get_params
|
||||||
|
|
||||||
from icefall import LmScorer, NgramLm
|
from icefall import ContextGraph, LmScorer, NgramLm
|
||||||
from icefall.checkpoint import (
|
from icefall.checkpoint import (
|
||||||
average_checkpoints,
|
average_checkpoints,
|
||||||
average_checkpoints_with_averaged_model,
|
average_checkpoints_with_averaged_model,
|
||||||
@ -347,6 +348,25 @@ def get_parser():
|
|||||||
help="ID of the backoff symbol in the ngram LM",
|
help="ID of the backoff symbol in the ngram LM",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--context-score",
|
||||||
|
type=float,
|
||||||
|
default=2,
|
||||||
|
help="""
|
||||||
|
The bonus score of each token for the context biasing words/phrases.
|
||||||
|
Used only when --decoding_method is modified_beam_search.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--context-file",
|
||||||
|
type=str,
|
||||||
|
default="",
|
||||||
|
help="""
|
||||||
|
The path of the context biasing lists, one word/phrase each line
|
||||||
|
Used only when --decoding_method is modified_beam_search.
|
||||||
|
""",
|
||||||
|
)
|
||||||
add_model_arguments(parser)
|
add_model_arguments(parser)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
@ -359,6 +379,7 @@ def decode_one_batch(
|
|||||||
batch: dict,
|
batch: dict,
|
||||||
word_table: Optional[k2.SymbolTable] = None,
|
word_table: Optional[k2.SymbolTable] = None,
|
||||||
decoding_graph: Optional[k2.Fsa] = None,
|
decoding_graph: Optional[k2.Fsa] = None,
|
||||||
|
context_graph: Optional[ContextGraph] = None,
|
||||||
LM: Optional[LmScorer] = None,
|
LM: Optional[LmScorer] = None,
|
||||||
ngram_lm=None,
|
ngram_lm=None,
|
||||||
ngram_lm_scale: float = 0.0,
|
ngram_lm_scale: float = 0.0,
|
||||||
@ -493,6 +514,7 @@ def decode_one_batch(
|
|||||||
encoder_out=encoder_out,
|
encoder_out=encoder_out,
|
||||||
encoder_out_lens=encoder_out_lens,
|
encoder_out_lens=encoder_out_lens,
|
||||||
beam=params.beam_size,
|
beam=params.beam_size,
|
||||||
|
context_graph=context_graph,
|
||||||
)
|
)
|
||||||
for hyp in sp.decode(hyp_tokens):
|
for hyp in sp.decode(hyp_tokens):
|
||||||
hyps.append(hyp.split())
|
hyps.append(hyp.split())
|
||||||
@ -503,6 +525,7 @@ def decode_one_batch(
|
|||||||
encoder_out_lens=encoder_out_lens,
|
encoder_out_lens=encoder_out_lens,
|
||||||
beam=params.beam_size,
|
beam=params.beam_size,
|
||||||
LM=LM,
|
LM=LM,
|
||||||
|
context_graph=context_graph,
|
||||||
)
|
)
|
||||||
for hyp in sp.decode(hyp_tokens):
|
for hyp in sp.decode(hyp_tokens):
|
||||||
hyps.append(hyp.split())
|
hyps.append(hyp.split())
|
||||||
@ -515,6 +538,7 @@ def decode_one_batch(
|
|||||||
LODR_lm=ngram_lm,
|
LODR_lm=ngram_lm,
|
||||||
LODR_lm_scale=ngram_lm_scale,
|
LODR_lm_scale=ngram_lm_scale,
|
||||||
LM=LM,
|
LM=LM,
|
||||||
|
context_graph=context_graph,
|
||||||
)
|
)
|
||||||
for hyp in sp.decode(hyp_tokens):
|
for hyp in sp.decode(hyp_tokens):
|
||||||
hyps.append(hyp.split())
|
hyps.append(hyp.split())
|
||||||
@ -527,6 +551,7 @@ def decode_one_batch(
|
|||||||
beam=params.beam_size,
|
beam=params.beam_size,
|
||||||
LM=LM,
|
LM=LM,
|
||||||
lm_scale_list=lm_scale_list,
|
lm_scale_list=lm_scale_list,
|
||||||
|
context_graph=context_graph,
|
||||||
)
|
)
|
||||||
elif params.decoding_method == "modified_beam_search_lm_rescore_LODR":
|
elif params.decoding_method == "modified_beam_search_lm_rescore_LODR":
|
||||||
lm_scale_list = [0.02 * i for i in range(2, 30)]
|
lm_scale_list = [0.02 * i for i in range(2, 30)]
|
||||||
@ -539,6 +564,7 @@ def decode_one_batch(
|
|||||||
LODR_lm=ngram_lm,
|
LODR_lm=ngram_lm,
|
||||||
sp=sp,
|
sp=sp,
|
||||||
lm_scale_list=lm_scale_list,
|
lm_scale_list=lm_scale_list,
|
||||||
|
context_graph=context_graph,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
batch_size = encoder_out.size(0)
|
batch_size = encoder_out.size(0)
|
||||||
@ -578,16 +604,24 @@ def decode_one_batch(
|
|||||||
key += f"_ngram_lm_scale_{params.ngram_lm_scale}"
|
key += f"_ngram_lm_scale_{params.ngram_lm_scale}"
|
||||||
|
|
||||||
return {key: hyps}
|
return {key: hyps}
|
||||||
elif params.decoding_method in (
|
elif "modified_beam_search" in params.decoding_method:
|
||||||
"modified_beam_search_lm_rescore",
|
prefix = f"beam_size_{params.beam_size}"
|
||||||
"modified_beam_search_lm_rescore_LODR",
|
if params.has_contexts:
|
||||||
):
|
prefix += f"-context-score-{params.context_score}"
|
||||||
ans = dict()
|
else:
|
||||||
assert ans_dict is not None
|
prefix += "-no-context-words"
|
||||||
for key, hyps in ans_dict.items():
|
if params.decoding_method in (
|
||||||
hyps = [sp.decode(hyp).split() for hyp in hyps]
|
"modified_beam_search_lm_rescore",
|
||||||
ans[f"beam_size_{params.beam_size}_{key}"] = hyps
|
"modified_beam_search_lm_rescore_LODR",
|
||||||
return ans
|
):
|
||||||
|
ans = dict()
|
||||||
|
assert ans_dict is not None
|
||||||
|
for key, hyps in ans_dict.items():
|
||||||
|
hyps = [sp.decode(hyp).split() for hyp in hyps]
|
||||||
|
ans[f"{prefix}_{key}"] = hyps
|
||||||
|
return ans
|
||||||
|
else:
|
||||||
|
return {prefix: hyps}
|
||||||
else:
|
else:
|
||||||
return {f"beam_size_{params.beam_size}": hyps}
|
return {f"beam_size_{params.beam_size}": hyps}
|
||||||
|
|
||||||
@ -599,6 +633,7 @@ def decode_dataset(
|
|||||||
sp: spm.SentencePieceProcessor,
|
sp: spm.SentencePieceProcessor,
|
||||||
word_table: Optional[k2.SymbolTable] = None,
|
word_table: Optional[k2.SymbolTable] = None,
|
||||||
decoding_graph: Optional[k2.Fsa] = None,
|
decoding_graph: Optional[k2.Fsa] = None,
|
||||||
|
context_graph: Optional[ContextGraph] = None,
|
||||||
LM: Optional[LmScorer] = None,
|
LM: Optional[LmScorer] = None,
|
||||||
ngram_lm=None,
|
ngram_lm=None,
|
||||||
ngram_lm_scale: float = 0.0,
|
ngram_lm_scale: float = 0.0,
|
||||||
@ -649,6 +684,7 @@ def decode_dataset(
|
|||||||
model=model,
|
model=model,
|
||||||
sp=sp,
|
sp=sp,
|
||||||
decoding_graph=decoding_graph,
|
decoding_graph=decoding_graph,
|
||||||
|
context_graph=context_graph,
|
||||||
word_table=word_table,
|
word_table=word_table,
|
||||||
batch=batch,
|
batch=batch,
|
||||||
LM=LM,
|
LM=LM,
|
||||||
@ -744,6 +780,11 @@ def main():
|
|||||||
)
|
)
|
||||||
params.res_dir = params.exp_dir / params.decoding_method
|
params.res_dir = params.exp_dir / params.decoding_method
|
||||||
|
|
||||||
|
if os.path.exists(params.context_file):
|
||||||
|
params.has_contexts = True
|
||||||
|
else:
|
||||||
|
params.has_contexts = False
|
||||||
|
|
||||||
if params.iter > 0:
|
if params.iter > 0:
|
||||||
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
|
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
|
||||||
else:
|
else:
|
||||||
@ -770,6 +811,10 @@ def main():
|
|||||||
params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
|
params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
|
||||||
elif "beam_search" in params.decoding_method:
|
elif "beam_search" in params.decoding_method:
|
||||||
params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
|
params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
|
||||||
|
if params.has_contexts:
|
||||||
|
params.suffix += f"-context-score-{params.context_score}"
|
||||||
|
else:
|
||||||
|
params.suffix += "-no-context-words"
|
||||||
else:
|
else:
|
||||||
params.suffix += f"-context-{params.context_size}"
|
params.suffix += f"-context-{params.context_size}"
|
||||||
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
||||||
@ -952,6 +997,18 @@ def main():
|
|||||||
decoding_graph = None
|
decoding_graph = None
|
||||||
word_table = None
|
word_table = None
|
||||||
|
|
||||||
|
if "modified_beam_search" in params.decoding_method:
|
||||||
|
if os.path.exists(params.context_file):
|
||||||
|
contexts = []
|
||||||
|
for line in open(params.context_file).readlines():
|
||||||
|
contexts.append(line.strip())
|
||||||
|
context_graph = ContextGraph(params.context_score)
|
||||||
|
context_graph.build(sp.encode(contexts))
|
||||||
|
else:
|
||||||
|
context_graph = None
|
||||||
|
else:
|
||||||
|
context_graph = None
|
||||||
|
|
||||||
num_param = sum([p.numel() for p in model.parameters()])
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
logging.info(f"Number of model parameters: {num_param}")
|
logging.info(f"Number of model parameters: {num_param}")
|
||||||
|
|
||||||
@ -976,6 +1033,7 @@ def main():
|
|||||||
sp=sp,
|
sp=sp,
|
||||||
word_table=word_table,
|
word_table=word_table,
|
||||||
decoding_graph=decoding_graph,
|
decoding_graph=decoding_graph,
|
||||||
|
context_graph=context_graph,
|
||||||
LM=LM,
|
LM=LM,
|
||||||
ngram_lm=ngram_lm,
|
ngram_lm=ngram_lm,
|
||||||
ngram_lm_scale=ngram_lm_scale,
|
ngram_lm_scale=ngram_lm_scale,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user