Add modified_beam_search for pruned_transducer_stateless/streaming_decode.py

This commit is contained in:
pkufool 2022-07-22 20:03:43 +08:00
parent a8696b36fc
commit 1b6daecc63
3 changed files with 244 additions and 57 deletions

View File

@ -751,7 +751,7 @@ class HypothesisList(object):
return ", ".join(s) return ", ".join(s)
def _get_hyps_shape(hyps: List[HypothesisList]) -> k2.RaggedShape: def get_hyps_shape(hyps: List[HypothesisList]) -> k2.RaggedShape:
"""Return a ragged shape with axes [utt][num_hyps]. """Return a ragged shape with axes [utt][num_hyps].
Args: Args:
@ -847,7 +847,7 @@ def modified_beam_search(
finalized_B = B[batch_size:] + finalized_B finalized_B = B[batch_size:] + finalized_B
B = B[:batch_size] B = B[:batch_size]
hyps_shape = _get_hyps_shape(B).to(device) hyps_shape = get_hyps_shape(B).to(device)
A = [list(b) for b in B] A = [list(b) for b in B]
B = [HypothesisList() for _ in range(batch_size)] B = [HypothesisList() for _ in range(batch_size)]

View File

@ -19,6 +19,7 @@ from typing import List, Optional, Tuple
import k2 import k2
import torch import torch
from beam_search import Hypothesis, HypothesisList
from icefall.utils import AttributeDict from icefall.utils import AttributeDict
@ -42,7 +43,8 @@ class DecodeStream(object):
device: device:
The device to run this stream. The device to run this stream.
""" """
if decoding_graph is not None: if params.decoding_method == "fast_beam_search":
assert decoding_graph is not None
assert device == decoding_graph.device assert device == decoding_graph.device
self.params = params self.params = params
@ -77,15 +79,23 @@ class DecodeStream(object):
if params.decoding_method == "greedy_search": if params.decoding_method == "greedy_search":
self.hyp = [params.blank_id] * params.context_size self.hyp = [params.blank_id] * params.context_size
elif params.decoding_method == "modified_beam_search":
self.hyps = HypothesisList()
self.hyps.add(
Hypothesis(
ys=[params.blank_id] * params.context_size,
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
)
)
elif params.decoding_method == "fast_beam_search": elif params.decoding_method == "fast_beam_search":
# The rnnt_decoding_stream for fast_beam_search. # The rnnt_decoding_stream for fast_beam_search.
self.rnnt_decoding_stream: k2.RnntDecodingStream = ( self.rnnt_decoding_stream: k2.RnntDecodingStream = (
k2.RnntDecodingStream(decoding_graph) k2.RnntDecodingStream(decoding_graph)
) )
else: else:
assert ( raise ValueError(
False f"Unsupported decoding method: {params.decoding_method}"
), f"Decoding method :{params.decoding_method} do not support." )
@property @property
def done(self) -> bool: def done(self) -> bool:
@ -124,3 +134,14 @@ class DecodeStream(object):
self._done = True self._done = True
return ret_features, ret_length return ret_features, ret_length
def decoding_result(self) -> List[int]:
"""Obtain current decoding result."""
if self.params.decoding_method == "greedy_search":
return self.hyp[self.params.context_size :] # noqa
elif self.params.decoding_method == "modified_beam_search":
best_hyp = self.hyps.get_most_probable(length_norm=True)
return best_hyp.ys[self.params.context_size :] # noqa
else:
assert self.params.decoding_method == "fast_beam_search"
return self.hyp

View File

@ -31,6 +31,7 @@ Usage:
import argparse import argparse
import logging import logging
import math import math
import warnings
from pathlib import Path from pathlib import Path
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
@ -40,6 +41,7 @@ import sentencepiece as spm
import torch import torch
import torch.nn as nn import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule from asr_datamodule import LibriSpeechAsrDataModule
from beam_search import Hypothesis, HypothesisList, get_hyps_shape
from decode_stream import DecodeStream from decode_stream import DecodeStream
from kaldifeat import Fbank, FbankOptions from kaldifeat import Fbank, FbankOptions
from lhotse import CutSet from lhotse import CutSet
@ -114,10 +116,21 @@ def get_parser():
"--decoding-method", "--decoding-method",
type=str, type=str,
default="greedy_search", default="greedy_search",
help="""Support only greedy_search and fast_beam_search now. help="""Supported decoding methods are:
greedy_search
modified_beam_search
fast_beam_search
""", """,
) )
parser.add_argument(
"--beam-size",
type=int,
default=4,
help="""An interger indicating how many candidates we will keep for each
frame. Used only when --decoding-method is modified_beam_search.""",
)
parser.add_argument( parser.add_argument(
"--beam", "--beam",
type=float, type=float,
@ -189,8 +202,17 @@ def greedy_search(
model: nn.Module, model: nn.Module,
encoder_out: torch.Tensor, encoder_out: torch.Tensor,
streams: List[DecodeStream], streams: List[DecodeStream],
) -> List[List[int]]: ) -> None:
"""Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
Args:
model:
The transducer model.
encoder_out:
Output from the encoder. Its shape is (N, T, C), where N >= 1.
streams:
A list of Stream objects.
"""
assert len(streams) == encoder_out.size(0) assert len(streams) == encoder_out.size(0)
assert encoder_out.ndim == 3 assert encoder_out.ndim == 3
@ -237,20 +259,163 @@ def greedy_search(
need_pad=False, need_pad=False,
) )
hyp_tokens = []
for stream in streams: def modified_beam_search(
hyp_tokens.append(stream.hyp) model: nn.Module,
return hyp_tokens encoder_out: torch.Tensor,
streams: List[DecodeStream],
beam: int = 4,
):
"""Beam search in batch mode with --max-sym-per-frame=1 being hardcoded.
Args:
model:
The RNN-T model.
encoder_out:
A 3-D tensor of shape (N, T, encoder_out_dim) containing the output of
the encoder model.
streams:
A list of stream objects.
beam:
Number of active paths during the beam search.
"""
assert encoder_out.ndim == 3, encoder_out.shape
assert len(streams) == encoder_out.size(0)
blank_id = model.decoder.blank_id
context_size = model.decoder.context_size
device = next(model.parameters()).device
batch_size = len(streams)
T = encoder_out.size(1)
B = [stream.hyps for stream in streams]
for t in range(T):
current_encoder_out = encoder_out[:, t].unsqueeze(1).unsqueeze(1)
# current_encoder_out's shape: (batch_size, 1, 1, encoder_out_dim)
hyps_shape = get_hyps_shape(B).to(device)
A = [list(b) for b in B]
B = [HypothesisList() for _ in range(batch_size)]
ys_log_probs = torch.stack(
[hyp.log_prob.reshape(1) for hyps in A for hyp in hyps], dim=0
) # (num_hyps, 1)
decoder_input = torch.tensor(
[hyp.ys[-context_size:] for hyps in A for hyp in hyps],
device=device,
dtype=torch.int64,
) # (num_hyps, context_size)
decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1)
# decoder_out is of shape (num_hyps, 1, 1, decoder_output_dim)
# Note: For torch 1.7.1 and below, it requires a torch.int64 tensor
# as index, so we use `to(torch.int64)` below.
current_encoder_out = torch.index_select(
current_encoder_out,
dim=0,
index=hyps_shape.row_ids(1).to(torch.int64),
) # (num_hyps, encoder_out_dim)
logits = model.joiner(current_encoder_out, decoder_out)
# logits is of shape (num_hyps, 1, 1, vocab_size)
logits = logits.squeeze(1).squeeze(1)
log_probs = logits.log_softmax(dim=-1) # (num_hyps, vocab_size)
log_probs.add_(ys_log_probs)
vocab_size = log_probs.size(-1)
log_probs = log_probs.reshape(-1)
row_splits = hyps_shape.row_splits(1) * vocab_size
log_probs_shape = k2.ragged.create_ragged_shape2(
row_splits=row_splits, cached_tot_size=log_probs.numel()
)
ragged_log_probs = k2.RaggedTensor(
shape=log_probs_shape, value=log_probs
)
for i in range(batch_size):
topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
topk_hyp_indexes = (topk_indexes // vocab_size).tolist()
topk_token_indexes = (topk_indexes % vocab_size).tolist()
for k in range(len(topk_hyp_indexes)):
hyp_idx = topk_hyp_indexes[k]
hyp = A[i][hyp_idx]
new_ys = hyp.ys[:]
new_token = topk_token_indexes[k]
if new_token != blank_id:
new_ys.append(new_token)
new_log_prob = topk_log_probs[k]
new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob)
B[i].add(new_hyp)
for i in range(batch_size):
streams[i].hyps = B[i]
def fast_beam_search( def fast_beam_search_one_best(
model: nn.Module, model: nn.Module,
encoder_out: torch.Tensor, encoder_out: torch.Tensor,
processed_lens: torch.Tensor, processed_lens: torch.Tensor,
decoding_streams: k2.RnntDecodingStreams, streams: List[DecodeStream],
) -> List[List[int]]: beam: float,
max_states: int,
max_contexts: int,
) -> None:
"""It limits the maximum number of symbols per frame to 1.
A lattice is first generated by Fsa-based beam search, then we get the
recognition by applying shortest path on the lattice.
Args:
model:
An instance of `Transducer`.
encoder_out:
A tensor of shape (N, T, C) from the encoder.
processed_lens:
A tensor of shape (N,) containing the number of processed frames
in `encoder_out` before padding.
streams:
A list of stream objects.
beam:
Beam value, similar to the beam used in Kaldi..
max_states:
Max states per stream per frame.
max_contexts:
Max contexts pre stream per frame.
"""
assert encoder_out.ndim == 3
B, T, C = encoder_out.shape B, T, C = encoder_out.shape
assert B == len(streams)
context_size = model.decoder.context_size
vocab_size = model.decoder.vocab_size
config = k2.RnntDecodingConfig(
vocab_size=vocab_size,
decoder_history_len=context_size,
beam=beam,
max_contexts=max_contexts,
max_states=max_states,
)
individual_streams = []
for i in range(B):
individual_streams.append(streams[i].rnnt_decoding_stream)
decoding_streams = k2.RnntDecodingStreams(individual_streams, config)
for t in range(T): for t in range(T):
# shape is a RaggedShape of shape (B, context) # shape is a RaggedShape of shape (B, context)
# contexts is a Tensor of shape (shape.NumElements(), context_size) # contexts is a Tensor of shape (shape.NumElements(), context_size)
@ -279,7 +444,9 @@ def fast_beam_search(
lattice = decoding_streams.format_output(processed_lens.tolist()) lattice = decoding_streams.format_output(processed_lens.tolist())
best_path = one_best_decoding(lattice) best_path = one_best_decoding(lattice)
hyp_tokens = get_texts(best_path) hyp_tokens = get_texts(best_path)
return hyp_tokens
for i in range(B):
streams[i].hyp = hyp_tokens[i]
def decode_one_chunk( def decode_one_chunk(
@ -305,8 +472,6 @@ def decode_one_chunk(
features = [] features = []
feature_lens = [] feature_lens = []
states = [] states = []
rnnt_stream_list = []
processed_lens = [] processed_lens = []
for stream in decode_streams: for stream in decode_streams:
@ -317,8 +482,6 @@ def decode_one_chunk(
feature_lens.append(feat_len) feature_lens.append(feat_len)
states.append(stream.states) states.append(stream.states)
processed_lens.append(stream.done_frames) processed_lens.append(stream.done_frames)
if params.decoding_method == "fast_beam_search":
rnnt_stream_list.append(stream.rnnt_decoding_stream)
feature_lens = torch.tensor(feature_lens, device=device) feature_lens = torch.tensor(feature_lens, device=device)
features = pad_sequence(features, batch_first=True, padding_value=LOG_EPS) features = pad_sequence(features, batch_first=True, padding_value=LOG_EPS)
@ -330,19 +493,13 @@ def decode_one_chunk(
# frames. # frames.
tail_length = 7 + (2 + params.right_context) * params.subsampling_factor tail_length = 7 + (2 + params.right_context) * params.subsampling_factor
if features.size(1) < tail_length: if features.size(1) < tail_length:
feature_lens += tail_length - features.size(1) pad_length = tail_length - features.size(1)
features = torch.cat( feature_lens += pad_length
[ features = torch.nn.functional.pad(
features, features,
torch.tensor( (0, 0, 0, pad_length),
LOG_EPS, dtype=features.dtype, device=device mode="constant",
).expand( value=LOG_EPS,
features.size(0),
tail_length - features.size(1),
features.size(2),
),
],
dim=1,
) )
states = [ states = [
@ -362,22 +519,31 @@ def decode_one_chunk(
) )
if params.decoding_method == "greedy_search": if params.decoding_method == "greedy_search":
hyp_tokens = greedy_search(model, encoder_out, decode_streams) greedy_search(
elif params.decoding_method == "fast_beam_search": model=model, encoder_out=encoder_out, streams=decode_streams
config = k2.RnntDecodingConfig(
vocab_size=params.vocab_size,
decoder_history_len=params.context_size,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
) )
decoding_streams = k2.RnntDecodingStreams(rnnt_stream_list, config) elif params.decoding_method == "fast_beam_search":
processed_lens = processed_lens + encoder_out_lens processed_lens = processed_lens + encoder_out_lens
hyp_tokens = fast_beam_search( fast_beam_search_one_best(
model, encoder_out, processed_lens, decoding_streams model=model,
encoder_out=encoder_out,
processed_lens=processed_lens,
streams=decode_streams,
beam=params.beam,
max_states=params.max_states,
max_contexts=params.max_contexts,
)
elif params.decoding_method == "modified_beam_search":
modified_beam_search(
model=model,
streams=decode_streams,
encoder_out=encoder_out,
beam=params.beam_size,
) )
else: else:
assert False raise ValueError(
f"Unsupported decoding method: {params.decoding_method}"
)
states = [torch.unbind(states[0], dim=2), torch.unbind(states[1], dim=2)] states = [torch.unbind(states[0], dim=2), torch.unbind(states[1], dim=2)]
@ -385,8 +551,6 @@ def decode_one_chunk(
for i in range(len(decode_streams)): for i in range(len(decode_streams)):
decode_streams[i].states = [states[0][i], states[1][i]] decode_streams[i].states = [states[0][i], states[1][i]]
decode_streams[i].done_frames += encoder_out_lens[i] decode_streams[i].done_frames += encoder_out_lens[i]
if params.decoding_method == "fast_beam_search":
decode_streams[i].hyp = hyp_tokens[i]
if decode_streams[i].done: if decode_streams[i].done:
finished_streams.append(i) finished_streams.append(i)
@ -469,13 +633,10 @@ def decode_dataset(
params=params, model=model, decode_streams=decode_streams params=params, model=model, decode_streams=decode_streams
) )
for i in sorted(finished_streams, reverse=True): for i in sorted(finished_streams, reverse=True):
hyp = decode_streams[i].hyp
if params.decoding_method == "greedy_search":
hyp = hyp[params.context_size :] # noqa
decode_results.append( decode_results.append(
( (
decode_streams[i].ground_truth.split(), decode_streams[i].ground_truth.split(),
sp.decode(hyp).split(), sp.decode(decode_streams[i].decoding_result()).split(),
) )
) )
del decode_streams[i] del decode_streams[i]
@ -489,24 +650,29 @@ def decode_dataset(
params=params, model=model, decode_streams=decode_streams params=params, model=model, decode_streams=decode_streams
) )
for i in sorted(finished_streams, reverse=True): for i in sorted(finished_streams, reverse=True):
hyp = decode_streams[i].hyp
if params.decoding_method == "greedy_search":
hyp = hyp[params.context_size :] # noqa
decode_results.append( decode_results.append(
( (
decode_streams[i].ground_truth.split(), decode_streams[i].ground_truth.split(),
sp.decode(hyp).split(), sp.decode(decode_streams[i].decoding_result()).split(),
) )
) )
del decode_streams[i] del decode_streams[i]
key = "greedy_search" if params.decoding_method == "greedy_search":
if params.decoding_method == "fast_beam_search": key = "greedy_search"
elif params.decoding_method == "fast_beam_search":
key = ( key = (
f"beam_{params.beam}_" f"beam_{params.beam}_"
f"max_contexts_{params.max_contexts}_" f"max_contexts_{params.max_contexts}_"
f"max_states_{params.max_states}" f"max_states_{params.max_states}"
) )
elif params.decoding_method == "modified_beam_search":
key = f"beam_size_{params.beam_size}"
else:
raise ValueError(
f"Unsupported decoding method: {params.decoding_method}"
)
return {key: decode_results} return {key: decode_results}