From 3fd4f1d13f11292b494ba968ea865669dbe5ed5a Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Fri, 6 May 2022 18:21:08 +0800 Subject: [PATCH] First attempt to add WEB interface for emformer model. --- .flake8 | 1 + .../ASR/transducer_emformer/client/index.html | 29 ++- .../ASR/transducer_emformer/client/main.js | 60 ++++++ .../ASR/transducer_emformer/server.py | 182 ++++++++++++++++++ .../transducer_emformer/streaming_decode.py | 6 +- .../ASR/transducer_emformer/train.py | 4 +- 6 files changed, 273 insertions(+), 9 deletions(-) create mode 100644 egs/librispeech/ASR/transducer_emformer/client/main.js create mode 100755 egs/librispeech/ASR/transducer_emformer/server.py diff --git a/.flake8 b/.flake8 index 9f54c46ee..982c229d4 100644 --- a/.flake8 +++ b/.flake8 @@ -17,6 +17,7 @@ exclude = .git, **/data/**, icefall/shared/make_kn_lm.py, + egs/librispeech/ASR/transducer_emformer/train.py, icefall/__init__.py ignore = diff --git a/egs/librispeech/ASR/transducer_emformer/client/index.html b/egs/librispeech/ASR/transducer_emformer/client/index.html index 85a21df49..5b6baa001 100644 --- a/egs/librispeech/ASR/transducer_emformer/client/index.html +++ b/egs/librispeech/ASR/transducer_emformer/client/index.html @@ -2,35 +2,50 @@ - - + + + crossorigin="anonymous"> + Hello next-gen Kaldi - +

Hello next-gen Kaldi

+
+ + +
+ +
+ + +
+ + diff --git a/egs/librispeech/ASR/transducer_emformer/client/main.js b/egs/librispeech/ASR/transducer_emformer/client/main.js new file mode 100644 index 000000000..a25eb5330 --- /dev/null +++ b/egs/librispeech/ASR/transducer_emformer/client/main.js @@ -0,0 +1,60 @@ +/** +References +https://developer.mozilla.org/en-US/docs/Web/API/FileList +https://developer.mozilla.org/en-US/docs/Web/API/FileReader +https://javascript.info/arraybuffer-binary-arrays +https://developer.mozilla.org/zh-CN/docs/Web/API/WebSocket +https://developer.mozilla.org/en-US/docs/Web/API/WebSocket/send +*/ + +var socket; +function initWebSocket() { + socket = new WebSocket("ws://localhost:6008/"); + + // Connection opened + socket.addEventListener( + 'open', + function(event) { document.getElementById('file').disabled = false; }); + + // Connection closed + socket.addEventListener('close', function(event) { + document.getElementById('file').disabled = true; + initWebSocket(); + }); + + // Listen for messages + socket.addEventListener('message', function(event) { + document.getElementById('results').innerHTML = event.data; + console.log('Received message: ', event.data); + }); +} + +function onFileChange() { + var files = document.getElementById("file").files; + + if (files.length == 0) { + console.log('No file selected'); + return; + } + + console.log('files: ' + files); + + const file = files[0]; + console.log(file); + console.log('file.name ' + file.name); + console.log('file.type ' + file.type); + console.log('file.size ' + file.size); + + let reader = new FileReader(); + reader.onload = function() { + let view = new Int16Array(reader.result); + console.log('bytes: ' + view.byteLength); + // we assume the input file is a wav file. + // TODO: add some checks here. + let body = view.subarray(44); + socket.send(body); + socket.send(JSON.stringify({'eof' : 1})); + }; + + reader.readAsArrayBuffer(file); +} diff --git a/egs/librispeech/ASR/transducer_emformer/server.py b/egs/librispeech/ASR/transducer_emformer/server.py new file mode 100755 index 000000000..35f66f60f --- /dev/null +++ b/egs/librispeech/ASR/transducer_emformer/server.py @@ -0,0 +1,182 @@ +#!/usr/bin/env python3 +import asyncio +import logging +from pathlib import Path + +import sentencepiece as spm +import torch +import websockets +from streaming_decode import StreamList, get_parser, process_features +from train import get_params, get_transducer_model + +from icefall.checkpoint import ( + average_checkpoints, + find_checkpoints, + load_checkpoint, +) +from icefall.utils import setup_logger + +g_params = None +g_model = None +g_sp = None + + +def build_stream_list(): + batch_size = 1 # will change it later + + stream_list = StreamList( + batch_size=batch_size, + context_size=g_params.context_size, + decoding_method=g_params.decoding_method, + ) + return stream_list + + +async def echo(websocket): + logging.info(f"connected: {websocket.remote_address}") + + stream_list = build_stream_list() + + # number of frames before subsampling + segment_length = g_model.encoder.segment_length + + right_context_length = g_model.encoder.right_context_length + + # We add 3 here since the subsampling method is using + # ((len - 1) // 2 - 1) // 2) + chunk_length = (segment_length + 3) + right_context_length + + async for message in websocket: + if isinstance(message, bytes): + samples = torch.frombuffer(message, dtype=torch.int16) + samples = samples.to(torch.float32) / 32768 + stream_list.accept_waveform( + audio_samples=[samples], + sampling_rate=g_params.sampling_rate, + ) + + while True: + features, active_streams = stream_list.build_batch( + chunk_length=chunk_length, + segment_length=segment_length, + ) + + if features is not None: + process_features( + model=g_model, + features=features, + streams=active_streams, + params=g_params, + sp=g_sp, + ) + results = [] + for stream in stream_list.streams: + text = g_sp.decode(stream.decoding_result()) + results.append(text) + await websocket.send(results[0]) + else: + break + elif isinstance(message, str): + stream_list[0].input_finished() + while True: + features, active_streams = stream_list.build_batch( + chunk_length=chunk_length, + segment_length=segment_length, + ) + + if features is not None: + process_features( + model=g_model, + features=features, + streams=active_streams, + params=g_params, + sp=g_sp, + ) + else: + break + + results = [] + for stream in stream_list.streams: + text = g_sp.decode(stream.decoding_result()) + results.append(text) + + await websocket.send(results[0]) + await websocket.close() + + logging.info(f"Closed: {websocket.remote_address}") + + +async def loop(): + logging.info("started") + async with websockets.serve(echo, "", 6008): + await asyncio.Future() # run forever + + +def main(): + parser = get_parser() + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + # Note: params.decoding_method is currently not used. + params.res_dir = params.exp_dir / "streaming" / params.decoding_method + + setup_logger(f"{params.res_dir}/log-streaming-decode") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # and are defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + params.device = device + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + if params.avg_last_n > 0: + filenames = find_checkpoints(params.exp_dir)[: params.avg_last_n] + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if start >= 0: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + + model.to(device) + model.eval() + model.device = device + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + global g_params, g_model, g_sp + g_params = params + g_model = model + g_sp = sp + + asyncio.run(loop()) + + +if __name__ == "__main__": + torch.manual_seed(20220506) + main() diff --git a/egs/librispeech/ASR/transducer_emformer/streaming_decode.py b/egs/librispeech/ASR/transducer_emformer/streaming_decode.py index 8ebfbb210..2064bd344 100755 --- a/egs/librispeech/ASR/transducer_emformer/streaming_decode.py +++ b/egs/librispeech/ASR/transducer_emformer/streaming_decode.py @@ -233,6 +233,9 @@ class StreamList(object): for _ in range(batch_size) ] + def __getitem__(self, i) -> FeatureExtractionStream: + return self.streams[i] + @property def done(self) -> bool: """Return True if all streams have reached end of utterance. @@ -667,8 +670,9 @@ def main(): sp = spm.SentencePieceProcessor() sp.load(params.bpe_model) - # is defined in local/train_bpe_model.py + # and are defined in local/train_bpe_model.py params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") params.vocab_size = sp.get_piece_size() params.device = device diff --git a/egs/librispeech/ASR/transducer_emformer/train.py b/egs/librispeech/ASR/transducer_emformer/train.py index 9798fe5e6..dae30f91b 100755 --- a/egs/librispeech/ASR/transducer_emformer/train.py +++ b/egs/librispeech/ASR/transducer_emformer/train.py @@ -378,6 +378,7 @@ def get_decoder_model(params: AttributeDict) -> nn.Module: vocab_size=params.vocab_size, embedding_dim=params.embedding_dim, blank_id=params.blank_id, + unk_id=params.unk_id, context_size=params.context_size, ) return decoder @@ -811,8 +812,9 @@ def run(rank, world_size, args): sp = spm.SentencePieceProcessor() sp.load(params.bpe_model) - # is defined in local/train_bpe_model.py + # and are defined in local/train_bpe_model.py params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") params.vocab_size = sp.get_piece_size() logging.info(params)