mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 10:16:14 +00:00
Add documents to ctc prefix beam search
This commit is contained in:
parent
33fa9e8b00
commit
3a40c073e7
@ -1513,10 +1513,12 @@ class Hypothesis:
|
|||||||
# Newly predicted tokens are appended to `ys`.
|
# Newly predicted tokens are appended to `ys`.
|
||||||
ys: List[int] = field(default_factory=list)
|
ys: List[int] = field(default_factory=list)
|
||||||
|
|
||||||
# The log prob of ys.
|
# The log prob of ys that ends with blank token.
|
||||||
# It contains only one entry.
|
# It contains only one entry.
|
||||||
log_prob_blank: torch.Tensor = torch.zeros(1, dtype=torch.float32)
|
log_prob_blank: torch.Tensor = torch.zeros(1, dtype=torch.float32)
|
||||||
|
|
||||||
|
# The log prob of ys that ends with non blank token.
|
||||||
|
# It contains only one entry.
|
||||||
log_prob_non_blank: torch.Tensor = torch.tensor(
|
log_prob_non_blank: torch.Tensor = torch.tensor(
|
||||||
[float("-inf")], dtype=torch.float32
|
[float("-inf")], dtype=torch.float32
|
||||||
)
|
)
|
||||||
@ -1526,16 +1528,18 @@ class Hypothesis:
|
|||||||
timestamp: List[int] = field(default_factory=list)
|
timestamp: List[int] = field(default_factory=list)
|
||||||
|
|
||||||
# The lm score of ys
|
# The lm score of ys
|
||||||
|
# May contain external LM score (including LODR score) and contextual biasing score
|
||||||
# It contains only one entry
|
# It contains only one entry
|
||||||
lm_score: torch.Tensor = torch.zeros(1, dtype=torch.float32)
|
lm_score: torch.Tensor = torch.zeros(1, dtype=torch.float32)
|
||||||
|
|
||||||
# the lm log_probs for next token given the history ys
|
# the lm log_probs for next token given the history ys
|
||||||
|
# The number of elements should be equal to vocabulary size.
|
||||||
lm_log_probs: Optional[torch.Tensor] = None
|
lm_log_probs: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
# the RNNLM states (h and c in LSTM)
|
# the RNNLM states (h and c in LSTM)
|
||||||
state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None
|
state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None
|
||||||
|
|
||||||
# N-gram LM state
|
# LODR (N-gram LM) state
|
||||||
LODR_state: Optional[NgramLmStateCost] = None
|
LODR_state: Optional[NgramLmStateCost] = None
|
||||||
|
|
||||||
# N-gram LM state
|
# N-gram LM state
|
||||||
@ -1544,10 +1548,12 @@ class Hypothesis:
|
|||||||
# Context graph state
|
# Context graph state
|
||||||
context_state: Optional[ContextState] = None
|
context_state: Optional[ContextState] = None
|
||||||
|
|
||||||
|
# This is the total score of current path, acoustic plus external LM score.
|
||||||
@property
|
@property
|
||||||
def tot_score(self) -> torch.Tensor:
|
def tot_score(self) -> torch.Tensor:
|
||||||
return self.log_prob + self.lm_score
|
return self.log_prob + self.lm_score
|
||||||
|
|
||||||
|
# This is only the probability from model output (i.e External LM score not included).
|
||||||
@property
|
@property
|
||||||
def log_prob(self) -> torch.Tensor:
|
def log_prob(self) -> torch.Tensor:
|
||||||
return torch.logaddexp(self.log_prob_non_blank, self.log_prob_blank)
|
return torch.logaddexp(self.log_prob_non_blank, self.log_prob_blank)
|
||||||
@ -1614,14 +1620,14 @@ class HypothesisList(object):
|
|||||||
|
|
||||||
def get_most_probable(self, length_norm: bool = False) -> Hypothesis:
|
def get_most_probable(self, length_norm: bool = False) -> Hypothesis:
|
||||||
"""Get the most probable hypothesis, i.e., the one with
|
"""Get the most probable hypothesis, i.e., the one with
|
||||||
the largest `log_prob`.
|
the largest `tot_score`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
length_norm:
|
length_norm:
|
||||||
If True, the `log_prob` of a hypothesis is normalized by the
|
If True, the `tot_score` of a hypothesis is normalized by the
|
||||||
number of tokens in it.
|
number of tokens in it.
|
||||||
Returns:
|
Returns:
|
||||||
Return the hypothesis that has the largest `log_prob`.
|
Return the hypothesis that has the largest `tot_score`.
|
||||||
"""
|
"""
|
||||||
if length_norm:
|
if length_norm:
|
||||||
return max(self._data.values(), key=lambda hyp: hyp.tot_score / len(hyp.ys))
|
return max(self._data.values(), key=lambda hyp: hyp.tot_score / len(hyp.ys))
|
||||||
@ -1645,14 +1651,14 @@ class HypothesisList(object):
|
|||||||
del self._data[key]
|
del self._data[key]
|
||||||
|
|
||||||
def filter(self, threshold: torch.Tensor) -> "HypothesisList":
|
def filter(self, threshold: torch.Tensor) -> "HypothesisList":
|
||||||
"""Remove all Hypotheses whose log_prob is less than threshold.
|
"""Remove all Hypotheses whose tot_score is less than threshold.
|
||||||
|
|
||||||
Caution:
|
Caution:
|
||||||
`self` is not modified. Instead, a new HypothesisList is returned.
|
`self` is not modified. Instead, a new HypothesisList is returned.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Return a new HypothesisList containing all hypotheses from `self`
|
Return a new HypothesisList containing all hypotheses from `self`
|
||||||
with `log_prob` being greater than the given `threshold`.
|
with `tot_score` being greater than the given `threshold`.
|
||||||
"""
|
"""
|
||||||
ans = HypothesisList()
|
ans = HypothesisList()
|
||||||
for _, hyp in self._data.items():
|
for _, hyp in self._data.items():
|
||||||
@ -1665,7 +1671,7 @@ class HypothesisList(object):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
length_norm:
|
length_norm:
|
||||||
If True, the `log_prob` of a hypothesis is normalized by the
|
If True, the `tot_score` of a hypothesis is normalized by the
|
||||||
number of tokens in it.
|
number of tokens in it.
|
||||||
"""
|
"""
|
||||||
hyps = list(self._data.items())
|
hyps = list(self._data.items())
|
||||||
@ -1725,15 +1731,39 @@ def get_hyps_shape(hyps: List[HypothesisList]) -> k2.RaggedShape:
|
|||||||
|
|
||||||
|
|
||||||
def _step_worker(
|
def _step_worker(
|
||||||
log_probs,
|
log_probs: torch.Tensor,
|
||||||
indexes,
|
indexes: torch.Tensor,
|
||||||
B,
|
B: HypothesisList,
|
||||||
beam,
|
beam: int = 4,
|
||||||
blank_id,
|
blank_id: int = 0,
|
||||||
lm_scale: float = 0,
|
lm_scale: float = 0,
|
||||||
LODR_lm_scale: float = 0,
|
LODR_lm_scale: float = 0,
|
||||||
context_graph: Optional[ContextGraph] = None,
|
context_graph: Optional[ContextGraph] = None,
|
||||||
):
|
) -> HypothesisList:
|
||||||
|
"""The worker to decode one step.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
log_probs:
|
||||||
|
topk log_probs of current step (i.e. the kept tokens of first pass pruning),
|
||||||
|
the shape is (beam,)
|
||||||
|
topk_indexes:
|
||||||
|
The indexes of the topk_values above, the shape is (beam,)
|
||||||
|
B:
|
||||||
|
An instance of HypothesisList containing the kept hypothesis.
|
||||||
|
beam:
|
||||||
|
The number of hypothesis to be kept at each step.
|
||||||
|
blank_id:
|
||||||
|
The id of blank in the vocabulary.
|
||||||
|
lm_scale:
|
||||||
|
The scale of nn lm.
|
||||||
|
LODR_lm_scale:
|
||||||
|
The scale of the LODR_lm
|
||||||
|
context_graph:
|
||||||
|
A ContextGraph instance containing contextual phrases.
|
||||||
|
|
||||||
|
Return:
|
||||||
|
Returns the updated HypothesisList.
|
||||||
|
"""
|
||||||
A = list(B)
|
A = list(B)
|
||||||
B = HypothesisList()
|
B = HypothesisList()
|
||||||
for h in range(len(A)):
|
for h in range(len(A)):
|
||||||
@ -1812,7 +1842,34 @@ def _step_worker(
|
|||||||
return B
|
return B
|
||||||
|
|
||||||
|
|
||||||
def _batch_worker(topk_values, topk_indexes, B, encoder_out_lens, beam, blank_id):
|
def _sequence_worker(
|
||||||
|
topk_values: torch.Tensor,
|
||||||
|
topk_indexes: torch.Tensor,
|
||||||
|
B: HypothesisList,
|
||||||
|
encoder_out_lens: torch.Tensor,
|
||||||
|
beam: int = 4,
|
||||||
|
blank_id: int = 0,
|
||||||
|
) -> HypothesisList:
|
||||||
|
"""The worker to decode one sequence.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
topk_values:
|
||||||
|
topk log_probs of model output (i.e. the kept tokens of first pass pruning),
|
||||||
|
the shape is (T, beam)
|
||||||
|
topk_indexes:
|
||||||
|
The indexes of the topk_values above, the shape is (T, beam)
|
||||||
|
B:
|
||||||
|
An instance of HypothesisList containing the kept hypothesis.
|
||||||
|
encoder_out_lens:
|
||||||
|
The lengths (frames) of sequences after subsampling, the shape is (B,)
|
||||||
|
beam:
|
||||||
|
The number of hypothesis to be kept at each step.
|
||||||
|
blank_id:
|
||||||
|
The id of blank in the vocabulary.
|
||||||
|
|
||||||
|
Return:
|
||||||
|
Returns the updated HypothesisList.
|
||||||
|
"""
|
||||||
B.add(Hypothesis())
|
B.add(Hypothesis())
|
||||||
for j in range(encoder_out_lens):
|
for j in range(encoder_out_lens):
|
||||||
log_probs, indexes = topk_values[j], topk_indexes[j]
|
log_probs, indexes = topk_values[j], topk_indexes[j]
|
||||||
@ -1828,6 +1885,24 @@ def ctc_prefix_beam_search(
|
|||||||
process_pool: Optional[Pool] = None,
|
process_pool: Optional[Pool] = None,
|
||||||
return_nbest: Optional[bool] = False,
|
return_nbest: Optional[bool] = False,
|
||||||
) -> Union[List[List[int]], List[HypothesisList]]:
|
) -> Union[List[List[int]], List[HypothesisList]]:
|
||||||
|
"""Implement prefix search decoding in "Connectionist Temporal Classification:
|
||||||
|
Labelling Unsegmented Sequence Data with Recurrent Neural Networks".
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ctc_output:
|
||||||
|
The output of ctc head (log probability), the shape is (B, T, V)
|
||||||
|
encoder_out_lens:
|
||||||
|
The lengths (frames) of sequences after subsampling, the shape is (B,)
|
||||||
|
beam:
|
||||||
|
The number of hypothesis to be kept at each step.
|
||||||
|
blank_id:
|
||||||
|
The id of blank in the vocabulary.
|
||||||
|
process_pool:
|
||||||
|
The process pool for parallel decoding, if not provided, it will use all
|
||||||
|
you cpu cores by default.
|
||||||
|
return_nbest:
|
||||||
|
If true, return a list of HypothesisList, return a list of list of decoded token ids otherwise.
|
||||||
|
"""
|
||||||
batch_size, num_frames, vocab_size = ctc_output.shape
|
batch_size, num_frames, vocab_size = ctc_output.shape
|
||||||
|
|
||||||
# TODO: using a larger beam for first pass pruning
|
# TODO: using a larger beam for first pass pruning
|
||||||
@ -1850,7 +1925,7 @@ def ctc_prefix_beam_search(
|
|||||||
blank_id,
|
blank_id,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
async_results = pool.starmap_async(_batch_worker, arguments)
|
async_results = pool.starmap_async(_sequence_worker, arguments)
|
||||||
B = list(async_results.get())
|
B = list(async_results.get())
|
||||||
if process_pool is None:
|
if process_pool is None:
|
||||||
pool.close()
|
pool.close()
|
||||||
@ -1872,6 +1947,32 @@ def ctc_prefix_beam_search_shallow_fussion(
|
|||||||
LM: Optional[LmScorer] = None,
|
LM: Optional[LmScorer] = None,
|
||||||
context_graph: Optional[ContextGraph] = None,
|
context_graph: Optional[ContextGraph] = None,
|
||||||
) -> List[List[int]]:
|
) -> List[List[int]]:
|
||||||
|
"""Implement prefix search decoding in "Connectionist Temporal Classification:
|
||||||
|
Labelling Unsegmented Sequence Data with Recurrent Neural Networks" and add
|
||||||
|
nervous language model shallow fussion, it also supports contextual
|
||||||
|
biasing with a given grammar.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ctc_output:
|
||||||
|
The output of ctc head (log probability), the shape is (B, T, V)
|
||||||
|
encoder_out_lens:
|
||||||
|
The lengths (frames) of sequences after subsampling, the shape is (B,)
|
||||||
|
beam:
|
||||||
|
The number of hypothesis to be kept at each step.
|
||||||
|
blank_id:
|
||||||
|
The id of blank in the vocabulary.
|
||||||
|
LODR_lm:
|
||||||
|
A low order n-gram LM, whose score will be subtracted during shallow fusion
|
||||||
|
LODR_lm_scale:
|
||||||
|
The scale of the LODR_lm
|
||||||
|
LM:
|
||||||
|
A neural net LM, e.g an RNNLM or transformer LM
|
||||||
|
context_graph:
|
||||||
|
A ContextGraph instance containing contextual phrases.
|
||||||
|
|
||||||
|
Return:
|
||||||
|
Returns a list of list of decoded token ids.
|
||||||
|
"""
|
||||||
batch_size, num_frames, vocab_size = ctc_output.shape
|
batch_size, num_frames, vocab_size = ctc_output.shape
|
||||||
# TODO: using a larger beam for first pass pruning
|
# TODO: using a larger beam for first pass pruning
|
||||||
topk_values, topk_indexes = ctc_output.topk(beam) # (B, T, beam)
|
topk_values, topk_indexes = ctc_output.topk(beam) # (B, T, beam)
|
||||||
@ -1926,14 +2027,14 @@ def ctc_prefix_beam_search_shallow_fussion(
|
|||||||
)
|
)
|
||||||
if LM is None:
|
if LM is None:
|
||||||
continue
|
continue
|
||||||
# update lm_score
|
# update lm_log_probs
|
||||||
token_list = [] # a list of list
|
token_list = [] # a list of list
|
||||||
hs = []
|
hs = []
|
||||||
cs = []
|
cs = []
|
||||||
indexes = [] # (batch_idx, key)
|
indexes = [] # (batch_idx, key)
|
||||||
for batch_idx, hyps in enumerate(B):
|
for batch_idx, hyps in enumerate(B):
|
||||||
for hyp in hyps:
|
for hyp in hyps:
|
||||||
if hyp.lm_log_probs is None:
|
if hyp.lm_log_probs is None: # those hyps that prefix changes
|
||||||
if LM.lm_type == "rnn":
|
if LM.lm_type == "rnn":
|
||||||
token_list.append([hyp.ys[-1]])
|
token_list.append([hyp.ys[-1]])
|
||||||
# store the LSTM states
|
# store the LSTM states
|
||||||
@ -2000,7 +2101,32 @@ def ctc_prefix_beam_search_attention_decoder_rescoring(
|
|||||||
beam: int = 8,
|
beam: int = 8,
|
||||||
blank_id: int = 0,
|
blank_id: int = 0,
|
||||||
attention_scale: Optional[float] = None,
|
attention_scale: Optional[float] = None,
|
||||||
|
process_pool: Optional[Pool] = None,
|
||||||
):
|
):
|
||||||
|
"""Implement prefix search decoding in "Connectionist Temporal Classification:
|
||||||
|
Labelling Unsegmented Sequence Data with Recurrent Neural Networks" and add
|
||||||
|
attention decoder rescoring.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ctc_output:
|
||||||
|
The output of ctc head (log probability), the shape is (B, T, V)
|
||||||
|
attention_decoder:
|
||||||
|
The attention decoder.
|
||||||
|
encoder_out:
|
||||||
|
The output of encoder, the shape is (B, T, D)
|
||||||
|
encoder_out_lens:
|
||||||
|
The lengths (frames) of sequences after subsampling, the shape is (B,)
|
||||||
|
beam:
|
||||||
|
The number of hypothesis to be kept at each step.
|
||||||
|
blank_id:
|
||||||
|
The id of blank in the vocabulary.
|
||||||
|
attention_scale:
|
||||||
|
The scale of attention decoder score, if not provided it will search in
|
||||||
|
a default list (see the code below).
|
||||||
|
process_pool:
|
||||||
|
The process pool for parallel decoding, if not provided, it will use all
|
||||||
|
you cpu cores by default.
|
||||||
|
"""
|
||||||
# List[HypothesisList]
|
# List[HypothesisList]
|
||||||
nbest = ctc_prefix_beam_search(
|
nbest = ctc_prefix_beam_search(
|
||||||
ctc_output=ctc_output,
|
ctc_output=ctc_output,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user