add ctc prefix beam search

This commit is contained in:
pkufool 2024-09-26 15:22:29 +08:00
parent 83c36ecc18
commit 0c096a9ab4
2 changed files with 409 additions and 11 deletions

View File

@ -124,7 +124,7 @@ from asr_datamodule import GigaSpeechAsrDataModule
from gigaspeech_scoring import asr_text_post_processing
from lhotse import set_caching_enabled
from train import add_model_arguments, get_model, get_params
from train_cr_aed import add_model_arguments, get_model, get_params
from icefall.checkpoint import (
average_checkpoints,
@ -134,6 +134,7 @@ from icefall.checkpoint import (
)
from icefall.decode import (
ctc_greedy_search,
ctc_prefix_beam_search,
get_lattice,
nbest_decoding,
nbest_oracle,
@ -327,6 +328,17 @@ def get_decoding_params() -> AttributeDict:
return params
def post_processing(
results: List[Tuple[str, List[str], List[str]]],
) -> List[Tuple[str, List[str], List[str]]]:
new_results = []
for key, ref, hyp in results:
new_ref = asr_text_post_processing(" ".join(ref)).split()
new_hyp = asr_text_post_processing(" ".join(hyp)).split()
new_results.append((key, new_ref, new_hyp))
return new_results
def decode_one_batch(
params: AttributeDict,
model: nn.Module,
@ -380,10 +392,7 @@ def decode_one_batch(
Return the decoding result. See above description for the format of
the returned dict. Note: If it decodes to nothing, then return None.
"""
if HLG is not None:
device = HLG.device
else:
device = H.device
device = params.device
feature = batch["inputs"]
assert feature.ndim == 3
feature = feature.to(device)
@ -414,6 +423,18 @@ def decode_one_batch(
key = "ctc-greedy-search"
return {key: hyps}
if params.decoding_method == "prefix-beam-search":
token_ids = ctc_prefix_beam_search(
ctc_output=ctc_output, encoder_out_lens=encoder_out_lens, beam=8
)
# hyps is a list of str, e.g., ['xxx yyy zzz', ...]
hyps = bpe_model.decode(token_ids)
# hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ]
hyps = [s.split() for s in hyps]
key = "prefix-beam-search"
return {key: hyps}
supervision_segments = torch.stack(
(
supervisions["sequence_idx"],
@ -738,6 +759,7 @@ def main():
assert params.decoding_method in (
"ctc-greedy-search",
"prefix-beam-search",
"ctc-decoding",
"1best",
"nbest",
@ -773,6 +795,7 @@ def main():
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
params.device = device
logging.info(f"Device: {device}")
logging.info(params)
@ -790,14 +813,20 @@ def main():
if params.decoding_method in [
"ctc-greedy-search",
"ctc-decoding",
"prefix-beam-search",
"attention-decoder-rescoring-no-ngram",
]:
HLG = None
H = k2.ctc_topo(
max_token=max_token_id,
modified=False,
device=device,
)
H = None
if params.decoding_method in [
"ctc-decoding",
"attention-decoder-rescoring-no-ngram",
]:
H = k2.ctc_topo(
max_token=max_token_id,
modified=False,
device=device,
)
bpe_model = spm.SentencePieceProcessor()
bpe_model.load(str(params.lang_dir / "bpe.model"))
else:

View File

@ -15,11 +15,18 @@
# limitations under the License.
import logging
from typing import Dict, List, Optional, Union
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Tuple, Union
import k2
import torch
from multiprocessing.pool import Pool
from icefall.context_graph import ContextGraph, ContextState
from icefall.ngram_lm import NgramLm, NgramLmStateCost
from icefall.lm_wrapper import LmScorer
from icefall.utils import add_eos, add_sos, get_texts
DEFAULT_LM_SCALE = [
@ -1497,3 +1504,365 @@ def ctc_greedy_search(
hyps = [h[h != blank_id].tolist() for h in hyps]
return hyps
@dataclass
class Hypothesis:
# The predicted tokens so far.
# Newly predicted tokens are appended to `ys`.
ys: List[int]
# The log prob of ys.
# It contains only one entry.
log_prob_blank: torch.Tensor
log_prob_non_blank: torch.Tensor
# timestamp[i] is the frame index after subsampling
# on which ys[i] is decoded
timestamp: List[int] = field(default_factory=list)
# the lm score for next token given the current ys
lm_score: Optional[torch.Tensor] = None
# the RNNLM states (h and c in LSTM)
state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None
# N-gram LM state
state_cost: Optional[NgramLmStateCost] = None
# Context graph state
context_state: Optional[ContextState] = None
@property
def log_prob(self) -> torch.Tensor:
return torch.logaddexp(self.log_prob_non_blank, self.log_prob_blank)
@property
def key(self) -> tuple:
"""Return a tuple representation of self.ys"""
return tuple(self.ys)
class HypothesisList(object):
def __init__(self, data: Optional[Dict[tuple, Hypothesis]] = None) -> None:
"""
Args:
data:
A dict of Hypotheses. Its key is its `value.key`.
"""
if data is None:
self._data = {}
else:
self._data = data
@property
def data(self) -> Dict[tuple, Hypothesis]:
return self._data
def add(self, hyp: Hypothesis) -> None:
"""Add a Hypothesis to `self`.
If `hyp` already exists in `self`, its probability is updated using
`log-sum-exp` with the existed one.
Args:
hyp:
The hypothesis to be added.
"""
key = hyp.key
if key in self:
old_hyp = self._data[key] # shallow copy
torch.logaddexp(
old_hyp.log_prob_blank, hyp.log_prob_blank, out=old_hyp.log_prob_blank
)
torch.logaddexp(
old_hyp.log_prob_non_blank,
hyp.log_prob_non_blank,
out=old_hyp.log_prob_non_blank,
)
else:
self._data[key] = hyp
def get_most_probable(self, length_norm: bool = False) -> Hypothesis:
"""Get the most probable hypothesis, i.e., the one with
the largest `log_prob`.
Args:
length_norm:
If True, the `log_prob` of a hypothesis is normalized by the
number of tokens in it.
Returns:
Return the hypothesis that has the largest `log_prob`.
"""
if length_norm:
return max(self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys))
else:
return max(self._data.values(), key=lambda hyp: hyp.log_prob)
def remove(self, hyp: Hypothesis) -> None:
"""Remove a given hypothesis.
Caution:
`self` is modified **in-place**.
Args:
hyp:
The hypothesis to be removed from `self`.
Note: It must be contained in `self`. Otherwise,
an exception is raised.
"""
key = hyp.key
assert key in self, f"{key} does not exist"
del self._data[key]
def filter(self, threshold: torch.Tensor) -> "HypothesisList":
"""Remove all Hypotheses whose log_prob is less than threshold.
Caution:
`self` is not modified. Instead, a new HypothesisList is returned.
Returns:
Return a new HypothesisList containing all hypotheses from `self`
with `log_prob` being greater than the given `threshold`.
"""
ans = HypothesisList()
for _, hyp in self._data.items():
if hyp.log_prob > threshold:
ans.add(hyp) # shallow copy
return ans
def topk(self, k: int, length_norm: bool = False) -> "HypothesisList":
"""Return the top-k hypothesis.
Args:
length_norm:
If True, the `log_prob` of a hypothesis is normalized by the
number of tokens in it.
"""
hyps = list(self._data.items())
if length_norm:
hyps = sorted(
hyps, key=lambda h: h[1].log_prob / len(h[1].ys), reverse=True
)[:k]
else:
hyps = sorted(hyps, key=lambda h: h[1].log_prob, reverse=True)[:k]
ans = HypothesisList(dict(hyps))
return ans
def __contains__(self, key: str):
return key in self._data
def __iter__(self):
return iter(self._data.values())
def __len__(self) -> int:
return len(self._data)
def __str__(self) -> str:
s = []
for key in self:
s.append(key)
return ", ".join(str(s))
def get_hyps_shape(hyps: List[HypothesisList]) -> k2.RaggedShape:
"""Return a ragged shape with axes [utt][num_hyps].
Args:
hyps:
len(hyps) == batch_size. It contains the current hypothesis for
each utterance in the batch.
Returns:
Return a ragged shape with 2 axes [utt][num_hyps]. Note that
the shape is on CPU.
"""
num_hyps = [len(h) for h in hyps]
# torch.cumsum() is inclusive sum, so we put a 0 at the beginning
# to get exclusive sum later.
num_hyps.insert(0, 0)
num_hyps = torch.tensor(num_hyps)
row_splits = torch.cumsum(num_hyps, dim=0, dtype=torch.int32)
ans = k2.ragged.create_ragged_shape2(
row_splits=row_splits, cached_tot_size=row_splits[-1].item()
)
return ans
def _step_worker(log_probs, indexes, B, beam, blank_id):
A = list(B)
B = HypothesisList()
for h in range(len(A)):
hyp = A[h]
for k in range(log_probs.size(0)):
log_prob, index = log_probs[k], indexes[k]
if index == blank_id:
# Case 0: *a + ε => *a
# *aε + ε => *a
# Prefix does not change, update log_prob of blank
new_hyp = Hypothesis(
ys=hyp.ys[:],
log_prob_non_blank=torch.tensor(
[float("-inf")], dtype=torch.float32
),
log_prob_blank=hyp.log_prob + log_prob,
)
B.add(new_hyp)
elif len(hyp.ys) > 0 and hyp.ys[-1] == index:
# Case 1: *a + a => *a
# Prefix does not change, update log_prob of non_blank
new_hyp = Hypothesis(
ys=hyp.ys[:],
log_prob_non_blank=hyp.log_prob_non_blank + log_prob,
log_prob_blank=torch.tensor([float("-inf")], dtype=torch.float32),
)
B.add(new_hyp)
# Case 2: *aε + a => *aa
# Prefix changes, update log_prob of blank
new_hyp = Hypothesis(
ys=hyp.ys[:] + [index.item()],
log_prob_non_blank=hyp.log_prob_blank + log_prob,
log_prob_blank=torch.tensor([float("-inf")], dtype=torch.float32),
)
B.add(new_hyp)
else:
# Case 3: *a + b => *ab, *aε + b => *ab
# Prefix changes, update log_prob of non_blank
new_hyp = Hypothesis(
ys=hyp.ys[:] + [index.item()],
log_prob_non_blank=hyp.log_prob + log_prob,
log_prob_blank=torch.tensor([float("-inf")], dtype=torch.float32),
)
B.add(new_hyp)
B = B.topk(beam)
return B
def _batch_worker(topk_values, topk_indexes, B, encoder_out_lens, beam, blank_id):
B.add(
Hypothesis(
ys=[],
log_prob_non_blank=torch.tensor([float("-inf")], dtype=torch.float32),
log_prob_blank=torch.zeros(1, dtype=torch.float32),
)
)
for j in range(encoder_out_lens):
log_probs, indexes = topk_values[j], topk_indexes[j]
B = _step_worker(log_probs, indexes, B, beam, blank_id)
return B
def ctc_prefix_beam_search(
ctc_output: torch.Tensor,
encoder_out_lens: torch.Tensor,
beam: int = 4,
blank_id: int = 0,
context_graph: Optional[ContextGraph] = None,
process_pool: Optional[Pool] = None,
return_nbest: Optional[bool] = False,
) -> Union[List[List[int]], List[HypothesisList]]:
batch_size, num_frames, vocab_size = ctc_output.shape
# TODO: using a larger beam for first pass pruning
topk_values, topk_indexes = ctc_output.topk(beam) # (B, T, beam)
topk_values = topk_values.cpu()
topk_indexes = topk_indexes.cpu()
B = [HypothesisList() for _ in range(batch_size)]
pool = Pool() if process_pool is None else process_pool
arguments = []
for i in range(batch_size):
arguments.append(
(
topk_values[i],
topk_indexes[i],
B[i],
encoder_out_lens[i].item(),
beam,
blank_id,
)
)
async_results = pool.starmap_async(_batch_worker, arguments)
B = list(async_results.get())
if process_pool is None:
pool.close()
pool.join()
if return_nbest:
return B
else:
best_hyps = [b.get_most_probable() for b in B]
return [hyp.ys for hyp in best_hyps]
def ctc_prefix_beam_search_attention_decoder_rescoring(
ctc_output: torch.Tensor,
attention_decoder: torch.nn.Module,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
beam: int = 8,
blank_id: int = 0,
attention_scale: Optional[float] = None,
):
# List[HypothesisList]
nbest = ctc_prefix_beam_search(
ctc_output=ctc_output,
encoder_out_lens=encoder_out_lens,
beam=beam,
blank_id=blank_id,
return_nbest=True,
)
device = ctc_output.device
hyp_shape = get_hyps_shape(nbest).to(device)
hyp_to_utt_map = hyp_shape.row_ids(1).to(torch.long)
# the shape of encoder_out is (N, T, C), so we use axis=0 here
expanded_encoder_out = encoder_out.index_select(0, hyp_to_utt_map)
expanded_encoder_out_lens = encoder_out_lens.index_select(0, hyp_to_utt_map)
nbest = [list(x) for x in nbest]
token_ids = []
scores = []
for hyps in nbest:
for hyp in hyps:
token_ids.append(hyp.ys)
scores.append(hyp.log_prob.reshape(1))
scores = torch.cat(scores).to(device)
nll = attention_decoder.nll(
encoder_out=expanded_encoder_out,
encoder_out_lens=expanded_encoder_out_lens,
token_ids=token_ids,
)
assert nll.ndim == 2
assert nll.shape[0] == len(token_ids)
attention_scores = -nll.sum(dim=1)
if attention_scale is None:
attention_scale_list = [0.01, 0.05, 0.08]
attention_scale_list += [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0]
attention_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0]
attention_scale_list += [2.1, 2.2, 2.3, 2.5, 3.0, 4.0, 5.0]
attention_scale_list += [5.0, 6.0, 7.0, 8.0, 9.0]
else:
attention_scale_list = [attention_scale]
ans = dict()
start_indexes = hyp_shape.row_splits(1)[0:-1]
for a_scale in attention_scale_list:
tot_scores = scores + a_scale * attention_scores
ragged_tot_scores = k2.RaggedTensor(hyp_shape, tot_scores)
max_indexes = ragged_tot_scores.argmax()
max_indexes = max_indexes - start_indexes
max_indexes = max_indexes.cpu()
best_path = [nbest[i][max_indexes[i]].ys for i in range(len(max_indexes))]
key = f"attention_scale_{a_scale}"
ans[key] = best_path
return ans