mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
Minor fixes to shallow fussion
This commit is contained in:
parent
e4fa25a780
commit
6a0e41b539
@ -111,6 +111,7 @@ Usage:
|
||||
import argparse
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
@ -288,7 +289,7 @@ def get_parser():
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lm-type",
|
||||
"--nnlm-type",
|
||||
type=str,
|
||||
default="rnn",
|
||||
help="Type of NN lm",
|
||||
@ -296,10 +297,10 @@ def get_parser():
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lm-scale",
|
||||
"--nnlm-scale",
|
||||
type=float,
|
||||
default=0.3,
|
||||
help="""The scale of the neural network LM
|
||||
default=0,
|
||||
help="""The scale of the neural network LM, 0 means don't use nnlm shallow fussion.
|
||||
Used only when `--use-shallow-fusion` is set to True.
|
||||
""",
|
||||
)
|
||||
@ -321,6 +322,47 @@ def get_parser():
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--backoff-id",
|
||||
type=int,
|
||||
default=500,
|
||||
help="ID of the backoff symbol in the ngram LM",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lodr-ngram",
|
||||
type=str,
|
||||
help="The path to the lodr ngram",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lodr-lm-scale",
|
||||
type=float,
|
||||
default=0,
|
||||
help="The scale of lodr ngram, should be less than 0. 0 means don't use lodr.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-score",
|
||||
type=float,
|
||||
default=0,
|
||||
help="""
|
||||
The bonus score of each token for the context biasing words/phrases.
|
||||
0 means don't use contextual biasing.
|
||||
Used only when --decoding-method is ctc-prefix-beam-search-shallow-fussion.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-file",
|
||||
type=str,
|
||||
default="",
|
||||
help="""
|
||||
The path of the context biasing lists, one word/phrase each line
|
||||
Used only when --decoding-method is ctc-prefix-beam-search-shallow-fussion.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--skip-scoring",
|
||||
type=str2bool,
|
||||
@ -358,7 +400,9 @@ def decode_one_batch(
|
||||
batch: dict,
|
||||
word_table: k2.SymbolTable,
|
||||
G: Optional[k2.Fsa] = None,
|
||||
LM: Optional[LmScorer] = None,
|
||||
NNLM: Optional[LmScorer] = None,
|
||||
LODR_lm: Optional[NgramLm] = None,
|
||||
context_graph: Optional[ContextGraph] = None,
|
||||
) -> Dict[str, List[List[str]]]:
|
||||
"""Decode one batch and return the result in a dict. The dict has the
|
||||
following format:
|
||||
@ -466,7 +510,10 @@ def decode_one_batch(
|
||||
token_ids = ctc_prefix_beam_search_shallow_fussion(
|
||||
ctc_output=ctc_output,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
LM=LM,
|
||||
NNLM=NNLM,
|
||||
LODR_lm=LODR_lm,
|
||||
LODR_lm_scale=params.lodr_lm_scale,
|
||||
context_graph=context_graph,
|
||||
)
|
||||
# hyps is a list of str, e.g., ['xxx yyy zzz', ...]
|
||||
hyps = bpe_model.decode(token_ids)
|
||||
@ -649,7 +696,9 @@ def decode_dataset(
|
||||
bpe_model: Optional[spm.SentencePieceProcessor],
|
||||
word_table: k2.SymbolTable,
|
||||
G: Optional[k2.Fsa] = None,
|
||||
LM: Optional[LmScorer] = None,
|
||||
NNLM: Optional[LmScorer] = None,
|
||||
LODR_lm: Optional[NgramLm] = None,
|
||||
context_graph: Optional[ContextGraph] = None,
|
||||
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
|
||||
"""Decode dataset.
|
||||
|
||||
@ -700,7 +749,9 @@ def decode_dataset(
|
||||
batch=batch,
|
||||
word_table=word_table,
|
||||
G=G,
|
||||
LM=LM,
|
||||
NNLM=NNLM,
|
||||
LODR_lm=LODR_lm,
|
||||
context_graph=context_graph,
|
||||
)
|
||||
|
||||
for name, hyps in hyps_dict.items():
|
||||
@ -835,7 +886,12 @@ def main():
|
||||
if "prefix-beam-search" in params.decoding_method:
|
||||
params.suffix += f"_beam-{params.beam}"
|
||||
if params.decoding_method == "ctc-prefix-beam-search-shallow-fussion":
|
||||
params.suffix += f"_lm-scale-{params.lm_scale}"
|
||||
if params.nnlm_scale != 0:
|
||||
params.suffix += f"_nnlm-scale-{params.nnlm_scale}"
|
||||
if params.lodr_lm_scale != 0:
|
||||
params.suffix += f"_lodr-scale-{params.lodr_lm_scale}"
|
||||
if params.context_score != 0:
|
||||
params.suffix += f"_context_score-{params.context_score}"
|
||||
|
||||
if params.use_averaged_model:
|
||||
params.suffix += "_use-averaged-model"
|
||||
@ -947,17 +1003,49 @@ def main():
|
||||
G = None
|
||||
|
||||
# only load the neural network LM if required
|
||||
if params.decoding_method == "ctc-prefix-beam-search-shallow-fussion":
|
||||
LM = LmScorer(
|
||||
lm_type=params.lm_type,
|
||||
NNLM = None
|
||||
if (
|
||||
params.decoding_method == "ctc-prefix-beam-search-shallow-fussion"
|
||||
and params.nnlm_scale != 0
|
||||
):
|
||||
NNLM = LmScorer(
|
||||
lm_type=params.nnlm_type,
|
||||
params=params,
|
||||
device=device,
|
||||
lm_scale=params.lm_scale,
|
||||
lm_scale=params.nnlm_scale,
|
||||
)
|
||||
LM.to(device)
|
||||
LM.eval()
|
||||
else:
|
||||
LM = None
|
||||
NNLM.to(device)
|
||||
NNLM.eval()
|
||||
|
||||
LODR_lm = None
|
||||
if (
|
||||
params.decoding_method == "ctc-prefix-beam-search-shallow-fussion"
|
||||
and params.lodr_lm_scale != 0
|
||||
):
|
||||
assert os.path.exists(
|
||||
params.lodr_ngram
|
||||
), f"LODR ngram does not exists, given path : {params.lodr_ngram}"
|
||||
logging.info(f"Loading LODR (token level lm): {params.lodr_ngram}")
|
||||
LODR_lm = NgramLm(
|
||||
params.lodr_ngram,
|
||||
backoff_id=params.backoff_id,
|
||||
is_binary=False,
|
||||
)
|
||||
logging.info(f"num states: {LODR_lm.lm.num_states}")
|
||||
|
||||
context_graph = None
|
||||
if (
|
||||
params.decoding_method == "ctc-prefix-beam-search-shallow-fussion"
|
||||
and params.context_score != 0
|
||||
):
|
||||
assert os.path.exists(
|
||||
params.context_file
|
||||
), f"context_file does not exists, given path : {params.context_file}"
|
||||
contexts = []
|
||||
for line in open(params.context_file).readlines():
|
||||
contexts.append(bpe_model.encode(line.strip()))
|
||||
context_graph = ContextGraph(params.context_score)
|
||||
context_graph.build(contexts)
|
||||
|
||||
logging.info("About to create model")
|
||||
model = get_model(params)
|
||||
@ -1068,7 +1156,9 @@ def main():
|
||||
bpe_model=bpe_model,
|
||||
word_table=lexicon.word_table,
|
||||
G=G,
|
||||
LM=LM,
|
||||
NNLM=NNLM,
|
||||
LODR_lm=LODR_lm,
|
||||
context_graph=context_graph,
|
||||
)
|
||||
|
||||
save_asr_output(
|
||||
|
@ -1736,7 +1736,7 @@ def _step_worker(
|
||||
B: HypothesisList,
|
||||
beam: int = 4,
|
||||
blank_id: int = 0,
|
||||
lm_scale: float = 0,
|
||||
nnlm_scale: float = 0,
|
||||
LODR_lm_scale: float = 0,
|
||||
context_graph: Optional[ContextGraph] = None,
|
||||
) -> HypothesisList:
|
||||
@ -1815,14 +1815,16 @@ def _step_worker(
|
||||
if update_prefix:
|
||||
lm_score = hyp.lm_score
|
||||
if hyp.lm_log_probs is not None:
|
||||
lm_score += hyp.lm_log_probs[new_token] * lm_scale
|
||||
lm_score = lm_score + hyp.lm_log_probs[new_token] * nnlm_scale
|
||||
new_hyp.lm_log_probs = None
|
||||
|
||||
if context_graph is not None and hyp.context_state is not None:
|
||||
context_score, new_context_state = context_graph.forward_one_step(
|
||||
hyp.context_state, new_token
|
||||
)
|
||||
lm_score += context_score
|
||||
(
|
||||
context_score,
|
||||
new_context_state,
|
||||
matched_state,
|
||||
) = context_graph.forward_one_step(hyp.context_state, new_token)
|
||||
lm_score = lm_score + context_score
|
||||
new_hyp.context_state = new_context_state
|
||||
|
||||
if hyp.LODR_state is not None:
|
||||
@ -1833,7 +1835,7 @@ def _step_worker(
|
||||
state_cost.lm_score,
|
||||
hyp.LODR_state.lm_score,
|
||||
)
|
||||
lm_score += LODR_lm_scale * current_ngram_score
|
||||
lm_score = lm_score + LODR_lm_scale * current_ngram_score
|
||||
new_hyp.LODR_state = state_cost
|
||||
|
||||
new_hyp.lm_score = lm_score
|
||||
@ -1944,7 +1946,7 @@ def ctc_prefix_beam_search_shallow_fussion(
|
||||
blank_id: int = 0,
|
||||
LODR_lm: Optional[NgramLm] = None,
|
||||
LODR_lm_scale: Optional[float] = 0,
|
||||
LM: Optional[LmScorer] = None,
|
||||
NNLM: Optional[LmScorer] = None,
|
||||
context_graph: Optional[ContextGraph] = None,
|
||||
) -> List[List[int]]:
|
||||
"""Implement prefix search decoding in "Connectionist Temporal Classification:
|
||||
@ -1981,17 +1983,16 @@ def ctc_prefix_beam_search_shallow_fussion(
|
||||
encoder_out_lens = encoder_out_lens.tolist()
|
||||
device = ctc_output.device
|
||||
|
||||
lm_scale = 0
|
||||
nnlm_scale = 0
|
||||
init_scores = None
|
||||
init_states = None
|
||||
|
||||
if LM is not None:
|
||||
lm_scale = LM.lm_scale
|
||||
sos_id = getattr(LM, "sos_id", 1)
|
||||
if NNLM is not None:
|
||||
nnlm_scale = NNLM.lm_scale
|
||||
sos_id = getattr(NNLM, "sos_id", 1)
|
||||
# get initial lm score and lm state by scoring the "sos" token
|
||||
sos_token = torch.tensor([[sos_id]]).to(torch.int64).to(device)
|
||||
lens = torch.tensor([1]).to(device)
|
||||
init_scores, init_states = LM.score_token(sos_token, lens)
|
||||
init_scores, init_states = NNLM.score_token(sos_token, lens)
|
||||
init_scores, init_states = init_scores.cpu(), (
|
||||
init_states[0].cpu(),
|
||||
init_states[1].cpu(),
|
||||
@ -2016,16 +2017,16 @@ def ctc_prefix_beam_search_shallow_fussion(
|
||||
if j < encoder_out_lens[i]:
|
||||
log_probs, indexes = topk_values[i][j], topk_indexes[i][j]
|
||||
B[i] = _step_worker(
|
||||
log_probs,
|
||||
indexes,
|
||||
B[i],
|
||||
beam,
|
||||
blank_id,
|
||||
lm_scale=lm_scale,
|
||||
log_probs=log_probs,
|
||||
indexes=indexes,
|
||||
B=B[i],
|
||||
beam=beam,
|
||||
blank_id=blank_id,
|
||||
nnlm_scale=nnlm_scale,
|
||||
LODR_lm_scale=LODR_lm_scale,
|
||||
context_graph=context_graph,
|
||||
)
|
||||
if LM is None:
|
||||
if NNLM is None:
|
||||
continue
|
||||
# update lm_log_probs
|
||||
token_list = [] # a list of list
|
||||
@ -2035,7 +2036,7 @@ def ctc_prefix_beam_search_shallow_fussion(
|
||||
for batch_idx, hyps in enumerate(B):
|
||||
for hyp in hyps:
|
||||
if hyp.lm_log_probs is None: # those hyps that prefix changes
|
||||
if LM.lm_type == "rnn":
|
||||
if NNLM.lm_type == "rnn":
|
||||
token_list.append([hyp.ys[-1]])
|
||||
# store the LSTM states
|
||||
hs.append(hyp.state[0])
|
||||
@ -2046,7 +2047,7 @@ def ctc_prefix_beam_search_shallow_fussion(
|
||||
indexes.append((batch_idx, hyp.key))
|
||||
if len(token_list) != 0:
|
||||
x_lens = torch.tensor([len(tokens) for tokens in token_list]).to(device)
|
||||
if LM.lm_type == "rnn":
|
||||
if NNLM.lm_type == "rnn":
|
||||
tokens_to_score = (
|
||||
torch.tensor(token_list).to(torch.int64).to(device).reshape(-1, 1)
|
||||
)
|
||||
@ -2065,13 +2066,13 @@ def ctc_prefix_beam_search_shallow_fussion(
|
||||
)
|
||||
state = None
|
||||
|
||||
scores, lm_states = LM.score_token(tokens_to_score, x_lens, state)
|
||||
scores, lm_states = NNLM.score_token(tokens_to_score, x_lens, state)
|
||||
scores, lm_states = scores.cpu(), (lm_states[0].cpu(), lm_states[1].cpu())
|
||||
assert scores.size(0) == len(indexes), (scores.size(0), len(indexes))
|
||||
for i in range(scores.size(0)):
|
||||
batch_idx, key = indexes[i]
|
||||
B[batch_idx][key].lm_log_probs = scores[i]
|
||||
if LM.lm_type == "rnn":
|
||||
if NNLM.lm_type == "rnn":
|
||||
state = (
|
||||
lm_states[0][:, i, :].unsqueeze(1),
|
||||
lm_states[1][:, i, :].unsqueeze(1),
|
||||
|
Loading…
x
Reference in New Issue
Block a user