Implement beam search.

This commit is contained in:
Fangjun Kuang 2021-12-22 15:15:58 +08:00
parent afec6b6cae
commit fbc1bc3a6b
3 changed files with 212 additions and 91 deletions

View File

@ -1,3 +1,17 @@
# Introduction
Please refer to <https://icefall.readthedocs.io/en/latest/recipes/librispeech.html>
for how to run models in this recipe.
# Transducers
There are various folders containing the name `transducer` in this folder.
The following table lists the differences among them.
| | Encoder | Decoder |
|------------------------|-----------|--------------------|
| `transducer` | Conformer | LSTM |
| `transducer_stateless` | Conformer | Conv1d + Embedding |

View File

@ -3,8 +3,9 @@
### LibriSpeech BPE training results (RNN-T)
#### 2021-12-17
Using commit `cb04c8a7509425ab45fae888b0ca71bbbd23f0de`.
RNN-T + Conformer encoder
RNN-T + Conformer encoder.
The best WER is
@ -12,7 +13,7 @@ The best WER is
|-----|------------|------------|
| WER | 3.16 | 7.71 |
using `--epoch 26 --avg 12` during decoding with greedy search.
using `--epoch 26 --avg 12` with **greedy search**.
The training command to reproduce the above WER is:

View File

@ -15,8 +15,9 @@
# limitations under the License.
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple
from typing import Dict, List, Optional
import numpy as np
import torch
from model import Transducer
@ -35,25 +36,35 @@ def greedy_search(model: Transducer, encoder_out: torch.Tensor) -> List[int]:
# support only batch_size == 1 for now
assert encoder_out.size(0) == 1, encoder_out.size(0)
blank_id = model.decoder.blank_id
context_size = model.decoder.context_size
device = model.device
sos = torch.tensor([blank_id] * context_size, device=device).reshape(
1, context_size
)
decoder_out = model.decoder(sos, need_pad=False)
decoder_input = torch.tensor(
[blank_id] * context_size, device=device
).reshape(1, context_size)
decoder_out = model.decoder(decoder_input, need_pad=False)
T = encoder_out.size(1)
t = 0
hyp = [blank_id] * context_size
sym_per_frame = 0
sym_per_utt = 0
# Maximum symbols per utterance.
max_sym_per_utt = 1000
# If at frame t, it decodes more than this number of symbols,
# it will move to the next step t+1
max_sym_per_frame = 3
# symbols per frame
sym_per_frame = 0
# symbols per utterance decoded so far
sym_per_utt = 0
while t < T and sym_per_utt < max_sym_per_utt:
# fmt: off
current_encoder_out = encoder_out[:, t:t+1, :]
@ -83,18 +94,125 @@ def greedy_search(model: Transducer, encoder_out: torch.Tensor) -> List[int]:
@dataclass
class Hypothesis:
ys: List[int] # the predicted sequences so far
log_prob: float # The log prob of ys
# The predicted tokens so far.
# Newly predicted tokens are appended to `ys`.
ys: List[int]
# Optional decoder state. We assume it is LSTM for now,
# so the state is a tuple (h, c)
decoder_state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None
# The log prob of ys
log_prob: float
@property
def key(self) -> str:
"""Return a string representation of self.ys"""
return "_".join(map(str, self.ys))
class HypothesisList(object):
def __init__(self, data: Optional[Dict[str, Hypothesis]] = {}):
"""
Args:
data:
A dict of Hypotheses. Its key is its `value.key`.
"""
self._data = data
@property
def data(self):
return self._data
# def add(self, ys: List[int], log_prob: float):
def add(self, hyp: Hypothesis):
"""Add a Hypothesis to `self`.
If `hyp` already exists in `self`, its probability is updated using
`log-sum-exp` with the existed one.
Args:
hyp:
The hypothesis to be added.
"""
key = hyp.key
if key in self:
old_hyp = self._data[key]
old_hyp.log_prob = np.logaddexp(old_hyp.log_prob, hyp.log_prob)
else:
self._data[key] = hyp
def get_most_probable(self, length_norm: bool = False) -> Hypothesis:
"""Get the most probable hypothesis, i.e., the one with
the largest `log_prob`.
Args:
length_norm:
If True, the `log_prob` of a hypothesis is normalized by the
number of tokens in it.
"""
if length_norm:
return max(
self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys)
)
else:
return max(self._data.values(), key=lambda hyp: hyp.log_prob)
def remove(self, hyp: Hypothesis) -> None:
"""Remove a given hypothesis.
Args:
hyp:
The hypothesis to be removed from `self`.
Note: It must be contained in `self`. Otherwise,
an exception is raised.
"""
key = hyp.key
assert key in self, f"{key} does not exist"
del self._data[key]
def filter(self, threshold: float) -> "HypothesisList":
"""Remove all Hypotheses whose log_prob is less than threshold.
Caution:
`self` is not modified. Instead, a new HypothesisList is returned.
Returns:
Return a new HypothesisList containing all hypotheses from `self`
that have `log_prob` being greater than the given `threshold`.
"""
ans = HypothesisList()
for key, hyp in self._data.items():
if hyp.log_prob > threshold:
ans.add(hyp) # shallow copy
return ans
def topk(self, k: int) -> "HypothesisList":
"""Return the top-k hypothesis."""
hyps = list(self._data.items())
hyps = sorted(hyps, key=lambda h: h[1].log_prob, reverse=True)[:k]
ans = HypothesisList(dict(hyps))
return ans
def __contains__(self, key: str):
return key in self._data
def __iter__(self):
return iter(self._data.values())
def __len__(self) -> int:
return len(self._data)
def __str__(self) -> str:
s = []
for key in self:
s.append(key)
return ", ".join(s)
def beam_search(
model: Transducer,
encoder_out: torch.Tensor,
beam: int = 5,
beam: int = 4,
) -> List[int]:
"""
It implements Algorithm 1 in https://arxiv.org/pdf/1211.3711.pdf
@ -116,110 +234,98 @@ def beam_search(
# support only batch_size == 1 for now
assert encoder_out.size(0) == 1, encoder_out.size(0)
blank_id = model.decoder.blank_id
sos_id = model.decoder.sos_id
context_size = model.decoder.context_size
device = model.device
sos = torch.tensor([blank_id], device=device).reshape(1, 1)
decoder_out, (h, c) = model.decoder(sos)
decoder_input = torch.tensor(
[blank_id] * context_size, device=device
).reshape(1, context_size)
decoder_out = model.decoder(decoder_input, need_pad=False)
T = encoder_out.size(1)
t = 0
B = [Hypothesis(ys=[blank_id], log_prob=0.0, decoder_state=None)]
max_u = 20000 # terminate after this number of steps
u = 0
cache: Dict[
str, Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]
] = {}
B = HypothesisList()
B.add(Hypothesis(ys=[blank_id] * context_size, log_prob=0.0))
while t < T and u < max_u:
max_sym_per_utt = 20000
sym_per_utt = 0
decoder_cache: Dict[str, torch.Tensor] = {}
while t < T and sym_per_utt < max_sym_per_utt:
# fmt: off
current_encoder_out = encoder_out[:, t:t+1, :]
# fmt: on
A = B
B = []
# for hyp in A:
# for h in A:
# if h.ys == hyp.ys[:-1]:
# # update the score of hyp
# decoder_input = torch.tensor(
# [h.ys[-1]], device=device
# ).reshape(1, 1)
# decoder_out, _ = model.decoder(
# decoder_input, h.decoder_state
# )
# logits = model.joiner(current_encoder_out, decoder_out)
# log_prob = logits.log_softmax(dim=-1)
# log_prob = log_prob.squeeze()
# hyp.log_prob += h.log_prob + log_prob[hyp.ys[-1]].item()
B = HypothesisList()
while u < max_u:
y_star = max(A, key=lambda hyp: hyp.log_prob)
joint_cache: Dict[str, torch.Tensor] = {}
# TODO(fangjun): Implement prefix search to update the `log_prob`
# of hypotheses in A
while True:
y_star = A.get_most_probable()
A.remove(y_star)
# Note: y_star.ys is unhashable, i.e., cannot be used
# as a key into a dict
cached_key = "_".join(map(str, y_star.ys))
cached_key = y_star.key
if cached_key not in cache:
if cached_key not in decoder_cache:
decoder_input = torch.tensor(
[y_star.ys[-1]], device=device
).reshape(1, 1)
[y_star.ys[-context_size:]], device=device
).reshape(1, context_size)
decoder_out, decoder_state = model.decoder(
decoder_input,
y_star.decoder_state,
)
cache[cached_key] = (decoder_out, decoder_state)
decoder_out = model.decoder(decoder_input, need_pad=False)
decoder_cache[cached_key] = decoder_out
else:
decoder_out, decoder_state = cache[cached_key]
decoder_out = decoder_cache[cached_key]
cached_key += f"-t-{t}"
if cached_key not in joint_cache:
logits = model.joiner(current_encoder_out, decoder_out)
# TODO(fangjun): Ccale the blank posterior
log_prob = logits.log_softmax(dim=-1)
# log_prob is (1, 1, 1, vocab_size)
log_prob = log_prob.squeeze()
# Now log_prob is (vocab_size,)
joint_cache[cached_key] = log_prob
else:
log_prob = joint_cache[cached_key]
# If we choose blank here, add the new hypothesis to B.
# Otherwise, add the new hypothesis to A
# First, choose blank
# First, process the blank symbol
skip_log_prob = log_prob[blank_id]
new_y_star_log_prob = y_star.log_prob + skip_log_prob.item()
# ys[:] returns a copy of ys
new_y_star = Hypothesis(
ys=y_star.ys[:],
log_prob=new_y_star_log_prob,
# Caution: Use y_star.decoder_state here
decoder_state=y_star.decoder_state,
)
B.append(new_y_star)
B.add(Hypothesis(ys=y_star.ys[:], log_prob=new_y_star_log_prob))
# Second, choose other labels
for i, v in enumerate(log_prob.tolist()):
if i in (blank_id, sos_id):
# Second, process other non-blank labels
values, indices = log_prob.topk(beam + 1)
for i, v in zip(indices.tolist(), values.tolist()):
if i == blank_id:
continue
new_ys = y_star.ys + [i]
new_log_prob = y_star.log_prob + v
new_hyp = Hypothesis(
ys=new_ys,
log_prob=new_log_prob,
decoder_state=decoder_state,
)
A.append(new_hyp)
u += 1
# check whether B contains more than "beam" elements more probable
A.add(Hypothesis(ys=new_ys, log_prob=new_log_prob))
# Check whether B contains more than "beam" elements more probable
# than the most probable in A
A_most_probable = max(A, key=lambda hyp: hyp.log_prob)
B = sorted(
[hyp for hyp in B if hyp.log_prob > A_most_probable.log_prob],
key=lambda hyp: hyp.log_prob,
reverse=True,
)
if len(B) >= beam:
B = B[:beam]
A_most_probable = A.get_most_probable()
kept_B = B.filter(A_most_probable.log_prob)
if len(kept_B) >= beam:
B = kept_B.topk(beam)
break
t += 1
best_hyp = max(B, key=lambda hyp: hyp.log_prob / len(hyp.ys[1:]))
ys = best_hyp.ys[1:] # [1:] to remove the blank
best_hyp = B.get_most_probable(length_norm=True)
ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks
return ys