Add modified beam search.

This commit is contained in:
Fangjun Kuang 2022-01-28 23:51:04 +08:00
parent c3b3123b27
commit 77261bc575
2 changed files with 70 additions and 43 deletions

View File

@ -17,7 +17,6 @@
from dataclasses import dataclass
from typing import Dict, List, Optional
import numpy as np
import torch
from model import Transducer
@ -108,8 +107,11 @@ class Hypothesis:
# Newly predicted tokens are appended to `ys`.
ys: List[int]
# The log prob of ys
log_prob: float
# The log prob of ys.
# It contains only one entry.
# TODO(fangjun): It was a float before. We need to change its usage
# in greedy_search and beam_search.
log_prob: torch.Tensor
@property
def key(self) -> str:
@ -145,8 +147,10 @@ class HypothesisList(object):
"""
key = hyp.key
if key in self:
old_hyp = self._data[key]
old_hyp.log_prob = np.logaddexp(old_hyp.log_prob, hyp.log_prob)
old_hyp = self._data[key] # shallow copy
torch.logaddexp(
old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob
)
else:
self._data[key] = hyp
@ -348,47 +352,70 @@ def modified_beam_search(
T = encoder_out.size(1)
B = HypothesisList()
B.add(Hypothesis(ys=[blank_id] * context_size, log_prob=0.0))
B.add(
Hypothesis(
ys=[blank_id] * context_size,
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
)
)
encoder_out_len = torch.tensor([1])
decoder_out_len = torch.tensor([1])
decoder_cache: Dict[str, torch.Tensor] = {}
for t in range(T):
# fmt: off
current_encoder_out = encoder_out[:, t:t+1, :]
# current_encoder_out is of shape (1, 1, encoder_out_dim)
# fmt: on
A = B
A = list(B)
B = HypothesisList()
joint_cache: Dict[str, torch.Tensor] = {}
ys_log_probs = torch.cat([hyp.log_prob.reshape(1, 1) for hyp in A])
# ys_log_probs is of shape (num_hyps, 1)
for hyp in A:
decoder_out = run_decoder(
ys=hyp.ys, model=model, decoder_cache=decoder_cache
)
key = "_".join(map(str, hyp.ys[-context_size:]))
key += f"-t-{t}"
log_prob = run_joiner(
key=key,
model=model,
encoder_out=current_encoder_out,
decoder_out=decoder_out,
encoder_out_len=encoder_out_len,
decoder_out_len=decoder_out_len,
joint_cache=joint_cache,
)
log_prob = log_prob.cpu().tolist()
decoder_input = torch.tensor(
[hyp.ys[-context_size:] for hyp in A],
device=device,
)
# decoder_input is of shape (num_hyps, context_size)
for i, v in enumerate(log_prob):
if i == blank_id:
# Use [:] to make a copy
new_ys = hyp.ys[:]
else:
new_ys = hyp.ys + [i]
new_hyp = Hypothesis(ys=new_ys, log_prob=hyp.log_prob + v)
B.add(new_hyp)
B = B.topk(beam)
decoder_out = model.decoder(decoder_input, need_pad=False)
# decoder_output is of shape (num_hyps, 1, decoder_output_dim)
current_encoder_out = current_encoder_out.expand(
decoder_out.size(0), 1, -1
)
logits = model.joiner(
current_encoder_out,
decoder_out,
encoder_out_len.expand(decoder_out.size(0)),
decoder_out_len.expand(decoder_out.size(0)),
)
# logits is of shape (num_hyps, vocab_size)
log_probs = logits.log_softmax(dim=-1)
log_probs.add_(ys_log_probs)
log_probs = log_probs.reshape(-1)
topk_log_probs, topk_indexes = log_probs.topk(beam)
# topk_hyp_indexes are indexes into `A`
topk_hyp_indexes = topk_indexes // logits.size(-1)
topk_token_indexes = topk_indexes % logits.size(-1)
topk_hyp_indexes = topk_hyp_indexes.tolist()
topk_token_indexes = topk_token_indexes.tolist()
for i in range(len(topk_hyp_indexes)):
hyp = A[topk_hyp_indexes[i]]
new_ys = hyp.ys[:]
new_token = topk_token_indexes[i]
if new_token != blank_id:
new_ys.append(new_token)
new_log_prob = topk_log_probs[i]
new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob)
B.add(new_hyp)
best_hyp = B.get_most_probable(length_norm=True)
ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks

View File

@ -75,24 +75,24 @@ class Decoder(nn.Module):
"""
Args:
y:
A 2-D tensor of shape (N, U) with blank prepended.
A 2-D tensor of shape (N, U).
need_pad:
True to left pad the input. Should be True during training.
False to not pad the input. Should be False during inference.
Returns:
Return a tensor of shape (N, U, embedding_dim).
"""
embeding_out = self.embedding(y)
embedding_out = self.embedding(y)
if self.context_size > 1:
embeding_out = embeding_out.permute(0, 2, 1)
embedding_out = embedding_out.permute(0, 2, 1)
if need_pad is True:
embeding_out = F.pad(
embeding_out, pad=(self.context_size - 1, 0)
embedding_out = F.pad(
embedding_out, pad=(self.context_size - 1, 0)
)
else:
# During inference time, there is no need to do extra padding
# as we only need one output
assert embeding_out.size(-1) == self.context_size
embeding_out = self.conv(embeding_out)
embeding_out = embeding_out.permute(0, 2, 1)
return embeding_out
assert embedding_out.size(-1) == self.context_size
embedding_out = self.conv(embedding_out)
embedding_out = embedding_out.permute(0, 2, 1)
return embedding_out