Implement beam search.

This commit is contained in:
Fangjun Kuang 2021-12-22 15:15:58 +08:00
parent afec6b6cae
commit fbc1bc3a6b
3 changed files with 212 additions and 91 deletions

View File

@ -1,3 +1,17 @@
# Introduction
Please refer to <https://icefall.readthedocs.io/en/latest/recipes/librispeech.html> Please refer to <https://icefall.readthedocs.io/en/latest/recipes/librispeech.html>
for how to run models in this recipe. for how to run models in this recipe.
# Transducers
There are various folders containing the name `transducer` in this folder.
The following table lists the differences among them.
| | Encoder | Decoder |
|------------------------|-----------|--------------------|
| `transducer` | Conformer | LSTM |
| `transducer_stateless` | Conformer | Conv1d + Embedding |

View File

@ -3,8 +3,9 @@
### LibriSpeech BPE training results (RNN-T) ### LibriSpeech BPE training results (RNN-T)
#### 2021-12-17 #### 2021-12-17
Using commit `cb04c8a7509425ab45fae888b0ca71bbbd23f0de`.
RNN-T + Conformer encoder RNN-T + Conformer encoder.
The best WER is The best WER is
@ -12,7 +13,7 @@ The best WER is
|-----|------------|------------| |-----|------------|------------|
| WER | 3.16 | 7.71 | | WER | 3.16 | 7.71 |
using `--epoch 26 --avg 12` during decoding with greedy search. using `--epoch 26 --avg 12` with **greedy search**.
The training command to reproduce the above WER is: The training command to reproduce the above WER is:

View File

