mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-07 16:14:17 +00:00
Add non-streaming ASR server implementation
This commit is contained in:
parent
2900ed8f8f
commit
4b59573473
57
bin/offline_client.py
Executable file
57
bin/offline_client.py
Executable file
@ -0,0 +1,57 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
A client for offline ASR recognition.
|
||||
"""
|
||||
import torch
|
||||
import torchaudio
|
||||
import websockets
|
||||
import asyncio
|
||||
|
||||
|
||||
async def main():
|
||||
test_wavs = [
|
||||
"/ceph-fj/fangjun/open-source-2/icefall-models/icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/test_wavs/1089-134686-0001.wav",
|
||||
"/ceph-fj/fangjun/open-source-2/icefall-models/icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/test_wavs/1221-135766-0001.wav",
|
||||
"/ceph-fj/fangjun/open-source-2/icefall-models/icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/test_wavs/1221-135766-0002.wav",
|
||||
]
|
||||
async with websockets.connect("ws://localhost:6006") as websocket:
|
||||
while True:
|
||||
for test_wav in test_wavs:
|
||||
print(f"Sending {test_wav}")
|
||||
wave, sample_rate = torchaudio.load(test_wav)
|
||||
wave = wave.squeeze(0)
|
||||
num_bytes = wave.numel() * wave.element_size()
|
||||
print(f"Sending {num_bytes}, {wave.shape}")
|
||||
await websocket.send(
|
||||
(num_bytes).to_bytes(8, "big", signed=True)
|
||||
)
|
||||
|
||||
frame_size = 1048576 // 4 # max payload is 1MB
|
||||
num_sent_samples = 0
|
||||
start = 0
|
||||
while start < wave.numel():
|
||||
end = start + frame_size
|
||||
await websocket.send(wave.numpy().data[start:end])
|
||||
start = end
|
||||
decoding_results = await websocket.recv()
|
||||
print(decoding_results)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
339
bin/offline_server.py
Executable file
339
bin/offline_server.py
Executable file
@ -0,0 +1,339 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
A server for offline ASR recognition. Offline means you send all the content
|
||||
of the audio for recognition. It supports multiple client sending at
|
||||
the same time.
|
||||
|
||||
TODO(fangjun): Run CPU-bound tasks such as neural network computation and
|
||||
decoding in C++ with the global interpreter lock (GIL) being released.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import math
|
||||
import warnings
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import List
|
||||
|
||||
import kaldifeat
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import websockets
|
||||
from beam_search import greedy_search_batch
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
|
||||
from icefall.utils import setup_logger
|
||||
|
||||
LOG_EPS = math.log(1e-10)
|
||||
|
||||
|
||||
def run_model_and_do_greedy_search(
|
||||
model: torch.jit.ScriptModule,
|
||||
features: List[torch.Tensor],
|
||||
) -> List[List[int]]:
|
||||
"""Run RNN-T model with the given features and use greedy search
|
||||
to decode the output of the model.
|
||||
|
||||
TODO:
|
||||
Split this function into two parts: One for computing the encoder output
|
||||
and another for decoding.
|
||||
|
||||
TODO:
|
||||
Move it to C++.
|
||||
|
||||
Args:
|
||||
model:
|
||||
The RNN-T model.
|
||||
features:
|
||||
A list of 2-D tensors. Each entry is of shape (num_frames, feature_dim).
|
||||
Returns:
|
||||
Return a list-of-list containing the decoding token IDs.
|
||||
"""
|
||||
feature_lengths = torch.tensor([f.size(0) for f in features])
|
||||
features = pad_sequence(
|
||||
features,
|
||||
batch_first=True,
|
||||
padding_value=LOG_EPS,
|
||||
)
|
||||
|
||||
device = next(model.parameters()).device
|
||||
features = features.to(device)
|
||||
feature_lengths = feature_lengths.to(device)
|
||||
|
||||
encoder_out, encoder_out_lens = model.encoder(features, feature_lengths)
|
||||
|
||||
hyp_tokens = greedy_search_batch(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
)
|
||||
return hyp_tokens
|
||||
|
||||
|
||||
class OfflineServer:
|
||||
def __init__(
|
||||
self,
|
||||
nn_model_filename: str,
|
||||
bpe_model_filename: str,
|
||||
num_device: int,
|
||||
feature_extractor_pool_size: int = 3,
|
||||
nn_pool_size: int = -1,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
nn_model_filename:
|
||||
Path to the torch script model.
|
||||
num_device:
|
||||
If 0, use CPU for neural network computation and decoding.
|
||||
If positive, it means the number of GPUs to use for NN computation
|
||||
and decoding. For each device, there will be a corresponding
|
||||
torchscript model. We assume available device IDs are
|
||||
0, 1, ... , num_device - 1. You can use the environment variable
|
||||
CUDA_VISBILE_DEVICES to achieve this.
|
||||
feature_extractor_pool_size:
|
||||
Number of threads to create for the feature extractor thread pool.
|
||||
nn_pool_size:
|
||||
Number of threads for the thread pool that is used for NN
|
||||
computation and decoding.
|
||||
"""
|
||||
self.feature_extractor = self._build_feature_extractor()
|
||||
self.nn_models = self._build_nn_model(nn_model_filename, num_device)
|
||||
|
||||
assert nn_pool_size > 0
|
||||
|
||||
self.feature_extractor_pool = ThreadPoolExecutor(
|
||||
max_workers=feature_extractor_pool_size
|
||||
)
|
||||
self.nn_pool = ThreadPoolExecutor(max_workers=nn_pool_size)
|
||||
|
||||
self.feature_queue = asyncio.Queue()
|
||||
|
||||
self.sp = spm.SentencePieceProcessor()
|
||||
self.sp.load(bpe_model_filename)
|
||||
|
||||
self.counter = 0
|
||||
|
||||
def _build_feature_extractor(self):
|
||||
"""Build a fbank feature extractor for extracting features.
|
||||
|
||||
TODO:
|
||||
Pass the options as arguments
|
||||
"""
|
||||
opts = kaldifeat.FbankOptions()
|
||||
opts.device = "cpu" # Note: It also supports CUDA, e.g., "cuda:0"
|
||||
opts.frame_opts.dither = 0
|
||||
opts.frame_opts.snip_edges = False
|
||||
opts.frame_opts.samp_freq = 16000
|
||||
opts.mel_opts.num_bins = 80
|
||||
|
||||
fbank = kaldifeat.Fbank(opts)
|
||||
|
||||
return fbank
|
||||
|
||||
def _build_nn_model(
|
||||
self, nn_model_filename: str, num_device: int
|
||||
) -> List[torch.jit.ScriptModule]:
|
||||
"""Build a torch script model for each given device.
|
||||
|
||||
Args:
|
||||
nn_model_filename:
|
||||
The path to the torch script model.
|
||||
num_device:
|
||||
Number of devices to use for NN computation and decoding.
|
||||
If it is 0, then only use CPU and it returns a model on CPU.
|
||||
If it is positive, it create a model for each device and returns
|
||||
them.
|
||||
Returns:
|
||||
Return a list of torch script models.
|
||||
"""
|
||||
|
||||
model = torch.jit.load(nn_model_filename, map_location="cpu")
|
||||
model.eval()
|
||||
if num_device < 1:
|
||||
return [model]
|
||||
|
||||
ans = []
|
||||
for i in range(num_device):
|
||||
device = torch.device("cuda", i)
|
||||
ans.append(model.to(device))
|
||||
|
||||
return ans
|
||||
|
||||
async def loop(self, port: int):
|
||||
logging.info("started")
|
||||
asyncio.create_task(self.feature_consumer_task())
|
||||
# asyncio.create_task(self.feature_consumer_task())
|
||||
# asyncio.create_task(self.feature_consumer_task())
|
||||
async with websockets.serve(self.handle_connection, "", port):
|
||||
await asyncio.Future() # run forever
|
||||
|
||||
async def recv_audio_samples(
|
||||
self,
|
||||
socket: websockets.WebSocketServerProtocol,
|
||||
) -> torch.Tensor:
|
||||
"""Receives a tensor from the client.
|
||||
|
||||
The message from the client has the following format:
|
||||
|
||||
- a header of 8 bytes, containing the number of bytes of the tensor.
|
||||
The header is in big endian format.
|
||||
- a binary representation of the 1-D torch.float32 tensor.
|
||||
|
||||
Args:
|
||||
socket:
|
||||
The socket for communicating with the client.
|
||||
Returns:
|
||||
Return a 1-D torch.float32 tensor.
|
||||
"""
|
||||
expected_num_bytes = None
|
||||
received = b""
|
||||
async for message in socket:
|
||||
if expected_num_bytes is None:
|
||||
assert len(message) >= 8, (len(message), message)
|
||||
expected_num_bytes = int.from_bytes(
|
||||
message[:8], "big", signed=True
|
||||
)
|
||||
received += message[8:]
|
||||
if len(received) == expected_num_bytes:
|
||||
break
|
||||
else:
|
||||
received += message
|
||||
if len(received) == expected_num_bytes:
|
||||
break
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
# PyTorch warns that the underlying buffer is not writable.
|
||||
# We ignore it here as we are not going to write it anyway.
|
||||
return torch.frombuffer(received, dtype=torch.float32)
|
||||
|
||||
async def feature_consumer_task(self):
|
||||
sleep_time = 20 / 1000.0 # wait for 20ms
|
||||
batch_size = 5
|
||||
while True:
|
||||
if self.feature_queue.empty():
|
||||
logging.info("empty")
|
||||
await asyncio.sleep(sleep_time)
|
||||
continue
|
||||
batch = []
|
||||
try:
|
||||
while len(batch) < batch_size:
|
||||
item = self.feature_queue.get_nowait()
|
||||
batch.append(item)
|
||||
self.feature_queue.task_done()
|
||||
except asyncio.QueueEmpty:
|
||||
pass
|
||||
logging.info(f"batch size: {len(batch)}")
|
||||
|
||||
feature_list = [b[0] for b in batch]
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
self.counter = (self.counter + 1) % len(self.nn_models)
|
||||
model = self.nn_models[self.counter]
|
||||
|
||||
hyp_tokens = await loop.run_in_executor(
|
||||
self.nn_pool,
|
||||
run_model_and_do_greedy_search,
|
||||
model,
|
||||
feature_list,
|
||||
)
|
||||
logging.info(f"batch_size: {len(hyp_tokens)}")
|
||||
|
||||
for i, hyp in enumerate(hyp_tokens):
|
||||
future = batch[i][1]
|
||||
future.set_result(hyp)
|
||||
|
||||
async def compute_features(self, samples: torch.Tensor) -> torch.Tensor:
|
||||
"""Compute the fbank features for the given audio samples.
|
||||
|
||||
Args:
|
||||
samples:
|
||||
A 1-D torch.float32 tensor containing the audio samples. Its
|
||||
sampling rate should be the one as expected by the feature
|
||||
extractor. Also, its range should match the one used in the
|
||||
training.
|
||||
Returns:
|
||||
Return a 2-D tensor of shape (num_frames, feature_dim) containing
|
||||
the features.
|
||||
"""
|
||||
loop = asyncio.get_running_loop()
|
||||
return await loop.run_in_executor(
|
||||
self.feature_extractor_pool,
|
||||
self.feature_extractor, # it releases GIL
|
||||
samples,
|
||||
)
|
||||
|
||||
async def compute_encoder_out(
|
||||
self,
|
||||
features: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""Run the RNN-T encoder network.
|
||||
|
||||
Args:
|
||||
features:
|
||||
A 2-D tensor of shape (num_frames, feature_dim).
|
||||
Returns:
|
||||
Return a 2-D tensor of shape (num_frames, encoder_out_dim) containing
|
||||
the output of the encoder network.
|
||||
"""
|
||||
loop = asyncio.get_running_loop()
|
||||
future = loop.create_future()
|
||||
await self.feature_queue.put((features, future))
|
||||
await future
|
||||
return future.result()
|
||||
|
||||
async def handle_connection(
|
||||
self,
|
||||
socket: websockets.WebSocketServerProtocol,
|
||||
):
|
||||
"""Receive audio samples from the client, process it, and sends
|
||||
deocoding result back to the client.
|
||||
|
||||
Args:
|
||||
socket:
|
||||
The socket for communicating with the client.
|
||||
"""
|
||||
logging.info(f"Connected: {socket.remote_address}")
|
||||
while True:
|
||||
samples = await self.recv_audio_samples(socket)
|
||||
features = await self.compute_features(samples)
|
||||
hyp = await self.compute_encoder_out(features)
|
||||
result = self.sp.decode(hyp)
|
||||
logging.info(f"hyp: {result}")
|
||||
await socket.send(result)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
nn_model_filename = "/ceph-fj/fangjun/open-source-2/icefall-master-2/egs/librispeech/ASR/pruned_transducer_stateless3/exp/cpu_jit.pt" # noqa
|
||||
bpe_model_filename = "/ceph-fj/fangjun/open-source-2/icefall-master-2/egs/librispeech/ASR/data/lang_bpe_500/bpe.model"
|
||||
port = 6006 # the server will listen on this port
|
||||
offline_server = OfflineServer(
|
||||
nn_model_filename=nn_model_filename,
|
||||
bpe_model_filename=bpe_model_filename,
|
||||
num_device=2,
|
||||
feature_extractor_pool_size=5,
|
||||
nn_pool_size=5,
|
||||
)
|
||||
asyncio.run(offline_server.loop(port))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
torch.manual_seed(20220519)
|
||||
setup_logger("./log")
|
||||
main()
|
Loading…
x
Reference in New Issue
Block a user