Add prefix beam search and corresponding decoding methods (#1786)

* Add prefix beam search / shallow fussion / hotwords in librispeech ctc decode

* Add librispeech cr-ctc prefix beam search results
This commit is contained in:
Wei Kang 2024-10-30 10:14:34 +08:00 committed by GitHub
parent 6c7863c2f8
commit d513d456b8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 908 additions and 24 deletions

View File

@ -153,6 +153,7 @@ You can use <https://github.com/k2-fsa/sherpa> to deploy it.
| decoding method | test-clean | test-other | comment |
|--------------------------------------|------------|------------|---------------------|
| ctc-greedy-decoding | 2.57 | 5.95 | --epoch 50 --avg 25 |
| ctc-prefix-beam-search | 2.52 | 5.85 | --epoch 50 --avg 25 |
The training command using 2 32G-V100 GPUs is:
```bash
@ -184,7 +185,7 @@ export CUDA_VISIBLE_DEVICES="0,1"
The decoding command is:
```bash
export CUDA_VISIBLE_DEVICES="0"
for m in ctc-greedy-search; do
for m in ctc-greedy-search ctc-prefix-beam-search; do
./zipformer/ctc_decode.py \
--epoch 50 \
--avg 25 \
@ -212,6 +213,7 @@ You can use <https://github.com/k2-fsa/sherpa> to deploy it.
| decoding method | test-clean | test-other | comment |
|--------------------------------------|------------|------------|---------------------|
| ctc-greedy-decoding | 2.12 | 4.62 | --epoch 50 --avg 24 |
| ctc-prefix-beam-search | 2.1 | 4.61 | --epoch 50 --avg 24 |
The training command using 4 32G-V100 GPUs is:
```bash
@ -238,7 +240,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
The decoding command is:
```bash
export CUDA_VISIBLE_DEVICES="0"
for m in ctc-greedy-search; do
for m in ctc-greedy-search ctc-prefix-beam-search; do
./zipformer/ctc_decode.py \
--epoch 50 \
--avg 24 \
@ -262,6 +264,7 @@ You can use <https://github.com/k2-fsa/sherpa> to deploy it.
| decoding method | test-clean | test-other | comment |
|--------------------------------------|------------|------------|---------------------|
| ctc-greedy-decoding | 2.03 | 4.37 | --epoch 50 --avg 26 |
| ctc-prefix-beam-search | 2.02 | 4.35 | --epoch 50 --avg 26 |
The training command using 2 80G-A100 GPUs is:
```bash
@ -292,7 +295,7 @@ export CUDA_VISIBLE_DEVICES="0,1"
The decoding command is:
```bash
export CUDA_VISIBLE_DEVICES="0"
for m in ctc-greedy-search; do
for m in ctc-greedy-search ctc-prefix-beam-search; do
./zipformer/ctc_decode.py \
--epoch 50 \
--avg 26 \

View File

@ -111,6 +111,7 @@ Usage:
import argparse
import logging
import math
import os
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Optional, Tuple
@ -129,8 +130,14 @@ from icefall.checkpoint import (
find_checkpoints,
load_checkpoint,
)
from icefall.context_graph import ContextGraph, ContextState
from icefall.decode import (
ctc_greedy_search,
ctc_prefix_beam_search,
ctc_prefix_beam_search_attention_decoder_rescoring,
ctc_prefix_beam_search_shallow_fussion,
get_lattice,
nbest_decoding,
nbest_oracle,
@ -140,7 +147,11 @@ from icefall.decode import (
rescore_with_n_best_list,
rescore_with_whole_lattice,
)
from icefall.ngram_lm import NgramLm, NgramLmStateCost
from icefall.lexicon import Lexicon
from icefall.lm_wrapper import LmScorer
from icefall.utils import (
AttributeDict,
get_texts,
@ -255,6 +266,12 @@ def get_parser():
lattice, rescore them with the attention decoder.
- (9) attention-decoder-rescoring-with-ngram. Extract n paths from the LM
rescored lattice, rescore them with the attention decoder.
- (10) ctc-prefix-beam-search. Extract n paths with the given beam, the best
path of the n paths is the decoding result.
- (11) ctc-prefix-beam-search-attention-decoder-rescoring. Extract n paths with
the given beam, rescore them with the attention decoder.
- (12) ctc-prefix-beam-search-shallow-fussion. Use NNLM shallow fussion during
beam search, LODR and hotwords are also supported in this decoding method.
""",
)
@ -280,6 +297,23 @@ def get_parser():
""",
)
parser.add_argument(
"--nnlm-type",
type=str,
default="rnn",
help="Type of NN lm",
choices=["rnn", "transformer"],
)
parser.add_argument(
"--nnlm-scale",
type=float,
default=0,
help="""The scale of the neural network LM, 0 means don't use nnlm shallow fussion.
Used only when `--use-shallow-fusion` is set to True.
""",
)
parser.add_argument(
"--hlg-scale",
type=float,
@ -297,11 +331,52 @@ def get_parser():
""",
)
parser.add_argument(
"--backoff-id",
type=int,
default=500,
help="ID of the backoff symbol in the ngram LM",
)
parser.add_argument(
"--lodr-ngram",
type=str,
help="The path to the lodr ngram",
)
parser.add_argument(
"--lodr-lm-scale",
type=float,
default=0,
help="The scale of lodr ngram, should be less than 0. 0 means don't use lodr.",
)
parser.add_argument(
"--context-score",
type=float,
default=0,
help="""
The bonus score of each token for the context biasing words/phrases.
0 means don't use contextual biasing.
Used only when --decoding-method is ctc-prefix-beam-search-shallow-fussion.
""",
)
parser.add_argument(
"--context-file",
type=str,
default="",
help="""
The path of the context biasing lists, one word/phrase each line
Used only when --decoding-method is ctc-prefix-beam-search-shallow-fussion.
""",
)
parser.add_argument(
"--skip-scoring",
type=str2bool,
default=False,
help="""Skip scoring, but still save the ASR output (for eval sets)."""
help="""Skip scoring, but still save the ASR output (for eval sets).""",
)
add_model_arguments(parser)
@ -314,11 +389,12 @@ def get_decoding_params() -> AttributeDict:
params = AttributeDict(
{
"frame_shift_ms": 10,
"search_beam": 20,
"output_beam": 8,
"search_beam": 20, # for k2 fsa composition
"output_beam": 8, # for k2 fsa composition
"min_active_states": 30,
"max_active_states": 10000,
"use_double_scores": True,
"beam": 4, # for prefix-beam-search
}
)
return params
@ -333,6 +409,9 @@ def decode_one_batch(
batch: dict,
word_table: k2.SymbolTable,
G: Optional[k2.Fsa] = None,
NNLM: Optional[LmScorer] = None,
LODR_lm: Optional[NgramLm] = None,
context_graph: Optional[ContextGraph] = None,
) -> Dict[str, List[List[str]]]:
"""Decode one batch and return the result in a dict. The dict has the
following format:
@ -377,10 +456,7 @@ def decode_one_batch(
Return the decoding result. See above description for the format of
the returned dict. Note: If it decodes to nothing, then return None.
"""
if HLG is not None:
device = HLG.device
else:
device = H.device
device = params.device
feature = batch["inputs"]
assert feature.ndim == 3
feature = feature.to(device)
@ -411,6 +487,51 @@ def decode_one_batch(
key = "ctc-greedy-search"
return {key: hyps}
if params.decoding_method == "ctc-prefix-beam-search":
token_ids = ctc_prefix_beam_search(
ctc_output=ctc_output, encoder_out_lens=encoder_out_lens
)
# hyps is a list of str, e.g., ['xxx yyy zzz', ...]
hyps = bpe_model.decode(token_ids)
# hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ]
hyps = [s.split() for s in hyps]
key = "prefix-beam-search"
return {key: hyps}
if params.decoding_method == "ctc-prefix-beam-search-attention-decoder-rescoring":
best_path_dict = ctc_prefix_beam_search_attention_decoder_rescoring(
ctc_output=ctc_output,
attention_decoder=model.attention_decoder,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
)
ans = dict()
for a_scale_str, token_ids in best_path_dict.items():
# hyps is a list of str, e.g., ['xxx yyy zzz', ...]
hyps = bpe_model.decode(token_ids)
# hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ]
hyps = [s.split() for s in hyps]
ans[a_scale_str] = hyps
return ans
if params.decoding_method == "ctc-prefix-beam-search-shallow-fussion":
token_ids = ctc_prefix_beam_search_shallow_fussion(
ctc_output=ctc_output,
encoder_out_lens=encoder_out_lens,
NNLM=NNLM,
LODR_lm=LODR_lm,
LODR_lm_scale=params.lodr_lm_scale,
context_graph=context_graph,
)
# hyps is a list of str, e.g., ['xxx yyy zzz', ...]
hyps = bpe_model.decode(token_ids)
# hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ]
hyps = [s.split() for s in hyps]
key = "prefix-beam-search-shallow-fussion"
return {key: hyps}
supervision_segments = torch.stack(
(
supervisions["sequence_idx"],
@ -584,6 +705,9 @@ def decode_dataset(
bpe_model: Optional[spm.SentencePieceProcessor],
word_table: k2.SymbolTable,
G: Optional[k2.Fsa] = None,
NNLM: Optional[LmScorer] = None,
LODR_lm: Optional[NgramLm] = None,
context_graph: Optional[ContextGraph] = None,
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
"""Decode dataset.
@ -634,6 +758,9 @@ def decode_dataset(
batch=batch,
word_table=word_table,
G=G,
NNLM=NNLM,
LODR_lm=LODR_lm,
context_graph=context_graph,
)
for name, hyps in hyps_dict.items():
@ -664,9 +791,7 @@ def save_asr_output(
"""
for key, results in results_dict.items():
recogs_filename = (
params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
)
recogs_filename = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
results = sorted(results)
store_transcripts(filename=recogs_filename, texts=results)
@ -680,7 +805,8 @@ def save_wer_results(
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
):
if params.decoding_method in (
"attention-decoder-rescoring-with-ngram", "whole-lattice-rescoring"
"attention-decoder-rescoring-with-ngram",
"whole-lattice-rescoring",
):
# Set it to False since there are too many logs.
enable_log = False
@ -721,6 +847,7 @@ def save_wer_results(
def main():
parser = get_parser()
LibriSpeechAsrDataModule.add_arguments(parser)
LmScorer.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
args.lang_dir = Path(args.lang_dir)
@ -735,8 +862,11 @@ def main():
set_caching_enabled(True) # lhotse
assert params.decoding_method in (
"ctc-greedy-search",
"ctc-decoding",
"ctc-greedy-search",
"ctc-prefix-beam-search",
"ctc-prefix-beam-search-attention-decoder-rescoring",
"ctc-prefix-beam-search-shallow-fussion",
"1best",
"nbest",
"nbest-rescoring",
@ -762,6 +892,16 @@ def main():
params.suffix += f"_chunk-{params.chunk_size}"
params.suffix += f"_left-context-{params.left_context_frames}"
if "prefix-beam-search" in params.decoding_method:
params.suffix += f"_beam-{params.beam}"
if params.decoding_method == "ctc-prefix-beam-search-shallow-fussion":
if params.nnlm_scale != 0:
params.suffix += f"_nnlm-scale-{params.nnlm_scale}"
if params.lodr_lm_scale != 0:
params.suffix += f"_lodr-scale-{params.lodr_lm_scale}"
if params.context_score != 0:
params.suffix += f"_context_score-{params.context_score}"
if params.use_averaged_model:
params.suffix += "_use-averaged-model"
@ -771,6 +911,7 @@ def main():
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
params.device = device
logging.info(f"Device: {device}")
logging.info(params)
@ -786,14 +927,24 @@ def main():
params.sos_id = 1
if params.decoding_method in [
"ctc-greedy-search", "ctc-decoding", "attention-decoder-rescoring-no-ngram"
"ctc-decoding",
"ctc-greedy-search",
"ctc-prefix-beam-search",
"ctc-prefix-beam-search-attention-decoder-rescoring",
"ctc-prefix-beam-search-shallow-fussion",
"attention-decoder-rescoring-no-ngram",
]:
HLG = None
H = k2.ctc_topo(
max_token=max_token_id,
modified=False,
device=device,
)
H = None
if params.decoding_method in [
"ctc-decoding",
"attention-decoder-rescoring-no-ngram",
]:
H = k2.ctc_topo(
max_token=max_token_id,
modified=False,
device=device,
)
bpe_model = spm.SentencePieceProcessor()
bpe_model.load(str(params.lang_dir / "bpe.model"))
else:
@ -844,7 +995,8 @@ def main():
G = k2.Fsa.from_dict(d)
if params.decoding_method in [
"whole-lattice-rescoring", "attention-decoder-rescoring-with-ngram"
"whole-lattice-rescoring",
"attention-decoder-rescoring-with-ngram",
]:
# Add epsilon self-loops to G as we will compose
# it with the whole lattice later
@ -858,6 +1010,51 @@ def main():
else:
G = None
# only load the neural network LM if required
NNLM = None
if (
params.decoding_method == "ctc-prefix-beam-search-shallow-fussion"
and params.nnlm_scale != 0
):
NNLM = LmScorer(
lm_type=params.nnlm_type,
params=params,
device=device,
lm_scale=params.nnlm_scale,
)
NNLM.to(device)
NNLM.eval()
LODR_lm = None
if (
params.decoding_method == "ctc-prefix-beam-search-shallow-fussion"
and params.lodr_lm_scale != 0
):
assert os.path.exists(
params.lodr_ngram
), f"LODR ngram does not exists, given path : {params.lodr_ngram}"
logging.info(f"Loading LODR (token level lm): {params.lodr_ngram}")
LODR_lm = NgramLm(
params.lodr_ngram,
backoff_id=params.backoff_id,
is_binary=False,
)
logging.info(f"num states: {LODR_lm.lm.num_states}")
context_graph = None
if (
params.decoding_method == "ctc-prefix-beam-search-shallow-fussion"
and params.context_score != 0
):
assert os.path.exists(
params.context_file
), f"context_file does not exists, given path : {params.context_file}"
contexts = []
for line in open(params.context_file).readlines():
contexts.append(bpe_model.encode(line.strip()))
context_graph = ContextGraph(params.context_score)
context_graph.build(contexts)
logging.info("About to create model")
model = get_model(params)
@ -967,6 +1164,9 @@ def main():
bpe_model=bpe_model,
word_table=lexicon.word_table,
G=G,
NNLM=NNLM,
LODR_lm=LODR_lm,
context_graph=context_graph,
)
save_asr_output(

View File

@ -1,4 +1,5 @@
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang,
# Wei Kang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
@ -15,11 +16,16 @@
# limitations under the License.
import logging
from typing import Dict, List, Optional, Union
from dataclasses import dataclass, field
from multiprocessing.pool import Pool
from typing import Dict, List, Optional, Tuple, Union
import k2
import torch
from icefall.context_graph import ContextGraph, ContextState
from icefall.ngram_lm import NgramLm, NgramLmStateCost
from icefall.lm_wrapper import LmScorer
from icefall.utils import add_eos, add_sos, get_texts
DEFAULT_LM_SCALE = [
@ -1497,3 +1503,667 @@ def ctc_greedy_search(
hyps = [h[h != blank_id].tolist() for h in hyps]
return hyps
@dataclass
class Hypothesis:
# The predicted tokens so far.
# Newly predicted tokens are appended to `ys`.
ys: List[int] = field(default_factory=list)
# The log prob of ys that ends with blank token.
# It contains only one entry.
log_prob_blank: torch.Tensor = torch.zeros(1, dtype=torch.float32)
# The log prob of ys that ends with non blank token.
# It contains only one entry.
log_prob_non_blank: torch.Tensor = torch.tensor(
[float("-inf")], dtype=torch.float32
)
# timestamp[i] is the frame index after subsampling
# on which ys[i] is decoded
timestamp: List[int] = field(default_factory=list)
# The lm score of ys
# May contain external LM score (including LODR score) and contextual biasing score
# It contains only one entry
lm_score: torch.Tensor = torch.zeros(1, dtype=torch.float32)
# the lm log_probs for next token given the history ys
# The number of elements should be equal to vocabulary size.
lm_log_probs: Optional[torch.Tensor] = None
# the RNNLM states (h and c in LSTM)
state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None
# LODR (N-gram LM) state
LODR_state: Optional[NgramLmStateCost] = None
# N-gram LM state
Ngram_state: Optional[NgramLmStateCost] = None
# Context graph state
context_state: Optional[ContextState] = None
# This is the total score of current path, acoustic plus external LM score.
@property
def tot_score(self) -> torch.Tensor:
return self.log_prob + self.lm_score
# This is only the probability from model output (i.e External LM score not included).
@property
def log_prob(self) -> torch.Tensor:
return torch.logaddexp(self.log_prob_non_blank, self.log_prob_blank)
@property
def key(self) -> tuple:
"""Return a tuple representation of self.ys"""
return tuple(self.ys)
def clone(self) -> "Hypothesis":
return Hypothesis(
ys=self.ys,
log_prob_blank=self.log_prob_blank,
log_prob_non_blank=self.log_prob_non_blank,
timestamp=self.timestamp,
lm_log_probs=self.lm_log_probs,
lm_score=self.lm_score,
state=self.state,
LODR_state=self.LODR_state,
Ngram_state=self.Ngram_state,
context_state=self.context_state,
)
class HypothesisList(object):
def __init__(self, data: Optional[Dict[tuple, Hypothesis]] = None) -> None:
"""
Args:
data:
A dict of Hypotheses. Its key is its `value.key`.
"""
if data is None:
self._data = {}
else:
self._data = data
@property
def data(self) -> Dict[tuple, Hypothesis]:
return self._data
def add(self, hyp: Hypothesis) -> None:
"""Add a Hypothesis to `self`.
If `hyp` already exists in `self`, its probability is updated using
`log-sum-exp` with the existed one.
Args:
hyp:
The hypothesis to be added.
"""
key = hyp.key
if key in self:
old_hyp = self._data[key] # shallow copy
torch.logaddexp(
old_hyp.log_prob_blank, hyp.log_prob_blank, out=old_hyp.log_prob_blank
)
torch.logaddexp(
old_hyp.log_prob_non_blank,
hyp.log_prob_non_blank,
out=old_hyp.log_prob_non_blank,
)
else:
self._data[key] = hyp
def get_most_probable(self, length_norm: bool = False) -> Hypothesis:
"""Get the most probable hypothesis, i.e., the one with
the largest `tot_score`.
Args:
length_norm:
If True, the `tot_score` of a hypothesis is normalized by the
number of tokens in it.
Returns:
Return the hypothesis that has the largest `tot_score`.
"""
if length_norm:
return max(self._data.values(), key=lambda hyp: hyp.tot_score / len(hyp.ys))
else:
return max(self._data.values(), key=lambda hyp: hyp.tot_score)
def remove(self, hyp: Hypothesis) -> None:
"""Remove a given hypothesis.
Caution:
`self` is modified **in-place**.
Args:
hyp:
The hypothesis to be removed from `self`.
Note: It must be contained in `self`. Otherwise,
an exception is raised.
"""
key = hyp.key
assert key in self, f"{key} does not exist"
del self._data[key]
def filter(self, threshold: torch.Tensor) -> "HypothesisList":
"""Remove all Hypotheses whose tot_score is less than threshold.
Caution:
`self` is not modified. Instead, a new HypothesisList is returned.
Returns:
Return a new HypothesisList containing all hypotheses from `self`
with `tot_score` being greater than the given `threshold`.
"""
ans = HypothesisList()
for _, hyp in self._data.items():
if hyp.tot_score > threshold:
ans.add(hyp) # shallow copy
return ans
def topk(self, k: int, length_norm: bool = False) -> "HypothesisList":
"""Return the top-k hypothesis.
Args:
length_norm:
If True, the `tot_score` of a hypothesis is normalized by the
number of tokens in it.
"""
hyps = list(self._data.items())
if length_norm:
hyps = sorted(
hyps, key=lambda h: h[1].tot_score / len(h[1].ys), reverse=True
)[:k]
else:
hyps = sorted(hyps, key=lambda h: h[1].tot_score, reverse=True)[:k]
ans = HypothesisList(dict(hyps))
return ans
def __contains__(self, key: tuple):
return key in self._data
def __getitem__(self, key: tuple):
return self._data[key]
def __iter__(self):
return iter(self._data.values())
def __len__(self) -> int:
return len(self._data)
def __str__(self) -> str:
s = []
for key in self:
s.append(key)
return ", ".join(str(s))
def get_hyps_shape(hyps: List[HypothesisList]) -> k2.RaggedShape:
"""Return a ragged shape with axes [utt][num_hyps].
Args:
hyps:
len(hyps) == batch_size. It contains the current hypothesis for
each utterance in the batch.
Returns:
Return a ragged shape with 2 axes [utt][num_hyps]. Note that
the shape is on CPU.
"""
num_hyps = [len(h) for h in hyps]
# torch.cumsum() is inclusive sum, so we put a 0 at the beginning
# to get exclusive sum later.
num_hyps.insert(0, 0)
num_hyps = torch.tensor(num_hyps)
row_splits = torch.cumsum(num_hyps, dim=0, dtype=torch.int32)
ans = k2.ragged.create_ragged_shape2(
row_splits=row_splits, cached_tot_size=row_splits[-1].item()
)
return ans
def _step_worker(
log_probs: torch.Tensor,
indexes: torch.Tensor,
B: HypothesisList,
beam: int = 4,
blank_id: int = 0,
nnlm_scale: float = 0,
LODR_lm_scale: float = 0,
context_graph: Optional[ContextGraph] = None,
) -> HypothesisList:
"""The worker to decode one step.
Args:
log_probs:
topk log_probs of current step (i.e. the kept tokens of first pass pruning),
the shape is (beam,)
topk_indexes:
The indexes of the topk_values above, the shape is (beam,)
B:
An instance of HypothesisList containing the kept hypothesis.
beam:
The number of hypothesis to be kept at each step.
blank_id:
The id of blank in the vocabulary.
lm_scale:
The scale of nn lm.
LODR_lm_scale:
The scale of the LODR_lm
context_graph:
A ContextGraph instance containing contextual phrases.
Return:
Returns the updated HypothesisList.
"""
A = list(B)
B = HypothesisList()
for h in range(len(A)):
hyp = A[h]
for k in range(log_probs.size(0)):
log_prob, index = log_probs[k], indexes[k]
new_token = index.item()
update_prefix = False
new_hyp = hyp.clone()
if new_token == blank_id:
# Case 0: *a + ε => *a
# *aε + ε => *a
# Prefix does not change, update log_prob of blank
new_hyp.log_prob_non_blank = torch.tensor(
[float("-inf")], dtype=torch.float32
)
new_hyp.log_prob_blank = hyp.log_prob + log_prob
B.add(new_hyp)
elif len(hyp.ys) > 0 and hyp.ys[-1] == new_token:
# Case 1: *a + a => *a
# Prefix does not change, update log_prob of non_blank
new_hyp.log_prob_non_blank = hyp.log_prob_non_blank + log_prob
new_hyp.log_prob_blank = torch.tensor(
[float("-inf")], dtype=torch.float32
)
B.add(new_hyp)
# Case 2: *aε + a => *aa
# Prefix changes, update log_prob of blank
new_hyp = hyp.clone()
# Caution: DO NOT use append, as clone is shallow copy
new_hyp.ys = hyp.ys + [new_token]
new_hyp.log_prob_non_blank = hyp.log_prob_blank + log_prob
new_hyp.log_prob_blank = torch.tensor(
[float("-inf")], dtype=torch.float32
)
update_prefix = True
else:
# Case 3: *a + b => *ab, *aε + b => *ab
# Prefix changes, update log_prob of non_blank
# Caution: DO NOT use append, as clone is shallow copy
new_hyp.ys = hyp.ys + [new_token]
new_hyp.log_prob_non_blank = hyp.log_prob + log_prob
new_hyp.log_prob_blank = torch.tensor(
[float("-inf")], dtype=torch.float32
)
update_prefix = True
if update_prefix:
lm_score = hyp.lm_score
if hyp.lm_log_probs is not None:
lm_score = lm_score + hyp.lm_log_probs[new_token] * nnlm_scale
new_hyp.lm_log_probs = None
if context_graph is not None and hyp.context_state is not None:
(
context_score,
new_context_state,
matched_state,
) = context_graph.forward_one_step(hyp.context_state, new_token)
lm_score = lm_score + context_score
new_hyp.context_state = new_context_state
if hyp.LODR_state is not None:
state_cost = hyp.LODR_state.forward_one_step(new_token)
# calculate the score of the latest token
current_ngram_score = state_cost.lm_score - hyp.LODR_state.lm_score
assert current_ngram_score <= 0.0, (
state_cost.lm_score,
hyp.LODR_state.lm_score,
)
lm_score = lm_score + LODR_lm_scale * current_ngram_score
new_hyp.LODR_state = state_cost
new_hyp.lm_score = lm_score
B.add(new_hyp)
B = B.topk(beam)
return B
def _sequence_worker(
topk_values: torch.Tensor,
topk_indexes: torch.Tensor,
B: HypothesisList,
encoder_out_lens: torch.Tensor,
beam: int = 4,
blank_id: int = 0,
) -> HypothesisList:
"""The worker to decode one sequence.
Args:
topk_values:
topk log_probs of model output (i.e. the kept tokens of first pass pruning),
the shape is (T, beam)
topk_indexes:
The indexes of the topk_values above, the shape is (T, beam)
B:
An instance of HypothesisList containing the kept hypothesis.
encoder_out_lens:
The lengths (frames) of sequences after subsampling, the shape is (B,)
beam:
The number of hypothesis to be kept at each step.
blank_id:
The id of blank in the vocabulary.
Return:
Returns the updated HypothesisList.
"""
B.add(Hypothesis())
for j in range(encoder_out_lens):
log_probs, indexes = topk_values[j], topk_indexes[j]
B = _step_worker(log_probs, indexes, B, beam, blank_id)
return B
def ctc_prefix_beam_search(
ctc_output: torch.Tensor,
encoder_out_lens: torch.Tensor,
beam: int = 4,
blank_id: int = 0,
process_pool: Optional[Pool] = None,
return_nbest: Optional[bool] = False,
) -> Union[List[List[int]], List[HypothesisList]]:
"""Implement prefix search decoding in "Connectionist Temporal Classification:
Labelling Unsegmented Sequence Data with Recurrent Neural Networks".
Args:
ctc_output:
The output of ctc head (log probability), the shape is (B, T, V)
encoder_out_lens:
The lengths (frames) of sequences after subsampling, the shape is (B,)
beam:
The number of hypothesis to be kept at each step.
blank_id:
The id of blank in the vocabulary.
process_pool:
The process pool for parallel decoding, if not provided, it will use all
you cpu cores by default.
return_nbest:
If true, return a list of HypothesisList, return a list of list of decoded token ids otherwise.
"""
batch_size, num_frames, vocab_size = ctc_output.shape
# TODO: using a larger beam for first pass pruning
topk_values, topk_indexes = ctc_output.topk(beam) # (B, T, beam)
topk_values = topk_values.cpu()
topk_indexes = topk_indexes.cpu()
B = [HypothesisList() for _ in range(batch_size)]
pool = Pool() if process_pool is None else process_pool
arguments = []
for i in range(batch_size):
arguments.append(
(
topk_values[i],
topk_indexes[i],
B[i],
encoder_out_lens[i].item(),
beam,
blank_id,
)
)
async_results = pool.starmap_async(_sequence_worker, arguments)
B = list(async_results.get())
if process_pool is None:
pool.close()
pool.join()
if return_nbest:
return B
else:
best_hyps = [b.get_most_probable() for b in B]
return [hyp.ys for hyp in best_hyps]
def ctc_prefix_beam_search_shallow_fussion(
ctc_output: torch.Tensor,
encoder_out_lens: torch.Tensor,
beam: int = 4,
blank_id: int = 0,
LODR_lm: Optional[NgramLm] = None,
LODR_lm_scale: Optional[float] = 0,
NNLM: Optional[LmScorer] = None,
context_graph: Optional[ContextGraph] = None,
) -> List[List[int]]:
"""Implement prefix search decoding in "Connectionist Temporal Classification:
Labelling Unsegmented Sequence Data with Recurrent Neural Networks" and add
nervous language model shallow fussion, it also supports contextual
biasing with a given grammar.
Args:
ctc_output:
The output of ctc head (log probability), the shape is (B, T, V)
encoder_out_lens:
The lengths (frames) of sequences after subsampling, the shape is (B,)
beam:
The number of hypothesis to be kept at each step.
blank_id:
The id of blank in the vocabulary.
LODR_lm:
A low order n-gram LM, whose score will be subtracted during shallow fusion
LODR_lm_scale:
The scale of the LODR_lm
LM:
A neural net LM, e.g an RNNLM or transformer LM
context_graph:
A ContextGraph instance containing contextual phrases.
Return:
Returns a list of list of decoded token ids.
"""
batch_size, num_frames, vocab_size = ctc_output.shape
# TODO: using a larger beam for first pass pruning
topk_values, topk_indexes = ctc_output.topk(beam) # (B, T, beam)
topk_values = topk_values.cpu()
topk_indexes = topk_indexes.cpu()
encoder_out_lens = encoder_out_lens.tolist()
device = ctc_output.device
nnlm_scale = 0
init_scores = None
init_states = None
if NNLM is not None:
nnlm_scale = NNLM.lm_scale
sos_id = getattr(NNLM, "sos_id", 1)
# get initial lm score and lm state by scoring the "sos" token
sos_token = torch.tensor([[sos_id]]).to(torch.int64).to(device)
lens = torch.tensor([1]).to(device)
init_scores, init_states = NNLM.score_token(sos_token, lens)
init_scores, init_states = init_scores.cpu(), (
init_states[0].cpu(),
init_states[1].cpu(),
)
B = [HypothesisList() for _ in range(batch_size)]
for i in range(batch_size):
B[i].add(
Hypothesis(
ys=[],
log_prob_non_blank=torch.tensor([float("-inf")], dtype=torch.float32),
log_prob_blank=torch.zeros(1, dtype=torch.float32),
lm_score=torch.zeros(1, dtype=torch.float32),
state=init_states,
lm_log_probs=None if init_scores is None else init_scores.reshape(-1),
LODR_state=None if LODR_lm is None else NgramLmStateCost(LODR_lm),
context_state=None if context_graph is None else context_graph.root,
)
)
for j in range(num_frames):
for i in range(batch_size):
if j < encoder_out_lens[i]:
log_probs, indexes = topk_values[i][j], topk_indexes[i][j]
B[i] = _step_worker(
log_probs=log_probs,
indexes=indexes,
B=B[i],
beam=beam,
blank_id=blank_id,
nnlm_scale=nnlm_scale,
LODR_lm_scale=LODR_lm_scale,
context_graph=context_graph,
)
if NNLM is None:
continue
# update lm_log_probs
token_list = [] # a list of list
hs = []
cs = []
indexes = [] # (batch_idx, key)
for batch_idx, hyps in enumerate(B):
for hyp in hyps:
if hyp.lm_log_probs is None: # those hyps that prefix changes
if NNLM.lm_type == "rnn":
token_list.append([hyp.ys[-1]])
# store the LSTM states
hs.append(hyp.state[0])
cs.append(hyp.state[1])
else:
# for transformer LM
token_list.append([sos_id] + hyp.ys[:])
indexes.append((batch_idx, hyp.key))
if len(token_list) != 0:
x_lens = torch.tensor([len(tokens) for tokens in token_list]).to(device)
if NNLM.lm_type == "rnn":
tokens_to_score = (
torch.tensor(token_list).to(torch.int64).to(device).reshape(-1, 1)
)
hs = torch.cat(hs, dim=1).to(device)
cs = torch.cat(cs, dim=1).to(device)
state = (hs, cs)
else:
# for transformer LM
tokens_list = [torch.tensor(tokens) for tokens in token_list]
tokens_to_score = (
torch.nn.utils.rnn.pad_sequence(
tokens_list, batch_first=True, padding_value=0.0
)
.to(device)
.to(torch.int64)
)
state = None
scores, lm_states = NNLM.score_token(tokens_to_score, x_lens, state)
scores, lm_states = scores.cpu(), (lm_states[0].cpu(), lm_states[1].cpu())
assert scores.size(0) == len(indexes), (scores.size(0), len(indexes))
for i in range(scores.size(0)):
batch_idx, key = indexes[i]
B[batch_idx][key].lm_log_probs = scores[i]
if NNLM.lm_type == "rnn":
state = (
lm_states[0][:, i, :].unsqueeze(1),
lm_states[1][:, i, :].unsqueeze(1),
)
B[batch_idx][key].state = state
# finalize context_state, if the matched contexts do not reach final state
# we need to add the score on the corresponding backoff arc
if context_graph is not None:
for hyps in B:
for hyp in hyps:
context_score, new_context_state = context_graph.finalize(
hyp.context_state
)
hyp.lm_score += context_score
hyp.context_state = new_context_state
best_hyps = [b.get_most_probable() for b in B]
return [hyp.ys for hyp in best_hyps]
def ctc_prefix_beam_search_attention_decoder_rescoring(
ctc_output: torch.Tensor,
attention_decoder: torch.nn.Module,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
beam: int = 8,
blank_id: int = 0,
attention_scale: Optional[float] = None,
process_pool: Optional[Pool] = None,
):
"""Implement prefix search decoding in "Connectionist Temporal Classification:
Labelling Unsegmented Sequence Data with Recurrent Neural Networks" and add
attention decoder rescoring.
Args:
ctc_output:
The output of ctc head (log probability), the shape is (B, T, V)
attention_decoder:
The attention decoder.
encoder_out:
The output of encoder, the shape is (B, T, D)
encoder_out_lens:
The lengths (frames) of sequences after subsampling, the shape is (B,)
beam:
The number of hypothesis to be kept at each step.
blank_id:
The id of blank in the vocabulary.
attention_scale:
The scale of attention decoder score, if not provided it will search in
a default list (see the code below).
process_pool:
The process pool for parallel decoding, if not provided, it will use all
you cpu cores by default.
"""
# List[HypothesisList]
nbest = ctc_prefix_beam_search(
ctc_output=ctc_output,
encoder_out_lens=encoder_out_lens,
beam=beam,
blank_id=blank_id,
return_nbest=True,
)
device = ctc_output.device
hyp_shape = get_hyps_shape(nbest).to(device)
hyp_to_utt_map = hyp_shape.row_ids(1).to(torch.long)
# the shape of encoder_out is (N, T, C), so we use axis=0 here
expanded_encoder_out = encoder_out.index_select(0, hyp_to_utt_map)
expanded_encoder_out_lens = encoder_out_lens.index_select(0, hyp_to_utt_map)
nbest = [list(x) for x in nbest]
token_ids = []
scores = []
for hyps in nbest:
for hyp in hyps:
token_ids.append(hyp.ys)
scores.append(hyp.log_prob.reshape(1))
scores = torch.cat(scores).to(device)
nll = attention_decoder.nll(
encoder_out=expanded_encoder_out,
encoder_out_lens=expanded_encoder_out_lens,
token_ids=token_ids,
)
assert nll.ndim == 2
assert nll.shape[0] == len(token_ids)
attention_scores = -nll.sum(dim=1)
if attention_scale is None:
attention_scale_list = [0.01, 0.05, 0.08]
attention_scale_list += [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0]
attention_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0]
attention_scale_list += [2.1, 2.2, 2.3, 2.5, 3.0, 4.0, 5.0]
attention_scale_list += [5.0, 6.0, 7.0, 8.0, 9.0]
else:
attention_scale_list = [attention_scale]
ans = dict()
start_indexes = hyp_shape.row_splits(1)[0:-1]
for a_scale in attention_scale_list:
tot_scores = scores + a_scale * attention_scores
ragged_tot_scores = k2.RaggedTensor(hyp_shape, tot_scores)
max_indexes = ragged_tot_scores.argmax()
max_indexes = max_indexes - start_indexes
max_indexes = max_indexes.cpu()
best_path = [nbest[i][max_indexes[i]].ys for i in range(len(max_indexes))]
key = f"attention_scale_{a_scale}"
ans[key] = best_path
return ans

View File

@ -19,8 +19,10 @@
import argparse
import collections
import json
import logging
import os
import pathlib
import random
import re
import subprocess
@ -180,6 +182,15 @@ class AttributeDict(dict):
return
raise AttributeError(f"No such attribute '{key}'")
def __str__(self, indent: int = 2):
tmp = {}
for k, v in self.items():
# PosixPath is ont JSON serializable
if isinstance(v, pathlib.Path) or isinstance(v, torch.device):
v = str(v)
tmp[k] = v
return json.dumps(tmp, indent=indent, sort_keys=True)
def encode_supervisions(
supervisions: dict,