@ -15,8 +15,9 @@
# limitations under the License. # limitations under the License.
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional
import numpy as np
import torch import torch
from model import Transducer from model import Transducer
@ -35,25 +36,35 @@ def greedy_search(model: Transducer, encoder_out: torch.Tensor) -> List[int]:
# support only batch_size == 1 for now # support only batch_size == 1 for now
assert encoder_out.size(0) == 1, encoder_out.size(0) assert encoder_out.size(0) == 1, encoder_out.size(0)
blank_id = model.decoder.blank_id blank_id = model.decoder.blank_id
context_size = model.decoder.context_size context_size = model.decoder.context_size
device = model.device device = model.device
sos = torch.tensor([blank_id] * context_size, device=device).reshape( decoder_input = torch.tensor(
1, context_size [blank_id] * context_size, device=device
) ).reshape(1, context_size)
decoder_out = model.decoder(sos, need_pad=False)
decoder_out = model.decoder(decoder_input, need_pad=False)
T = encoder_out.size(1) T = encoder_out.size(1)
t = 0 t = 0
hyp = [blank_id] * context_size hyp = [blank_id] * context_size
sym_per_frame = 0 # Maximum symbols per utterance.
sym_per_utt = 0
max_sym_per_utt = 1000 max_sym_per_utt = 1000
# If at frame t, it decodes more than this number of symbols,
# it will move to the next step t+1
max_sym_per_frame = 3 max_sym_per_frame = 3
# symbols per frame
sym_per_frame = 0
# symbols per utterance decoded so far
sym_per_utt = 0
while t < T and sym_per_utt < max_sym_per_utt: while t < T and sym_per_utt < max_sym_per_utt:
# fmt: off # fmt: off
current_encoder_out = encoder_out[:, t:t+1, :] current_encoder_out = encoder_out[:, t:t+1, :]
@ -83,18 +94,125 @@ def greedy_search(model: Transducer, encoder_out: torch.Tensor) -> List[int]:
@dataclass @dataclass
class Hypothesis: class Hypothesis:
ys: List[int] # the predicted sequences so far # The predicted tokens so far.
log_prob: float # The log prob of ys # Newly predicted tokens are appended to `ys`.
ys: List[int]
# Optional decoder state. We assume it is LSTM for now, # The log prob of ys
# so the state is a tuple (h, c) log_prob: float
decoder_state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None
@property
def key(self) -> str:
"""Return a string representation of self.ys"""
return "_".join(map(str, self.ys))
class HypothesisList(object):
def __init__(self, data: Optional[Dict[str, Hypothesis]] = {}):
"""
Args:
data:
A dict of Hypotheses. Its key is its `value.key`.
"""
self._data = data
@property
def data(self):
return self._data
# def add(self, ys: List[int], log_prob: float):
def add(self, hyp: Hypothesis):
"""Add a Hypothesis to `self`.
If `hyp` already exists in `self`, its probability is updated using
`log-sum-exp` with the existed one.
Args:
hyp:
The hypothesis to be added.
"""
key = hyp.key
if key in self:
old_hyp = self._data[key]
old_hyp.log_prob = np.logaddexp(old_hyp.log_prob, hyp.log_prob)
else:
self._data[key] = hyp
def get_most_probable(self, length_norm: bool = False) -> Hypothesis:
"""Get the most probable hypothesis, i.e., the one with
the largest `log_prob`.
Args:
length_norm:
If True, the `log_prob` of a hypothesis is normalized by the
number of tokens in it.
"""
if length_norm:
return max(
self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys)
)
else:
return max(self._data.values(), key=lambda hyp: hyp.log_prob)
def remove(self, hyp: Hypothesis) -> None:
"""Remove a given hypothesis.
Args:
hyp:
The hypothesis to be removed from `self`.
Note: It must be contained in `self`. Otherwise,
an exception is raised.
"""
key = hyp.key
assert key in self, f"{key} does not exist"
del self._data[key]
def filter(self, threshold: float) -> "HypothesisList":
"""Remove all Hypotheses whose log_prob is less than threshold.
Caution:
`self` is not modified. Instead, a new HypothesisList is returned.
Returns:
Return a new HypothesisList containing all hypotheses from `self`
that have `log_prob` being greater than the given `threshold`.
"""
ans = HypothesisList()
for key, hyp in self._data.items():
if hyp.log_prob > threshold:
ans.add(hyp) # shallow copy
return ans
def topk(self, k: int) -> "HypothesisList":
"""Return the top-k hypothesis."""
hyps = list(self._data.items())
hyps = sorted(hyps, key=lambda h: h[1].log_prob, reverse=True)[:k]
ans = HypothesisList(dict(hyps))
return ans
def __contains__(self, key: str):
return key in self._data
def __iter__(self):
return iter(self._data.values())
def __len__(self) -> int:
return len(self._data)
def __str__(self) -> str:
s = []
for key in self:
s.append(key)
return ", ".join(s)
def beam_search( def beam_search(
model: Transducer, model: Transducer,
encoder_out: torch.Tensor, encoder_out: torch.Tensor,
beam: int = 5, beam: int = 4,
) -> List[int]: ) -> List[int]:
""" """
It implements Algorithm 1 in https://arxiv.org/pdf/1211.3711.pdf It implements Algorithm 1 in https://arxiv.org/pdf/1211.3711.pdf
@ -116,110 +234,98 @@ def beam_search(
# support only batch_size == 1 for now # support only batch_size == 1 for now
assert encoder_out.size(0) == 1, encoder_out.size(0) assert encoder_out.size(0) == 1, encoder_out.size(0)
blank_id = model.decoder.blank_id blank_id = model.decoder.blank_id
sos_id = model.decoder.sos_id context_size = model.decoder.context_size
device = model.device device = model.device
sos = torch.tensor([blank_id], device=device).reshape(1, 1) decoder_input = torch.tensor(
decoder_out, (h, c) = model.decoder(sos) [blank_id] * context_size, device=device
).reshape(1, context_size)
decoder_out = model.decoder(decoder_input, need_pad=False)
T = encoder_out.size(1) T = encoder_out.size(1)
t = 0 t = 0
B = [Hypothesis(ys=[blank_id], log_prob=0.0, decoder_state=None)]
max_u = 20000 # terminate after this number of steps
u = 0
cache: Dict[ B = HypothesisList()
str, Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] B.add(Hypothesis(ys=[blank_id] * context_size, log_prob=0.0))
] = {}
while t < T and u < max_u: max_sym_per_utt = 20000
sym_per_utt = 0
decoder_cache: Dict[str, torch.Tensor] = {}
while t < T and sym_per_utt < max_sym_per_utt:
# fmt: off # fmt: off
current_encoder_out = encoder_out[:, t:t+1, :] current_encoder_out = encoder_out[:, t:t+1, :]
# fmt: on # fmt: on
A = B A = B
B = [] B = HypothesisList()
# for hyp in A:
# for h in A:
# if h.ys == hyp.ys[:-1]:
# # update the score of hyp
# decoder_input = torch.tensor(
# [h.ys[-1]], device=device
# ).reshape(1, 1)
# decoder_out, _ = model.decoder(
# decoder_input, h.decoder_state
# )
# logits = model.joiner(current_encoder_out, decoder_out)
# log_prob = logits.log_softmax(dim=-1)
# log_prob = log_prob.squeeze()
# hyp.log_prob += h.log_prob + log_prob[hyp.ys[-1]].item()
while u < max_u: joint_cache: Dict[str, torch.Tensor] = {}
y_star = max(A, key=lambda hyp: hyp.log_prob)
# TODO(fangjun): Implement prefix search to update the `log_prob`
# of hypotheses in A
while True:
y_star = A.get_most_probable()
A.remove(y_star) A.remove(y_star)
# Note: y_star.ys is unhashable, i.e., cannot be used cached_key = y_star.key
# as a key into a dict
cached_key = "_".join(map(str, y_star.ys))
if cached_key not in cache: if cached_key not in decoder_cache:
decoder_input = torch.tensor( decoder_input = torch.tensor(
[y_star.ys[-1]], device=device [y_star.ys[-context_size:]], device=device
).reshape(1, 1) ).reshape(1, context_size)
decoder_out, decoder_state = model.decoder( decoder_out = model.decoder(decoder_input, need_pad=False)
decoder_input, decoder_cache[cached_key] = decoder_out
y_star.decoder_state,
)
cache[cached_key] = (decoder_out, decoder_state)
else: else:
decoder_out, decoder_state = cache[cached_key] decoder_out = decoder_cache[cached_key]
logits = model.joiner(current_encoder_out, decoder_out) cached_key += f"-t-{t}"
log_prob = logits.log_softmax(dim=-1) if cached_key not in joint_cache:
# log_prob is (1, 1, 1, vocab_size) logits = model.joiner(current_encoder_out, decoder_out)
log_prob = log_prob.squeeze()
# Now log_prob is (vocab_size,)
# If we choose blank here, add the new hypothesis to B. # TODO(fangjun): Ccale the blank posterior
# Otherwise, add the new hypothesis to A
# First, choose blank log_prob = logits.log_softmax(dim=-1)
# log_prob is (1, 1, 1, vocab_size)
log_prob = log_prob.squeeze()
# Now log_prob is (vocab_size,)
joint_cache[cached_key] = log_prob
else:
log_prob = joint_cache[cached_key]
# First, process the blank symbol
skip_log_prob = log_prob[blank_id] skip_log_prob = log_prob[blank_id]
new_y_star_log_prob = y_star.log_prob + skip_log_prob.item() new_y_star_log_prob = y_star.log_prob + skip_log_prob.item()
# ys[:] returns a copy of ys # ys[:] returns a copy of ys
new_y_star = Hypothesis( B.add(Hypothesis(ys=y_star.ys[:], log_prob=new_y_star_log_prob))
ys=y_star.ys[:],
log_prob=new_y_star_log_prob,
# Caution: Use y_star.decoder_state here
decoder_state=y_star.decoder_state,
)
B.append(new_y_star)
# Second, choose other labels # Second, process other non-blank labels
for i, v in enumerate(log_prob.tolist()): values, indices = log_prob.topk(beam + 1)
if i in (blank_id, sos_id): for i, v in zip(indices.tolist(), values.tolist()):
if i == blank_id:
continue continue
new_ys = y_star.ys + [i] new_ys = y_star.ys + [i]
new_log_prob = y_star.log_prob + v new_log_prob = y_star.log_prob + v
new_hyp = Hypothesis( A.add(Hypothesis(ys=new_ys, log_prob=new_log_prob))
ys=new_ys,
log_prob=new_log_prob, # Check whether B contains more than "beam" elements more probable
decoder_state=decoder_state,
)
A.append(new_hyp)
u += 1
# check whether B contains more than "beam" elements more probable
# than the most probable in A # than the most probable in A
A_most_probable = max(A, key=lambda hyp: hyp.log_prob) A_most_probable = A.get_most_probable()
B = sorted(
[hyp for hyp in B if hyp.log_prob > A_most_probable.log_prob], kept_B = B.filter(A_most_probable.log_prob)
key=lambda hyp: hyp.log_prob,
reverse=True, if len(kept_B) >= beam:
) B = kept_B.topk(beam)
if len(B) >= beam:
B = B[:beam]
break break
t += 1 t += 1
best_hyp = max(B, key=lambda hyp: hyp.log_prob / len(hyp.ys[1:]))
ys = best_hyp.ys[1:] # [1:] to remove the blank best_hyp = B.get_most_probable(length_norm=True)
ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks
return ys return ys