mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-08 16:44:20 +00:00
modify emformer states stack and unstack, streaming decoding, to be continued
This commit is contained in:
parent
5df1406684
commit
f8071e9373
@ -37,6 +37,112 @@ from scaling import (
|
||||
from icefall.utils import make_pad_mask
|
||||
|
||||
|
||||
LOG_EPSILON = math.log(1e-10)
|
||||
|
||||
|
||||
def unstack_states(
|
||||
states,
|
||||
) -> List[List[List[torch.Tensor]]]:
|
||||
# TODO: modify doc
|
||||
"""Unstack the emformer state corresponding to a batch of utterances
|
||||
into a list of states, were the i-th entry is the state from the i-th
|
||||
utterance in the batch.
|
||||
|
||||
Args:
|
||||
states:
|
||||
A list-of-list of tensors. ``len(states)`` equals to number of
|
||||
layers in the emformer. ``states[i]]`` contains the states for
|
||||
the i-th layer. ``states[i][k]`` is either a 3-D tensor of shape
|
||||
``(T, N, C)`` or a 2-D tensor of shape ``(C, N)``
|
||||
"""
|
||||
|
||||
past_lens, attn_caches, conv_caches = states
|
||||
batch_size = past_lens.size(0)
|
||||
num_layers = len(attn_caches)
|
||||
|
||||
list_past_len = past_lens.tolist()
|
||||
|
||||
list_attn_caches = [None] * batch_size
|
||||
for i in range(batch_size):
|
||||
list_attn_caches[i] = [[] for _ in range(num_layers)]
|
||||
for li, layer in enumerate(attn_caches):
|
||||
for s in layer:
|
||||
s_list = s.unbind(dim=1)
|
||||
for bi, b in enumerate(list_attn_caches):
|
||||
b[li].append(s_list[bi])
|
||||
|
||||
list_conv_caches = [None] * batch_size
|
||||
for i in range(batch_size):
|
||||
list_conv_caches[i] = [None] * num_layers
|
||||
for li, layer in enumerate(conv_caches):
|
||||
c_list = layer.unbind(dim=0)
|
||||
for bi, b in enumerate(list_conv_caches):
|
||||
b[li] = c_list[bi]
|
||||
|
||||
ans = [None] * batch_size
|
||||
for i in range(batch_size):
|
||||
ans[i] = [list_past_len[i], list_attn_caches[i], list_conv_caches[i]]
|
||||
|
||||
return ans
|
||||
|
||||
|
||||
def stack_states(
|
||||
state_list,
|
||||
) -> List[List[torch.Tensor]]:
|
||||
# TODO: modify doc
|
||||
"""Stack list of emformer states that correspond to separate utterances
|
||||
into a single emformer state so that it can be used as an input for
|
||||
emformer when those utterances are formed into a batch.
|
||||
|
||||
Note:
|
||||
It is the inverse of :func:`unstack_states`.
|
||||
|
||||
Args:
|
||||
state_list:
|
||||
Each element in state_list corresponding to the internal state
|
||||
of the emformer model for a single utterance.
|
||||
Returns:
|
||||
Return a new state corresponding to a batch of utterances.
|
||||
See the input argument of :func:`unstack_states` for the meaning
|
||||
of the returned tensor.
|
||||
"""
|
||||
batch_size = len(state_list)
|
||||
|
||||
past_lens = [states[0] for states in state_list]
|
||||
past_lens = torch.tensor([past_lens])
|
||||
|
||||
attn_caches = []
|
||||
for layer in state_list[0][1]:
|
||||
if batch_size > 1:
|
||||
# Note: We will stack attn_caches[layer][s][] later to get attn_caches[layer][s] # noqa
|
||||
attn_caches.append([[s] for s in layer])
|
||||
else:
|
||||
attn_caches.append([s.unsqueeze(1) for s in layer])
|
||||
for b, states in enumerate(state_list[1:], 1):
|
||||
for li, layer in enumerate(states[1]):
|
||||
for si, s in enumerate(layer):
|
||||
attn_caches[li][si].append(s)
|
||||
if b == batch_size - 1:
|
||||
attn_caches[li][si] = torch.stack(
|
||||
attn_caches[li][si], dim=1
|
||||
)
|
||||
|
||||
conv_caches = []
|
||||
for layer in state_list[0][2]:
|
||||
if batch_size > 1:
|
||||
# Note: We will stack conv_caches[layer][] later to get attn_caches[layer] # noqa
|
||||
conv_caches.append([layer])
|
||||
else:
|
||||
conv_caches.append(layer.unsqueeze(0))
|
||||
for b, states in enumerate(state_list[1:], 1):
|
||||
for li, layer in enumerate(states[2]):
|
||||
conv_caches[li].append(layer)
|
||||
if b == batch_size - 1:
|
||||
conv_caches[li] = torch.stack(conv_caches[li], dim=0)
|
||||
|
||||
return [past_lens, attn_caches, conv_caches]
|
||||
|
||||
|
||||
class ConvolutionModule(nn.Module):
|
||||
"""ConvolutionModule.
|
||||
|
||||
|
@ -23,6 +23,7 @@ from pathlib import Path
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import k2
|
||||
from lhotse import CutSet
|
||||
import numpy as np
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
@ -30,7 +31,7 @@ import torch.nn as nn
|
||||
from asr_datamodule import LibriSpeechAsrDataModule
|
||||
from beam_search import Hypothesis, HypothesisList, get_hyps_shape
|
||||
from emformer import LOG_EPSILON, stack_states, unstack_states
|
||||
from streaming_feature_extractor import FeatureExtractionStream
|
||||
from streaming_feature_extractor import Stream
|
||||
from train import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
from icefall.checkpoint import (
|
||||
@ -157,174 +158,21 @@ def get_parser():
|
||||
help="Sample rate of the audio",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-decode-streams",
|
||||
type=int,
|
||||
default=2000,
|
||||
help="The number of streams that can be decoded parallel",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
class StreamingAudioSamples(object):
|
||||
"""This class takes as input a list of audio samples and returns
|
||||
them in a streaming fashion.
|
||||
"""
|
||||
|
||||
def __init__(self, samples: List[torch.Tensor]) -> None:
|
||||
"""
|
||||
Args:
|
||||
samples:
|
||||
A list of audio samples. Each entry is a 1-D tensor of dtype
|
||||
torch.float32, containing the audio samples of an utterance.
|
||||
"""
|
||||
self.samples = samples
|
||||
self.cur_indexes = [0] * len(self.samples)
|
||||
|
||||
@property
|
||||
def done(self) -> bool:
|
||||
"""Return True if all samples have been processed.
|
||||
Return False otherwise.
|
||||
"""
|
||||
for i, samples in zip(self.cur_indexes, self.samples):
|
||||
if i < samples.numel():
|
||||
return False
|
||||
return True
|
||||
|
||||
def get_next(self) -> List[torch.Tensor]:
|
||||
"""Return a list of audio samples. Each entry may have different
|
||||
lengths. It is OK if an entry contains no samples at all, which
|
||||
means it reaches the end of the utterance.
|
||||
"""
|
||||
ans = []
|
||||
|
||||
num = [1024] * len(self.samples)
|
||||
|
||||
for i in range(len(self.samples)):
|
||||
start = self.cur_indexes[i]
|
||||
end = start + num[i]
|
||||
self.cur_indexes[i] = end
|
||||
|
||||
s = self.samples[i][start:end]
|
||||
ans.append(s)
|
||||
|
||||
return ans
|
||||
|
||||
|
||||
class StreamList(object):
|
||||
def __init__(
|
||||
self,
|
||||
batch_size: int,
|
||||
context_size: int,
|
||||
decoding_method: str,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
batch_size:
|
||||
Size of this batch.
|
||||
context_size:
|
||||
Context size of the RNN-T decoder model.
|
||||
decoding_method:
|
||||
Decoding method. The possible values are:
|
||||
- greedy_search
|
||||
- modified_beam_search
|
||||
"""
|
||||
|
||||
self.streams = [
|
||||
FeatureExtractionStream(
|
||||
context_size=context_size, decoding_method=decoding_method
|
||||
)
|
||||
for _ in range(batch_size)
|
||||
]
|
||||
|
||||
def __getitem__(self, i) -> FeatureExtractionStream:
|
||||
return self.streams[i]
|
||||
|
||||
@property
|
||||
def done(self) -> bool:
|
||||
"""Return True if all streams have reached end of utterance.
|
||||
That is, no more audio samples are available for all utterances.
|
||||
"""
|
||||
return all(stream.done for stream in self.streams)
|
||||
|
||||
def accept_waveform(
|
||||
self,
|
||||
audio_samples: List[torch.Tensor],
|
||||
sampling_rate: float,
|
||||
):
|
||||
"""Feed audio samples to each stream.
|
||||
Args:
|
||||
audio_samples:
|
||||
A list of 1-D tensors containing the audio samples for each
|
||||
utterance in the batch. If an entry is empty, it means
|
||||
end-of-utterance has been reached.
|
||||
sampling_rate:
|
||||
Sampling rate of the given audio samples.
|
||||
"""
|
||||
assert len(audio_samples) == len(self.streams)
|
||||
for stream, samples in zip(self.streams, audio_samples):
|
||||
|
||||
if stream.done:
|
||||
assert samples.numel() == 0
|
||||
continue
|
||||
|
||||
stream.accept_waveform(
|
||||
sampling_rate=sampling_rate,
|
||||
waveform=samples,
|
||||
)
|
||||
|
||||
if samples.numel() == 0:
|
||||
stream.input_finished()
|
||||
|
||||
def build_batch(
|
||||
self,
|
||||
chunk_length: int,
|
||||
segment_length: int,
|
||||
) -> Tuple[Optional[torch.Tensor], Optional[List[FeatureExtractionStream]]]:
|
||||
"""
|
||||
Args:
|
||||
chunk_length:
|
||||
Number of frames for each chunk. It equals to
|
||||
``segment_length + right_context_length``.
|
||||
segment_length
|
||||
Number of frames for each segment.
|
||||
Returns:
|
||||
Return a tuple containing:
|
||||
- features, a 3-D tensor of shape ``(num_active_streams, T, C)``
|
||||
- active_streams, a list of active streams. We say a stream is
|
||||
active when it has enough feature frames to be fed into the
|
||||
encoder model.
|
||||
"""
|
||||
feature_list = []
|
||||
stream_list = []
|
||||
for stream in self.streams:
|
||||
if len(stream.feature_frames) >= chunk_length:
|
||||
# this_chunk is a list of tensors, each of which
|
||||
# has a shape (1, feature_dim)
|
||||
chunk = stream.feature_frames[:chunk_length]
|
||||
stream.feature_frames = stream.feature_frames[segment_length:]
|
||||
features = torch.cat(chunk, dim=0)
|
||||
feature_list.append(features)
|
||||
stream_list.append(stream)
|
||||
elif stream.done and len(stream.feature_frames) > 0:
|
||||
chunk = stream.feature_frames[:chunk_length]
|
||||
stream.feature_frames = []
|
||||
features = torch.cat(chunk, dim=0)
|
||||
features = torch.nn.functional.pad(
|
||||
features,
|
||||
(0, 0, 0, chunk_length - features.size(0)),
|
||||
mode="constant",
|
||||
value=LOG_EPSILON,
|
||||
)
|
||||
feature_list.append(features)
|
||||
stream_list.append(stream)
|
||||
|
||||
if len(feature_list) == 0:
|
||||
return None, None
|
||||
|
||||
features = torch.stack(feature_list, dim=0)
|
||||
return features, stream_list
|
||||
|
||||
|
||||
def greedy_search(
|
||||
model: nn.Module,
|
||||
streams: List[FeatureExtractionStream],
|
||||
streams: List[Stream],
|
||||
encoder_out: torch.Tensor,
|
||||
sp: spm.SentencePieceProcessor,
|
||||
):
|
||||
@ -401,7 +249,7 @@ def greedy_search(
|
||||
|
||||
def modified_beam_search(
|
||||
model: nn.Module,
|
||||
streams: List[FeatureExtractionStream],
|
||||
streams: List[Stream],
|
||||
encoder_out: torch.Tensor,
|
||||
sp: spm.SentencePieceProcessor,
|
||||
beam: int = 4,
|
||||
@ -513,10 +361,69 @@ def modified_beam_search(
|
||||
logging.info(f"Partial result {i}:\n{result}")
|
||||
|
||||
|
||||
def build_batch(
|
||||
decode_steams: List[Stream],
|
||||
chunk_length: int,
|
||||
segment_length: int,
|
||||
) -> Tuple[
|
||||
Optional[torch.Tensor],
|
||||
Optional[torch.tensor],
|
||||
Optional[List[Stream]],
|
||||
]:
|
||||
"""
|
||||
Args:
|
||||
chunk_length:
|
||||
Number of frames for each chunk. It equals to
|
||||
``segment_length + right_context_length``.
|
||||
segment_length
|
||||
Number of frames for each segment.
|
||||
Returns:
|
||||
Return a tuple containing:
|
||||
- features, a 3-D tensor of shape ``(num_active_streams, T, C)``
|
||||
- active_streams, a list of active streams. We say a stream is
|
||||
active when it has enough feature frames to be fed into the
|
||||
encoder model.
|
||||
"""
|
||||
feature_list = []
|
||||
length_list = []
|
||||
stream_list = []
|
||||
for stream in decode_steams:
|
||||
if len(stream.feature_frames) >= chunk_length:
|
||||
# this_chunk is a list of tensors, each of which
|
||||
# has a shape (1, feature_dim)
|
||||
chunk = stream.feature_frames[:chunk_length]
|
||||
stream.feature_frames = stream.feature_frames[segment_length:]
|
||||
features = torch.cat(chunk, dim=0)
|
||||
feature_list.append(features)
|
||||
length_list.append(chunk_length)
|
||||
stream_list.append(stream)
|
||||
elif stream.done and len(stream.feature_frames) > 0:
|
||||
chunk = stream.feature_frames[:chunk_length]
|
||||
stream.feature_frames = []
|
||||
features = torch.cat(chunk, dim=0)
|
||||
length_list.append(features.size(0))
|
||||
features = torch.nn.functional.pad(
|
||||
features,
|
||||
(0, 0, 0, chunk_length - features.size(0)),
|
||||
mode="constant",
|
||||
value=LOG_EPSILON,
|
||||
)
|
||||
feature_list.append(features)
|
||||
stream_list.append(stream)
|
||||
|
||||
if len(feature_list) == 0:
|
||||
return None, None, None
|
||||
|
||||
features = torch.stack(feature_list, dim=0)
|
||||
lengths = torch.cat(length_list)
|
||||
return features, lengths, stream_list
|
||||
|
||||
|
||||
def process_features(
|
||||
model: nn.Module,
|
||||
features: torch.Tensor,
|
||||
streams: List[FeatureExtractionStream],
|
||||
feature_lens: torch.Tensor,
|
||||
streams: List[Stream],
|
||||
params: AttributeDict,
|
||||
sp: spm.SentencePieceProcessor,
|
||||
) -> None:
|
||||
@ -536,30 +443,20 @@ def process_features(
|
||||
"""
|
||||
assert features.ndim == 3
|
||||
assert features.size(0) == len(streams)
|
||||
batch_size = features.size(0)
|
||||
assert feature_lens.size(0) == len(streams)
|
||||
|
||||
device = model.device
|
||||
features = features.to(device)
|
||||
feature_lens = torch.full(
|
||||
(batch_size,),
|
||||
fill_value=features.size(1),
|
||||
device=device,
|
||||
)
|
||||
|
||||
# Caution: It has a limitation as it assumes that
|
||||
# if one of the stream has an empty state, then all other
|
||||
# streams also have empty states.
|
||||
if streams[0].states is None:
|
||||
states = None
|
||||
else:
|
||||
state_list = [stream.states for stream in streams]
|
||||
states = stack_states(state_list)
|
||||
state_list = [stream.states for stream in streams]
|
||||
states = stack_states(state_list)
|
||||
|
||||
(encoder_out, encoder_out_lens, states,) = model.encoder.streaming_forward(
|
||||
encoder_out, encoder_out_lens, states = model.encoder.infer(
|
||||
features,
|
||||
feature_lens,
|
||||
states,
|
||||
)
|
||||
|
||||
state_list = unstack_states(states)
|
||||
for i, s in enumerate(state_list):
|
||||
streams[i].states = s
|
||||
@ -585,64 +482,80 @@ def process_features(
|
||||
)
|
||||
|
||||
|
||||
def decode_batch(
|
||||
batched_samples: List[torch.Tensor],
|
||||
model: nn.Module,
|
||||
def decode_dataset(
|
||||
params: AttributeDict,
|
||||
cuts: CutSet,
|
||||
model: nn.Module,
|
||||
sp: spm.SentencePieceProcessor,
|
||||
) -> List[str]:
|
||||
"""
|
||||
):
|
||||
"""Decode dataset.
|
||||
Args:
|
||||
batched_samples:
|
||||
A list of 1-D tensors containing the audio samples of each utterance.
|
||||
model:
|
||||
The RNN-T model.
|
||||
params:
|
||||
It is the return value of :func:`get_params`.
|
||||
sp:
|
||||
The BPE model.
|
||||
"""
|
||||
device = next(model.parameters()).device
|
||||
|
||||
# number of frames before subsampling
|
||||
segment_length = model.encoder.segment_length
|
||||
|
||||
right_context_length = model.encoder.right_context_length
|
||||
# 5 = 3 + 2
|
||||
# 1) add 3 here since the subsampling method is using
|
||||
# ((len - 1) // 2 - 1) // 2)
|
||||
# 2) add 2 here we will drop first and last frame after subsampling
|
||||
chunk_length = (segment_length + 5) + right_context_length
|
||||
|
||||
# We add 3 here since the subsampling method is using
|
||||
# ((len - 1) // 2 - 1) // 2)
|
||||
chunk_length = (segment_length + 3) + right_context_length
|
||||
decode_results = []
|
||||
streams = []
|
||||
for num, cut in enumerate(cuts):
|
||||
audio: np.ndarray = cut.load_audio()
|
||||
# audio.shape: (1, num_samples)
|
||||
assert len(audio.shape) == 2
|
||||
assert audio.shape[0] == 1, "Should be single channel"
|
||||
assert audio.dtype == np.float32, audio.dtype
|
||||
|
||||
batch_size = len(batched_samples)
|
||||
streaming_audio_samples = StreamingAudioSamples(batched_samples)
|
||||
# The trained model is using normalized samples
|
||||
assert audio.max() <= 1, "Should be normalized to [-1, 1])"
|
||||
|
||||
stream_list = StreamList(
|
||||
batch_size=batch_size,
|
||||
context_size=params.context_size,
|
||||
decoding_method=params.decoding_method,
|
||||
)
|
||||
samples = torch.from_numpy(audio).squeeze(0)
|
||||
|
||||
while not streaming_audio_samples.done:
|
||||
samples = streaming_audio_samples.get_next()
|
||||
stream_list.accept_waveform(
|
||||
audio_samples=samples,
|
||||
sampling_rate=params.sampling_rate,
|
||||
# Each uttetance has a Stream
|
||||
stream = Stream(
|
||||
params=params,
|
||||
audio_sample=samples,
|
||||
ground_truth=cut.supervisions[0].text,
|
||||
device=device,
|
||||
)
|
||||
features, active_streams = stream_list.build_batch(
|
||||
chunk_length=chunk_length,
|
||||
segment_length=segment_length,
|
||||
)
|
||||
if features is not None:
|
||||
process_features(
|
||||
model=model,
|
||||
features=features,
|
||||
streams=active_streams,
|
||||
params=params,
|
||||
sp=sp,
|
||||
streams.append(stream)
|
||||
|
||||
while len(streams) >= params.num_decode_streams:
|
||||
for stream in streams:
|
||||
stream.accept_waveform()
|
||||
|
||||
# try to build batch
|
||||
features, active_streams = build_batch(
|
||||
chunk_length=chunk_length,
|
||||
segment_length=segment_length,
|
||||
)
|
||||
results = []
|
||||
for stream in stream_list.streams:
|
||||
text = sp.decode(stream.decoding_result())
|
||||
results.append(text)
|
||||
return results
|
||||
if features is not None:
|
||||
process_features(
|
||||
model=model,
|
||||
features=features,
|
||||
streams=active_streams,
|
||||
params=params,
|
||||
sp=sp,
|
||||
)
|
||||
|
||||
new_streams = []
|
||||
for stream in streams:
|
||||
if stream.done:
|
||||
decode_results.append(
|
||||
(
|
||||
stream.ground_truth.split(),
|
||||
sp.decode(stream.decoding_result()).split(),
|
||||
)
|
||||
)
|
||||
else:
|
||||
new_streams.append(stream)
|
||||
del streams
|
||||
streams = new_streams
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
@ -726,8 +639,8 @@ def main():
|
||||
|
||||
samples = torch.from_numpy(audio).squeeze(0)
|
||||
|
||||
batched_samples.append(samples)
|
||||
ground_truth.append(cut.supervisions[0].text)
|
||||
# batched_samples.append(samples)
|
||||
# ground_truth.append(cut.supervisions[0].text)
|
||||
|
||||
if len(batched_samples) >= batch_size:
|
||||
decoded_results = decode_batch(
|
||||
|
@ -42,12 +42,12 @@ def _create_streaming_feature_extractor() -> OnlineFeature:
|
||||
return OnlineFbank(opts)
|
||||
|
||||
|
||||
class FeatureExtractionStream(object):
|
||||
class Stream(object):
|
||||
def __init__(
|
||||
self,
|
||||
params: AttributeDict,
|
||||
context_size: int,
|
||||
decoding_method: str,
|
||||
audio_sample: torch.Tensor,
|
||||
ground_truth: str,
|
||||
device: torch.device = torch.devive("cpu"),
|
||||
) -> None:
|
||||
"""
|
||||
@ -63,6 +63,7 @@ class FeatureExtractionStream(object):
|
||||
# It contains a list of 1-D tensors representing the feature frames.
|
||||
self.feature_frames: List[torch.Tensor] = []
|
||||
self.num_fetched_frames = 0
|
||||
|
||||
# After calling `self.input_finished()`, we set this flag to True
|
||||
self._done = False
|
||||
|
||||
@ -87,20 +88,29 @@ class FeatureExtractionStream(object):
|
||||
self.states = [past_len, attn_caches, conv_caches]
|
||||
|
||||
# It use different attributes for different decoding methods.
|
||||
self.context_size = context_size
|
||||
self.decoding_method = decoding_method
|
||||
if decoding_method == "greedy_search":
|
||||
self.context_size = params.context_size
|
||||
self.decoding_method = params.decoding_method
|
||||
if params.decoding_method == "greedy_search":
|
||||
self.hyp: Optional[List[int]] = None
|
||||
self.decoder_out: Optional[torch.Tensor] = None
|
||||
elif decoding_method == "modified_beam_search":
|
||||
elif params.decoding_method == "modified_beam_search":
|
||||
self.hyps = HypothesisList()
|
||||
else:
|
||||
raise ValueError(f"Unsupported decoding method: {decoding_method}")
|
||||
raise ValueError(
|
||||
f"Unsupported decoding method: {params.decoding_method}"
|
||||
)
|
||||
|
||||
self.sample_rate = params.sample_rate
|
||||
self.audio_sample = audio_sample
|
||||
# Current index of sample
|
||||
self.cur_index = 0
|
||||
|
||||
self.ground_truth = ground_truth
|
||||
|
||||
def accept_waveform(
|
||||
self,
|
||||
sampling_rate: float,
|
||||
waveform: torch.Tensor,
|
||||
# sampling_rate: float,
|
||||
# waveform: torch.Tensor,
|
||||
) -> None:
|
||||
"""Feed audio samples to the feature extractor and compute features
|
||||
if there are enough samples available.
|
||||
@ -120,12 +130,20 @@ class FeatureExtractionStream(object):
|
||||
A 1-D torch tensor of dtype torch.float32 containing audio samples.
|
||||
It should be on CPU.
|
||||
"""
|
||||
start = self.cur_index
|
||||
end = self.cur_index + 1024
|
||||
waveform = self.audio_sample[start:end]
|
||||
self.cur_index = end
|
||||
|
||||
self.feature_extractor.accept_waveform(
|
||||
sampling_rate=sampling_rate,
|
||||
sampling_rate=self.sampling_rate,
|
||||
waveform=waveform,
|
||||
)
|
||||
self._fetch_frames()
|
||||
|
||||
if waveform.numel() == 0:
|
||||
self.input_finished()
|
||||
|
||||
def input_finished(self) -> None:
|
||||
"""Signal that no more audio samples available and the feature
|
||||
extractor should flush the buffered samples to compute frames.
|
||||
|
Loading…
x
Reference in New Issue
Block a user