Add streaming feature extractor.

This commit is contained in:
Fangjun Kuang 2022-04-10 23:07:41 +08:00
parent 189ca555b1
commit f16b759397
6 changed files with 515 additions and 94 deletions

View File

@ -14,3 +14,7 @@ exclude =
.git,
**/data/**,
icefall/shared/make_kn_lm.py
ignore =
# E203 whitespace before ':'
E203,

View File

@ -63,11 +63,11 @@ class Emformer(EncoderInterface):
num_encoder_layers:
Number of encoder layers.
segment_length:
Number of frames per segment.
Number of frames per segment before subsampling.
left_context_length:
Number of frames in the left context.
Number of frames in the left context before subsampling.
right_context_length:
Number of frames in the right context.
Number of frames in the right context before subsampling.
max_memory_size:
TODO.
dropout:
@ -94,6 +94,7 @@ class Emformer(EncoderInterface):
else:
self.encoder_embed = Conv2dSubsampling(num_features, d_model)
self.segment_length = segment_length
self.right_context_length = right_context_length
assert right_context_length % subsampling_factor == 0

View File

@ -0,0 +1,184 @@
#!/usr/bin/env python3
#
# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This script converts several saved checkpoints
# to a single one using model averaging.
"""
Usage:
./transducer_emformer/export.py \
--exp-dir ./transducer_emformer/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--epoch 20 \
--avg 10
It will generate a file exp_dir/pretrained.pt
To use the generated file with `transducer_emformer/decode.py`,
you can do:
cd /path/to/exp_dir
ln -s pretrained.pt epoch-9999.pt
cd /path/to/egs/librispeech/ASR
./transducer_emformer/decode.py \
--exp-dir ./transducer_emformer/exp \
--epoch 9999 \
--avg 1 \
--max-duration 1000 \
--bpe-model data/lang_bpe_500/bpe.model
"""
import argparse
import logging
from pathlib import Path
import sentencepiece as spm
import torch
from train import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.utils import str2bool
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=28,
help="It specifies the checkpoint to use for decoding."
"Note: Epoch counts from 0.",
)
parser.add_argument(
"--avg",
type=int,
default=15,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch'. ",
)
parser.add_argument(
"--exp-dir",
type=str,
default="pruned_transducer_stateless/exp",
help="""It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
""",
)
parser.add_argument(
"--bpe-model",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
)
parser.add_argument(
"--jit",
type=str2bool,
default=False,
help="""True to save a model after applying torch.jit.script.
""",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
)
add_model_arguments(parser)
return parser
def main():
args = get_parser().parse_args()
args.exp_dir = Path(args.exp_dir)
assert args.jit is False, "Support torchscript will be added later"
params = get_params()
params.update(vars(args))
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size()
logging.info(params)
logging.info("About to create model")
model = get_transducer_model(params)
model.to(device)
if params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if start >= 0:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
model.eval()
model.to("cpu")
model.eval()
if params.jit:
logging.info("Using torch.jit.script")
model = torch.jit.script(model)
filename = params.exp_dir / "cpu_jit.pt"
model.save(str(filename))
logging.info(f"Saved to {filename}")
else:
logging.info("Not using torch.jit.script")
# Save it using a format so that it can be loaded
# by :func:`load_checkpoint`
filename = params.exp_dir / "pretrained.pt"
torch.save({"model": model.state_dict()}, str(filename))
logging.info(f"Saved to {filename}")
if __name__ == "__main__":
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -20,14 +20,14 @@ import argparse
import logging
import time
from pathlib import Path
from typing import List, Optional
import kaldifeat
import numpy as np
import sentencepiece as spm
import torch
import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule
from emformer import LOG_EPSILON
from streaming_feature_extractor import Stream
from train import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import (
@ -147,10 +147,10 @@ def get_parser():
)
parser.add_argument(
"--sample-rate",
type=int,
"--sampling-rate",
type=float,
default=16000,
help="The sample rate of the input sound file",
help="Sample rate of the audio",
)
add_model_arguments(parser)
@ -158,32 +158,159 @@ def get_parser():
return parser
def get_feature_extractor(
params: AttributeDict,
) -> kaldifeat.Fbank:
logging.info("Constructing Fbank computer")
opts = kaldifeat.FbankOptions()
opts.device = params.device
opts.frame_opts.dither = 0
opts.frame_opts.snip_edges = True
opts.frame_opts.samp_freq = params.sample_rate
opts.mel_opts.num_bins = params.feature_dim
def greedy_search(
model: nn.Module,
stream: Stream,
encoder_out: torch.Tensor,
sp: spm.SentencePieceProcessor,
):
"""
Args:
model:
The RNN-T model.
stream:
A stream object.
encoder_out:
A 2-D tensor of shape (T, encoder_out_dim) containing the output of
the encoder model.
sp:
The BPE model.
"""
blank_id = model.decoder.blank_id
context_size = model.decoder.context_size
device = model.device
return kaldifeat.Fbank(opts)
if stream.decoder_out is None:
decoder_input = torch.tensor(
[stream.hyp.ys[-context_size:]],
device=device,
dtype=torch.int64,
)
stream.decoder_out = model.decoder(
decoder_input,
need_pad=False,
).unsqueeze(1)
# stream.decoder_out is of shape (1, 1, decoder_out_dim)
assert encoder_out.ndim == 2
T = encoder_out.size(0)
for t in range(T):
current_encoder_out = encoder_out[t].reshape(
1, 1, 1, encoder_out.size(-1)
)
logits = model.joiner(current_encoder_out, stream.decoder_out)
# logits is of shape (1, 1, 1, vocab_size)
y = logits.argmax().item()
if y == blank_id:
continue
stream.hyp.ys.append(y)
decoder_input = torch.tensor(
[stream.hyp.ys[-context_size:]],
device=device,
dtype=torch.int64,
)
stream.decoder_out = model.decoder(
decoder_input,
need_pad=False,
).unsqueeze(1)
logging.info(
f"Partial result:\n{sp.decode(stream.hyp.ys[context_size:])}"
)
def process_feature_frames(
model: nn.Module,
stream: Stream,
sp: spm.SentencePieceProcessor,
):
"""Process the feature frames contained in ``stream.feature_frames``.
Args:
model:
The RNN-T model.
stream:
The stream corresponding to the input audio samples.
sp:
The BPE model.
"""
# number of frames before subsampling
segment_length = model.encoder.segment_length
right_context_length = model.encoder.right_context_length
chunk_length = (segment_length + 3) + right_context_length
device = model.device
while len(stream.feature_frames) >= chunk_length:
# a list of tensor, each with a shape (1, feature_dim)
this_chunk = stream.feature_frames[:chunk_length]
stream.feature_frames = stream.feature_frames[segment_length:]
features = torch.cat(this_chunk, dim=0).to(device) # (T, feature_dim)
features = features.unsqueeze(0) # (1, T, feature_dim)
feature_lens = torch.tensor([features.size(1)], device=device)
(
encoder_out,
encoder_out_lens,
stream.states,
) = model.encoder.streaming_forward(
features,
feature_lens,
stream.states,
)
greedy_search(
model=model,
stream=stream,
encoder_out=encoder_out[0],
sp=sp,
)
if stream.feature_extractor.is_last_frame(stream.num_fetched_frames - 1):
assert len(stream.feature_frames) < chunk_length
if len(stream.feature_frames) > 0:
this_chunk = stream.feature_frames[:chunk_length]
stream.feature_frames = []
features = torch.cat(this_chunk, dim=0) # (T, feature_dim)
features = features.to(device).unsqueeze(0) # (1, T, feature_dim)
features = torch.nn.functional.pad(
features,
(0, 0, 0, chunk_length - features.size(1)),
value=LOG_EPSILON,
)
feature_lens = torch.tensor([features.size(1)], device=device)
(
encoder_out,
encoder_out_lens,
stream.states,
) = model.encoder.streaming_forward(
features,
feature_lens,
stream.states,
)
greedy_search(
model=model,
stream=stream,
encoder_out=encoder_out[0],
sp=sp,
)
def decode_one_utterance(
audio_samples: torch.Tensor,
model: nn.Module,
fbank: kaldifeat.Fbank,
stream: Stream,
params: AttributeDict,
sp: spm.SentencePieceProcessor,
):
"""Decode one utterance.
Args:
audio_samples:
A 1-D float32 tensor of shape (num_samples,) containing the normalized
audio samples. Normalized means the samples is in the range [-1, 1].
A 1-D float32 tensor of shape (num_samples,) containing the
audio samples.
model:
The RNN-T model.
feature_extractor:
@ -193,80 +320,23 @@ def decode_one_utterance(
sp:
The BPE model.
"""
sample_rate = params.sample_rate
frame_shift = sample_rate * fbank.opts.frame_opts.frame_shift_ms / 1000
frame_shift = int(frame_shift) # number of samples
# Note: We add 3 here because the subsampling method ((n-1)//2-1))//2
# is not equal to n//4. We will switch to a subsampling method that
# satisfies n//4, where n is the number of input frames.
segment_length = (params.segment_length + 3) * frame_shift
right_context_length = params.right_context_length * frame_shift
chunk_size = segment_length + right_context_length
opts = fbank.opts.frame_opts
chunk_size += (
(opts.frame_length_ms - opts.frame_shift_ms) / 1000 * sample_rate
)
chunk_size = int(chunk_size)
states: Optional[List[List[torch.Tensor]]] = None
blank_id = model.decoder.blank_id
context_size = model.decoder.context_size
device = model.device
hyp = [blank_id] * context_size
decoder_input = torch.tensor(hyp, device=device, dtype=torch.int64).reshape(
1, context_size
)
decoder_out = model.decoder(decoder_input, need_pad=False)
i = 0
num_samples = audio_samples.size(0)
while i < num_samples:
# Note: The current approach of computing the features is not ideal
# since it re-computes the features for the right context.
chunk = audio_samples[i : i + chunk_size] # noqa
i += segment_length
if chunk.size(0) < chunk_size:
chunk = torch.nn.functional.pad(
chunk, pad=(0, chunk_size - chunk.size(0))
)
features = fbank(chunk)
feature_lens = torch.tensor([features.size(0)], device=params.device)
# Simulate streaming.
this_chunk_num_samples = torch.randint(2000, 5000, (1,)).item()
features = features.unsqueeze(0) # (1, T, C)
thiks_chunk_samples = audio_samples[i : (i + this_chunk_num_samples)]
i += this_chunk_num_samples
encoder_out, encoder_out_lens, states = model.encoder.streaming_forward(
features,
feature_lens,
states,
stream.accept_waveform(
sampling_rate=params.sampling_rate,
waveform=thiks_chunk_samples,
)
for t in range(encoder_out_lens.item()):
# fmt: off
current_encoder_out = encoder_out[0:1, t:t+1, :].unsqueeze(2)
# fmt: on
logits = model.joiner(current_encoder_out, decoder_out.unsqueeze(1))
# logits is (1, 1, 1, vocab_size)
y = logits.argmax().item()
if y == blank_id:
continue
process_feature_frames(model=model, stream=stream, sp=sp)
hyp.append(y)
decoder_input = torch.tensor(
[hyp[-context_size:]], device=device, dtype=torch.int64
).reshape(1, context_size)
decoder_out = model.decoder(decoder_input, need_pad=False)
logging.info(f"Partial result:\n{sp.decode(hyp[context_size:])}")
stream.input_finished()
process_feature_frames(model=model, stream=stream, sp=sp)
@torch.no_grad()
@ -333,10 +403,12 @@ def main():
test_clean_cuts = librispeech.test_clean_cuts()
fbank = get_feature_extractor(params)
for num, cut in enumerate(test_clean_cuts):
logging.info("Processing {num}")
logging.info(f"Processing {num}")
stream = Stream(
context_size=model.decoder.context_size,
blank_id=model.decoder.blank_id,
)
audio: np.ndarray = cut.load_audio()
# audio.shape: (1, num_samples)
@ -347,16 +419,17 @@ def main():
decode_one_utterance(
audio_samples=torch.from_numpy(audio).squeeze(0).to(device),
model=model,
fbank=fbank,
stream=stream,
params=params,
sp=sp,
)
logging.info(f"The ground truth is:\n{cut.supervisions[0].text}")
if num >= 0:
if num >= 2:
break
time.sleep(2) # So that you can see the decoded results
if __name__ == "__main__":
torch.manual_seed(20220410)
main()

View File

@ -0,0 +1,106 @@
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional
import torch
from beam_search import Hypothesis
from kaldifeat import FbankOptions, OnlineFbank, OnlineFeature
def _create_streaming_feature_extractr() -> OnlineFeature:
"""Create a CPU streaming feature extractor.
At present, we assume it returns a fbank feature extractor with
fixed options. In the future, we will support passing in the options
from outside.
Returns:
Return a CPU streaming feature extractor.
"""
opts = FbankOptions()
opts.device = "cpu"
opts.frame_opts.dither = 0
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = 16000
opts.mel_opts.num_bins = 80
return OnlineFbank(opts)
class Stream(object):
def __init__(self, context_size: int, blank_id: int = 0) -> None:
"""Context size of the RNN-T decoder model."""
self.feature_extractor = _create_streaming_feature_extractr()
self.hyp = Hypothesis(
ys=([blank_id] * context_size),
log_prob=torch.tensor([0.0]),
) # for greedy search, will extend it to beam search
# It contains a list of 1-D tensors representing the feature frames.
self.feature_frames: List[torch.Tensor] = []
self.num_fetched_frames = 0
# For the emformer model, it contains the states of each
# encoder layer.
self.states: Optional[List[List[torch.Tensor]]] = None
# For the RNN-T decoder, it contains the decoder output
# corresponding to the decoder input self.hyp.ys[-context_size:]
self.decoder_out: Optional[torch.Tensor] = None
def accept_waveform(
self,
sampling_rate: float,
waveform: torch.Tensor,
) -> None:
"""Feed audio samples to the feature extractor and compute features
if there are enough samples available.
Caution:
The range of the audio samples should match the one used in the
training. That is, if you use the range [-1, 1] in the training, then
the input audio samples should also be normalized to [-1, 1].
Args
sampling_rate:
The sampling rate of the input audio samples. It is used for sanity
check to ensure that the input sampling rate equals to the one
used in the extractor. If they are not equal, then no resampling
will be performed; instead an error will be thrown.
waveform:
A 1-D torch tensor of dtype torch.float32 containing audio samples.
It should be on CPU.
"""
self.feature_extractor.accept_waveform(
sampling_rate=sampling_rate,
waveform=waveform,
)
self._fetch_frames()
def input_finished(self) -> None:
"""Signal that no more audio samples available and the feature
extractor should flush the buffered samples to compute frames.
"""
self.feature_extractor.input_finished()
self._fetch_frames()
def _fetch_frames(self) -> None:
"""Fetch frames from the feature extractor"""
while self.num_fetched_frames < self.feature_extractor.num_frames_ready:
frame = self.feature_extractor.get_frame(self.num_fetched_frames)
self.feature_frames.append(frame)
self.num_fetched_frames += 1

View File

@ -0,0 +1,53 @@
#!/usr/bin/env python3
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
To run this file, do:
cd icefall/egs/librispeech/ASR
python ./transducer_emformer/test_streaming_feature_extractor.py
"""
import torch
from streaming_feature_extractor import Stream
def test_streaming_feature_extractor():
stream = Stream(context_size=2, blank_id=0)
samples = torch.rand(16000)
start = 0
while True:
n = torch.randint(50, 500, (1,)).item()
end = start + n
this_chunk = samples[start:end]
start = end
if len(this_chunk) == 0:
break
stream.accept_waveform(sampling_rate=16000, waveform=this_chunk)
print(len(stream.feature_frames))
stream.input_finished()
print(len(stream.feature_frames))
def main():
test_streaming_feature_extractor()
if __name__ == "__main__":
main()