mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 10:02:22 +00:00
Add prefix beam search and corresponding decoding methods (#1786)
* Add prefix beam search / shallow fussion / hotwords in librispeech ctc decode * Add librispeech cr-ctc prefix beam search results
This commit is contained in:
parent
6c7863c2f8
commit
d513d456b8
@ -153,6 +153,7 @@ You can use <https://github.com/k2-fsa/sherpa> to deploy it.
|
|||||||
| decoding method | test-clean | test-other | comment |
|
| decoding method | test-clean | test-other | comment |
|
||||||
|--------------------------------------|------------|------------|---------------------|
|
|--------------------------------------|------------|------------|---------------------|
|
||||||
| ctc-greedy-decoding | 2.57 | 5.95 | --epoch 50 --avg 25 |
|
| ctc-greedy-decoding | 2.57 | 5.95 | --epoch 50 --avg 25 |
|
||||||
|
| ctc-prefix-beam-search | 2.52 | 5.85 | --epoch 50 --avg 25 |
|
||||||
|
|
||||||
The training command using 2 32G-V100 GPUs is:
|
The training command using 2 32G-V100 GPUs is:
|
||||||
```bash
|
```bash
|
||||||
@ -184,7 +185,7 @@ export CUDA_VISIBLE_DEVICES="0,1"
|
|||||||
The decoding command is:
|
The decoding command is:
|
||||||
```bash
|
```bash
|
||||||
export CUDA_VISIBLE_DEVICES="0"
|
export CUDA_VISIBLE_DEVICES="0"
|
||||||
for m in ctc-greedy-search; do
|
for m in ctc-greedy-search ctc-prefix-beam-search; do
|
||||||
./zipformer/ctc_decode.py \
|
./zipformer/ctc_decode.py \
|
||||||
--epoch 50 \
|
--epoch 50 \
|
||||||
--avg 25 \
|
--avg 25 \
|
||||||
@ -212,6 +213,7 @@ You can use <https://github.com/k2-fsa/sherpa> to deploy it.
|
|||||||
| decoding method | test-clean | test-other | comment |
|
| decoding method | test-clean | test-other | comment |
|
||||||
|--------------------------------------|------------|------------|---------------------|
|
|--------------------------------------|------------|------------|---------------------|
|
||||||
| ctc-greedy-decoding | 2.12 | 4.62 | --epoch 50 --avg 24 |
|
| ctc-greedy-decoding | 2.12 | 4.62 | --epoch 50 --avg 24 |
|
||||||
|
| ctc-prefix-beam-search | 2.1 | 4.61 | --epoch 50 --avg 24 |
|
||||||
|
|
||||||
The training command using 4 32G-V100 GPUs is:
|
The training command using 4 32G-V100 GPUs is:
|
||||||
```bash
|
```bash
|
||||||
@ -238,7 +240,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
|||||||
The decoding command is:
|
The decoding command is:
|
||||||
```bash
|
```bash
|
||||||
export CUDA_VISIBLE_DEVICES="0"
|
export CUDA_VISIBLE_DEVICES="0"
|
||||||
for m in ctc-greedy-search; do
|
for m in ctc-greedy-search ctc-prefix-beam-search; do
|
||||||
./zipformer/ctc_decode.py \
|
./zipformer/ctc_decode.py \
|
||||||
--epoch 50 \
|
--epoch 50 \
|
||||||
--avg 24 \
|
--avg 24 \
|
||||||
@ -262,6 +264,7 @@ You can use <https://github.com/k2-fsa/sherpa> to deploy it.
|
|||||||
| decoding method | test-clean | test-other | comment |
|
| decoding method | test-clean | test-other | comment |
|
||||||
|--------------------------------------|------------|------------|---------------------|
|
|--------------------------------------|------------|------------|---------------------|
|
||||||
| ctc-greedy-decoding | 2.03 | 4.37 | --epoch 50 --avg 26 |
|
| ctc-greedy-decoding | 2.03 | 4.37 | --epoch 50 --avg 26 |
|
||||||
|
| ctc-prefix-beam-search | 2.02 | 4.35 | --epoch 50 --avg 26 |
|
||||||
|
|
||||||
The training command using 2 80G-A100 GPUs is:
|
The training command using 2 80G-A100 GPUs is:
|
||||||
```bash
|
```bash
|
||||||
@ -292,7 +295,7 @@ export CUDA_VISIBLE_DEVICES="0,1"
|
|||||||
The decoding command is:
|
The decoding command is:
|
||||||
```bash
|
```bash
|
||||||
export CUDA_VISIBLE_DEVICES="0"
|
export CUDA_VISIBLE_DEVICES="0"
|
||||||
for m in ctc-greedy-search; do
|
for m in ctc-greedy-search ctc-prefix-beam-search; do
|
||||||
./zipformer/ctc_decode.py \
|
./zipformer/ctc_decode.py \
|
||||||
--epoch 50 \
|
--epoch 50 \
|
||||||
--avg 26 \
|
--avg 26 \
|
||||||
|
@ -111,6 +111,7 @@ Usage:
|
|||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
|
import os
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import Dict, List, Optional, Tuple
|
||||||
@ -129,8 +130,14 @@ from icefall.checkpoint import (
|
|||||||
find_checkpoints,
|
find_checkpoints,
|
||||||
load_checkpoint,
|
load_checkpoint,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from icefall.context_graph import ContextGraph, ContextState
|
||||||
|
|
||||||
from icefall.decode import (
|
from icefall.decode import (
|
||||||
ctc_greedy_search,
|
ctc_greedy_search,
|
||||||
|
ctc_prefix_beam_search,
|
||||||
|
ctc_prefix_beam_search_attention_decoder_rescoring,
|
||||||
|
ctc_prefix_beam_search_shallow_fussion,
|
||||||
get_lattice,
|
get_lattice,
|
||||||
nbest_decoding,
|
nbest_decoding,
|
||||||
nbest_oracle,
|
nbest_oracle,
|
||||||
@ -140,7 +147,11 @@ from icefall.decode import (
|
|||||||
rescore_with_n_best_list,
|
rescore_with_n_best_list,
|
||||||
rescore_with_whole_lattice,
|
rescore_with_whole_lattice,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from icefall.ngram_lm import NgramLm, NgramLmStateCost
|
||||||
from icefall.lexicon import Lexicon
|
from icefall.lexicon import Lexicon
|
||||||
|
from icefall.lm_wrapper import LmScorer
|
||||||
|
|
||||||
from icefall.utils import (
|
from icefall.utils import (
|
||||||
AttributeDict,
|
AttributeDict,
|
||||||
get_texts,
|
get_texts,
|
||||||
@ -255,6 +266,12 @@ def get_parser():
|
|||||||
lattice, rescore them with the attention decoder.
|
lattice, rescore them with the attention decoder.
|
||||||
- (9) attention-decoder-rescoring-with-ngram. Extract n paths from the LM
|
- (9) attention-decoder-rescoring-with-ngram. Extract n paths from the LM
|
||||||
rescored lattice, rescore them with the attention decoder.
|
rescored lattice, rescore them with the attention decoder.
|
||||||
|
- (10) ctc-prefix-beam-search. Extract n paths with the given beam, the best
|
||||||
|
path of the n paths is the decoding result.
|
||||||
|
- (11) ctc-prefix-beam-search-attention-decoder-rescoring. Extract n paths with
|
||||||
|
the given beam, rescore them with the attention decoder.
|
||||||
|
- (12) ctc-prefix-beam-search-shallow-fussion. Use NNLM shallow fussion during
|
||||||
|
beam search, LODR and hotwords are also supported in this decoding method.
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -280,6 +297,23 @@ def get_parser():
|
|||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--nnlm-type",
|
||||||
|
type=str,
|
||||||
|
default="rnn",
|
||||||
|
help="Type of NN lm",
|
||||||
|
choices=["rnn", "transformer"],
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--nnlm-scale",
|
||||||
|
type=float,
|
||||||
|
default=0,
|
||||||
|
help="""The scale of the neural network LM, 0 means don't use nnlm shallow fussion.
|
||||||
|
Used only when `--use-shallow-fusion` is set to True.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--hlg-scale",
|
"--hlg-scale",
|
||||||
type=float,
|
type=float,
|
||||||
@ -297,11 +331,52 @@ def get_parser():
|
|||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--backoff-id",
|
||||||
|
type=int,
|
||||||
|
default=500,
|
||||||
|
help="ID of the backoff symbol in the ngram LM",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--lodr-ngram",
|
||||||
|
type=str,
|
||||||
|
help="The path to the lodr ngram",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--lodr-lm-scale",
|
||||||
|
type=float,
|
||||||
|
default=0,
|
||||||
|
help="The scale of lodr ngram, should be less than 0. 0 means don't use lodr.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--context-score",
|
||||||
|
type=float,
|
||||||
|
default=0,
|
||||||
|
help="""
|
||||||
|
The bonus score of each token for the context biasing words/phrases.
|
||||||
|
0 means don't use contextual biasing.
|
||||||
|
Used only when --decoding-method is ctc-prefix-beam-search-shallow-fussion.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--context-file",
|
||||||
|
type=str,
|
||||||
|
default="",
|
||||||
|
help="""
|
||||||
|
The path of the context biasing lists, one word/phrase each line
|
||||||
|
Used only when --decoding-method is ctc-prefix-beam-search-shallow-fussion.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--skip-scoring",
|
"--skip-scoring",
|
||||||
type=str2bool,
|
type=str2bool,
|
||||||
default=False,
|
default=False,
|
||||||
help="""Skip scoring, but still save the ASR output (for eval sets)."""
|
help="""Skip scoring, but still save the ASR output (for eval sets).""",
|
||||||
)
|
)
|
||||||
|
|
||||||
add_model_arguments(parser)
|
add_model_arguments(parser)
|
||||||
@ -314,11 +389,12 @@ def get_decoding_params() -> AttributeDict:
|
|||||||
params = AttributeDict(
|
params = AttributeDict(
|
||||||
{
|
{
|
||||||
"frame_shift_ms": 10,
|
"frame_shift_ms": 10,
|
||||||
"search_beam": 20,
|
"search_beam": 20, # for k2 fsa composition
|
||||||
"output_beam": 8,
|
"output_beam": 8, # for k2 fsa composition
|
||||||
"min_active_states": 30,
|
"min_active_states": 30,
|
||||||
"max_active_states": 10000,
|
"max_active_states": 10000,
|
||||||
"use_double_scores": True,
|
"use_double_scores": True,
|
||||||
|
"beam": 4, # for prefix-beam-search
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
return params
|
return params
|
||||||
@ -333,6 +409,9 @@ def decode_one_batch(
|
|||||||
batch: dict,
|
batch: dict,
|
||||||
word_table: k2.SymbolTable,
|
word_table: k2.SymbolTable,
|
||||||
G: Optional[k2.Fsa] = None,
|
G: Optional[k2.Fsa] = None,
|
||||||
|
NNLM: Optional[LmScorer] = None,
|
||||||
|
LODR_lm: Optional[NgramLm] = None,
|
||||||
|
context_graph: Optional[ContextGraph] = None,
|
||||||
) -> Dict[str, List[List[str]]]:
|
) -> Dict[str, List[List[str]]]:
|
||||||
"""Decode one batch and return the result in a dict. The dict has the
|
"""Decode one batch and return the result in a dict. The dict has the
|
||||||
following format:
|
following format:
|
||||||
@ -377,10 +456,7 @@ def decode_one_batch(
|
|||||||
Return the decoding result. See above description for the format of
|
Return the decoding result. See above description for the format of
|
||||||
the returned dict. Note: If it decodes to nothing, then return None.
|
the returned dict. Note: If it decodes to nothing, then return None.
|
||||||
"""
|
"""
|
||||||
if HLG is not None:
|
device = params.device
|
||||||
device = HLG.device
|
|
||||||
else:
|
|
||||||
device = H.device
|
|
||||||
feature = batch["inputs"]
|
feature = batch["inputs"]
|
||||||
assert feature.ndim == 3
|
assert feature.ndim == 3
|
||||||
feature = feature.to(device)
|
feature = feature.to(device)
|
||||||
@ -411,6 +487,51 @@ def decode_one_batch(
|
|||||||
key = "ctc-greedy-search"
|
key = "ctc-greedy-search"
|
||||||
return {key: hyps}
|
return {key: hyps}
|
||||||
|
|
||||||
|
if params.decoding_method == "ctc-prefix-beam-search":
|
||||||
|
token_ids = ctc_prefix_beam_search(
|
||||||
|
ctc_output=ctc_output, encoder_out_lens=encoder_out_lens
|
||||||
|
)
|
||||||
|
# hyps is a list of str, e.g., ['xxx yyy zzz', ...]
|
||||||
|
hyps = bpe_model.decode(token_ids)
|
||||||
|
|
||||||
|
# hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ]
|
||||||
|
hyps = [s.split() for s in hyps]
|
||||||
|
key = "prefix-beam-search"
|
||||||
|
return {key: hyps}
|
||||||
|
|
||||||
|
if params.decoding_method == "ctc-prefix-beam-search-attention-decoder-rescoring":
|
||||||
|
best_path_dict = ctc_prefix_beam_search_attention_decoder_rescoring(
|
||||||
|
ctc_output=ctc_output,
|
||||||
|
attention_decoder=model.attention_decoder,
|
||||||
|
encoder_out=encoder_out,
|
||||||
|
encoder_out_lens=encoder_out_lens,
|
||||||
|
)
|
||||||
|
ans = dict()
|
||||||
|
for a_scale_str, token_ids in best_path_dict.items():
|
||||||
|
# hyps is a list of str, e.g., ['xxx yyy zzz', ...]
|
||||||
|
hyps = bpe_model.decode(token_ids)
|
||||||
|
# hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ]
|
||||||
|
hyps = [s.split() for s in hyps]
|
||||||
|
ans[a_scale_str] = hyps
|
||||||
|
return ans
|
||||||
|
|
||||||
|
if params.decoding_method == "ctc-prefix-beam-search-shallow-fussion":
|
||||||
|
token_ids = ctc_prefix_beam_search_shallow_fussion(
|
||||||
|
ctc_output=ctc_output,
|
||||||
|
encoder_out_lens=encoder_out_lens,
|
||||||
|
NNLM=NNLM,
|
||||||
|
LODR_lm=LODR_lm,
|
||||||
|
LODR_lm_scale=params.lodr_lm_scale,
|
||||||
|
context_graph=context_graph,
|
||||||
|
)
|
||||||
|
# hyps is a list of str, e.g., ['xxx yyy zzz', ...]
|
||||||
|
hyps = bpe_model.decode(token_ids)
|
||||||
|
|
||||||
|
# hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ]
|
||||||
|
hyps = [s.split() for s in hyps]
|
||||||
|
key = "prefix-beam-search-shallow-fussion"
|
||||||
|
return {key: hyps}
|
||||||
|
|
||||||
supervision_segments = torch.stack(
|
supervision_segments = torch.stack(
|
||||||
(
|
(
|
||||||
supervisions["sequence_idx"],
|
supervisions["sequence_idx"],
|
||||||
@ -584,6 +705,9 @@ def decode_dataset(
|
|||||||
bpe_model: Optional[spm.SentencePieceProcessor],
|
bpe_model: Optional[spm.SentencePieceProcessor],
|
||||||
word_table: k2.SymbolTable,
|
word_table: k2.SymbolTable,
|
||||||
G: Optional[k2.Fsa] = None,
|
G: Optional[k2.Fsa] = None,
|
||||||
|
NNLM: Optional[LmScorer] = None,
|
||||||
|
LODR_lm: Optional[NgramLm] = None,
|
||||||
|
context_graph: Optional[ContextGraph] = None,
|
||||||
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
|
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
|
||||||
"""Decode dataset.
|
"""Decode dataset.
|
||||||
|
|
||||||
@ -634,6 +758,9 @@ def decode_dataset(
|
|||||||
batch=batch,
|
batch=batch,
|
||||||
word_table=word_table,
|
word_table=word_table,
|
||||||
G=G,
|
G=G,
|
||||||
|
NNLM=NNLM,
|
||||||
|
LODR_lm=LODR_lm,
|
||||||
|
context_graph=context_graph,
|
||||||
)
|
)
|
||||||
|
|
||||||
for name, hyps in hyps_dict.items():
|
for name, hyps in hyps_dict.items():
|
||||||
@ -664,9 +791,7 @@ def save_asr_output(
|
|||||||
"""
|
"""
|
||||||
for key, results in results_dict.items():
|
for key, results in results_dict.items():
|
||||||
|
|
||||||
recogs_filename = (
|
recogs_filename = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
|
||||||
params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
|
|
||||||
)
|
|
||||||
|
|
||||||
results = sorted(results)
|
results = sorted(results)
|
||||||
store_transcripts(filename=recogs_filename, texts=results)
|
store_transcripts(filename=recogs_filename, texts=results)
|
||||||
@ -680,7 +805,8 @@ def save_wer_results(
|
|||||||
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
|
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
|
||||||
):
|
):
|
||||||
if params.decoding_method in (
|
if params.decoding_method in (
|
||||||
"attention-decoder-rescoring-with-ngram", "whole-lattice-rescoring"
|
"attention-decoder-rescoring-with-ngram",
|
||||||
|
"whole-lattice-rescoring",
|
||||||
):
|
):
|
||||||
# Set it to False since there are too many logs.
|
# Set it to False since there are too many logs.
|
||||||
enable_log = False
|
enable_log = False
|
||||||
@ -721,6 +847,7 @@ def save_wer_results(
|
|||||||
def main():
|
def main():
|
||||||
parser = get_parser()
|
parser = get_parser()
|
||||||
LibriSpeechAsrDataModule.add_arguments(parser)
|
LibriSpeechAsrDataModule.add_arguments(parser)
|
||||||
|
LmScorer.add_arguments(parser)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
args.exp_dir = Path(args.exp_dir)
|
args.exp_dir = Path(args.exp_dir)
|
||||||
args.lang_dir = Path(args.lang_dir)
|
args.lang_dir = Path(args.lang_dir)
|
||||||
@ -735,8 +862,11 @@ def main():
|
|||||||
set_caching_enabled(True) # lhotse
|
set_caching_enabled(True) # lhotse
|
||||||
|
|
||||||
assert params.decoding_method in (
|
assert params.decoding_method in (
|
||||||
"ctc-greedy-search",
|
|
||||||
"ctc-decoding",
|
"ctc-decoding",
|
||||||
|
"ctc-greedy-search",
|
||||||
|
"ctc-prefix-beam-search",
|
||||||
|
"ctc-prefix-beam-search-attention-decoder-rescoring",
|
||||||
|
"ctc-prefix-beam-search-shallow-fussion",
|
||||||
"1best",
|
"1best",
|
||||||
"nbest",
|
"nbest",
|
||||||
"nbest-rescoring",
|
"nbest-rescoring",
|
||||||
@ -762,6 +892,16 @@ def main():
|
|||||||
params.suffix += f"_chunk-{params.chunk_size}"
|
params.suffix += f"_chunk-{params.chunk_size}"
|
||||||
params.suffix += f"_left-context-{params.left_context_frames}"
|
params.suffix += f"_left-context-{params.left_context_frames}"
|
||||||
|
|
||||||
|
if "prefix-beam-search" in params.decoding_method:
|
||||||
|
params.suffix += f"_beam-{params.beam}"
|
||||||
|
if params.decoding_method == "ctc-prefix-beam-search-shallow-fussion":
|
||||||
|
if params.nnlm_scale != 0:
|
||||||
|
params.suffix += f"_nnlm-scale-{params.nnlm_scale}"
|
||||||
|
if params.lodr_lm_scale != 0:
|
||||||
|
params.suffix += f"_lodr-scale-{params.lodr_lm_scale}"
|
||||||
|
if params.context_score != 0:
|
||||||
|
params.suffix += f"_context_score-{params.context_score}"
|
||||||
|
|
||||||
if params.use_averaged_model:
|
if params.use_averaged_model:
|
||||||
params.suffix += "_use-averaged-model"
|
params.suffix += "_use-averaged-model"
|
||||||
|
|
||||||
@ -771,6 +911,7 @@ def main():
|
|||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device("cuda", 0)
|
device = torch.device("cuda", 0)
|
||||||
|
params.device = device
|
||||||
|
|
||||||
logging.info(f"Device: {device}")
|
logging.info(f"Device: {device}")
|
||||||
logging.info(params)
|
logging.info(params)
|
||||||
@ -786,14 +927,24 @@ def main():
|
|||||||
params.sos_id = 1
|
params.sos_id = 1
|
||||||
|
|
||||||
if params.decoding_method in [
|
if params.decoding_method in [
|
||||||
"ctc-greedy-search", "ctc-decoding", "attention-decoder-rescoring-no-ngram"
|
"ctc-decoding",
|
||||||
|
"ctc-greedy-search",
|
||||||
|
"ctc-prefix-beam-search",
|
||||||
|
"ctc-prefix-beam-search-attention-decoder-rescoring",
|
||||||
|
"ctc-prefix-beam-search-shallow-fussion",
|
||||||
|
"attention-decoder-rescoring-no-ngram",
|
||||||
]:
|
]:
|
||||||
HLG = None
|
HLG = None
|
||||||
H = k2.ctc_topo(
|
H = None
|
||||||
max_token=max_token_id,
|
if params.decoding_method in [
|
||||||
modified=False,
|
"ctc-decoding",
|
||||||
device=device,
|
"attention-decoder-rescoring-no-ngram",
|
||||||
)
|
]:
|
||||||
|
H = k2.ctc_topo(
|
||||||
|
max_token=max_token_id,
|
||||||
|
modified=False,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
bpe_model = spm.SentencePieceProcessor()
|
bpe_model = spm.SentencePieceProcessor()
|
||||||
bpe_model.load(str(params.lang_dir / "bpe.model"))
|
bpe_model.load(str(params.lang_dir / "bpe.model"))
|
||||||
else:
|
else:
|
||||||
@ -844,7 +995,8 @@ def main():
|
|||||||
G = k2.Fsa.from_dict(d)
|
G = k2.Fsa.from_dict(d)
|
||||||
|
|
||||||
if params.decoding_method in [
|
if params.decoding_method in [
|
||||||
"whole-lattice-rescoring", "attention-decoder-rescoring-with-ngram"
|
"whole-lattice-rescoring",
|
||||||
|
"attention-decoder-rescoring-with-ngram",
|
||||||
]:
|
]:
|
||||||
# Add epsilon self-loops to G as we will compose
|
# Add epsilon self-loops to G as we will compose
|
||||||
# it with the whole lattice later
|
# it with the whole lattice later
|
||||||
@ -858,6 +1010,51 @@ def main():
|
|||||||
else:
|
else:
|
||||||
G = None
|
G = None
|
||||||
|
|
||||||
|
# only load the neural network LM if required
|
||||||
|
NNLM = None
|
||||||
|
if (
|
||||||
|
params.decoding_method == "ctc-prefix-beam-search-shallow-fussion"
|
||||||
|
and params.nnlm_scale != 0
|
||||||
|
):
|
||||||
|
NNLM = LmScorer(
|
||||||
|
lm_type=params.nnlm_type,
|
||||||
|
params=params,
|
||||||
|
device=device,
|
||||||
|
lm_scale=params.nnlm_scale,
|
||||||
|
)
|
||||||
|
NNLM.to(device)
|
||||||
|
NNLM.eval()
|
||||||
|
|
||||||
|
LODR_lm = None
|
||||||
|
if (
|
||||||
|
params.decoding_method == "ctc-prefix-beam-search-shallow-fussion"
|
||||||
|
and params.lodr_lm_scale != 0
|
||||||
|
):
|
||||||
|
assert os.path.exists(
|
||||||
|
params.lodr_ngram
|
||||||
|
), f"LODR ngram does not exists, given path : {params.lodr_ngram}"
|
||||||
|
logging.info(f"Loading LODR (token level lm): {params.lodr_ngram}")
|
||||||
|
LODR_lm = NgramLm(
|
||||||
|
params.lodr_ngram,
|
||||||
|
backoff_id=params.backoff_id,
|
||||||
|
is_binary=False,
|
||||||
|
)
|
||||||
|
logging.info(f"num states: {LODR_lm.lm.num_states}")
|
||||||
|
|
||||||
|
context_graph = None
|
||||||
|
if (
|
||||||
|
params.decoding_method == "ctc-prefix-beam-search-shallow-fussion"
|
||||||
|
and params.context_score != 0
|
||||||
|
):
|
||||||
|
assert os.path.exists(
|
||||||
|
params.context_file
|
||||||
|
), f"context_file does not exists, given path : {params.context_file}"
|
||||||
|
contexts = []
|
||||||
|
for line in open(params.context_file).readlines():
|
||||||
|
contexts.append(bpe_model.encode(line.strip()))
|
||||||
|
context_graph = ContextGraph(params.context_score)
|
||||||
|
context_graph.build(contexts)
|
||||||
|
|
||||||
logging.info("About to create model")
|
logging.info("About to create model")
|
||||||
model = get_model(params)
|
model = get_model(params)
|
||||||
|
|
||||||
@ -967,6 +1164,9 @@ def main():
|
|||||||
bpe_model=bpe_model,
|
bpe_model=bpe_model,
|
||||||
word_table=lexicon.word_table,
|
word_table=lexicon.word_table,
|
||||||
G=G,
|
G=G,
|
||||||
|
NNLM=NNLM,
|
||||||
|
LODR_lm=LODR_lm,
|
||||||
|
context_graph=context_graph,
|
||||||
)
|
)
|
||||||
|
|
||||||
save_asr_output(
|
save_asr_output(
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang,
|
||||||
|
# Wei Kang)
|
||||||
#
|
#
|
||||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
#
|
#
|
||||||
@ -15,11 +16,16 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Dict, List, Optional, Union
|
from dataclasses import dataclass, field
|
||||||
|
from multiprocessing.pool import Pool
|
||||||
|
from typing import Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import k2
|
import k2
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from icefall.context_graph import ContextGraph, ContextState
|
||||||
|
from icefall.ngram_lm import NgramLm, NgramLmStateCost
|
||||||
|
from icefall.lm_wrapper import LmScorer
|
||||||
from icefall.utils import add_eos, add_sos, get_texts
|
from icefall.utils import add_eos, add_sos, get_texts
|
||||||
|
|
||||||
DEFAULT_LM_SCALE = [
|
DEFAULT_LM_SCALE = [
|
||||||
@ -1497,3 +1503,667 @@ def ctc_greedy_search(
|
|||||||
hyps = [h[h != blank_id].tolist() for h in hyps]
|
hyps = [h[h != blank_id].tolist() for h in hyps]
|
||||||
|
|
||||||
return hyps
|
return hyps
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Hypothesis:
|
||||||
|
# The predicted tokens so far.
|
||||||
|
# Newly predicted tokens are appended to `ys`.
|
||||||
|
ys: List[int] = field(default_factory=list)
|
||||||
|
|
||||||
|
# The log prob of ys that ends with blank token.
|
||||||
|
# It contains only one entry.
|
||||||
|
log_prob_blank: torch.Tensor = torch.zeros(1, dtype=torch.float32)
|
||||||
|
|
||||||
|
# The log prob of ys that ends with non blank token.
|
||||||
|
# It contains only one entry.
|
||||||
|
log_prob_non_blank: torch.Tensor = torch.tensor(
|
||||||
|
[float("-inf")], dtype=torch.float32
|
||||||
|
)
|
||||||
|
|
||||||
|
# timestamp[i] is the frame index after subsampling
|
||||||
|
# on which ys[i] is decoded
|
||||||
|
timestamp: List[int] = field(default_factory=list)
|
||||||
|
|
||||||
|
# The lm score of ys
|
||||||
|
# May contain external LM score (including LODR score) and contextual biasing score
|
||||||
|
# It contains only one entry
|
||||||
|
lm_score: torch.Tensor = torch.zeros(1, dtype=torch.float32)
|
||||||
|
|
||||||
|
# the lm log_probs for next token given the history ys
|
||||||
|
# The number of elements should be equal to vocabulary size.
|
||||||
|
lm_log_probs: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
|
# the RNNLM states (h and c in LSTM)
|
||||||
|
state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None
|
||||||
|
|
||||||
|
# LODR (N-gram LM) state
|
||||||
|
LODR_state: Optional[NgramLmStateCost] = None
|
||||||
|
|
||||||
|
# N-gram LM state
|
||||||
|
Ngram_state: Optional[NgramLmStateCost] = None
|
||||||
|
|
||||||
|
# Context graph state
|
||||||
|
context_state: Optional[ContextState] = None
|
||||||
|
|
||||||
|
# This is the total score of current path, acoustic plus external LM score.
|
||||||
|
@property
|
||||||
|
def tot_score(self) -> torch.Tensor:
|
||||||
|
return self.log_prob + self.lm_score
|
||||||
|
|
||||||
|
# This is only the probability from model output (i.e External LM score not included).
|
||||||
|
@property
|
||||||
|
def log_prob(self) -> torch.Tensor:
|
||||||
|
return torch.logaddexp(self.log_prob_non_blank, self.log_prob_blank)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def key(self) -> tuple:
|
||||||
|
"""Return a tuple representation of self.ys"""
|
||||||
|
return tuple(self.ys)
|
||||||
|
|
||||||
|
def clone(self) -> "Hypothesis":
|
||||||
|
return Hypothesis(
|
||||||
|
ys=self.ys,
|
||||||
|
log_prob_blank=self.log_prob_blank,
|
||||||
|
log_prob_non_blank=self.log_prob_non_blank,
|
||||||
|
timestamp=self.timestamp,
|
||||||
|
lm_log_probs=self.lm_log_probs,
|
||||||
|
lm_score=self.lm_score,
|
||||||
|
state=self.state,
|
||||||
|
LODR_state=self.LODR_state,
|
||||||
|
Ngram_state=self.Ngram_state,
|
||||||
|
context_state=self.context_state,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class HypothesisList(object):
|
||||||
|
def __init__(self, data: Optional[Dict[tuple, Hypothesis]] = None) -> None:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
data:
|
||||||
|
A dict of Hypotheses. Its key is its `value.key`.
|
||||||
|
"""
|
||||||
|
if data is None:
|
||||||
|
self._data = {}
|
||||||
|
else:
|
||||||
|
self._data = data
|
||||||
|
|
||||||
|
@property
|
||||||
|
def data(self) -> Dict[tuple, Hypothesis]:
|
||||||
|
return self._data
|
||||||
|
|
||||||
|
def add(self, hyp: Hypothesis) -> None:
|
||||||
|
"""Add a Hypothesis to `self`.
|
||||||
|
If `hyp` already exists in `self`, its probability is updated using
|
||||||
|
`log-sum-exp` with the existed one.
|
||||||
|
Args:
|
||||||
|
hyp:
|
||||||
|
The hypothesis to be added.
|
||||||
|
"""
|
||||||
|
key = hyp.key
|
||||||
|
if key in self:
|
||||||
|
old_hyp = self._data[key] # shallow copy
|
||||||
|
torch.logaddexp(
|
||||||
|
old_hyp.log_prob_blank, hyp.log_prob_blank, out=old_hyp.log_prob_blank
|
||||||
|
)
|
||||||
|
torch.logaddexp(
|
||||||
|
old_hyp.log_prob_non_blank,
|
||||||
|
hyp.log_prob_non_blank,
|
||||||
|
out=old_hyp.log_prob_non_blank,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self._data[key] = hyp
|
||||||
|
|
||||||
|
def get_most_probable(self, length_norm: bool = False) -> Hypothesis:
|
||||||
|
"""Get the most probable hypothesis, i.e., the one with
|
||||||
|
the largest `tot_score`.
|
||||||
|
Args:
|
||||||
|
length_norm:
|
||||||
|
If True, the `tot_score` of a hypothesis is normalized by the
|
||||||
|
number of tokens in it.
|
||||||
|
Returns:
|
||||||
|
Return the hypothesis that has the largest `tot_score`.
|
||||||
|
"""
|
||||||
|
if length_norm:
|
||||||
|
return max(self._data.values(), key=lambda hyp: hyp.tot_score / len(hyp.ys))
|
||||||
|
else:
|
||||||
|
return max(self._data.values(), key=lambda hyp: hyp.tot_score)
|
||||||
|
|
||||||
|
def remove(self, hyp: Hypothesis) -> None:
|
||||||
|
"""Remove a given hypothesis.
|
||||||
|
Caution:
|
||||||
|
`self` is modified **in-place**.
|
||||||
|
Args:
|
||||||
|
hyp:
|
||||||
|
The hypothesis to be removed from `self`.
|
||||||
|
Note: It must be contained in `self`. Otherwise,
|
||||||
|
an exception is raised.
|
||||||
|
"""
|
||||||
|
key = hyp.key
|
||||||
|
assert key in self, f"{key} does not exist"
|
||||||
|
del self._data[key]
|
||||||
|
|
||||||
|
def filter(self, threshold: torch.Tensor) -> "HypothesisList":
|
||||||
|
"""Remove all Hypotheses whose tot_score is less than threshold.
|
||||||
|
Caution:
|
||||||
|
`self` is not modified. Instead, a new HypothesisList is returned.
|
||||||
|
Returns:
|
||||||
|
Return a new HypothesisList containing all hypotheses from `self`
|
||||||
|
with `tot_score` being greater than the given `threshold`.
|
||||||
|
"""
|
||||||
|
ans = HypothesisList()
|
||||||
|
for _, hyp in self._data.items():
|
||||||
|
if hyp.tot_score > threshold:
|
||||||
|
ans.add(hyp) # shallow copy
|
||||||
|
return ans
|
||||||
|
|
||||||
|
def topk(self, k: int, length_norm: bool = False) -> "HypothesisList":
|
||||||
|
"""Return the top-k hypothesis.
|
||||||
|
Args:
|
||||||
|
length_norm:
|
||||||
|
If True, the `tot_score` of a hypothesis is normalized by the
|
||||||
|
number of tokens in it.
|
||||||
|
"""
|
||||||
|
hyps = list(self._data.items())
|
||||||
|
|
||||||
|
if length_norm:
|
||||||
|
hyps = sorted(
|
||||||
|
hyps, key=lambda h: h[1].tot_score / len(h[1].ys), reverse=True
|
||||||
|
)[:k]
|
||||||
|
else:
|
||||||
|
hyps = sorted(hyps, key=lambda h: h[1].tot_score, reverse=True)[:k]
|
||||||
|
|
||||||
|
ans = HypothesisList(dict(hyps))
|
||||||
|
return ans
|
||||||
|
|
||||||
|
def __contains__(self, key: tuple):
|
||||||
|
return key in self._data
|
||||||
|
|
||||||
|
def __getitem__(self, key: tuple):
|
||||||
|
return self._data[key]
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
return iter(self._data.values())
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
return len(self._data)
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
s = []
|
||||||
|
for key in self:
|
||||||
|
s.append(key)
|
||||||
|
return ", ".join(str(s))
|
||||||
|
|
||||||
|
|
||||||
|
def get_hyps_shape(hyps: List[HypothesisList]) -> k2.RaggedShape:
|
||||||
|
"""Return a ragged shape with axes [utt][num_hyps].
|
||||||
|
Args:
|
||||||
|
hyps:
|
||||||
|
len(hyps) == batch_size. It contains the current hypothesis for
|
||||||
|
each utterance in the batch.
|
||||||
|
Returns:
|
||||||
|
Return a ragged shape with 2 axes [utt][num_hyps]. Note that
|
||||||
|
the shape is on CPU.
|
||||||
|
"""
|
||||||
|
num_hyps = [len(h) for h in hyps]
|
||||||
|
|
||||||
|
# torch.cumsum() is inclusive sum, so we put a 0 at the beginning
|
||||||
|
# to get exclusive sum later.
|
||||||
|
num_hyps.insert(0, 0)
|
||||||
|
|
||||||
|
num_hyps = torch.tensor(num_hyps)
|
||||||
|
row_splits = torch.cumsum(num_hyps, dim=0, dtype=torch.int32)
|
||||||
|
ans = k2.ragged.create_ragged_shape2(
|
||||||
|
row_splits=row_splits, cached_tot_size=row_splits[-1].item()
|
||||||
|
)
|
||||||
|
return ans
|
||||||
|
|
||||||
|
|
||||||
|
def _step_worker(
|
||||||
|
log_probs: torch.Tensor,
|
||||||
|
indexes: torch.Tensor,
|
||||||
|
B: HypothesisList,
|
||||||
|
beam: int = 4,
|
||||||
|
blank_id: int = 0,
|
||||||
|
nnlm_scale: float = 0,
|
||||||
|
LODR_lm_scale: float = 0,
|
||||||
|
context_graph: Optional[ContextGraph] = None,
|
||||||
|
) -> HypothesisList:
|
||||||
|
"""The worker to decode one step.
|
||||||
|
Args:
|
||||||
|
log_probs:
|
||||||
|
topk log_probs of current step (i.e. the kept tokens of first pass pruning),
|
||||||
|
the shape is (beam,)
|
||||||
|
topk_indexes:
|
||||||
|
The indexes of the topk_values above, the shape is (beam,)
|
||||||
|
B:
|
||||||
|
An instance of HypothesisList containing the kept hypothesis.
|
||||||
|
beam:
|
||||||
|
The number of hypothesis to be kept at each step.
|
||||||
|
blank_id:
|
||||||
|
The id of blank in the vocabulary.
|
||||||
|
lm_scale:
|
||||||
|
The scale of nn lm.
|
||||||
|
LODR_lm_scale:
|
||||||
|
The scale of the LODR_lm
|
||||||
|
context_graph:
|
||||||
|
A ContextGraph instance containing contextual phrases.
|
||||||
|
Return:
|
||||||
|
Returns the updated HypothesisList.
|
||||||
|
"""
|
||||||
|
A = list(B)
|
||||||
|
B = HypothesisList()
|
||||||
|
for h in range(len(A)):
|
||||||
|
hyp = A[h]
|
||||||
|
for k in range(log_probs.size(0)):
|
||||||
|
log_prob, index = log_probs[k], indexes[k]
|
||||||
|
new_token = index.item()
|
||||||
|
update_prefix = False
|
||||||
|
new_hyp = hyp.clone()
|
||||||
|
if new_token == blank_id:
|
||||||
|
# Case 0: *a + ε => *a
|
||||||
|
# *aε + ε => *a
|
||||||
|
# Prefix does not change, update log_prob of blank
|
||||||
|
new_hyp.log_prob_non_blank = torch.tensor(
|
||||||
|
[float("-inf")], dtype=torch.float32
|
||||||
|
)
|
||||||
|
new_hyp.log_prob_blank = hyp.log_prob + log_prob
|
||||||
|
B.add(new_hyp)
|
||||||
|
elif len(hyp.ys) > 0 and hyp.ys[-1] == new_token:
|
||||||
|
# Case 1: *a + a => *a
|
||||||
|
# Prefix does not change, update log_prob of non_blank
|
||||||
|
new_hyp.log_prob_non_blank = hyp.log_prob_non_blank + log_prob
|
||||||
|
new_hyp.log_prob_blank = torch.tensor(
|
||||||
|
[float("-inf")], dtype=torch.float32
|
||||||
|
)
|
||||||
|
B.add(new_hyp)
|
||||||
|
|
||||||
|
# Case 2: *aε + a => *aa
|
||||||
|
# Prefix changes, update log_prob of blank
|
||||||
|
new_hyp = hyp.clone()
|
||||||
|
# Caution: DO NOT use append, as clone is shallow copy
|
||||||
|
new_hyp.ys = hyp.ys + [new_token]
|
||||||
|
new_hyp.log_prob_non_blank = hyp.log_prob_blank + log_prob
|
||||||
|
new_hyp.log_prob_blank = torch.tensor(
|
||||||
|
[float("-inf")], dtype=torch.float32
|
||||||
|
)
|
||||||
|
update_prefix = True
|
||||||
|
else:
|
||||||
|
# Case 3: *a + b => *ab, *aε + b => *ab
|
||||||
|
# Prefix changes, update log_prob of non_blank
|
||||||
|
# Caution: DO NOT use append, as clone is shallow copy
|
||||||
|
new_hyp.ys = hyp.ys + [new_token]
|
||||||
|
new_hyp.log_prob_non_blank = hyp.log_prob + log_prob
|
||||||
|
new_hyp.log_prob_blank = torch.tensor(
|
||||||
|
[float("-inf")], dtype=torch.float32
|
||||||
|
)
|
||||||
|
update_prefix = True
|
||||||
|
|
||||||
|
if update_prefix:
|
||||||
|
lm_score = hyp.lm_score
|
||||||
|
if hyp.lm_log_probs is not None:
|
||||||
|
lm_score = lm_score + hyp.lm_log_probs[new_token] * nnlm_scale
|
||||||
|
new_hyp.lm_log_probs = None
|
||||||
|
|
||||||
|
if context_graph is not None and hyp.context_state is not None:
|
||||||
|
(
|
||||||
|
context_score,
|
||||||
|
new_context_state,
|
||||||
|
matched_state,
|
||||||
|
) = context_graph.forward_one_step(hyp.context_state, new_token)
|
||||||
|
lm_score = lm_score + context_score
|
||||||
|
new_hyp.context_state = new_context_state
|
||||||
|
|
||||||
|
if hyp.LODR_state is not None:
|
||||||
|
state_cost = hyp.LODR_state.forward_one_step(new_token)
|
||||||
|
# calculate the score of the latest token
|
||||||
|
current_ngram_score = state_cost.lm_score - hyp.LODR_state.lm_score
|
||||||
|
assert current_ngram_score <= 0.0, (
|
||||||
|
state_cost.lm_score,
|
||||||
|
hyp.LODR_state.lm_score,
|
||||||
|
)
|
||||||
|
lm_score = lm_score + LODR_lm_scale * current_ngram_score
|
||||||
|
new_hyp.LODR_state = state_cost
|
||||||
|
|
||||||
|
new_hyp.lm_score = lm_score
|
||||||
|
B.add(new_hyp)
|
||||||
|
B = B.topk(beam)
|
||||||
|
return B
|
||||||
|
|
||||||
|
|
||||||
|
def _sequence_worker(
|
||||||
|
topk_values: torch.Tensor,
|
||||||
|
topk_indexes: torch.Tensor,
|
||||||
|
B: HypothesisList,
|
||||||
|
encoder_out_lens: torch.Tensor,
|
||||||
|
beam: int = 4,
|
||||||
|
blank_id: int = 0,
|
||||||
|
) -> HypothesisList:
|
||||||
|
"""The worker to decode one sequence.
|
||||||
|
Args:
|
||||||
|
topk_values:
|
||||||
|
topk log_probs of model output (i.e. the kept tokens of first pass pruning),
|
||||||
|
the shape is (T, beam)
|
||||||
|
topk_indexes:
|
||||||
|
The indexes of the topk_values above, the shape is (T, beam)
|
||||||
|
B:
|
||||||
|
An instance of HypothesisList containing the kept hypothesis.
|
||||||
|
encoder_out_lens:
|
||||||
|
The lengths (frames) of sequences after subsampling, the shape is (B,)
|
||||||
|
beam:
|
||||||
|
The number of hypothesis to be kept at each step.
|
||||||
|
blank_id:
|
||||||
|
The id of blank in the vocabulary.
|
||||||
|
Return:
|
||||||
|
Returns the updated HypothesisList.
|
||||||
|
"""
|
||||||
|
B.add(Hypothesis())
|
||||||
|
for j in range(encoder_out_lens):
|
||||||
|
log_probs, indexes = topk_values[j], topk_indexes[j]
|
||||||
|
B = _step_worker(log_probs, indexes, B, beam, blank_id)
|
||||||
|
return B
|
||||||
|
|
||||||
|
|
||||||
|
def ctc_prefix_beam_search(
|
||||||
|
ctc_output: torch.Tensor,
|
||||||
|
encoder_out_lens: torch.Tensor,
|
||||||
|
beam: int = 4,
|
||||||
|
blank_id: int = 0,
|
||||||
|
process_pool: Optional[Pool] = None,
|
||||||
|
return_nbest: Optional[bool] = False,
|
||||||
|
) -> Union[List[List[int]], List[HypothesisList]]:
|
||||||
|
"""Implement prefix search decoding in "Connectionist Temporal Classification:
|
||||||
|
Labelling Unsegmented Sequence Data with Recurrent Neural Networks".
|
||||||
|
Args:
|
||||||
|
ctc_output:
|
||||||
|
The output of ctc head (log probability), the shape is (B, T, V)
|
||||||
|
encoder_out_lens:
|
||||||
|
The lengths (frames) of sequences after subsampling, the shape is (B,)
|
||||||
|
beam:
|
||||||
|
The number of hypothesis to be kept at each step.
|
||||||
|
blank_id:
|
||||||
|
The id of blank in the vocabulary.
|
||||||
|
process_pool:
|
||||||
|
The process pool for parallel decoding, if not provided, it will use all
|
||||||
|
you cpu cores by default.
|
||||||
|
return_nbest:
|
||||||
|
If true, return a list of HypothesisList, return a list of list of decoded token ids otherwise.
|
||||||
|
"""
|
||||||
|
batch_size, num_frames, vocab_size = ctc_output.shape
|
||||||
|
|
||||||
|
# TODO: using a larger beam for first pass pruning
|
||||||
|
topk_values, topk_indexes = ctc_output.topk(beam) # (B, T, beam)
|
||||||
|
topk_values = topk_values.cpu()
|
||||||
|
topk_indexes = topk_indexes.cpu()
|
||||||
|
|
||||||
|
B = [HypothesisList() for _ in range(batch_size)]
|
||||||
|
|
||||||
|
pool = Pool() if process_pool is None else process_pool
|
||||||
|
arguments = []
|
||||||
|
for i in range(batch_size):
|
||||||
|
arguments.append(
|
||||||
|
(
|
||||||
|
topk_values[i],
|
||||||
|
topk_indexes[i],
|
||||||
|
B[i],
|
||||||
|
encoder_out_lens[i].item(),
|
||||||
|
beam,
|
||||||
|
blank_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
async_results = pool.starmap_async(_sequence_worker, arguments)
|
||||||
|
B = list(async_results.get())
|
||||||
|
if process_pool is None:
|
||||||
|
pool.close()
|
||||||
|
pool.join()
|
||||||
|
if return_nbest:
|
||||||
|
return B
|
||||||
|
else:
|
||||||
|
best_hyps = [b.get_most_probable() for b in B]
|
||||||
|
return [hyp.ys for hyp in best_hyps]
|
||||||
|
|
||||||
|
|
||||||
|
def ctc_prefix_beam_search_shallow_fussion(
|
||||||
|
ctc_output: torch.Tensor,
|
||||||
|
encoder_out_lens: torch.Tensor,
|
||||||
|
beam: int = 4,
|
||||||
|
blank_id: int = 0,
|
||||||
|
LODR_lm: Optional[NgramLm] = None,
|
||||||
|
LODR_lm_scale: Optional[float] = 0,
|
||||||
|
NNLM: Optional[LmScorer] = None,
|
||||||
|
context_graph: Optional[ContextGraph] = None,
|
||||||
|
) -> List[List[int]]:
|
||||||
|
"""Implement prefix search decoding in "Connectionist Temporal Classification:
|
||||||
|
Labelling Unsegmented Sequence Data with Recurrent Neural Networks" and add
|
||||||
|
nervous language model shallow fussion, it also supports contextual
|
||||||
|
biasing with a given grammar.
|
||||||
|
Args:
|
||||||
|
ctc_output:
|
||||||
|
The output of ctc head (log probability), the shape is (B, T, V)
|
||||||
|
encoder_out_lens:
|
||||||
|
The lengths (frames) of sequences after subsampling, the shape is (B,)
|
||||||
|
beam:
|
||||||
|
The number of hypothesis to be kept at each step.
|
||||||
|
blank_id:
|
||||||
|
The id of blank in the vocabulary.
|
||||||
|
LODR_lm:
|
||||||
|
A low order n-gram LM, whose score will be subtracted during shallow fusion
|
||||||
|
LODR_lm_scale:
|
||||||
|
The scale of the LODR_lm
|
||||||
|
LM:
|
||||||
|
A neural net LM, e.g an RNNLM or transformer LM
|
||||||
|
context_graph:
|
||||||
|
A ContextGraph instance containing contextual phrases.
|
||||||
|
Return:
|
||||||
|
Returns a list of list of decoded token ids.
|
||||||
|
"""
|
||||||
|
batch_size, num_frames, vocab_size = ctc_output.shape
|
||||||
|
# TODO: using a larger beam for first pass pruning
|
||||||
|
topk_values, topk_indexes = ctc_output.topk(beam) # (B, T, beam)
|
||||||
|
topk_values = topk_values.cpu()
|
||||||
|
topk_indexes = topk_indexes.cpu()
|
||||||
|
encoder_out_lens = encoder_out_lens.tolist()
|
||||||
|
device = ctc_output.device
|
||||||
|
|
||||||
|
nnlm_scale = 0
|
||||||
|
init_scores = None
|
||||||
|
init_states = None
|
||||||
|
if NNLM is not None:
|
||||||
|
nnlm_scale = NNLM.lm_scale
|
||||||
|
sos_id = getattr(NNLM, "sos_id", 1)
|
||||||
|
# get initial lm score and lm state by scoring the "sos" token
|
||||||
|
sos_token = torch.tensor([[sos_id]]).to(torch.int64).to(device)
|
||||||
|
lens = torch.tensor([1]).to(device)
|
||||||
|
init_scores, init_states = NNLM.score_token(sos_token, lens)
|
||||||
|
init_scores, init_states = init_scores.cpu(), (
|
||||||
|
init_states[0].cpu(),
|
||||||
|
init_states[1].cpu(),
|
||||||
|
)
|
||||||
|
|
||||||
|
B = [HypothesisList() for _ in range(batch_size)]
|
||||||
|
for i in range(batch_size):
|
||||||
|
B[i].add(
|
||||||
|
Hypothesis(
|
||||||
|
ys=[],
|
||||||
|
log_prob_non_blank=torch.tensor([float("-inf")], dtype=torch.float32),
|
||||||
|
log_prob_blank=torch.zeros(1, dtype=torch.float32),
|
||||||
|
lm_score=torch.zeros(1, dtype=torch.float32),
|
||||||
|
state=init_states,
|
||||||
|
lm_log_probs=None if init_scores is None else init_scores.reshape(-1),
|
||||||
|
LODR_state=None if LODR_lm is None else NgramLmStateCost(LODR_lm),
|
||||||
|
context_state=None if context_graph is None else context_graph.root,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
for j in range(num_frames):
|
||||||
|
for i in range(batch_size):
|
||||||
|
if j < encoder_out_lens[i]:
|
||||||
|
log_probs, indexes = topk_values[i][j], topk_indexes[i][j]
|
||||||
|
B[i] = _step_worker(
|
||||||
|
log_probs=log_probs,
|
||||||
|
indexes=indexes,
|
||||||
|
B=B[i],
|
||||||
|
beam=beam,
|
||||||
|
blank_id=blank_id,
|
||||||
|
nnlm_scale=nnlm_scale,
|
||||||
|
LODR_lm_scale=LODR_lm_scale,
|
||||||
|
context_graph=context_graph,
|
||||||
|
)
|
||||||
|
if NNLM is None:
|
||||||
|
continue
|
||||||
|
# update lm_log_probs
|
||||||
|
token_list = [] # a list of list
|
||||||
|
hs = []
|
||||||
|
cs = []
|
||||||
|
indexes = [] # (batch_idx, key)
|
||||||
|
for batch_idx, hyps in enumerate(B):
|
||||||
|
for hyp in hyps:
|
||||||
|
if hyp.lm_log_probs is None: # those hyps that prefix changes
|
||||||
|
if NNLM.lm_type == "rnn":
|
||||||
|
token_list.append([hyp.ys[-1]])
|
||||||
|
# store the LSTM states
|
||||||
|
hs.append(hyp.state[0])
|
||||||
|
cs.append(hyp.state[1])
|
||||||
|
else:
|
||||||
|
# for transformer LM
|
||||||
|
token_list.append([sos_id] + hyp.ys[:])
|
||||||
|
indexes.append((batch_idx, hyp.key))
|
||||||
|
if len(token_list) != 0:
|
||||||
|
x_lens = torch.tensor([len(tokens) for tokens in token_list]).to(device)
|
||||||
|
if NNLM.lm_type == "rnn":
|
||||||
|
tokens_to_score = (
|
||||||
|
torch.tensor(token_list).to(torch.int64).to(device).reshape(-1, 1)
|
||||||
|
)
|
||||||
|
hs = torch.cat(hs, dim=1).to(device)
|
||||||
|
cs = torch.cat(cs, dim=1).to(device)
|
||||||
|
state = (hs, cs)
|
||||||
|
else:
|
||||||
|
# for transformer LM
|
||||||
|
tokens_list = [torch.tensor(tokens) for tokens in token_list]
|
||||||
|
tokens_to_score = (
|
||||||
|
torch.nn.utils.rnn.pad_sequence(
|
||||||
|
tokens_list, batch_first=True, padding_value=0.0
|
||||||
|
)
|
||||||
|
.to(device)
|
||||||
|
.to(torch.int64)
|
||||||
|
)
|
||||||
|
state = None
|
||||||
|
|
||||||
|
scores, lm_states = NNLM.score_token(tokens_to_score, x_lens, state)
|
||||||
|
scores, lm_states = scores.cpu(), (lm_states[0].cpu(), lm_states[1].cpu())
|
||||||
|
assert scores.size(0) == len(indexes), (scores.size(0), len(indexes))
|
||||||
|
for i in range(scores.size(0)):
|
||||||
|
batch_idx, key = indexes[i]
|
||||||
|
B[batch_idx][key].lm_log_probs = scores[i]
|
||||||
|
if NNLM.lm_type == "rnn":
|
||||||
|
state = (
|
||||||
|
lm_states[0][:, i, :].unsqueeze(1),
|
||||||
|
lm_states[1][:, i, :].unsqueeze(1),
|
||||||
|
)
|
||||||
|
B[batch_idx][key].state = state
|
||||||
|
|
||||||
|
# finalize context_state, if the matched contexts do not reach final state
|
||||||
|
# we need to add the score on the corresponding backoff arc
|
||||||
|
if context_graph is not None:
|
||||||
|
for hyps in B:
|
||||||
|
for hyp in hyps:
|
||||||
|
context_score, new_context_state = context_graph.finalize(
|
||||||
|
hyp.context_state
|
||||||
|
)
|
||||||
|
hyp.lm_score += context_score
|
||||||
|
hyp.context_state = new_context_state
|
||||||
|
|
||||||
|
best_hyps = [b.get_most_probable() for b in B]
|
||||||
|
return [hyp.ys for hyp in best_hyps]
|
||||||
|
|
||||||
|
|
||||||
|
def ctc_prefix_beam_search_attention_decoder_rescoring(
|
||||||
|
ctc_output: torch.Tensor,
|
||||||
|
attention_decoder: torch.nn.Module,
|
||||||
|
encoder_out: torch.Tensor,
|
||||||
|
encoder_out_lens: torch.Tensor,
|
||||||
|
beam: int = 8,
|
||||||
|
blank_id: int = 0,
|
||||||
|
attention_scale: Optional[float] = None,
|
||||||
|
process_pool: Optional[Pool] = None,
|
||||||
|
):
|
||||||
|
"""Implement prefix search decoding in "Connectionist Temporal Classification:
|
||||||
|
Labelling Unsegmented Sequence Data with Recurrent Neural Networks" and add
|
||||||
|
attention decoder rescoring.
|
||||||
|
Args:
|
||||||
|
ctc_output:
|
||||||
|
The output of ctc head (log probability), the shape is (B, T, V)
|
||||||
|
attention_decoder:
|
||||||
|
The attention decoder.
|
||||||
|
encoder_out:
|
||||||
|
The output of encoder, the shape is (B, T, D)
|
||||||
|
encoder_out_lens:
|
||||||
|
The lengths (frames) of sequences after subsampling, the shape is (B,)
|
||||||
|
beam:
|
||||||
|
The number of hypothesis to be kept at each step.
|
||||||
|
blank_id:
|
||||||
|
The id of blank in the vocabulary.
|
||||||
|
attention_scale:
|
||||||
|
The scale of attention decoder score, if not provided it will search in
|
||||||
|
a default list (see the code below).
|
||||||
|
process_pool:
|
||||||
|
The process pool for parallel decoding, if not provided, it will use all
|
||||||
|
you cpu cores by default.
|
||||||
|
"""
|
||||||
|
# List[HypothesisList]
|
||||||
|
nbest = ctc_prefix_beam_search(
|
||||||
|
ctc_output=ctc_output,
|
||||||
|
encoder_out_lens=encoder_out_lens,
|
||||||
|
beam=beam,
|
||||||
|
blank_id=blank_id,
|
||||||
|
return_nbest=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
device = ctc_output.device
|
||||||
|
|
||||||
|
hyp_shape = get_hyps_shape(nbest).to(device)
|
||||||
|
hyp_to_utt_map = hyp_shape.row_ids(1).to(torch.long)
|
||||||
|
# the shape of encoder_out is (N, T, C), so we use axis=0 here
|
||||||
|
expanded_encoder_out = encoder_out.index_select(0, hyp_to_utt_map)
|
||||||
|
expanded_encoder_out_lens = encoder_out_lens.index_select(0, hyp_to_utt_map)
|
||||||
|
|
||||||
|
nbest = [list(x) for x in nbest]
|
||||||
|
token_ids = []
|
||||||
|
scores = []
|
||||||
|
for hyps in nbest:
|
||||||
|
for hyp in hyps:
|
||||||
|
token_ids.append(hyp.ys)
|
||||||
|
scores.append(hyp.log_prob.reshape(1))
|
||||||
|
scores = torch.cat(scores).to(device)
|
||||||
|
|
||||||
|
nll = attention_decoder.nll(
|
||||||
|
encoder_out=expanded_encoder_out,
|
||||||
|
encoder_out_lens=expanded_encoder_out_lens,
|
||||||
|
token_ids=token_ids,
|
||||||
|
)
|
||||||
|
assert nll.ndim == 2
|
||||||
|
assert nll.shape[0] == len(token_ids)
|
||||||
|
|
||||||
|
attention_scores = -nll.sum(dim=1)
|
||||||
|
|
||||||
|
if attention_scale is None:
|
||||||
|
attention_scale_list = [0.01, 0.05, 0.08]
|
||||||
|
attention_scale_list += [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0]
|
||||||
|
attention_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0]
|
||||||
|
attention_scale_list += [2.1, 2.2, 2.3, 2.5, 3.0, 4.0, 5.0]
|
||||||
|
attention_scale_list += [5.0, 6.0, 7.0, 8.0, 9.0]
|
||||||
|
else:
|
||||||
|
attention_scale_list = [attention_scale]
|
||||||
|
|
||||||
|
ans = dict()
|
||||||
|
|
||||||
|
start_indexes = hyp_shape.row_splits(1)[0:-1]
|
||||||
|
for a_scale in attention_scale_list:
|
||||||
|
tot_scores = scores + a_scale * attention_scores
|
||||||
|
ragged_tot_scores = k2.RaggedTensor(hyp_shape, tot_scores)
|
||||||
|
max_indexes = ragged_tot_scores.argmax()
|
||||||
|
max_indexes = max_indexes - start_indexes
|
||||||
|
max_indexes = max_indexes.cpu()
|
||||||
|
best_path = [nbest[i][max_indexes[i]].ys for i in range(len(max_indexes))]
|
||||||
|
key = f"attention_scale_{a_scale}"
|
||||||
|
ans[key] = best_path
|
||||||
|
return ans
|
||||||
|
@ -19,8 +19,10 @@
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import collections
|
import collections
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
import pathlib
|
||||||
import random
|
import random
|
||||||
import re
|
import re
|
||||||
import subprocess
|
import subprocess
|
||||||
@ -180,6 +182,15 @@ class AttributeDict(dict):
|
|||||||
return
|
return
|
||||||
raise AttributeError(f"No such attribute '{key}'")
|
raise AttributeError(f"No such attribute '{key}'")
|
||||||
|
|
||||||
|
def __str__(self, indent: int = 2):
|
||||||
|
tmp = {}
|
||||||
|
for k, v in self.items():
|
||||||
|
# PosixPath is ont JSON serializable
|
||||||
|
if isinstance(v, pathlib.Path) or isinstance(v, torch.device):
|
||||||
|
v = str(v)
|
||||||
|
tmp[k] = v
|
||||||
|
return json.dumps(tmp, indent=indent, sort_keys=True)
|
||||||
|
|
||||||
|
|
||||||
def encode_supervisions(
|
def encode_supervisions(
|
||||||
supervisions: dict,
|
supervisions: dict,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user