mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
Add beam search.
This commit is contained in:
parent
cbda811a10
commit
3174bebf07
1
.gitignore
vendored
1
.gitignore
vendored
@ -8,3 +8,4 @@ exp*/
|
||||
download
|
||||
*.bak
|
||||
*-bak
|
||||
*bak.py
|
||||
|
@ -14,7 +14,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import List
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from transducer.model import Transducer
|
||||
@ -50,9 +51,10 @@ def greedy_search(model: Transducer, encoder_out: torch.Tensor) -> List[int]:
|
||||
current_encoder_out = encoder_out[:, t:t+1, :]
|
||||
# fmt: on
|
||||
logits = model.joiner(current_encoder_out, decoder_out)
|
||||
# logits is (1, 1, 1, vocab_size)
|
||||
|
||||
log_prob = logits.log_softmax(dim=-1)
|
||||
# log_prob is (N, 1, 1)
|
||||
# log_prob is (1, 1, 1, vocab_size)
|
||||
# TODO: Use logits.argmax()
|
||||
y = log_prob.argmax()
|
||||
if y != blank_id:
|
||||
@ -64,3 +66,147 @@ def greedy_search(model: Transducer, encoder_out: torch.Tensor) -> List[int]:
|
||||
t += 1
|
||||
|
||||
return hyp
|
||||
|
||||
|
||||
@dataclass
|
||||
class Hypothesis:
|
||||
ys: List[int] # the predicated sequences so far
|
||||
log_prob: float # The log prob of ys
|
||||
|
||||
# Optional decoder state. We assume it is LSTM for now,
|
||||
# so the state is a tuple (h, c)
|
||||
decoder_state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None
|
||||
|
||||
|
||||
def beam_search(
|
||||
model: Transducer,
|
||||
encoder_out: torch.Tensor,
|
||||
beam: int = 5,
|
||||
) -> List[int]:
|
||||
"""
|
||||
It implements Algorithm 1 in https://arxiv.org/pdf/1211.3711.pdf
|
||||
|
||||
espnet/nets/beam_search_transducer.py#L247 is used as a reference.
|
||||
|
||||
Args:
|
||||
model:
|
||||
An instance of `Transducer`.
|
||||
encoder_out:
|
||||
A tensor of shape (N, T, C) from the encoder. Support only N==1 for now.
|
||||
beam:
|
||||
Beam size.
|
||||
Returns:
|
||||
Return the decoded result.
|
||||
"""
|
||||
assert encoder_out.ndim == 3
|
||||
|
||||
# support only batch_size == 1 for now
|
||||
assert encoder_out.size(0) == 1, encoder_out.size(0)
|
||||
blank_id = model.decoder.blank_id
|
||||
sos_id = model.decoder.sos_id
|
||||
device = model.device
|
||||
|
||||
sos = torch.tensor([blank_id], device=device).reshape(1, 1)
|
||||
decoder_out, (h, c) = model.decoder(sos)
|
||||
T = encoder_out.size(1)
|
||||
t = 0
|
||||
B = [Hypothesis(ys=[blank_id], log_prob=0.0, decoder_state=None)]
|
||||
max_u = 20000 # terminate after this number of steps
|
||||
u = 0
|
||||
|
||||
cache: Dict[
|
||||
str, Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]
|
||||
] = {}
|
||||
|
||||
while t < T and u < max_u:
|
||||
# fmt: off
|
||||
current_encoder_out = encoder_out[:, t:t+1, :]
|
||||
# fmt: on
|
||||
A = B
|
||||
B = []
|
||||
# for hyp in A:
|
||||
# for h in A:
|
||||
# if h.ys == hyp.ys[:-1]:
|
||||
# # update the score of hyp
|
||||
# decoder_input = torch.tensor(
|
||||
# [h.ys[-1]], device=device
|
||||
# ).reshape(1, 1)
|
||||
# decoder_out, _ = model.decoder(
|
||||
# decoder_input, h.decoder_state
|
||||
# )
|
||||
# logits = model.joiner(current_encoder_out, decoder_out)
|
||||
# log_prob = logits.log_softmax(dim=-1)
|
||||
# log_prob = log_prob.squeeze()
|
||||
# hyp.log_prob += h.log_prob + log_prob[hyp.ys[-1]].item()
|
||||
|
||||
while u < max_u:
|
||||
y_star = max(A, key=lambda hyp: hyp.log_prob)
|
||||
A.remove(y_star)
|
||||
|
||||
# Note: y_star.ys is unhashable, i.e., cannot be used
|
||||
# as a key into a dict
|
||||
cached_key = "_".join(map(str, y_star.ys))
|
||||
|
||||
if cached_key not in cache:
|
||||
decoder_input = torch.tensor(
|
||||
[y_star.ys[-1]], device=device
|
||||
).reshape(1, 1)
|
||||
|
||||
decoder_out, decoder_state = model.decoder(
|
||||
decoder_input,
|
||||
y_star.decoder_state,
|
||||
)
|
||||
cache[cached_key] = (decoder_out, decoder_state)
|
||||
else:
|
||||
decoder_out, decoder_state = cache[cached_key]
|
||||
|
||||
logits = model.joiner(current_encoder_out, decoder_out)
|
||||
log_prob = logits.log_softmax(dim=-1)
|
||||
# log_prob is (1, 1, 1, vocab_size)
|
||||
log_prob = log_prob.squeeze()
|
||||
# Now log_prob is (vocab_size,)
|
||||
|
||||
# If we choose blank here, add the new hypothesis to B.
|
||||
# Otherwise, add the new hypothesis to A
|
||||
|
||||
# First, choose blank
|
||||
skip_log_prob = log_prob[blank_id]
|
||||
new_y_star_log_prob = y_star.log_prob + skip_log_prob.item()
|
||||
|
||||
# ys[:] returns a copy of ys
|
||||
new_y_star = Hypothesis(
|
||||
ys=y_star.ys[:],
|
||||
log_prob=new_y_star_log_prob,
|
||||
# Caution: Use y_star.decoder_state here
|
||||
decoder_state=y_star.decoder_state,
|
||||
)
|
||||
B.append(new_y_star)
|
||||
|
||||
# Second, choose other labels
|
||||
for i, v in enumerate(log_prob.tolist()):
|
||||
if i in (blank_id, sos_id):
|
||||
continue
|
||||
new_ys = y_star.ys + [i]
|
||||
new_log_prob = y_star.log_prob + v
|
||||
new_hyp = Hypothesis(
|
||||
ys=new_ys,
|
||||
log_prob=new_log_prob,
|
||||
decoder_state=decoder_state,
|
||||
)
|
||||
A.append(new_hyp)
|
||||
u += 1
|
||||
# check whether B contains more than "beam" elements more probable
|
||||
# than the most probable in A
|
||||
A_most_probable = max(A, key=lambda hyp: hyp.log_prob)
|
||||
B = sorted(
|
||||
[hyp for hyp in B if hyp.log_prob > A_most_probable.log_prob],
|
||||
key=lambda hyp: hyp.log_prob,
|
||||
reverse=True,
|
||||
)
|
||||
if len(B) >= beam:
|
||||
B = B[:beam]
|
||||
break
|
||||
t += 1
|
||||
best_hyp = max(B, key=lambda hyp: hyp.log_prob / len(hyp.ys[1:]))
|
||||
ys = best_hyp.ys[1:] # [1:] to remove the blank
|
||||
return ys
|
||||
|
@ -15,6 +15,25 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Usage:
|
||||
(1) greedy search
|
||||
./transducer/decode.py \
|
||||
--epoch 14 \
|
||||
--avg 7 \
|
||||
--exp-dir ./transducer/exp \
|
||||
--max-duration 100 \
|
||||
--decoding-method greedy_search
|
||||
(2) beam search
|
||||
|
||||
./transducer/decode.py \
|
||||
--epoch 14 \
|
||||
--avg 7 \
|
||||
--exp-dir ./transducer/exp \
|
||||
--max-duration 100 \
|
||||
--decoding-method beam_search \
|
||||
--beam-size 8
|
||||
"""
|
||||
|
||||
|
||||
import argparse
|
||||
@ -27,7 +46,7 @@ import sentencepiece as spm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import LibriSpeechAsrDataModule
|
||||
from transducer.beam_search import greedy_search
|
||||
from transducer.beam_search import beam_search, greedy_search
|
||||
from transducer.conformer import Conformer
|
||||
from transducer.decoder import Decoder
|
||||
from transducer.joiner import Joiner
|
||||
@ -78,6 +97,23 @@ def get_parser():
|
||||
help="Path to the BPE model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--decoding-method",
|
||||
type=str,
|
||||
default="greedy_search",
|
||||
help="""Possible values are:
|
||||
- greedy_search
|
||||
- beam_search
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--beam-size",
|
||||
type=int,
|
||||
default=5,
|
||||
help="Used only when --decoding-method is beam_search",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
@ -205,11 +241,22 @@ def decode_one_batch(
|
||||
# fmt: off
|
||||
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
|
||||
# fmt: on
|
||||
hyp = greedy_search(model=model, encoder_out=encoder_out_i)
|
||||
if params.decoding_method == "greedy_search":
|
||||
hyp = greedy_search(model=model, encoder_out=encoder_out_i)
|
||||
elif params.decoding_method == "beam_search":
|
||||
hyp = beam_search(
|
||||
model=model, encoder_out=encoder_out_i, beam=params.beam_size
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported decoding method: {params.decoding_method}"
|
||||
)
|
||||
hyps.append(sp.decode(hyp).split())
|
||||
|
||||
return {"greedy_search": hyps}
|
||||
# TODO: Implement beam search
|
||||
if params.decoding_method == "greedy_search":
|
||||
return {"greedy_search": hyps}
|
||||
else:
|
||||
return {f"beam_{params.beam_size}": hyps}
|
||||
|
||||
|
||||
def decode_dataset(
|
||||
@ -243,6 +290,11 @@ def decode_dataset(
|
||||
except TypeError:
|
||||
num_batches = "?"
|
||||
|
||||
if params.decoding_method == "greedy_search":
|
||||
log_interval = 100
|
||||
else:
|
||||
log_interval = 2
|
||||
|
||||
results = defaultdict(list)
|
||||
for batch_idx, batch in enumerate(dl):
|
||||
texts = batch["supervisions"]["text"]
|
||||
@ -265,7 +317,7 @@ def decode_dataset(
|
||||
|
||||
num_cuts += len(texts)
|
||||
|
||||
if batch_idx % 100 == 0:
|
||||
if batch_idx % log_interval == 0:
|
||||
batch_str = f"{batch_idx}/{num_batches}"
|
||||
|
||||
logging.info(
|
||||
@ -327,8 +379,13 @@ def main():
|
||||
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
params.res_dir = params.exp_dir / "greedy_search"
|
||||
|
||||
assert params.decoding_method in ("greedy_search", "beam_search")
|
||||
params.res_dir = params.exp_dir / params.decoding_method
|
||||
|
||||
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||
if params.decoding_method == "beam_search":
|
||||
params.suffix += f"-beam-{params.beam_size}"
|
||||
|
||||
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
|
||||
logging.info("Decoding started")
|
||||
|
Loading…
x
Reference in New Issue
Block a user