mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
Add prefix search.
This commit is contained in:
parent
273e5fb2f3
commit
3372149493
@ -118,7 +118,7 @@ class Hypothesis:
|
|||||||
|
|
||||||
|
|
||||||
class HypothesisList(object):
|
class HypothesisList(object):
|
||||||
def __init__(self, data: Optional[Dict[str, Hypothesis]] = None):
|
def __init__(self, data: Optional[Dict[str, Hypothesis]] = None) -> None:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
data:
|
data:
|
||||||
@ -130,11 +130,10 @@ class HypothesisList(object):
|
|||||||
self._data = data
|
self._data = data
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def data(self):
|
def data(self) -> Dict[str, Hypothesis]:
|
||||||
return self._data
|
return self._data
|
||||||
|
|
||||||
# def add(self, ys: List[int], log_prob: float):
|
def add(self, hyp: Hypothesis) -> None:
|
||||||
def add(self, hyp: Hypothesis):
|
|
||||||
"""Add a Hypothesis to `self`.
|
"""Add a Hypothesis to `self`.
|
||||||
|
|
||||||
If `hyp` already exists in `self`, its probability is updated using
|
If `hyp` already exists in `self`, its probability is updated using
|
||||||
@ -159,7 +158,8 @@ class HypothesisList(object):
|
|||||||
length_norm:
|
length_norm:
|
||||||
If True, the `log_prob` of a hypothesis is normalized by the
|
If True, the `log_prob` of a hypothesis is normalized by the
|
||||||
number of tokens in it.
|
number of tokens in it.
|
||||||
|
Returns:
|
||||||
|
Return the hypothesis that has the largest `log_prob`.
|
||||||
"""
|
"""
|
||||||
if length_norm:
|
if length_norm:
|
||||||
return max(
|
return max(
|
||||||
@ -171,6 +171,9 @@ class HypothesisList(object):
|
|||||||
def remove(self, hyp: Hypothesis) -> None:
|
def remove(self, hyp: Hypothesis) -> None:
|
||||||
"""Remove a given hypothesis.
|
"""Remove a given hypothesis.
|
||||||
|
|
||||||
|
Caution:
|
||||||
|
`self` is modified **in-place**.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
hyp:
|
hyp:
|
||||||
The hypothesis to be removed from `self`.
|
The hypothesis to be removed from `self`.
|
||||||
@ -189,10 +192,10 @@ class HypothesisList(object):
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Return a new HypothesisList containing all hypotheses from `self`
|
Return a new HypothesisList containing all hypotheses from `self`
|
||||||
that have `log_prob` being greater than the given `threshold`.
|
with `log_prob` being greater than the given `threshold`.
|
||||||
"""
|
"""
|
||||||
ans = HypothesisList()
|
ans = HypothesisList()
|
||||||
for key, hyp in self._data.items():
|
for _, hyp in self._data.items():
|
||||||
if hyp.log_prob > threshold:
|
if hyp.log_prob > threshold:
|
||||||
ans.add(hyp) # shallow copy
|
ans.add(hyp) # shallow copy
|
||||||
return ans
|
return ans
|
||||||
@ -222,6 +225,171 @@ class HypothesisList(object):
|
|||||||
return ", ".join(s)
|
return ", ".join(s)
|
||||||
|
|
||||||
|
|
||||||
|
def run_decoder(
|
||||||
|
ys: List[int],
|
||||||
|
model: Transducer,
|
||||||
|
decoder_cache: Dict[str, torch.Tensor],
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Run the neural decoder model for a given hypothesis.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ys:
|
||||||
|
The current hypothesis.
|
||||||
|
model:
|
||||||
|
The transducer model.
|
||||||
|
decoder_cache:
|
||||||
|
Cache to save computations.
|
||||||
|
Returns:
|
||||||
|
Return a 1-D tensor of shape (decoder_out_dim,) containing
|
||||||
|
output of `model.decoder`.
|
||||||
|
"""
|
||||||
|
context_size = model.decoder.context_size
|
||||||
|
key = "_".join(map(str, ys[-context_size:]))
|
||||||
|
if key in decoder_cache:
|
||||||
|
return decoder_cache[key]
|
||||||
|
|
||||||
|
device = model.device
|
||||||
|
|
||||||
|
decoder_input = torch.tensor([ys[-context_size:]], device=device).reshape(
|
||||||
|
1, context_size
|
||||||
|
)
|
||||||
|
|
||||||
|
decoder_out = model.decoder(decoder_input, need_pad=False)
|
||||||
|
decoder_cache[key] = decoder_out
|
||||||
|
|
||||||
|
return decoder_out
|
||||||
|
|
||||||
|
|
||||||
|
def run_joiner(
|
||||||
|
key: str,
|
||||||
|
model: Transducer,
|
||||||
|
encoder_out: torch.Tensor,
|
||||||
|
decoder_out: torch.Tensor,
|
||||||
|
encoder_out_len: torch.Tensor,
|
||||||
|
decoder_out_len: torch.Tensor,
|
||||||
|
joint_cache: Dict[str, torch.Tensor],
|
||||||
|
):
|
||||||
|
"""Run the joint network given outputs from the encoder and decoder.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key:
|
||||||
|
A key into the `joint_cache`.
|
||||||
|
model:
|
||||||
|
The transducer model.
|
||||||
|
encoder_out:
|
||||||
|
A tensor of shape (1, 1, encoder_out_dim).
|
||||||
|
decoder_out:
|
||||||
|
A tensor of shape (1, 1, decoder_out_dim).
|
||||||
|
encoder_out_len:
|
||||||
|
A tensor with value [1].
|
||||||
|
decoder_out_len:
|
||||||
|
A tensor with value [1].
|
||||||
|
joint_cache:
|
||||||
|
A dict to save computations.
|
||||||
|
Returns:
|
||||||
|
Return a tensor from the output of log-softmax.
|
||||||
|
Its shape is (vocab_size,).
|
||||||
|
"""
|
||||||
|
if key in joint_cache:
|
||||||
|
return joint_cache[key]
|
||||||
|
|
||||||
|
logits = model.joiner(
|
||||||
|
encoder_out,
|
||||||
|
decoder_out,
|
||||||
|
encoder_out_len,
|
||||||
|
decoder_out_len,
|
||||||
|
)
|
||||||
|
|
||||||
|
# TODO(fangjun): Scale the blank posterior
|
||||||
|
log_prob = logits.log_softmax(dim=-1)
|
||||||
|
# log_prob is (1, 1, 1, vocab_size)
|
||||||
|
|
||||||
|
log_prob = log_prob.squeeze()
|
||||||
|
# Now log_prob is (vocab_size,)
|
||||||
|
|
||||||
|
joint_cache[key] = log_prob
|
||||||
|
|
||||||
|
return log_prob
|
||||||
|
|
||||||
|
|
||||||
|
def start_with(a: List[int], b: List[int]) -> bool:
|
||||||
|
"""Check whether a is started with b, i.e., whether a[len(b)] == b"""
|
||||||
|
a_len = len(a)
|
||||||
|
b_len = len(b)
|
||||||
|
if b_len > a_len:
|
||||||
|
return False
|
||||||
|
|
||||||
|
for i in range(b_len):
|
||||||
|
if a[i] != b[i]:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
# The implementation uses
|
||||||
|
# espnet/nets/beam_search_transducer.py#L168
|
||||||
|
# as a reference
|
||||||
|
def prefix_search(
|
||||||
|
hyp_list: HypothesisList,
|
||||||
|
model: Transducer,
|
||||||
|
encoder_out: torch.Tensor,
|
||||||
|
decoder_cache: Dict[str, torch.Tensor],
|
||||||
|
joint_cache: Dict[str, torch.Tensor],
|
||||||
|
t: int,
|
||||||
|
):
|
||||||
|
hyps = list(hyp_list)
|
||||||
|
|
||||||
|
# sort hyps by number of tokens in descending order
|
||||||
|
hyps = sorted(hyps, key=lambda h: len(h.ys), reverse=True)
|
||||||
|
|
||||||
|
prefix_alpha = 1
|
||||||
|
|
||||||
|
device = model.device
|
||||||
|
context_size = model.decoder.context_size
|
||||||
|
|
||||||
|
encoder_out_len = torch.tensor([1])
|
||||||
|
decoder_out_len = torch.tensor([1])
|
||||||
|
|
||||||
|
for i, cur_hyp in enumerate(hyps[:-1]):
|
||||||
|
cur_hyp_len = len(cur_hyp.ys)
|
||||||
|
for next_hyp in hyps[i + 1 :]: # noqa
|
||||||
|
if not start_with(cur_hyp.ys, next_hyp.ys):
|
||||||
|
continue
|
||||||
|
|
||||||
|
next_hyp_len = len(next_hyp.ys)
|
||||||
|
|
||||||
|
# at this point, next_hyp.ys is a prefix of cur_hyp.ys
|
||||||
|
len_diff = cur_hyp_len - next_hyp_len
|
||||||
|
if len_diff > prefix_alpha:
|
||||||
|
continue
|
||||||
|
offset = next_hyp_len
|
||||||
|
|
||||||
|
total_log_prob = next_hyp.log_prob
|
||||||
|
for i in range(len_diff):
|
||||||
|
pos = offset + i
|
||||||
|
ys = cur_hyp.ys[:pos]
|
||||||
|
|
||||||
|
decoder_out = run_decoder(
|
||||||
|
ys=ys, model=model, decoder_cache=decoder_cache
|
||||||
|
)
|
||||||
|
|
||||||
|
key = "_".join(map(str, ys[-context_size:]))
|
||||||
|
key += f"-t-{t}"
|
||||||
|
log_prob = run_joiner(
|
||||||
|
key=key,
|
||||||
|
model=model,
|
||||||
|
encoder_out=encoder_out,
|
||||||
|
decoder_out=decoder_out,
|
||||||
|
encoder_out_len=encoder_out_len,
|
||||||
|
decoder_out_len=decoder_out_len,
|
||||||
|
joint_cache=joint_cache,
|
||||||
|
)
|
||||||
|
total_log_prob += log_prob[cur_hyp.ys[pos]].item()
|
||||||
|
cur_hyp.log_prob = np.logaddexp(total_log_prob, cur_hyp.log_prob)
|
||||||
|
|
||||||
|
ans = {hyp.key: hyp for hyp in hyps}
|
||||||
|
return HypothesisList(ans)
|
||||||
|
|
||||||
|
|
||||||
def beam_search(
|
def beam_search(
|
||||||
model: Transducer,
|
model: Transducer,
|
||||||
encoder_out: torch.Tensor,
|
encoder_out: torch.Tensor,
|
||||||
@ -281,43 +449,34 @@ def beam_search(
|
|||||||
|
|
||||||
joint_cache: Dict[str, torch.Tensor] = {}
|
joint_cache: Dict[str, torch.Tensor] = {}
|
||||||
|
|
||||||
# TODO(fangjun): Implement prefix search to update the `log_prob`
|
A = prefix_search(
|
||||||
# of hypotheses in A
|
hyp_list=A,
|
||||||
|
model=model,
|
||||||
|
encoder_out=current_encoder_out,
|
||||||
|
decoder_cache=decoder_cache,
|
||||||
|
joint_cache=joint_cache,
|
||||||
|
t=t,
|
||||||
|
)
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
y_star = A.get_most_probable()
|
y_star = A.get_most_probable()
|
||||||
A.remove(y_star)
|
A.remove(y_star)
|
||||||
|
|
||||||
cached_key = y_star.key
|
decoder_out = run_decoder(
|
||||||
|
ys=y_star.ys, model=model, decoder_cache=decoder_cache
|
||||||
|
)
|
||||||
|
|
||||||
if cached_key not in decoder_cache:
|
key = "_".join(map(str, y_star.ys[-context_size:]))
|
||||||
decoder_input = torch.tensor(
|
key += f"-t-{t}"
|
||||||
[y_star.ys[-context_size:]], device=device
|
log_prob = run_joiner(
|
||||||
).reshape(1, context_size)
|
key=key,
|
||||||
|
model=model,
|
||||||
decoder_out = model.decoder(decoder_input, need_pad=False)
|
encoder_out=current_encoder_out,
|
||||||
decoder_cache[cached_key] = decoder_out
|
decoder_out=decoder_out,
|
||||||
else:
|
encoder_out_len=encoder_out_len,
|
||||||
decoder_out = decoder_cache[cached_key]
|
decoder_out_len=decoder_out_len,
|
||||||
|
joint_cache=joint_cache,
|
||||||
cached_key += f"-t-{t}"
|
)
|
||||||
if cached_key not in joint_cache:
|
|
||||||
logits = model.joiner(
|
|
||||||
current_encoder_out,
|
|
||||||
decoder_out,
|
|
||||||
encoder_out_len,
|
|
||||||
decoder_out_len,
|
|
||||||
)
|
|
||||||
|
|
||||||
# TODO(fangjun): Ccale the blank posterior
|
|
||||||
|
|
||||||
log_prob = logits.log_softmax(dim=-1)
|
|
||||||
# log_prob is (1, 1, 1, vocab_size)
|
|
||||||
log_prob = log_prob.squeeze()
|
|
||||||
# Now log_prob is (vocab_size,)
|
|
||||||
joint_cache[cached_key] = log_prob
|
|
||||||
else:
|
|
||||||
log_prob = joint_cache[cached_key]
|
|
||||||
|
|
||||||
# First, process the blank symbol
|
# First, process the blank symbol
|
||||||
skip_log_prob = log_prob[blank_id]
|
skip_log_prob = log_prob[blank_id]
|
||||||
|
Loading…
x
Reference in New Issue
Block a user