Minor fixes to shallow fussion

This commit is contained in:
pkufool 2024-10-09 11:16:38 +08:00
parent e4fa25a780
commit 6a0e41b539
2 changed files with 134 additions and 43 deletions

View File

@ -111,6 +111,7 @@ Usage:
import argparse import argparse
import logging import logging
import math import math
import os
from collections import defaultdict from collections import defaultdict
from pathlib import Path from pathlib import Path
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
@ -288,7 +289,7 @@ def get_parser():
) )
parser.add_argument( parser.add_argument(
"--lm-type", "--nnlm-type",
type=str, type=str,
default="rnn", default="rnn",
help="Type of NN lm", help="Type of NN lm",
@ -296,10 +297,10 @@ def get_parser():
) )
parser.add_argument( parser.add_argument(
"--lm-scale", "--nnlm-scale",
type=float, type=float,
default=0.3, default=0,
help="""The scale of the neural network LM help="""The scale of the neural network LM, 0 means don't use nnlm shallow fussion.
Used only when `--use-shallow-fusion` is set to True. Used only when `--use-shallow-fusion` is set to True.
""", """,
) )
@ -321,6 +322,47 @@ def get_parser():
""", """,
) )
parser.add_argument(
"--backoff-id",
type=int,
default=500,
help="ID of the backoff symbol in the ngram LM",
)
parser.add_argument(
"--lodr-ngram",
type=str,
help="The path to the lodr ngram",
)
parser.add_argument(
"--lodr-lm-scale",
type=float,
default=0,
help="The scale of lodr ngram, should be less than 0. 0 means don't use lodr.",
)
parser.add_argument(
"--context-score",
type=float,
default=0,
help="""
The bonus score of each token for the context biasing words/phrases.
0 means don't use contextual biasing.
Used only when --decoding-method is ctc-prefix-beam-search-shallow-fussion.
""",
)
parser.add_argument(
"--context-file",
type=str,
default="",
help="""
The path of the context biasing lists, one word/phrase each line
Used only when --decoding-method is ctc-prefix-beam-search-shallow-fussion.
""",
)
parser.add_argument( parser.add_argument(
"--skip-scoring", "--skip-scoring",
type=str2bool, type=str2bool,
@ -358,7 +400,9 @@ def decode_one_batch(
batch: dict, batch: dict,
word_table: k2.SymbolTable, word_table: k2.SymbolTable,
G: Optional[k2.Fsa] = None, G: Optional[k2.Fsa] = None,
LM: Optional[LmScorer] = None, NNLM: Optional[LmScorer] = None,
LODR_lm: Optional[NgramLm] = None,
context_graph: Optional[ContextGraph] = None,
) -> Dict[str, List[List[str]]]: ) -> Dict[str, List[List[str]]]:
"""Decode one batch and return the result in a dict. The dict has the """Decode one batch and return the result in a dict. The dict has the
following format: following format:
@ -466,7 +510,10 @@ def decode_one_batch(
token_ids = ctc_prefix_beam_search_shallow_fussion( token_ids = ctc_prefix_beam_search_shallow_fussion(
ctc_output=ctc_output, ctc_output=ctc_output,
encoder_out_lens=encoder_out_lens, encoder_out_lens=encoder_out_lens,
LM=LM, NNLM=NNLM,
LODR_lm=LODR_lm,
LODR_lm_scale=params.lodr_lm_scale,
context_graph=context_graph,
) )
# hyps is a list of str, e.g., ['xxx yyy zzz', ...] # hyps is a list of str, e.g., ['xxx yyy zzz', ...]
hyps = bpe_model.decode(token_ids) hyps = bpe_model.decode(token_ids)
@ -649,7 +696,9 @@ def decode_dataset(
bpe_model: Optional[spm.SentencePieceProcessor], bpe_model: Optional[spm.SentencePieceProcessor],
word_table: k2.SymbolTable, word_table: k2.SymbolTable,
G: Optional[k2.Fsa] = None, G: Optional[k2.Fsa] = None,
LM: Optional[LmScorer] = None, NNLM: Optional[LmScorer] = None,
LODR_lm: Optional[NgramLm] = None,
context_graph: Optional[ContextGraph] = None,
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: ) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
"""Decode dataset. """Decode dataset.
@ -700,7 +749,9 @@ def decode_dataset(
batch=batch, batch=batch,
word_table=word_table, word_table=word_table,
G=G, G=G,
LM=LM, NNLM=NNLM,
LODR_lm=LODR_lm,
context_graph=context_graph,
) )
for name, hyps in hyps_dict.items(): for name, hyps in hyps_dict.items():
@ -835,7 +886,12 @@ def main():
if "prefix-beam-search" in params.decoding_method: if "prefix-beam-search" in params.decoding_method:
params.suffix += f"_beam-{params.beam}" params.suffix += f"_beam-{params.beam}"
if params.decoding_method == "ctc-prefix-beam-search-shallow-fussion": if params.decoding_method == "ctc-prefix-beam-search-shallow-fussion":
params.suffix += f"_lm-scale-{params.lm_scale}" if params.nnlm_scale != 0:
params.suffix += f"_nnlm-scale-{params.nnlm_scale}"
if params.lodr_lm_scale != 0:
params.suffix += f"_lodr-scale-{params.lodr_lm_scale}"
if params.context_score != 0:
params.suffix += f"_context_score-{params.context_score}"
if params.use_averaged_model: if params.use_averaged_model:
params.suffix += "_use-averaged-model" params.suffix += "_use-averaged-model"
@ -947,17 +1003,49 @@ def main():
G = None G = None
# only load the neural network LM if required # only load the neural network LM if required
if params.decoding_method == "ctc-prefix-beam-search-shallow-fussion": NNLM = None
LM = LmScorer( if (
lm_type=params.lm_type, params.decoding_method == "ctc-prefix-beam-search-shallow-fussion"
and params.nnlm_scale != 0
):
NNLM = LmScorer(
lm_type=params.nnlm_type,
params=params, params=params,
device=device, device=device,
lm_scale=params.lm_scale, lm_scale=params.nnlm_scale,
) )
LM.to(device) NNLM.to(device)
LM.eval() NNLM.eval()
else:
LM = None LODR_lm = None
if (
params.decoding_method == "ctc-prefix-beam-search-shallow-fussion"
and params.lodr_lm_scale != 0
):
assert os.path.exists(
params.lodr_ngram
), f"LODR ngram does not exists, given path : {params.lodr_ngram}"
logging.info(f"Loading LODR (token level lm): {params.lodr_ngram}")
LODR_lm = NgramLm(
params.lodr_ngram,
backoff_id=params.backoff_id,
is_binary=False,
)
logging.info(f"num states: {LODR_lm.lm.num_states}")
context_graph = None
if (
params.decoding_method == "ctc-prefix-beam-search-shallow-fussion"
and params.context_score != 0
):
assert os.path.exists(
params.context_file
), f"context_file does not exists, given path : {params.context_file}"
contexts = []
for line in open(params.context_file).readlines():
contexts.append(bpe_model.encode(line.strip()))
context_graph = ContextGraph(params.context_score)
context_graph.build(contexts)
logging.info("About to create model") logging.info("About to create model")
model = get_model(params) model = get_model(params)
@ -1068,7 +1156,9 @@ def main():
bpe_model=bpe_model, bpe_model=bpe_model,
word_table=lexicon.word_table, word_table=lexicon.word_table,
G=G, G=G,
LM=LM, NNLM=NNLM,
LODR_lm=LODR_lm,
context_graph=context_graph,
) )
save_asr_output( save_asr_output(

View File

@ -1736,7 +1736,7 @@ def _step_worker(
B: HypothesisList, B: HypothesisList,
beam: int = 4, beam: int = 4,
blank_id: int = 0, blank_id: int = 0,
lm_scale: float = 0, nnlm_scale: float = 0,
LODR_lm_scale: float = 0, LODR_lm_scale: float = 0,
context_graph: Optional[ContextGraph] = None, context_graph: Optional[ContextGraph] = None,
) -> HypothesisList: ) -> HypothesisList:
@ -1815,14 +1815,16 @@ def _step_worker(
if update_prefix: if update_prefix:
lm_score = hyp.lm_score lm_score = hyp.lm_score
if hyp.lm_log_probs is not None: if hyp.lm_log_probs is not None:
lm_score += hyp.lm_log_probs[new_token] * lm_scale lm_score = lm_score + hyp.lm_log_probs[new_token] * nnlm_scale
new_hyp.lm_log_probs = None new_hyp.lm_log_probs = None
if context_graph is not None and hyp.context_state is not None: if context_graph is not None and hyp.context_state is not None:
context_score, new_context_state = context_graph.forward_one_step( (
hyp.context_state, new_token context_score,
) new_context_state,
lm_score += context_score matched_state,
) = context_graph.forward_one_step(hyp.context_state, new_token)
lm_score = lm_score + context_score
new_hyp.context_state = new_context_state new_hyp.context_state = new_context_state
if hyp.LODR_state is not None: if hyp.LODR_state is not None:
@ -1833,7 +1835,7 @@ def _step_worker(
state_cost.lm_score, state_cost.lm_score,
hyp.LODR_state.lm_score, hyp.LODR_state.lm_score,
) )
lm_score += LODR_lm_scale * current_ngram_score lm_score = lm_score + LODR_lm_scale * current_ngram_score
new_hyp.LODR_state = state_cost new_hyp.LODR_state = state_cost
new_hyp.lm_score = lm_score new_hyp.lm_score = lm_score
@ -1944,7 +1946,7 @@ def ctc_prefix_beam_search_shallow_fussion(
blank_id: int = 0, blank_id: int = 0,
LODR_lm: Optional[NgramLm] = None, LODR_lm: Optional[NgramLm] = None,
LODR_lm_scale: Optional[float] = 0, LODR_lm_scale: Optional[float] = 0,
LM: Optional[LmScorer] = None, NNLM: Optional[LmScorer] = None,
context_graph: Optional[ContextGraph] = None, context_graph: Optional[ContextGraph] = None,
) -> List[List[int]]: ) -> List[List[int]]:
"""Implement prefix search decoding in "Connectionist Temporal Classification: """Implement prefix search decoding in "Connectionist Temporal Classification:
@ -1981,17 +1983,16 @@ def ctc_prefix_beam_search_shallow_fussion(
encoder_out_lens = encoder_out_lens.tolist() encoder_out_lens = encoder_out_lens.tolist()
device = ctc_output.device device = ctc_output.device
lm_scale = 0 nnlm_scale = 0
init_scores = None init_scores = None
init_states = None init_states = None
if NNLM is not None:
if LM is not None: nnlm_scale = NNLM.lm_scale
lm_scale = LM.lm_scale sos_id = getattr(NNLM, "sos_id", 1)
sos_id = getattr(LM, "sos_id", 1)
# get initial lm score and lm state by scoring the "sos" token # get initial lm score and lm state by scoring the "sos" token
sos_token = torch.tensor([[sos_id]]).to(torch.int64).to(device) sos_token = torch.tensor([[sos_id]]).to(torch.int64).to(device)
lens = torch.tensor([1]).to(device) lens = torch.tensor([1]).to(device)
init_scores, init_states = LM.score_token(sos_token, lens) init_scores, init_states = NNLM.score_token(sos_token, lens)
init_scores, init_states = init_scores.cpu(), ( init_scores, init_states = init_scores.cpu(), (
init_states[0].cpu(), init_states[0].cpu(),
init_states[1].cpu(), init_states[1].cpu(),
@ -2016,16 +2017,16 @@ def ctc_prefix_beam_search_shallow_fussion(
if j < encoder_out_lens[i]: if j < encoder_out_lens[i]:
log_probs, indexes = topk_values[i][j], topk_indexes[i][j] log_probs, indexes = topk_values[i][j], topk_indexes[i][j]
B[i] = _step_worker( B[i] = _step_worker(
log_probs, log_probs=log_probs,
indexes, indexes=indexes,
B[i], B=B[i],
beam, beam=beam,
blank_id, blank_id=blank_id,
lm_scale=lm_scale, nnlm_scale=nnlm_scale,
LODR_lm_scale=LODR_lm_scale, LODR_lm_scale=LODR_lm_scale,
context_graph=context_graph, context_graph=context_graph,
) )
if LM is None: if NNLM is None:
continue continue
# update lm_log_probs # update lm_log_probs
token_list = [] # a list of list token_list = [] # a list of list
@ -2035,7 +2036,7 @@ def ctc_prefix_beam_search_shallow_fussion(
for batch_idx, hyps in enumerate(B): for batch_idx, hyps in enumerate(B):
for hyp in hyps: for hyp in hyps:
if hyp.lm_log_probs is None: # those hyps that prefix changes if hyp.lm_log_probs is None: # those hyps that prefix changes
if LM.lm_type == "rnn": if NNLM.lm_type == "rnn":
token_list.append([hyp.ys[-1]]) token_list.append([hyp.ys[-1]])
# store the LSTM states # store the LSTM states
hs.append(hyp.state[0]) hs.append(hyp.state[0])
@ -2046,7 +2047,7 @@ def ctc_prefix_beam_search_shallow_fussion(
indexes.append((batch_idx, hyp.key)) indexes.append((batch_idx, hyp.key))
if len(token_list) != 0: if len(token_list) != 0:
x_lens = torch.tensor([len(tokens) for tokens in token_list]).to(device) x_lens = torch.tensor([len(tokens) for tokens in token_list]).to(device)
if LM.lm_type == "rnn": if NNLM.lm_type == "rnn":
tokens_to_score = ( tokens_to_score = (
torch.tensor(token_list).to(torch.int64).to(device).reshape(-1, 1) torch.tensor(token_list).to(torch.int64).to(device).reshape(-1, 1)
) )
@ -2065,13 +2066,13 @@ def ctc_prefix_beam_search_shallow_fussion(
) )
state = None state = None
scores, lm_states = LM.score_token(tokens_to_score, x_lens, state) scores, lm_states = NNLM.score_token(tokens_to_score, x_lens, state)
scores, lm_states = scores.cpu(), (lm_states[0].cpu(), lm_states[1].cpu()) scores, lm_states = scores.cpu(), (lm_states[0].cpu(), lm_states[1].cpu())
assert scores.size(0) == len(indexes), (scores.size(0), len(indexes)) assert scores.size(0) == len(indexes), (scores.size(0), len(indexes))
for i in range(scores.size(0)): for i in range(scores.size(0)):
batch_idx, key = indexes[i] batch_idx, key = indexes[i]
B[batch_idx][key].lm_log_probs = scores[i] B[batch_idx][key].lm_log_probs = scores[i]
if LM.lm_type == "rnn": if NNLM.lm_type == "rnn":
state = ( state = (
lm_states[0][:, i, :].unsqueeze(1), lm_states[0][:, i, :].unsqueeze(1),
lm_states[1][:, i, :].unsqueeze(1), lm_states[1][:, i, :].unsqueeze(1),