mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
Support modified beam search in batch mode. (#264)
* Support modified beam search in batch mode. * Update k2 versions in GitHub CI.
This commit is contained in:
parent
d5c78a2238
commit
8c7995d493
@ -188,7 +188,7 @@ def greedy_search_batch(
|
|||||||
encoder_out:
|
encoder_out:
|
||||||
Output from the encoder. Its shape is (N, T, C), where N >= 1.
|
Output from the encoder. Its shape is (N, T, C), where N >= 1.
|
||||||
Returns:
|
Returns:
|
||||||
Return a list-of-list integers containing the decoded results.
|
Return a list-of-list of token IDs containing the decoded results.
|
||||||
len(ans) equals to encoder_out.size(0).
|
len(ans) equals to encoder_out.size(0).
|
||||||
"""
|
"""
|
||||||
assert encoder_out.ndim == 3
|
assert encoder_out.ndim == 3
|
||||||
@ -362,13 +362,156 @@ class HypothesisList(object):
|
|||||||
return ", ".join(s)
|
return ", ".join(s)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_hyps_shape(hyps: List[HypothesisList]) -> k2.RaggedShape:
|
||||||
|
"""Return a ragged shape with axes [utt][num_hyps].
|
||||||
|
|
||||||
|
Args:
|
||||||
|
hyps:
|
||||||
|
len(hyps) == batch_size. It contains the current hypothesis for
|
||||||
|
each utterance in the batch.
|
||||||
|
Returns:
|
||||||
|
Return a ragged shape with 2 axes [utt][num_hyps]. Note that
|
||||||
|
the shape is on CPU.
|
||||||
|
"""
|
||||||
|
num_hyps = [len(h) for h in hyps]
|
||||||
|
|
||||||
|
# torch.cumsum() is inclusive sum, so we put a 0 at the beginning
|
||||||
|
# to get exclusive sum later.
|
||||||
|
num_hyps.insert(0, 0)
|
||||||
|
|
||||||
|
num_hyps = torch.tensor(num_hyps)
|
||||||
|
row_splits = torch.cumsum(num_hyps, dim=0, dtype=torch.int32)
|
||||||
|
ans = k2.ragged.create_ragged_shape2(
|
||||||
|
row_splits=row_splits, cached_tot_size=row_splits[-1].item()
|
||||||
|
)
|
||||||
|
return ans
|
||||||
|
|
||||||
|
|
||||||
def modified_beam_search(
|
def modified_beam_search(
|
||||||
model: Transducer,
|
model: Transducer,
|
||||||
encoder_out: torch.Tensor,
|
encoder_out: torch.Tensor,
|
||||||
beam: int = 4,
|
beam: int = 4,
|
||||||
|
) -> List[List[int]]:
|
||||||
|
"""Beam search in batch mode with --max-sym-per-frame=1 being hardcoded.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model:
|
||||||
|
The transducer model.
|
||||||
|
encoder_out:
|
||||||
|
Output from the encoder. Its shape is (N, T, C).
|
||||||
|
beam:
|
||||||
|
Number of active paths during the beam search.
|
||||||
|
Returns:
|
||||||
|
Return a list-of-list of token IDs. ans[i] is the decoding results
|
||||||
|
for the i-th utterance.
|
||||||
|
"""
|
||||||
|
assert encoder_out.ndim == 3, encoder_out.shape
|
||||||
|
|
||||||
|
batch_size = encoder_out.size(0)
|
||||||
|
T = encoder_out.size(1)
|
||||||
|
|
||||||
|
blank_id = model.decoder.blank_id
|
||||||
|
context_size = model.decoder.context_size
|
||||||
|
device = model.device
|
||||||
|
B = [HypothesisList() for _ in range(batch_size)]
|
||||||
|
for i in range(batch_size):
|
||||||
|
B[i].add(
|
||||||
|
Hypothesis(
|
||||||
|
ys=[blank_id] * context_size,
|
||||||
|
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
for t in range(T):
|
||||||
|
current_encoder_out = encoder_out[:, t : t + 1, :].unsqueeze(2) # noqa
|
||||||
|
# current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim)
|
||||||
|
|
||||||
|
hyps_shape = _get_hyps_shape(B).to(device)
|
||||||
|
|
||||||
|
A = [list(b) for b in B]
|
||||||
|
B = [HypothesisList() for _ in range(batch_size)]
|
||||||
|
|
||||||
|
ys_log_probs = torch.cat(
|
||||||
|
[hyp.log_prob.reshape(1, 1) for hyps in A for hyp in hyps]
|
||||||
|
) # (num_hyps, 1)
|
||||||
|
|
||||||
|
decoder_input = torch.tensor(
|
||||||
|
[hyp.ys[-context_size:] for hyps in A for hyp in hyps],
|
||||||
|
device=device,
|
||||||
|
dtype=torch.int64,
|
||||||
|
) # (num_hyps, context_size)
|
||||||
|
|
||||||
|
decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1)
|
||||||
|
# decoder_output is of shape (num_hyps, 1, 1, decoder_output_dim)
|
||||||
|
|
||||||
|
# Note: For torch 1.7.1 and below, it requires a torch.int64 tensor
|
||||||
|
# as index, so we use `to(torch.int64)` below.
|
||||||
|
current_encoder_out = torch.index_select(
|
||||||
|
current_encoder_out,
|
||||||
|
dim=0,
|
||||||
|
index=hyps_shape.row_ids(1).to(torch.int64),
|
||||||
|
) # (num_hyps, 1, 1, encoder_out_dim)
|
||||||
|
|
||||||
|
logits = model.joiner(
|
||||||
|
current_encoder_out,
|
||||||
|
decoder_out,
|
||||||
|
) # (num_hyps, 1, 1, vocab_size)
|
||||||
|
|
||||||
|
logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size)
|
||||||
|
|
||||||
|
log_probs = logits.log_softmax(dim=-1) # (num_hyps, vocab_size)
|
||||||
|
|
||||||
|
log_probs.add_(ys_log_probs)
|
||||||
|
|
||||||
|
vocab_size = log_probs.size(-1)
|
||||||
|
|
||||||
|
log_probs = log_probs.reshape(-1)
|
||||||
|
|
||||||
|
row_splits = hyps_shape.row_splits(1) * vocab_size
|
||||||
|
log_probs_shape = k2.ragged.create_ragged_shape2(
|
||||||
|
row_splits=row_splits, cached_tot_size=log_probs.numel()
|
||||||
|
)
|
||||||
|
ragged_log_probs = k2.RaggedTensor(
|
||||||
|
shape=log_probs_shape, value=log_probs
|
||||||
|
)
|
||||||
|
|
||||||
|
for i in range(batch_size):
|
||||||
|
topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam)
|
||||||
|
|
||||||
|
topk_hyp_indexes = (topk_indexes // vocab_size).tolist()
|
||||||
|
topk_token_indexes = (topk_indexes % vocab_size).tolist()
|
||||||
|
|
||||||
|
for k in range(len(topk_hyp_indexes)):
|
||||||
|
hyp_idx = topk_hyp_indexes[k]
|
||||||
|
hyp = A[i][hyp_idx]
|
||||||
|
|
||||||
|
new_ys = hyp.ys[:]
|
||||||
|
new_token = topk_token_indexes[k]
|
||||||
|
if new_token != blank_id:
|
||||||
|
new_ys.append(new_token)
|
||||||
|
|
||||||
|
new_log_prob = topk_log_probs[k]
|
||||||
|
new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob)
|
||||||
|
B[i].add(new_hyp)
|
||||||
|
|
||||||
|
best_hyps = [b.get_most_probable(length_norm=True) for b in B]
|
||||||
|
ans = [h.ys[context_size:] for h in best_hyps]
|
||||||
|
|
||||||
|
return ans
|
||||||
|
|
||||||
|
|
||||||
|
def _deprecated_modified_beam_search(
|
||||||
|
model: Transducer,
|
||||||
|
encoder_out: torch.Tensor,
|
||||||
|
beam: int = 4,
|
||||||
) -> List[int]:
|
) -> List[int]:
|
||||||
"""It limits the maximum number of symbols per frame to 1.
|
"""It limits the maximum number of symbols per frame to 1.
|
||||||
|
|
||||||
|
It decodes only one utterance at a time. We keep it only for reference.
|
||||||
|
The function :func:`modified_beam_search` should be preferred as it
|
||||||
|
supports batch decoding.
|
||||||
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model:
|
model:
|
||||||
An instance of `Transducer`.
|
An instance of `Transducer`.
|
||||||
|
@ -272,6 +272,14 @@ def decode_one_batch(
|
|||||||
)
|
)
|
||||||
for hyp in sp.decode(hyp_tokens):
|
for hyp in sp.decode(hyp_tokens):
|
||||||
hyps.append(hyp.split())
|
hyps.append(hyp.split())
|
||||||
|
elif params.decoding_method == "modified_beam_search":
|
||||||
|
hyp_tokens = modified_beam_search(
|
||||||
|
model=model,
|
||||||
|
encoder_out=encoder_out,
|
||||||
|
beam=params.beam_size,
|
||||||
|
)
|
||||||
|
for hyp in sp.decode(hyp_tokens):
|
||||||
|
hyps.append(hyp.split())
|
||||||
else:
|
else:
|
||||||
batch_size = encoder_out.size(0)
|
batch_size = encoder_out.size(0)
|
||||||
|
|
||||||
@ -291,12 +299,6 @@ def decode_one_batch(
|
|||||||
encoder_out=encoder_out_i,
|
encoder_out=encoder_out_i,
|
||||||
beam=params.beam_size,
|
beam=params.beam_size,
|
||||||
)
|
)
|
||||||
elif params.decoding_method == "modified_beam_search":
|
|
||||||
hyp = modified_beam_search(
|
|
||||||
model=model,
|
|
||||||
encoder_out=encoder_out_i,
|
|
||||||
beam=params.beam_size,
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unsupported decoding method: {params.decoding_method}"
|
f"Unsupported decoding method: {params.decoding_method}"
|
||||||
|
@ -50,7 +50,12 @@ import kaldifeat
|
|||||||
import sentencepiece as spm
|
import sentencepiece as spm
|
||||||
import torch
|
import torch
|
||||||
import torchaudio
|
import torchaudio
|
||||||
from beam_search import beam_search, greedy_search, modified_beam_search
|
from beam_search import (
|
||||||
|
beam_search,
|
||||||
|
greedy_search,
|
||||||
|
greedy_search_batch,
|
||||||
|
modified_beam_search,
|
||||||
|
)
|
||||||
from torch.nn.utils.rnn import pad_sequence
|
from torch.nn.utils.rnn import pad_sequence
|
||||||
from train import get_params, get_transducer_model
|
from train import get_params, get_transducer_model
|
||||||
|
|
||||||
@ -224,6 +229,23 @@ def main():
|
|||||||
if params.method == "beam_search":
|
if params.method == "beam_search":
|
||||||
msg += f" with beam size {params.beam_size}"
|
msg += f" with beam size {params.beam_size}"
|
||||||
logging.info(msg)
|
logging.info(msg)
|
||||||
|
if params.method == "modified_beam_search":
|
||||||
|
hyp_tokens = modified_beam_search(
|
||||||
|
model=model,
|
||||||
|
encoder_out=encoder_out,
|
||||||
|
beam=params.beam_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
for hyp in sp.decode(hyp_tokens):
|
||||||
|
hyps.append(hyp.split())
|
||||||
|
elif params.method == "greedy_search" and params.max_sym_per_frame == 1:
|
||||||
|
hyp_tokens = greedy_search_batch(
|
||||||
|
model=model,
|
||||||
|
encoder_out=encoder_out,
|
||||||
|
)
|
||||||
|
for hyp in sp.decode(hyp_tokens):
|
||||||
|
hyps.append(hyp.split())
|
||||||
|
else:
|
||||||
for i in range(num_waves):
|
for i in range(num_waves):
|
||||||
# fmt: off
|
# fmt: off
|
||||||
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
|
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
|
||||||
@ -236,11 +258,9 @@ def main():
|
|||||||
)
|
)
|
||||||
elif params.method == "beam_search":
|
elif params.method == "beam_search":
|
||||||
hyp = beam_search(
|
hyp = beam_search(
|
||||||
model=model, encoder_out=encoder_out_i, beam=params.beam_size
|
model=model,
|
||||||
)
|
encoder_out=encoder_out_i,
|
||||||
elif params.method == "modified_beam_search":
|
beam=params.beam_size,
|
||||||
hyp = modified_beam_search(
|
|
||||||
model=model, encoder_out=encoder_out_i, beam=params.beam_size
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported method: {params.method}")
|
raise ValueError(f"Unsupported method: {params.method}")
|
||||||
|
@ -11,7 +11,7 @@ graphviz==0.19.1
|
|||||||
-f https://download.pytorch.org/whl/cpu/torch_stable.html torch==1.10.0+cpu
|
-f https://download.pytorch.org/whl/cpu/torch_stable.html torch==1.10.0+cpu
|
||||||
-f https://download.pytorch.org/whl/cpu/torch_stable.html torchaudio==0.10.0+cpu
|
-f https://download.pytorch.org/whl/cpu/torch_stable.html torchaudio==0.10.0+cpu
|
||||||
|
|
||||||
-f https://k2-fsa.org/nightly/ k2==1.9.dev20211101+cpu.torch1.10.0
|
-f https://k2-fsa.org/nightly/ k2==1.14.dev20220316+cpu.torch1.10.0
|
||||||
|
|
||||||
git+https://github.com/lhotse-speech/lhotse
|
git+https://github.com/lhotse-speech/lhotse
|
||||||
kaldilm==1.11
|
kaldilm==1.11
|
||||||
|
Loading…
x
Reference in New Issue
Block a user