mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-07 08:04:18 +00:00
Add backoff arcs to the start state to handle OOV word.
This commit is contained in:
parent
5af23efa69
commit
adb54aea91
@ -14,13 +14,13 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from typing import Dict, List, Optional
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
import k2
|
import k2
|
||||||
import torch
|
import torch
|
||||||
from model import Transducer
|
from model import Transducer
|
||||||
from shallow_fusion import shallow_fusion
|
from shallow_fusion import shallow_fusion
|
||||||
|
from utils import Hypothesis, HypothesisList
|
||||||
|
|
||||||
|
|
||||||
def greedy_search(
|
def greedy_search(
|
||||||
@ -103,153 +103,6 @@ def greedy_search(
|
|||||||
return hyp
|
return hyp
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class Hypothesis:
|
|
||||||
# The predicted tokens so far.
|
|
||||||
# Newly predicted tokens are appended to `ys`.
|
|
||||||
ys: List[int]
|
|
||||||
|
|
||||||
# The log prob of ys.
|
|
||||||
# It contains only one entry.
|
|
||||||
log_prob: torch.Tensor
|
|
||||||
|
|
||||||
# Used for shallow fusion
|
|
||||||
# The key of the dict is a state index into LG
|
|
||||||
# while the corresponding value is the LM score
|
|
||||||
# reaching this state.
|
|
||||||
# Note: The value tensor contains only a single entry
|
|
||||||
ngram_state_and_scores: Optional[Dict[int, torch.Tensor]] = (None,)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def key(self) -> str:
|
|
||||||
"""Return a string representation of self.ys"""
|
|
||||||
return "_".join(map(str, self.ys))
|
|
||||||
|
|
||||||
|
|
||||||
class HypothesisList(object):
|
|
||||||
def __init__(self, data: Optional[Dict[str, Hypothesis]] = None) -> None:
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
data:
|
|
||||||
A dict of Hypotheses. Its key is its `value.key`.
|
|
||||||
"""
|
|
||||||
if data is None:
|
|
||||||
self._data = {}
|
|
||||||
else:
|
|
||||||
self._data = data
|
|
||||||
|
|
||||||
@property
|
|
||||||
def data(self) -> Dict[str, Hypothesis]:
|
|
||||||
return self._data
|
|
||||||
|
|
||||||
def add(self, hyp: Hypothesis) -> None:
|
|
||||||
"""Add a Hypothesis to `self`.
|
|
||||||
|
|
||||||
If `hyp` already exists in `self`, its probability is updated using
|
|
||||||
`log-sum-exp` with the existed one.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
hyp:
|
|
||||||
The hypothesis to be added.
|
|
||||||
"""
|
|
||||||
key = hyp.key
|
|
||||||
if key in self:
|
|
||||||
old_hyp = self._data[key] # shallow copy
|
|
||||||
|
|
||||||
if True:
|
|
||||||
old_hyp.log_prob = torch.logaddexp(
|
|
||||||
old_hyp.log_prob, hyp.log_prob
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
old_hyp.log_prob = max(old_hyp.log_prob, hyp.log_prob)
|
|
||||||
|
|
||||||
if hyp.ngram_state_and_scores is not None:
|
|
||||||
for state, score in hyp.ngram_state_and_scores.items():
|
|
||||||
if (
|
|
||||||
state in old_hyp.ngram_state_and_scores
|
|
||||||
and score > old_hyp.ngram_state_and_scores[state]
|
|
||||||
):
|
|
||||||
old_hyp.ngram_state_and_scores[state] = score
|
|
||||||
else:
|
|
||||||
old_hyp.ngram_state_and_scores[state] = score
|
|
||||||
else:
|
|
||||||
self._data[key] = hyp
|
|
||||||
|
|
||||||
def get_most_probable(self, length_norm: bool = False) -> Hypothesis:
|
|
||||||
"""Get the most probable hypothesis, i.e., the one with
|
|
||||||
the largest `log_prob`.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
length_norm:
|
|
||||||
If True, the `log_prob` of a hypothesis is normalized by the
|
|
||||||
number of tokens in it.
|
|
||||||
Returns:
|
|
||||||
Return the hypothesis that has the largest `log_prob`.
|
|
||||||
"""
|
|
||||||
if length_norm:
|
|
||||||
return max(
|
|
||||||
self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
return max(self._data.values(), key=lambda hyp: hyp.log_prob)
|
|
||||||
|
|
||||||
def remove(self, hyp: Hypothesis) -> None:
|
|
||||||
"""Remove a given hypothesis.
|
|
||||||
|
|
||||||
Caution:
|
|
||||||
`self` is modified **in-place**.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
hyp:
|
|
||||||
The hypothesis to be removed from `self`.
|
|
||||||
Note: It must be contained in `self`. Otherwise,
|
|
||||||
an exception is raised.
|
|
||||||
"""
|
|
||||||
key = hyp.key
|
|
||||||
assert key in self, f"{key} does not exist"
|
|
||||||
del self._data[key]
|
|
||||||
|
|
||||||
def filter(self, threshold: torch.Tensor) -> "HypothesisList":
|
|
||||||
"""Remove all Hypotheses whose log_prob is less than threshold.
|
|
||||||
|
|
||||||
Caution:
|
|
||||||
`self` is not modified. Instead, a new HypothesisList is returned.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Return a new HypothesisList containing all hypotheses from `self`
|
|
||||||
with `log_prob` being greater than the given `threshold`.
|
|
||||||
"""
|
|
||||||
ans = HypothesisList()
|
|
||||||
for _, hyp in self._data.items():
|
|
||||||
if hyp.log_prob > threshold:
|
|
||||||
ans.add(hyp) # shallow copy
|
|
||||||
return ans
|
|
||||||
|
|
||||||
def topk(self, k: int) -> "HypothesisList":
|
|
||||||
"""Return the top-k hypothesis."""
|
|
||||||
hyps = list(self._data.items())
|
|
||||||
|
|
||||||
hyps = sorted(hyps, key=lambda h: h[1].log_prob, reverse=True)[:k]
|
|
||||||
|
|
||||||
ans = HypothesisList(dict(hyps))
|
|
||||||
return ans
|
|
||||||
|
|
||||||
def __contains__(self, key: str):
|
|
||||||
return key in self._data
|
|
||||||
|
|
||||||
def __iter__(self):
|
|
||||||
return iter(self._data.values())
|
|
||||||
|
|
||||||
def __len__(self) -> int:
|
|
||||||
return len(self._data)
|
|
||||||
|
|
||||||
def __str__(self) -> str:
|
|
||||||
s = []
|
|
||||||
for key in self:
|
|
||||||
s.append(key)
|
|
||||||
return ", ".join(s)
|
|
||||||
|
|
||||||
|
|
||||||
def run_decoder(
|
def run_decoder(
|
||||||
ys: List[int],
|
ys: List[int],
|
||||||
model: Transducer,
|
model: Transducer,
|
||||||
@ -341,6 +194,113 @@ def modified_beam_search(
|
|||||||
model: Transducer,
|
model: Transducer,
|
||||||
encoder_out: torch.Tensor,
|
encoder_out: torch.Tensor,
|
||||||
beam: int = 4,
|
beam: int = 4,
|
||||||
|
) -> List[int]:
|
||||||
|
"""It limits the maximum number of symbols per frame to 1.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model:
|
||||||
|
An instance of `Transducer`.
|
||||||
|
encoder_out:
|
||||||
|
A tensor of shape (N, T, C) from the encoder. Support only N==1 for now.
|
||||||
|
beam:
|
||||||
|
Beam size.
|
||||||
|
Returns:
|
||||||
|
Return the decoded result.
|
||||||
|
"""
|
||||||
|
|
||||||
|
assert encoder_out.ndim == 3
|
||||||
|
|
||||||
|
# support only batch_size == 1 for now
|
||||||
|
assert encoder_out.size(0) == 1, encoder_out.size(0)
|
||||||
|
blank_id = model.decoder.blank_id
|
||||||
|
context_size = model.decoder.context_size
|
||||||
|
|
||||||
|
device = model.device
|
||||||
|
|
||||||
|
decoder_input = torch.tensor(
|
||||||
|
[blank_id] * context_size, device=device
|
||||||
|
).reshape(1, context_size)
|
||||||
|
|
||||||
|
decoder_out = model.decoder(decoder_input, need_pad=False)
|
||||||
|
|
||||||
|
T = encoder_out.size(1)
|
||||||
|
|
||||||
|
B = HypothesisList()
|
||||||
|
B.add(
|
||||||
|
Hypothesis(
|
||||||
|
ys=[blank_id] * context_size,
|
||||||
|
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
encoder_out_len = torch.tensor([1])
|
||||||
|
decoder_out_len = torch.tensor([1])
|
||||||
|
|
||||||
|
for t in range(T):
|
||||||
|
# fmt: off
|
||||||
|
current_encoder_out = encoder_out[:, t:t+1, :]
|
||||||
|
# current_encoder_out is of shape (1, 1, encoder_out_dim)
|
||||||
|
# fmt: on
|
||||||
|
A = list(B)
|
||||||
|
B = HypothesisList()
|
||||||
|
|
||||||
|
ys_log_probs = torch.cat([hyp.log_prob.reshape(1, 1) for hyp in A])
|
||||||
|
# ys_log_probs is of shape (num_hyps, 1)
|
||||||
|
|
||||||
|
decoder_input = torch.tensor(
|
||||||
|
[hyp.ys[-context_size:] for hyp in A],
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
# decoder_input is of shape (num_hyps, context_size)
|
||||||
|
|
||||||
|
decoder_out = model.decoder(decoder_input, need_pad=False)
|
||||||
|
# decoder_output is of shape (num_hyps, 1, decoder_output_dim)
|
||||||
|
|
||||||
|
current_encoder_out = current_encoder_out.expand(
|
||||||
|
decoder_out.size(0), 1, -1
|
||||||
|
)
|
||||||
|
|
||||||
|
logits = model.joiner(
|
||||||
|
current_encoder_out,
|
||||||
|
decoder_out,
|
||||||
|
encoder_out_len.expand(decoder_out.size(0)),
|
||||||
|
decoder_out_len.expand(decoder_out.size(0)),
|
||||||
|
)
|
||||||
|
# logits is of shape (num_hyps, vocab_size)
|
||||||
|
log_probs = logits.log_softmax(dim=-1)
|
||||||
|
|
||||||
|
log_probs.add_(ys_log_probs)
|
||||||
|
|
||||||
|
log_probs = log_probs.reshape(-1)
|
||||||
|
topk_log_probs, topk_indexes = log_probs.topk(beam)
|
||||||
|
|
||||||
|
# topk_hyp_indexes are indexes into `A`
|
||||||
|
topk_hyp_indexes = topk_indexes // logits.size(-1)
|
||||||
|
topk_token_indexes = topk_indexes % logits.size(-1)
|
||||||
|
|
||||||
|
topk_hyp_indexes = topk_hyp_indexes.tolist()
|
||||||
|
topk_token_indexes = topk_token_indexes.tolist()
|
||||||
|
|
||||||
|
for i in range(len(topk_hyp_indexes)):
|
||||||
|
hyp = A[topk_hyp_indexes[i]]
|
||||||
|
new_ys = hyp.ys[:]
|
||||||
|
new_token = topk_token_indexes[i]
|
||||||
|
if new_token != blank_id:
|
||||||
|
new_ys.append(new_token)
|
||||||
|
new_log_prob = topk_log_probs[i]
|
||||||
|
new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob)
|
||||||
|
B.add(new_hyp)
|
||||||
|
|
||||||
|
best_hyp = B.get_most_probable(length_norm=True)
|
||||||
|
ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks
|
||||||
|
|
||||||
|
return ys
|
||||||
|
|
||||||
|
|
||||||
|
def modified_beam_search_with_shallow_fusion(
|
||||||
|
model: Transducer,
|
||||||
|
encoder_out: torch.Tensor,
|
||||||
|
beam: int = 4,
|
||||||
LG: Optional[k2.Fsa] = None,
|
LG: Optional[k2.Fsa] = None,
|
||||||
ngram_lm_scale: float = 0.1,
|
ngram_lm_scale: float = 0.1,
|
||||||
) -> List[int]:
|
) -> List[int]:
|
||||||
@ -408,7 +368,14 @@ def modified_beam_search(
|
|||||||
A = list(B)
|
A = list(B)
|
||||||
B = HypothesisList()
|
B = HypothesisList()
|
||||||
|
|
||||||
ys_log_probs = torch.cat([hyp.log_prob.reshape(1, 1) for hyp in A])
|
# ys_log_probs contains both AM scores and LM scores
|
||||||
|
ys_log_probs = torch.cat(
|
||||||
|
[
|
||||||
|
hyp.log_prob.reshape(1, 1)
|
||||||
|
+ ngram_lm_scale * max(hyp.ngram_state_and_scores.values())
|
||||||
|
for hyp in A
|
||||||
|
]
|
||||||
|
)
|
||||||
# ys_log_probs is of shape (num_hyps, 1)
|
# ys_log_probs is of shape (num_hyps, 1)
|
||||||
|
|
||||||
decoder_input = torch.tensor(
|
decoder_input = torch.tensor(
|
||||||
@ -434,62 +401,52 @@ def modified_beam_search(
|
|||||||
# logits is of shape (num_hyps, vocab_size)
|
# logits is of shape (num_hyps, vocab_size)
|
||||||
log_probs = logits.log_softmax(dim=-1)
|
log_probs = logits.log_softmax(dim=-1)
|
||||||
|
|
||||||
log_probs.add_(ys_log_probs)
|
tot_log_probs = log_probs + ys_log_probs
|
||||||
|
|
||||||
log_probs = log_probs.reshape(-1)
|
_, topk_indexes = tot_log_probs.reshape(-1).topk(beam)
|
||||||
topk_log_probs, topk_indexes = log_probs.topk(beam)
|
topk_log_probs = log_probs.reshape(-1)[topk_indexes]
|
||||||
|
|
||||||
# topk_hyp_indexes are indexes into `A`
|
# topk_hyp_indexes are indexes into `A`
|
||||||
topk_hyp_indexes = topk_indexes // logits.size(-1)
|
topk_hyp_indexes = topk_indexes // logits.size(-1)
|
||||||
topk_token_indexes = topk_indexes % logits.size(-1)
|
topk_token_indexes = topk_indexes % logits.size(-1)
|
||||||
|
|
||||||
topk_hyp_indexes = topk_hyp_indexes.tolist()
|
topk_hyp_indexes, indexes = torch.sort(topk_hyp_indexes)
|
||||||
topk_token_indexes = topk_token_indexes.tolist()
|
topk_token_indexes = topk_token_indexes[indexes]
|
||||||
|
topk_log_probs = topk_log_probs[indexes]
|
||||||
|
|
||||||
# import pdb
|
shape = k2.ragged.create_ragged_shape2(
|
||||||
#
|
row_ids=topk_hyp_indexes.to(torch.int32),
|
||||||
# pdb.set_trace()
|
cached_tot_size=topk_hyp_indexes.numel(),
|
||||||
for i in range(len(topk_hyp_indexes)):
|
)
|
||||||
hyp = A[topk_hyp_indexes[i]]
|
blank_log_probs = log_probs[topk_hyp_indexes, 0]
|
||||||
new_ys = hyp.ys[:]
|
|
||||||
new_token = topk_token_indexes[i]
|
|
||||||
if new_token != blank_id:
|
|
||||||
new_ys.append(new_token)
|
|
||||||
else:
|
|
||||||
ngram_state_and_scores = hyp.ngram_state_and_scores
|
|
||||||
|
|
||||||
new_log_prob = topk_log_probs[i]
|
row_splits = shape.row_splits(1).tolist()
|
||||||
|
num_rows = len(row_splits) - 1
|
||||||
|
for i in range(num_rows):
|
||||||
|
start = row_splits[i]
|
||||||
|
end = row_splits[i + 1]
|
||||||
|
if start >= end:
|
||||||
|
# Discard A[i] as other hyps have higher log_probs
|
||||||
|
continue
|
||||||
|
tokens = topk_token_indexes[start:end]
|
||||||
|
|
||||||
if enable_shallow_fusion and new_token != blank_id:
|
hyps = shallow_fusion(
|
||||||
ngram_state_and_scores = shallow_fusion(
|
LG,
|
||||||
LG,
|
A[i],
|
||||||
new_token,
|
tokens,
|
||||||
hyp.ngram_state_and_scores,
|
topk_log_probs[start:end],
|
||||||
vocab_size,
|
vocab_size,
|
||||||
)
|
blank_log_probs[i],
|
||||||
if len(ngram_state_and_scores) == 0:
|
|
||||||
continue
|
|
||||||
max_ngram_score = max(ngram_state_and_scores.values())
|
|
||||||
new_log_prob = new_log_prob + ngram_lm_scale * max_ngram_score
|
|
||||||
|
|
||||||
# TODO: Get the maximum scores in ngram_state_and_scores
|
|
||||||
# and add it to new_log_prob
|
|
||||||
|
|
||||||
new_hyp = Hypothesis(
|
|
||||||
ys=new_ys,
|
|
||||||
log_prob=new_log_prob,
|
|
||||||
ngram_state_and_scores=ngram_state_and_scores,
|
|
||||||
)
|
)
|
||||||
|
for h in hyps:
|
||||||
B.add(new_hyp)
|
|
||||||
if len(B) == 0:
|
|
||||||
import logging
|
|
||||||
|
|
||||||
logging.info("\n*****\nEmpty states!\n***\n")
|
|
||||||
for h in A:
|
|
||||||
B.add(h)
|
B.add(h)
|
||||||
|
|
||||||
best_hyp = B.get_most_probable(length_norm=True)
|
if len(B) > beam:
|
||||||
|
B = B.topk(beam, ngram_lm_scale=ngram_lm_scale)
|
||||||
|
|
||||||
|
best_hyp = B.get_most_probable(
|
||||||
|
length_norm=True, ngram_lm_scale=ngram_lm_scale
|
||||||
|
)
|
||||||
ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks
|
ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks
|
||||||
|
|
||||||
return ys
|
return ys
|
||||||
|
@ -47,7 +47,12 @@ import sentencepiece as spm
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from asr_datamodule import LibriSpeechAsrDataModule
|
from asr_datamodule import LibriSpeechAsrDataModule
|
||||||
from beam_search import beam_search, greedy_search, modified_beam_search
|
from beam_search import (
|
||||||
|
beam_search,
|
||||||
|
greedy_search,
|
||||||
|
modified_beam_search,
|
||||||
|
modified_beam_search_with_shallow_fusion,
|
||||||
|
)
|
||||||
from conformer import Conformer
|
from conformer import Conformer
|
||||||
from decoder import Decoder
|
from decoder import Decoder
|
||||||
from joiner import Joiner
|
from joiner import Joiner
|
||||||
@ -283,23 +288,25 @@ def decode_one_batch(
|
|||||||
beam=params.beam_size,
|
beam=params.beam_size,
|
||||||
)
|
)
|
||||||
elif params.decoding_method == "modified_beam_search":
|
elif params.decoding_method == "modified_beam_search":
|
||||||
hyp = modified_beam_search(
|
if LG is None:
|
||||||
model=model,
|
hyp = modified_beam_search(
|
||||||
encoder_out=encoder_out_i,
|
model=model,
|
||||||
beam=params.beam_size,
|
encoder_out=encoder_out_i,
|
||||||
LG=LG,
|
beam=params.beam_size,
|
||||||
ngram_lm_scale=params.ngram_lm_scale,
|
)
|
||||||
)
|
else:
|
||||||
|
hyp = modified_beam_search_with_shallow_fusion(
|
||||||
|
model=model,
|
||||||
|
encoder_out=encoder_out_i,
|
||||||
|
beam=params.beam_size,
|
||||||
|
LG=LG,
|
||||||
|
ngram_lm_scale=params.ngram_lm_scale,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unsupported decoding method: {params.decoding_method}"
|
f"Unsupported decoding method: {params.decoding_method}"
|
||||||
)
|
)
|
||||||
hyps.append(sp.decode(hyp).split())
|
hyps.append(sp.decode(hyp).split())
|
||||||
s = "\n"
|
|
||||||
for h in hyps:
|
|
||||||
s += " ".join(h)
|
|
||||||
s += "\n"
|
|
||||||
logging.info(s)
|
|
||||||
|
|
||||||
if params.decoding_method == "greedy_search":
|
if params.decoding_method == "greedy_search":
|
||||||
return {"greedy_search": hyps}
|
return {"greedy_search": hyps}
|
||||||
@ -349,8 +356,6 @@ def decode_dataset(
|
|||||||
|
|
||||||
results = defaultdict(list)
|
results = defaultdict(list)
|
||||||
for batch_idx, batch in enumerate(dl):
|
for batch_idx, batch in enumerate(dl):
|
||||||
if batch_idx > 10:
|
|
||||||
break
|
|
||||||
texts = batch["supervisions"]["text"]
|
texts = batch["supervisions"]["text"]
|
||||||
|
|
||||||
hyps_dict = decode_one_batch(
|
hyps_dict = decode_one_batch(
|
||||||
@ -464,6 +469,9 @@ def main():
|
|||||||
), "--LG is used only when --decoding_method=modified_beam_search"
|
), "--LG is used only when --decoding_method=modified_beam_search"
|
||||||
logging.info(f"Loading LG from {params.LG}")
|
logging.info(f"Loading LG from {params.LG}")
|
||||||
LG = k2.Fsa.from_dict(torch.load(params.LG, map_location=device))
|
LG = k2.Fsa.from_dict(torch.load(params.LG, map_location=device))
|
||||||
|
logging.info(
|
||||||
|
f"max: {LG.scores.max()}, min: {LG.scores.min()}, mean: {LG.scores.mean()}"
|
||||||
|
)
|
||||||
logging.info(f"LG properties: {LG.properties_str}")
|
logging.info(f"LG properties: {LG.properties_str}")
|
||||||
logging.info(f"LG num_states: {LG.shape[0]}, num_arcs: {LG.num_arcs}")
|
logging.info(f"LG num_states: {LG.shape[0]}, num_arcs: {LG.num_arcs}")
|
||||||
# If LG is created by local/compile_lg.py, then it should be epsilon
|
# If LG is created by local/compile_lg.py, then it should be epsilon
|
||||||
@ -517,8 +525,6 @@ def main():
|
|||||||
test_dl = [test_clean_dl, test_other_dl]
|
test_dl = [test_clean_dl, test_other_dl]
|
||||||
|
|
||||||
for test_set, test_dl in zip(test_sets, test_dl):
|
for test_set, test_dl in zip(test_sets, test_dl):
|
||||||
if test_set == "test-other":
|
|
||||||
break
|
|
||||||
results_dict = decode_dataset(
|
results_dict = decode_dataset(
|
||||||
dl=test_dl,
|
dl=test_dl,
|
||||||
params=params,
|
params=params,
|
||||||
|
@ -19,39 +19,51 @@ from typing import Dict
|
|||||||
import k2
|
import k2
|
||||||
import torch
|
import torch
|
||||||
import copy
|
import copy
|
||||||
|
from utils import Hypothesis, HypothesisList
|
||||||
|
|
||||||
|
|
||||||
def shallow_fusion(
|
def shallow_fusion(
|
||||||
LG: k2.Fsa,
|
LG: k2.Fsa,
|
||||||
token: int,
|
hyp: Hypothesis,
|
||||||
state_and_scores: Dict[int, torch.Tensor],
|
tokens: torch.Tensor,
|
||||||
|
log_probs: torch.Tensor,
|
||||||
vocab_size: int,
|
vocab_size: int,
|
||||||
) -> Dict[int, torch.Tensor]:
|
blank_log_prob: torch.Tensor,
|
||||||
|
) -> HypothesisList:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
LG:
|
LG:
|
||||||
An n-gram. It should be arc sorted, deterministic, and epsilon free.
|
An n-gram. It should be arc sorted, deterministic, and epsilon free.
|
||||||
token:
|
It contains disambig IDs and back-off arcs.
|
||||||
The input token ID.
|
hyp:
|
||||||
state_and_scores:
|
The current hypothesis.
|
||||||
The keys contain the current state we are in and the
|
tokens:
|
||||||
values are the LM log_prob for reaching the corresponding
|
The possible tokens that will be expanded from the given `hyp`.
|
||||||
states from the start state.
|
It is a 1-D tensor of dtype torch.int32.
|
||||||
|
log_probs:
|
||||||
|
It contains the acoustic log probabilities of each path that
|
||||||
|
is extended from `hyp.ys` with `tokens`.
|
||||||
|
log_probs.shape == tokens.shape.
|
||||||
vocab_size:
|
vocab_size:
|
||||||
Vocabulary size, including the blank symbol. We assume that
|
Vocabulary size, including the blank symbol. We assume that
|
||||||
token IDs >= vocab_size are disambig IDs (including the backoff
|
token IDs >= vocab_size are disambig IDs (including the backoff
|
||||||
symbol #0).
|
symbol #0).
|
||||||
|
blank_log_prob:
|
||||||
|
The log_prob for the blank token at this frame. It is from
|
||||||
|
the output of the joiner.
|
||||||
Returns:
|
Returns:
|
||||||
Return a new state_and_scores.
|
Return new hypotheses by extending the given `hyp` with tokens in the
|
||||||
|
given `tokens`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
row_splits = LG.arcs.row_splits(1)
|
row_splits = LG.arcs.row_splits(1)
|
||||||
arcs = LG.arcs.values()
|
arcs = LG.arcs.values()
|
||||||
|
|
||||||
state_and_scores = copy.deepcopy(state_and_scores)
|
state_and_scores = copy.deepcopy(hyp.ngram_state_and_scores)
|
||||||
|
|
||||||
current_states = list(state_and_scores.keys())
|
current_states = list(state_and_scores.keys())
|
||||||
|
|
||||||
# Process out-going arcs with label being disambig tokens and #0
|
# Process out-going arcs with label equal to disambig tokens or #0
|
||||||
while len(current_states) > 0:
|
while len(current_states) > 0:
|
||||||
s = current_states.pop()
|
s = current_states.pop()
|
||||||
labels_begin = row_splits[s]
|
labels_begin = row_splits[s]
|
||||||
@ -84,7 +96,9 @@ def shallow_fusion(
|
|||||||
)
|
)
|
||||||
|
|
||||||
current_states = list(state_and_scores.keys())
|
current_states = list(state_and_scores.keys())
|
||||||
ans = dict()
|
ans = HypothesisList()
|
||||||
|
|
||||||
|
device = log_probs.device
|
||||||
for s in current_states:
|
for s in current_states:
|
||||||
labels_begin = row_splits[s]
|
labels_begin = row_splits[s]
|
||||||
labels_end = row_splits[s + 1]
|
labels_end = row_splits[s + 1]
|
||||||
@ -93,17 +107,47 @@ def shallow_fusion(
|
|||||||
if labels[-1] == -1:
|
if labels[-1] == -1:
|
||||||
labels = labels[:-1]
|
labels = labels[:-1]
|
||||||
|
|
||||||
pos = torch.searchsorted(labels, token)
|
if s != 0:
|
||||||
if pos >= labels.numel() or labels[pos] != token:
|
# We add a backoff arc to the start state. Otherwise,
|
||||||
continue
|
# all activate state may die due to out-of-Vocabulary word.
|
||||||
|
new_hyp = Hypothesis(
|
||||||
|
ys=hyp.ys[:],
|
||||||
|
log_prob=hyp.log_prob + blank_log_prob,
|
||||||
|
ngram_state_and_scores={
|
||||||
|
# -20 is the cost on the backoff arc to the start state.
|
||||||
|
# As LG.scores.min() is about -16.6, we choose -20 here.
|
||||||
|
# You may need to tune this value.
|
||||||
|
0: torch.full((1,), -20, dtype=torch.float32, device=device)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
ans.add(new_hyp)
|
||||||
|
|
||||||
idx = labels_begin + pos
|
pos = torch.searchsorted(labels, tokens)
|
||||||
next_state = arcs[idx][1].item()
|
for i in range(pos.numel()):
|
||||||
score = LG.scores[idx] + state_and_scores[s]
|
if tokens[i] == 0:
|
||||||
|
# blank ID
|
||||||
|
new_hyp = Hypothesis(
|
||||||
|
ys=hyp.ys[:],
|
||||||
|
log_prob=hyp.log_prob + log_probs[i],
|
||||||
|
ngram_state_and_scores=hyp.ngram_state_and_scores,
|
||||||
|
)
|
||||||
|
ans.add(new_hyp)
|
||||||
|
continue
|
||||||
|
elif pos[i] >= labels.numel() or labels[pos[i]] != tokens[i]:
|
||||||
|
# No out-going arcs from this state has labels
|
||||||
|
# equal to tokens[i]
|
||||||
|
continue
|
||||||
|
|
||||||
if next_state not in ans:
|
# Found one arc
|
||||||
ans[next_state] = score
|
|
||||||
else:
|
idx = labels_begin + pos[i]
|
||||||
ans[next_state] = max(score, ans[next_state])
|
next_state = arcs[idx][1].item()
|
||||||
|
score = LG.scores[idx] + state_and_scores[s]
|
||||||
|
new_hyp = Hypothesis(
|
||||||
|
ys=hyp.ys + [tokens[i].item()],
|
||||||
|
log_prob=hyp.log_prob + log_probs[i],
|
||||||
|
ngram_state_and_scores={next_state: score},
|
||||||
|
)
|
||||||
|
ans.add(new_hyp)
|
||||||
|
|
||||||
return ans
|
return ans
|
||||||
|
219
egs/librispeech/ASR/transducer_stateless/utils.py
Normal file
219
egs/librispeech/ASR/transducer_stateless/utils.py
Normal file
@ -0,0 +1,219 @@
|
|||||||
|
# Copyright 2021-2022 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Hypothesis:
|
||||||
|
# The predicted tokens so far.
|
||||||
|
# Newly predicted tokens are appended to `ys`.
|
||||||
|
ys: List[int]
|
||||||
|
|
||||||
|
# The log prob of ys.
|
||||||
|
# It contains only one entry.
|
||||||
|
# Note: It contains only the acoustic part.
|
||||||
|
log_prob: torch.Tensor
|
||||||
|
|
||||||
|
# Used for shallow fusion
|
||||||
|
# The key of the dict is a state index into LG
|
||||||
|
# while the corresponding value is the LM score
|
||||||
|
# reaching this state from the start state.
|
||||||
|
# Note: The value tensor contains only a single entry
|
||||||
|
# and it contains only the LM part.
|
||||||
|
ngram_state_and_scores: Optional[Dict[int, torch.Tensor]] = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def key(self) -> str:
|
||||||
|
"""Return a string representation of self.ys"""
|
||||||
|
return "_".join(map(str, self.ys))
|
||||||
|
|
||||||
|
|
||||||
|
class HypothesisList(object):
|
||||||
|
def __init__(self, data: Optional[Dict[str, Hypothesis]] = None) -> None:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
data:
|
||||||
|
A dict of Hypotheses. Its key is its `value.key`.
|
||||||
|
"""
|
||||||
|
if data is None:
|
||||||
|
self._data = {}
|
||||||
|
else:
|
||||||
|
self._data = data
|
||||||
|
|
||||||
|
@property
|
||||||
|
def data(self) -> Dict[str, Hypothesis]:
|
||||||
|
return self._data
|
||||||
|
|
||||||
|
def add(self, hyp: Hypothesis) -> None:
|
||||||
|
"""Add a Hypothesis to `self`.
|
||||||
|
|
||||||
|
If `hyp` already exists in `self`, its probability is updated using
|
||||||
|
`log-sum-exp` with the existed one.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
hyp:
|
||||||
|
The hypothesis to be added.
|
||||||
|
"""
|
||||||
|
key = hyp.key
|
||||||
|
if key in self:
|
||||||
|
old_hyp = self._data[key] # shallow copy
|
||||||
|
|
||||||
|
if False:
|
||||||
|
old_hyp.log_prob = torch.logaddexp(
|
||||||
|
old_hyp.log_prob, hyp.log_prob
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
old_hyp.log_prob = max(old_hyp.log_prob, hyp.log_prob)
|
||||||
|
|
||||||
|
if hyp.ngram_state_and_scores is not None:
|
||||||
|
for state, score in hyp.ngram_state_and_scores.items():
|
||||||
|
if (
|
||||||
|
state in old_hyp.ngram_state_and_scores
|
||||||
|
and score > old_hyp.ngram_state_and_scores[state]
|
||||||
|
):
|
||||||
|
old_hyp.ngram_state_and_scores[state] = score
|
||||||
|
else:
|
||||||
|
old_hyp.ngram_state_and_scores[state] = score
|
||||||
|
else:
|
||||||
|
self._data[key] = hyp
|
||||||
|
|
||||||
|
def get_most_probable(
|
||||||
|
self, length_norm: bool = False, ngram_lm_scale: Optional[float] = None
|
||||||
|
) -> Hypothesis:
|
||||||
|
"""Get the most probable hypothesis, i.e., the one with
|
||||||
|
the largest `log_prob`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
length_norm:
|
||||||
|
If True, the `log_prob` of a hypothesis is normalized by the
|
||||||
|
number of tokens in it.
|
||||||
|
ngram_lm_scale:
|
||||||
|
If not None, it specifies the scale applied to the LM score.
|
||||||
|
Returns:
|
||||||
|
Return the hypothesis that has the largest `log_prob`.
|
||||||
|
"""
|
||||||
|
if length_norm:
|
||||||
|
if ngram_lm_scale is None:
|
||||||
|
return max(
|
||||||
|
self._data.values(),
|
||||||
|
key=lambda hyp: hyp.log_prob / len(hyp.ys),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return max(
|
||||||
|
self._data.values(),
|
||||||
|
key=lambda hyp: (
|
||||||
|
hyp.log_prob
|
||||||
|
+ ngram_lm_scale
|
||||||
|
* max(hyp.ngram_state_and_scores.values())
|
||||||
|
)
|
||||||
|
/ len(hyp.ys),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if ngram_lm_scale is None:
|
||||||
|
return max(self._data.values(), key=lambda hyp: hyp.log_prob)
|
||||||
|
else:
|
||||||
|
return max(
|
||||||
|
self._data.values(),
|
||||||
|
key=lambda hyp: hyp.log_prob
|
||||||
|
+ ngram_lm_scale * max(hyp.ngram_state_and_scores.values()),
|
||||||
|
)
|
||||||
|
|
||||||
|
def remove(self, hyp: Hypothesis) -> None:
|
||||||
|
"""Remove a given hypothesis.
|
||||||
|
|
||||||
|
Caution:
|
||||||
|
`self` is modified **in-place**.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
hyp:
|
||||||
|
The hypothesis to be removed from `self`.
|
||||||
|
Note: It must be contained in `self`. Otherwise,
|
||||||
|
an exception is raised.
|
||||||
|
"""
|
||||||
|
key = hyp.key
|
||||||
|
assert key in self, f"{key} does not exist"
|
||||||
|
del self._data[key]
|
||||||
|
|
||||||
|
def filter(
|
||||||
|
self, threshold: torch.Tensor, ngram_lm_scale: Optional[float] = None
|
||||||
|
) -> "HypothesisList":
|
||||||
|
"""Remove all Hypotheses whose log_prob is less than threshold.
|
||||||
|
|
||||||
|
Caution:
|
||||||
|
`self` is not modified. Instead, a new HypothesisList is returned.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
threshold:
|
||||||
|
Hypotheses with log_prob less than this value are removed.
|
||||||
|
ngram_lm_scale:
|
||||||
|
If not None, it specifies the scale applied to the LM score.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Return a new HypothesisList containing all hypotheses from `self`
|
||||||
|
with `log_prob` being greater than the given `threshold`.
|
||||||
|
"""
|
||||||
|
ans = HypothesisList()
|
||||||
|
if ngram_lm_scale is None:
|
||||||
|
for _, hyp in self._data.items():
|
||||||
|
if hyp.log_prob > threshold:
|
||||||
|
ans.add(hyp) # shallow copy
|
||||||
|
else:
|
||||||
|
for _, hyp in self._data.items():
|
||||||
|
if (
|
||||||
|
hyp.log_prob
|
||||||
|
+ ngram_lm_scale * max(hyp.ngram_state_and_scores.values())
|
||||||
|
> threshold
|
||||||
|
):
|
||||||
|
ans.add(hyp) # shallow copy
|
||||||
|
return ans
|
||||||
|
|
||||||
|
def topk(
|
||||||
|
self, k: int, ngram_lm_scale: Optional[float] = None
|
||||||
|
) -> "HypothesisList":
|
||||||
|
"""Return the top-k hypothesis."""
|
||||||
|
hyps = list(self._data.items())
|
||||||
|
|
||||||
|
if ngram_lm_scale is None:
|
||||||
|
hyps = sorted(hyps, key=lambda h: h[1].log_prob, reverse=True)[:k]
|
||||||
|
else:
|
||||||
|
hyps = sorted(
|
||||||
|
hyps,
|
||||||
|
key=lambda h: h[1].log_prob
|
||||||
|
+ ngram_lm_scale * max(h[1].ngram_state_and_scores.values()),
|
||||||
|
reverse=True,
|
||||||
|
)[:k]
|
||||||
|
|
||||||
|
ans = HypothesisList(dict(hyps))
|
||||||
|
return ans
|
||||||
|
|
||||||
|
def __contains__(self, key: str):
|
||||||
|
return key in self._data
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
return iter(self._data.values())
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
return len(self._data)
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
s = []
|
||||||
|
for key in self:
|
||||||
|
s.append(key)
|
||||||
|
return ", ".join(s)
|
Loading…
x
Reference in New Issue
Block a user