Fixed streaming decoding codes for emformer model.

This commit is contained in:
yaozengwei 2022-04-21 19:48:35 +08:00
parent e74654c2a2
commit cf0ce8db32
3 changed files with 74 additions and 98 deletions

View File

@ -367,7 +367,7 @@ class HypothesisList(object):
return ", ".join(s)
def _get_hyps_shape(hyps: List[HypothesisList]) -> k2.RaggedShape:
def get_hyps_shape(hyps: List[HypothesisList]) -> k2.RaggedShape:
"""Return a ragged shape with axes [utt][num_hyps].
Args:
@ -431,7 +431,7 @@ def modified_beam_search(
current_encoder_out = encoder_out[:, t : t + 1, :].unsqueeze(2) # noqa
# current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim)
hyps_shape = _get_hyps_shape(B).to(device)
hyps_shape = get_hyps_shape(B).to(device)
A = [list(b) for b in B]
B = [HypothesisList() for _ in range(batch_size)]

View File

@ -28,16 +28,16 @@ import sentencepiece as spm
import torch
import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule
from beam_search import Hypothesis, HypothesisList, _get_hyps_shape
from beam_search import Hypothesis, HypothesisList, get_hyps_shape
from emformer import LOG_EPSILON, stack_states, unstack_states
from streaming_feature_extractor import (
FeatureExtractionStream,
GreedySearchStream,
ModifiedBeamSearchStream,
)
from streaming_feature_extractor import FeatureExtractionStream
from train import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint
from icefall.checkpoint import (
average_checkpoints,
find_checkpoints,
load_checkpoint,
)
from icefall.utils import AttributeDict, setup_logger
@ -225,16 +225,12 @@ class StreamList(object):
- greedy_search
- modified_beam_search
"""
decoding_classes = {
"greedy_search": GreedySearchStream,
"modified_beam_search": ModifiedBeamSearchStream,
}
assert decoding_method in decoding_classes
cls = decoding_classes[decoding_method]
self.streams = [
cls(context_size=context_size) for _ in range(batch_size)
FeatureExtractionStream(
context_size=context_size, decoding_method=decoding_method
)
for _ in range(batch_size)
]
@property
@ -325,7 +321,7 @@ class StreamList(object):
def greedy_search(
model: nn.Module,
streams: List[GreedySearchStream],
streams: List[FeatureExtractionStream],
encoder_out: torch.Tensor,
sp: spm.SentencePieceProcessor,
):
@ -333,36 +329,31 @@ def greedy_search(
Args:
model:
The RNN-T model.
stream:
A stream object.
streams:
A list of GreedySearchDecodingStream objects.
encoder_out:
A 3-D tensor of shape (N, T, encoder_out_dim) containing the output of
the encoder model.
sp:
The BPE model.
"""
blank_id = model.decoder.blank_id
context_size = model.decoder.context_size
device = model.device
assert len(streams) == encoder_out.size(0)
assert encoder_out.ndim == 3
for s in streams:
if s.hyp is None:
s.hyp = Hypothesis(
ys=([blank_id] * context_size),
log_prob=torch.tensor([0.0], device=device),
)
blank_id = model.decoder.blank_id
context_size = model.decoder.context_size
device = model.device
T = encoder_out.size(1)
if streams[0].decoder_out is None:
for stream in streams:
stream.hyp = [blank_id] * context_size
decoder_input = torch.tensor(
[stream.hyp.ys[-context_size:] for stream in streams],
[stream.hyp[-context_size:] for stream in streams],
device=device,
dtype=torch.int64,
)
decoder_out = model.decoder(
decoder_input,
need_pad=False,
).squeeze(1)
decoder_out = model.decoder(decoder_input, need_pad=False).squeeze(1)
# decoder_out is of shape (N, decoder_out_dim)
else:
decoder_out = torch.stack(
@ -370,7 +361,6 @@ def greedy_search(
dim=0,
)
T = encoder_out.size(1)
for t in range(T):
current_encoder_out = encoder_out[:, t]
# current_encoder_out's shape: (batch_size, encoder_out_dim)
@ -383,22 +373,23 @@ def greedy_search(
emitted = False
for i, v in enumerate(y):
if v != blank_id:
streams[i].hyp.ys.append(v)
streams[i].hyp.append(v)
emitted = True
if emitted:
# update decoder output
decoder_input = torch.tensor(
[stream.hyp.ys[-context_size:] for stream in streams],
[stream.hyp[-context_size:] for stream in streams],
device=device,
dtype=torch.int64,
)
decoder_out = model.decoder(decoder_input, need_pad=False).squeeze(
1
)
decoder_out = model.decoder(
decoder_input,
need_pad=False,
).squeeze(1)
for k, s in enumerate(streams):
logging.info(f"Partial result {k}:\n{sp.decode(s.result)}")
for k, stream in enumerate(streams):
result = sp.decode(stream.decoding_result())
logging.info(f"Partial result {k}:\n{result}")
decoder_out_list = decoder_out.unbind(dim=0)
for i, d in enumerate(decoder_out_list):
@ -407,7 +398,7 @@ def greedy_search(
def modified_beam_search(
model: nn.Module,
streams: List[ModifiedBeamSearchStream],
streams: List[FeatureExtractionStream],
encoder_out: torch.Tensor,
sp: spm.SentencePieceProcessor,
beam: int = 4,
@ -426,36 +417,35 @@ def modified_beam_search(
beam:
Number of active paths during the beam search.
"""
assert encoder_out.ndim == 3, encoder_out.shape
assert len(streams) == encoder_out.size(0)
blank_id = model.decoder.blank_id
context_size = model.decoder.context_size
device = model.device
assert encoder_out.ndim == 3, encoder_out.shape
assert len(streams) == encoder_out.size(0)
batch_size = len(streams)
T = encoder_out.size(1)
for s in streams:
if len(s.hyps) == 0:
s.hyps.add(
for stream in streams:
if len(stream.hyps) == 0:
stream.hyps.add(
Hypothesis(
ys=[blank_id] * context_size,
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
)
)
B = [s.hyps for s in streams]
T = encoder_out.size(1)
B = [stream.hyps for stream in streams]
for t in range(T):
current_encoder_out = encoder_out[:, t]
# current_encoder_out's shape: (batch_size, encoder_out_dim)
hyps_shape = _get_hyps_shape(B).to(device)
hyps_shape = get_hyps_shape(B).to(device)
A = [list(b) for b in B]
B = [HypothesisList() for _ in range(batch_size)]
ys_log_probs = torch.cat(
[hyp.log_prob.reshape(1, 1) for hyps in A for hyp in hyps]
ys_log_probs = torch.stack(
[hyp.log_prob.reshape(1) for hyps in A for hyp in hyps], dim=0
) # (num_hyps, 1)
decoder_input = torch.tensor(
@ -516,7 +506,8 @@ def modified_beam_search(
B[i].add(new_hyp)
streams[i].hyps = B[i]
logging.info(f"Partial result {i}:\n{sp.decode(streams[i].result)}")
result = sp.decode(streams[i].decoding_result())
logging.info(f"Partial result {i}:\n{result}")
def process_features(
@ -645,8 +636,8 @@ def decode_batch(
sp=sp,
)
results = []
for s in stream_list.streams:
text = sp.decode(s.result)
for stream in stream_list.streams:
text = sp.decode(stream.decoding_result())
results.append(text)
return results

View File

@ -14,11 +14,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional
import torch
from beam_search import Hypothesis, HypothesisList
from beam_search import HypothesisList
from kaldifeat import FbankOptions, OnlineFbank, OnlineFeature
from typing import List, Optional
import torch
def _create_streaming_feature_extractor() -> OnlineFeature:
@ -41,21 +40,28 @@ def _create_streaming_feature_extractor() -> OnlineFeature:
class FeatureExtractionStream(object):
def __init__(
self,
) -> None:
def __init__(self, context_size: int, decoding_method: str) -> None:
self.feature_extractor = _create_streaming_feature_extractor()
# It contains a list of 1-D tensors representing the feature frames.
self.feature_frames: List[torch.Tensor] = []
self.num_fetched_frames = 0
# After calling `self.input_finished()`, we set this flag to True
self._done = False
# For the emformer model, it contains the states of each
# encoder layer.
self.states: Optional[List[List[torch.Tensor]]] = None
# After calling `self.input_finished()`, we set this flag to True
self._done = False
# It use different attributes for different decoding methods.
self.context_size = context_size
self.decoding_method = decoding_method
if decoding_method == "greedy_search":
self.hyp: List[int] = None
self.decoder_out: Optional[torch.Tensor] = None
elif decoding_method == "modified_beam_search":
self.hyps = HypothesisList()
else:
raise ValueError(f"Unsupported decoding method: {decoding_method}")
def accept_waveform(
self,
@ -106,32 +112,11 @@ class FeatureExtractionStream(object):
self.feature_frames.append(frame)
self.num_fetched_frames += 1
class GreedySearchStream(FeatureExtractionStream):
def __init__(self, context_size: int) -> None:
"""FeatureExtractionStream class for greedy search."""
super().__init__()
self.context_size = context_size
# For the RNN-T decoder, it contains the decoder output
# corresponding to the decoder input self.hyp.ys[-context_size:]
# Its shape is (decoder_out_dim,)
self.hyp: Hypothesis = None
self.decoder_out: Optional[torch.Tensor] = None
@property
def result(self) -> List[int]:
return self.hyp.ys[self.context_size :]
class ModifiedBeamSearchStream(FeatureExtractionStream):
def __init__(self, context_size: int) -> None:
"""FeatureExtractionStream class for modified beam search decoding."""
super().__init__()
self.context_size = context_size
self.hyps = HypothesisList()
self.best_hyp = None
@property
def result(self) -> List[int]:
best_hyp = self.hyps.get_most_probable(length_norm=True)
return best_hyp.ys[self.context_size :]
def decoding_result(self) -> List[int]:
"""Obtain current decoding result."""
if self.decoding_method == "greedy_search":
return self.hyp[self.context_size :]
else:
assert self.decoding_method == "modified_beam_search"
best_hyp = self.hyps.get_most_probable(length_norm=True)
return best_hyp.ys[self.context_size :]