modify emformer states stack and unstack, streaming decoding, to be continued

This commit is contained in:
yaozengwei 2022-06-07 23:57:20 +08:00
parent 5df1406684
commit f8071e9373
3 changed files with 274 additions and 237 deletions

View File

@ -37,6 +37,112 @@ from scaling import (
from icefall.utils import make_pad_mask
LOG_EPSILON = math.log(1e-10)
def unstack_states(
states,
) -> List[List[List[torch.Tensor]]]:
# TODO: modify doc
"""Unstack the emformer state corresponding to a batch of utterances
into a list of states, were the i-th entry is the state from the i-th
utterance in the batch.
Args:
states:
A list-of-list of tensors. ``len(states)`` equals to number of
layers in the emformer. ``states[i]]`` contains the states for
the i-th layer. ``states[i][k]`` is either a 3-D tensor of shape
``(T, N, C)`` or a 2-D tensor of shape ``(C, N)``
"""
past_lens, attn_caches, conv_caches = states
batch_size = past_lens.size(0)
num_layers = len(attn_caches)
list_past_len = past_lens.tolist()
list_attn_caches = [None] * batch_size
for i in range(batch_size):
list_attn_caches[i] = [[] for _ in range(num_layers)]
for li, layer in enumerate(attn_caches):
for s in layer:
s_list = s.unbind(dim=1)
for bi, b in enumerate(list_attn_caches):
b[li].append(s_list[bi])
list_conv_caches = [None] * batch_size
for i in range(batch_size):
list_conv_caches[i] = [None] * num_layers
for li, layer in enumerate(conv_caches):
c_list = layer.unbind(dim=0)
for bi, b in enumerate(list_conv_caches):
b[li] = c_list[bi]
ans = [None] * batch_size
for i in range(batch_size):
ans[i] = [list_past_len[i], list_attn_caches[i], list_conv_caches[i]]
return ans
def stack_states(
state_list,
) -> List[List[torch.Tensor]]:
# TODO: modify doc
"""Stack list of emformer states that correspond to separate utterances
into a single emformer state so that it can be used as an input for
emformer when those utterances are formed into a batch.
Note:
It is the inverse of :func:`unstack_states`.
Args:
state_list:
Each element in state_list corresponding to the internal state
of the emformer model for a single utterance.
Returns:
Return a new state corresponding to a batch of utterances.
See the input argument of :func:`unstack_states` for the meaning
of the returned tensor.
"""
batch_size = len(state_list)
past_lens = [states[0] for states in state_list]
past_lens = torch.tensor([past_lens])
attn_caches = []
for layer in state_list[0][1]:
if batch_size > 1:
# Note: We will stack attn_caches[layer][s][] later to get attn_caches[layer][s] # noqa
attn_caches.append([[s] for s in layer])
else:
attn_caches.append([s.unsqueeze(1) for s in layer])
for b, states in enumerate(state_list[1:], 1):
for li, layer in enumerate(states[1]):
for si, s in enumerate(layer):
attn_caches[li][si].append(s)
if b == batch_size - 1:
attn_caches[li][si] = torch.stack(
attn_caches[li][si], dim=1
)
conv_caches = []
for layer in state_list[0][2]:
if batch_size > 1:
# Note: We will stack conv_caches[layer][] later to get attn_caches[layer] # noqa
conv_caches.append([layer])
else:
conv_caches.append(layer.unsqueeze(0))
for b, states in enumerate(state_list[1:], 1):
for li, layer in enumerate(states[2]):
conv_caches[li].append(layer)
if b == batch_size - 1:
conv_caches[li] = torch.stack(conv_caches[li], dim=0)
return [past_lens, attn_caches, conv_caches]
class ConvolutionModule(nn.Module):
"""ConvolutionModule.

View File

@ -23,6 +23,7 @@ from pathlib import Path
from typing import List, Optional, Tuple
import k2
from lhotse import CutSet
import numpy as np
import sentencepiece as spm
import torch
@ -30,7 +31,7 @@ import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule
from beam_search import Hypothesis, HypothesisList, get_hyps_shape
from emformer import LOG_EPSILON, stack_states, unstack_states
from streaming_feature_extractor import FeatureExtractionStream
from streaming_feature_extractor import Stream
from train import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import (
@ -157,174 +158,21 @@ def get_parser():
help="Sample rate of the audio",
)
parser.add_argument(
"--num-decode-streams",
type=int,
default=2000,
help="The number of streams that can be decoded parallel",
)
add_model_arguments(parser)
return parser
class StreamingAudioSamples(object):
"""This class takes as input a list of audio samples and returns
them in a streaming fashion.
"""
def __init__(self, samples: List[torch.Tensor]) -> None:
"""
Args:
samples:
A list of audio samples. Each entry is a 1-D tensor of dtype
torch.float32, containing the audio samples of an utterance.
"""
self.samples = samples
self.cur_indexes = [0] * len(self.samples)
@property
def done(self) -> bool:
"""Return True if all samples have been processed.
Return False otherwise.
"""
for i, samples in zip(self.cur_indexes, self.samples):
if i < samples.numel():
return False
return True
def get_next(self) -> List[torch.Tensor]:
"""Return a list of audio samples. Each entry may have different
lengths. It is OK if an entry contains no samples at all, which
means it reaches the end of the utterance.
"""
ans = []
num = [1024] * len(self.samples)
for i in range(len(self.samples)):
start = self.cur_indexes[i]
end = start + num[i]
self.cur_indexes[i] = end
s = self.samples[i][start:end]
ans.append(s)
return ans
class StreamList(object):
def __init__(
self,
batch_size: int,
context_size: int,
decoding_method: str,
):
"""
Args:
batch_size:
Size of this batch.
context_size:
Context size of the RNN-T decoder model.
decoding_method:
Decoding method. The possible values are:
- greedy_search
- modified_beam_search
"""
self.streams = [
FeatureExtractionStream(
context_size=context_size, decoding_method=decoding_method
)
for _ in range(batch_size)
]
def __getitem__(self, i) -> FeatureExtractionStream:
return self.streams[i]
@property
def done(self) -> bool:
"""Return True if all streams have reached end of utterance.
That is, no more audio samples are available for all utterances.
"""
return all(stream.done for stream in self.streams)
def accept_waveform(
self,
audio_samples: List[torch.Tensor],
sampling_rate: float,
):
"""Feed audio samples to each stream.
Args:
audio_samples:
A list of 1-D tensors containing the audio samples for each
utterance in the batch. If an entry is empty, it means
end-of-utterance has been reached.
sampling_rate:
Sampling rate of the given audio samples.
"""
assert len(audio_samples) == len(self.streams)
for stream, samples in zip(self.streams, audio_samples):
if stream.done:
assert samples.numel() == 0
continue
stream.accept_waveform(
sampling_rate=sampling_rate,
waveform=samples,
)
if samples.numel() == 0:
stream.input_finished()
def build_batch(
self,
chunk_length: int,
segment_length: int,
) -> Tuple[Optional[torch.Tensor], Optional[List[FeatureExtractionStream]]]:
"""
Args:
chunk_length:
Number of frames for each chunk. It equals to
``segment_length + right_context_length``.
segment_length
Number of frames for each segment.
Returns:
Return a tuple containing:
- features, a 3-D tensor of shape ``(num_active_streams, T, C)``
- active_streams, a list of active streams. We say a stream is
active when it has enough feature frames to be fed into the
encoder model.
"""
feature_list = []
stream_list = []
for stream in self.streams:
if len(stream.feature_frames) >= chunk_length:
# this_chunk is a list of tensors, each of which
# has a shape (1, feature_dim)
chunk = stream.feature_frames[:chunk_length]
stream.feature_frames = stream.feature_frames[segment_length:]
features = torch.cat(chunk, dim=0)
feature_list.append(features)
stream_list.append(stream)
elif stream.done and len(stream.feature_frames) > 0:
chunk = stream.feature_frames[:chunk_length]
stream.feature_frames = []
features = torch.cat(chunk, dim=0)
features = torch.nn.functional.pad(
features,
(0, 0, 0, chunk_length - features.size(0)),
mode="constant",
value=LOG_EPSILON,
)
feature_list.append(features)
stream_list.append(stream)
if len(feature_list) == 0:
return None, None
features = torch.stack(feature_list, dim=0)
return features, stream_list
def greedy_search(
model: nn.Module,
streams: List[FeatureExtractionStream],
streams: List[Stream],
encoder_out: torch.Tensor,
sp: spm.SentencePieceProcessor,
):
@ -401,7 +249,7 @@ def greedy_search(
def modified_beam_search(
model: nn.Module,
streams: List[FeatureExtractionStream],
streams: List[Stream],
encoder_out: torch.Tensor,
sp: spm.SentencePieceProcessor,
beam: int = 4,
@ -513,10 +361,69 @@ def modified_beam_search(
logging.info(f"Partial result {i}:\n{result}")
def build_batch(
decode_steams: List[Stream],
chunk_length: int,
segment_length: int,
) -> Tuple[
Optional[torch.Tensor],
Optional[torch.tensor],
Optional[List[Stream]],
]:
"""
Args:
chunk_length:
Number of frames for each chunk. It equals to
``segment_length + right_context_length``.
segment_length
Number of frames for each segment.
Returns:
Return a tuple containing:
- features, a 3-D tensor of shape ``(num_active_streams, T, C)``
- active_streams, a list of active streams. We say a stream is
active when it has enough feature frames to be fed into the
encoder model.
"""
feature_list = []
length_list = []
stream_list = []
for stream in decode_steams:
if len(stream.feature_frames) >= chunk_length:
# this_chunk is a list of tensors, each of which
# has a shape (1, feature_dim)
chunk = stream.feature_frames[:chunk_length]
stream.feature_frames = stream.feature_frames[segment_length:]
features = torch.cat(chunk, dim=0)
feature_list.append(features)
length_list.append(chunk_length)
stream_list.append(stream)
elif stream.done and len(stream.feature_frames) > 0:
chunk = stream.feature_frames[:chunk_length]
stream.feature_frames = []
features = torch.cat(chunk, dim=0)
length_list.append(features.size(0))
features = torch.nn.functional.pad(
features,
(0, 0, 0, chunk_length - features.size(0)),
mode="constant",
value=LOG_EPSILON,
)
feature_list.append(features)
stream_list.append(stream)
if len(feature_list) == 0:
return None, None, None
features = torch.stack(feature_list, dim=0)
lengths = torch.cat(length_list)
return features, lengths, stream_list
def process_features(
model: nn.Module,
features: torch.Tensor,
streams: List[FeatureExtractionStream],
feature_lens: torch.Tensor,
streams: List[Stream],
params: AttributeDict,
sp: spm.SentencePieceProcessor,
) -> None:
@ -536,30 +443,20 @@ def process_features(
"""
assert features.ndim == 3
assert features.size(0) == len(streams)
batch_size = features.size(0)
assert feature_lens.size(0) == len(streams)
device = model.device
features = features.to(device)
feature_lens = torch.full(
(batch_size,),
fill_value=features.size(1),
device=device,
)
# Caution: It has a limitation as it assumes that
# if one of the stream has an empty state, then all other
# streams also have empty states.
if streams[0].states is None:
states = None
else:
state_list = [stream.states for stream in streams]
states = stack_states(state_list)
state_list = [stream.states for stream in streams]
states = stack_states(state_list)
(encoder_out, encoder_out_lens, states,) = model.encoder.streaming_forward(
encoder_out, encoder_out_lens, states = model.encoder.infer(
features,
feature_lens,
states,
)
state_list = unstack_states(states)
for i, s in enumerate(state_list):
streams[i].states = s
@ -585,64 +482,80 @@ def process_features(
)
def decode_batch(
batched_samples: List[torch.Tensor],
model: nn.Module,
def decode_dataset(
params: AttributeDict,
cuts: CutSet,
model: nn.Module,
sp: spm.SentencePieceProcessor,
) -> List[str]:
"""
):
"""Decode dataset.
Args:
batched_samples:
A list of 1-D tensors containing the audio samples of each utterance.
model:
The RNN-T model.
params:
It is the return value of :func:`get_params`.
sp:
The BPE model.
"""
device = next(model.parameters()).device
# number of frames before subsampling
segment_length = model.encoder.segment_length
right_context_length = model.encoder.right_context_length
# 5 = 3 + 2
# 1) add 3 here since the subsampling method is using
# ((len - 1) // 2 - 1) // 2)
# 2) add 2 here we will drop first and last frame after subsampling
chunk_length = (segment_length + 5) + right_context_length
# We add 3 here since the subsampling method is using
# ((len - 1) // 2 - 1) // 2)
chunk_length = (segment_length + 3) + right_context_length
decode_results = []
streams = []
for num, cut in enumerate(cuts):
audio: np.ndarray = cut.load_audio()
# audio.shape: (1, num_samples)
assert len(audio.shape) == 2
assert audio.shape[0] == 1, "Should be single channel"
assert audio.dtype == np.float32, audio.dtype
batch_size = len(batched_samples)
streaming_audio_samples = StreamingAudioSamples(batched_samples)
# The trained model is using normalized samples
assert audio.max() <= 1, "Should be normalized to [-1, 1])"
stream_list = StreamList(
batch_size=batch_size,
context_size=params.context_size,
decoding_method=params.decoding_method,
)
samples = torch.from_numpy(audio).squeeze(0)
while not streaming_audio_samples.done:
samples = streaming_audio_samples.get_next()
stream_list.accept_waveform(
audio_samples=samples,
sampling_rate=params.sampling_rate,
# Each uttetance has a Stream
stream = Stream(
params=params,
audio_sample=samples,
ground_truth=cut.supervisions[0].text,
device=device,
)
features, active_streams = stream_list.build_batch(
chunk_length=chunk_length,
segment_length=segment_length,
)
if features is not None:
process_features(
model=model,
features=features,
streams=active_streams,
params=params,
sp=sp,
streams.append(stream)
while len(streams) >= params.num_decode_streams:
for stream in streams:
stream.accept_waveform()
# try to build batch
features, active_streams = build_batch(
chunk_length=chunk_length,
segment_length=segment_length,
)
results = []
for stream in stream_list.streams:
text = sp.decode(stream.decoding_result())
results.append(text)
return results
if features is not None:
process_features(
model=model,
features=features,
streams=active_streams,
params=params,
sp=sp,
)
new_streams = []
for stream in streams:
if stream.done:
decode_results.append(
(
stream.ground_truth.split(),
sp.decode(stream.decoding_result()).split(),
)
)
else:
new_streams.append(stream)
del streams
streams = new_streams
@torch.no_grad()
@ -726,8 +639,8 @@ def main():
samples = torch.from_numpy(audio).squeeze(0)
batched_samples.append(samples)
ground_truth.append(cut.supervisions[0].text)
# batched_samples.append(samples)
# ground_truth.append(cut.supervisions[0].text)
if len(batched_samples) >= batch_size:
decoded_results = decode_batch(

View File

@ -42,12 +42,12 @@ def _create_streaming_feature_extractor() -> OnlineFeature:
return OnlineFbank(opts)
class FeatureExtractionStream(object):
class Stream(object):
def __init__(
self,
params: AttributeDict,
context_size: int,
decoding_method: str,
audio_sample: torch.Tensor,
ground_truth: str,
device: torch.device = torch.devive("cpu"),
) -> None:
"""
@ -63,6 +63,7 @@ class FeatureExtractionStream(object):
# It contains a list of 1-D tensors representing the feature frames.
self.feature_frames: List[torch.Tensor] = []
self.num_fetched_frames = 0
# After calling `self.input_finished()`, we set this flag to True
self._done = False
@ -87,20 +88,29 @@ class FeatureExtractionStream(object):
self.states = [past_len, attn_caches, conv_caches]
# It use different attributes for different decoding methods.
self.context_size = context_size
self.decoding_method = decoding_method
if decoding_method == "greedy_search":
self.context_size = params.context_size
self.decoding_method = params.decoding_method
if params.decoding_method == "greedy_search":
self.hyp: Optional[List[int]] = None
self.decoder_out: Optional[torch.Tensor] = None
elif decoding_method == "modified_beam_search":
elif params.decoding_method == "modified_beam_search":
self.hyps = HypothesisList()
else:
raise ValueError(f"Unsupported decoding method: {decoding_method}")
raise ValueError(
f"Unsupported decoding method: {params.decoding_method}"
)
self.sample_rate = params.sample_rate
self.audio_sample = audio_sample
# Current index of sample
self.cur_index = 0
self.ground_truth = ground_truth
def accept_waveform(
self,
sampling_rate: float,
waveform: torch.Tensor,
# sampling_rate: float,
# waveform: torch.Tensor,
) -> None:
"""Feed audio samples to the feature extractor and compute features
if there are enough samples available.
@ -120,12 +130,20 @@ class FeatureExtractionStream(object):
A 1-D torch tensor of dtype torch.float32 containing audio samples.
It should be on CPU.
"""
start = self.cur_index
end = self.cur_index + 1024
waveform = self.audio_sample[start:end]
self.cur_index = end
self.feature_extractor.accept_waveform(
sampling_rate=sampling_rate,
sampling_rate=self.sampling_rate,
waveform=waveform,
)
self._fetch_frames()
if waveform.numel() == 0:
self.input_finished()
def input_finished(self) -> None:
"""Signal that no more audio samples available and the feature
extractor should flush the buffered samples to compute frames.