mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
Add documents to ctc prefix beam search
This commit is contained in:
parent
33fa9e8b00
commit
3a40c073e7
@ -1513,10 +1513,12 @@ class Hypothesis:
|
||||
# Newly predicted tokens are appended to `ys`.
|
||||
ys: List[int] = field(default_factory=list)
|
||||
|
||||
# The log prob of ys.
|
||||
# The log prob of ys that ends with blank token.
|
||||
# It contains only one entry.
|
||||
log_prob_blank: torch.Tensor = torch.zeros(1, dtype=torch.float32)
|
||||
|
||||
# The log prob of ys that ends with non blank token.
|
||||
# It contains only one entry.
|
||||
log_prob_non_blank: torch.Tensor = torch.tensor(
|
||||
[float("-inf")], dtype=torch.float32
|
||||
)
|
||||
@ -1526,16 +1528,18 @@ class Hypothesis:
|
||||
timestamp: List[int] = field(default_factory=list)
|
||||
|
||||
# The lm score of ys
|
||||
# May contain external LM score (including LODR score) and contextual biasing score
|
||||
# It contains only one entry
|
||||
lm_score: torch.Tensor = torch.zeros(1, dtype=torch.float32)
|
||||
|
||||
# the lm log_probs for next token given the history ys
|
||||
# The number of elements should be equal to vocabulary size.
|
||||
lm_log_probs: Optional[torch.Tensor] = None
|
||||
|
||||
# the RNNLM states (h and c in LSTM)
|
||||
state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None
|
||||
|
||||
# N-gram LM state
|
||||
# LODR (N-gram LM) state
|
||||
LODR_state: Optional[NgramLmStateCost] = None
|
||||
|
||||
# N-gram LM state
|
||||
@ -1544,10 +1548,12 @@ class Hypothesis:
|
||||
# Context graph state
|
||||
context_state: Optional[ContextState] = None
|
||||
|
||||
# This is the total score of current path, acoustic plus external LM score.
|
||||
@property
|
||||
def tot_score(self) -> torch.Tensor:
|
||||
return self.log_prob + self.lm_score
|
||||
|
||||
# This is only the probability from model output (i.e External LM score not included).
|
||||
@property
|
||||
def log_prob(self) -> torch.Tensor:
|
||||
return torch.logaddexp(self.log_prob_non_blank, self.log_prob_blank)
|
||||
@ -1614,14 +1620,14 @@ class HypothesisList(object):
|
||||
|
||||
def get_most_probable(self, length_norm: bool = False) -> Hypothesis:
|
||||
"""Get the most probable hypothesis, i.e., the one with
|
||||
the largest `log_prob`.
|
||||
the largest `tot_score`.
|
||||
|
||||
Args:
|
||||
length_norm:
|
||||
If True, the `log_prob` of a hypothesis is normalized by the
|
||||
If True, the `tot_score` of a hypothesis is normalized by the
|
||||
number of tokens in it.
|
||||
Returns:
|
||||
Return the hypothesis that has the largest `log_prob`.
|
||||
Return the hypothesis that has the largest `tot_score`.
|
||||
"""
|
||||
if length_norm:
|
||||
return max(self._data.values(), key=lambda hyp: hyp.tot_score / len(hyp.ys))
|
||||
@ -1645,14 +1651,14 @@ class HypothesisList(object):
|
||||
del self._data[key]
|
||||
|
||||
def filter(self, threshold: torch.Tensor) -> "HypothesisList":
|
||||
"""Remove all Hypotheses whose log_prob is less than threshold.
|
||||
"""Remove all Hypotheses whose tot_score is less than threshold.
|
||||
|
||||
Caution:
|
||||
`self` is not modified. Instead, a new HypothesisList is returned.
|
||||
|
||||
Returns:
|
||||
Return a new HypothesisList containing all hypotheses from `self`
|
||||
with `log_prob` being greater than the given `threshold`.
|
||||
with `tot_score` being greater than the given `threshold`.
|
||||
"""
|
||||
ans = HypothesisList()
|
||||
for _, hyp in self._data.items():
|
||||
@ -1665,7 +1671,7 @@ class HypothesisList(object):
|
||||
|
||||
Args:
|
||||
length_norm:
|
||||
If True, the `log_prob` of a hypothesis is normalized by the
|
||||
If True, the `tot_score` of a hypothesis is normalized by the
|
||||
number of tokens in it.
|
||||
"""
|
||||
hyps = list(self._data.items())
|
||||
@ -1725,15 +1731,39 @@ def get_hyps_shape(hyps: List[HypothesisList]) -> k2.RaggedShape:
|
||||
|
||||
|
||||
def _step_worker(
|
||||
log_probs,
|
||||
indexes,
|
||||
B,
|
||||
beam,
|
||||
blank_id,
|
||||
log_probs: torch.Tensor,
|
||||
indexes: torch.Tensor,
|
||||
B: HypothesisList,
|
||||
beam: int = 4,
|
||||
blank_id: int = 0,
|
||||
lm_scale: float = 0,
|
||||
LODR_lm_scale: float = 0,
|
||||
context_graph: Optional[ContextGraph] = None,
|
||||
):
|
||||
) -> HypothesisList:
|
||||
"""The worker to decode one step.
|
||||
|
||||
Args:
|
||||
log_probs:
|
||||
topk log_probs of current step (i.e. the kept tokens of first pass pruning),
|
||||
the shape is (beam,)
|
||||
topk_indexes:
|
||||
The indexes of the topk_values above, the shape is (beam,)
|
||||
B:
|
||||
An instance of HypothesisList containing the kept hypothesis.
|
||||
beam:
|
||||
The number of hypothesis to be kept at each step.
|
||||
blank_id:
|
||||
The id of blank in the vocabulary.
|
||||
lm_scale:
|
||||
The scale of nn lm.
|
||||
LODR_lm_scale:
|
||||
The scale of the LODR_lm
|
||||
context_graph:
|
||||
A ContextGraph instance containing contextual phrases.
|
||||
|
||||
Return:
|
||||
Returns the updated HypothesisList.
|
||||
"""
|
||||
A = list(B)
|
||||
B = HypothesisList()
|
||||
for h in range(len(A)):
|
||||
@ -1812,7 +1842,34 @@ def _step_worker(
|
||||
return B
|
||||
|
||||
|
||||
def _batch_worker(topk_values, topk_indexes, B, encoder_out_lens, beam, blank_id):
|
||||
def _sequence_worker(
|
||||
topk_values: torch.Tensor,
|
||||
topk_indexes: torch.Tensor,
|
||||
B: HypothesisList,
|
||||
encoder_out_lens: torch.Tensor,
|
||||
beam: int = 4,
|
||||
blank_id: int = 0,
|
||||
) -> HypothesisList:
|
||||
"""The worker to decode one sequence.
|
||||
|
||||
Args:
|
||||
topk_values:
|
||||
topk log_probs of model output (i.e. the kept tokens of first pass pruning),
|
||||
the shape is (T, beam)
|
||||
topk_indexes:
|
||||
The indexes of the topk_values above, the shape is (T, beam)
|
||||
B:
|
||||
An instance of HypothesisList containing the kept hypothesis.
|
||||
encoder_out_lens:
|
||||
The lengths (frames) of sequences after subsampling, the shape is (B,)
|
||||
beam:
|
||||
The number of hypothesis to be kept at each step.
|
||||
blank_id:
|
||||
The id of blank in the vocabulary.
|
||||
|
||||
Return:
|
||||
Returns the updated HypothesisList.
|
||||
"""
|
||||
B.add(Hypothesis())
|
||||
for j in range(encoder_out_lens):
|
||||
log_probs, indexes = topk_values[j], topk_indexes[j]
|
||||
@ -1828,6 +1885,24 @@ def ctc_prefix_beam_search(
|
||||
process_pool: Optional[Pool] = None,
|
||||
return_nbest: Optional[bool] = False,
|
||||
) -> Union[List[List[int]], List[HypothesisList]]:
|
||||
"""Implement prefix search decoding in "Connectionist Temporal Classification:
|
||||
Labelling Unsegmented Sequence Data with Recurrent Neural Networks".
|
||||
|
||||
Args:
|
||||
ctc_output:
|
||||
The output of ctc head (log probability), the shape is (B, T, V)
|
||||
encoder_out_lens:
|
||||
The lengths (frames) of sequences after subsampling, the shape is (B,)
|
||||
beam:
|
||||
The number of hypothesis to be kept at each step.
|
||||
blank_id:
|
||||
The id of blank in the vocabulary.
|
||||
process_pool:
|
||||
The process pool for parallel decoding, if not provided, it will use all
|
||||
you cpu cores by default.
|
||||
return_nbest:
|
||||
If true, return a list of HypothesisList, return a list of list of decoded token ids otherwise.
|
||||
"""
|
||||
batch_size, num_frames, vocab_size = ctc_output.shape
|
||||
|
||||
# TODO: using a larger beam for first pass pruning
|
||||
@ -1850,7 +1925,7 @@ def ctc_prefix_beam_search(
|
||||
blank_id,
|
||||
)
|
||||
)
|
||||
async_results = pool.starmap_async(_batch_worker, arguments)
|
||||
async_results = pool.starmap_async(_sequence_worker, arguments)
|
||||
B = list(async_results.get())
|
||||
if process_pool is None:
|
||||
pool.close()
|
||||
@ -1872,6 +1947,32 @@ def ctc_prefix_beam_search_shallow_fussion(
|
||||
LM: Optional[LmScorer] = None,
|
||||
context_graph: Optional[ContextGraph] = None,
|
||||
) -> List[List[int]]:
|
||||
"""Implement prefix search decoding in "Connectionist Temporal Classification:
|
||||
Labelling Unsegmented Sequence Data with Recurrent Neural Networks" and add
|
||||
nervous language model shallow fussion, it also supports contextual
|
||||
biasing with a given grammar.
|
||||
|
||||
Args:
|
||||
ctc_output:
|
||||
The output of ctc head (log probability), the shape is (B, T, V)
|
||||
encoder_out_lens:
|
||||
The lengths (frames) of sequences after subsampling, the shape is (B,)
|
||||
beam:
|
||||
The number of hypothesis to be kept at each step.
|
||||
blank_id:
|
||||
The id of blank in the vocabulary.
|
||||
LODR_lm:
|
||||
A low order n-gram LM, whose score will be subtracted during shallow fusion
|
||||
LODR_lm_scale:
|
||||
The scale of the LODR_lm
|
||||
LM:
|
||||
A neural net LM, e.g an RNNLM or transformer LM
|
||||
context_graph:
|
||||
A ContextGraph instance containing contextual phrases.
|
||||
|
||||
Return:
|
||||
Returns a list of list of decoded token ids.
|
||||
"""
|
||||
batch_size, num_frames, vocab_size = ctc_output.shape
|
||||
# TODO: using a larger beam for first pass pruning
|
||||
topk_values, topk_indexes = ctc_output.topk(beam) # (B, T, beam)
|
||||
@ -1926,14 +2027,14 @@ def ctc_prefix_beam_search_shallow_fussion(
|
||||
)
|
||||
if LM is None:
|
||||
continue
|
||||
# update lm_score
|
||||
# update lm_log_probs
|
||||
token_list = [] # a list of list
|
||||
hs = []
|
||||
cs = []
|
||||
indexes = [] # (batch_idx, key)
|
||||
for batch_idx, hyps in enumerate(B):
|
||||
for hyp in hyps:
|
||||
if hyp.lm_log_probs is None:
|
||||
if hyp.lm_log_probs is None: # those hyps that prefix changes
|
||||
if LM.lm_type == "rnn":
|
||||
token_list.append([hyp.ys[-1]])
|
||||
# store the LSTM states
|
||||
@ -2000,7 +2101,32 @@ def ctc_prefix_beam_search_attention_decoder_rescoring(
|
||||
beam: int = 8,
|
||||
blank_id: int = 0,
|
||||
attention_scale: Optional[float] = None,
|
||||
process_pool: Optional[Pool] = None,
|
||||
):
|
||||
"""Implement prefix search decoding in "Connectionist Temporal Classification:
|
||||
Labelling Unsegmented Sequence Data with Recurrent Neural Networks" and add
|
||||
attention decoder rescoring.
|
||||
|
||||
Args:
|
||||
ctc_output:
|
||||
The output of ctc head (log probability), the shape is (B, T, V)
|
||||
attention_decoder:
|
||||
The attention decoder.
|
||||
encoder_out:
|
||||
The output of encoder, the shape is (B, T, D)
|
||||
encoder_out_lens:
|
||||
The lengths (frames) of sequences after subsampling, the shape is (B,)
|
||||
beam:
|
||||
The number of hypothesis to be kept at each step.
|
||||
blank_id:
|
||||
The id of blank in the vocabulary.
|
||||
attention_scale:
|
||||
The scale of attention decoder score, if not provided it will search in
|
||||
a default list (see the code below).
|
||||
process_pool:
|
||||
The process pool for parallel decoding, if not provided, it will use all
|
||||
you cpu cores by default.
|
||||
"""
|
||||
# List[HypothesisList]
|
||||
nbest = ctc_prefix_beam_search(
|
||||
ctc_output=ctc_output,
|
||||
|
Loading…
x
Reference in New Issue
Block a user