Parallel streaming decode with greedy search.

This commit is contained in:
Fangjun Kuang 2022-04-12 12:58:20 +08:00
parent 85db12caac
commit 985707f38a
5 changed files with 427 additions and 148 deletions

View File

@ -27,6 +27,66 @@ from torchaudio.models import Emformer as _Emformer
LOG_EPSILON = math.log(1e-10)
def unstack_states(
states: List[List[torch.Tensor]],
) -> List[List[List[torch.Tensor]]]:
"""Unstack the emformer state corresponding to a batch of utterances
into a list of states, were the i-th entry is the state from the i-th
utterance in the batch.
Args:
states:
A list-of-list of tensors. ``len(states)`` equals to number of
layers in the emformer. ``states[i]]`` contains the states for
the i-th layer. ``states[i][k]`` is either a 3-D tensor of shape
``(T, N, C)`` or a 2-D tensor of shape ``(C, N)``
"""
batch_size = states[0][0].size(1)
num_layers = len(states)
ans = [None] * batch_size
for i in range(batch_size):
ans[i] = [[] for _ in range(num_layers)]
for li, layer in enumerate(states):
for s in layer:
s_list = s.unbind(dim=1)
for bi, b in enumerate(ans):
b[li].append(s_list[bi].unsqueeze(dim=1))
return ans
def stack_states(
state_list: List[List[List[torch.Tensor]]],
) -> List[List[torch.Tensor]]:
"""Stack list of emformer states that correspond to separate utterances
into a single emformer state so that it can be used as an input for
emformer when those utterances are formed into a batch.
Note:
It is the inverse of :func:`unstack_states`.
Args:
state_list:
Each element in state_list corresponding to the internal state
of the emformer model for a single utterance.
Returns:
Return a new state corresponding to a batch of utterances.
See the input argument of :func:`unstack_states` for the meaning
of the returned tensor.
"""
ans = []
for layer in state_list[0]:
# layer is a list of tensors
ans.append([s for s in layer])
for states in state_list[1:]:
for li, layer in enumerate(states):
for si, s in enumerate(layer):
ans[li][si] = torch.cat([ans[li][si], s], dim=1)
return ans
class Emformer(EncoderInterface):
"""This is just a simple wrapper around torchaudio.models.Emformer.
We may replace it with our own implementation some time later.

View File

@ -18,16 +18,16 @@
import argparse
import logging
import time
from pathlib import Path
from typing import List, Optional, Tuple
import numpy as np
import sentencepiece as spm
import torch
import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule
from emformer import LOG_EPSILON
from streaming_feature_extractor import Stream
from emformer import LOG_EPSILON, stack_states, unstack_states
from streaming_feature_extractor import FeatureExtractionStream
from train import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import (
@ -158,9 +158,166 @@ def get_parser():
return parser
class StreamingAudioSamples(object):
"""This class takes as input a list of audio samples and returns
them in a streaming fashion.
"""
def __init__(self, samples: List[torch.Tensor]) -> None:
"""
Args:
samples:
A list of audio samples. Each entry is a 1-D tensor of dtype
torch.float32, containing the audio samples of an utterance.
"""
self.samples = samples
self.cur_indexes = [0] * len(self.samples)
@property
def done(self) -> bool:
"""Return True if all samples have been processed.
Return False otherwise.
"""
for i, samples in zip(self.cur_indexes, self.samples):
if i < samples.numel():
return False
return True
def get_next(self) -> List[torch.Tensor]:
"""Return a list of audio samples. Each entry may have different
lengths. It is OK if an entry contains no samples at all, which
means it reaches the end of the utterance.
"""
ans = []
# Note: Either branch is fine. The purpose is to simulate streaming
if False:
num = torch.randint(2000, 5000, (len(self.samples),)).tolist()
else:
num = [1024] * len(self.samples)
for i in range(len(self.samples)):
start = self.cur_indexes[i]
end = start + num[i]
self.cur_indexes[i] = end
s = self.samples[i][start:end]
ans.append(s)
return ans
class StreamList(object):
def __init__(
self,
batch_size: int,
context_size: int,
blank_id: int,
):
"""
Args:
batch_size:
Size of this batch.
context_size:
Context size of the RNN-T decoder model.
blank_id:
The ID of the blank symbol of the BPE model.
"""
self.streams = [
FeatureExtractionStream(
context_size=context_size, blank_id=blank_id
)
for _ in range(batch_size)
]
@property
def done(self) -> bool:
"""Return True if all streams have reached end of utterance.
That is, no more audio samples are available for all utterances.
"""
return all(stream.done for stream in self.streams)
def accept_waveform(
self,
audio_samples: List[torch.Tensor],
sampling_rate: float,
):
"""Feeed audio samples to each stream.
Args:
audio_samples:
A list of 1-D tensors containing the audio samples for each
utterance in the batch. If an entry is empty, it means
end-of-utterance has been reached.
sampling_rate:
Sampling rate of the given audio samples.
"""
assert len(audio_samples) == len(self.streams)
for stream, samples in zip(self.streams, audio_samples):
if stream.done:
assert samples.numel() == 0
continue
stream.accept_waveform(
sampling_rate=sampling_rate,
waveform=samples,
)
if samples.numel() == 0:
stream.input_finished()
def build_batch(
self,
chunk_length: int,
segment_length: int,
) -> Tuple[Optional[torch.Tensor], Optional[List[FeatureExtractionStream]]]:
"""
Args:
chunk_length:
Number of frames for each chunk. It equals to
``segment_length + right_context_length``.
segment_length
Number of frames for each segment.
Returns:
Return a tuple containing:
- features, a 3-D tensor of shape ``(num_active_streams, T, C)``
- active_streams, a list of active streams. We say a stream is
active when it has enough feature frames to be fed into the
encoder model.
"""
feature_list = []
stream_list = []
for stream in self.streams:
if len(stream.feature_frames) >= chunk_length:
# this_chunk is a list of tensors, each of which
# has a shape (1, feature_dim)
chunk = stream.feature_frames[:chunk_length]
stream.feature_frames = stream.feature_frames[segment_length:]
features = torch.cat(chunk, dim=0).unsqueeze(0)
feature_list.append(features)
stream_list.append(stream)
elif stream.done and len(stream.feature_frames) > 0:
chunk = stream.feature_frames[:chunk_length]
stream.feature_frames = []
features = torch.cat(chunk, dim=0).unsqueeze(0)
features = torch.nn.functional.pad(
features,
(0, 0, 0, chunk_length - features.size(1)),
value=LOG_EPSILON,
)
feature_list.append(features)
stream_list.append(stream)
if len(feature_list) == 0:
return None, None
features = torch.cat(feature_list, dim=0)
return features, stream_list
def greedy_search(
model: nn.Module,
stream: Stream,
streams: List[FeatureExtractionStream],
encoder_out: torch.Tensor,
sp: spm.SentencePieceProcessor,
):
@ -171,7 +328,7 @@ def greedy_search(
stream:
A stream object.
encoder_out:
A 2-D tensor of shape (T, encoder_out_dim) containing the output of
A 3-D tensor of shape (N, T, encoder_out_dim) containing the output of
the encoder model.
sp:
The BPE model.
@ -180,59 +337,130 @@ def greedy_search(
context_size = model.decoder.context_size
device = model.device
if stream.decoder_out is None:
if streams[0].decoder_out is None:
decoder_input = torch.tensor(
[stream.hyp.ys[-context_size:]],
[stream.hyp.ys[-context_size:] for stream in streams],
device=device,
dtype=torch.int64,
)
stream.decoder_out = model.decoder(
decoder_out = model.decoder(
decoder_input,
need_pad=False,
).unsqueeze(1)
# stream.decoder_out is of shape (1, 1, decoder_out_dim)
# decoder_out is of shape (N, 1, decoder_out_dim)
else:
decoder_out = torch.cat(
[stream.decoder_out for stream in streams],
dim=0,
)
assert encoder_out.ndim == 2
assert encoder_out.ndim == 3
T = encoder_out.size(0)
T = encoder_out.size(1)
for t in range(T):
current_encoder_out = encoder_out[t].reshape(
1, 1, 1, encoder_out.size(-1)
)
logits = model.joiner(current_encoder_out, stream.decoder_out)
# logits is of shape (1, 1, 1, vocab_size)
y = logits.argmax().item()
if y == blank_id:
continue
stream.hyp.ys.append(y)
current_encoder_out = encoder_out[:, t : t + 1, :].unsqueeze(2) # noqa
# current_encoder_out's shape: (batch_size, 1, 1, encoder_out_dim)
decoder_input = torch.tensor(
[stream.hyp.ys[-context_size:]],
device=device,
dtype=torch.int64,
)
logits = model.joiner(current_encoder_out, decoder_out)
# logits'shape (batch_size, 1, 1, vocab_size)
stream.decoder_out = model.decoder(
decoder_input,
need_pad=False,
).unsqueeze(1)
logits = logits.squeeze(1).squeeze(1) # (batch_size, vocab_size)
assert logits.ndim == 2, logits.shape
y = logits.argmax(dim=1).tolist()
emitted = False
for i, v in enumerate(y):
if v != blank_id:
streams[i].hyp.ys.append(v)
emitted = True
logging.info(
f"Partial result:\n{sp.decode(stream.hyp.ys[context_size:])}"
)
if emitted:
# update decoder output
decoder_input = torch.tensor(
[stream.hyp.ys[-context_size:] for stream in streams],
device=device,
dtype=torch.int64,
)
decoder_out = model.decoder(
decoder_input, need_pad=False
).unsqueeze(1)
for k, s in enumerate(streams):
logging.info(
f"Partial result {k}:\n{sp.decode(s.hyp.ys[context_size:])}"
)
decoder_out_list = decoder_out.unbind(dim=0)
for i, d in enumerate(decoder_out_list):
streams[i].decoder_out = d.unsqueeze(0)
def process_feature_frames(
def process_features(
model: nn.Module,
stream: Stream,
features: torch.Tensor,
streams: List[FeatureExtractionStream],
sp: spm.SentencePieceProcessor,
):
"""Process the feature frames contained in ``stream.feature_frames``.
) -> None:
"""Process features for each stream in parallel.
Args:
model:
The RNN-T model.
stream:
The stream corresponding to the input audio samples.
features:
A 3-D tensor of shape (N, T, C).
streams:
A list of streams of size (N,).
sp:
The BPE model.
"""
assert features.ndim == 3
assert features.size(0) == len(streams)
batch_size = features.size(0)
device = model.device
features = features.to(device)
feature_lens = torch.full(
(batch_size,),
fill_value=features.size(1),
device=device,
)
if streams[0].states is None:
states = None
else:
state_list = [stream.states for stream in streams]
states = stack_states(state_list)
(encoder_out, encoder_out_lens, states,) = model.encoder.streaming_forward(
features,
feature_lens,
states,
)
state_list = unstack_states(states)
for i, s in enumerate(state_list):
streams[i].states = s
greedy_search(
model=model,
streams=streams,
encoder_out=encoder_out,
sp=sp,
)
def decode_batch(
batched_samples: List[torch.Tensor],
model: nn.Module,
params: AttributeDict,
sp: spm.SentencePieceProcessor,
) -> List[str]:
"""
Args:
batched_samples:
A list of 1-D tensors containing the audio samples of each utterance.
model:
The RNN-T model.
params:
It is the return value of :func:`get_params`.
sp:
The BPE model.
"""
@ -241,102 +469,41 @@ def process_feature_frames(
right_context_length = model.encoder.right_context_length
# We add 3 here since the subsampling method is using
# ((len - 1) // 2 - 1) // 2)
chunk_length = (segment_length + 3) + right_context_length
device = model.device
while len(stream.feature_frames) >= chunk_length:
# a list of tensor, each with a shape (1, feature_dim)
this_chunk = stream.feature_frames[:chunk_length]
batch_size = len(batched_samples)
streaming_audio_samples = StreamingAudioSamples(batched_samples)
stream.feature_frames = stream.feature_frames[segment_length:]
features = torch.cat(this_chunk, dim=0).to(device) # (T, feature_dim)
features = features.unsqueeze(0) # (1, T, feature_dim)
feature_lens = torch.tensor([features.size(1)], device=device)
(
encoder_out,
encoder_out_lens,
stream.states,
) = model.encoder.streaming_forward(
features,
feature_lens,
stream.states,
stream_list = StreamList(
batch_size=batch_size,
context_size=params.context_size,
blank_id=params.blank_id,
)
while not streaming_audio_samples.done:
samples = streaming_audio_samples.get_next()
stream_list.accept_waveform(
audio_samples=samples,
sampling_rate=params.sampling_rate,
)
greedy_search(
model=model,
stream=stream,
encoder_out=encoder_out[0],
sp=sp,
features, active_streams = stream_list.build_batch(
chunk_length=chunk_length,
segment_length=segment_length,
)
if stream.feature_extractor.is_last_frame(stream.num_fetched_frames - 1):
assert len(stream.feature_frames) < chunk_length
if len(stream.feature_frames) > 0:
this_chunk = stream.feature_frames[:chunk_length]
stream.feature_frames = []
features = torch.cat(this_chunk, dim=0) # (T, feature_dim)
features = features.to(device).unsqueeze(0) # (1, T, feature_dim)
features = torch.nn.functional.pad(
features,
(0, 0, 0, chunk_length - features.size(1)),
value=LOG_EPSILON,
)
feature_lens = torch.tensor([features.size(1)], device=device)
(
encoder_out,
encoder_out_lens,
stream.states,
) = model.encoder.streaming_forward(
features,
feature_lens,
stream.states,
)
greedy_search(
if features is not None:
process_features(
model=model,
stream=stream,
encoder_out=encoder_out[0],
features=features,
streams=active_streams,
sp=sp,
)
def decode_one_utterance(
audio_samples: torch.Tensor,
model: nn.Module,
stream: Stream,
params: AttributeDict,
sp: spm.SentencePieceProcessor,
):
"""Decode one utterance.
Args:
audio_samples:
A 1-D float32 tensor of shape (num_samples,) containing the
audio samples.
model:
The RNN-T model.
feature_extractor:
The feature extractor.
params:
It is the return value of :func:`get_params`.
sp:
The BPE model.
"""
i = 0
num_samples = audio_samples.size(0)
while i < num_samples:
# Simulate streaming.
this_chunk_num_samples = torch.randint(2000, 5000, (1,)).item()
thiks_chunk_samples = audio_samples[i : (i + this_chunk_num_samples)]
i += this_chunk_num_samples
stream.accept_waveform(
sampling_rate=params.sampling_rate,
waveform=thiks_chunk_samples,
)
process_feature_frames(model=model, stream=stream, sp=sp)
stream.input_finished()
process_feature_frames(model=model, stream=stream, sp=sp)
results = []
for s in stream_list.streams:
text = sp.decode(s.hyp.ys[params.context_size :])
results.append(text)
return results
@torch.no_grad()
@ -403,31 +570,41 @@ def main():
test_clean_cuts = librispeech.test_clean_cuts()
for num, cut in enumerate(test_clean_cuts):
logging.info(f"Processing {num}")
stream = Stream(
context_size=model.decoder.context_size,
blank_id=model.decoder.blank_id,
)
batch_size = 3
ground_truth = []
batched_samples = []
for num, cut in enumerate(test_clean_cuts):
audio: np.ndarray = cut.load_audio()
# audio.shape: (1, num_samples)
assert len(audio.shape) == 2
assert audio.shape[0] == 1, "Should be single channel"
assert audio.dtype == np.float32, audio.dtype
assert audio.max() <= 1, "Should be normalized to [-1, 1])"
decode_one_utterance(
audio_samples=torch.from_numpy(audio).squeeze(0).to(device),
model=model,
stream=stream,
params=params,
sp=sp,
)
logging.info(f"The ground truth is:\n{cut.supervisions[0].text}")
if num >= 2:
# The trained model is using normalized samples
assert audio.max() <= 1, "Should be normalized to [-1, 1])"
samples = torch.from_numpy(audio).squeeze(0)
batched_samples.append(samples)
ground_truth.append(cut.supervisions[0].text)
if len(batched_samples) >= batch_size:
decoded_results = decode_batch(
batched_samples=batched_samples,
model=model,
params=params,
sp=sp,
)
s = "\n"
for i, (hyp, ref) in enumerate(zip(decoded_results, ground_truth)):
s += f"hyp {i}:\n{hyp}\n"
s += f"ref {i}:\n{ref}\n\n"
logging.info(s)
batched_samples = []
ground_truth = []
# break after processing the first batch for test purposes
break
time.sleep(2) # So that you can see the decoded results
if __name__ == "__main__":

View File

@ -40,7 +40,7 @@ def _create_streaming_feature_extractr() -> OnlineFeature:
return OnlineFbank(opts)
class Stream(object):
class FeatureExtractionStream(object):
def __init__(self, context_size: int, blank_id: int = 0) -> None:
"""Context size of the RNN-T decoder model."""
self.feature_extractor = _create_streaming_feature_extractr()
@ -62,6 +62,9 @@ class Stream(object):
# corresponding to the decoder input self.hyp.ys[-context_size:]
self.decoder_out: Optional[torch.Tensor] = None
# After calling `self.input_finished()`, we set this flag to True
self._done = False
def accept_waveform(
self,
sampling_rate: float,
@ -97,6 +100,12 @@ class Stream(object):
"""
self.feature_extractor.input_finished()
self._fetch_frames()
self._done = True
@property
def done(self) -> bool:
"""Return True if `self.input_finished()` has been invoked"""
return self._done
def _fetch_frames(self) -> None:
"""Fetch frames from the feature extractor"""

View File

@ -25,7 +25,7 @@ To run this file, do:
import warnings
import torch
from emformer import Emformer
from emformer import Emformer, stack_states, unstack_states
def test_emformer():
@ -65,8 +65,41 @@ def test_emformer():
print(f"Number of encoder parameters: {num_param}")
def test_emformer_streaming_forward():
N = 3
C = 80
output_dim = 500
encoder = Emformer(
num_features=C,
output_dim=output_dim,
d_model=512,
nhead=8,
dim_feedforward=2048,
num_encoder_layers=20,
segment_length=16,
left_context_length=120,
right_context_length=4,
vgg_frontend=False,
)
x = torch.rand(N, 23, C)
x_lens = torch.full((N,), 23)
y, y_lens, states = encoder.streaming_forward(x=x, x_lens=x_lens)
state_list = unstack_states(states)
states2 = stack_states(state_list)
for ss, ss2 in zip(states, states2):
for s, s2 in zip(ss, ss2):
assert torch.allclose(s, s2), f"{s.sum()}, {s2.sum()}"
@torch.no_grad()
def main():
test_emformer()
# test_emformer()
test_emformer_streaming_forward()
if __name__ == "__main__":

View File

@ -24,11 +24,11 @@ To run this file, do:
"""
import torch
from streaming_feature_extractor import Stream
from streaming_feature_extractor import FeatureExtractionStream
def test_streaming_feature_extractor():
stream = Stream(context_size=2, blank_id=0)
stream = FeatureExtractionStream(context_size=2, blank_id=0)
samples = torch.rand(16000)
start = 0
while True: