Fixed streaming decoding codes for emformer model.

This commit is contained in:
yaozengwei 2022-04-21 19:48:35 +08:00
parent e74654c2a2
commit cf0ce8db32
3 changed files with 74 additions and 98 deletions

View File

@ -367,7 +367,7 @@ class HypothesisList(object):
return ", ".join(s) return ", ".join(s)
def _get_hyps_shape(hyps: List[HypothesisList]) -> k2.RaggedShape: def get_hyps_shape(hyps: List[HypothesisList]) -> k2.RaggedShape:
"""Return a ragged shape with axes [utt][num_hyps]. """Return a ragged shape with axes [utt][num_hyps].
Args: Args:
@ -431,7 +431,7 @@ def modified_beam_search(
current_encoder_out = encoder_out[:, t : t + 1, :].unsqueeze(2) # noqa current_encoder_out = encoder_out[:, t : t + 1, :].unsqueeze(2) # noqa
# current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim) # current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim)
hyps_shape = _get_hyps_shape(B).to(device) hyps_shape = get_hyps_shape(B).to(device)
A = [list(b) for b in B] A = [list(b) for b in B]
B = [HypothesisList() for _ in range(batch_size)] B = [HypothesisList() for _ in range(batch_size)]

View File

@ -28,16 +28,16 @@ import sentencepiece as spm
import torch import torch
import torch.nn as nn import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule from asr_datamodule import LibriSpeechAsrDataModule
from beam_search import Hypothesis, HypothesisList, _get_hyps_shape from beam_search import Hypothesis, HypothesisList, get_hyps_shape
from emformer import LOG_EPSILON, stack_states, unstack_states from emformer import LOG_EPSILON, stack_states, unstack_states
from streaming_feature_extractor import ( from streaming_feature_extractor import FeatureExtractionStream
FeatureExtractionStream,
GreedySearchStream,
ModifiedBeamSearchStream,
)
from train import add_model_arguments, get_params, get_transducer_model from train import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint from icefall.checkpoint import (
average_checkpoints,
find_checkpoints,
load_checkpoint,
)
from icefall.utils import AttributeDict, setup_logger from icefall.utils import AttributeDict, setup_logger
@ -225,16 +225,12 @@ class StreamList(object):
- greedy_search - greedy_search
- modified_beam_search - modified_beam_search
""" """
decoding_classes = {
"greedy_search": GreedySearchStream,
"modified_beam_search": ModifiedBeamSearchStream,
}
assert decoding_method in decoding_classes
cls = decoding_classes[decoding_method]
self.streams = [ self.streams = [
cls(context_size=context_size) for _ in range(batch_size) FeatureExtractionStream(
context_size=context_size, decoding_method=decoding_method
)
for _ in range(batch_size)
] ]
@property @property
@ -325,7 +321,7 @@ class StreamList(object):
def greedy_search( def greedy_search(
model: nn.Module, model: nn.Module,
streams: List[GreedySearchStream], streams: List[FeatureExtractionStream],
encoder_out: torch.Tensor, encoder_out: torch.Tensor,
sp: spm.SentencePieceProcessor, sp: spm.SentencePieceProcessor,
): ):
@ -333,36 +329,31 @@ def greedy_search(
Args: Args:
model: model:
The RNN-T model. The RNN-T model.
stream: streams:
A stream object. A list of GreedySearchDecodingStream objects.
encoder_out: encoder_out:
A 3-D tensor of shape (N, T, encoder_out_dim) containing the output of A 3-D tensor of shape (N, T, encoder_out_dim) containing the output of
the encoder model. the encoder model.
sp: sp:
The BPE model. The BPE model.
""" """
blank_id = model.decoder.blank_id
context_size = model.decoder.context_size
device = model.device
assert len(streams) == encoder_out.size(0) assert len(streams) == encoder_out.size(0)
assert encoder_out.ndim == 3 assert encoder_out.ndim == 3
for s in streams: blank_id = model.decoder.blank_id
if s.hyp is None: context_size = model.decoder.context_size
s.hyp = Hypothesis( device = model.device
ys=([blank_id] * context_size), T = encoder_out.size(1)
log_prob=torch.tensor([0.0], device=device),
)
if streams[0].decoder_out is None: if streams[0].decoder_out is None:
for stream in streams:
stream.hyp = [blank_id] * context_size
decoder_input = torch.tensor( decoder_input = torch.tensor(
[stream.hyp.ys[-context_size:] for stream in streams], [stream.hyp[-context_size:] for stream in streams],
device=device, device=device,
dtype=torch.int64, dtype=torch.int64,
) )
decoder_out = model.decoder( decoder_out = model.decoder(decoder_input, need_pad=False).squeeze(1)
decoder_input,
need_pad=False,
).squeeze(1)
# decoder_out is of shape (N, decoder_out_dim) # decoder_out is of shape (N, decoder_out_dim)
else: else:
decoder_out = torch.stack( decoder_out = torch.stack(
@ -370,7 +361,6 @@ def greedy_search(
dim=0, dim=0,
) )
T = encoder_out.size(1)
for t in range(T): for t in range(T):
current_encoder_out = encoder_out[:, t] current_encoder_out = encoder_out[:, t]
# current_encoder_out's shape: (batch_size, encoder_out_dim) # current_encoder_out's shape: (batch_size, encoder_out_dim)
@ -383,22 +373,23 @@ def greedy_search(
emitted = False emitted = False
for i, v in enumerate(y): for i, v in enumerate(y):
if v != blank_id: if v != blank_id:
streams[i].hyp.ys.append(v) streams[i].hyp.append(v)
emitted = True emitted = True
if emitted: if emitted:
# update decoder output # update decoder output
decoder_input = torch.tensor( decoder_input = torch.tensor(
[stream.hyp.ys[-context_size:] for stream in streams], [stream.hyp[-context_size:] for stream in streams],
device=device, device=device,
dtype=torch.int64, dtype=torch.int64,
) )
decoder_out = model.decoder(decoder_input, need_pad=False).squeeze( decoder_out = model.decoder(
1 decoder_input,
) need_pad=False,
).squeeze(1)
for k, s in enumerate(streams): for k, stream in enumerate(streams):
logging.info(f"Partial result {k}:\n{sp.decode(s.result)}") result = sp.decode(stream.decoding_result())
logging.info(f"Partial result {k}:\n{result}")
decoder_out_list = decoder_out.unbind(dim=0) decoder_out_list = decoder_out.unbind(dim=0)
for i, d in enumerate(decoder_out_list): for i, d in enumerate(decoder_out_list):
@ -407,7 +398,7 @@ def greedy_search(
def modified_beam_search( def modified_beam_search(
model: nn.Module, model: nn.Module,
streams: List[ModifiedBeamSearchStream], streams: List[FeatureExtractionStream],
encoder_out: torch.Tensor, encoder_out: torch.Tensor,
sp: spm.SentencePieceProcessor, sp: spm.SentencePieceProcessor,
beam: int = 4, beam: int = 4,
@ -426,36 +417,35 @@ def modified_beam_search(
beam: beam:
Number of active paths during the beam search. Number of active paths during the beam search.
""" """
assert encoder_out.ndim == 3, encoder_out.shape
assert len(streams) == encoder_out.size(0)
blank_id = model.decoder.blank_id blank_id = model.decoder.blank_id
context_size = model.decoder.context_size context_size = model.decoder.context_size
device = model.device device = model.device
assert encoder_out.ndim == 3, encoder_out.shape
assert len(streams) == encoder_out.size(0)
batch_size = len(streams) batch_size = len(streams)
T = encoder_out.size(1)
for s in streams: for stream in streams:
if len(s.hyps) == 0: if len(stream.hyps) == 0:
s.hyps.add( stream.hyps.add(
Hypothesis( Hypothesis(
ys=[blank_id] * context_size, ys=[blank_id] * context_size,
log_prob=torch.zeros(1, dtype=torch.float32, device=device), log_prob=torch.zeros(1, dtype=torch.float32, device=device),
) )
) )
B = [stream.hyps for stream in streams]
B = [s.hyps for s in streams]
T = encoder_out.size(1)
for t in range(T): for t in range(T):
current_encoder_out = encoder_out[:, t] current_encoder_out = encoder_out[:, t]
# current_encoder_out's shape: (batch_size, encoder_out_dim) # current_encoder_out's shape: (batch_size, encoder_out_dim)
hyps_shape = _get_hyps_shape(B).to(device) hyps_shape = get_hyps_shape(B).to(device)
A = [list(b) for b in B] A = [list(b) for b in B]
B = [HypothesisList() for _ in range(batch_size)] B = [HypothesisList() for _ in range(batch_size)]
ys_log_probs = torch.cat( ys_log_probs = torch.stack(
[hyp.log_prob.reshape(1, 1) for hyps in A for hyp in hyps] [hyp.log_prob.reshape(1) for hyps in A for hyp in hyps], dim=0
) # (num_hyps, 1) ) # (num_hyps, 1)
decoder_input = torch.tensor( decoder_input = torch.tensor(
@ -516,7 +506,8 @@ def modified_beam_search(
B[i].add(new_hyp) B[i].add(new_hyp)
streams[i].hyps = B[i] streams[i].hyps = B[i]
logging.info(f"Partial result {i}:\n{sp.decode(streams[i].result)}") result = sp.decode(streams[i].decoding_result())
logging.info(f"Partial result {i}:\n{result}")
def process_features( def process_features(
@ -645,8 +636,8 @@ def decode_batch(
sp=sp, sp=sp,
) )
results = [] results = []
for s in stream_list.streams: for stream in stream_list.streams:
text = sp.decode(s.result) text = sp.decode(stream.decoding_result())
results.append(text) results.append(text)
return results return results

View File

@ -14,11 +14,10 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import List, Optional from beam_search import HypothesisList
import torch
from beam_search import Hypothesis, HypothesisList
from kaldifeat import FbankOptions, OnlineFbank, OnlineFeature from kaldifeat import FbankOptions, OnlineFbank, OnlineFeature
from typing import List, Optional
import torch
def _create_streaming_feature_extractor() -> OnlineFeature: def _create_streaming_feature_extractor() -> OnlineFeature:
@ -41,21 +40,28 @@ def _create_streaming_feature_extractor() -> OnlineFeature:
class FeatureExtractionStream(object): class FeatureExtractionStream(object):
def __init__( def __init__(self, context_size: int, decoding_method: str) -> None:
self,
) -> None:
self.feature_extractor = _create_streaming_feature_extractor() self.feature_extractor = _create_streaming_feature_extractor()
# It contains a list of 1-D tensors representing the feature frames. # It contains a list of 1-D tensors representing the feature frames.
self.feature_frames: List[torch.Tensor] = [] self.feature_frames: List[torch.Tensor] = []
self.num_fetched_frames = 0 self.num_fetched_frames = 0
# After calling `self.input_finished()`, we set this flag to True
self._done = False
# For the emformer model, it contains the states of each # For the emformer model, it contains the states of each
# encoder layer. # encoder layer.
self.states: Optional[List[List[torch.Tensor]]] = None self.states: Optional[List[List[torch.Tensor]]] = None
# After calling `self.input_finished()`, we set this flag to True # It use different attributes for different decoding methods.
self._done = False self.context_size = context_size
self.decoding_method = decoding_method
if decoding_method == "greedy_search":
self.hyp: List[int] = None
self.decoder_out: Optional[torch.Tensor] = None
elif decoding_method == "modified_beam_search":
self.hyps = HypothesisList()
else:
raise ValueError(f"Unsupported decoding method: {decoding_method}")
def accept_waveform( def accept_waveform(
self, self,
@ -106,32 +112,11 @@ class FeatureExtractionStream(object):
self.feature_frames.append(frame) self.feature_frames.append(frame)
self.num_fetched_frames += 1 self.num_fetched_frames += 1
def decoding_result(self) -> List[int]:
class GreedySearchStream(FeatureExtractionStream): """Obtain current decoding result."""
def __init__(self, context_size: int) -> None: if self.decoding_method == "greedy_search":
"""FeatureExtractionStream class for greedy search.""" return self.hyp[self.context_size :]
super().__init__() else:
self.context_size = context_size assert self.decoding_method == "modified_beam_search"
# For the RNN-T decoder, it contains the decoder output best_hyp = self.hyps.get_most_probable(length_norm=True)
# corresponding to the decoder input self.hyp.ys[-context_size:] return best_hyp.ys[self.context_size :]
# Its shape is (decoder_out_dim,)
self.hyp: Hypothesis = None
self.decoder_out: Optional[torch.Tensor] = None
@property
def result(self) -> List[int]:
return self.hyp.ys[self.context_size :]
class ModifiedBeamSearchStream(FeatureExtractionStream):
def __init__(self, context_size: int) -> None:
"""FeatureExtractionStream class for modified beam search decoding."""
super().__init__()
self.context_size = context_size
self.hyps = HypothesisList()
self.best_hyp = None
@property
def result(self) -> List[int]:
best_hyp = self.hyps.get_most_probable(length_norm=True)
return best_hyp.ys[self.context_size :]