Add beam search.

This commit is contained in:
Fangjun Kuang 2021-12-15 18:50:29 +08:00
parent cbda811a10
commit 3174bebf07
3 changed files with 212 additions and 8 deletions

1
.gitignore vendored
View File

@ -8,3 +8,4 @@ exp*/
download download
*.bak *.bak
*-bak *-bak
*bak.py

View File

@ -14,7 +14,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import List from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple
import torch import torch
from transducer.model import Transducer from transducer.model import Transducer
@ -50,9 +51,10 @@ def greedy_search(model: Transducer, encoder_out: torch.Tensor) -> List[int]:
current_encoder_out = encoder_out[:, t:t+1, :] current_encoder_out = encoder_out[:, t:t+1, :]
# fmt: on # fmt: on
logits = model.joiner(current_encoder_out, decoder_out) logits = model.joiner(current_encoder_out, decoder_out)
# logits is (1, 1, 1, vocab_size)
log_prob = logits.log_softmax(dim=-1) log_prob = logits.log_softmax(dim=-1)
# log_prob is (N, 1, 1) # log_prob is (1, 1, 1, vocab_size)
# TODO: Use logits.argmax() # TODO: Use logits.argmax()
y = log_prob.argmax() y = log_prob.argmax()
if y != blank_id: if y != blank_id:
@ -64,3 +66,147 @@ def greedy_search(model: Transducer, encoder_out: torch.Tensor) -> List[int]:
t += 1 t += 1
return hyp return hyp
@dataclass
class Hypothesis:
ys: List[int] # the predicated sequences so far
log_prob: float # The log prob of ys
# Optional decoder state. We assume it is LSTM for now,
# so the state is a tuple (h, c)
decoder_state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None
def beam_search(
model: Transducer,
encoder_out: torch.Tensor,
beam: int = 5,
) -> List[int]:
"""
It implements Algorithm 1 in https://arxiv.org/pdf/1211.3711.pdf
espnet/nets/beam_search_transducer.py#L247 is used as a reference.
Args:
model:
An instance of `Transducer`.
encoder_out:
A tensor of shape (N, T, C) from the encoder. Support only N==1 for now.
beam:
Beam size.
Returns:
Return the decoded result.
"""
assert encoder_out.ndim == 3
# support only batch_size == 1 for now
assert encoder_out.size(0) == 1, encoder_out.size(0)
blank_id = model.decoder.blank_id
sos_id = model.decoder.sos_id
device = model.device
sos = torch.tensor([blank_id], device=device).reshape(1, 1)
decoder_out, (h, c) = model.decoder(sos)
T = encoder_out.size(1)
t = 0
B = [Hypothesis(ys=[blank_id], log_prob=0.0, decoder_state=None)]
max_u = 20000 # terminate after this number of steps
u = 0
cache: Dict[
str, Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]
] = {}
while t < T and u < max_u:
# fmt: off
current_encoder_out = encoder_out[:, t:t+1, :]
# fmt: on
A = B
B = []
# for hyp in A:
# for h in A:
# if h.ys == hyp.ys[:-1]:
# # update the score of hyp
# decoder_input = torch.tensor(
# [h.ys[-1]], device=device
# ).reshape(1, 1)
# decoder_out, _ = model.decoder(
# decoder_input, h.decoder_state
# )
# logits = model.joiner(current_encoder_out, decoder_out)
# log_prob = logits.log_softmax(dim=-1)
# log_prob = log_prob.squeeze()
# hyp.log_prob += h.log_prob + log_prob[hyp.ys[-1]].item()
while u < max_u:
y_star = max(A, key=lambda hyp: hyp.log_prob)
A.remove(y_star)
# Note: y_star.ys is unhashable, i.e., cannot be used
# as a key into a dict
cached_key = "_".join(map(str, y_star.ys))
if cached_key not in cache:
decoder_input = torch.tensor(
[y_star.ys[-1]], device=device
).reshape(1, 1)
decoder_out, decoder_state = model.decoder(
decoder_input,
y_star.decoder_state,
)
cache[cached_key] = (decoder_out, decoder_state)
else:
decoder_out, decoder_state = cache[cached_key]
logits = model.joiner(current_encoder_out, decoder_out)
log_prob = logits.log_softmax(dim=-1)
# log_prob is (1, 1, 1, vocab_size)
log_prob = log_prob.squeeze()
# Now log_prob is (vocab_size,)
# If we choose blank here, add the new hypothesis to B.
# Otherwise, add the new hypothesis to A
# First, choose blank
skip_log_prob = log_prob[blank_id]
new_y_star_log_prob = y_star.log_prob + skip_log_prob.item()
# ys[:] returns a copy of ys
new_y_star = Hypothesis(
ys=y_star.ys[:],
log_prob=new_y_star_log_prob,
# Caution: Use y_star.decoder_state here
decoder_state=y_star.decoder_state,
)
B.append(new_y_star)
# Second, choose other labels
for i, v in enumerate(log_prob.tolist()):
if i in (blank_id, sos_id):
continue
new_ys = y_star.ys + [i]
new_log_prob = y_star.log_prob + v
new_hyp = Hypothesis(
ys=new_ys,
log_prob=new_log_prob,
decoder_state=decoder_state,
)
A.append(new_hyp)
u += 1
# check whether B contains more than "beam" elements more probable
# than the most probable in A
A_most_probable = max(A, key=lambda hyp: hyp.log_prob)
B = sorted(
[hyp for hyp in B if hyp.log_prob > A_most_probable.log_prob],
key=lambda hyp: hyp.log_prob,
reverse=True,
)
if len(B) >= beam:
B = B[:beam]
break
t += 1
best_hyp = max(B, key=lambda hyp: hyp.log_prob / len(hyp.ys[1:]))
ys = best_hyp.ys[1:] # [1:] to remove the blank
return ys

