Add modified beam search for pruned rnn-t.

This commit is contained in:
Fangjun Kuang 2022-03-12 10:42:25 +08:00
parent ad62981765
commit bd033de8bc
2 changed files with 152 additions and 21 deletions

View File

@ -48,7 +48,7 @@ def greedy_search(
device = model.device
decoder_input = torch.tensor(
[blank_id] * context_size, device=device
[blank_id] * context_size, device=device, dtype=torch.int64
).reshape(1, context_size)
decoder_out = model.decoder(decoder_input, need_pad=False)
@ -103,8 +103,9 @@ class Hypothesis:
# Newly predicted tokens are appended to `ys`.
ys: List[int]
# The log prob of ys
log_prob: float
# The log prob of ys.
# It contains only one entry.
log_prob: torch.Tensor
@property
def key(self) -> str:
@ -113,7 +114,7 @@ class Hypothesis:
class HypothesisList(object):
def __init__(self, data: Optional[Dict[str, Hypothesis]] = None):
def __init__(self, data: Optional[Dict[str, Hypothesis]] = None) -> None:
"""
Args:
data:
@ -125,10 +126,10 @@ class HypothesisList(object):
self._data = data
@property
def data(self):
def data(self) -> Dict[str, Hypothesis]:
return self._data
def add(self, hyp: Hypothesis):
def add(self, hyp: Hypothesis) -> None:
"""Add a Hypothesis to `self`.
If `hyp` already exists in `self`, its probability is updated using
@ -140,8 +141,10 @@ class HypothesisList(object):
"""
key = hyp.key
if key in self:
old_hyp = self._data[key]
old_hyp.log_prob = np.logaddexp(old_hyp.log_prob, hyp.log_prob)
old_hyp = self._data[key] # shallow copy
torch.logaddexp(
old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob
)
else:
self._data[key] = hyp
@ -153,7 +156,8 @@ class HypothesisList(object):
length_norm:
If True, the `log_prob` of a hypothesis is normalized by the
number of tokens in it.
Returns:
Return the hypothesis that has the largest `log_prob`.
"""
if length_norm:
return max(
@ -165,6 +169,9 @@ class HypothesisList(object):
def remove(self, hyp: Hypothesis) -> None:
"""Remove a given hypothesis.
Caution:
`self` is modified **in-place**.
Args:
hyp:
The hypothesis to be removed from `self`.
@ -175,7 +182,7 @@ class HypothesisList(object):
assert key in self, f"{key} does not exist"
del self._data[key]
def filter(self, threshold: float) -> "HypothesisList":
def filter(self, threshold: torch.Tensor) -> "HypothesisList":
"""Remove all Hypotheses whose log_prob is less than threshold.
Caution:
@ -183,10 +190,10 @@ class HypothesisList(object):
Returns:
Return a new HypothesisList containing all hypotheses from `self`
that have `log_prob` being greater than the given `threshold`.
with `log_prob` being greater than the given `threshold`.
"""
ans = HypothesisList()
for key, hyp in self._data.items():
for _, hyp in self._data.items():
if hyp.log_prob > threshold:
ans.add(hyp) # shallow copy
return ans
@ -216,6 +223,106 @@ class HypothesisList(object):
return ", ".join(s)
def modified_beam_search(
model: Transducer,
encoder_out: torch.Tensor,
beam: int = 4,
) -> List[int]:
"""It limits the maximum number of symbols per frame to 1.
Args:
model:
An instance of `Transducer`.
encoder_out:
A tensor of shape (N, T, C) from the encoder. Support only N==1 for now.
beam:
Beam size.
Returns:
Return the decoded result.
"""
assert encoder_out.ndim == 3
# support only batch_size == 1 for now
assert encoder_out.size(0) == 1, encoder_out.size(0)
blank_id = model.decoder.blank_id
context_size = model.decoder.context_size
device = model.device
T = encoder_out.size(1)
B = HypothesisList()
B.add(
Hypothesis(
ys=[blank_id] * context_size,
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
)
)
for t in range(T):
# fmt: off
current_encoder_out = encoder_out[:, t:t+1, :].unsqueeze(2)
# current_encoder_out is of shape (1, 1, encoder_out_dim)
# fmt: on
A = list(B)
B = HypothesisList()
ys_log_probs = torch.cat([hyp.log_prob.reshape(1, 1) for hyp in A])
# ys_log_probs is of shape (num_hyps, 1)
decoder_input = torch.tensor(
[hyp.ys[-context_size:] for hyp in A],
device=device,
dtype=torch.int64,
)
# decoder_input is of shape (num_hyps, context_size)
decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1)
# decoder_output is of shape (num_hyps, 1,1, decoder_output_dim)
current_encoder_out = current_encoder_out.expand(
decoder_out.size(0), 1, 1, -1
)
logits = model.joiner(
current_encoder_out,
decoder_out,
)
# logits is of shape (num_hyps, 1, 1, vocab_size)
logits = logits.squeeze(1).squeeze(1)
# now logits is of shape (num_hyps, vocab_size)
log_probs = logits.log_softmax(dim=-1)
log_probs.add_(ys_log_probs)
log_probs = log_probs.reshape(-1)
topk_log_probs, topk_indexes = log_probs.topk(beam)
# topk_hyp_indexes are indexes into `A`
topk_hyp_indexes = topk_indexes // logits.size(-1)
topk_token_indexes = topk_indexes % logits.size(-1)
topk_hyp_indexes = topk_hyp_indexes.tolist()
topk_token_indexes = topk_token_indexes.tolist()
for i in range(len(topk_hyp_indexes)):
hyp = A[topk_hyp_indexes[i]]
new_ys = hyp.ys[:]
new_token = topk_token_indexes[i]
if new_token != blank_id:
new_ys.append(new_token)
new_log_prob = topk_log_probs[i]
new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob)
B.add(new_hyp)
best_hyp = B.get_most_probable(length_norm=True)
ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks
return ys
def beam_search(
model: Transducer,
encoder_out: torch.Tensor,
@ -246,7 +353,9 @@ def beam_search(
device = model.device
decoder_input = torch.tensor(
[blank_id] * context_size, device=device
[blank_id] * context_size,
device=device,
dtype=torch.int64,
).reshape(1, context_size)
decoder_out = model.decoder(decoder_input, need_pad=False)
@ -283,7 +392,9 @@ def beam_search(
if cached_key not in decoder_cache:
decoder_input = torch.tensor(
[y_star.ys[-context_size:]], device=device
[y_star.ys[-context_size:]],
device=device,
dtype=torch.int64,
).reshape(1, context_size)
decoder_out = model.decoder(decoder_input, need_pad=False)
@ -297,7 +408,7 @@ def beam_search(
current_encoder_out, decoder_out.unsqueeze(1)
)
# TODO(fangjun): Cache the blank posterior
# TODO(fangjun): Scale the blank posterior
log_prob = logits.log_softmax(dim=-1)
# log_prob is (1, 1, 1, vocab_size)
@ -309,7 +420,7 @@ def beam_search(
# First, process the blank symbol
skip_log_prob = log_prob[blank_id]
new_y_star_log_prob = y_star.log_prob + skip_log_prob.item()
new_y_star_log_prob = y_star.log_prob + skip_log_prob
# ys[:] returns a copy of ys
B.add(Hypothesis(ys=y_star.ys[:], log_prob=new_y_star_log_prob))

View File

@ -33,6 +33,15 @@ Usage:
--max-duration 100 \
--decoding-method beam_search \
--beam-size 4
(3) modified beam search
./pruned_transducer_stateless/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless/exp \
--max-duration 100 \
--decoding-method modified_beam_search \
--beam-size 4
"""
@ -46,7 +55,7 @@ import sentencepiece as spm
import torch
import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule
from beam_search import beam_search, greedy_search
from beam_search import beam_search, greedy_search, modified_beam_search
from conformer import Conformer
from decoder import Decoder
from joiner import Joiner
@ -104,6 +113,7 @@ def get_parser():
help="""Possible values are:
- greedy_search
- beam_search
- modified_beam_search
""",
)
@ -111,7 +121,8 @@ def get_parser():
"--beam-size",
type=int,
default=4,
help="Used only when --decoding-method is beam_search",
help="""Used only when --decoding-method is
beam_search or modified_beam_search""",
)
parser.add_argument(
@ -125,7 +136,8 @@ def get_parser():
"--max-sym-per-frame",
type=int,
default=3,
help="Maximum number of symbols per frame",
help="""Maximum number of symbols per frame.
Used only when --decoding_method is greedy_search""",
)
return parser
@ -258,6 +270,10 @@ def decode_one_batch(
hyp = beam_search(
model=model, encoder_out=encoder_out_i, beam=params.beam_size
)
elif params.decoding_method == "modified_beam_search":
hyp = modified_beam_search(
model=model, encoder_out=encoder_out_i, beam=params.beam_size
)
else:
raise ValueError(
f"Unsupported decoding method: {params.decoding_method}"
@ -391,11 +407,15 @@ def main():
params = get_params()
params.update(vars(args))
assert params.decoding_method in ("greedy_search", "beam_search")
assert params.decoding_method in (
"greedy_search",
"beam_search",
"modified_beam_search",
)
params.res_dir = params.exp_dir / params.decoding_method
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
if params.decoding_method == "beam_search":
if "beam_search" in params.decoding_method:
params.suffix += f"-beam-{params.beam_size}"
else:
params.suffix += f"-context-{params.context_size}"