Merge adb54aea91abe211b19ec75eeb422b15a3867405 into 6a091da0b0543befb0492848d3583700c274d111

This commit is contained in:
Fangjun Kuang 2022-03-23 12:52:43 +08:00 committed by GitHub
commit 1ce4349c17
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 840 additions and 133 deletions

View File

@ -0,0 +1,144 @@
#!/usr/bin/env python3
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script takes as input lang_dir and generates LG from
- L, the lexicon, built from lang_dir/L_disambig.pt
Caution: We use a lexicon that contains disambiguation symbols
- G, the LM, built from data/lm/G_3_gram.fst.txt
The generated LG is saved in $lang_dir/LG.fst
"""
import argparse
import logging
from pathlib import Path
import k2
import torch
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--lang-dir",
type=str,
help="""Input and output directory.
""",
)
return parser.parse_args()
def compile_LG(lang_dir: str) -> k2.Fsa:
"""
Args:
lang_dir:
The language directory, e.g., data/lang_phone or data/lang_bpe_500.
Return:
An FST representing LG.
"""
tokens = k2.SymbolTable.from_file(f"{lang_dir}/tokens.txt")
assert "#0" in tokens
first_token_disambig_id = tokens["#0"]
logging.info(f"first token disambig ID: {first_token_disambig_id}")
L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt"))
if Path("data/lm/G_3_gram.pt").is_file():
logging.info("Loading pre-compiled G_3_gram")
d = torch.load("data/lm/G_3_gram.pt")
G = k2.Fsa.from_dict(d)
else:
logging.info("Loading G_3_gram.fst.txt")
with open("data/lm/G_3_gram.fst.txt") as f:
G = k2.Fsa.from_openfst(f.read(), acceptor=False)
del G.aux_labels
torch.save(G.as_dict(), "data/lm/G_3_gram.pt")
L = k2.arc_sort(L)
G = k2.arc_sort(G)
logging.info("Composing L and G")
LG = k2.compose(L, G)
logging.info(f"LG shape: {LG.shape}, num_arcs: {LG.num_arcs}")
del LG.aux_labels
logging.info("Connecting LG")
LG = k2.connect(LG)
logging.info(
f"LG shape after k2.connect: {LG.shape}, num_arcs: {LG.num_arcs}"
)
logging.info("Determinizing LG")
LG = k2.determinize(LG)
logging.info(
f"LG shape after k2.determinize: {LG.shape}, num_arcs: {LG.num_arcs}"
)
logging.info("Connecting LG after k2.determinize")
LG = k2.connect(LG)
logging.info(
f"LG shape after k2.connect: {LG.shape}, num_arcs: {LG.num_arcs}"
)
logging.info("Arc sorting LG")
LG = k2.arc_sort(LG)
logging.info(f"LG properties: {LG.properties_str}")
# Possible properties is:
# "Valid|Nonempty|ArcSorted|ArcSortedAndDeterministic|EpsilonFree|MaybeAccessible|MaybeCoaccessible" # noqa
logging.info(
"Caution: LG is deterministic and contains disambig symbols!!!"
)
return LG
def main():
args = get_args()
lang_dir = Path(args.lang_dir)
out_filename = lang_dir / "LG.pt"
if out_filename.is_file():
logging.info(f"{out_filename} already exists - skipping")
return
logging.info(f"Processing {lang_dir}")
LG = compile_LG(lang_dir)
logging.info(f"Saving LG to {out_filename}")
torch.save(LG.as_dict(), out_filename)
if __name__ == "__main__":
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1,96 @@
#!/usr/bin/env python3
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
To run this file, do:
cd icefall/egs/librispeech/ASR
python ./local/test_compile_lg.py
"""
import os
from pathlib import Path
import k2
import torch
lang_dir = Path("./data/lang_bpe_500")
corpus = "test_compile_lg_corpus.txt"
arpa = "test_compile_lg_3_gram.arpa"
G_fst_txt = "test_compile_lg_3_gram.fst.txt"
def generate_corpus():
s = """HELLO WORLD
HELLOA WORLDER
HELLOA WORLDER HELLO
HELLOA WORLDER"""
with open(corpus, "w") as f:
f.write(s)
def generate_arpa():
cmd = f"""
./shared/make_kn_lm.py \
-ngram-order 3 \
-text {corpus} \
-lm {arpa}
"""
os.system(cmd)
def generate_G():
cmd = f"""
python3 -m kaldilm \
--read-symbol-table="{lang_dir}/words.txt" \
--disambig-symbol='#0' \
{arpa} > {G_fst_txt}
"""
os.system(cmd)
def main():
generate_corpus()
generate_arpa()
generate_G()
with open(G_fst_txt) as f:
G = k2.Fsa.from_openfst(f.read(), acceptor=False)
del G.aux_labels
G.labels_sym = k2.SymbolTable.from_file(f"{lang_dir}/words.txt")
G.draw("G.pdf", title="G")
L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt"))
L.labels_sym = k2.SymbolTable.from_file(f"{lang_dir}/tokens.txt")
L.aux_labels_sym = k2.SymbolTable.from_file(f"{lang_dir}/words.txt")
L = k2.arc_sort(L)
G = k2.arc_sort(G)
LG = k2.compose(L, G)
del LG.aux_labels
LG = k2.determinize(LG)
LG = k2.connect(LG)
LG = k2.arc_sort(LG)
print(LG.properties_str)
LG.draw("LG.pdf", title="LG")
# You can have a look at G.pdf and LG.pdf to get a feel
# what they look like
if __name__ == "__main__":
main()

