mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 10:16:14 +00:00
Add librispeech prefix-beam-search
This commit is contained in:
parent
cef16574f7
commit
02e00ff504
@ -123,6 +123,10 @@ from asr_datamodule import LibriSpeechAsrDataModule
|
||||
from lhotse import set_caching_enabled
|
||||
from train import add_model_arguments, get_model, get_params
|
||||
|
||||
from icefall.context_graph import ContextGraph, ContextState
|
||||
from icefall.ngram_lm import NgramLm, NgramLmStateCost
|
||||
from icefall.lm_wrapper import LmScorer
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
average_checkpoints_with_averaged_model,
|
||||
@ -131,6 +135,9 @@ from icefall.checkpoint import (
|
||||
)
|
||||
from icefall.decode import (
|
||||
ctc_greedy_search,
|
||||
ctc_prefix_beam_search,
|
||||
ctc_prefix_beam_search_attention_decoder_rescoring,
|
||||
ctc_prefix_beam_search_shallow_fussion,
|
||||
get_lattice,
|
||||
nbest_decoding,
|
||||
nbest_oracle,
|
||||
@ -280,6 +287,23 @@ def get_parser():
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lm-type",
|
||||
type=str,
|
||||
default="rnn",
|
||||
help="Type of NN lm",
|
||||
choices=["rnn", "transformer"],
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lm-scale",
|
||||
type=float,
|
||||
default=0.3,
|
||||
help="""The scale of the neural network LM
|
||||
Used only when `--use-shallow-fusion` is set to True.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--hlg-scale",
|
||||
type=float,
|
||||
@ -301,7 +325,7 @@ def get_parser():
|
||||
"--skip-scoring",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="""Skip scoring, but still save the ASR output (for eval sets)."""
|
||||
help="""Skip scoring, but still save the ASR output (for eval sets).""",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
@ -314,8 +338,9 @@ def get_decoding_params() -> AttributeDict:
|
||||
params = AttributeDict(
|
||||
{
|
||||
"frame_shift_ms": 10,
|
||||
"search_beam": 20,
|
||||
"output_beam": 8,
|
||||
"search_beam": 20, # for k2 fsa composition
|
||||
"output_beam": 8, # for k2 fsa composition
|
||||
"beam": 4, # for prefix-beam-search
|
||||
"min_active_states": 30,
|
||||
"max_active_states": 10000,
|
||||
"use_double_scores": True,
|
||||
@ -333,6 +358,7 @@ def decode_one_batch(
|
||||
batch: dict,
|
||||
word_table: k2.SymbolTable,
|
||||
G: Optional[k2.Fsa] = None,
|
||||
LM: Optional[LmScorer] = None,
|
||||
) -> Dict[str, List[List[str]]]:
|
||||
"""Decode one batch and return the result in a dict. The dict has the
|
||||
following format:
|
||||
@ -377,10 +403,7 @@ def decode_one_batch(
|
||||
Return the decoding result. See above description for the format of
|
||||
the returned dict. Note: If it decodes to nothing, then return None.
|
||||
"""
|
||||
if HLG is not None:
|
||||
device = HLG.device
|
||||
else:
|
||||
device = H.device
|
||||
device = params.device
|
||||
feature = batch["inputs"]
|
||||
assert feature.ndim == 3
|
||||
feature = feature.to(device)
|
||||
@ -411,6 +434,48 @@ def decode_one_batch(
|
||||
key = "ctc-greedy-search"
|
||||
return {key: hyps}
|
||||
|
||||
if params.decoding_method == "prefix-beam-search":
|
||||
token_ids = ctc_prefix_beam_search(
|
||||
ctc_output=ctc_output, encoder_out_lens=encoder_out_lens
|
||||
)
|
||||
# hyps is a list of str, e.g., ['xxx yyy zzz', ...]
|
||||
hyps = bpe_model.decode(token_ids)
|
||||
|
||||
# hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ]
|
||||
hyps = [s.split() for s in hyps]
|
||||
key = "prefix-beam-search"
|
||||
return {key: hyps}
|
||||
|
||||
if params.decoding_method == "ctc-prefix-beam-search-attention-decoder-rescoring":
|
||||
best_path_dict = ctc_prefix_beam_search_attention_decoder_rescoring(
|
||||
ctc_output=ctc_output,
|
||||
attention_decoder=model.attention_decoder,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
)
|
||||
ans = dict()
|
||||
for a_scale_str, token_ids in best_path_dict.items():
|
||||
# hyps is a list of str, e.g., ['xxx yyy zzz', ...]
|
||||
hyps = bpe_model.decode(token_ids)
|
||||
# hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ]
|
||||
hyps = [s.split() for s in hyps]
|
||||
ans[a_scale_str] = hyps
|
||||
return ans
|
||||
|
||||
if params.decoding_method == "ctc-prefix-beam-search-shallow-fussion":
|
||||
token_ids = ctc_prefix_beam_search_shallow_fussion(
|
||||
ctc_output=ctc_output,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
LM=LM,
|
||||
)
|
||||
# hyps is a list of str, e.g., ['xxx yyy zzz', ...]
|
||||
hyps = bpe_model.decode(token_ids)
|
||||
|
||||
# hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ]
|
||||
hyps = [s.split() for s in hyps]
|
||||
key = "prefix-beam-search-shallow-fussion"
|
||||
return {key: hyps}
|
||||
|
||||
supervision_segments = torch.stack(
|
||||
(
|
||||
supervisions["sequence_idx"],
|
||||
@ -584,6 +649,7 @@ def decode_dataset(
|
||||
bpe_model: Optional[spm.SentencePieceProcessor],
|
||||
word_table: k2.SymbolTable,
|
||||
G: Optional[k2.Fsa] = None,
|
||||
LM: Optional[LmScorer] = None,
|
||||
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
|
||||
"""Decode dataset.
|
||||
|
||||
@ -634,6 +700,7 @@ def decode_dataset(
|
||||
batch=batch,
|
||||
word_table=word_table,
|
||||
G=G,
|
||||
LM=LM,
|
||||
)
|
||||
|
||||
for name, hyps in hyps_dict.items():
|
||||
@ -664,9 +731,7 @@ def save_asr_output(
|
||||
"""
|
||||
for key, results in results_dict.items():
|
||||
|
||||
recogs_filename = (
|
||||
params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
|
||||
)
|
||||
recogs_filename = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
|
||||
|
||||
results = sorted(results)
|
||||
store_transcripts(filename=recogs_filename, texts=results)
|
||||
@ -680,7 +745,8 @@ def save_wer_results(
|
||||
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
|
||||
):
|
||||
if params.decoding_method in (
|
||||
"attention-decoder-rescoring-with-ngram", "whole-lattice-rescoring"
|
||||
"attention-decoder-rescoring-with-ngram",
|
||||
"whole-lattice-rescoring",
|
||||
):
|
||||
# Set it to False since there are too many logs.
|
||||
enable_log = False
|
||||
@ -721,6 +787,7 @@ def save_wer_results(
|
||||
def main():
|
||||
parser = get_parser()
|
||||
LibriSpeechAsrDataModule.add_arguments(parser)
|
||||
LmScorer.add_arguments(parser)
|
||||
args = parser.parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
args.lang_dir = Path(args.lang_dir)
|
||||
@ -735,8 +802,11 @@ def main():
|
||||
set_caching_enabled(True) # lhotse
|
||||
|
||||
assert params.decoding_method in (
|
||||
"ctc-greedy-search",
|
||||
"ctc-decoding",
|
||||
"ctc-greedy-search",
|
||||
"prefix-beam-search",
|
||||
"ctc-prefix-beam-search-attention-decoder-rescoring",
|
||||
"ctc-prefix-beam-search-shallow-fussion",
|
||||
"1best",
|
||||
"nbest",
|
||||
"nbest-rescoring",
|
||||
@ -762,6 +832,11 @@ def main():
|
||||
params.suffix += f"_chunk-{params.chunk_size}"
|
||||
params.suffix += f"_left-context-{params.left_context_frames}"
|
||||
|
||||
if "prefix-beam-search" in params.decoding_method:
|
||||
params.suffix += f"_beam-{params.beam}"
|
||||
if params.decoding_method == "ctc-prefix-beam-search-shallow-fussion":
|
||||
params.suffix += f"_lm-scale-{params.lm_scale}"
|
||||
|
||||
if params.use_averaged_model:
|
||||
params.suffix += "_use-averaged-model"
|
||||
|
||||
@ -772,6 +847,8 @@ def main():
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
|
||||
params.device = device
|
||||
|
||||
logging.info(f"Device: {device}")
|
||||
logging.info(params)
|
||||
|
||||
@ -786,9 +863,19 @@ def main():
|
||||
params.sos_id = 1
|
||||
|
||||
if params.decoding_method in [
|
||||
"ctc-greedy-search", "ctc-decoding", "attention-decoder-rescoring-no-ngram"
|
||||
"ctc-greedy-search",
|
||||
"ctc-decoding",
|
||||
"attention-decoder-rescoring-no-ngram",
|
||||
"prefix-beam-search",
|
||||
"ctc-prefix-beam-search-attention-decoder-rescoring",
|
||||
"ctc-prefix-beam-search-shallow-fussion",
|
||||
]:
|
||||
HLG = None
|
||||
H = None
|
||||
if params.decoding_method in [
|
||||
"ctc-decoding",
|
||||
"attention-decoder-rescoring-no-ngram",
|
||||
]:
|
||||
H = k2.ctc_topo(
|
||||
max_token=max_token_id,
|
||||
modified=False,
|
||||
@ -844,7 +931,8 @@ def main():
|
||||
G = k2.Fsa.from_dict(d)
|
||||
|
||||
if params.decoding_method in [
|
||||
"whole-lattice-rescoring", "attention-decoder-rescoring-with-ngram"
|
||||
"whole-lattice-rescoring",
|
||||
"attention-decoder-rescoring-with-ngram",
|
||||
]:
|
||||
# Add epsilon self-loops to G as we will compose
|
||||
# it with the whole lattice later
|
||||
@ -858,6 +946,19 @@ def main():
|
||||
else:
|
||||
G = None
|
||||
|
||||
# only load the neural network LM if required
|
||||
if params.decoding_method == "ctc-prefix-beam-search-shallow-fussion":
|
||||
LM = LmScorer(
|
||||
lm_type=params.lm_type,
|
||||
params=params,
|
||||
device=device,
|
||||
lm_scale=params.lm_scale,
|
||||
)
|
||||
LM.to(device)
|
||||
LM.eval()
|
||||
else:
|
||||
LM = None
|
||||
|
||||
logging.info("About to create model")
|
||||
model = get_model(params)
|
||||
|
||||
@ -967,6 +1068,7 @@ def main():
|
||||
bpe_model=bpe_model,
|
||||
word_table=lexicon.word_table,
|
||||
G=G,
|
||||
LM=LM,
|
||||
)
|
||||
|
||||
save_asr_output(
|
||||
|
@ -1511,30 +1511,43 @@ def ctc_greedy_search(
|
||||
class Hypothesis:
|
||||
# The predicted tokens so far.
|
||||
# Newly predicted tokens are appended to `ys`.
|
||||
ys: List[int]
|
||||
ys: List[int] = field(default_factory=list)
|
||||
|
||||
# The log prob of ys.
|
||||
# It contains only one entry.
|
||||
log_prob_blank: torch.Tensor
|
||||
log_prob_blank: torch.Tensor = torch.zeros(1, dtype=torch.float32)
|
||||
|
||||
log_prob_non_blank: torch.Tensor
|
||||
log_prob_non_blank: torch.Tensor = torch.tensor(
|
||||
[float("-inf")], dtype=torch.float32
|
||||
)
|
||||
|
||||
# timestamp[i] is the frame index after subsampling
|
||||
# on which ys[i] is decoded
|
||||
timestamp: List[int] = field(default_factory=list)
|
||||
|
||||
# the lm score for next token given the current ys
|
||||
lm_score: Optional[torch.Tensor] = None
|
||||
# The lm score of ys
|
||||
# It contains only one entry
|
||||
lm_score: torch.Tensor = torch.zeros(1, dtype=torch.float32)
|
||||
|
||||
# the lm log_probs for next token given the history ys
|
||||
lm_log_probs: Optional[torch.Tensor] = None
|
||||
|
||||
# the RNNLM states (h and c in LSTM)
|
||||
state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None
|
||||
|
||||
# N-gram LM state
|
||||
state_cost: Optional[NgramLmStateCost] = None
|
||||
LODR_state: Optional[NgramLmStateCost] = None
|
||||
|
||||
# N-gram LM state
|
||||
Ngram_state: Optional[NgramLmStateCost] = None
|
||||
|
||||
# Context graph state
|
||||
context_state: Optional[ContextState] = None
|
||||
|
||||
@property
|
||||
def tot_score(self) -> torch.Tensor:
|
||||
return self.log_prob + self.lm_score
|
||||
|
||||
@property
|
||||
def log_prob(self) -> torch.Tensor:
|
||||
return torch.logaddexp(self.log_prob_non_blank, self.log_prob_blank)
|
||||
@ -1544,6 +1557,20 @@ class Hypothesis:
|
||||
"""Return a tuple representation of self.ys"""
|
||||
return tuple(self.ys)
|
||||
|
||||
def clone(self) -> "Hypothesis":
|
||||
return Hypothesis(
|
||||
ys=self.ys,
|
||||
log_prob_blank=self.log_prob_blank,
|
||||
log_prob_non_blank=self.log_prob_non_blank,
|
||||
timestamp=self.timestamp,
|
||||
lm_log_probs=self.lm_log_probs,
|
||||
lm_score=self.lm_score,
|
||||
state=self.state,
|
||||
LODR_state=self.LODR_state,
|
||||
Ngram_state=self.Ngram_state,
|
||||
context_state=self.context_state,
|
||||
)
|
||||
|
||||
|
||||
class HypothesisList(object):
|
||||
def __init__(self, data: Optional[Dict[tuple, Hypothesis]] = None) -> None:
|
||||
@ -1597,9 +1624,9 @@ class HypothesisList(object):
|
||||
Return the hypothesis that has the largest `log_prob`.
|
||||
"""
|
||||
if length_norm:
|
||||
return max(self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys))
|
||||
return max(self._data.values(), key=lambda hyp: hyp.tot_score / len(hyp.ys))
|
||||
else:
|
||||
return max(self._data.values(), key=lambda hyp: hyp.log_prob)
|
||||
return max(self._data.values(), key=lambda hyp: hyp.tot_score)
|
||||
|
||||
def remove(self, hyp: Hypothesis) -> None:
|
||||
"""Remove a given hypothesis.
|
||||
@ -1629,7 +1656,7 @@ class HypothesisList(object):
|
||||
"""
|
||||
ans = HypothesisList()
|
||||
for _, hyp in self._data.items():
|
||||
if hyp.log_prob > threshold:
|
||||
if hyp.tot_score > threshold:
|
||||
ans.add(hyp) # shallow copy
|
||||
return ans
|
||||
|
||||
@ -1645,17 +1672,20 @@ class HypothesisList(object):
|
||||
|
||||
if length_norm:
|
||||
hyps = sorted(
|
||||
hyps, key=lambda h: h[1].log_prob / len(h[1].ys), reverse=True
|
||||
hyps, key=lambda h: h[1].tot_score / len(h[1].ys), reverse=True
|
||||
)[:k]
|
||||
else:
|
||||
hyps = sorted(hyps, key=lambda h: h[1].log_prob, reverse=True)[:k]
|
||||
hyps = sorted(hyps, key=lambda h: h[1].tot_score, reverse=True)[:k]
|
||||
|
||||
ans = HypothesisList(dict(hyps))
|
||||
return ans
|
||||
|
||||
def __contains__(self, key: str):
|
||||
def __contains__(self, key: tuple):
|
||||
return key in self._data
|
||||
|
||||
def __getitem__(self, key: tuple):
|
||||
return self._data[key]
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self._data.values())
|
||||
|
||||
@ -1694,64 +1724,96 @@ def get_hyps_shape(hyps: List[HypothesisList]) -> k2.RaggedShape:
|
||||
return ans
|
||||
|
||||
|
||||
def _step_worker(log_probs, indexes, B, beam, blank_id):
|
||||
def _step_worker(
|
||||
log_probs,
|
||||
indexes,
|
||||
B,
|
||||
beam,
|
||||
blank_id,
|
||||
lm_scale: float = 0,
|
||||
LODR_lm_scale: float = 0,
|
||||
context_graph: Optional[ContextGraph] = None,
|
||||
):
|
||||
A = list(B)
|
||||
B = HypothesisList()
|
||||
for h in range(len(A)):
|
||||
hyp = A[h]
|
||||
for k in range(log_probs.size(0)):
|
||||
log_prob, index = log_probs[k], indexes[k]
|
||||
if index == blank_id:
|
||||
new_token = index.item()
|
||||
update_prefix = False
|
||||
new_hyp = hyp.clone()
|
||||
if new_token == blank_id:
|
||||
# Case 0: *a + ε => *a
|
||||
# *aε + ε => *a
|
||||
# Prefix does not change, update log_prob of blank
|
||||
new_hyp = Hypothesis(
|
||||
ys=hyp.ys[:],
|
||||
log_prob_non_blank=torch.tensor(
|
||||
new_hyp.log_prob_non_blank = torch.tensor(
|
||||
[float("-inf")], dtype=torch.float32
|
||||
),
|
||||
log_prob_blank=hyp.log_prob + log_prob,
|
||||
)
|
||||
new_hyp.log_prob_blank = hyp.log_prob + log_prob
|
||||
B.add(new_hyp)
|
||||
elif len(hyp.ys) > 0 and hyp.ys[-1] == index:
|
||||
elif len(hyp.ys) > 0 and hyp.ys[-1] == new_token:
|
||||
# Case 1: *a + a => *a
|
||||
# Prefix does not change, update log_prob of non_blank
|
||||
new_hyp = Hypothesis(
|
||||
ys=hyp.ys[:],
|
||||
log_prob_non_blank=hyp.log_prob_non_blank + log_prob,
|
||||
log_prob_blank=torch.tensor([float("-inf")], dtype=torch.float32),
|
||||
new_hyp.log_prob_non_blank = hyp.log_prob_non_blank + log_prob
|
||||
new_hyp.log_prob_blank = torch.tensor(
|
||||
[float("-inf")], dtype=torch.float32
|
||||
)
|
||||
B.add(new_hyp)
|
||||
|
||||
# Case 2: *aε + a => *aa
|
||||
# Prefix changes, update log_prob of blank
|
||||
new_hyp = Hypothesis(
|
||||
ys=hyp.ys[:] + [index.item()],
|
||||
log_prob_non_blank=hyp.log_prob_blank + log_prob,
|
||||
log_prob_blank=torch.tensor([float("-inf")], dtype=torch.float32),
|
||||
new_hyp = hyp.clone()
|
||||
# Caution: DO NOT use append, as clone is shallow copy
|
||||
new_hyp.ys = hyp.ys + [new_token]
|
||||
new_hyp.log_prob_non_blank = hyp.log_prob_blank + log_prob
|
||||
new_hyp.log_prob_blank = torch.tensor(
|
||||
[float("-inf")], dtype=torch.float32
|
||||
)
|
||||
B.add(new_hyp)
|
||||
update_prefix = True
|
||||
else:
|
||||
# Case 3: *a + b => *ab, *aε + b => *ab
|
||||
# Prefix changes, update log_prob of non_blank
|
||||
new_hyp = Hypothesis(
|
||||
ys=hyp.ys[:] + [index.item()],
|
||||
log_prob_non_blank=hyp.log_prob + log_prob,
|
||||
log_prob_blank=torch.tensor([float("-inf")], dtype=torch.float32),
|
||||
# Caution: DO NOT use append, as clone is shallow copy
|
||||
new_hyp.ys = hyp.ys + [new_token]
|
||||
new_hyp.log_prob_non_blank = hyp.log_prob + log_prob
|
||||
new_hyp.log_prob_blank = torch.tensor(
|
||||
[float("-inf")], dtype=torch.float32
|
||||
)
|
||||
update_prefix = True
|
||||
|
||||
if update_prefix:
|
||||
lm_score = hyp.lm_score
|
||||
if hyp.lm_log_probs is not None:
|
||||
lm_score += hyp.lm_log_probs[new_token] * lm_scale
|
||||
new_hyp.lm_log_probs = None
|
||||
|
||||
if context_graph is not None and hyp.context_state is not None:
|
||||
context_score, new_context_state = context_graph.forward_one_step(
|
||||
hyp.context_state, new_token
|
||||
)
|
||||
lm_score += context_score
|
||||
new_hyp.context_state = new_context_state
|
||||
|
||||
if hyp.LODR_state is not None:
|
||||
state_cost = hyp.LODR_state.forward_one_step(new_token)
|
||||
# calculate the score of the latest token
|
||||
current_ngram_score = state_cost.lm_score - hyp.LODR_state.lm_score
|
||||
assert current_ngram_score <= 0.0, (
|
||||
state_cost.lm_score,
|
||||
hyp.LODR_state.lm_score,
|
||||
)
|
||||
lm_score += LODR_lm_scale * current_ngram_score
|
||||
new_hyp.LODR_state = state_cost
|
||||
|
||||
new_hyp.lm_score = lm_score
|
||||
B.add(new_hyp)
|
||||
B = B.topk(beam)
|
||||
return B
|
||||
|
||||
|
||||
def _batch_worker(topk_values, topk_indexes, B, encoder_out_lens, beam, blank_id):
|
||||
B.add(
|
||||
Hypothesis(
|
||||
ys=[],
|
||||
log_prob_non_blank=torch.tensor([float("-inf")], dtype=torch.float32),
|
||||
log_prob_blank=torch.zeros(1, dtype=torch.float32),
|
||||
)
|
||||
)
|
||||
B.add(Hypothesis())
|
||||
for j in range(encoder_out_lens):
|
||||
log_probs, indexes = topk_values[j], topk_indexes[j]
|
||||
B = _step_worker(log_probs, indexes, B, beam, blank_id)
|
||||
@ -1763,11 +1825,11 @@ def ctc_prefix_beam_search(
|
||||
encoder_out_lens: torch.Tensor,
|
||||
beam: int = 4,
|
||||
blank_id: int = 0,
|
||||
context_graph: Optional[ContextGraph] = None,
|
||||
process_pool: Optional[Pool] = None,
|
||||
return_nbest: Optional[bool] = False,
|
||||
) -> Union[List[List[int]], List[HypothesisList]]:
|
||||
batch_size, num_frames, vocab_size = ctc_output.shape
|
||||
|
||||
# TODO: using a larger beam for first pass pruning
|
||||
topk_values, topk_indexes = ctc_output.topk(beam) # (B, T, beam)
|
||||
topk_values = topk_values.cpu()
|
||||
@ -1800,6 +1862,136 @@ def ctc_prefix_beam_search(
|
||||
return [hyp.ys for hyp in best_hyps]
|
||||
|
||||
|
||||
def ctc_prefix_beam_search_shallow_fussion(
|
||||
ctc_output: torch.Tensor,
|
||||
encoder_out_lens: torch.Tensor,
|
||||
beam: int = 4,
|
||||
blank_id: int = 0,
|
||||
LODR_lm: Optional[NgramLm] = None,
|
||||
LODR_lm_scale: Optional[float] = 0,
|
||||
LM: Optional[LmScorer] = None,
|
||||
context_graph: Optional[ContextGraph] = None,
|
||||
) -> List[List[int]]:
|
||||
batch_size, num_frames, vocab_size = ctc_output.shape
|
||||
# TODO: using a larger beam for first pass pruning
|
||||
topk_values, topk_indexes = ctc_output.topk(beam) # (B, T, beam)
|
||||
topk_values = topk_values.cpu()
|
||||
topk_indexes = topk_indexes.cpu()
|
||||
encoder_out_lens = encoder_out_lens.tolist()
|
||||
device = ctc_output.device
|
||||
|
||||
lm_scale = 0
|
||||
init_scores = None
|
||||
init_states = None
|
||||
|
||||
if LM is not None:
|
||||
lm_scale = LM.lm_scale
|
||||
sos_id = getattr(LM, "sos_id", 1)
|
||||
# get initial lm score and lm state by scoring the "sos" token
|
||||
sos_token = torch.tensor([[sos_id]]).to(torch.int64).to(device)
|
||||
lens = torch.tensor([1]).to(device)
|
||||
init_scores, init_states = LM.score_token(sos_token, lens)
|
||||
init_scores, init_states = init_scores.cpu(), (
|
||||
init_states[0].cpu(),
|
||||
init_states[1].cpu(),
|
||||
)
|
||||
|
||||
B = [HypothesisList() for _ in range(batch_size)]
|
||||
for i in range(batch_size):
|
||||
B[i].add(
|
||||
Hypothesis(
|
||||
ys=[],
|
||||
log_prob_non_blank=torch.tensor([float("-inf")], dtype=torch.float32),
|
||||
log_prob_blank=torch.zeros(1, dtype=torch.float32),
|
||||
lm_score=torch.zeros(1, dtype=torch.float32),
|
||||
state=init_states,
|
||||
lm_log_probs=None if init_scores is None else init_scores.reshape(-1),
|
||||
LODR_state=None if LODR_lm is None else NgramLmStateCost(LODR_lm),
|
||||
context_state=None if context_graph is None else context_graph.root,
|
||||
)
|
||||
)
|
||||
for j in range(num_frames):
|
||||
for i in range(batch_size):
|
||||
if j < encoder_out_lens[i]:
|
||||
log_probs, indexes = topk_values[i][j], topk_indexes[i][j]
|
||||
B[i] = _step_worker(
|
||||
log_probs,
|
||||
indexes,
|
||||
B[i],
|
||||
beam,
|
||||
blank_id,
|
||||
lm_scale=lm_scale,
|
||||
LODR_lm_scale=LODR_lm_scale,
|
||||
context_graph=context_graph,
|
||||
)
|
||||
if LM is None:
|
||||
continue
|
||||
# update lm_score
|
||||
token_list = [] # a list of list
|
||||
hs = []
|
||||
cs = []
|
||||
indexes = [] # (batch_idx, key)
|
||||
for batch_idx, hyps in enumerate(B):
|
||||
for hyp in hyps:
|
||||
if hyp.lm_log_probs is None:
|
||||
if LM.lm_type == "rnn":
|
||||
token_list.append([hyp.ys[-1]])
|
||||
# store the LSTM states
|
||||
hs.append(hyp.state[0])
|
||||
cs.append(hyp.state[1])
|
||||
else:
|
||||
# for transformer LM
|
||||
token_list.append([sos_id] + hyp.ys[:])
|
||||
indexes.append((batch_idx, hyp.key))
|
||||
if len(token_list) != 0:
|
||||
x_lens = torch.tensor([len(tokens) for tokens in token_list]).to(device)
|
||||
if LM.lm_type == "rnn":
|
||||
tokens_to_score = (
|
||||
torch.tensor(token_list).to(torch.int64).to(device).reshape(-1, 1)
|
||||
)
|
||||
hs = torch.cat(hs, dim=1).to(device)
|
||||
cs = torch.cat(cs, dim=1).to(device)
|
||||
state = (hs, cs)
|
||||
else:
|
||||
# for transformer LM
|
||||
tokens_list = [torch.tensor(tokens) for tokens in token_list]
|
||||
tokens_to_score = (
|
||||
torch.nn.utils.rnn.pad_sequence(
|
||||
tokens_list, batch_first=True, padding_value=0.0
|
||||
)
|
||||
.to(device)
|
||||
.to(torch.int64)
|
||||
)
|
||||
state = None
|
||||
|
||||
scores, lm_states = LM.score_token(tokens_to_score, x_lens, state)
|
||||
scores, lm_states = scores.cpu(), (lm_states[0].cpu(), lm_states[1].cpu())
|
||||
assert scores.size(0) == len(indexes), (scores.size(0), len(indexes))
|
||||
for i in range(scores.size(0)):
|
||||
batch_idx, key = indexes[i]
|
||||
B[batch_idx][key].lm_log_probs = scores[i]
|
||||
if LM.lm_type == "rnn":
|
||||
state = (
|
||||
lm_states[0][:, i, :].unsqueeze(1),
|
||||
lm_states[1][:, i, :].unsqueeze(1),
|
||||
)
|
||||
B[batch_idx][key].state = state
|
||||
|
||||
# finalize context_state, if the matched contexts do not reach final state
|
||||
# we need to add the score on the corresponding backoff arc
|
||||
if context_graph is not None:
|
||||
for hyps in B:
|
||||
for hyp in hyps:
|
||||
context_score, new_context_state = context_graph.finalize(
|
||||
hyp.context_state
|
||||
)
|
||||
hyp.lm_score += context_score
|
||||
hyp.context_state = new_context_state
|
||||
|
||||
best_hyps = [b.get_most_probable() for b in B]
|
||||
return [hyp.ys for hyp in best_hyps]
|
||||
|
||||
|
||||
def ctc_prefix_beam_search_attention_decoder_rescoring(
|
||||
ctc_output: torch.Tensor,
|
||||
attention_decoder: torch.nn.Module,
|
||||
|
@ -19,8 +19,10 @@
|
||||
|
||||
import argparse
|
||||
import collections
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import pathlib
|
||||
import re
|
||||
import subprocess
|
||||
from collections import defaultdict
|
||||
@ -178,6 +180,15 @@ class AttributeDict(dict):
|
||||
return
|
||||
raise AttributeError(f"No such attribute '{key}'")
|
||||
|
||||
def __str__(self, indent: int = 2):
|
||||
tmp = {}
|
||||
for k, v in self.items():
|
||||
# PosixPath is ont JSON serializable
|
||||
if isinstance(v, pathlib.Path) or isinstance(v, torch.device):
|
||||
v = str(v)
|
||||
tmp[k] = v
|
||||
return json.dumps(tmp, indent=indent, sort_keys=True)
|
||||
|
||||
|
||||
def encode_supervisions(
|
||||
supervisions: dict,
|
||||
|
Loading…
x
Reference in New Issue
Block a user