Add prefix search.

This commit is contained in:
Fangjun Kuang 2022-01-18 13:23:45 +08:00
parent 273e5fb2f3
commit 3372149493

View File

@ -118,7 +118,7 @@ class Hypothesis:
class HypothesisList(object): class HypothesisList(object):
def __init__(self, data: Optional[Dict[str, Hypothesis]] = None): def __init__(self, data: Optional[Dict[str, Hypothesis]] = None) -> None:
""" """
Args: Args:
data: data:
@ -130,11 +130,10 @@ class HypothesisList(object):
self._data = data self._data = data
@property @property
def data(self): def data(self) -> Dict[str, Hypothesis]:
return self._data return self._data
# def add(self, ys: List[int], log_prob: float): def add(self, hyp: Hypothesis) -> None:
def add(self, hyp: Hypothesis):
"""Add a Hypothesis to `self`. """Add a Hypothesis to `self`.
If `hyp` already exists in `self`, its probability is updated using If `hyp` already exists in `self`, its probability is updated using
@ -159,7 +158,8 @@ class HypothesisList(object):
length_norm: length_norm:
If True, the `log_prob` of a hypothesis is normalized by the If True, the `log_prob` of a hypothesis is normalized by the
number of tokens in it. number of tokens in it.
Returns:
Return the hypothesis that has the largest `log_prob`.
""" """
if length_norm: if length_norm:
return max( return max(
@ -171,6 +171,9 @@ class HypothesisList(object):
def remove(self, hyp: Hypothesis) -> None: def remove(self, hyp: Hypothesis) -> None:
"""Remove a given hypothesis. """Remove a given hypothesis.
Caution:
`self` is modified **in-place**.
Args: Args:
hyp: hyp:
The hypothesis to be removed from `self`. The hypothesis to be removed from `self`.
@ -189,10 +192,10 @@ class HypothesisList(object):
Returns: Returns:
Return a new HypothesisList containing all hypotheses from `self` Return a new HypothesisList containing all hypotheses from `self`
that have `log_prob` being greater than the given `threshold`. with `log_prob` being greater than the given `threshold`.
""" """
ans = HypothesisList() ans = HypothesisList()
for key, hyp in self._data.items(): for _, hyp in self._data.items():
if hyp.log_prob > threshold: if hyp.log_prob > threshold:
ans.add(hyp) # shallow copy ans.add(hyp) # shallow copy
return ans return ans
@ -222,6 +225,171 @@ class HypothesisList(object):
return ", ".join(s) return ", ".join(s)
def run_decoder(
ys: List[int],
model: Transducer,
decoder_cache: Dict[str, torch.Tensor],
) -> torch.Tensor:
"""Run the neural decoder model for a given hypothesis.
Args:
ys:
The current hypothesis.
model:
The transducer model.
decoder_cache:
Cache to save computations.
Returns:
Return a 1-D tensor of shape (decoder_out_dim,) containing
output of `model.decoder`.
"""
context_size = model.decoder.context_size
key = "_".join(map(str, ys[-context_size:]))
if key in decoder_cache:
return decoder_cache[key]
device = model.device
decoder_input = torch.tensor([ys[-context_size:]], device=device).reshape(
1, context_size
)
decoder_out = model.decoder(decoder_input, need_pad=False)
decoder_cache[key] = decoder_out
return decoder_out
def run_joiner(
key: str,
model: Transducer,
encoder_out: torch.Tensor,
decoder_out: torch.Tensor,
encoder_out_len: torch.Tensor,
decoder_out_len: torch.Tensor,
joint_cache: Dict[str, torch.Tensor],
):
"""Run the joint network given outputs from the encoder and decoder.
Args:
key:
A key into the `joint_cache`.
model:
The transducer model.
encoder_out:
A tensor of shape (1, 1, encoder_out_dim).
decoder_out:
A tensor of shape (1, 1, decoder_out_dim).
encoder_out_len:
A tensor with value [1].
decoder_out_len:
A tensor with value [1].
joint_cache:
A dict to save computations.
Returns:
Return a tensor from the output of log-softmax.
Its shape is (vocab_size,).
"""
if key in joint_cache:
return joint_cache[key]
logits = model.joiner(
encoder_out,
decoder_out,
encoder_out_len,
decoder_out_len,
)
# TODO(fangjun): Scale the blank posterior
log_prob = logits.log_softmax(dim=-1)
# log_prob is (1, 1, 1, vocab_size)
log_prob = log_prob.squeeze()
# Now log_prob is (vocab_size,)
joint_cache[key] = log_prob
return log_prob
def start_with(a: List[int], b: List[int]) -> bool:
"""Check whether a is started with b, i.e., whether a[len(b)] == b"""
a_len = len(a)
b_len = len(b)
if b_len > a_len:
return False
for i in range(b_len):
if a[i] != b[i]:
return False
return True
# The implementation uses
# espnet/nets/beam_search_transducer.py#L168
# as a reference
def prefix_search(
hyp_list: HypothesisList,
model: Transducer,
encoder_out: torch.Tensor,
decoder_cache: Dict[str, torch.Tensor],
joint_cache: Dict[str, torch.Tensor],
t: int,
):
hyps = list(hyp_list)
# sort hyps by number of tokens in descending order
hyps = sorted(hyps, key=lambda h: len(h.ys), reverse=True)
prefix_alpha = 1
device = model.device
context_size = model.decoder.context_size
encoder_out_len = torch.tensor([1])
decoder_out_len = torch.tensor([1])
for i, cur_hyp in enumerate(hyps[:-1]):
cur_hyp_len = len(cur_hyp.ys)
for next_hyp in hyps[i + 1 :]: # noqa
if not start_with(cur_hyp.ys, next_hyp.ys):
continue
next_hyp_len = len(next_hyp.ys)
# at this point, next_hyp.ys is a prefix of cur_hyp.ys
len_diff = cur_hyp_len - next_hyp_len
if len_diff > prefix_alpha:
continue
offset = next_hyp_len
total_log_prob = next_hyp.log_prob
for i in range(len_diff):
pos = offset + i
ys = cur_hyp.ys[:pos]
decoder_out = run_decoder(
ys=ys, model=model, decoder_cache=decoder_cache
)
key = "_".join(map(str, ys[-context_size:]))
key += f"-t-{t}"
log_prob = run_joiner(
key=key,
model=model,
encoder_out=encoder_out,
decoder_out=decoder_out,
encoder_out_len=encoder_out_len,
decoder_out_len=decoder_out_len,
joint_cache=joint_cache,
)
total_log_prob += log_prob[cur_hyp.ys[pos]].item()
cur_hyp.log_prob = np.logaddexp(total_log_prob, cur_hyp.log_prob)
ans = {hyp.key: hyp for hyp in hyps}
return HypothesisList(ans)
def beam_search( def beam_search(
model: Transducer, model: Transducer,
encoder_out: torch.Tensor, encoder_out: torch.Tensor,
@ -281,43 +449,34 @@ def beam_search(
joint_cache: Dict[str, torch.Tensor] = {} joint_cache: Dict[str, torch.Tensor] = {}
# TODO(fangjun): Implement prefix search to update the `log_prob` A = prefix_search(
# of hypotheses in A hyp_list=A,
model=model,
encoder_out=current_encoder_out,
decoder_cache=decoder_cache,
joint_cache=joint_cache,
t=t,
)
while True: while True:
y_star = A.get_most_probable() y_star = A.get_most_probable()
A.remove(y_star) A.remove(y_star)
cached_key = y_star.key decoder_out = run_decoder(
ys=y_star.ys, model=model, decoder_cache=decoder_cache
)
if cached_key not in decoder_cache: key = "_".join(map(str, y_star.ys[-context_size:]))
decoder_input = torch.tensor( key += f"-t-{t}"
[y_star.ys[-context_size:]], device=device log_prob = run_joiner(
).reshape(1, context_size) key=key,
model=model,
decoder_out = model.decoder(decoder_input, need_pad=False) encoder_out=current_encoder_out,
decoder_cache[cached_key] = decoder_out decoder_out=decoder_out,
else: encoder_out_len=encoder_out_len,
decoder_out = decoder_cache[cached_key] decoder_out_len=decoder_out_len,
joint_cache=joint_cache,
cached_key += f"-t-{t}" )
if cached_key not in joint_cache:
logits = model.joiner(
current_encoder_out,
decoder_out,
encoder_out_len,
decoder_out_len,
)
# TODO(fangjun): Ccale the blank posterior
log_prob = logits.log_softmax(dim=-1)
# log_prob is (1, 1, 1, vocab_size)
log_prob = log_prob.squeeze()
# Now log_prob is (vocab_size,)
joint_cache[cached_key] = log_prob
else:
log_prob = joint_cache[cached_key]
# First, process the blank symbol # First, process the blank symbol
skip_log_prob = log_prob[blank_id] skip_log_prob = log_prob[blank_id]