View File

@ -15,6 +15,25 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""
Usage:
(1) greedy search
./transducer/decode.py \
--epoch 14 \
--avg 7 \
--exp-dir ./transducer/exp \
--max-duration 100 \
--decoding-method greedy_search
(2) beam search
./transducer/decode.py \
--epoch 14 \
--avg 7 \
--exp-dir ./transducer/exp \
--max-duration 100 \
--decoding-method beam_search \
--beam-size 8
"""
import argparse import argparse
@ -27,7 +46,7 @@ import sentencepiece as spm
import torch import torch
import torch.nn as nn import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule from asr_datamodule import LibriSpeechAsrDataModule
from transducer.beam_search import greedy_search from transducer.beam_search import beam_search, greedy_search
from transducer.conformer import Conformer from transducer.conformer import Conformer
from transducer.decoder import Decoder from transducer.decoder import Decoder
from transducer.joiner import Joiner from transducer.joiner import Joiner
@ -78,6 +97,23 @@ def get_parser():
help="Path to the BPE model", help="Path to the BPE model",
) )
parser.add_argument(
"--decoding-method",
type=str,
default="greedy_search",
help="""Possible values are:
- greedy_search
- beam_search
""",
)
parser.add_argument(
"--beam-size",
type=int,
default=5,
help="Used only when --decoding-method is beam_search",
)
return parser return parser
@ -205,11 +241,22 @@ def decode_one_batch(
# fmt: off # fmt: off
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
# fmt: on # fmt: on
hyp = greedy_search(model=model, encoder_out=encoder_out_i) if params.decoding_method == "greedy_search":
hyp = greedy_search(model=model, encoder_out=encoder_out_i)
elif params.decoding_method == "beam_search":
hyp = beam_search(
model=model, encoder_out=encoder_out_i, beam=params.beam_size
)
else:
raise ValueError(
f"Unsupported decoding method: {params.decoding_method}"
)
hyps.append(sp.decode(hyp).split()) hyps.append(sp.decode(hyp).split())
return {"greedy_search": hyps} if params.decoding_method == "greedy_search":
# TODO: Implement beam search return {"greedy_search": hyps}
else:
return {f"beam_{params.beam_size}": hyps}
def decode_dataset( def decode_dataset(
@ -243,6 +290,11 @@ def decode_dataset(
except TypeError: except TypeError:
num_batches = "?" num_batches = "?"
if params.decoding_method == "greedy_search":
log_interval = 100
else:
log_interval = 2
results = defaultdict(list) results = defaultdict(list)
for batch_idx, batch in enumerate(dl): for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"] texts = batch["supervisions"]["text"]
@ -265,7 +317,7 @@ def decode_dataset(
num_cuts += len(texts) num_cuts += len(texts)
if batch_idx % 100 == 0: if batch_idx % log_interval == 0:
batch_str = f"{batch_idx}/{num_batches}" batch_str = f"{batch_idx}/{num_batches}"
logging.info( logging.info(
@ -327,8 +379,13 @@ def main():
params = get_params() params = get_params()
params.update(vars(args)) params.update(vars(args))
params.res_dir = params.exp_dir / "greedy_search"
assert params.decoding_method in ("greedy_search", "beam_search")
params.res_dir = params.exp_dir / params.decoding_method
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
if params.decoding_method == "beam_search":
params.suffix += f"-beam-{params.beam_size}"
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
logging.info("Decoding started") logging.info("Decoding started")