Add librispeech prefix-beam-search

This commit is contained in:
pkufool 2024-09-27 19:31:54 +08:00
parent cef16574f7
commit 02e00ff504
3 changed files with 366 additions and 61 deletions

View File

@ -123,6 +123,10 @@ from asr_datamodule import LibriSpeechAsrDataModule
from lhotse import set_caching_enabled
from train import add_model_arguments, get_model, get_params
from icefall.context_graph import ContextGraph, ContextState
from icefall.ngram_lm import NgramLm, NgramLmStateCost
from icefall.lm_wrapper import LmScorer
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
@ -131,6 +135,9 @@ from icefall.checkpoint import (
)
from icefall.decode import (
ctc_greedy_search,
ctc_prefix_beam_search,
ctc_prefix_beam_search_attention_decoder_rescoring,
ctc_prefix_beam_search_shallow_fussion,
get_lattice,
nbest_decoding,
nbest_oracle,
@ -280,6 +287,23 @@ def get_parser():
""",
)
parser.add_argument(
"--lm-type",
type=str,
default="rnn",
help="Type of NN lm",
choices=["rnn", "transformer"],
)
parser.add_argument(
"--lm-scale",
type=float,
default=0.3,
help="""The scale of the neural network LM
Used only when `--use-shallow-fusion` is set to True.
""",
)
parser.add_argument(
"--hlg-scale",
type=float,
@ -301,7 +325,7 @@ def get_parser():
"--skip-scoring",
type=str2bool,
default=False,
help="""Skip scoring, but still save the ASR output (for eval sets)."""
help="""Skip scoring, but still save the ASR output (for eval sets).""",
)
add_model_arguments(parser)
@ -314,8 +338,9 @@ def get_decoding_params() -> AttributeDict:
params = AttributeDict(
{
"frame_shift_ms": 10,
"search_beam": 20,
"output_beam": 8,
"search_beam": 20, # for k2 fsa composition
"output_beam": 8, # for k2 fsa composition
"beam": 4, # for prefix-beam-search
"min_active_states": 30,
"max_active_states": 10000,
"use_double_scores": True,
@ -333,6 +358,7 @@ def decode_one_batch(
batch: dict,
word_table: k2.SymbolTable,
G: Optional[k2.Fsa] = None,
LM: Optional[LmScorer] = None,
) -> Dict[str, List[List[str]]]:
"""Decode one batch and return the result in a dict. The dict has the
following format:
@ -377,10 +403,7 @@ def decode_one_batch(
Return the decoding result. See above description for the format of
the returned dict. Note: If it decodes to nothing, then return None.
"""
if HLG is not None:
device = HLG.device
else:
device = H.device
device = params.device
feature = batch["inputs"]
assert feature.ndim == 3
feature = feature.to(device)
@ -411,6 +434,48 @@ def decode_one_batch(
key = "ctc-greedy-search"
return {key: hyps}
if params.decoding_method == "prefix-beam-search":
token_ids = ctc_prefix_beam_search(
ctc_output=ctc_output, encoder_out_lens=encoder_out_lens
)
# hyps is a list of str, e.g., ['xxx yyy zzz', ...]
hyps = bpe_model.decode(token_ids)
# hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ]
hyps = [s.split() for s in hyps]
key = "prefix-beam-search"
return {key: hyps}
if params.decoding_method == "ctc-prefix-beam-search-attention-decoder-rescoring":
best_path_dict = ctc_prefix_beam_search_attention_decoder_rescoring(
ctc_output=ctc_output,
attention_decoder=model.attention_decoder,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
)
ans = dict()
for a_scale_str, token_ids in best_path_dict.items():
# hyps is a list of str, e.g., ['xxx yyy zzz', ...]
hyps = bpe_model.decode(token_ids)
# hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ]
hyps = [s.split() for s in hyps]
ans[a_scale_str] = hyps
return ans
if params.decoding_method == "ctc-prefix-beam-search-shallow-fussion":
token_ids = ctc_prefix_beam_search_shallow_fussion(
ctc_output=ctc_output,
encoder_out_lens=encoder_out_lens,
LM=LM,
)
# hyps is a list of str, e.g., ['xxx yyy zzz', ...]
hyps = bpe_model.decode(token_ids)
# hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ]
hyps = [s.split() for s in hyps]
key = "prefix-beam-search-shallow-fussion"
return {key: hyps}
supervision_segments = torch.stack(
(
supervisions["sequence_idx"],
@ -584,6 +649,7 @@ def decode_dataset(
bpe_model: Optional[spm.SentencePieceProcessor],
word_table: k2.SymbolTable,
G: Optional[k2.Fsa] = None,
LM: Optional[LmScorer] = None,
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
"""Decode dataset.
@ -634,6 +700,7 @@ def decode_dataset(
batch=batch,
word_table=word_table,
G=G,
LM=LM,
)
for name, hyps in hyps_dict.items():
@ -664,9 +731,7 @@ def save_asr_output(
"""
for key, results in results_dict.items():
recogs_filename = (
params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
)
recogs_filename = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
results = sorted(results)
store_transcripts(filename=recogs_filename, texts=results)
@ -680,7 +745,8 @@ def save_wer_results(
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
):
if params.decoding_method in (
"attention-decoder-rescoring-with-ngram", "whole-lattice-rescoring"
"attention-decoder-rescoring-with-ngram",
"whole-lattice-rescoring",
):
# Set it to False since there are too many logs.
enable_log = False
@ -721,6 +787,7 @@ def save_wer_results(
def main():
parser = get_parser()
LibriSpeechAsrDataModule.add_arguments(parser)
LmScorer.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
args.lang_dir = Path(args.lang_dir)
@ -735,8 +802,11 @@ def main():
set_caching_enabled(True) # lhotse
assert params.decoding_method in (
"ctc-greedy-search",
"ctc-decoding",
"ctc-greedy-search",
"prefix-beam-search",
"ctc-prefix-beam-search-attention-decoder-rescoring",
"ctc-prefix-beam-search-shallow-fussion",
"1best",
"nbest",
"nbest-rescoring",
@ -762,6 +832,11 @@ def main():
params.suffix += f"_chunk-{params.chunk_size}"
params.suffix += f"_left-context-{params.left_context_frames}"
if "prefix-beam-search" in params.decoding_method:
params.suffix += f"_beam-{params.beam}"
if params.decoding_method == "ctc-prefix-beam-search-shallow-fussion":
params.suffix += f"_lm-scale-{params.lm_scale}"
if params.use_averaged_model:
params.suffix += "_use-averaged-model"
@ -772,6 +847,8 @@ def main():
if torch.cuda.is_available():
device = torch.device("cuda", 0)
params.device = device
logging.info(f"Device: {device}")
logging.info(params)
@ -786,14 +863,24 @@ def main():
params.sos_id = 1
if params.decoding_method in [
"ctc-greedy-search", "ctc-decoding", "attention-decoder-rescoring-no-ngram"
"ctc-greedy-search",
"ctc-decoding",
"attention-decoder-rescoring-no-ngram",
"prefix-beam-search",
"ctc-prefix-beam-search-attention-decoder-rescoring",
"ctc-prefix-beam-search-shallow-fussion",
]:
HLG = None
H = k2.ctc_topo(
max_token=max_token_id,
modified=False,
device=device,
)
H = None
if params.decoding_method in [
"ctc-decoding",
"attention-decoder-rescoring-no-ngram",
]:
H = k2.ctc_topo(
max_token=max_token_id,
modified=False,
device=device,
)
bpe_model = spm.SentencePieceProcessor()
bpe_model.load(str(params.lang_dir / "bpe.model"))
else:
@ -844,7 +931,8 @@ def main():
G = k2.Fsa.from_dict(d)
if params.decoding_method in [
"whole-lattice-rescoring", "attention-decoder-rescoring-with-ngram"
"whole-lattice-rescoring",
"attention-decoder-rescoring-with-ngram",
]:
# Add epsilon self-loops to G as we will compose
# it with the whole lattice later
@ -858,6 +946,19 @@ def main():
else:
G = None
# only load the neural network LM if required
if params.decoding_method == "ctc-prefix-beam-search-shallow-fussion":
LM = LmScorer(
lm_type=params.lm_type,
params=params,
device=device,
lm_scale=params.lm_scale,
)
LM.to(device)
LM.eval()
else:
LM = None
logging.info("About to create model")
model = get_model(params)
@ -967,6 +1068,7 @@ def main():
bpe_model=bpe_model,
word_table=lexicon.word_table,
G=G,
LM=LM,
)
save_asr_output(

View File

@ -1511,30 +1511,43 @@ def ctc_greedy_search(
class Hypothesis:
# The predicted tokens so far.
# Newly predicted tokens are appended to `ys`.
ys: List[int]
ys: List[int] = field(default_factory=list)
# The log prob of ys.
# It contains only one entry.
log_prob_blank: torch.Tensor
log_prob_blank: torch.Tensor = torch.zeros(1, dtype=torch.float32)
log_prob_non_blank: torch.Tensor
log_prob_non_blank: torch.Tensor = torch.tensor(
[float("-inf")], dtype=torch.float32
)
# timestamp[i] is the frame index after subsampling
# on which ys[i] is decoded
timestamp: List[int] = field(default_factory=list)
# the lm score for next token given the current ys
lm_score: Optional[torch.Tensor] = None
# The lm score of ys
# It contains only one entry
lm_score: torch.Tensor = torch.zeros(1, dtype=torch.float32)
# the lm log_probs for next token given the history ys
lm_log_probs: Optional[torch.Tensor] = None
# the RNNLM states (h and c in LSTM)
state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None
# N-gram LM state
state_cost: Optional[NgramLmStateCost] = None
LODR_state: Optional[NgramLmStateCost] = None
# N-gram LM state
Ngram_state: Optional[NgramLmStateCost] = None
# Context graph state
context_state: Optional[ContextState] = None
@property
def tot_score(self) -> torch.Tensor:
return self.log_prob + self.lm_score
@property
def log_prob(self) -> torch.Tensor:
return torch.logaddexp(self.log_prob_non_blank, self.log_prob_blank)
@ -1544,6 +1557,20 @@ class Hypothesis:
"""Return a tuple representation of self.ys"""
return tuple(self.ys)
def clone(self) -> "Hypothesis":
return Hypothesis(
ys=self.ys,
log_prob_blank=self.log_prob_blank,
log_prob_non_blank=self.log_prob_non_blank,
timestamp=self.timestamp,
lm_log_probs=self.lm_log_probs,
lm_score=self.lm_score,
state=self.state,
LODR_state=self.LODR_state,
Ngram_state=self.Ngram_state,
context_state=self.context_state,
)
class HypothesisList(object):
def __init__(self, data: Optional[Dict[tuple, Hypothesis]] = None) -> None:
@ -1597,9 +1624,9 @@ class HypothesisList(object):
Return the hypothesis that has the largest `log_prob`.
"""
if length_norm:
return max(self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys))
return max(self._data.values(), key=lambda hyp: hyp.tot_score / len(hyp.ys))
else:
return max(self._data.values(), key=lambda hyp: hyp.log_prob)
return max(self._data.values(), key=lambda hyp: hyp.tot_score)
def remove(self, hyp: Hypothesis) -> None:
"""Remove a given hypothesis.
@ -1629,7 +1656,7 @@ class HypothesisList(object):
"""
ans = HypothesisList()
for _, hyp in self._data.items():
if hyp.log_prob > threshold:
if hyp.tot_score > threshold:
ans.add(hyp) # shallow copy
return ans
@ -1645,17 +1672,20 @@ class HypothesisList(object):
if length_norm:
hyps = sorted(
hyps, key=lambda h: h[1].log_prob / len(h[1].ys), reverse=True
hyps, key=lambda h: h[1].tot_score / len(h[1].ys), reverse=True
)[:k]
else:
hyps = sorted(hyps, key=lambda h: h[1].log_prob, reverse=True)[:k]
hyps = sorted(hyps, key=lambda h: h[1].tot_score, reverse=True)[:k]
ans = HypothesisList(dict(hyps))
return ans
def __contains__(self, key: str):
def __contains__(self, key: tuple):
return key in self._data
def __getitem__(self, key: tuple):
return self._data[key]
def __iter__(self):
return iter(self._data.values())
@ -1694,64 +1724,96 @@ def get_hyps_shape(hyps: List[HypothesisList]) -> k2.RaggedShape:
return ans
def _step_worker(log_probs, indexes, B, beam, blank_id):
def _step_worker(
log_probs,
indexes,
B,
beam,
blank_id,
lm_scale: float = 0,
LODR_lm_scale: float = 0,
context_graph: Optional[ContextGraph] = None,
):
A = list(B)
B = HypothesisList()
for h in range(len(A)):
hyp = A[h]
for k in range(log_probs.size(0)):
log_prob, index = log_probs[k], indexes[k]
if index == blank_id:
new_token = index.item()
update_prefix = False
new_hyp = hyp.clone()
if new_token == blank_id:
# Case 0: *a + ε => *a
# *aε + ε => *a
# Prefix does not change, update log_prob of blank
new_hyp = Hypothesis(
ys=hyp.ys[:],
log_prob_non_blank=torch.tensor(
[float("-inf")], dtype=torch.float32
),
log_prob_blank=hyp.log_prob + log_prob,
new_hyp.log_prob_non_blank = torch.tensor(
[float("-inf")], dtype=torch.float32
)
new_hyp.log_prob_blank = hyp.log_prob + log_prob
B.add(new_hyp)
elif len(hyp.ys) > 0 and hyp.ys[-1] == index:
elif len(hyp.ys) > 0 and hyp.ys[-1] == new_token:
# Case 1: *a + a => *a
# Prefix does not change, update log_prob of non_blank
new_hyp = Hypothesis(
ys=hyp.ys[:],
log_prob_non_blank=hyp.log_prob_non_blank + log_prob,
log_prob_blank=torch.tensor([float("-inf")], dtype=torch.float32),
new_hyp.log_prob_non_blank = hyp.log_prob_non_blank + log_prob
new_hyp.log_prob_blank = torch.tensor(
[float("-inf")], dtype=torch.float32
)
B.add(new_hyp)
# Case 2: *aε + a => *aa
# Prefix changes, update log_prob of blank
new_hyp = Hypothesis(
ys=hyp.ys[:] + [index.item()],
log_prob_non_blank=hyp.log_prob_blank + log_prob,
log_prob_blank=torch.tensor([float("-inf")], dtype=torch.float32),
new_hyp = hyp.clone()
# Caution: DO NOT use append, as clone is shallow copy
new_hyp.ys = hyp.ys + [new_token]
new_hyp.log_prob_non_blank = hyp.log_prob_blank + log_prob
new_hyp.log_prob_blank = torch.tensor(
[float("-inf")], dtype=torch.float32
)
B.add(new_hyp)
update_prefix = True
else:
# Case 3: *a + b => *ab, *aε + b => *ab
# Prefix changes, update log_prob of non_blank
new_hyp = Hypothesis(
ys=hyp.ys[:] + [index.item()],
log_prob_non_blank=hyp.log_prob + log_prob,
log_prob_blank=torch.tensor([float("-inf")], dtype=torch.float32),
# Caution: DO NOT use append, as clone is shallow copy
new_hyp.ys = hyp.ys + [new_token]
new_hyp.log_prob_non_blank = hyp.log_prob + log_prob
new_hyp.log_prob_blank = torch.tensor(
[float("-inf")], dtype=torch.float32
)
update_prefix = True
if update_prefix:
lm_score = hyp.lm_score
if hyp.lm_log_probs is not None:
lm_score += hyp.lm_log_probs[new_token] * lm_scale
new_hyp.lm_log_probs = None
if context_graph is not None and hyp.context_state is not None:
context_score, new_context_state = context_graph.forward_one_step(
hyp.context_state, new_token
)
lm_score += context_score
new_hyp.context_state = new_context_state
if hyp.LODR_state is not None:
state_cost = hyp.LODR_state.forward_one_step(new_token)
# calculate the score of the latest token
current_ngram_score = state_cost.lm_score - hyp.LODR_state.lm_score
assert current_ngram_score <= 0.0, (
state_cost.lm_score,
hyp.LODR_state.lm_score,
)
lm_score += LODR_lm_scale * current_ngram_score
new_hyp.LODR_state = state_cost
new_hyp.lm_score = lm_score
B.add(new_hyp)
B = B.topk(beam)
return B
def _batch_worker(topk_values, topk_indexes, B, encoder_out_lens, beam, blank_id):
B.add(
Hypothesis(
ys=[],
log_prob_non_blank=torch.tensor([float("-inf")], dtype=torch.float32),
log_prob_blank=torch.zeros(1, dtype=torch.float32),
)
)
B.add(Hypothesis())
for j in range(encoder_out_lens):
log_probs, indexes = topk_values[j], topk_indexes[j]
B = _step_worker(log_probs, indexes, B, beam, blank_id)
@ -1763,11 +1825,11 @@ def ctc_prefix_beam_search(
encoder_out_lens: torch.Tensor,
beam: int = 4,
blank_id: int = 0,
context_graph: Optional[ContextGraph] = None,
process_pool: Optional[Pool] = None,
return_nbest: Optional[bool] = False,
) -> Union[List[List[int]], List[HypothesisList]]:
batch_size, num_frames, vocab_size = ctc_output.shape
# TODO: using a larger beam for first pass pruning
topk_values, topk_indexes = ctc_output.topk(beam) # (B, T, beam)
topk_values = topk_values.cpu()
@ -1800,6 +1862,136 @@ def ctc_prefix_beam_search(
return [hyp.ys for hyp in best_hyps]
def ctc_prefix_beam_search_shallow_fussion(
ctc_output: torch.Tensor,
encoder_out_lens: torch.Tensor,
beam: int = 4,
blank_id: int = 0,
LODR_lm: Optional[NgramLm] = None,
LODR_lm_scale: Optional[float] = 0,
LM: Optional[LmScorer] = None,
context_graph: Optional[ContextGraph] = None,
) -> List[List[int]]:
batch_size, num_frames, vocab_size = ctc_output.shape
# TODO: using a larger beam for first pass pruning
topk_values, topk_indexes = ctc_output.topk(beam) # (B, T, beam)
topk_values = topk_values.cpu()
topk_indexes = topk_indexes.cpu()
encoder_out_lens = encoder_out_lens.tolist()
device = ctc_output.device
lm_scale = 0
init_scores = None
init_states = None
if LM is not None:
lm_scale = LM.lm_scale
sos_id = getattr(LM, "sos_id", 1)
# get initial lm score and lm state by scoring the "sos" token
sos_token = torch.tensor([[sos_id]]).to(torch.int64).to(device)
lens = torch.tensor([1]).to(device)
init_scores, init_states = LM.score_token(sos_token, lens)
init_scores, init_states = init_scores.cpu(), (
init_states[0].cpu(),
init_states[1].cpu(),
)
B = [HypothesisList() for _ in range(batch_size)]
for i in range(batch_size):
B[i].add(
Hypothesis(
ys=[],
log_prob_non_blank=torch.tensor([float("-inf")], dtype=torch.float32),
log_prob_blank=torch.zeros(1, dtype=torch.float32),
lm_score=torch.zeros(1, dtype=torch.float32),
state=init_states,
lm_log_probs=None if init_scores is None else init_scores.reshape(-1),
LODR_state=None if LODR_lm is None else NgramLmStateCost(LODR_lm),
context_state=None if context_graph is None else context_graph.root,
)
)
for j in range(num_frames):
for i in range(batch_size):
if j < encoder_out_lens[i]:
log_probs, indexes = topk_values[i][j], topk_indexes[i][j]
B[i] = _step_worker(
log_probs,
indexes,
B[i],
beam,
blank_id,
lm_scale=lm_scale,
LODR_lm_scale=LODR_lm_scale,
context_graph=context_graph,
)
if LM is None:
continue
# update lm_score
token_list = [] # a list of list
hs = []
cs = []
indexes = [] # (batch_idx, key)
for batch_idx, hyps in enumerate(B):
for hyp in hyps:
if hyp.lm_log_probs is None:
if LM.lm_type == "rnn":
token_list.append([hyp.ys[-1]])
# store the LSTM states
hs.append(hyp.state[0])
cs.append(hyp.state[1])
else:
# for transformer LM
token_list.append([sos_id] + hyp.ys[:])
indexes.append((batch_idx, hyp.key))
if len(token_list) != 0:
x_lens = torch.tensor([len(tokens) for tokens in token_list]).to(device)
if LM.lm_type == "rnn":
tokens_to_score = (
torch.tensor(token_list).to(torch.int64).to(device).reshape(-1, 1)
)
hs = torch.cat(hs, dim=1).to(device)
cs = torch.cat(cs, dim=1).to(device)
state = (hs, cs)
else:
# for transformer LM
tokens_list = [torch.tensor(tokens) for tokens in token_list]
tokens_to_score = (
torch.nn.utils.rnn.pad_sequence(
tokens_list, batch_first=True, padding_value=0.0
)
.to(device)
.to(torch.int64)
)
state = None
scores, lm_states = LM.score_token(tokens_to_score, x_lens, state)
scores, lm_states = scores.cpu(), (lm_states[0].cpu(), lm_states[1].cpu())
assert scores.size(0) == len(indexes), (scores.size(0), len(indexes))
for i in range(scores.size(0)):
batch_idx, key = indexes[i]
B[batch_idx][key].lm_log_probs = scores[i]
if LM.lm_type == "rnn":
state = (
lm_states[0][:, i, :].unsqueeze(1),
lm_states[1][:, i, :].unsqueeze(1),
)
B[batch_idx][key].state = state
# finalize context_state, if the matched contexts do not reach final state
# we need to add the score on the corresponding backoff arc
if context_graph is not None:
for hyps in B:
for hyp in hyps:
context_score, new_context_state = context_graph.finalize(
hyp.context_state
)
hyp.lm_score += context_score
hyp.context_state = new_context_state
best_hyps = [b.get_most_probable() for b in B]
return [hyp.ys for hyp in best_hyps]
def ctc_prefix_beam_search_attention_decoder_rescoring(
ctc_output: torch.Tensor,
attention_decoder: torch.nn.Module,

View File

@ -19,8 +19,10 @@
import argparse
import collections
import json
import logging
import os
import pathlib
import re
import subprocess
from collections import defaultdict
@ -178,6 +180,15 @@ class AttributeDict(dict):
return
raise AttributeError(f"No such attribute '{key}'")
def __str__(self, indent: int = 2):
tmp = {}
for k, v in self.items():
# PosixPath is ont JSON serializable
if isinstance(v, pathlib.Path) or isinstance(v, torch.device):
v = str(v)
tmp[k] = v
return json.dumps(tmp, indent=indent, sort_keys=True)
def encode_supervisions(
supervisions: dict,