Add streaming feature extractor. (#302)

* Add streaming feature extractor.

* Parallel streaming decode with greedy search.

* Fix typos.

* Use torch.stack() to replace torch.cat()
This commit is contained in:
Fangjun Kuang 2022-04-18 10:38:56 +08:00 committed by GitHub
parent 7f73043219
commit 0f45356ee6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 829 additions and 116 deletions

View File

@ -15,3 +15,7 @@ exclude =
**/data/**,
icefall/shared/make_kn_lm.py,
icefall/__init__.py
ignore =
# E203 whitespace before ':'
E203,

View File

@ -32,13 +32,16 @@ class Joiner(nn.Module):
"""
Args:
encoder_out:
Output from the encoder. Its shape is (N, T, s_range, C).
Output from the encoder. Its shape is (N, T, s_range, C) for
training and (N, C) for streaming decoding.
decoder_out:
Output from the decoder. Its shape is (N, T, s_range, C).
Output from the decoder. Its shape is (N, T, s_range, C) for
training and (N, C) for streaming decoding.
Returns:
Return a tensor of shape (N, T, s_range, C).
"""
assert encoder_out.ndim == decoder_out.ndim == 4
assert encoder_out.ndim == decoder_out.ndim
assert encoder_out.ndim in (2, 4)
assert encoder_out.shape == decoder_out.shape
logit = encoder_out + decoder_out

View File

@ -27,6 +27,75 @@ from torchaudio.models import Emformer as _Emformer
LOG_EPSILON = math.log(1e-10)
def unstack_states(
states: List[List[torch.Tensor]],
) -> List[List[List[torch.Tensor]]]:
"""Unstack the emformer state corresponding to a batch of utterances
into a list of states, were the i-th entry is the state from the i-th
utterance in the batch.
Args:
states:
A list-of-list of tensors. ``len(states)`` equals to number of
layers in the emformer. ``states[i]]`` contains the states for
the i-th layer. ``states[i][k]`` is either a 3-D tensor of shape
``(T, N, C)`` or a 2-D tensor of shape ``(C, N)``
"""
batch_size = states[0][0].size(1)
num_layers = len(states)
ans = [None] * batch_size
for i in range(batch_size):
ans[i] = [[] for _ in range(num_layers)]
for li, layer in enumerate(states):
for s in layer:
s_list = s.unbind(dim=1)
# We will use stack(dim=1) later in stack_states()
for bi, b in enumerate(ans):
b[li].append(s_list[bi])
return ans
def stack_states(
state_list: List[List[List[torch.Tensor]]],
) -> List[List[torch.Tensor]]:
"""Stack list of emformer states that correspond to separate utterances
into a single emformer state so that it can be used as an input for
emformer when those utterances are formed into a batch.
Note:
It is the inverse of :func:`unstack_states`.
Args:
state_list:
Each element in state_list corresponding to the internal state
of the emformer model for a single utterance.
Returns:
Return a new state corresponding to a batch of utterances.
See the input argument of :func:`unstack_states` for the meaning
of the returned tensor.
"""
batch_size = len(state_list)
ans = []
for layer in state_list[0]:
# layer is a list of tensors
if batch_size > 1:
ans.append([[s] for s in layer])
# Note: We will stack ans[layer][s][] later to get ans[layer][s]
else:
ans.append([s.unsqueeze(1) for s in layer])
for b, states in enumerate(state_list[1:], 1):
for li, layer in enumerate(states):
for si, s in enumerate(layer):
ans[li][si].append(s)
if b == batch_size - 1:
ans[li][si] = torch.stack(ans[li][si], dim=1)
# We will use unbind(dim=1) later in unstack_states()
return ans
class Emformer(EncoderInterface):
"""This is just a simple wrapper around torchaudio.models.Emformer.
We may replace it with our own implementation some time later.
@ -63,11 +132,11 @@ class Emformer(EncoderInterface):
num_encoder_layers:
Number of encoder layers.
segment_length:
Number of frames per segment.
Number of frames per segment before subsampling.
left_context_length:
Number of frames in the left context.
Number of frames in the left context before subsampling.
right_context_length:
Number of frames in the right context.
Number of frames in the right context before subsampling.
max_memory_size:
TODO.
dropout:
@ -94,6 +163,7 @@ class Emformer(EncoderInterface):
else:
self.encoder_embed = Conv2dSubsampling(num_features, d_model)
self.segment_length = segment_length
self.right_context_length = right_context_length
assert right_context_length % subsampling_factor == 0

View File

@ -0,0 +1,184 @@
#!/usr/bin/env python3
#
# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This script converts several saved checkpoints
# to a single one using model averaging.
"""
Usage:
./transducer_emformer/export.py \
--exp-dir ./transducer_emformer/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--epoch 20 \
--avg 10
It will generate a file exp_dir/pretrained.pt
To use the generated file with `transducer_emformer/decode.py`,
you can do:
cd /path/to/exp_dir
ln -s pretrained.pt epoch-9999.pt
cd /path/to/egs/librispeech/ASR
./transducer_emformer/decode.py \
--exp-dir ./transducer_emformer/exp \
--epoch 9999 \
--avg 1 \
--max-duration 1000 \
--bpe-model data/lang_bpe_500/bpe.model
"""
import argparse
import logging
from pathlib import Path
import sentencepiece as spm
import torch
from train import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.utils import str2bool
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=28,
help="It specifies the checkpoint to use for decoding."
"Note: Epoch counts from 0.",
)
parser.add_argument(
"--avg",
type=int,
default=15,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch'. ",
)
parser.add_argument(
"--exp-dir",
type=str,
default="pruned_transducer_stateless/exp",
help="""It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
""",
)
parser.add_argument(
"--bpe-model",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
)
parser.add_argument(
"--jit",
type=str2bool,
default=False,
help="""True to save a model after applying torch.jit.script.
""",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
)
add_model_arguments(parser)
return parser
def main():
args = get_parser().parse_args()
args.exp_dir = Path(args.exp_dir)
assert args.jit is False, "Support torchscript will be added later"
params = get_params()
params.update(vars(args))
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size()
logging.info(params)
logging.info("About to create model")
model = get_transducer_model(params)
model.to(device)
if params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if start >= 0:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
model.eval()
model.to("cpu")
model.eval()
if params.jit:
logging.info("Using torch.jit.script")
model = torch.jit.script(model)
filename = params.exp_dir / "cpu_jit.pt"
model.save(str(filename))
logging.info(f"Saved to {filename}")
else:
logging.info("Not using torch.jit.script")
# Save it using a format so that it can be loaded
# by :func:`load_checkpoint`
filename = params.exp_dir / "pretrained.pt"
torch.save({"model": model.state_dict()}, str(filename))
logging.info(f"Saved to {filename}")
if __name__ == "__main__":
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -18,16 +18,16 @@
import argparse
import logging
import time
from pathlib import Path
from typing import List, Optional
from typing import List, Optional, Tuple
import kaldifeat
import numpy as np
import sentencepiece as spm
import torch
import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule
from emformer import LOG_EPSILON, stack_states, unstack_states
from streaming_feature_extractor import FeatureExtractionStream
from train import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import (
@ -147,10 +147,10 @@ def get_parser():
)
parser.add_argument(
"--sample-rate",
type=int,
"--sampling-rate",
type=float,
default=16000,
help="The sample rate of the input sound file",
help="Sample rate of the audio",
)
add_model_arguments(parser)
@ -158,115 +158,352 @@ def get_parser():
return parser
def get_feature_extractor(
params: AttributeDict,
) -> kaldifeat.Fbank:
logging.info("Constructing Fbank computer")
opts = kaldifeat.FbankOptions()
opts.device = params.device
opts.frame_opts.dither = 0
opts.frame_opts.snip_edges = True
opts.frame_opts.samp_freq = params.sample_rate
opts.mel_opts.num_bins = params.feature_dim
class StreamingAudioSamples(object):
"""This class takes as input a list of audio samples and returns
them in a streaming fashion.
"""
return kaldifeat.Fbank(opts)
def __init__(self, samples: List[torch.Tensor]) -> None:
"""
Args:
samples:
A list of audio samples. Each entry is a 1-D tensor of dtype
torch.float32, containing the audio samples of an utterance.
"""
self.samples = samples
self.cur_indexes = [0] * len(self.samples)
@property
def done(self) -> bool:
"""Return True if all samples have been processed.
Return False otherwise.
"""
for i, samples in zip(self.cur_indexes, self.samples):
if i < samples.numel():
return False
return True
def get_next(self) -> List[torch.Tensor]:
"""Return a list of audio samples. Each entry may have different
lengths. It is OK if an entry contains no samples at all, which
means it reaches the end of the utterance.
"""
ans = []
num = [1024] * len(self.samples)
for i in range(len(self.samples)):
start = self.cur_indexes[i]
end = start + num[i]
self.cur_indexes[i] = end
s = self.samples[i][start:end]
ans.append(s)
return ans
def decode_one_utterance(
audio_samples: torch.Tensor,
class StreamList(object):
def __init__(
self,
batch_size: int,
context_size: int,
blank_id: int,
):
"""
Args:
batch_size:
Size of this batch.
context_size:
Context size of the RNN-T decoder model.
blank_id:
The ID of the blank symbol of the BPE model.
"""
self.streams = [
FeatureExtractionStream(
context_size=context_size, blank_id=blank_id
)
for _ in range(batch_size)
]
@property
def done(self) -> bool:
"""Return True if all streams have reached end of utterance.
That is, no more audio samples are available for all utterances.
"""
return all(stream.done for stream in self.streams)
def accept_waveform(
self,
audio_samples: List[torch.Tensor],
sampling_rate: float,
):
"""Feeed audio samples to each stream.
Args:
audio_samples:
A list of 1-D tensors containing the audio samples for each
utterance in the batch. If an entry is empty, it means
end-of-utterance has been reached.
sampling_rate:
Sampling rate of the given audio samples.
"""
assert len(audio_samples) == len(self.streams)
for stream, samples in zip(self.streams, audio_samples):
if stream.done:
assert samples.numel() == 0
continue
stream.accept_waveform(
sampling_rate=sampling_rate,
waveform=samples,
)
if samples.numel() == 0:
stream.input_finished()
def build_batch(
self,
chunk_length: int,
segment_length: int,
) -> Tuple[Optional[torch.Tensor], Optional[List[FeatureExtractionStream]]]:
"""
Args:
chunk_length:
Number of frames for each chunk. It equals to
``segment_length + right_context_length``.
segment_length
Number of frames for each segment.
Returns:
Return a tuple containing:
- features, a 3-D tensor of shape ``(num_active_streams, T, C)``
- active_streams, a list of active streams. We say a stream is
active when it has enough feature frames to be fed into the
encoder model.
"""
feature_list = []
stream_list = []
for stream in self.streams:
if len(stream.feature_frames) >= chunk_length:
# this_chunk is a list of tensors, each of which
# has a shape (1, feature_dim)
chunk = stream.feature_frames[:chunk_length]
stream.feature_frames = stream.feature_frames[segment_length:]
features = torch.cat(chunk, dim=0)
feature_list.append(features)
stream_list.append(stream)
elif stream.done and len(stream.feature_frames) > 0:
chunk = stream.feature_frames[:chunk_length]
stream.feature_frames = []
features = torch.cat(chunk, dim=0)
features = torch.nn.functional.pad(
features,
(0, 0, 0, chunk_length - features.size(0)),
mode="constant",
value=LOG_EPSILON,
)
feature_list.append(features)
stream_list.append(stream)
if len(feature_list) == 0:
return None, None
features = torch.stack(feature_list, dim=0)
return features, stream_list
def greedy_search(
model: nn.Module,
fbank: kaldifeat.Fbank,
params: AttributeDict,
streams: List[FeatureExtractionStream],
encoder_out: torch.Tensor,
sp: spm.SentencePieceProcessor,
):
"""Decode one utterance.
"""
Args:
audio_samples:
A 1-D float32 tensor of shape (num_samples,) containing the normalized
audio samples. Normalized means the samples is in the range [-1, 1].
model:
The RNN-T model.
feature_extractor:
The feature extractor.
stream:
A stream object.
encoder_out:
A 3-D tensor of shape (N, T, encoder_out_dim) containing the output of
the encoder model.
sp:
The BPE model.
"""
blank_id = model.decoder.blank_id
context_size = model.decoder.context_size
device = model.device
if streams[0].decoder_out is None:
decoder_input = torch.tensor(
[stream.hyp.ys[-context_size:] for stream in streams],
device=device,
dtype=torch.int64,
)
decoder_out = model.decoder(
decoder_input,
need_pad=False,
).squeeze(1)
# decoder_out is of shape (N, decoder_out_dim)
else:
decoder_out = torch.stack(
[stream.decoder_out for stream in streams],
dim=0,
)
assert encoder_out.ndim == 3
T = encoder_out.size(1)
for t in range(T):
current_encoder_out = encoder_out[:, t]
# current_encoder_out's shape: (batch_size, encoder_out_dim)
logits = model.joiner(current_encoder_out, decoder_out)
# logits'shape (batch_size, vocab_size)
assert logits.ndim == 2, logits.shape
y = logits.argmax(dim=1).tolist()
emitted = False
for i, v in enumerate(y):
if v != blank_id:
streams[i].hyp.ys.append(v)
emitted = True
if emitted:
# update decoder output
decoder_input = torch.tensor(
[stream.hyp.ys[-context_size:] for stream in streams],
device=device,
dtype=torch.int64,
)
decoder_out = model.decoder(decoder_input, need_pad=False).squeeze(
1
)
for k, s in enumerate(streams):
logging.info(
f"Partial result {k}:\n{sp.decode(s.hyp.ys[context_size:])}"
)
decoder_out_list = decoder_out.unbind(dim=0)
for i, d in enumerate(decoder_out_list):
streams[i].decoder_out = d
def process_features(
model: nn.Module,
features: torch.Tensor,
streams: List[FeatureExtractionStream],
sp: spm.SentencePieceProcessor,
) -> None:
"""Process features for each stream in parallel.
Args:
model:
The RNN-T model.
features:
A 3-D tensor of shape (N, T, C).
streams:
A list of streams of size (N,).
sp:
The BPE model.
"""
assert features.ndim == 3
assert features.size(0) == len(streams)
batch_size = features.size(0)
device = model.device
features = features.to(device)
feature_lens = torch.full(
(batch_size,),
fill_value=features.size(1),
device=device,
)
# Caution: It has a limitation as it assumes that
# if one of the stream has an empty state, then all other
# streams also have empty states.
if streams[0].states is None:
states = None
else:
state_list = [stream.states for stream in streams]
states = stack_states(state_list)
(encoder_out, encoder_out_lens, states,) = model.encoder.streaming_forward(
features,
feature_lens,
states,
)
state_list = unstack_states(states)
for i, s in enumerate(state_list):
streams[i].states = s
greedy_search(
model=model,
streams=streams,
encoder_out=encoder_out,
sp=sp,
)
def decode_batch(
batched_samples: List[torch.Tensor],
model: nn.Module,
params: AttributeDict,
sp: spm.SentencePieceProcessor,
) -> List[str]:
"""
Args:
batched_samples:
A list of 1-D tensors containing the audio samples of each utterance.
model:
The RNN-T model.
params:
It is the return value of :func:`get_params`.
sp:
The BPE model.
"""
sample_rate = params.sample_rate
frame_shift = sample_rate * fbank.opts.frame_opts.frame_shift_ms / 1000
# number of frames before subsampling
segment_length = model.encoder.segment_length
frame_shift = int(frame_shift) # number of samples
right_context_length = model.encoder.right_context_length
# Note: We add 3 here because the subsampling method ((n-1)//2-1))//2
# is not equal to n//4. We will switch to a subsampling method that
# satisfies n//4, where n is the number of input frames.
segment_length = (params.segment_length + 3) * frame_shift
# We add 3 here since the subsampling method is using
# ((len - 1) // 2 - 1) // 2)
chunk_length = (segment_length + 3) + right_context_length
right_context_length = params.right_context_length * frame_shift
chunk_size = segment_length + right_context_length
batch_size = len(batched_samples)
streaming_audio_samples = StreamingAudioSamples(batched_samples)
opts = fbank.opts.frame_opts
chunk_size += (
(opts.frame_length_ms - opts.frame_shift_ms) / 1000 * sample_rate
stream_list = StreamList(
batch_size=batch_size,
context_size=params.context_size,
blank_id=params.blank_id,
)
chunk_size = int(chunk_size)
states: Optional[List[List[torch.Tensor]]] = None
blank_id = model.decoder.blank_id
context_size = model.decoder.context_size
device = model.device
hyp = [blank_id] * context_size
decoder_input = torch.tensor(hyp, device=device, dtype=torch.int64).reshape(
1, context_size
)
decoder_out = model.decoder(decoder_input, need_pad=False)
i = 0
num_samples = audio_samples.size(0)
while i < num_samples:
# Note: The current approach of computing the features is not ideal
# since it re-computes the features for the right context.
chunk = audio_samples[i : i + chunk_size] # noqa
i += segment_length
if chunk.size(0) < chunk_size:
chunk = torch.nn.functional.pad(
chunk, pad=(0, chunk_size - chunk.size(0))
)
features = fbank(chunk)
feature_lens = torch.tensor([features.size(0)], device=params.device)
features = features.unsqueeze(0) # (1, T, C)
encoder_out, encoder_out_lens, states = model.encoder.streaming_forward(
features,
feature_lens,
states,
while not streaming_audio_samples.done:
samples = streaming_audio_samples.get_next()
stream_list.accept_waveform(
audio_samples=samples,
sampling_rate=params.sampling_rate,
)
for t in range(encoder_out_lens.item()):
# fmt: off
current_encoder_out = encoder_out[0:1, t:t+1, :].unsqueeze(2)
# fmt: on
logits = model.joiner(current_encoder_out, decoder_out.unsqueeze(1))
# logits is (1, 1, 1, vocab_size)
y = logits.argmax().item()
if y == blank_id:
continue
hyp.append(y)
decoder_input = torch.tensor(
[hyp[-context_size:]], device=device, dtype=torch.int64
).reshape(1, context_size)
decoder_out = model.decoder(decoder_input, need_pad=False)
logging.info(f"Partial result:\n{sp.decode(hyp[context_size:])}")
features, active_streams = stream_list.build_batch(
chunk_length=chunk_length,
segment_length=segment_length,
)
if features is not None:
process_features(
model=model,
features=features,
streams=active_streams,
sp=sp,
)
results = []
for s in stream_list.streams:
text = sp.decode(s.hyp.ys[params.context_size :])
results.append(text)
return results
@torch.no_grad()
@ -333,30 +570,43 @@ def main():
test_clean_cuts = librispeech.test_clean_cuts()
fbank = get_feature_extractor(params)
batch_size = 3
ground_truth = []
batched_samples = []
for num, cut in enumerate(test_clean_cuts):
logging.info("Processing {num}")
audio: np.ndarray = cut.load_audio()
# audio.shape: (1, num_samples)
assert len(audio.shape) == 2
assert audio.shape[0] == 1, "Should be single channel"
assert audio.dtype == np.float32, audio.dtype
assert audio.max() <= 1, "Should be normalized to [-1, 1])"
decode_one_utterance(
audio_samples=torch.from_numpy(audio).squeeze(0).to(device),
model=model,
fbank=fbank,
params=params,
sp=sp,
)
logging.info(f"The ground truth is:\n{cut.supervisions[0].text}")
if num >= 0:
# The trained model is using normalized samples
assert audio.max() <= 1, "Should be normalized to [-1, 1])"
samples = torch.from_numpy(audio).squeeze(0)
batched_samples.append(samples)
ground_truth.append(cut.supervisions[0].text)
if len(batched_samples) >= batch_size:
decoded_results = decode_batch(
batched_samples=batched_samples,
model=model,
params=params,
sp=sp,
)
s = "\n"
for i, (hyp, ref) in enumerate(zip(decoded_results, ground_truth)):
s += f"hyp {i}:\n{hyp}\n"
s += f"ref {i}:\n{ref}\n\n"
logging.info(s)
batched_samples = []
ground_truth = []
# break after processing the first batch for test purposes
break
time.sleep(2) # So that you can see the decoded results
if __name__ == "__main__":
torch.manual_seed(20220410)
main()

View File

@ -0,0 +1,116 @@
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional
import torch
from beam_search import Hypothesis
from kaldifeat import FbankOptions, OnlineFbank, OnlineFeature
def _create_streaming_feature_extractor() -> OnlineFeature:
"""Create a CPU streaming feature extractor.
At present, we assume it returns a fbank feature extractor with
fixed options. In the future, we will support passing in the options
from outside.
Returns:
Return a CPU streaming feature extractor.
"""
opts = FbankOptions()
opts.device = "cpu"
opts.frame_opts.dither = 0
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = 16000
opts.mel_opts.num_bins = 80
return OnlineFbank(opts)
class FeatureExtractionStream(object):
def __init__(self, context_size: int, blank_id: int = 0) -> None:
"""Context size of the RNN-T decoder model."""
self.feature_extractor = _create_streaming_feature_extractor()
self.hyp = Hypothesis(
ys=([blank_id] * context_size),
log_prob=torch.tensor([0.0]),
) # for greedy search, will extend it to beam search
# It contains a list of 1-D tensors representing the feature frames.
self.feature_frames: List[torch.Tensor] = []
self.num_fetched_frames = 0
# For the emformer model, it contains the states of each
# encoder layer.
self.states: Optional[List[List[torch.Tensor]]] = None
# For the RNN-T decoder, it contains the decoder output
# corresponding to the decoder input self.hyp.ys[-context_size:]
# Its shape is (decoder_out_dim,)
self.decoder_out: Optional[torch.Tensor] = None
# After calling `self.input_finished()`, we set this flag to True
self._done = False
def accept_waveform(
self,
sampling_rate: float,
waveform: torch.Tensor,
) -> None:
"""Feed audio samples to the feature extractor and compute features
if there are enough samples available.
Caution:
The range of the audio samples should match the one used in the
training. That is, if you use the range [-1, 1] in the training, then
the input audio samples should also be normalized to [-1, 1].
Args
sampling_rate:
The sampling rate of the input audio samples. It is used for sanity
check to ensure that the input sampling rate equals to the one
used in the extractor. If they are not equal, then no resampling
will be performed; instead an error will be thrown.
waveform:
A 1-D torch tensor of dtype torch.float32 containing audio samples.
It should be on CPU.
"""
self.feature_extractor.accept_waveform(
sampling_rate=sampling_rate,
waveform=waveform,
)
self._fetch_frames()
def input_finished(self) -> None:
"""Signal that no more audio samples available and the feature
extractor should flush the buffered samples to compute frames.
"""
self.feature_extractor.input_finished()
self._fetch_frames()
self._done = True
@property
def done(self) -> bool:
"""Return True if `self.input_finished()` has been invoked"""
return self._done
def _fetch_frames(self) -> None:
"""Fetch frames from the feature extractor"""
while self.num_fetched_frames < self.feature_extractor.num_frames_ready:
frame = self.feature_extractor.get_frame(self.num_fetched_frames)
self.feature_frames.append(frame)
self.num_fetched_frames += 1

View File

@ -25,7 +25,7 @@ To run this file, do:
import warnings
import torch
from emformer import Emformer
from emformer import Emformer, stack_states, unstack_states
def test_emformer():
@ -65,8 +65,41 @@ def test_emformer():
print(f"Number of encoder parameters: {num_param}")
def test_emformer_streaming_forward():
N = 3
C = 80
output_dim = 500
encoder = Emformer(
num_features=C,
output_dim=output_dim,
d_model=512,
nhead=8,
dim_feedforward=2048,
num_encoder_layers=20,
segment_length=16,
left_context_length=120,
right_context_length=4,
vgg_frontend=False,
)
x = torch.rand(N, 23, C)
x_lens = torch.full((N,), 23)
y, y_lens, states = encoder.streaming_forward(x=x, x_lens=x_lens)
state_list = unstack_states(states)
states2 = stack_states(state_list)
for ss, ss2 in zip(states, states2):
for s, s2 in zip(ss, ss2):
assert torch.allclose(s, s2), f"{s.sum()}, {s2.sum()}"
@torch.no_grad()
def main():
test_emformer()
# test_emformer()
test_emformer_streaming_forward()
if __name__ == "__main__":

View File

@ -0,0 +1,53 @@
#!/usr/bin/env python3
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
To run this file, do:
cd icefall/egs/librispeech/ASR
python ./transducer_emformer/test_streaming_feature_extractor.py
"""
import torch
from streaming_feature_extractor import FeatureExtractionStream
def test_streaming_feature_extractor():
stream = FeatureExtractionStream(context_size=2, blank_id=0)
samples = torch.rand(16000)
start = 0
while True:
n = torch.randint(50, 500, (1,)).item()
end = start + n
this_chunk = samples[start:end]
start = end
if len(this_chunk) == 0:
break
stream.accept_waveform(sampling_rate=16000, waveform=this_chunk)
print(len(stream.feature_frames))
stream.input_finished()
print(len(stream.feature_frames))
def main():
test_streaming_feature_extractor()
if __name__ == "__main__":
main()