mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-10 01:24:19 +00:00
modify emformer states stack and unstack, streaming decoding, to be continued
This commit is contained in:
parent
5df1406684
commit
f8071e9373
@ -37,6 +37,112 @@ from scaling import (
|
|||||||
from icefall.utils import make_pad_mask
|
from icefall.utils import make_pad_mask
|
||||||
|
|
||||||
|
|
||||||
|
LOG_EPSILON = math.log(1e-10)
|
||||||
|
|
||||||
|
|
||||||
|
def unstack_states(
|
||||||
|
states,
|
||||||
|
) -> List[List[List[torch.Tensor]]]:
|
||||||
|
# TODO: modify doc
|
||||||
|
"""Unstack the emformer state corresponding to a batch of utterances
|
||||||
|
into a list of states, were the i-th entry is the state from the i-th
|
||||||
|
utterance in the batch.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
states:
|
||||||
|
A list-of-list of tensors. ``len(states)`` equals to number of
|
||||||
|
layers in the emformer. ``states[i]]`` contains the states for
|
||||||
|
the i-th layer. ``states[i][k]`` is either a 3-D tensor of shape
|
||||||
|
``(T, N, C)`` or a 2-D tensor of shape ``(C, N)``
|
||||||
|
"""
|
||||||
|
|
||||||
|
past_lens, attn_caches, conv_caches = states
|
||||||
|
batch_size = past_lens.size(0)
|
||||||
|
num_layers = len(attn_caches)
|
||||||
|
|
||||||
|
list_past_len = past_lens.tolist()
|
||||||
|
|
||||||
|
list_attn_caches = [None] * batch_size
|
||||||
|
for i in range(batch_size):
|
||||||
|
list_attn_caches[i] = [[] for _ in range(num_layers)]
|
||||||
|
for li, layer in enumerate(attn_caches):
|
||||||
|
for s in layer:
|
||||||
|
s_list = s.unbind(dim=1)
|
||||||
|
for bi, b in enumerate(list_attn_caches):
|
||||||
|
b[li].append(s_list[bi])
|
||||||
|
|
||||||
|
list_conv_caches = [None] * batch_size
|
||||||
|
for i in range(batch_size):
|
||||||
|
list_conv_caches[i] = [None] * num_layers
|
||||||
|
for li, layer in enumerate(conv_caches):
|
||||||
|
c_list = layer.unbind(dim=0)
|
||||||
|
for bi, b in enumerate(list_conv_caches):
|
||||||
|
b[li] = c_list[bi]
|
||||||
|
|
||||||
|
ans = [None] * batch_size
|
||||||
|
for i in range(batch_size):
|
||||||
|
ans[i] = [list_past_len[i], list_attn_caches[i], list_conv_caches[i]]
|
||||||
|
|
||||||
|
return ans
|
||||||
|
|
||||||
|
|
||||||
|
def stack_states(
|
||||||
|
state_list,
|
||||||
|
) -> List[List[torch.Tensor]]:
|
||||||
|
# TODO: modify doc
|
||||||
|
"""Stack list of emformer states that correspond to separate utterances
|
||||||
|
into a single emformer state so that it can be used as an input for
|
||||||
|
emformer when those utterances are formed into a batch.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
It is the inverse of :func:`unstack_states`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state_list:
|
||||||
|
Each element in state_list corresponding to the internal state
|
||||||
|
of the emformer model for a single utterance.
|
||||||
|
Returns:
|
||||||
|
Return a new state corresponding to a batch of utterances.
|
||||||
|
See the input argument of :func:`unstack_states` for the meaning
|
||||||
|
of the returned tensor.
|
||||||
|
"""
|
||||||
|
batch_size = len(state_list)
|
||||||
|
|
||||||
|
past_lens = [states[0] for states in state_list]
|
||||||
|
past_lens = torch.tensor([past_lens])
|
||||||
|
|
||||||
|
attn_caches = []
|
||||||
|
for layer in state_list[0][1]:
|
||||||
|
if batch_size > 1:
|
||||||
|
# Note: We will stack attn_caches[layer][s][] later to get attn_caches[layer][s] # noqa
|
||||||
|
attn_caches.append([[s] for s in layer])
|
||||||
|
else:
|
||||||
|
attn_caches.append([s.unsqueeze(1) for s in layer])
|
||||||
|
for b, states in enumerate(state_list[1:], 1):
|
||||||
|
for li, layer in enumerate(states[1]):
|
||||||
|
for si, s in enumerate(layer):
|
||||||
|
attn_caches[li][si].append(s)
|
||||||
|
if b == batch_size - 1:
|
||||||
|
attn_caches[li][si] = torch.stack(
|
||||||
|
attn_caches[li][si], dim=1
|
||||||
|
)
|
||||||
|
|
||||||
|
conv_caches = []
|
||||||
|
for layer in state_list[0][2]:
|
||||||
|
if batch_size > 1:
|
||||||
|
# Note: We will stack conv_caches[layer][] later to get attn_caches[layer] # noqa
|
||||||
|
conv_caches.append([layer])
|
||||||
|
else:
|
||||||
|
conv_caches.append(layer.unsqueeze(0))
|
||||||
|
for b, states in enumerate(state_list[1:], 1):
|
||||||
|
for li, layer in enumerate(states[2]):
|
||||||
|
conv_caches[li].append(layer)
|
||||||
|
if b == batch_size - 1:
|
||||||
|
conv_caches[li] = torch.stack(conv_caches[li], dim=0)
|
||||||
|
|
||||||
|
return [past_lens, attn_caches, conv_caches]
|
||||||
|
|
||||||
|
|
||||||
class ConvolutionModule(nn.Module):
|
class ConvolutionModule(nn.Module):
|
||||||
"""ConvolutionModule.
|
"""ConvolutionModule.
|
||||||
|
|
||||||
|
@ -23,6 +23,7 @@ from pathlib import Path
|
|||||||
from typing import List, Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
import k2
|
import k2
|
||||||
|
from lhotse import CutSet
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import sentencepiece as spm
|
import sentencepiece as spm
|
||||||
import torch
|
import torch
|
||||||
@ -30,7 +31,7 @@ import torch.nn as nn
|
|||||||
from asr_datamodule import LibriSpeechAsrDataModule
|
from asr_datamodule import LibriSpeechAsrDataModule
|
||||||
from beam_search import Hypothesis, HypothesisList, get_hyps_shape
|
from beam_search import Hypothesis, HypothesisList, get_hyps_shape
|
||||||
from emformer import LOG_EPSILON, stack_states, unstack_states
|
from emformer import LOG_EPSILON, stack_states, unstack_states
|
||||||
from streaming_feature_extractor import FeatureExtractionStream
|
from streaming_feature_extractor import Stream
|
||||||
from train import add_model_arguments, get_params, get_transducer_model
|
from train import add_model_arguments, get_params, get_transducer_model
|
||||||
|
|
||||||
from icefall.checkpoint import (
|
from icefall.checkpoint import (
|
||||||
@ -157,174 +158,21 @@ def get_parser():
|
|||||||
help="Sample rate of the audio",
|
help="Sample rate of the audio",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--num-decode-streams",
|
||||||
|
type=int,
|
||||||
|
default=2000,
|
||||||
|
help="The number of streams that can be decoded parallel",
|
||||||
|
)
|
||||||
|
|
||||||
add_model_arguments(parser)
|
add_model_arguments(parser)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
class StreamingAudioSamples(object):
|
|
||||||
"""This class takes as input a list of audio samples and returns
|
|
||||||
them in a streaming fashion.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, samples: List[torch.Tensor]) -> None:
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
samples:
|
|
||||||
A list of audio samples. Each entry is a 1-D tensor of dtype
|
|
||||||
torch.float32, containing the audio samples of an utterance.
|
|
||||||
"""
|
|
||||||
self.samples = samples
|
|
||||||
self.cur_indexes = [0] * len(self.samples)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def done(self) -> bool:
|
|
||||||
"""Return True if all samples have been processed.
|
|
||||||
Return False otherwise.
|
|
||||||
"""
|
|
||||||
for i, samples in zip(self.cur_indexes, self.samples):
|
|
||||||
if i < samples.numel():
|
|
||||||
return False
|
|
||||||
return True
|
|
||||||
|
|
||||||
def get_next(self) -> List[torch.Tensor]:
|
|
||||||
"""Return a list of audio samples. Each entry may have different
|
|
||||||
lengths. It is OK if an entry contains no samples at all, which
|
|
||||||
means it reaches the end of the utterance.
|
|
||||||
"""
|
|
||||||
ans = []
|
|
||||||
|
|
||||||
num = [1024] * len(self.samples)
|
|
||||||
|
|
||||||
for i in range(len(self.samples)):
|
|
||||||
start = self.cur_indexes[i]
|
|
||||||
end = start + num[i]
|
|
||||||
self.cur_indexes[i] = end
|
|
||||||
|
|
||||||
s = self.samples[i][start:end]
|
|
||||||
ans.append(s)
|
|
||||||
|
|
||||||
return ans
|
|
||||||
|
|
||||||
|
|
||||||
class StreamList(object):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
batch_size: int,
|
|
||||||
context_size: int,
|
|
||||||
decoding_method: str,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
batch_size:
|
|
||||||
Size of this batch.
|
|
||||||
context_size:
|
|
||||||
Context size of the RNN-T decoder model.
|
|
||||||
decoding_method:
|
|
||||||
Decoding method. The possible values are:
|
|
||||||
- greedy_search
|
|
||||||
- modified_beam_search
|
|
||||||
"""
|
|
||||||
|
|
||||||
self.streams = [
|
|
||||||
FeatureExtractionStream(
|
|
||||||
context_size=context_size, decoding_method=decoding_method
|
|
||||||
)
|
|
||||||
for _ in range(batch_size)
|
|
||||||
]
|
|
||||||
|
|
||||||
def __getitem__(self, i) -> FeatureExtractionStream:
|
|
||||||
return self.streams[i]
|
|
||||||
|
|
||||||
@property
|
|
||||||
def done(self) -> bool:
|
|
||||||
"""Return True if all streams have reached end of utterance.
|
|
||||||
That is, no more audio samples are available for all utterances.
|
|
||||||
"""
|
|
||||||
return all(stream.done for stream in self.streams)
|
|
||||||
|
|
||||||
def accept_waveform(
|
|
||||||
self,
|
|
||||||
audio_samples: List[torch.Tensor],
|
|
||||||
sampling_rate: float,
|
|
||||||
):
|
|
||||||
"""Feed audio samples to each stream.
|
|
||||||
Args:
|
|
||||||
audio_samples:
|
|
||||||
A list of 1-D tensors containing the audio samples for each
|
|
||||||
utterance in the batch. If an entry is empty, it means
|
|
||||||
end-of-utterance has been reached.
|
|
||||||
sampling_rate:
|
|
||||||
Sampling rate of the given audio samples.
|
|
||||||
"""
|
|
||||||
assert len(audio_samples) == len(self.streams)
|
|
||||||
for stream, samples in zip(self.streams, audio_samples):
|
|
||||||
|
|
||||||
if stream.done:
|
|
||||||
assert samples.numel() == 0
|
|
||||||
continue
|
|
||||||
|
|
||||||
stream.accept_waveform(
|
|
||||||
sampling_rate=sampling_rate,
|
|
||||||
waveform=samples,
|
|
||||||
)
|
|
||||||
|
|
||||||
if samples.numel() == 0:
|
|
||||||
stream.input_finished()
|
|
||||||
|
|
||||||
def build_batch(
|
|
||||||
self,
|
|
||||||
chunk_length: int,
|
|
||||||
segment_length: int,
|
|
||||||
) -> Tuple[Optional[torch.Tensor], Optional[List[FeatureExtractionStream]]]:
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
chunk_length:
|
|
||||||
Number of frames for each chunk. It equals to
|
|
||||||
``segment_length + right_context_length``.
|
|
||||||
segment_length
|
|
||||||
Number of frames for each segment.
|
|
||||||
Returns:
|
|
||||||
Return a tuple containing:
|
|
||||||
- features, a 3-D tensor of shape ``(num_active_streams, T, C)``
|
|
||||||
- active_streams, a list of active streams. We say a stream is
|
|
||||||
active when it has enough feature frames to be fed into the
|
|
||||||
encoder model.
|
|
||||||
"""
|
|
||||||
feature_list = []
|
|
||||||
stream_list = []
|
|
||||||
for stream in self.streams:
|
|
||||||
if len(stream.feature_frames) >= chunk_length:
|
|
||||||
# this_chunk is a list of tensors, each of which
|
|
||||||
# has a shape (1, feature_dim)
|
|
||||||
chunk = stream.feature_frames[:chunk_length]
|
|
||||||
stream.feature_frames = stream.feature_frames[segment_length:]
|
|
||||||
features = torch.cat(chunk, dim=0)
|
|
||||||
feature_list.append(features)
|
|
||||||
stream_list.append(stream)
|
|
||||||
elif stream.done and len(stream.feature_frames) > 0:
|
|
||||||
chunk = stream.feature_frames[:chunk_length]
|
|
||||||
stream.feature_frames = []
|
|
||||||
features = torch.cat(chunk, dim=0)
|
|
||||||
features = torch.nn.functional.pad(
|
|
||||||
features,
|
|
||||||
(0, 0, 0, chunk_length - features.size(0)),
|
|
||||||
mode="constant",
|
|
||||||
value=LOG_EPSILON,
|
|
||||||
)
|
|
||||||
feature_list.append(features)
|
|
||||||
stream_list.append(stream)
|
|
||||||
|
|
||||||
if len(feature_list) == 0:
|
|
||||||
return None, None
|
|
||||||
|
|
||||||
features = torch.stack(feature_list, dim=0)
|
|
||||||
return features, stream_list
|
|
||||||
|
|
||||||
|
|
||||||
def greedy_search(
|
def greedy_search(
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
streams: List[FeatureExtractionStream],
|
streams: List[Stream],
|
||||||
encoder_out: torch.Tensor,
|
encoder_out: torch.Tensor,
|
||||||
sp: spm.SentencePieceProcessor,
|
sp: spm.SentencePieceProcessor,
|
||||||
):
|
):
|
||||||
@ -401,7 +249,7 @@ def greedy_search(
|
|||||||
|
|
||||||
def modified_beam_search(
|
def modified_beam_search(
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
streams: List[FeatureExtractionStream],
|
streams: List[Stream],
|
||||||
encoder_out: torch.Tensor,
|
encoder_out: torch.Tensor,
|
||||||
sp: spm.SentencePieceProcessor,
|
sp: spm.SentencePieceProcessor,
|
||||||
beam: int = 4,
|
beam: int = 4,
|
||||||
@ -513,10 +361,69 @@ def modified_beam_search(
|
|||||||
logging.info(f"Partial result {i}:\n{result}")
|
logging.info(f"Partial result {i}:\n{result}")
|
||||||
|
|
||||||
|
|
||||||
|
def build_batch(
|
||||||
|
decode_steams: List[Stream],
|
||||||
|
chunk_length: int,
|
||||||
|
segment_length: int,
|
||||||
|
) -> Tuple[
|
||||||
|
Optional[torch.Tensor],
|
||||||
|
Optional[torch.tensor],
|
||||||
|
Optional[List[Stream]],
|
||||||
|
]:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
chunk_length:
|
||||||
|
Number of frames for each chunk. It equals to
|
||||||
|
``segment_length + right_context_length``.
|
||||||
|
segment_length
|
||||||
|
Number of frames for each segment.
|
||||||
|
Returns:
|
||||||
|
Return a tuple containing:
|
||||||
|
- features, a 3-D tensor of shape ``(num_active_streams, T, C)``
|
||||||
|
- active_streams, a list of active streams. We say a stream is
|
||||||
|
active when it has enough feature frames to be fed into the
|
||||||
|
encoder model.
|
||||||
|
"""
|
||||||
|
feature_list = []
|
||||||
|
length_list = []
|
||||||
|
stream_list = []
|
||||||
|
for stream in decode_steams:
|
||||||
|
if len(stream.feature_frames) >= chunk_length:
|
||||||
|
# this_chunk is a list of tensors, each of which
|
||||||
|
# has a shape (1, feature_dim)
|
||||||
|
chunk = stream.feature_frames[:chunk_length]
|
||||||
|
stream.feature_frames = stream.feature_frames[segment_length:]
|
||||||
|
features = torch.cat(chunk, dim=0)
|
||||||
|
feature_list.append(features)
|
||||||
|
length_list.append(chunk_length)
|
||||||
|
stream_list.append(stream)
|
||||||
|
elif stream.done and len(stream.feature_frames) > 0:
|
||||||
|
chunk = stream.feature_frames[:chunk_length]
|
||||||
|
stream.feature_frames = []
|
||||||
|
features = torch.cat(chunk, dim=0)
|
||||||
|
length_list.append(features.size(0))
|
||||||
|
features = torch.nn.functional.pad(
|
||||||
|
features,
|
||||||
|
(0, 0, 0, chunk_length - features.size(0)),
|
||||||
|
mode="constant",
|
||||||
|
value=LOG_EPSILON,
|
||||||
|
)
|
||||||
|
feature_list.append(features)
|
||||||
|
stream_list.append(stream)
|
||||||
|
|
||||||
|
if len(feature_list) == 0:
|
||||||
|
return None, None, None
|
||||||
|
|
||||||
|
features = torch.stack(feature_list, dim=0)
|
||||||
|
lengths = torch.cat(length_list)
|
||||||
|
return features, lengths, stream_list
|
||||||
|
|
||||||
|
|
||||||
def process_features(
|
def process_features(
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
features: torch.Tensor,
|
features: torch.Tensor,
|
||||||
streams: List[FeatureExtractionStream],
|
feature_lens: torch.Tensor,
|
||||||
|
streams: List[Stream],
|
||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
sp: spm.SentencePieceProcessor,
|
sp: spm.SentencePieceProcessor,
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -536,30 +443,20 @@ def process_features(
|
|||||||
"""
|
"""
|
||||||
assert features.ndim == 3
|
assert features.ndim == 3
|
||||||
assert features.size(0) == len(streams)
|
assert features.size(0) == len(streams)
|
||||||
batch_size = features.size(0)
|
assert feature_lens.size(0) == len(streams)
|
||||||
|
|
||||||
device = model.device
|
device = model.device
|
||||||
features = features.to(device)
|
features = features.to(device)
|
||||||
feature_lens = torch.full(
|
|
||||||
(batch_size,),
|
|
||||||
fill_value=features.size(1),
|
|
||||||
device=device,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Caution: It has a limitation as it assumes that
|
|
||||||
# if one of the stream has an empty state, then all other
|
|
||||||
# streams also have empty states.
|
|
||||||
if streams[0].states is None:
|
|
||||||
states = None
|
|
||||||
else:
|
|
||||||
state_list = [stream.states for stream in streams]
|
state_list = [stream.states for stream in streams]
|
||||||
states = stack_states(state_list)
|
states = stack_states(state_list)
|
||||||
|
|
||||||
(encoder_out, encoder_out_lens, states,) = model.encoder.streaming_forward(
|
encoder_out, encoder_out_lens, states = model.encoder.infer(
|
||||||
features,
|
features,
|
||||||
feature_lens,
|
feature_lens,
|
||||||
states,
|
states,
|
||||||
)
|
)
|
||||||
|
|
||||||
state_list = unstack_states(states)
|
state_list = unstack_states(states)
|
||||||
for i, s in enumerate(state_list):
|
for i, s in enumerate(state_list):
|
||||||
streams[i].states = s
|
streams[i].states = s
|
||||||
@ -585,48 +482,55 @@ def process_features(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def decode_batch(
|
def decode_dataset(
|
||||||
batched_samples: List[torch.Tensor],
|
|
||||||
model: nn.Module,
|
|
||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
|
cuts: CutSet,
|
||||||
|
model: nn.Module,
|
||||||
sp: spm.SentencePieceProcessor,
|
sp: spm.SentencePieceProcessor,
|
||||||
) -> List[str]:
|
):
|
||||||
"""
|
"""Decode dataset.
|
||||||
Args:
|
Args:
|
||||||
batched_samples:
|
|
||||||
A list of 1-D tensors containing the audio samples of each utterance.
|
|
||||||
model:
|
|
||||||
The RNN-T model.
|
|
||||||
params:
|
|
||||||
It is the return value of :func:`get_params`.
|
|
||||||
sp:
|
|
||||||
The BPE model.
|
|
||||||
"""
|
"""
|
||||||
|
device = next(model.parameters()).device
|
||||||
|
|
||||||
# number of frames before subsampling
|
# number of frames before subsampling
|
||||||
segment_length = model.encoder.segment_length
|
segment_length = model.encoder.segment_length
|
||||||
|
|
||||||
right_context_length = model.encoder.right_context_length
|
right_context_length = model.encoder.right_context_length
|
||||||
|
# 5 = 3 + 2
|
||||||
# We add 3 here since the subsampling method is using
|
# 1) add 3 here since the subsampling method is using
|
||||||
# ((len - 1) // 2 - 1) // 2)
|
# ((len - 1) // 2 - 1) // 2)
|
||||||
chunk_length = (segment_length + 3) + right_context_length
|
# 2) add 2 here we will drop first and last frame after subsampling
|
||||||
|
chunk_length = (segment_length + 5) + right_context_length
|
||||||
|
|
||||||
batch_size = len(batched_samples)
|
decode_results = []
|
||||||
streaming_audio_samples = StreamingAudioSamples(batched_samples)
|
streams = []
|
||||||
|
for num, cut in enumerate(cuts):
|
||||||
|
audio: np.ndarray = cut.load_audio()
|
||||||
|
# audio.shape: (1, num_samples)
|
||||||
|
assert len(audio.shape) == 2
|
||||||
|
assert audio.shape[0] == 1, "Should be single channel"
|
||||||
|
assert audio.dtype == np.float32, audio.dtype
|
||||||
|
|
||||||
stream_list = StreamList(
|
# The trained model is using normalized samples
|
||||||
batch_size=batch_size,
|
assert audio.max() <= 1, "Should be normalized to [-1, 1])"
|
||||||
context_size=params.context_size,
|
|
||||||
decoding_method=params.decoding_method,
|
samples = torch.from_numpy(audio).squeeze(0)
|
||||||
|
|
||||||
|
# Each uttetance has a Stream
|
||||||
|
stream = Stream(
|
||||||
|
params=params,
|
||||||
|
audio_sample=samples,
|
||||||
|
ground_truth=cut.supervisions[0].text,
|
||||||
|
device=device,
|
||||||
)
|
)
|
||||||
|
streams.append(stream)
|
||||||
|
|
||||||
while not streaming_audio_samples.done:
|
while len(streams) >= params.num_decode_streams:
|
||||||
samples = streaming_audio_samples.get_next()
|
for stream in streams:
|
||||||
stream_list.accept_waveform(
|
stream.accept_waveform()
|
||||||
audio_samples=samples,
|
|
||||||
sampling_rate=params.sampling_rate,
|
# try to build batch
|
||||||
)
|
features, active_streams = build_batch(
|
||||||
features, active_streams = stream_list.build_batch(
|
|
||||||
chunk_length=chunk_length,
|
chunk_length=chunk_length,
|
||||||
segment_length=segment_length,
|
segment_length=segment_length,
|
||||||
)
|
)
|
||||||
@ -638,11 +542,20 @@ def decode_batch(
|
|||||||
params=params,
|
params=params,
|
||||||
sp=sp,
|
sp=sp,
|
||||||
)
|
)
|
||||||
results = []
|
|
||||||
for stream in stream_list.streams:
|
new_streams = []
|
||||||
text = sp.decode(stream.decoding_result())
|
for stream in streams:
|
||||||
results.append(text)
|
if stream.done:
|
||||||
return results
|
decode_results.append(
|
||||||
|
(
|
||||||
|
stream.ground_truth.split(),
|
||||||
|
sp.decode(stream.decoding_result()).split(),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
new_streams.append(stream)
|
||||||
|
del streams
|
||||||
|
streams = new_streams
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
@ -726,8 +639,8 @@ def main():
|
|||||||
|
|
||||||
samples = torch.from_numpy(audio).squeeze(0)
|
samples = torch.from_numpy(audio).squeeze(0)
|
||||||
|
|
||||||
batched_samples.append(samples)
|
# batched_samples.append(samples)
|
||||||
ground_truth.append(cut.supervisions[0].text)
|
# ground_truth.append(cut.supervisions[0].text)
|
||||||
|
|
||||||
if len(batched_samples) >= batch_size:
|
if len(batched_samples) >= batch_size:
|
||||||
decoded_results = decode_batch(
|
decoded_results = decode_batch(
|
||||||
|
@ -42,12 +42,12 @@ def _create_streaming_feature_extractor() -> OnlineFeature:
|
|||||||
return OnlineFbank(opts)
|
return OnlineFbank(opts)
|
||||||
|
|
||||||
|
|
||||||
class FeatureExtractionStream(object):
|
class Stream(object):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
context_size: int,
|
audio_sample: torch.Tensor,
|
||||||
decoding_method: str,
|
ground_truth: str,
|
||||||
device: torch.device = torch.devive("cpu"),
|
device: torch.device = torch.devive("cpu"),
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
@ -63,6 +63,7 @@ class FeatureExtractionStream(object):
|
|||||||
# It contains a list of 1-D tensors representing the feature frames.
|
# It contains a list of 1-D tensors representing the feature frames.
|
||||||
self.feature_frames: List[torch.Tensor] = []
|
self.feature_frames: List[torch.Tensor] = []
|
||||||
self.num_fetched_frames = 0
|
self.num_fetched_frames = 0
|
||||||
|
|
||||||
# After calling `self.input_finished()`, we set this flag to True
|
# After calling `self.input_finished()`, we set this flag to True
|
||||||
self._done = False
|
self._done = False
|
||||||
|
|
||||||
@ -87,20 +88,29 @@ class FeatureExtractionStream(object):
|
|||||||
self.states = [past_len, attn_caches, conv_caches]
|
self.states = [past_len, attn_caches, conv_caches]
|
||||||
|
|
||||||
# It use different attributes for different decoding methods.
|
# It use different attributes for different decoding methods.
|
||||||
self.context_size = context_size
|
self.context_size = params.context_size
|
||||||
self.decoding_method = decoding_method
|
self.decoding_method = params.decoding_method
|
||||||
if decoding_method == "greedy_search":
|
if params.decoding_method == "greedy_search":
|
||||||
self.hyp: Optional[List[int]] = None
|
self.hyp: Optional[List[int]] = None
|
||||||
self.decoder_out: Optional[torch.Tensor] = None
|
self.decoder_out: Optional[torch.Tensor] = None
|
||||||
elif decoding_method == "modified_beam_search":
|
elif params.decoding_method == "modified_beam_search":
|
||||||
self.hyps = HypothesisList()
|
self.hyps = HypothesisList()
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported decoding method: {decoding_method}")
|
raise ValueError(
|
||||||
|
f"Unsupported decoding method: {params.decoding_method}"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.sample_rate = params.sample_rate
|
||||||
|
self.audio_sample = audio_sample
|
||||||
|
# Current index of sample
|
||||||
|
self.cur_index = 0
|
||||||
|
|
||||||
|
self.ground_truth = ground_truth
|
||||||
|
|
||||||
def accept_waveform(
|
def accept_waveform(
|
||||||
self,
|
self,
|
||||||
sampling_rate: float,
|
# sampling_rate: float,
|
||||||
waveform: torch.Tensor,
|
# waveform: torch.Tensor,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Feed audio samples to the feature extractor and compute features
|
"""Feed audio samples to the feature extractor and compute features
|
||||||
if there are enough samples available.
|
if there are enough samples available.
|
||||||
@ -120,12 +130,20 @@ class FeatureExtractionStream(object):
|
|||||||
A 1-D torch tensor of dtype torch.float32 containing audio samples.
|
A 1-D torch tensor of dtype torch.float32 containing audio samples.
|
||||||
It should be on CPU.
|
It should be on CPU.
|
||||||
"""
|
"""
|
||||||
|
start = self.cur_index
|
||||||
|
end = self.cur_index + 1024
|
||||||
|
waveform = self.audio_sample[start:end]
|
||||||
|
self.cur_index = end
|
||||||
|
|
||||||
self.feature_extractor.accept_waveform(
|
self.feature_extractor.accept_waveform(
|
||||||
sampling_rate=sampling_rate,
|
sampling_rate=self.sampling_rate,
|
||||||
waveform=waveform,
|
waveform=waveform,
|
||||||
)
|
)
|
||||||
self._fetch_frames()
|
self._fetch_frames()
|
||||||
|
|
||||||
|
if waveform.numel() == 0:
|
||||||
|
self.input_finished()
|
||||||
|
|
||||||
def input_finished(self) -> None:
|
def input_finished(self) -> None:
|
||||||
"""Signal that no more audio samples available and the feature
|
"""Signal that no more audio samples available and the feature
|
||||||
extractor should flush the buffered samples to compute frames.
|
extractor should flush the buffered samples to compute frames.
|
||||||
|
Loading…
x
Reference in New Issue
Block a user