Add modified beam search for pruned rnn-t.

This commit is contained in:
Fangjun Kuang 2022-03-12 10:42:25 +08:00
parent ad62981765
commit bd033de8bc
2 changed files with 152 additions and 21 deletions

View File

@ -48,7 +48,7 @@ def greedy_search(
device = model.device device = model.device
decoder_input = torch.tensor( decoder_input = torch.tensor(
[blank_id] * context_size, device=device [blank_id] * context_size, device=device, dtype=torch.int64
).reshape(1, context_size) ).reshape(1, context_size)
decoder_out = model.decoder(decoder_input, need_pad=False) decoder_out = model.decoder(decoder_input, need_pad=False)
@ -103,8 +103,9 @@ class Hypothesis:
# Newly predicted tokens are appended to `ys`. # Newly predicted tokens are appended to `ys`.
ys: List[int] ys: List[int]
# The log prob of ys # The log prob of ys.
log_prob: float # It contains only one entry.
log_prob: torch.Tensor
@property @property
def key(self) -> str: def key(self) -> str:
@ -113,7 +114,7 @@ class Hypothesis:
class HypothesisList(object): class HypothesisList(object):
def __init__(self, data: Optional[Dict[str, Hypothesis]] = None): def __init__(self, data: Optional[Dict[str, Hypothesis]] = None) -> None:
""" """
Args: Args:
data: data:
@ -125,10 +126,10 @@ class HypothesisList(object):
self._data = data self._data = data
@property @property
def data(self): def data(self) -> Dict[str, Hypothesis]:
return self._data return self._data
def add(self, hyp: Hypothesis): def add(self, hyp: Hypothesis) -> None:
"""Add a Hypothesis to `self`. """Add a Hypothesis to `self`.
If `hyp` already exists in `self`, its probability is updated using If `hyp` already exists in `self`, its probability is updated using
@ -140,8 +141,10 @@ class HypothesisList(object):
""" """
key = hyp.key key = hyp.key
if key in self: if key in self:
old_hyp = self._data[key] old_hyp = self._data[key] # shallow copy
old_hyp.log_prob = np.logaddexp(old_hyp.log_prob, hyp.log_prob) torch.logaddexp(
old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob
)
else: else:
self._data[key] = hyp self._data[key] = hyp
@ -153,7 +156,8 @@ class HypothesisList(object):
length_norm: length_norm:
If True, the `log_prob` of a hypothesis is normalized by the If True, the `log_prob` of a hypothesis is normalized by the
number of tokens in it. number of tokens in it.
Returns:
Return the hypothesis that has the largest `log_prob`.
""" """
if length_norm: if length_norm:
return max( return max(
@ -165,6 +169,9 @@ class HypothesisList(object):
def remove(self, hyp: Hypothesis) -> None: def remove(self, hyp: Hypothesis) -> None:
"""Remove a given hypothesis. """Remove a given hypothesis.
Caution:
`self` is modified **in-place**.
Args: Args:
hyp: hyp:
The hypothesis to be removed from `self`. The hypothesis to be removed from `self`.
@ -175,7 +182,7 @@ class HypothesisList(object):
assert key in self, f"{key} does not exist" assert key in self, f"{key} does not exist"
del self._data[key] del self._data[key]
def filter(self, threshold: float) -> "HypothesisList": def filter(self, threshold: torch.Tensor) -> "HypothesisList":
"""Remove all Hypotheses whose log_prob is less than threshold. """Remove all Hypotheses whose log_prob is less than threshold.
Caution: Caution:
@ -183,10 +190,10 @@ class HypothesisList(object):
Returns: Returns:
Return a new HypothesisList containing all hypotheses from `self` Return a new HypothesisList containing all hypotheses from `self`
that have `log_prob` being greater than the given `threshold`. with `log_prob` being greater than the given `threshold`.
""" """
ans = HypothesisList() ans = HypothesisList()
for key, hyp in self._data.items(): for _, hyp in self._data.items():
if hyp.log_prob > threshold: if hyp.log_prob > threshold:
ans.add(hyp) # shallow copy ans.add(hyp) # shallow copy
return ans return ans
@ -216,6 +223,106 @@ class HypothesisList(object):
return ", ".join(s) return ", ".join(s)
def modified_beam_search(
model: Transducer,
encoder_out: torch.Tensor,
beam: int = 4,
) -> List[int]:
"""It limits the maximum number of symbols per frame to 1.
Args:
model:
An instance of `Transducer`.
encoder_out:
A tensor of shape (N, T, C) from the encoder. Support only N==1 for now.
beam:
Beam size.
Returns:
Return the decoded result.
"""
assert encoder_out.ndim == 3
# support only batch_size == 1 for now
assert encoder_out.size(0) == 1, encoder_out.size(0)
blank_id = model.decoder.blank_id
context_size = model.decoder.context_size
device = model.device
T = encoder_out.size(1)
B = HypothesisList()
B.add(
Hypothesis(
ys=[blank_id] * context_size,
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
)
)
for t in range(T):
# fmt: off
current_encoder_out = encoder_out[:, t:t+1, :].unsqueeze(2)
# current_encoder_out is of shape (1, 1, encoder_out_dim)
# fmt: on
A = list(B)
B = HypothesisList()
ys_log_probs = torch.cat([hyp.log_prob.reshape(1, 1) for hyp in A])
# ys_log_probs is of shape (num_hyps, 1)
decoder_input = torch.tensor(
[hyp.ys[-context_size:] for hyp in A],
device=device,
dtype=torch.int64,
)
# decoder_input is of shape (num_hyps, context_size)
decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1)
# decoder_output is of shape (num_hyps, 1,1, decoder_output_dim)
current_encoder_out = current_encoder_out.expand(
decoder_out.size(0), 1, 1, -1
)
logits = model.joiner(
current_encoder_out,
decoder_out,
)
# logits is of shape (num_hyps, 1, 1, vocab_size)
logits = logits.squeeze(1).squeeze(1)
# now logits is of shape (num_hyps, vocab_size)
log_probs = logits.log_softmax(dim=-1)
log_probs.add_(ys_log_probs)
log_probs = log_probs.reshape(-1)
topk_log_probs, topk_indexes = log_probs.topk(beam)
# topk_hyp_indexes are indexes into `A`
topk_hyp_indexes = topk_indexes // logits.size(-1)
topk_token_indexes = topk_indexes % logits.size(-1)
topk_hyp_indexes = topk_hyp_indexes.tolist()
topk_token_indexes = topk_token_indexes.tolist()
for i in range(len(topk_hyp_indexes)):
hyp = A[topk_hyp_indexes[i]]
new_ys = hyp.ys[:]
new_token = topk_token_indexes[i]
if new_token != blank_id:
new_ys.append(new_token)
new_log_prob = topk_log_probs[i]
new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob)
B.add(new_hyp)
best_hyp = B.get_most_probable(length_norm=True)
ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks
return ys
def beam_search( def beam_search(
model: Transducer, model: Transducer,
encoder_out: torch.Tensor, encoder_out: torch.Tensor,
@ -246,7 +353,9 @@ def beam_search(
device = model.device device = model.device
decoder_input = torch.tensor( decoder_input = torch.tensor(
[blank_id] * context_size, device=device [blank_id] * context_size,
device=device,
dtype=torch.int64,
).reshape(1, context_size) ).reshape(1, context_size)
decoder_out = model.decoder(decoder_input, need_pad=False) decoder_out = model.decoder(decoder_input, need_pad=False)
@ -283,7 +392,9 @@ def beam_search(
if cached_key not in decoder_cache: if cached_key not in decoder_cache:
decoder_input = torch.tensor( decoder_input = torch.tensor(
[y_star.ys[-context_size:]], device=device [y_star.ys[-context_size:]],
device=device,
dtype=torch.int64,
).reshape(1, context_size) ).reshape(1, context_size)
decoder_out = model.decoder(decoder_input, need_pad=False) decoder_out = model.decoder(decoder_input, need_pad=False)
@ -297,7 +408,7 @@ def beam_search(
current_encoder_out, decoder_out.unsqueeze(1) current_encoder_out, decoder_out.unsqueeze(1)
) )
# TODO(fangjun): Cache the blank posterior # TODO(fangjun): Scale the blank posterior
log_prob = logits.log_softmax(dim=-1) log_prob = logits.log_softmax(dim=-1)
# log_prob is (1, 1, 1, vocab_size) # log_prob is (1, 1, 1, vocab_size)
@ -309,7 +420,7 @@ def beam_search(
# First, process the blank symbol # First, process the blank symbol
skip_log_prob = log_prob[blank_id] skip_log_prob = log_prob[blank_id]
new_y_star_log_prob = y_star.log_prob + skip_log_prob.item() new_y_star_log_prob = y_star.log_prob + skip_log_prob
# ys[:] returns a copy of ys # ys[:] returns a copy of ys
B.add(Hypothesis(ys=y_star.ys[:], log_prob=new_y_star_log_prob)) B.add(Hypothesis(ys=y_star.ys[:], log_prob=new_y_star_log_prob))

View File

@ -33,6 +33,15 @@ Usage:
--max-duration 100 \ --max-duration 100 \
--decoding-method beam_search \ --decoding-method beam_search \
--beam-size 4 --beam-size 4
(3) modified beam search
./pruned_transducer_stateless/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless/exp \
--max-duration 100 \
--decoding-method modified_beam_search \
--beam-size 4
""" """
@ -46,7 +55,7 @@ import sentencepiece as spm
import torch import torch
import torch.nn as nn import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule from asr_datamodule import LibriSpeechAsrDataModule
from beam_search import beam_search, greedy_search from beam_search import beam_search, greedy_search, modified_beam_search
from conformer import Conformer from conformer import Conformer
from decoder import Decoder from decoder import Decoder
from joiner import Joiner from joiner import Joiner
@ -104,6 +113,7 @@ def get_parser():
help="""Possible values are: help="""Possible values are:
- greedy_search - greedy_search
- beam_search - beam_search
- modified_beam_search
""", """,
) )
@ -111,7 +121,8 @@ def get_parser():
"--beam-size", "--beam-size",
type=int, type=int,
default=4, default=4,
help="Used only when --decoding-method is beam_search", help="""Used only when --decoding-method is
beam_search or modified_beam_search""",
) )
parser.add_argument( parser.add_argument(
@ -125,7 +136,8 @@ def get_parser():
"--max-sym-per-frame", "--max-sym-per-frame",
type=int, type=int,
default=3, default=3,
help="Maximum number of symbols per frame", help="""Maximum number of symbols per frame.
Used only when --decoding_method is greedy_search""",
) )
return parser return parser
@ -258,6 +270,10 @@ def decode_one_batch(
hyp = beam_search( hyp = beam_search(
model=model, encoder_out=encoder_out_i, beam=params.beam_size model=model, encoder_out=encoder_out_i, beam=params.beam_size
) )
elif params.decoding_method == "modified_beam_search":
hyp = modified_beam_search(
model=model, encoder_out=encoder_out_i, beam=params.beam_size
)
else: else:
raise ValueError( raise ValueError(
f"Unsupported decoding method: {params.decoding_method}" f"Unsupported decoding method: {params.decoding_method}"
@ -391,11 +407,15 @@ def main():
params = get_params() params = get_params()
params.update(vars(args)) params.update(vars(args))
assert params.decoding_method in ("greedy_search", "beam_search") assert params.decoding_method in (
"greedy_search",
"beam_search",
"modified_beam_search",
)
params.res_dir = params.exp_dir / params.decoding_method params.res_dir = params.exp_dir / params.decoding_method
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
if params.decoding_method == "beam_search": if "beam_search" in params.decoding_method:
params.suffix += f"-beam-{params.beam_size}" params.suffix += f"-beam-{params.beam_size}"
else: else:
params.suffix += f"-context-{params.context_size}" params.suffix += f"-context-{params.context_size}"