Support modified beam search decoding for streaming inference with Emformer model.

This commit is contained in:
yaozengwei 2022-04-19 22:00:47 +08:00
parent 0f45356ee6
commit 5228b44de7
2 changed files with 210 additions and 40 deletions

View File

@ -18,16 +18,23 @@
import argparse import argparse
import logging import logging
import warnings
from pathlib import Path from pathlib import Path
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
import k2
import numpy as np import numpy as np
import sentencepiece as spm import sentencepiece as spm
import torch import torch
import torch.nn as nn import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule from asr_datamodule import LibriSpeechAsrDataModule
from beam_search import Hypothesis, HypothesisList, _get_hyps_shape
from emformer import LOG_EPSILON, stack_states, unstack_states from emformer import LOG_EPSILON, stack_states, unstack_states
from streaming_feature_extractor import FeatureExtractionStream from streaming_feature_extractor import (
FeatureExtractionStream,
GreedySearchStream,
ModifiedBeamSearchStream,
)
from train import add_model_arguments, get_params, get_transducer_model from train import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import ( from icefall.checkpoint import (
@ -50,6 +57,7 @@ def get_parser():
help="It specifies the checkpoint to use for decoding." help="It specifies the checkpoint to use for decoding."
"Note: Epoch counts from 0.", "Note: Epoch counts from 0.",
) )
parser.add_argument( parser.add_argument(
"--avg", "--avg",
type=int, type=int,
@ -208,7 +216,7 @@ class StreamList(object):
self, self,
batch_size: int, batch_size: int,
context_size: int, context_size: int,
blank_id: int, decoding_method: str,
): ):
""" """
Args: Args:
@ -216,14 +224,21 @@ class StreamList(object):
Size of this batch. Size of this batch.
context_size: context_size:
Context size of the RNN-T decoder model. Context size of the RNN-T decoder model.
blank_id: decoding_method:
The ID of the blank symbol of the BPE model. Decoding method. The possible values are:
- greedy_search
- modified_beam_search
""" """
decoding_classes = {
"greedy_search": GreedySearchStream,
"modified_beam_search": ModifiedBeamSearchStream,
}
assert decoding_method in decoding_classes
cls = decoding_classes[decoding_method]
self.streams = [ self.streams = [
FeatureExtractionStream( cls(context_size=context_size) for _ in range(batch_size)
context_size=context_size, blank_id=blank_id
)
for _ in range(batch_size)
] ]
@property @property
@ -238,7 +253,7 @@ class StreamList(object):
audio_samples: List[torch.Tensor], audio_samples: List[torch.Tensor],
sampling_rate: float, sampling_rate: float,
): ):
"""Feeed audio samples to each stream. """Feed audio samples to each stream.
Args: Args:
audio_samples: audio_samples:
A list of 1-D tensors containing the audio samples for each A list of 1-D tensors containing the audio samples for each
@ -314,7 +329,7 @@ class StreamList(object):
def greedy_search( def greedy_search(
model: nn.Module, model: nn.Module,
streams: List[FeatureExtractionStream], streams: List[GreedySearchStream],
encoder_out: torch.Tensor, encoder_out: torch.Tensor,
sp: spm.SentencePieceProcessor, sp: spm.SentencePieceProcessor,
): ):
@ -333,7 +348,15 @@ def greedy_search(
blank_id = model.decoder.blank_id blank_id = model.decoder.blank_id
context_size = model.decoder.context_size context_size = model.decoder.context_size
device = model.device device = model.device
assert len(streams) == encoder_out.size(0)
assert encoder_out.ndim == 3
for s in streams:
if s.hyp is None:
s.hyp = Hypothesis(
ys=([blank_id] * context_size),
log_prob=torch.tensor([0.0], device=device),
)
if streams[0].decoder_out is None: if streams[0].decoder_out is None:
decoder_input = torch.tensor( decoder_input = torch.tensor(
[stream.hyp.ys[-context_size:] for stream in streams], [stream.hyp.ys[-context_size:] for stream in streams],
@ -351,8 +374,6 @@ def greedy_search(
dim=0, dim=0,
) )
assert encoder_out.ndim == 3
T = encoder_out.size(1) T = encoder_out.size(1)
for t in range(T): for t in range(T):
current_encoder_out = encoder_out[:, t] current_encoder_out = encoder_out[:, t]
@ -381,20 +402,132 @@ def greedy_search(
) )
for k, s in enumerate(streams): for k, s in enumerate(streams):
logging.info( logging.info(f"Partial result {k}:\n{sp.decode(s.result)}")
f"Partial result {k}:\n{sp.decode(s.hyp.ys[context_size:])}"
)
decoder_out_list = decoder_out.unbind(dim=0) decoder_out_list = decoder_out.unbind(dim=0)
for i, d in enumerate(decoder_out_list): for i, d in enumerate(decoder_out_list):
streams[i].decoder_out = d streams[i].decoder_out = d
def modified_beam_search(
model: nn.Module,
streams: List[ModifiedBeamSearchStream],
encoder_out: torch.Tensor,
sp: spm.SentencePieceProcessor,
beam: int = 4,
):
"""
Args:
model:
The RNN-T model.
stream:
A stream object.
encoder_out:
A 3-D tensor of shape (N, T, encoder_out_dim) containing the output of
the encoder model.
sp:
The BPE model.
beam:
Number of active paths during the beam search.
"""
blank_id = model.decoder.blank_id
context_size = model.decoder.context_size
device = model.device
assert encoder_out.ndim == 3, encoder_out.shape
assert len(streams) == encoder_out.size(0)
batch_size = len(streams)
for s in streams:
if len(s.hyps) == 0:
s.hyps.add(
Hypothesis(
ys=[blank_id] * context_size,
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
)
)
B = [s.hyps for s in streams]
T = encoder_out.size(1)
for t in range(T):
current_encoder_out = encoder_out[:, t]
# current_encoder_out's shape: (batch_size, encoder_out_dim)
hyps_shape = _get_hyps_shape(B).to(device)
A = [list(b) for b in B]
B = [HypothesisList() for _ in range(batch_size)]
ys_log_probs = torch.cat(
[hyp.log_prob.reshape(1, 1) for hyps in A for hyp in hyps]
) # (num_hyps, 1)
decoder_input = torch.tensor(
[hyp.ys[-context_size:] for hyps in A for hyp in hyps],
device=device,
dtype=torch.int64,
) # (num_hyps, context_size)
decoder_out = model.decoder(decoder_input, need_pad=False).squeeze(1)
# decoder_out is of shape (num_hyps, decoder_output_dim)
# Note: For torch 1.7.1 and below, it requires a torch.int64 tensor
# as index, so we use `to(torch.int64)` below.
current_encoder_out = torch.index_select(
current_encoder_out,
dim=0,
index=hyps_shape.row_ids(1).to(torch.int64),
) # (num_hyps, encoder_out_dim)
logits = model.joiner(current_encoder_out, decoder_out)
# logits is of shape (num_hyps, vocab_size)
log_probs = logits.log_softmax(dim=-1) # (num_hyps, vocab_size)
log_probs.add_(ys_log_probs)
vocab_size = log_probs.size(-1)
log_probs = log_probs.reshape(-1)
row_splits = hyps_shape.row_splits(1) * vocab_size
log_probs_shape = k2.ragged.create_ragged_shape2(
row_splits=row_splits, cached_tot_size=log_probs.numel()
)
ragged_log_probs = k2.RaggedTensor(
shape=log_probs_shape, value=log_probs
)
for i in range(batch_size):
topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
topk_hyp_indexes = (topk_indexes // vocab_size).tolist()
topk_token_indexes = (topk_indexes % vocab_size).tolist()
for k in range(len(topk_hyp_indexes)):
hyp_idx = topk_hyp_indexes[k]
hyp = A[i][hyp_idx]
new_ys = hyp.ys[:]
new_token = topk_token_indexes[k]
if new_token != blank_id:
new_ys.append(new_token)
new_log_prob = topk_log_probs[k]
new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob)
B[i].add(new_hyp)
streams[i].hyps = B[i]
logging.info(f"Partial result {i}:\n{sp.decode(streams[i].result)}")
def process_features( def process_features(
model: nn.Module, model: nn.Module,
features: torch.Tensor, features: torch.Tensor,
streams: List[FeatureExtractionStream], streams: List[FeatureExtractionStream],
params: AttributeDict,
sp: spm.SentencePieceProcessor, sp: spm.SentencePieceProcessor,
) -> None: ) -> None:
"""Process features for each stream in parallel. """Process features for each stream in parallel.
@ -406,6 +539,8 @@ def process_features(
A 3-D tensor of shape (N, T, C). A 3-D tensor of shape (N, T, C).
streams: streams:
A list of streams of size (N,). A list of streams of size (N,).
params:
It is the return value of :func:`get_params`.
sp: sp:
The BPE model. The BPE model.
""" """
@ -439,12 +574,25 @@ def process_features(
for i, s in enumerate(state_list): for i, s in enumerate(state_list):
streams[i].states = s streams[i].states = s
greedy_search( if params.decoding_method == "greedy_search":
model=model, greedy_search(
streams=streams, model=model,
encoder_out=encoder_out, streams=streams,
sp=sp, encoder_out=encoder_out,
) sp=sp,
)
elif params.decoding_method == "modified_beam_search":
modified_beam_search(
model=model,
streams=streams,
encoder_out=encoder_out,
sp=sp,
beam=params.beam_size,
)
else:
raise ValueError(
f"Unsupported decoding method: {params.decoding_method}"
)
def decode_batch( def decode_batch(
@ -479,7 +627,7 @@ def decode_batch(
stream_list = StreamList( stream_list = StreamList(
batch_size=batch_size, batch_size=batch_size,
context_size=params.context_size, context_size=params.context_size,
blank_id=params.blank_id, decoding_method=params.decoding_method,
) )
while not streaming_audio_samples.done: while not streaming_audio_samples.done:
@ -497,11 +645,12 @@ def decode_batch(
model=model, model=model,
features=features, features=features,
streams=active_streams, streams=active_streams,
params=params,
sp=sp, sp=sp,
) )
results = [] results = []
for s in stream_list.streams: for s in stream_list.streams:
text = sp.decode(s.hyp.ys[params.context_size :]) text = sp.decode(s.result)
results.append(text) results.append(text)
return results return results

View File

@ -17,7 +17,7 @@
from typing import List, Optional from typing import List, Optional
import torch import torch
from beam_search import Hypothesis from beam_search import Hypothesis, HypothesisList
from kaldifeat import FbankOptions, OnlineFbank, OnlineFeature from kaldifeat import FbankOptions, OnlineFbank, OnlineFeature
@ -41,14 +41,10 @@ def _create_streaming_feature_extractor() -> OnlineFeature:
class FeatureExtractionStream(object): class FeatureExtractionStream(object):
def __init__(self, context_size: int, blank_id: int = 0) -> None: def __init__(
"""Context size of the RNN-T decoder model.""" self,
) -> None:
self.feature_extractor = _create_streaming_feature_extractor() self.feature_extractor = _create_streaming_feature_extractor()
self.hyp = Hypothesis(
ys=([blank_id] * context_size),
log_prob=torch.tensor([0.0]),
) # for greedy search, will extend it to beam search
# It contains a list of 1-D tensors representing the feature frames. # It contains a list of 1-D tensors representing the feature frames.
self.feature_frames: List[torch.Tensor] = [] self.feature_frames: List[torch.Tensor] = []
@ -58,11 +54,6 @@ class FeatureExtractionStream(object):
# encoder layer. # encoder layer.
self.states: Optional[List[List[torch.Tensor]]] = None self.states: Optional[List[List[torch.Tensor]]] = None
# For the RNN-T decoder, it contains the decoder output
# corresponding to the decoder input self.hyp.ys[-context_size:]
# Its shape is (decoder_out_dim,)
self.decoder_out: Optional[torch.Tensor] = None
# After calling `self.input_finished()`, we set this flag to True # After calling `self.input_finished()`, we set this flag to True
self._done = False self._done = False
@ -85,9 +76,9 @@ class FeatureExtractionStream(object):
check to ensure that the input sampling rate equals to the one check to ensure that the input sampling rate equals to the one
used in the extractor. If they are not equal, then no resampling used in the extractor. If they are not equal, then no resampling
will be performed; instead an error will be thrown. will be performed; instead an error will be thrown.
waveform: waveform:
A 1-D torch tensor of dtype torch.float32 containing audio samples. A 1-D torch tensor of dtype torch.float32 containing audio samples.
It should be on CPU. It should be on CPU.
""" """
self.feature_extractor.accept_waveform( self.feature_extractor.accept_waveform(
sampling_rate=sampling_rate, sampling_rate=sampling_rate,
@ -114,3 +105,33 @@ class FeatureExtractionStream(object):
frame = self.feature_extractor.get_frame(self.num_fetched_frames) frame = self.feature_extractor.get_frame(self.num_fetched_frames)
self.feature_frames.append(frame) self.feature_frames.append(frame)
self.num_fetched_frames += 1 self.num_fetched_frames += 1
class GreedySearchStream(FeatureExtractionStream):
def __init__(self, context_size: int) -> None:
"""FeatureExtractionStream class for greedy search."""
super().__init__()
self.context_size = context_size
# For the RNN-T decoder, it contains the decoder output
# corresponding to the decoder input self.hyp.ys[-context_size:]
# Its shape is (decoder_out_dim,)
self.hyp: Hypothesis = None
self.decoder_out: Optional[torch.Tensor] = None
@property
def result(self) -> List[int]:
return self.hyp.ys[self.context_size :]
class ModifiedBeamSearchStream(FeatureExtractionStream):
def __init__(self, context_size: int) -> None:
"""FeatureExtractionStream class for modified beam search decoding."""
super().__init__()
self.context_size = context_size
self.hyps = HypothesisList()
self.best_hyp = None
@property
def result(self) -> List[int]:
best_hyp = self.hyps.get_most_probable(length_norm=True)
return best_hyp.ys[self.context_size :]