Add modified beam search.

This commit is contained in:
Fangjun Kuang 2022-01-28 23:51:04 +08:00
parent c3b3123b27
commit 77261bc575
2 changed files with 70 additions and 43 deletions

View File

@ -17,7 +17,6 @@
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, List, Optional from typing import Dict, List, Optional
import numpy as np
import torch import torch
from model import Transducer from model import Transducer
@ -108,8 +107,11 @@ class Hypothesis:
# Newly predicted tokens are appended to `ys`. # Newly predicted tokens are appended to `ys`.
ys: List[int] ys: List[int]
# The log prob of ys # The log prob of ys.
log_prob: float # It contains only one entry.
# TODO(fangjun): It was a float before. We need to change its usage
# in greedy_search and beam_search.
log_prob: torch.Tensor
@property @property
def key(self) -> str: def key(self) -> str:
@ -145,8 +147,10 @@ class HypothesisList(object):
""" """
key = hyp.key key = hyp.key
if key in self: if key in self:
old_hyp = self._data[key] old_hyp = self._data[key] # shallow copy
old_hyp.log_prob = np.logaddexp(old_hyp.log_prob, hyp.log_prob) torch.logaddexp(
old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob
)
else: else:
self._data[key] = hyp self._data[key] = hyp
@ -348,47 +352,70 @@ def modified_beam_search(
T = encoder_out.size(1) T = encoder_out.size(1)
B = HypothesisList() B = HypothesisList()
B.add(Hypothesis(ys=[blank_id] * context_size, log_prob=0.0)) B.add(
Hypothesis(
ys=[blank_id] * context_size,
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
)
)
encoder_out_len = torch.tensor([1]) encoder_out_len = torch.tensor([1])
decoder_out_len = torch.tensor([1]) decoder_out_len = torch.tensor([1])
decoder_cache: Dict[str, torch.Tensor] = {}
for t in range(T): for t in range(T):
# fmt: off # fmt: off
current_encoder_out = encoder_out[:, t:t+1, :] current_encoder_out = encoder_out[:, t:t+1, :]
# current_encoder_out is of shape (1, 1, encoder_out_dim)
# fmt: on # fmt: on
A = B A = list(B)
B = HypothesisList() B = HypothesisList()
joint_cache: Dict[str, torch.Tensor] = {} ys_log_probs = torch.cat([hyp.log_prob.reshape(1, 1) for hyp in A])
# ys_log_probs is of shape (num_hyps, 1)
for hyp in A: decoder_input = torch.tensor(
decoder_out = run_decoder( [hyp.ys[-context_size:] for hyp in A],
ys=hyp.ys, model=model, decoder_cache=decoder_cache device=device,
) )
key = "_".join(map(str, hyp.ys[-context_size:])) # decoder_input is of shape (num_hyps, context_size)
key += f"-t-{t}"
log_prob = run_joiner(
key=key,
model=model,
encoder_out=current_encoder_out,
decoder_out=decoder_out,
encoder_out_len=encoder_out_len,
decoder_out_len=decoder_out_len,
joint_cache=joint_cache,
)
log_prob = log_prob.cpu().tolist()
for i, v in enumerate(log_prob): decoder_out = model.decoder(decoder_input, need_pad=False)
if i == blank_id: # decoder_output is of shape (num_hyps, 1, decoder_output_dim)
# Use [:] to make a copy
current_encoder_out = current_encoder_out.expand(
decoder_out.size(0), 1, -1
)
logits = model.joiner(
current_encoder_out,
decoder_out,
encoder_out_len.expand(decoder_out.size(0)),
decoder_out_len.expand(decoder_out.size(0)),
)
# logits is of shape (num_hyps, vocab_size)
log_probs = logits.log_softmax(dim=-1)
log_probs.add_(ys_log_probs)
log_probs = log_probs.reshape(-1)
topk_log_probs, topk_indexes = log_probs.topk(beam)
# topk_hyp_indexes are indexes into `A`
topk_hyp_indexes = topk_indexes // logits.size(-1)
topk_token_indexes = topk_indexes % logits.size(-1)
topk_hyp_indexes = topk_hyp_indexes.tolist()
topk_token_indexes = topk_token_indexes.tolist()
for i in range(len(topk_hyp_indexes)):
hyp = A[topk_hyp_indexes[i]]
new_ys = hyp.ys[:] new_ys = hyp.ys[:]
else: new_token = topk_token_indexes[i]
new_ys = hyp.ys + [i] if new_token != blank_id:
new_hyp = Hypothesis(ys=new_ys, log_prob=hyp.log_prob + v) new_ys.append(new_token)
new_log_prob = topk_log_probs[i]
new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob)
B.add(new_hyp) B.add(new_hyp)
B = B.topk(beam)
best_hyp = B.get_most_probable(length_norm=True) best_hyp = B.get_most_probable(length_norm=True)
ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks

View File

@ -75,24 +75,24 @@ class Decoder(nn.Module):
""" """
Args: Args:
y: y:
A 2-D tensor of shape (N, U) with blank prepended. A 2-D tensor of shape (N, U).
need_pad: need_pad:
True to left pad the input. Should be True during training. True to left pad the input. Should be True during training.
False to not pad the input. Should be False during inference. False to not pad the input. Should be False during inference.
Returns: Returns:
Return a tensor of shape (N, U, embedding_dim). Return a tensor of shape (N, U, embedding_dim).
""" """
embeding_out = self.embedding(y) embedding_out = self.embedding(y)
if self.context_size > 1: if self.context_size > 1:
embeding_out = embeding_out.permute(0, 2, 1) embedding_out = embedding_out.permute(0, 2, 1)
if need_pad is True: if need_pad is True:
embeding_out = F.pad( embedding_out = F.pad(
embeding_out, pad=(self.context_size - 1, 0) embedding_out, pad=(self.context_size - 1, 0)
) )
else: else:
# During inference time, there is no need to do extra padding # During inference time, there is no need to do extra padding
# as we only need one output # as we only need one output
assert embeding_out.size(-1) == self.context_size assert embedding_out.size(-1) == self.context_size
embeding_out = self.conv(embeding_out) embedding_out = self.conv(embedding_out)
embeding_out = embeding_out.permute(0, 2, 1) embedding_out = embedding_out.permute(0, 2, 1)
return embeding_out return embedding_out