mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-04 14:44:18 +00:00
Add modified beam search for pruned rnn-t.
This commit is contained in:
parent
ad62981765
commit
bd033de8bc
@ -48,7 +48,7 @@ def greedy_search(
|
||||
device = model.device
|
||||
|
||||
decoder_input = torch.tensor(
|
||||
[blank_id] * context_size, device=device
|
||||
[blank_id] * context_size, device=device, dtype=torch.int64
|
||||
).reshape(1, context_size)
|
||||
|
||||
decoder_out = model.decoder(decoder_input, need_pad=False)
|
||||
@ -103,8 +103,9 @@ class Hypothesis:
|
||||
# Newly predicted tokens are appended to `ys`.
|
||||
ys: List[int]
|
||||
|
||||
# The log prob of ys
|
||||
log_prob: float
|
||||
# The log prob of ys.
|
||||
# It contains only one entry.
|
||||
log_prob: torch.Tensor
|
||||
|
||||
@property
|
||||
def key(self) -> str:
|
||||
@ -113,7 +114,7 @@ class Hypothesis:
|
||||
|
||||
|
||||
class HypothesisList(object):
|
||||
def __init__(self, data: Optional[Dict[str, Hypothesis]] = None):
|
||||
def __init__(self, data: Optional[Dict[str, Hypothesis]] = None) -> None:
|
||||
"""
|
||||
Args:
|
||||
data:
|
||||
@ -125,10 +126,10 @@ class HypothesisList(object):
|
||||
self._data = data
|
||||
|
||||
@property
|
||||
def data(self):
|
||||
def data(self) -> Dict[str, Hypothesis]:
|
||||
return self._data
|
||||
|
||||
def add(self, hyp: Hypothesis):
|
||||
def add(self, hyp: Hypothesis) -> None:
|
||||
"""Add a Hypothesis to `self`.
|
||||
|
||||
If `hyp` already exists in `self`, its probability is updated using
|
||||
@ -140,8 +141,10 @@ class HypothesisList(object):
|
||||
"""
|
||||
key = hyp.key
|
||||
if key in self:
|
||||
old_hyp = self._data[key]
|
||||
old_hyp.log_prob = np.logaddexp(old_hyp.log_prob, hyp.log_prob)
|
||||
old_hyp = self._data[key] # shallow copy
|
||||
torch.logaddexp(
|
||||
old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob
|
||||
)
|
||||
else:
|
||||
self._data[key] = hyp
|
||||
|
||||
@ -153,7 +156,8 @@ class HypothesisList(object):
|
||||
length_norm:
|
||||
If True, the `log_prob` of a hypothesis is normalized by the
|
||||
number of tokens in it.
|
||||
|
||||
Returns:
|
||||
Return the hypothesis that has the largest `log_prob`.
|
||||
"""
|
||||
if length_norm:
|
||||
return max(
|
||||
@ -165,6 +169,9 @@ class HypothesisList(object):
|
||||
def remove(self, hyp: Hypothesis) -> None:
|
||||
"""Remove a given hypothesis.
|
||||
|
||||
Caution:
|
||||
`self` is modified **in-place**.
|
||||
|
||||
Args:
|
||||
hyp:
|
||||
The hypothesis to be removed from `self`.
|
||||
@ -175,7 +182,7 @@ class HypothesisList(object):
|
||||
assert key in self, f"{key} does not exist"
|
||||
del self._data[key]
|
||||
|
||||
def filter(self, threshold: float) -> "HypothesisList":
|
||||
def filter(self, threshold: torch.Tensor) -> "HypothesisList":
|
||||
"""Remove all Hypotheses whose log_prob is less than threshold.
|
||||
|
||||
Caution:
|
||||
@ -183,10 +190,10 @@ class HypothesisList(object):
|
||||
|
||||
Returns:
|
||||
Return a new HypothesisList containing all hypotheses from `self`
|
||||
that have `log_prob` being greater than the given `threshold`.
|
||||
with `log_prob` being greater than the given `threshold`.
|
||||
"""
|
||||
ans = HypothesisList()
|
||||
for key, hyp in self._data.items():
|
||||
for _, hyp in self._data.items():
|
||||
if hyp.log_prob > threshold:
|
||||
ans.add(hyp) # shallow copy
|
||||
return ans
|
||||
@ -216,6 +223,106 @@ class HypothesisList(object):
|
||||
return ", ".join(s)
|
||||
|
||||
|
||||
def modified_beam_search(
|
||||
model: Transducer,
|
||||
encoder_out: torch.Tensor,
|
||||
beam: int = 4,
|
||||
) -> List[int]:
|
||||
"""It limits the maximum number of symbols per frame to 1.
|
||||
|
||||
Args:
|
||||
model:
|
||||
An instance of `Transducer`.
|
||||
encoder_out:
|
||||
A tensor of shape (N, T, C) from the encoder. Support only N==1 for now.
|
||||
beam:
|
||||
Beam size.
|
||||
Returns:
|
||||
Return the decoded result.
|
||||
"""
|
||||
|
||||
assert encoder_out.ndim == 3
|
||||
|
||||
# support only batch_size == 1 for now
|
||||
assert encoder_out.size(0) == 1, encoder_out.size(0)
|
||||
blank_id = model.decoder.blank_id
|
||||
context_size = model.decoder.context_size
|
||||
|
||||
device = model.device
|
||||
|
||||
T = encoder_out.size(1)
|
||||
|
||||
B = HypothesisList()
|
||||
B.add(
|
||||
Hypothesis(
|
||||
ys=[blank_id] * context_size,
|
||||
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
|
||||
)
|
||||
)
|
||||
|
||||
for t in range(T):
|
||||
# fmt: off
|
||||
current_encoder_out = encoder_out[:, t:t+1, :].unsqueeze(2)
|
||||
# current_encoder_out is of shape (1, 1, encoder_out_dim)
|
||||
# fmt: on
|
||||
A = list(B)
|
||||
B = HypothesisList()
|
||||
|
||||
ys_log_probs = torch.cat([hyp.log_prob.reshape(1, 1) for hyp in A])
|
||||
# ys_log_probs is of shape (num_hyps, 1)
|
||||
|
||||
decoder_input = torch.tensor(
|
||||
[hyp.ys[-context_size:] for hyp in A],
|
||||
device=device,
|
||||
dtype=torch.int64,
|
||||
)
|
||||
# decoder_input is of shape (num_hyps, context_size)
|
||||
|
||||
decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1)
|
||||
# decoder_output is of shape (num_hyps, 1,1, decoder_output_dim)
|
||||
|
||||
current_encoder_out = current_encoder_out.expand(
|
||||
decoder_out.size(0), 1, 1, -1
|
||||
)
|
||||
|
||||
logits = model.joiner(
|
||||
current_encoder_out,
|
||||
decoder_out,
|
||||
)
|
||||
# logits is of shape (num_hyps, 1, 1, vocab_size)
|
||||
logits = logits.squeeze(1).squeeze(1)
|
||||
|
||||
# now logits is of shape (num_hyps, vocab_size)
|
||||
log_probs = logits.log_softmax(dim=-1)
|
||||
|
||||
log_probs.add_(ys_log_probs)
|
||||
|
||||
log_probs = log_probs.reshape(-1)
|
||||
topk_log_probs, topk_indexes = log_probs.topk(beam)
|
||||
|
||||
# topk_hyp_indexes are indexes into `A`
|
||||
topk_hyp_indexes = topk_indexes // logits.size(-1)
|
||||
topk_token_indexes = topk_indexes % logits.size(-1)
|
||||
|
||||
topk_hyp_indexes = topk_hyp_indexes.tolist()
|
||||
topk_token_indexes = topk_token_indexes.tolist()
|
||||
|
||||
for i in range(len(topk_hyp_indexes)):
|
||||
hyp = A[topk_hyp_indexes[i]]
|
||||
new_ys = hyp.ys[:]
|
||||
new_token = topk_token_indexes[i]
|
||||
if new_token != blank_id:
|
||||
new_ys.append(new_token)
|
||||
new_log_prob = topk_log_probs[i]
|
||||
new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob)
|
||||
B.add(new_hyp)
|
||||
|
||||
best_hyp = B.get_most_probable(length_norm=True)
|
||||
ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks
|
||||
|
||||
return ys
|
||||
|
||||
|
||||
def beam_search(
|
||||
model: Transducer,
|
||||
encoder_out: torch.Tensor,
|
||||
@ -246,7 +353,9 @@ def beam_search(
|
||||
device = model.device
|
||||
|
||||
decoder_input = torch.tensor(
|
||||
[blank_id] * context_size, device=device
|
||||
[blank_id] * context_size,
|
||||
device=device,
|
||||
dtype=torch.int64,
|
||||
).reshape(1, context_size)
|
||||
|
||||
decoder_out = model.decoder(decoder_input, need_pad=False)
|
||||
@ -283,7 +392,9 @@ def beam_search(
|
||||
|
||||
if cached_key not in decoder_cache:
|
||||
decoder_input = torch.tensor(
|
||||
[y_star.ys[-context_size:]], device=device
|
||||
[y_star.ys[-context_size:]],
|
||||
device=device,
|
||||
dtype=torch.int64,
|
||||
).reshape(1, context_size)
|
||||
|
||||
decoder_out = model.decoder(decoder_input, need_pad=False)
|
||||
@ -297,7 +408,7 @@ def beam_search(
|
||||
current_encoder_out, decoder_out.unsqueeze(1)
|
||||
)
|
||||
|
||||
# TODO(fangjun): Cache the blank posterior
|
||||
# TODO(fangjun): Scale the blank posterior
|
||||
|
||||
log_prob = logits.log_softmax(dim=-1)
|
||||
# log_prob is (1, 1, 1, vocab_size)
|
||||
@ -309,7 +420,7 @@ def beam_search(
|
||||
|
||||
# First, process the blank symbol
|
||||
skip_log_prob = log_prob[blank_id]
|
||||
new_y_star_log_prob = y_star.log_prob + skip_log_prob.item()
|
||||
new_y_star_log_prob = y_star.log_prob + skip_log_prob
|
||||
|
||||
# ys[:] returns a copy of ys
|
||||
B.add(Hypothesis(ys=y_star.ys[:], log_prob=new_y_star_log_prob))
|
||||
|
@ -33,6 +33,15 @@ Usage:
|
||||
--max-duration 100 \
|
||||
--decoding-method beam_search \
|
||||
--beam-size 4
|
||||
|
||||
(3) modified beam search
|
||||
./pruned_transducer_stateless/decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./pruned_transducer_stateless/exp \
|
||||
--max-duration 100 \
|
||||
--decoding-method modified_beam_search \
|
||||
--beam-size 4
|
||||
"""
|
||||
|
||||
|
||||
@ -46,7 +55,7 @@ import sentencepiece as spm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import LibriSpeechAsrDataModule
|
||||
from beam_search import beam_search, greedy_search
|
||||
from beam_search import beam_search, greedy_search, modified_beam_search
|
||||
from conformer import Conformer
|
||||
from decoder import Decoder
|
||||
from joiner import Joiner
|
||||
@ -104,6 +113,7 @@ def get_parser():
|
||||
help="""Possible values are:
|
||||
- greedy_search
|
||||
- beam_search
|
||||
- modified_beam_search
|
||||
""",
|
||||
)
|
||||
|
||||
@ -111,7 +121,8 @@ def get_parser():
|
||||
"--beam-size",
|
||||
type=int,
|
||||
default=4,
|
||||
help="Used only when --decoding-method is beam_search",
|
||||
help="""Used only when --decoding-method is
|
||||
beam_search or modified_beam_search""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@ -125,7 +136,8 @@ def get_parser():
|
||||
"--max-sym-per-frame",
|
||||
type=int,
|
||||
default=3,
|
||||
help="Maximum number of symbols per frame",
|
||||
help="""Maximum number of symbols per frame.
|
||||
Used only when --decoding_method is greedy_search""",
|
||||
)
|
||||
|
||||
return parser
|
||||
@ -258,6 +270,10 @@ def decode_one_batch(
|
||||
hyp = beam_search(
|
||||
model=model, encoder_out=encoder_out_i, beam=params.beam_size
|
||||
)
|
||||
elif params.decoding_method == "modified_beam_search":
|
||||
hyp = modified_beam_search(
|
||||
model=model, encoder_out=encoder_out_i, beam=params.beam_size
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported decoding method: {params.decoding_method}"
|
||||
@ -391,11 +407,15 @@ def main():
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
|
||||
assert params.decoding_method in ("greedy_search", "beam_search")
|
||||
assert params.decoding_method in (
|
||||
"greedy_search",
|
||||
"beam_search",
|
||||
"modified_beam_search",
|
||||
)
|
||||
params.res_dir = params.exp_dir / params.decoding_method
|
||||
|
||||
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||
if params.decoding_method == "beam_search":
|
||||
if "beam_search" in params.decoding_method:
|
||||
params.suffix += f"-beam-{params.beam_size}"
|
||||
else:
|
||||
params.suffix += f"-context-{params.context_size}"
|
||||
|
Loading…
x
Reference in New Issue
Block a user