View File

@ -14,11 +14,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Dict, List, Optional
import k2
import torch
from model import Transducer
from shallow_fusion import shallow_fusion
from utils import Hypothesis, HypothesisList
def greedy_search(
@ -101,132 +103,6 @@ def greedy_search(
return hyp
@dataclass
class Hypothesis:
# The predicted tokens so far.
# Newly predicted tokens are appended to `ys`.
ys: List[int]
# The log prob of ys.
# It contains only one entry.
log_prob: torch.Tensor
@property
def key(self) -> str:
"""Return a string representation of self.ys"""
return "_".join(map(str, self.ys))
class HypothesisList(object):
def __init__(self, data: Optional[Dict[str, Hypothesis]] = None) -> None:
"""
Args:
data:
A dict of Hypotheses. Its key is its `value.key`.
"""
if data is None:
self._data = {}
else:
self._data = data
@property
def data(self) -> Dict[str, Hypothesis]:
return self._data
def add(self, hyp: Hypothesis) -> None:
"""Add a Hypothesis to `self`.
If `hyp` already exists in `self`, its probability is updated using
`log-sum-exp` with the existed one.
Args:
hyp:
The hypothesis to be added.
"""
key = hyp.key
if key in self:
old_hyp = self._data[key] # shallow copy
torch.logaddexp(
old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob
)
else:
self._data[key] = hyp
def get_most_probable(self, length_norm: bool = False) -> Hypothesis:
"""Get the most probable hypothesis, i.e., the one with
the largest `log_prob`.
Args:
length_norm:
If True, the `log_prob` of a hypothesis is normalized by the
number of tokens in it.
Returns:
Return the hypothesis that has the largest `log_prob`.
"""
if length_norm:
return max(
self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys)
)
else:
return max(self._data.values(), key=lambda hyp: hyp.log_prob)
def remove(self, hyp: Hypothesis) -> None:
"""Remove a given hypothesis.
Caution:
`self` is modified **in-place**.
Args:
hyp:
The hypothesis to be removed from `self`.
Note: It must be contained in `self`. Otherwise,
an exception is raised.
"""
key = hyp.key
assert key in self, f"{key} does not exist"
del self._data[key]
def filter(self, threshold: torch.Tensor) -> "HypothesisList":
"""Remove all Hypotheses whose log_prob is less than threshold.
Caution:
`self` is not modified. Instead, a new HypothesisList is returned.
Returns:
Return a new HypothesisList containing all hypotheses from `self`
with `log_prob` being greater than the given `threshold`.
"""
ans = HypothesisList()
for _, hyp in self._data.items():
if hyp.log_prob > threshold:
ans.add(hyp) # shallow copy
return ans
def topk(self, k: int) -> "HypothesisList":
"""Return the top-k hypothesis."""
hyps = list(self._data.items())
hyps = sorted(hyps, key=lambda h: h[1].log_prob, reverse=True)[:k]
ans = HypothesisList(dict(hyps))
return ans
def __contains__(self, key: str):
return key in self._data
def __iter__(self):
return iter(self._data.values())
def __len__(self) -> int:
return len(self._data)
def __str__(self) -> str:
s = []
for key in self:
s.append(key)
return ", ".join(s)
def run_decoder(
ys: List[int],
model: Transducer,
@ -421,6 +297,161 @@ def modified_beam_search(
return ys
def modified_beam_search_with_shallow_fusion(
model: Transducer,
encoder_out: torch.Tensor,
beam: int = 4,
LG: Optional[k2.Fsa] = None,
ngram_lm_scale: float = 0.1,
) -> List[int]:
"""It limits the maximum number of symbols per frame to 1.
Args:
model:
An instance of `Transducer`.
encoder_out:
A tensor of shape (N, T, C) from the encoder. Support only N==1 for now.
beam:
Beam size.
LG:
Optional. Used for shallow fusion.
ngram_lm_scale:
Used only when LG is not None. The total score of a path is
am_score + ngram_lm_scale * ngram_lm_scale
Returns:
Return the decoded result.
"""
enable_shallow_fusion = LG is not None
assert encoder_out.ndim == 3
# support only batch_size == 1 for now
assert encoder_out.size(0) == 1, encoder_out.size(0)
blank_id = model.decoder.blank_id
context_size = model.decoder.context_size
device = model.device
decoder_input = torch.tensor(
[blank_id] * context_size, device=device
).reshape(1, context_size)
decoder_out = model.decoder(decoder_input, need_pad=False)
T = encoder_out.size(1)
B = HypothesisList()
if enable_shallow_fusion:
ngram_state_and_scores = {
0: torch.zeros(1, dtype=torch.float32, device=device)
}
else:
ngram_state_and_scores = None
B.add(
Hypothesis(
ys=[blank_id] * context_size,
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
ngram_state_and_scores=ngram_state_and_scores,
)
)
encoder_out_len = torch.tensor([1])
decoder_out_len = torch.tensor([1])
for t in range(T):
# fmt: off
current_encoder_out = encoder_out[:, t:t+1, :]
# current_encoder_out is of shape (1, 1, encoder_out_dim)
# fmt: on
A = list(B)
B = HypothesisList()
# ys_log_probs contains both AM scores and LM scores
ys_log_probs = torch.cat(
[
hyp.log_prob.reshape(1, 1)
+ ngram_lm_scale * max(hyp.ngram_state_and_scores.values())
for hyp in A
]
)
# ys_log_probs is of shape (num_hyps, 1)
decoder_input = torch.tensor(
[hyp.ys[-context_size:] for hyp in A],
device=device,
)
# decoder_input is of shape (num_hyps, context_size)
decoder_out = model.decoder(decoder_input, need_pad=False)
# decoder_output is of shape (num_hyps, 1, decoder_output_dim)
current_encoder_out = current_encoder_out.expand(
decoder_out.size(0), 1, -1
)
logits = model.joiner(
current_encoder_out,
decoder_out,
encoder_out_len.expand(decoder_out.size(0)),
decoder_out_len.expand(decoder_out.size(0)),
)
vocab_size = logits.size(-1)
# logits is of shape (num_hyps, vocab_size)
log_probs = logits.log_softmax(dim=-1)
tot_log_probs = log_probs + ys_log_probs
_, topk_indexes = tot_log_probs.reshape(-1).topk(beam)
topk_log_probs = log_probs.reshape(-1)[topk_indexes]
# topk_hyp_indexes are indexes into `A`
topk_hyp_indexes = topk_indexes // logits.size(-1)
topk_token_indexes = topk_indexes % logits.size(-1)
topk_hyp_indexes, indexes = torch.sort(topk_hyp_indexes)
topk_token_indexes = topk_token_indexes[indexes]
topk_log_probs = topk_log_probs[indexes]
shape = k2.ragged.create_ragged_shape2(
row_ids=topk_hyp_indexes.to(torch.int32),
cached_tot_size=topk_hyp_indexes.numel(),
)
blank_log_probs = log_probs[topk_hyp_indexes, 0]
row_splits = shape.row_splits(1).tolist()
num_rows = len(row_splits) - 1
for i in range(num_rows):
start = row_splits[i]
end = row_splits[i + 1]
if start >= end:
# Discard A[i] as other hyps have higher log_probs
continue
tokens = topk_token_indexes[start:end]
hyps = shallow_fusion(
LG,
A[i],
tokens,
topk_log_probs[start:end],
vocab_size,
blank_log_probs[i],
)
for h in hyps:
B.add(h)
if len(B) > beam:
B = B.topk(beam, ngram_lm_scale=ngram_lm_scale)
best_hyp = B.get_most_probable(
length_norm=True, ngram_lm_scale=ngram_lm_scale
)
ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks
return ys
def beam_search(
model: Transducer,
encoder_out: torch.Tensor,

View File

@ -49,13 +49,19 @@ import argparse
import logging
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Tuple
from typing import Dict, List, Optional, Tuple
import k2
import sentencepiece as spm
import torch
import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule
from beam_search import beam_search, greedy_search, modified_beam_search
from beam_search import (
beam_search,
greedy_search,
modified_beam_search,
modified_beam_search_with_shallow_fusion,
)
from conformer import Conformer
from decoder import Decoder
from joiner import Joiner
@ -140,6 +146,22 @@ def get_parser():
Used only when --decoding_method is greedy_search""",
)
parser.add_argument(
"--LG",
type=str,
help="""Path to LG.pt for shallow fusion.
Used only when --decoding-method is modified_beam_search.""",
)
parser.add_argument(
"--ngram-lm-scale",
type=float,
default=0.1,
help="""Used when only --LG is provided.
The total score of a path is am_score + ngram_lm_scale * ngram_lm_score.
""",
)
return parser
@ -212,6 +234,7 @@ def decode_one_batch(
model: nn.Module,
sp: spm.SentencePieceProcessor,
batch: dict,
LG: Optional[k2.Fsa] = None,
) -> Dict[str, List[List[str]]]:
"""Decode one batch and return the result in a dict. The dict has the
following format:
@ -234,6 +257,9 @@ def decode_one_batch(
It is the return value from iterating
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
for the format of the `batch`.
LG:
Optional. Used for shallow fusion. Used only when params.decoding_method
is modified_beam_search.
Returns:
Return the decoding result. See above description for the format of
the returned dict.
@ -266,12 +292,25 @@ def decode_one_batch(
)
elif params.decoding_method == "beam_search":
hyp = beam_search(
model=model, encoder_out=encoder_out_i, beam=params.beam_size
model=model,
encoder_out=encoder_out_i,
beam=params.beam_size,
)
elif params.decoding_method == "modified_beam_search":
hyp = modified_beam_search(
model=model, encoder_out=encoder_out_i, beam=params.beam_size
)
if LG is None:
hyp = modified_beam_search(
model=model,
encoder_out=encoder_out_i,
beam=params.beam_size,
)
else:
hyp = modified_beam_search_with_shallow_fusion(
model=model,
encoder_out=encoder_out_i,
beam=params.beam_size,
LG=LG,
ngram_lm_scale=params.ngram_lm_scale,
)
else:
raise ValueError(
f"Unsupported decoding method: {params.decoding_method}"
@ -289,6 +328,7 @@ def decode_dataset(
params: AttributeDict,
model: nn.Module,
sp: spm.SentencePieceProcessor,
LG: Optional[k2.Fsa] = None,
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
"""Decode dataset.
@ -301,6 +341,9 @@ def decode_dataset(
The neural model.
sp:
The BPE model.
LG:
Optional. Used for shallow fusion. Used only when params.decoding_method
is modified_beam_search.
Returns:
Return a dict, whose key may be "greedy_search" if greedy search
is used, or it may be "beam_7" if beam size of 7 is used.
@ -329,6 +372,7 @@ def decode_dataset(
model=model,
sp=sp,
batch=batch,
LG=LG,
)
for name, hyps in hyps_dict.items():
@ -428,6 +472,25 @@ def main():
logging.info(f"Device: {device}")
if params.LG is not None:
assert (
params.decoding_method == "modified_beam_search"
), "--LG is used only when --decoding_method=modified_beam_search"
logging.info(f"Loading LG from {params.LG}")
LG = k2.Fsa.from_dict(torch.load(params.LG, map_location=device))
logging.info(
f"max: {LG.scores.max()}, min: {LG.scores.min()}, mean: {LG.scores.mean()}"
)
logging.info(f"LG properties: {LG.properties_str}")
logging.info(f"LG num_states: {LG.shape[0]}, num_arcs: {LG.num_arcs}")
# If LG is created by local/compile_lg.py, then it should be epsilon
# free, deterministic, and arc sorted
assert "ArcSorted" in LG.properties_str
assert "EpsilonFree" in LG.properties_str
assert "Deterministic" in LG.properties_str
else:
LG = None
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
@ -476,6 +539,7 @@ def main():
params=params,
model=model,
sp=sp,
LG=LG,
)
save_results(

View File

@ -0,0 +1,153 @@
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict
import k2
import torch
import copy
from utils import Hypothesis, HypothesisList
def shallow_fusion(
LG: k2.Fsa,
hyp: Hypothesis,
tokens: torch.Tensor,
log_probs: torch.Tensor,
vocab_size: int,
blank_log_prob: torch.Tensor,
) -> HypothesisList:
"""
Args:
LG:
An n-gram. It should be arc sorted, deterministic, and epsilon free.
It contains disambig IDs and back-off arcs.
hyp:
The current hypothesis.
tokens:
The possible tokens that will be expanded from the given `hyp`.
It is a 1-D tensor of dtype torch.int32.
log_probs:
It contains the acoustic log probabilities of each path that
is extended from `hyp.ys` with `tokens`.
log_probs.shape == tokens.shape.
vocab_size:
Vocabulary size, including the blank symbol. We assume that
token IDs >= vocab_size are disambig IDs (including the backoff
symbol #0).
blank_log_prob:
The log_prob for the blank token at this frame. It is from
the output of the joiner.
Returns:
Return new hypotheses by extending the given `hyp` with tokens in the
given `tokens`.
"""
row_splits = LG.arcs.row_splits(1)
arcs = LG.arcs.values()
state_and_scores = copy.deepcopy(hyp.ngram_state_and_scores)
current_states = list(state_and_scores.keys())
# Process out-going arcs with label equal to disambig tokens or #0
while len(current_states) > 0:
s = current_states.pop()
labels_begin = row_splits[s]
labels_end = row_splits[s + 1]
labels = LG.labels[labels_begin:labels_end].contiguous()
for i in reversed(range(labels.numel())):
lab = labels[i]
if lab == -1:
# Note: When sorting arcs, k2 treats arc labels as
# unsigned types
continue
if lab < vocab_size:
# Since LG is arc sorted, we can exit
# the for loop as soon as we have a label
# with ID less than vocab_size
break
# This is a diambig token or #0
idx = labels_begin + i
next_state = arcs[idx][1].item()
score = LG.scores[idx] + state_and_scores[s]
if next_state not in state_and_scores:
state_and_scores[next_state] = score
current_states.append(next_state)
else:
state_and_scores[next_state] = max(
score, state_and_scores[next_state]
)
current_states = list(state_and_scores.keys())
ans = HypothesisList()
device = log_probs.device
for s in current_states:
labels_begin = row_splits[s]
labels_end = row_splits[s + 1]
labels = LG.labels[labels_begin:labels_end].contiguous()
if labels[-1] == -1:
labels = labels[:-1]
if s != 0:
# We add a backoff arc to the start state. Otherwise,
# all activate state may die due to out-of-Vocabulary word.
new_hyp = Hypothesis(
ys=hyp.ys[:],
log_prob=hyp.log_prob + blank_log_prob,
ngram_state_and_scores={
# -20 is the cost on the backoff arc to the start state.
# As LG.scores.min() is about -16.6, we choose -20 here.
# You may need to tune this value.
0: torch.full((1,), -20, dtype=torch.float32, device=device)
},
)
ans.add(new_hyp)
pos = torch.searchsorted(labels, tokens)
for i in range(pos.numel()):
if tokens[i] == 0:
# blank ID
new_hyp = Hypothesis(
ys=hyp.ys[:],
log_prob=hyp.log_prob + log_probs[i],
ngram_state_and_scores=hyp.ngram_state_and_scores,
)
ans.add(new_hyp)
continue
elif pos[i] >= labels.numel() or labels[pos[i]] != tokens[i]:
# No out-going arcs from this state has labels
# equal to tokens[i]
continue
# Found one arc
idx = labels_begin + pos[i]
next_state = arcs[idx][1].item()
score = LG.scores[idx] + state_and_scores[s]
new_hyp = Hypothesis(
ys=hyp.ys + [tokens[i].item()],
log_prob=hyp.log_prob + log_probs[i],
ngram_state_and_scores={next_state: score},
)
ans.add(new_hyp)
return ans

View File

@ -0,0 +1,219 @@
# Copyright 2021-2022 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Dict, List, Optional
import torch
@dataclass
class Hypothesis:
# The predicted tokens so far.
# Newly predicted tokens are appended to `ys`.
ys: List[int]
# The log prob of ys.
# It contains only one entry.
# Note: It contains only the acoustic part.
log_prob: torch.Tensor
# Used for shallow fusion
# The key of the dict is a state index into LG
# while the corresponding value is the LM score
# reaching this state from the start state.
# Note: The value tensor contains only a single entry
# and it contains only the LM part.
ngram_state_and_scores: Optional[Dict[int, torch.Tensor]] = None
@property
def key(self) -> str:
"""Return a string representation of self.ys"""
return "_".join(map(str, self.ys))
class HypothesisList(object):
def __init__(self, data: Optional[Dict[str, Hypothesis]] = None) -> None:
"""
Args:
data:
A dict of Hypotheses. Its key is its `value.key`.
"""
if data is None:
self._data = {}
else:
self._data = data
@property
def data(self) -> Dict[str, Hypothesis]:
return self._data
def add(self, hyp: Hypothesis) -> None:
"""Add a Hypothesis to `self`.
If `hyp` already exists in `self`, its probability is updated using
`log-sum-exp` with the existed one.
Args:
hyp:
The hypothesis to be added.
"""
key = hyp.key
if key in self:
old_hyp = self._data[key] # shallow copy
if False:
old_hyp.log_prob = torch.logaddexp(
old_hyp.log_prob, hyp.log_prob
)
else:
old_hyp.log_prob = max(old_hyp.log_prob, hyp.log_prob)
if hyp.ngram_state_and_scores is not None:
for state, score in hyp.ngram_state_and_scores.items():
if (
state in old_hyp.ngram_state_and_scores
and score > old_hyp.ngram_state_and_scores[state]
):
old_hyp.ngram_state_and_scores[state] = score
else:
old_hyp.ngram_state_and_scores[state] = score
else:
self._data[key] = hyp
def get_most_probable(
self, length_norm: bool = False, ngram_lm_scale: Optional[float] = None
) -> Hypothesis:
"""Get the most probable hypothesis, i.e., the one with
the largest `log_prob`.
Args:
length_norm:
If True, the `log_prob` of a hypothesis is normalized by the
number of tokens in it.
ngram_lm_scale:
If not None, it specifies the scale applied to the LM score.
Returns:
Return the hypothesis that has the largest `log_prob`.
"""
if length_norm:
if ngram_lm_scale is None:
return max(
self._data.values(),
key=lambda hyp: hyp.log_prob / len(hyp.ys),
)
else:
return max(
self._data.values(),
key=lambda hyp: (
hyp.log_prob
+ ngram_lm_scale
* max(hyp.ngram_state_and_scores.values())
)
/ len(hyp.ys),
)
else:
if ngram_lm_scale is None:
return max(self._data.values(), key=lambda hyp: hyp.log_prob)
else:
return max(
self._data.values(),
key=lambda hyp: hyp.log_prob
+ ngram_lm_scale * max(hyp.ngram_state_and_scores.values()),
)
def remove(self, hyp: Hypothesis) -> None:
"""Remove a given hypothesis.
Caution:
`self` is modified **in-place**.
Args:
hyp:
The hypothesis to be removed from `self`.
Note: It must be contained in `self`. Otherwise,
an exception is raised.
"""
key = hyp.key
assert key in self, f"{key} does not exist"
del self._data[key]
def filter(
self, threshold: torch.Tensor, ngram_lm_scale: Optional[float] = None
) -> "HypothesisList":
"""Remove all Hypotheses whose log_prob is less than threshold.
Caution:
`self` is not modified. Instead, a new HypothesisList is returned.
Args:
threshold:
Hypotheses with log_prob less than this value are removed.
ngram_lm_scale:
If not None, it specifies the scale applied to the LM score.
Returns:
Return a new HypothesisList containing all hypotheses from `self`
with `log_prob` being greater than the given `threshold`.
"""
ans = HypothesisList()
if ngram_lm_scale is None:
for _, hyp in self._data.items():
if hyp.log_prob > threshold:
ans.add(hyp) # shallow copy
else:
for _, hyp in self._data.items():
if (
hyp.log_prob
+ ngram_lm_scale * max(hyp.ngram_state_and_scores.values())
> threshold
):
ans.add(hyp) # shallow copy
return ans
def topk(
self, k: int, ngram_lm_scale: Optional[float] = None
) -> "HypothesisList":
"""Return the top-k hypothesis."""
hyps = list(self._data.items())
if ngram_lm_scale is None:
hyps = sorted(hyps, key=lambda h: h[1].log_prob, reverse=True)[:k]
else:
hyps = sorted(
hyps,
key=lambda h: h[1].log_prob
+ ngram_lm_scale * max(h[1].ngram_state_and_scores.values()),
reverse=True,
)[:k]
ans = HypothesisList(dict(hyps))
return ans
def __contains__(self, key: str):
return key in self._data
def __iter__(self):
return iter(self._data.values())
def __len__(self) -> int:
return len(self._data)
def __str__(self) -> str:
s = []
for key in self:
s.append(key)
return ", ".join(s)