mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-27 02:34:21 +00:00
Minor fixes to shallow fussion
This commit is contained in:
parent
e4fa25a780
commit
6a0e41b539
@ -111,6 +111,7 @@ Usage:
|
|||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
|
import os
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import Dict, List, Optional, Tuple
|
||||||
@ -288,7 +289,7 @@ def get_parser():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--lm-type",
|
"--nnlm-type",
|
||||||
type=str,
|
type=str,
|
||||||
default="rnn",
|
default="rnn",
|
||||||
help="Type of NN lm",
|
help="Type of NN lm",
|
||||||
@ -296,10 +297,10 @@ def get_parser():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--lm-scale",
|
"--nnlm-scale",
|
||||||
type=float,
|
type=float,
|
||||||
default=0.3,
|
default=0,
|
||||||
help="""The scale of the neural network LM
|
help="""The scale of the neural network LM, 0 means don't use nnlm shallow fussion.
|
||||||
Used only when `--use-shallow-fusion` is set to True.
|
Used only when `--use-shallow-fusion` is set to True.
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
@ -321,6 +322,47 @@ def get_parser():
|
|||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--backoff-id",
|
||||||
|
type=int,
|
||||||
|
default=500,
|
||||||
|
help="ID of the backoff symbol in the ngram LM",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--lodr-ngram",
|
||||||
|
type=str,
|
||||||
|
help="The path to the lodr ngram",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--lodr-lm-scale",
|
||||||
|
type=float,
|
||||||
|
default=0,
|
||||||
|
help="The scale of lodr ngram, should be less than 0. 0 means don't use lodr.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--context-score",
|
||||||
|
type=float,
|
||||||
|
default=0,
|
||||||
|
help="""
|
||||||
|
The bonus score of each token for the context biasing words/phrases.
|
||||||
|
0 means don't use contextual biasing.
|
||||||
|
Used only when --decoding-method is ctc-prefix-beam-search-shallow-fussion.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--context-file",
|
||||||
|
type=str,
|
||||||
|
default="",
|
||||||
|
help="""
|
||||||
|
The path of the context biasing lists, one word/phrase each line
|
||||||
|
Used only when --decoding-method is ctc-prefix-beam-search-shallow-fussion.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--skip-scoring",
|
"--skip-scoring",
|
||||||
type=str2bool,
|
type=str2bool,
|
||||||
@ -358,7 +400,9 @@ def decode_one_batch(
|
|||||||
batch: dict,
|
batch: dict,
|
||||||
word_table: k2.SymbolTable,
|
word_table: k2.SymbolTable,
|
||||||
G: Optional[k2.Fsa] = None,
|
G: Optional[k2.Fsa] = None,
|
||||||
LM: Optional[LmScorer] = None,
|
NNLM: Optional[LmScorer] = None,
|
||||||
|
LODR_lm: Optional[NgramLm] = None,
|
||||||
|
context_graph: Optional[ContextGraph] = None,
|
||||||
) -> Dict[str, List[List[str]]]:
|
) -> Dict[str, List[List[str]]]:
|
||||||
"""Decode one batch and return the result in a dict. The dict has the
|
"""Decode one batch and return the result in a dict. The dict has the
|
||||||
following format:
|
following format:
|
||||||
@ -466,7 +510,10 @@ def decode_one_batch(
|
|||||||
token_ids = ctc_prefix_beam_search_shallow_fussion(
|
token_ids = ctc_prefix_beam_search_shallow_fussion(
|
||||||
ctc_output=ctc_output,
|
ctc_output=ctc_output,
|
||||||
encoder_out_lens=encoder_out_lens,
|
encoder_out_lens=encoder_out_lens,
|
||||||
LM=LM,
|
NNLM=NNLM,
|
||||||
|
LODR_lm=LODR_lm,
|
||||||
|
LODR_lm_scale=params.lodr_lm_scale,
|
||||||
|
context_graph=context_graph,
|
||||||
)
|
)
|
||||||
# hyps is a list of str, e.g., ['xxx yyy zzz', ...]
|
# hyps is a list of str, e.g., ['xxx yyy zzz', ...]
|
||||||
hyps = bpe_model.decode(token_ids)
|
hyps = bpe_model.decode(token_ids)
|
||||||
@ -649,7 +696,9 @@ def decode_dataset(
|
|||||||
bpe_model: Optional[spm.SentencePieceProcessor],
|
bpe_model: Optional[spm.SentencePieceProcessor],
|
||||||
word_table: k2.SymbolTable,
|
word_table: k2.SymbolTable,
|
||||||
G: Optional[k2.Fsa] = None,
|
G: Optional[k2.Fsa] = None,
|
||||||
LM: Optional[LmScorer] = None,
|
NNLM: Optional[LmScorer] = None,
|
||||||
|
LODR_lm: Optional[NgramLm] = None,
|
||||||
|
context_graph: Optional[ContextGraph] = None,
|
||||||
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
|
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
|
||||||
"""Decode dataset.
|
"""Decode dataset.
|
||||||
|
|
||||||
@ -700,7 +749,9 @@ def decode_dataset(
|
|||||||
batch=batch,
|
batch=batch,
|
||||||
word_table=word_table,
|
word_table=word_table,
|
||||||
G=G,
|
G=G,
|
||||||
LM=LM,
|
NNLM=NNLM,
|
||||||
|
LODR_lm=LODR_lm,
|
||||||
|
context_graph=context_graph,
|
||||||
)
|
)
|
||||||
|
|
||||||
for name, hyps in hyps_dict.items():
|
for name, hyps in hyps_dict.items():
|
||||||
@ -835,7 +886,12 @@ def main():
|
|||||||
if "prefix-beam-search" in params.decoding_method:
|
if "prefix-beam-search" in params.decoding_method:
|
||||||
params.suffix += f"_beam-{params.beam}"
|
params.suffix += f"_beam-{params.beam}"
|
||||||
if params.decoding_method == "ctc-prefix-beam-search-shallow-fussion":
|
if params.decoding_method == "ctc-prefix-beam-search-shallow-fussion":
|
||||||
params.suffix += f"_lm-scale-{params.lm_scale}"
|
if params.nnlm_scale != 0:
|
||||||
|
params.suffix += f"_nnlm-scale-{params.nnlm_scale}"
|
||||||
|
if params.lodr_lm_scale != 0:
|
||||||
|
params.suffix += f"_lodr-scale-{params.lodr_lm_scale}"
|
||||||
|
if params.context_score != 0:
|
||||||
|
params.suffix += f"_context_score-{params.context_score}"
|
||||||
|
|
||||||
if params.use_averaged_model:
|
if params.use_averaged_model:
|
||||||
params.suffix += "_use-averaged-model"
|
params.suffix += "_use-averaged-model"
|
||||||
@ -947,17 +1003,49 @@ def main():
|
|||||||
G = None
|
G = None
|
||||||
|
|
||||||
# only load the neural network LM if required
|
# only load the neural network LM if required
|
||||||
if params.decoding_method == "ctc-prefix-beam-search-shallow-fussion":
|
NNLM = None
|
||||||
LM = LmScorer(
|
if (
|
||||||
lm_type=params.lm_type,
|
params.decoding_method == "ctc-prefix-beam-search-shallow-fussion"
|
||||||
|
and params.nnlm_scale != 0
|
||||||
|
):
|
||||||
|
NNLM = LmScorer(
|
||||||
|
lm_type=params.nnlm_type,
|
||||||
params=params,
|
params=params,
|
||||||
device=device,
|
device=device,
|
||||||
lm_scale=params.lm_scale,
|
lm_scale=params.nnlm_scale,
|
||||||
)
|
)
|
||||||
LM.to(device)
|
NNLM.to(device)
|
||||||
LM.eval()
|
NNLM.eval()
|
||||||
else:
|
|
||||||
LM = None
|
LODR_lm = None
|
||||||
|
if (
|
||||||
|
params.decoding_method == "ctc-prefix-beam-search-shallow-fussion"
|
||||||
|
and params.lodr_lm_scale != 0
|
||||||
|
):
|
||||||
|
assert os.path.exists(
|
||||||
|
params.lodr_ngram
|
||||||
|
), f"LODR ngram does not exists, given path : {params.lodr_ngram}"
|
||||||
|
logging.info(f"Loading LODR (token level lm): {params.lodr_ngram}")
|
||||||
|
LODR_lm = NgramLm(
|
||||||
|
params.lodr_ngram,
|
||||||
|
backoff_id=params.backoff_id,
|
||||||
|
is_binary=False,
|
||||||
|
)
|
||||||
|
logging.info(f"num states: {LODR_lm.lm.num_states}")
|
||||||
|
|
||||||
|
context_graph = None
|
||||||
|
if (
|
||||||
|
params.decoding_method == "ctc-prefix-beam-search-shallow-fussion"
|
||||||
|
and params.context_score != 0
|
||||||
|
):
|
||||||
|
assert os.path.exists(
|
||||||
|
params.context_file
|
||||||
|
), f"context_file does not exists, given path : {params.context_file}"
|
||||||
|
contexts = []
|
||||||
|
for line in open(params.context_file).readlines():
|
||||||
|
contexts.append(bpe_model.encode(line.strip()))
|
||||||
|
context_graph = ContextGraph(params.context_score)
|
||||||
|
context_graph.build(contexts)
|
||||||
|
|
||||||
logging.info("About to create model")
|
logging.info("About to create model")
|
||||||
model = get_model(params)
|
model = get_model(params)
|
||||||
@ -1068,7 +1156,9 @@ def main():
|
|||||||
bpe_model=bpe_model,
|
bpe_model=bpe_model,
|
||||||
word_table=lexicon.word_table,
|
word_table=lexicon.word_table,
|
||||||
G=G,
|
G=G,
|
||||||
LM=LM,
|
NNLM=NNLM,
|
||||||
|
LODR_lm=LODR_lm,
|
||||||
|
context_graph=context_graph,
|
||||||
)
|
)
|
||||||
|
|
||||||
save_asr_output(
|
save_asr_output(
|
||||||
|
@ -1736,7 +1736,7 @@ def _step_worker(
|
|||||||
B: HypothesisList,
|
B: HypothesisList,
|
||||||
beam: int = 4,
|
beam: int = 4,
|
||||||
blank_id: int = 0,
|
blank_id: int = 0,
|
||||||
lm_scale: float = 0,
|
nnlm_scale: float = 0,
|
||||||
LODR_lm_scale: float = 0,
|
LODR_lm_scale: float = 0,
|
||||||
context_graph: Optional[ContextGraph] = None,
|
context_graph: Optional[ContextGraph] = None,
|
||||||
) -> HypothesisList:
|
) -> HypothesisList:
|
||||||
@ -1815,14 +1815,16 @@ def _step_worker(
|
|||||||
if update_prefix:
|
if update_prefix:
|
||||||
lm_score = hyp.lm_score
|
lm_score = hyp.lm_score
|
||||||
if hyp.lm_log_probs is not None:
|
if hyp.lm_log_probs is not None:
|
||||||
lm_score += hyp.lm_log_probs[new_token] * lm_scale
|
lm_score = lm_score + hyp.lm_log_probs[new_token] * nnlm_scale
|
||||||
new_hyp.lm_log_probs = None
|
new_hyp.lm_log_probs = None
|
||||||
|
|
||||||
if context_graph is not None and hyp.context_state is not None:
|
if context_graph is not None and hyp.context_state is not None:
|
||||||
context_score, new_context_state = context_graph.forward_one_step(
|
(
|
||||||
hyp.context_state, new_token
|
context_score,
|
||||||
)
|
new_context_state,
|
||||||
lm_score += context_score
|
matched_state,
|
||||||
|
) = context_graph.forward_one_step(hyp.context_state, new_token)
|
||||||
|
lm_score = lm_score + context_score
|
||||||
new_hyp.context_state = new_context_state
|
new_hyp.context_state = new_context_state
|
||||||
|
|
||||||
if hyp.LODR_state is not None:
|
if hyp.LODR_state is not None:
|
||||||
@ -1833,7 +1835,7 @@ def _step_worker(
|
|||||||
state_cost.lm_score,
|
state_cost.lm_score,
|
||||||
hyp.LODR_state.lm_score,
|
hyp.LODR_state.lm_score,
|
||||||
)
|
)
|
||||||
lm_score += LODR_lm_scale * current_ngram_score
|
lm_score = lm_score + LODR_lm_scale * current_ngram_score
|
||||||
new_hyp.LODR_state = state_cost
|
new_hyp.LODR_state = state_cost
|
||||||
|
|
||||||
new_hyp.lm_score = lm_score
|
new_hyp.lm_score = lm_score
|
||||||
@ -1944,7 +1946,7 @@ def ctc_prefix_beam_search_shallow_fussion(
|
|||||||
blank_id: int = 0,
|
blank_id: int = 0,
|
||||||
LODR_lm: Optional[NgramLm] = None,
|
LODR_lm: Optional[NgramLm] = None,
|
||||||
LODR_lm_scale: Optional[float] = 0,
|
LODR_lm_scale: Optional[float] = 0,
|
||||||
LM: Optional[LmScorer] = None,
|
NNLM: Optional[LmScorer] = None,
|
||||||
context_graph: Optional[ContextGraph] = None,
|
context_graph: Optional[ContextGraph] = None,
|
||||||
) -> List[List[int]]:
|
) -> List[List[int]]:
|
||||||
"""Implement prefix search decoding in "Connectionist Temporal Classification:
|
"""Implement prefix search decoding in "Connectionist Temporal Classification:
|
||||||
@ -1981,17 +1983,16 @@ def ctc_prefix_beam_search_shallow_fussion(
|
|||||||
encoder_out_lens = encoder_out_lens.tolist()
|
encoder_out_lens = encoder_out_lens.tolist()
|
||||||
device = ctc_output.device
|
device = ctc_output.device
|
||||||
|
|
||||||
lm_scale = 0
|
nnlm_scale = 0
|
||||||
init_scores = None
|
init_scores = None
|
||||||
init_states = None
|
init_states = None
|
||||||
|
if NNLM is not None:
|
||||||
if LM is not None:
|
nnlm_scale = NNLM.lm_scale
|
||||||
lm_scale = LM.lm_scale
|
sos_id = getattr(NNLM, "sos_id", 1)
|
||||||
sos_id = getattr(LM, "sos_id", 1)
|
|
||||||
# get initial lm score and lm state by scoring the "sos" token
|
# get initial lm score and lm state by scoring the "sos" token
|
||||||
sos_token = torch.tensor([[sos_id]]).to(torch.int64).to(device)
|
sos_token = torch.tensor([[sos_id]]).to(torch.int64).to(device)
|
||||||
lens = torch.tensor([1]).to(device)
|
lens = torch.tensor([1]).to(device)
|
||||||
init_scores, init_states = LM.score_token(sos_token, lens)
|
init_scores, init_states = NNLM.score_token(sos_token, lens)
|
||||||
init_scores, init_states = init_scores.cpu(), (
|
init_scores, init_states = init_scores.cpu(), (
|
||||||
init_states[0].cpu(),
|
init_states[0].cpu(),
|
||||||
init_states[1].cpu(),
|
init_states[1].cpu(),
|
||||||
@ -2016,16 +2017,16 @@ def ctc_prefix_beam_search_shallow_fussion(
|
|||||||
if j < encoder_out_lens[i]:
|
if j < encoder_out_lens[i]:
|
||||||
log_probs, indexes = topk_values[i][j], topk_indexes[i][j]
|
log_probs, indexes = topk_values[i][j], topk_indexes[i][j]
|
||||||
B[i] = _step_worker(
|
B[i] = _step_worker(
|
||||||
log_probs,
|
log_probs=log_probs,
|
||||||
indexes,
|
indexes=indexes,
|
||||||
B[i],
|
B=B[i],
|
||||||
beam,
|
beam=beam,
|
||||||
blank_id,
|
blank_id=blank_id,
|
||||||
lm_scale=lm_scale,
|
nnlm_scale=nnlm_scale,
|
||||||
LODR_lm_scale=LODR_lm_scale,
|
LODR_lm_scale=LODR_lm_scale,
|
||||||
context_graph=context_graph,
|
context_graph=context_graph,
|
||||||
)
|
)
|
||||||
if LM is None:
|
if NNLM is None:
|
||||||
continue
|
continue
|
||||||
# update lm_log_probs
|
# update lm_log_probs
|
||||||
token_list = [] # a list of list
|
token_list = [] # a list of list
|
||||||
@ -2035,7 +2036,7 @@ def ctc_prefix_beam_search_shallow_fussion(
|
|||||||
for batch_idx, hyps in enumerate(B):
|
for batch_idx, hyps in enumerate(B):
|
||||||
for hyp in hyps:
|
for hyp in hyps:
|
||||||
if hyp.lm_log_probs is None: # those hyps that prefix changes
|
if hyp.lm_log_probs is None: # those hyps that prefix changes
|
||||||
if LM.lm_type == "rnn":
|
if NNLM.lm_type == "rnn":
|
||||||
token_list.append([hyp.ys[-1]])
|
token_list.append([hyp.ys[-1]])
|
||||||
# store the LSTM states
|
# store the LSTM states
|
||||||
hs.append(hyp.state[0])
|
hs.append(hyp.state[0])
|
||||||
@ -2046,7 +2047,7 @@ def ctc_prefix_beam_search_shallow_fussion(
|
|||||||
indexes.append((batch_idx, hyp.key))
|
indexes.append((batch_idx, hyp.key))
|
||||||
if len(token_list) != 0:
|
if len(token_list) != 0:
|
||||||
x_lens = torch.tensor([len(tokens) for tokens in token_list]).to(device)
|
x_lens = torch.tensor([len(tokens) for tokens in token_list]).to(device)
|
||||||
if LM.lm_type == "rnn":
|
if NNLM.lm_type == "rnn":
|
||||||
tokens_to_score = (
|
tokens_to_score = (
|
||||||
torch.tensor(token_list).to(torch.int64).to(device).reshape(-1, 1)
|
torch.tensor(token_list).to(torch.int64).to(device).reshape(-1, 1)
|
||||||
)
|
)
|
||||||
@ -2065,13 +2066,13 @@ def ctc_prefix_beam_search_shallow_fussion(
|
|||||||
)
|
)
|
||||||
state = None
|
state = None
|
||||||
|
|
||||||
scores, lm_states = LM.score_token(tokens_to_score, x_lens, state)
|
scores, lm_states = NNLM.score_token(tokens_to_score, x_lens, state)
|
||||||
scores, lm_states = scores.cpu(), (lm_states[0].cpu(), lm_states[1].cpu())
|
scores, lm_states = scores.cpu(), (lm_states[0].cpu(), lm_states[1].cpu())
|
||||||
assert scores.size(0) == len(indexes), (scores.size(0), len(indexes))
|
assert scores.size(0) == len(indexes), (scores.size(0), len(indexes))
|
||||||
for i in range(scores.size(0)):
|
for i in range(scores.size(0)):
|
||||||
batch_idx, key = indexes[i]
|
batch_idx, key = indexes[i]
|
||||||
B[batch_idx][key].lm_log_probs = scores[i]
|
B[batch_idx][key].lm_log_probs = scores[i]
|
||||||
if LM.lm_type == "rnn":
|
if NNLM.lm_type == "rnn":
|
||||||
state = (
|
state = (
|
||||||
lm_states[0][:, i, :].unsqueeze(1),
|
lm_states[0][:, i, :].unsqueeze(1),
|
||||||
lm_states[1][:, i, :].unsqueeze(1),
|
lm_states[1][:, i, :].unsqueeze(1),
|
||||||
|
Loading…
x
Reference in New Issue
Block a user