mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-10 18:42:19 +00:00
Add streaming feature extractor. (#302)
* Add streaming feature extractor. * Parallel streaming decode with greedy search. * Fix typos. * Use torch.stack() to replace torch.cat()
This commit is contained in:
parent
7f73043219
commit
0f45356ee6
4
.flake8
4
.flake8
@ -15,3 +15,7 @@ exclude =
|
|||||||
**/data/**,
|
**/data/**,
|
||||||
icefall/shared/make_kn_lm.py,
|
icefall/shared/make_kn_lm.py,
|
||||||
icefall/__init__.py
|
icefall/__init__.py
|
||||||
|
|
||||||
|
ignore =
|
||||||
|
# E203 whitespace before ':'
|
||||||
|
E203,
|
||||||
|
@ -32,13 +32,16 @@ class Joiner(nn.Module):
|
|||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
encoder_out:
|
encoder_out:
|
||||||
Output from the encoder. Its shape is (N, T, s_range, C).
|
Output from the encoder. Its shape is (N, T, s_range, C) for
|
||||||
|
training and (N, C) for streaming decoding.
|
||||||
decoder_out:
|
decoder_out:
|
||||||
Output from the decoder. Its shape is (N, T, s_range, C).
|
Output from the decoder. Its shape is (N, T, s_range, C) for
|
||||||
|
training and (N, C) for streaming decoding.
|
||||||
Returns:
|
Returns:
|
||||||
Return a tensor of shape (N, T, s_range, C).
|
Return a tensor of shape (N, T, s_range, C).
|
||||||
"""
|
"""
|
||||||
assert encoder_out.ndim == decoder_out.ndim == 4
|
assert encoder_out.ndim == decoder_out.ndim
|
||||||
|
assert encoder_out.ndim in (2, 4)
|
||||||
assert encoder_out.shape == decoder_out.shape
|
assert encoder_out.shape == decoder_out.shape
|
||||||
|
|
||||||
logit = encoder_out + decoder_out
|
logit = encoder_out + decoder_out
|
||||||
|
@ -27,6 +27,75 @@ from torchaudio.models import Emformer as _Emformer
|
|||||||
LOG_EPSILON = math.log(1e-10)
|
LOG_EPSILON = math.log(1e-10)
|
||||||
|
|
||||||
|
|
||||||
|
def unstack_states(
|
||||||
|
states: List[List[torch.Tensor]],
|
||||||
|
) -> List[List[List[torch.Tensor]]]:
|
||||||
|
"""Unstack the emformer state corresponding to a batch of utterances
|
||||||
|
into a list of states, were the i-th entry is the state from the i-th
|
||||||
|
utterance in the batch.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
states:
|
||||||
|
A list-of-list of tensors. ``len(states)`` equals to number of
|
||||||
|
layers in the emformer. ``states[i]]`` contains the states for
|
||||||
|
the i-th layer. ``states[i][k]`` is either a 3-D tensor of shape
|
||||||
|
``(T, N, C)`` or a 2-D tensor of shape ``(C, N)``
|
||||||
|
"""
|
||||||
|
batch_size = states[0][0].size(1)
|
||||||
|
num_layers = len(states)
|
||||||
|
|
||||||
|
ans = [None] * batch_size
|
||||||
|
for i in range(batch_size):
|
||||||
|
ans[i] = [[] for _ in range(num_layers)]
|
||||||
|
|
||||||
|
for li, layer in enumerate(states):
|
||||||
|
for s in layer:
|
||||||
|
s_list = s.unbind(dim=1)
|
||||||
|
# We will use stack(dim=1) later in stack_states()
|
||||||
|
for bi, b in enumerate(ans):
|
||||||
|
b[li].append(s_list[bi])
|
||||||
|
return ans
|
||||||
|
|
||||||
|
|
||||||
|
def stack_states(
|
||||||
|
state_list: List[List[List[torch.Tensor]]],
|
||||||
|
) -> List[List[torch.Tensor]]:
|
||||||
|
"""Stack list of emformer states that correspond to separate utterances
|
||||||
|
into a single emformer state so that it can be used as an input for
|
||||||
|
emformer when those utterances are formed into a batch.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
It is the inverse of :func:`unstack_states`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state_list:
|
||||||
|
Each element in state_list corresponding to the internal state
|
||||||
|
of the emformer model for a single utterance.
|
||||||
|
Returns:
|
||||||
|
Return a new state corresponding to a batch of utterances.
|
||||||
|
See the input argument of :func:`unstack_states` for the meaning
|
||||||
|
of the returned tensor.
|
||||||
|
"""
|
||||||
|
batch_size = len(state_list)
|
||||||
|
ans = []
|
||||||
|
for layer in state_list[0]:
|
||||||
|
# layer is a list of tensors
|
||||||
|
if batch_size > 1:
|
||||||
|
ans.append([[s] for s in layer])
|
||||||
|
# Note: We will stack ans[layer][s][] later to get ans[layer][s]
|
||||||
|
else:
|
||||||
|
ans.append([s.unsqueeze(1) for s in layer])
|
||||||
|
|
||||||
|
for b, states in enumerate(state_list[1:], 1):
|
||||||
|
for li, layer in enumerate(states):
|
||||||
|
for si, s in enumerate(layer):
|
||||||
|
ans[li][si].append(s)
|
||||||
|
if b == batch_size - 1:
|
||||||
|
ans[li][si] = torch.stack(ans[li][si], dim=1)
|
||||||
|
# We will use unbind(dim=1) later in unstack_states()
|
||||||
|
return ans
|
||||||
|
|
||||||
|
|
||||||
class Emformer(EncoderInterface):
|
class Emformer(EncoderInterface):
|
||||||
"""This is just a simple wrapper around torchaudio.models.Emformer.
|
"""This is just a simple wrapper around torchaudio.models.Emformer.
|
||||||
We may replace it with our own implementation some time later.
|
We may replace it with our own implementation some time later.
|
||||||
@ -63,11 +132,11 @@ class Emformer(EncoderInterface):
|
|||||||
num_encoder_layers:
|
num_encoder_layers:
|
||||||
Number of encoder layers.
|
Number of encoder layers.
|
||||||
segment_length:
|
segment_length:
|
||||||
Number of frames per segment.
|
Number of frames per segment before subsampling.
|
||||||
left_context_length:
|
left_context_length:
|
||||||
Number of frames in the left context.
|
Number of frames in the left context before subsampling.
|
||||||
right_context_length:
|
right_context_length:
|
||||||
Number of frames in the right context.
|
Number of frames in the right context before subsampling.
|
||||||
max_memory_size:
|
max_memory_size:
|
||||||
TODO.
|
TODO.
|
||||||
dropout:
|
dropout:
|
||||||
@ -94,6 +163,7 @@ class Emformer(EncoderInterface):
|
|||||||
else:
|
else:
|
||||||
self.encoder_embed = Conv2dSubsampling(num_features, d_model)
|
self.encoder_embed = Conv2dSubsampling(num_features, d_model)
|
||||||
|
|
||||||
|
self.segment_length = segment_length
|
||||||
self.right_context_length = right_context_length
|
self.right_context_length = right_context_length
|
||||||
|
|
||||||
assert right_context_length % subsampling_factor == 0
|
assert right_context_length % subsampling_factor == 0
|
||||||
|
184
egs/librispeech/ASR/transducer_emformer/export.py
Executable file
184
egs/librispeech/ASR/transducer_emformer/export.py
Executable file
@ -0,0 +1,184 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
#
|
||||||
|
# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
# This script converts several saved checkpoints
|
||||||
|
# to a single one using model averaging.
|
||||||
|
"""
|
||||||
|
Usage:
|
||||||
|
./transducer_emformer/export.py \
|
||||||
|
--exp-dir ./transducer_emformer/exp \
|
||||||
|
--bpe-model data/lang_bpe_500/bpe.model \
|
||||||
|
--epoch 20 \
|
||||||
|
--avg 10
|
||||||
|
|
||||||
|
It will generate a file exp_dir/pretrained.pt
|
||||||
|
|
||||||
|
To use the generated file with `transducer_emformer/decode.py`,
|
||||||
|
you can do:
|
||||||
|
|
||||||
|
cd /path/to/exp_dir
|
||||||
|
ln -s pretrained.pt epoch-9999.pt
|
||||||
|
|
||||||
|
cd /path/to/egs/librispeech/ASR
|
||||||
|
./transducer_emformer/decode.py \
|
||||||
|
--exp-dir ./transducer_emformer/exp \
|
||||||
|
--epoch 9999 \
|
||||||
|
--avg 1 \
|
||||||
|
--max-duration 1000 \
|
||||||
|
--bpe-model data/lang_bpe_500/bpe.model
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import sentencepiece as spm
|
||||||
|
import torch
|
||||||
|
from train import add_model_arguments, get_params, get_transducer_model
|
||||||
|
|
||||||
|
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
||||||
|
from icefall.utils import str2bool
|
||||||
|
|
||||||
|
|
||||||
|
def get_parser():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--epoch",
|
||||||
|
type=int,
|
||||||
|
default=28,
|
||||||
|
help="It specifies the checkpoint to use for decoding."
|
||||||
|
"Note: Epoch counts from 0.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--avg",
|
||||||
|
type=int,
|
||||||
|
default=15,
|
||||||
|
help="Number of checkpoints to average. Automatically select "
|
||||||
|
"consecutive checkpoints before the checkpoint specified by "
|
||||||
|
"'--epoch'. ",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--exp-dir",
|
||||||
|
type=str,
|
||||||
|
default="pruned_transducer_stateless/exp",
|
||||||
|
help="""It specifies the directory where all training related
|
||||||
|
files, e.g., checkpoints, log, etc, are saved
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--bpe-model",
|
||||||
|
type=str,
|
||||||
|
default="data/lang_bpe_500/bpe.model",
|
||||||
|
help="Path to the BPE model",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--jit",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="""True to save a model after applying torch.jit.script.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--context-size",
|
||||||
|
type=int,
|
||||||
|
default=2,
|
||||||
|
help="The context size in the decoder. 1 means bigram; "
|
||||||
|
"2 means tri-gram",
|
||||||
|
)
|
||||||
|
|
||||||
|
add_model_arguments(parser)
|
||||||
|
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
args = get_parser().parse_args()
|
||||||
|
args.exp_dir = Path(args.exp_dir)
|
||||||
|
|
||||||
|
assert args.jit is False, "Support torchscript will be added later"
|
||||||
|
|
||||||
|
params = get_params()
|
||||||
|
params.update(vars(args))
|
||||||
|
|
||||||
|
device = torch.device("cpu")
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
device = torch.device("cuda", 0)
|
||||||
|
|
||||||
|
logging.info(f"device: {device}")
|
||||||
|
|
||||||
|
sp = spm.SentencePieceProcessor()
|
||||||
|
sp.load(params.bpe_model)
|
||||||
|
|
||||||
|
# <blk> is defined in local/train_bpe_model.py
|
||||||
|
params.blank_id = sp.piece_to_id("<blk>")
|
||||||
|
params.vocab_size = sp.get_piece_size()
|
||||||
|
|
||||||
|
logging.info(params)
|
||||||
|
|
||||||
|
logging.info("About to create model")
|
||||||
|
model = get_transducer_model(params)
|
||||||
|
|
||||||
|
model.to(device)
|
||||||
|
|
||||||
|
if params.avg == 1:
|
||||||
|
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||||
|
else:
|
||||||
|
start = params.epoch - params.avg + 1
|
||||||
|
filenames = []
|
||||||
|
for i in range(start, params.epoch + 1):
|
||||||
|
if start >= 0:
|
||||||
|
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||||
|
logging.info(f"averaging {filenames}")
|
||||||
|
model.to(device)
|
||||||
|
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||||
|
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
model.to("cpu")
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
if params.jit:
|
||||||
|
logging.info("Using torch.jit.script")
|
||||||
|
model = torch.jit.script(model)
|
||||||
|
filename = params.exp_dir / "cpu_jit.pt"
|
||||||
|
model.save(str(filename))
|
||||||
|
logging.info(f"Saved to {filename}")
|
||||||
|
else:
|
||||||
|
logging.info("Not using torch.jit.script")
|
||||||
|
# Save it using a format so that it can be loaded
|
||||||
|
# by :func:`load_checkpoint`
|
||||||
|
filename = params.exp_dir / "pretrained.pt"
|
||||||
|
torch.save({"model": model.state_dict()}, str(filename))
|
||||||
|
logging.info(f"Saved to {filename}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
formatter = (
|
||||||
|
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
|
)
|
||||||
|
|
||||||
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
|
main()
|
@ -18,16 +18,16 @@
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
import time
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Optional
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
import kaldifeat
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import sentencepiece as spm
|
import sentencepiece as spm
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from asr_datamodule import LibriSpeechAsrDataModule
|
from asr_datamodule import LibriSpeechAsrDataModule
|
||||||
|
from emformer import LOG_EPSILON, stack_states, unstack_states
|
||||||
|
from streaming_feature_extractor import FeatureExtractionStream
|
||||||
from train import add_model_arguments, get_params, get_transducer_model
|
from train import add_model_arguments, get_params, get_transducer_model
|
||||||
|
|
||||||
from icefall.checkpoint import (
|
from icefall.checkpoint import (
|
||||||
@ -147,10 +147,10 @@ def get_parser():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--sample-rate",
|
"--sampling-rate",
|
||||||
type=int,
|
type=float,
|
||||||
default=16000,
|
default=16000,
|
||||||
help="The sample rate of the input sound file",
|
help="Sample rate of the audio",
|
||||||
)
|
)
|
||||||
|
|
||||||
add_model_arguments(parser)
|
add_model_arguments(parser)
|
||||||
@ -158,115 +158,352 @@ def get_parser():
|
|||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
def get_feature_extractor(
|
class StreamingAudioSamples(object):
|
||||||
params: AttributeDict,
|
"""This class takes as input a list of audio samples and returns
|
||||||
) -> kaldifeat.Fbank:
|
them in a streaming fashion.
|
||||||
logging.info("Constructing Fbank computer")
|
"""
|
||||||
opts = kaldifeat.FbankOptions()
|
|
||||||
opts.device = params.device
|
|
||||||
opts.frame_opts.dither = 0
|
|
||||||
opts.frame_opts.snip_edges = True
|
|
||||||
opts.frame_opts.samp_freq = params.sample_rate
|
|
||||||
opts.mel_opts.num_bins = params.feature_dim
|
|
||||||
|
|
||||||
return kaldifeat.Fbank(opts)
|
def __init__(self, samples: List[torch.Tensor]) -> None:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
samples:
|
||||||
|
A list of audio samples. Each entry is a 1-D tensor of dtype
|
||||||
|
torch.float32, containing the audio samples of an utterance.
|
||||||
|
"""
|
||||||
|
self.samples = samples
|
||||||
|
self.cur_indexes = [0] * len(self.samples)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def done(self) -> bool:
|
||||||
|
"""Return True if all samples have been processed.
|
||||||
|
Return False otherwise.
|
||||||
|
"""
|
||||||
|
for i, samples in zip(self.cur_indexes, self.samples):
|
||||||
|
if i < samples.numel():
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
def get_next(self) -> List[torch.Tensor]:
|
||||||
|
"""Return a list of audio samples. Each entry may have different
|
||||||
|
lengths. It is OK if an entry contains no samples at all, which
|
||||||
|
means it reaches the end of the utterance.
|
||||||
|
"""
|
||||||
|
ans = []
|
||||||
|
|
||||||
|
num = [1024] * len(self.samples)
|
||||||
|
|
||||||
|
for i in range(len(self.samples)):
|
||||||
|
start = self.cur_indexes[i]
|
||||||
|
end = start + num[i]
|
||||||
|
self.cur_indexes[i] = end
|
||||||
|
|
||||||
|
s = self.samples[i][start:end]
|
||||||
|
ans.append(s)
|
||||||
|
|
||||||
|
return ans
|
||||||
|
|
||||||
|
|
||||||
def decode_one_utterance(
|
class StreamList(object):
|
||||||
audio_samples: torch.Tensor,
|
def __init__(
|
||||||
|
self,
|
||||||
|
batch_size: int,
|
||||||
|
context_size: int,
|
||||||
|
blank_id: int,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
batch_size:
|
||||||
|
Size of this batch.
|
||||||
|
context_size:
|
||||||
|
Context size of the RNN-T decoder model.
|
||||||
|
blank_id:
|
||||||
|
The ID of the blank symbol of the BPE model.
|
||||||
|
"""
|
||||||
|
self.streams = [
|
||||||
|
FeatureExtractionStream(
|
||||||
|
context_size=context_size, blank_id=blank_id
|
||||||
|
)
|
||||||
|
for _ in range(batch_size)
|
||||||
|
]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def done(self) -> bool:
|
||||||
|
"""Return True if all streams have reached end of utterance.
|
||||||
|
That is, no more audio samples are available for all utterances.
|
||||||
|
"""
|
||||||
|
return all(stream.done for stream in self.streams)
|
||||||
|
|
||||||
|
def accept_waveform(
|
||||||
|
self,
|
||||||
|
audio_samples: List[torch.Tensor],
|
||||||
|
sampling_rate: float,
|
||||||
|
):
|
||||||
|
"""Feeed audio samples to each stream.
|
||||||
|
Args:
|
||||||
|
audio_samples:
|
||||||
|
A list of 1-D tensors containing the audio samples for each
|
||||||
|
utterance in the batch. If an entry is empty, it means
|
||||||
|
end-of-utterance has been reached.
|
||||||
|
sampling_rate:
|
||||||
|
Sampling rate of the given audio samples.
|
||||||
|
"""
|
||||||
|
assert len(audio_samples) == len(self.streams)
|
||||||
|
for stream, samples in zip(self.streams, audio_samples):
|
||||||
|
|
||||||
|
if stream.done:
|
||||||
|
assert samples.numel() == 0
|
||||||
|
continue
|
||||||
|
|
||||||
|
stream.accept_waveform(
|
||||||
|
sampling_rate=sampling_rate,
|
||||||
|
waveform=samples,
|
||||||
|
)
|
||||||
|
|
||||||
|
if samples.numel() == 0:
|
||||||
|
stream.input_finished()
|
||||||
|
|
||||||
|
def build_batch(
|
||||||
|
self,
|
||||||
|
chunk_length: int,
|
||||||
|
segment_length: int,
|
||||||
|
) -> Tuple[Optional[torch.Tensor], Optional[List[FeatureExtractionStream]]]:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
chunk_length:
|
||||||
|
Number of frames for each chunk. It equals to
|
||||||
|
``segment_length + right_context_length``.
|
||||||
|
segment_length
|
||||||
|
Number of frames for each segment.
|
||||||
|
Returns:
|
||||||
|
Return a tuple containing:
|
||||||
|
- features, a 3-D tensor of shape ``(num_active_streams, T, C)``
|
||||||
|
- active_streams, a list of active streams. We say a stream is
|
||||||
|
active when it has enough feature frames to be fed into the
|
||||||
|
encoder model.
|
||||||
|
"""
|
||||||
|
feature_list = []
|
||||||
|
stream_list = []
|
||||||
|
for stream in self.streams:
|
||||||
|
if len(stream.feature_frames) >= chunk_length:
|
||||||
|
# this_chunk is a list of tensors, each of which
|
||||||
|
# has a shape (1, feature_dim)
|
||||||
|
chunk = stream.feature_frames[:chunk_length]
|
||||||
|
stream.feature_frames = stream.feature_frames[segment_length:]
|
||||||
|
features = torch.cat(chunk, dim=0)
|
||||||
|
feature_list.append(features)
|
||||||
|
stream_list.append(stream)
|
||||||
|
elif stream.done and len(stream.feature_frames) > 0:
|
||||||
|
chunk = stream.feature_frames[:chunk_length]
|
||||||
|
stream.feature_frames = []
|
||||||
|
features = torch.cat(chunk, dim=0)
|
||||||
|
features = torch.nn.functional.pad(
|
||||||
|
features,
|
||||||
|
(0, 0, 0, chunk_length - features.size(0)),
|
||||||
|
mode="constant",
|
||||||
|
value=LOG_EPSILON,
|
||||||
|
)
|
||||||
|
feature_list.append(features)
|
||||||
|
stream_list.append(stream)
|
||||||
|
|
||||||
|
if len(feature_list) == 0:
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
features = torch.stack(feature_list, dim=0)
|
||||||
|
return features, stream_list
|
||||||
|
|
||||||
|
|
||||||
|
def greedy_search(
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
fbank: kaldifeat.Fbank,
|
streams: List[FeatureExtractionStream],
|
||||||
params: AttributeDict,
|
encoder_out: torch.Tensor,
|
||||||
sp: spm.SentencePieceProcessor,
|
sp: spm.SentencePieceProcessor,
|
||||||
):
|
):
|
||||||
"""Decode one utterance.
|
"""
|
||||||
Args:
|
Args:
|
||||||
audio_samples:
|
|
||||||
A 1-D float32 tensor of shape (num_samples,) containing the normalized
|
|
||||||
audio samples. Normalized means the samples is in the range [-1, 1].
|
|
||||||
model:
|
model:
|
||||||
The RNN-T model.
|
The RNN-T model.
|
||||||
feature_extractor:
|
stream:
|
||||||
The feature extractor.
|
A stream object.
|
||||||
|
encoder_out:
|
||||||
|
A 3-D tensor of shape (N, T, encoder_out_dim) containing the output of
|
||||||
|
the encoder model.
|
||||||
|
sp:
|
||||||
|
The BPE model.
|
||||||
|
"""
|
||||||
|
blank_id = model.decoder.blank_id
|
||||||
|
context_size = model.decoder.context_size
|
||||||
|
device = model.device
|
||||||
|
|
||||||
|
if streams[0].decoder_out is None:
|
||||||
|
decoder_input = torch.tensor(
|
||||||
|
[stream.hyp.ys[-context_size:] for stream in streams],
|
||||||
|
device=device,
|
||||||
|
dtype=torch.int64,
|
||||||
|
)
|
||||||
|
decoder_out = model.decoder(
|
||||||
|
decoder_input,
|
||||||
|
need_pad=False,
|
||||||
|
).squeeze(1)
|
||||||
|
# decoder_out is of shape (N, decoder_out_dim)
|
||||||
|
else:
|
||||||
|
decoder_out = torch.stack(
|
||||||
|
[stream.decoder_out for stream in streams],
|
||||||
|
dim=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert encoder_out.ndim == 3
|
||||||
|
|
||||||
|
T = encoder_out.size(1)
|
||||||
|
for t in range(T):
|
||||||
|
current_encoder_out = encoder_out[:, t]
|
||||||
|
# current_encoder_out's shape: (batch_size, encoder_out_dim)
|
||||||
|
|
||||||
|
logits = model.joiner(current_encoder_out, decoder_out)
|
||||||
|
# logits'shape (batch_size, vocab_size)
|
||||||
|
|
||||||
|
assert logits.ndim == 2, logits.shape
|
||||||
|
y = logits.argmax(dim=1).tolist()
|
||||||
|
emitted = False
|
||||||
|
for i, v in enumerate(y):
|
||||||
|
if v != blank_id:
|
||||||
|
streams[i].hyp.ys.append(v)
|
||||||
|
emitted = True
|
||||||
|
|
||||||
|
if emitted:
|
||||||
|
# update decoder output
|
||||||
|
decoder_input = torch.tensor(
|
||||||
|
[stream.hyp.ys[-context_size:] for stream in streams],
|
||||||
|
device=device,
|
||||||
|
dtype=torch.int64,
|
||||||
|
)
|
||||||
|
decoder_out = model.decoder(decoder_input, need_pad=False).squeeze(
|
||||||
|
1
|
||||||
|
)
|
||||||
|
|
||||||
|
for k, s in enumerate(streams):
|
||||||
|
logging.info(
|
||||||
|
f"Partial result {k}:\n{sp.decode(s.hyp.ys[context_size:])}"
|
||||||
|
)
|
||||||
|
|
||||||
|
decoder_out_list = decoder_out.unbind(dim=0)
|
||||||
|
|
||||||
|
for i, d in enumerate(decoder_out_list):
|
||||||
|
streams[i].decoder_out = d
|
||||||
|
|
||||||
|
|
||||||
|
def process_features(
|
||||||
|
model: nn.Module,
|
||||||
|
features: torch.Tensor,
|
||||||
|
streams: List[FeatureExtractionStream],
|
||||||
|
sp: spm.SentencePieceProcessor,
|
||||||
|
) -> None:
|
||||||
|
"""Process features for each stream in parallel.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model:
|
||||||
|
The RNN-T model.
|
||||||
|
features:
|
||||||
|
A 3-D tensor of shape (N, T, C).
|
||||||
|
streams:
|
||||||
|
A list of streams of size (N,).
|
||||||
|
sp:
|
||||||
|
The BPE model.
|
||||||
|
"""
|
||||||
|
assert features.ndim == 3
|
||||||
|
assert features.size(0) == len(streams)
|
||||||
|
batch_size = features.size(0)
|
||||||
|
|
||||||
|
device = model.device
|
||||||
|
features = features.to(device)
|
||||||
|
feature_lens = torch.full(
|
||||||
|
(batch_size,),
|
||||||
|
fill_value=features.size(1),
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Caution: It has a limitation as it assumes that
|
||||||
|
# if one of the stream has an empty state, then all other
|
||||||
|
# streams also have empty states.
|
||||||
|
if streams[0].states is None:
|
||||||
|
states = None
|
||||||
|
else:
|
||||||
|
state_list = [stream.states for stream in streams]
|
||||||
|
states = stack_states(state_list)
|
||||||
|
|
||||||
|
(encoder_out, encoder_out_lens, states,) = model.encoder.streaming_forward(
|
||||||
|
features,
|
||||||
|
feature_lens,
|
||||||
|
states,
|
||||||
|
)
|
||||||
|
state_list = unstack_states(states)
|
||||||
|
for i, s in enumerate(state_list):
|
||||||
|
streams[i].states = s
|
||||||
|
|
||||||
|
greedy_search(
|
||||||
|
model=model,
|
||||||
|
streams=streams,
|
||||||
|
encoder_out=encoder_out,
|
||||||
|
sp=sp,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def decode_batch(
|
||||||
|
batched_samples: List[torch.Tensor],
|
||||||
|
model: nn.Module,
|
||||||
|
params: AttributeDict,
|
||||||
|
sp: spm.SentencePieceProcessor,
|
||||||
|
) -> List[str]:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
batched_samples:
|
||||||
|
A list of 1-D tensors containing the audio samples of each utterance.
|
||||||
|
model:
|
||||||
|
The RNN-T model.
|
||||||
params:
|
params:
|
||||||
It is the return value of :func:`get_params`.
|
It is the return value of :func:`get_params`.
|
||||||
sp:
|
sp:
|
||||||
The BPE model.
|
The BPE model.
|
||||||
"""
|
"""
|
||||||
sample_rate = params.sample_rate
|
# number of frames before subsampling
|
||||||
frame_shift = sample_rate * fbank.opts.frame_opts.frame_shift_ms / 1000
|
segment_length = model.encoder.segment_length
|
||||||
|
|
||||||
frame_shift = int(frame_shift) # number of samples
|
right_context_length = model.encoder.right_context_length
|
||||||
|
|
||||||
# Note: We add 3 here because the subsampling method ((n-1)//2-1))//2
|
# We add 3 here since the subsampling method is using
|
||||||
# is not equal to n//4. We will switch to a subsampling method that
|
# ((len - 1) // 2 - 1) // 2)
|
||||||
# satisfies n//4, where n is the number of input frames.
|
chunk_length = (segment_length + 3) + right_context_length
|
||||||
segment_length = (params.segment_length + 3) * frame_shift
|
|
||||||
|
|
||||||
right_context_length = params.right_context_length * frame_shift
|
batch_size = len(batched_samples)
|
||||||
chunk_size = segment_length + right_context_length
|
streaming_audio_samples = StreamingAudioSamples(batched_samples)
|
||||||
|
|
||||||
opts = fbank.opts.frame_opts
|
stream_list = StreamList(
|
||||||
chunk_size += (
|
batch_size=batch_size,
|
||||||
(opts.frame_length_ms - opts.frame_shift_ms) / 1000 * sample_rate
|
context_size=params.context_size,
|
||||||
|
blank_id=params.blank_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
chunk_size = int(chunk_size)
|
while not streaming_audio_samples.done:
|
||||||
|
samples = streaming_audio_samples.get_next()
|
||||||
states: Optional[List[List[torch.Tensor]]] = None
|
stream_list.accept_waveform(
|
||||||
|
audio_samples=samples,
|
||||||
blank_id = model.decoder.blank_id
|
sampling_rate=params.sampling_rate,
|
||||||
context_size = model.decoder.context_size
|
|
||||||
|
|
||||||
device = model.device
|
|
||||||
|
|
||||||
hyp = [blank_id] * context_size
|
|
||||||
|
|
||||||
decoder_input = torch.tensor(hyp, device=device, dtype=torch.int64).reshape(
|
|
||||||
1, context_size
|
|
||||||
)
|
|
||||||
|
|
||||||
decoder_out = model.decoder(decoder_input, need_pad=False)
|
|
||||||
|
|
||||||
i = 0
|
|
||||||
num_samples = audio_samples.size(0)
|
|
||||||
while i < num_samples:
|
|
||||||
# Note: The current approach of computing the features is not ideal
|
|
||||||
# since it re-computes the features for the right context.
|
|
||||||
chunk = audio_samples[i : i + chunk_size] # noqa
|
|
||||||
i += segment_length
|
|
||||||
if chunk.size(0) < chunk_size:
|
|
||||||
chunk = torch.nn.functional.pad(
|
|
||||||
chunk, pad=(0, chunk_size - chunk.size(0))
|
|
||||||
)
|
|
||||||
features = fbank(chunk)
|
|
||||||
feature_lens = torch.tensor([features.size(0)], device=params.device)
|
|
||||||
|
|
||||||
features = features.unsqueeze(0) # (1, T, C)
|
|
||||||
|
|
||||||
encoder_out, encoder_out_lens, states = model.encoder.streaming_forward(
|
|
||||||
features,
|
|
||||||
feature_lens,
|
|
||||||
states,
|
|
||||||
)
|
)
|
||||||
for t in range(encoder_out_lens.item()):
|
features, active_streams = stream_list.build_batch(
|
||||||
# fmt: off
|
chunk_length=chunk_length,
|
||||||
current_encoder_out = encoder_out[0:1, t:t+1, :].unsqueeze(2)
|
segment_length=segment_length,
|
||||||
# fmt: on
|
)
|
||||||
logits = model.joiner(current_encoder_out, decoder_out.unsqueeze(1))
|
if features is not None:
|
||||||
# logits is (1, 1, 1, vocab_size)
|
process_features(
|
||||||
y = logits.argmax().item()
|
model=model,
|
||||||
if y == blank_id:
|
features=features,
|
||||||
continue
|
streams=active_streams,
|
||||||
|
sp=sp,
|
||||||
hyp.append(y)
|
)
|
||||||
|
results = []
|
||||||
decoder_input = torch.tensor(
|
for s in stream_list.streams:
|
||||||
[hyp[-context_size:]], device=device, dtype=torch.int64
|
text = sp.decode(s.hyp.ys[params.context_size :])
|
||||||
).reshape(1, context_size)
|
results.append(text)
|
||||||
|
return results
|
||||||
decoder_out = model.decoder(decoder_input, need_pad=False)
|
|
||||||
logging.info(f"Partial result:\n{sp.decode(hyp[context_size:])}")
|
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
@ -333,30 +570,43 @@ def main():
|
|||||||
|
|
||||||
test_clean_cuts = librispeech.test_clean_cuts()
|
test_clean_cuts = librispeech.test_clean_cuts()
|
||||||
|
|
||||||
fbank = get_feature_extractor(params)
|
batch_size = 3
|
||||||
|
|
||||||
|
ground_truth = []
|
||||||
|
batched_samples = []
|
||||||
for num, cut in enumerate(test_clean_cuts):
|
for num, cut in enumerate(test_clean_cuts):
|
||||||
logging.info("Processing {num}")
|
|
||||||
|
|
||||||
audio: np.ndarray = cut.load_audio()
|
audio: np.ndarray = cut.load_audio()
|
||||||
# audio.shape: (1, num_samples)
|
# audio.shape: (1, num_samples)
|
||||||
assert len(audio.shape) == 2
|
assert len(audio.shape) == 2
|
||||||
assert audio.shape[0] == 1, "Should be single channel"
|
assert audio.shape[0] == 1, "Should be single channel"
|
||||||
assert audio.dtype == np.float32, audio.dtype
|
assert audio.dtype == np.float32, audio.dtype
|
||||||
assert audio.max() <= 1, "Should be normalized to [-1, 1])"
|
|
||||||
decode_one_utterance(
|
|
||||||
audio_samples=torch.from_numpy(audio).squeeze(0).to(device),
|
|
||||||
model=model,
|
|
||||||
fbank=fbank,
|
|
||||||
params=params,
|
|
||||||
sp=sp,
|
|
||||||
)
|
|
||||||
|
|
||||||
logging.info(f"The ground truth is:\n{cut.supervisions[0].text}")
|
# The trained model is using normalized samples
|
||||||
if num >= 0:
|
assert audio.max() <= 1, "Should be normalized to [-1, 1])"
|
||||||
|
|
||||||
|
samples = torch.from_numpy(audio).squeeze(0)
|
||||||
|
|
||||||
|
batched_samples.append(samples)
|
||||||
|
ground_truth.append(cut.supervisions[0].text)
|
||||||
|
|
||||||
|
if len(batched_samples) >= batch_size:
|
||||||
|
decoded_results = decode_batch(
|
||||||
|
batched_samples=batched_samples,
|
||||||
|
model=model,
|
||||||
|
params=params,
|
||||||
|
sp=sp,
|
||||||
|
)
|
||||||
|
s = "\n"
|
||||||
|
for i, (hyp, ref) in enumerate(zip(decoded_results, ground_truth)):
|
||||||
|
s += f"hyp {i}:\n{hyp}\n"
|
||||||
|
s += f"ref {i}:\n{ref}\n\n"
|
||||||
|
logging.info(s)
|
||||||
|
batched_samples = []
|
||||||
|
ground_truth = []
|
||||||
|
# break after processing the first batch for test purposes
|
||||||
break
|
break
|
||||||
time.sleep(2) # So that you can see the decoded results
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
torch.manual_seed(20220410)
|
||||||
main()
|
main()
|
||||||
|
@ -0,0 +1,116 @@
|
|||||||
|
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from beam_search import Hypothesis
|
||||||
|
from kaldifeat import FbankOptions, OnlineFbank, OnlineFeature
|
||||||
|
|
||||||
|
|
||||||
|
def _create_streaming_feature_extractor() -> OnlineFeature:
|
||||||
|
"""Create a CPU streaming feature extractor.
|
||||||
|
|
||||||
|
At present, we assume it returns a fbank feature extractor with
|
||||||
|
fixed options. In the future, we will support passing in the options
|
||||||
|
from outside.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Return a CPU streaming feature extractor.
|
||||||
|
"""
|
||||||
|
opts = FbankOptions()
|
||||||
|
opts.device = "cpu"
|
||||||
|
opts.frame_opts.dither = 0
|
||||||
|
opts.frame_opts.snip_edges = False
|
||||||
|
opts.frame_opts.samp_freq = 16000
|
||||||
|
opts.mel_opts.num_bins = 80
|
||||||
|
return OnlineFbank(opts)
|
||||||
|
|
||||||
|
|
||||||
|
class FeatureExtractionStream(object):
|
||||||
|
def __init__(self, context_size: int, blank_id: int = 0) -> None:
|
||||||
|
"""Context size of the RNN-T decoder model."""
|
||||||
|
self.feature_extractor = _create_streaming_feature_extractor()
|
||||||
|
self.hyp = Hypothesis(
|
||||||
|
ys=([blank_id] * context_size),
|
||||||
|
log_prob=torch.tensor([0.0]),
|
||||||
|
) # for greedy search, will extend it to beam search
|
||||||
|
|
||||||
|
# It contains a list of 1-D tensors representing the feature frames.
|
||||||
|
self.feature_frames: List[torch.Tensor] = []
|
||||||
|
|
||||||
|
self.num_fetched_frames = 0
|
||||||
|
|
||||||
|
# For the emformer model, it contains the states of each
|
||||||
|
# encoder layer.
|
||||||
|
self.states: Optional[List[List[torch.Tensor]]] = None
|
||||||
|
|
||||||
|
# For the RNN-T decoder, it contains the decoder output
|
||||||
|
# corresponding to the decoder input self.hyp.ys[-context_size:]
|
||||||
|
# Its shape is (decoder_out_dim,)
|
||||||
|
self.decoder_out: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
|
# After calling `self.input_finished()`, we set this flag to True
|
||||||
|
self._done = False
|
||||||
|
|
||||||
|
def accept_waveform(
|
||||||
|
self,
|
||||||
|
sampling_rate: float,
|
||||||
|
waveform: torch.Tensor,
|
||||||
|
) -> None:
|
||||||
|
"""Feed audio samples to the feature extractor and compute features
|
||||||
|
if there are enough samples available.
|
||||||
|
|
||||||
|
Caution:
|
||||||
|
The range of the audio samples should match the one used in the
|
||||||
|
training. That is, if you use the range [-1, 1] in the training, then
|
||||||
|
the input audio samples should also be normalized to [-1, 1].
|
||||||
|
|
||||||
|
Args
|
||||||
|
sampling_rate:
|
||||||
|
The sampling rate of the input audio samples. It is used for sanity
|
||||||
|
check to ensure that the input sampling rate equals to the one
|
||||||
|
used in the extractor. If they are not equal, then no resampling
|
||||||
|
will be performed; instead an error will be thrown.
|
||||||
|
waveform:
|
||||||
|
A 1-D torch tensor of dtype torch.float32 containing audio samples.
|
||||||
|
It should be on CPU.
|
||||||
|
"""
|
||||||
|
self.feature_extractor.accept_waveform(
|
||||||
|
sampling_rate=sampling_rate,
|
||||||
|
waveform=waveform,
|
||||||
|
)
|
||||||
|
self._fetch_frames()
|
||||||
|
|
||||||
|
def input_finished(self) -> None:
|
||||||
|
"""Signal that no more audio samples available and the feature
|
||||||
|
extractor should flush the buffered samples to compute frames.
|
||||||
|
"""
|
||||||
|
self.feature_extractor.input_finished()
|
||||||
|
self._fetch_frames()
|
||||||
|
self._done = True
|
||||||
|
|
||||||
|
@property
|
||||||
|
def done(self) -> bool:
|
||||||
|
"""Return True if `self.input_finished()` has been invoked"""
|
||||||
|
return self._done
|
||||||
|
|
||||||
|
def _fetch_frames(self) -> None:
|
||||||
|
"""Fetch frames from the feature extractor"""
|
||||||
|
while self.num_fetched_frames < self.feature_extractor.num_frames_ready:
|
||||||
|
frame = self.feature_extractor.get_frame(self.num_fetched_frames)
|
||||||
|
self.feature_frames.append(frame)
|
||||||
|
self.num_fetched_frames += 1
|
@ -25,7 +25,7 @@ To run this file, do:
|
|||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from emformer import Emformer
|
from emformer import Emformer, stack_states, unstack_states
|
||||||
|
|
||||||
|
|
||||||
def test_emformer():
|
def test_emformer():
|
||||||
@ -65,8 +65,41 @@ def test_emformer():
|
|||||||
print(f"Number of encoder parameters: {num_param}")
|
print(f"Number of encoder parameters: {num_param}")
|
||||||
|
|
||||||
|
|
||||||
|
def test_emformer_streaming_forward():
|
||||||
|
N = 3
|
||||||
|
C = 80
|
||||||
|
|
||||||
|
output_dim = 500
|
||||||
|
|
||||||
|
encoder = Emformer(
|
||||||
|
num_features=C,
|
||||||
|
output_dim=output_dim,
|
||||||
|
d_model=512,
|
||||||
|
nhead=8,
|
||||||
|
dim_feedforward=2048,
|
||||||
|
num_encoder_layers=20,
|
||||||
|
segment_length=16,
|
||||||
|
left_context_length=120,
|
||||||
|
right_context_length=4,
|
||||||
|
vgg_frontend=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
x = torch.rand(N, 23, C)
|
||||||
|
x_lens = torch.full((N,), 23)
|
||||||
|
y, y_lens, states = encoder.streaming_forward(x=x, x_lens=x_lens)
|
||||||
|
|
||||||
|
state_list = unstack_states(states)
|
||||||
|
states2 = stack_states(state_list)
|
||||||
|
|
||||||
|
for ss, ss2 in zip(states, states2):
|
||||||
|
for s, s2 in zip(ss, ss2):
|
||||||
|
assert torch.allclose(s, s2), f"{s.sum()}, {s2.sum()}"
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
def main():
|
def main():
|
||||||
test_emformer()
|
# test_emformer()
|
||||||
|
test_emformer_streaming_forward()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
53
egs/librispeech/ASR/transducer_emformer/test_streaming_feature_extractor.py
Executable file
53
egs/librispeech/ASR/transducer_emformer/test_streaming_feature_extractor.py
Executable file
@ -0,0 +1,53 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
To run this file, do:
|
||||||
|
|
||||||
|
cd icefall/egs/librispeech/ASR
|
||||||
|
python ./transducer_emformer/test_streaming_feature_extractor.py
|
||||||
|
"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from streaming_feature_extractor import FeatureExtractionStream
|
||||||
|
|
||||||
|
|
||||||
|
def test_streaming_feature_extractor():
|
||||||
|
stream = FeatureExtractionStream(context_size=2, blank_id=0)
|
||||||
|
samples = torch.rand(16000)
|
||||||
|
start = 0
|
||||||
|
while True:
|
||||||
|
n = torch.randint(50, 500, (1,)).item()
|
||||||
|
end = start + n
|
||||||
|
this_chunk = samples[start:end]
|
||||||
|
start = end
|
||||||
|
|
||||||
|
if len(this_chunk) == 0:
|
||||||
|
break
|
||||||
|
stream.accept_waveform(sampling_rate=16000, waveform=this_chunk)
|
||||||
|
print(len(stream.feature_frames))
|
||||||
|
stream.input_finished()
|
||||||
|
print(len(stream.feature_frames))
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
test_streaming_feature_extractor()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
Loading…
x
Reference in New Issue
Block a user