mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-18 21:44:18 +00:00
add modify_beam_search, fast_beam_search
This commit is contained in:
parent
026fb22076
commit
f233b16974
@ -17,6 +17,7 @@
|
||||
import math
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import k2
|
||||
import torch
|
||||
from beam_search import Hypothesis, HypothesisList
|
||||
from kaldifeat import FbankOptions, OnlineFbank, OnlineFeature
|
||||
@ -47,6 +48,7 @@ class Stream(object):
|
||||
def __init__(
|
||||
self,
|
||||
params: AttributeDict,
|
||||
decoding_graph: Optional[k2.Fsa] = None,
|
||||
device: torch.device = torch.device("cpu"),
|
||||
LOG_EPS: float = math.log(1e-10),
|
||||
) -> None:
|
||||
@ -80,6 +82,13 @@ class Stream(object):
|
||||
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
|
||||
)
|
||||
)
|
||||
elif params.decoding_method == "fast_beam_search":
|
||||
# feature_len is needed to get partial results.
|
||||
# The rnnt_decoding_stream for fast_beam_search.
|
||||
self.rnnt_decoding_stream: k2.RnntDecodingStream = (
|
||||
k2.RnntDecodingStream(decoding_graph)
|
||||
)
|
||||
self.hyp: List[int] = None
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported decoding method: {params.decoding_method}"
|
||||
@ -171,7 +180,9 @@ class Stream(object):
|
||||
"""Obtain current decoding result."""
|
||||
if self.decoding_method == "greedy_search":
|
||||
return self.hyp[self.context_size :]
|
||||
else:
|
||||
assert self.decoding_method == "modified_beam_search"
|
||||
elif self.decoding_method == "modified_beam_search":
|
||||
best_hyp = self.hyps.get_most_probable(length_norm=True)
|
||||
return best_hyp.ys[self.context_size :]
|
||||
else:
|
||||
assert self.decoding_method == "fast_beam_search"
|
||||
return self.hyp
|
||||
|
@ -17,9 +17,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import copy
|
||||
import logging
|
||||
import math
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
@ -44,8 +42,10 @@ from icefall.checkpoint import (
|
||||
find_checkpoints,
|
||||
load_checkpoint,
|
||||
)
|
||||
from icefall.decode import one_best_decoding
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
get_texts,
|
||||
setup_logger,
|
||||
store_transcripts,
|
||||
str2bool,
|
||||
@ -116,7 +116,6 @@ def get_parser():
|
||||
default="greedy_search",
|
||||
help="""Possible values are:
|
||||
- greedy_search
|
||||
- beam_search
|
||||
- modified_beam_search
|
||||
- fast_beam_search
|
||||
""",
|
||||
@ -196,7 +195,15 @@ def greedy_search(
|
||||
encoder_out: torch.Tensor,
|
||||
streams: List[Stream],
|
||||
) -> List[List[int]]:
|
||||
|
||||
"""Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
|
||||
Args:
|
||||
model:
|
||||
The transducer model.
|
||||
encoder_out:
|
||||
Output from the encoder. Its shape is (N, T, C), where N >= 1.
|
||||
streams:
|
||||
A list of Stream objects.
|
||||
"""
|
||||
assert len(streams) == encoder_out.size(0)
|
||||
assert encoder_out.ndim == 3
|
||||
|
||||
@ -205,6 +212,8 @@ def greedy_search(
|
||||
device = next(model.parameters()).device
|
||||
T = encoder_out.size(1)
|
||||
|
||||
encoder_out = model.joiner.encoder_proj(encoder_out)
|
||||
|
||||
decoder_input = torch.tensor(
|
||||
[stream.hyp[-context_size:] for stream in streams],
|
||||
device=device,
|
||||
@ -248,11 +257,216 @@ def greedy_search(
|
||||
decoder_out = model.joiner.decoder_proj(decoder_out)
|
||||
|
||||
|
||||
def modified_beam_search(
|
||||
model: nn.Module,
|
||||
encoder_out: torch.Tensor,
|
||||
streams: List[Stream],
|
||||
beam: int = 4,
|
||||
):
|
||||
"""Beam search in batch mode with --max-sym-per-frame=1 being hardcoded.
|
||||
|
||||
Args:
|
||||
model:
|
||||
The RNN-T model.
|
||||
encoder_out:
|
||||
A 3-D tensor of shape (N, T, encoder_out_dim) containing the output of
|
||||
the encoder model.
|
||||
streams:
|
||||
A list of stream objects.
|
||||
beam:
|
||||
Number of active paths during the beam search.
|
||||
"""
|
||||
assert encoder_out.ndim == 3, encoder_out.shape
|
||||
assert len(streams) == encoder_out.size(0)
|
||||
|
||||
blank_id = model.decoder.blank_id
|
||||
context_size = model.decoder.context_size
|
||||
device = model.device
|
||||
batch_size = len(streams)
|
||||
T = encoder_out.size(1)
|
||||
|
||||
B = [stream.hyps for stream in streams]
|
||||
|
||||
encoder_out = model.joiner.encoder_proj(encoder_out)
|
||||
|
||||
for t in range(T):
|
||||
current_encoder_out = encoder_out[:, t].unsqueeze(1).unsqueeze(1)
|
||||
# current_encoder_out's shape: (batch_size, 1, 1, encoder_out_dim)
|
||||
|
||||
hyps_shape = get_hyps_shape(B).to(device)
|
||||
|
||||
A = [list(b) for b in B]
|
||||
B = [HypothesisList() for _ in range(batch_size)]
|
||||
|
||||
ys_log_probs = torch.stack(
|
||||
[hyp.log_prob.reshape(1) for hyps in A for hyp in hyps], dim=0
|
||||
) # (num_hyps, 1)
|
||||
|
||||
decoder_input = torch.tensor(
|
||||
[hyp.ys[-context_size:] for hyps in A for hyp in hyps],
|
||||
device=device,
|
||||
dtype=torch.int64,
|
||||
) # (num_hyps, context_size)
|
||||
|
||||
decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1)
|
||||
decoder_out = model.joiner.decoder_proj(decoder_out)
|
||||
# decoder_out is of shape (num_hyps, 1, 1, decoder_output_dim)
|
||||
|
||||
# Note: For torch 1.7.1 and below, it requires a torch.int64 tensor
|
||||
# as index, so we use `to(torch.int64)` below.
|
||||
current_encoder_out = torch.index_select(
|
||||
current_encoder_out,
|
||||
dim=0,
|
||||
index=hyps_shape.row_ids(1).to(torch.int64),
|
||||
) # (num_hyps, encoder_out_dim)
|
||||
|
||||
logits = model.joiner(
|
||||
current_encoder_out, decoder_out, project_input=False
|
||||
)
|
||||
# logits is of shape (num_hyps, 1, 1, vocab_size)
|
||||
|
||||
logits = logits.squeeze(1).squeeze(1)
|
||||
|
||||
log_probs = logits.log_softmax(dim=-1) # (num_hyps, vocab_size)
|
||||
|
||||
log_probs.add_(ys_log_probs)
|
||||
|
||||
vocab_size = log_probs.size(-1)
|
||||
|
||||
log_probs = log_probs.reshape(-1)
|
||||
|
||||
row_splits = hyps_shape.row_splits(1) * vocab_size
|
||||
log_probs_shape = k2.ragged.create_ragged_shape2(
|
||||
row_splits=row_splits, cached_tot_size=log_probs.numel()
|
||||
)
|
||||
ragged_log_probs = k2.RaggedTensor(
|
||||
shape=log_probs_shape, value=log_probs
|
||||
)
|
||||
|
||||
for i in range(batch_size):
|
||||
topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam)
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
topk_hyp_indexes = (topk_indexes // vocab_size).tolist()
|
||||
topk_token_indexes = (topk_indexes % vocab_size).tolist()
|
||||
|
||||
for k in range(len(topk_hyp_indexes)):
|
||||
hyp_idx = topk_hyp_indexes[k]
|
||||
hyp = A[i][hyp_idx]
|
||||
|
||||
new_ys = hyp.ys[:]
|
||||
new_token = topk_token_indexes[k]
|
||||
if new_token != blank_id:
|
||||
new_ys.append(new_token)
|
||||
|
||||
new_log_prob = topk_log_probs[k]
|
||||
new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob)
|
||||
B[i].add(new_hyp)
|
||||
|
||||
for i in range(batch_size):
|
||||
streams[i].hyps = B[i]
|
||||
|
||||
|
||||
def fast_beam_search_one_best(
|
||||
model: nn.Module,
|
||||
streams: List[Stream],
|
||||
encoder_out: torch.Tensor,
|
||||
processed_lens: torch.Tensor,
|
||||
beam: float,
|
||||
max_states: int,
|
||||
max_contexts: int,
|
||||
) -> List[List[int]]:
|
||||
"""It limits the maximum number of symbols per frame to 1.
|
||||
|
||||
A lattice is first obtained using modified beam search, and then
|
||||
the shortest path within the lattice is used as the final output.
|
||||
|
||||
Args:
|
||||
model:
|
||||
An instance of `Transducer`.
|
||||
streams:
|
||||
A list of stream objects.
|
||||
encoder_out:
|
||||
A tensor of shape (N, T, C) from the encoder.
|
||||
processed_lens:
|
||||
A tensor of shape (N,) containing the number of processed frames
|
||||
in `encoder_out` before padding.
|
||||
beam:
|
||||
Beam value, similar to the beam used in Kaldi..
|
||||
max_states:
|
||||
Max states per stream per frame.
|
||||
max_contexts:
|
||||
Max contexts pre stream per frame.
|
||||
Returns:
|
||||
Return the decoded result.
|
||||
"""
|
||||
assert encoder_out.ndim == 3
|
||||
|
||||
context_size = model.decoder.context_size
|
||||
vocab_size = model.decoder.vocab_size
|
||||
|
||||
B, T, C = encoder_out.shape
|
||||
assert B == len(streams)
|
||||
|
||||
config = k2.RnntDecodingConfig(
|
||||
vocab_size=vocab_size,
|
||||
decoder_history_len=context_size,
|
||||
beam=beam,
|
||||
max_contexts=max_contexts,
|
||||
max_states=max_states,
|
||||
)
|
||||
individual_streams = []
|
||||
for i in range(B):
|
||||
individual_streams.append(streams[i].rnnt_decoding_stream)
|
||||
decoding_streams = k2.RnntDecodingStreams(individual_streams, config)
|
||||
|
||||
encoder_out = model.joiner.encoder_proj(encoder_out)
|
||||
|
||||
for t in range(T):
|
||||
# shape is a RaggedShape of shape (B, context)
|
||||
# contexts is a Tensor of shape (shape.NumElements(), context_size)
|
||||
shape, contexts = decoding_streams.get_contexts()
|
||||
# `nn.Embedding()` in torch below v1.7.1 supports only torch.int64
|
||||
contexts = contexts.to(torch.int64)
|
||||
# decoder_out is of shape (shape.NumElements(), 1, decoder_out_dim)
|
||||
decoder_out = model.decoder(contexts, need_pad=False)
|
||||
decoder_out = model.joiner.decoder_proj(decoder_out)
|
||||
# current_encoder_out is of shape
|
||||
# (shape.NumElements(), 1, joiner_dim)
|
||||
# fmt: off
|
||||
current_encoder_out = torch.index_select(
|
||||
encoder_out[:, t:t + 1, :], 0, shape.row_ids(1).to(torch.int64)
|
||||
)
|
||||
# fmt: on
|
||||
logits = model.joiner(
|
||||
current_encoder_out.unsqueeze(2),
|
||||
decoder_out.unsqueeze(1),
|
||||
project_input=False,
|
||||
)
|
||||
logits = logits.squeeze(1).squeeze(1)
|
||||
log_probs = logits.log_softmax(dim=-1)
|
||||
decoding_streams.advance(log_probs)
|
||||
|
||||
decoding_streams.terminate_and_flush_to_streams()
|
||||
|
||||
# import pdb
|
||||
|
||||
# pdb.set_trace()
|
||||
lattice = decoding_streams.format_output(processed_lens.tolist())
|
||||
|
||||
best_path = one_best_decoding(lattice)
|
||||
hyps = get_texts(best_path)
|
||||
|
||||
for i in range(B):
|
||||
streams[i].hyp = hyps[i]
|
||||
|
||||
|
||||
def decode_one_chunk(
|
||||
model: nn.Module,
|
||||
streams: List[Stream],
|
||||
params: AttributeDict,
|
||||
sp: spm.SentencePieceProcessor,
|
||||
decoding_graph: Optional[k2.Fsa] = None,
|
||||
) -> List[int]:
|
||||
device = next(model.parameters()).device
|
||||
|
||||
@ -292,7 +506,8 @@ def decode_one_chunk(
|
||||
mode="constant",
|
||||
value=LOG_EPSILON,
|
||||
)
|
||||
# stack states of all streams
|
||||
|
||||
# Stack states of all streams
|
||||
states = stack_states(state_list)
|
||||
|
||||
encoder_out, encoder_out_lens, states = model.encoder.infer(
|
||||
@ -301,7 +516,6 @@ def decode_one_chunk(
|
||||
states=states,
|
||||
num_processed_frames=num_processed_frames,
|
||||
)
|
||||
encoder_out = model.joiner.encoder_proj(encoder_out)
|
||||
|
||||
if params.decoding_method == "greedy_search":
|
||||
greedy_search(
|
||||
@ -309,20 +523,29 @@ def decode_one_chunk(
|
||||
streams=streams,
|
||||
encoder_out=encoder_out,
|
||||
)
|
||||
# elif params.decoding_method == "modified_beam_search":
|
||||
# modified_beam_search(
|
||||
# model=model,
|
||||
# streams=streams,
|
||||
# encoder_out=encoder_out,
|
||||
# sp=sp,
|
||||
# beam=params.beam_size,
|
||||
# )
|
||||
elif params.decoding_method == "modified_beam_search":
|
||||
modified_beam_search(
|
||||
model=model,
|
||||
streams=streams,
|
||||
encoder_out=encoder_out,
|
||||
beam=params.beam_size,
|
||||
)
|
||||
elif params.decoding_method == "fast_beam_search":
|
||||
fast_beam_search_one_best(
|
||||
model=model,
|
||||
streams=streams,
|
||||
encoder_out=encoder_out,
|
||||
processed_lens=(num_processed_frames >> 2) + encoder_out_lens,
|
||||
beam=params.beam,
|
||||
max_contexts=params.max_contexts,
|
||||
max_states=params.max_states,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported decoding method: {params.decoding_method}"
|
||||
)
|
||||
|
||||
# update cached states of each stream
|
||||
# Update cached states of each stream
|
||||
state_list = unstack_states(states)
|
||||
for i, s in enumerate(state_list):
|
||||
streams[i].states = s
|
||||
@ -355,9 +578,29 @@ def decode_dataset(
|
||||
model: nn.Module,
|
||||
params: AttributeDict,
|
||||
sp: spm.SentencePieceProcessor,
|
||||
decoding_graph: Optional[k2.Fsa] = None,
|
||||
):
|
||||
"""Decode dataset.
|
||||
|
||||
Args:
|
||||
cuts:
|
||||
Lhotse Cutset containing the dataset to decode.
|
||||
params:
|
||||
It is returned by :func:`get_params`.
|
||||
model:
|
||||
The Transducer model.
|
||||
sp:
|
||||
The BPE model.
|
||||
decoding_graph:
|
||||
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
|
||||
only when --decoding_method is fast_beam_search.
|
||||
|
||||
Returns:
|
||||
Return a dict, whose key may be "greedy_search" if greedy search
|
||||
is used, or it may be "beam_7" if beam size of 7 is used.
|
||||
Its value is a list of tuples. Each tuple contains two elements:
|
||||
The first is the reference transcript, and the second is the
|
||||
predicted result.
|
||||
"""
|
||||
device = next(model.parameters()).device
|
||||
|
||||
@ -369,7 +612,12 @@ def decode_dataset(
|
||||
streams = []
|
||||
for num, cut in enumerate(cuts):
|
||||
# Each utterance has a Stream.
|
||||
stream = Stream(params=params, device=device, LOG_EPS=LOG_EPSILON)
|
||||
stream = Stream(
|
||||
params=params,
|
||||
decoding_graph=decoding_graph,
|
||||
device=device,
|
||||
LOG_EPS=LOG_EPSILON,
|
||||
)
|
||||
|
||||
audio: np.ndarray = cut.load_audio()
|
||||
# audio.shape: (1, num_samples)
|
||||
@ -391,7 +639,7 @@ def decode_dataset(
|
||||
model=model,
|
||||
streams=streams,
|
||||
params=params,
|
||||
sp=sp,
|
||||
decoding_graph=decoding_graph,
|
||||
)
|
||||
|
||||
for i in sorted(finished_streams, reverse=True):
|
||||
@ -411,7 +659,7 @@ def decode_dataset(
|
||||
model=model,
|
||||
streams=streams,
|
||||
params=params,
|
||||
sp=sp,
|
||||
decoding_graph=decoding_graph,
|
||||
)
|
||||
|
||||
for i in sorted(finished_streams, reverse=True):
|
||||
@ -423,10 +671,17 @@ def decode_dataset(
|
||||
)
|
||||
del streams[i]
|
||||
|
||||
if params.decoding_method == "greedy_search":
|
||||
return {"greedy_search": decode_results}
|
||||
key = "greedy_search"
|
||||
if params.decoding_method == "fast_beam_search":
|
||||
key = (
|
||||
f"beam_{params.beam}_"
|
||||
f"max_contexts_{params.max_contexts}_"
|
||||
f"max_states_{params.max_states}"
|
||||
)
|
||||
else:
|
||||
return {f"beam_size_{params.beam_size}": decode_results}
|
||||
key = f"beam_size_{params.beam_size}"
|
||||
|
||||
return {key: decode_results}
|
||||
|
||||
|
||||
def save_results(
|
||||
@ -483,6 +738,11 @@ def main():
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
|
||||
assert params.decoding_method in (
|
||||
"greedy_search",
|
||||
"fast_beam_search",
|
||||
"modified_beam_search",
|
||||
)
|
||||
# Note: params.decoding_method is currently not used.
|
||||
params.res_dir = params.exp_dir / "streaming" / params.decoding_method
|
||||
|
||||
@ -616,6 +876,11 @@ def main():
|
||||
model.eval()
|
||||
model.device = device
|
||||
|
||||
if params.decoding_method == "fast_beam_search":
|
||||
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
|
||||
else:
|
||||
decoding_graph = None
|
||||
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
@ -633,6 +898,7 @@ def main():
|
||||
model=model,
|
||||
params=params,
|
||||
sp=sp,
|
||||
decoding_graph=decoding_graph,
|
||||
)
|
||||
|
||||
save_results(
|
||||
|
Loading…
x
Reference in New Issue
Block a user